diff --git a/.bazelrc b/.bazelrc index fb938169b3c0..755572f21355 100644 --- a/.bazelrc +++ b/.bazelrc @@ -31,6 +31,7 @@ build -c opt build --output_filter=DONT_MATCH_ANYTHING build --copt=-DMLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir. +build --copt=-DNB_DOMAIN=jax # ############################################################################# # Platform Specific configs below. These are automatically picked up by Bazel @@ -97,6 +98,7 @@ build:windows --incompatible_strict_action_env=true # ############################################################################# build:nonccl --define=no_nccl_support=true +build --repo_env USE_PYWRAP_RULES=1 build:posix --copt=-fvisibility=hidden build:posix --copt=-Wno-sign-compare build:posix --cxxopt=-std=c++17 @@ -130,19 +132,21 @@ build:clang --copt=-Wno-gnu-offsetof-extensions build:clang --copt=-Qunused-arguments # Error on struct/class mismatches, since this causes link failures on Windows. build:clang --copt=-Werror=mismatched-tags +# Required when building with clang>=19, see jax-ml/jax#27091 +build:clang --copt=-Wno-error=c23-extensions # Configs for CUDA build:cuda --repo_env TF_NEED_CUDA=1 build:cuda --repo_env TF_NCCL_USE_STUB=1 # "sm" means we emit only cubin, which is forward compatible within a GPU generation. # "compute" means we emit both cubin and PTX, which is larger but also forward compatible to future GPU generations. -build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_80,compute_90" +build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_80,sm_90,sm_100,compute_120" build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain build:cuda --@local_config_cuda//:enable_cuda # Default hermetic CUDA and CUDNN versions. -build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.3.2" -build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.1.1" +build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.8.0" +build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.8.0" build:cuda --@local_config_cuda//cuda:include_cuda_libs=true # This config is used for building targets with CUDA libraries from stubs. @@ -253,6 +257,10 @@ build:ci_linux_aarch64_cuda --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm # Mac Arm64 CI configs build:ci_darwin_arm64 --macos_minimum_os=11.0 +# Clang 19 requires `-Wno-error=c23-extensions` but this flag is not supported +# on Apple Clang in XCode 16.0 so we suppress unknown warning option errors +# on Mac CI builds. +build:ci_darwin_arm64 --copt=-Wno-unknown-warning-option build:ci_darwin_arm64 --config=macos_cache_push build:ci_darwin_arm64 --verbose_failures=true build:ci_darwin_arm64 --color=yes @@ -260,8 +268,8 @@ build:ci_darwin_arm64 --color=yes # Windows x86 CI configs build:ci_windows_amd64 --config=avx_windows build:ci_windows_amd64 --compiler=clang-cl --config=clang --verbose_failures=true -build:ci_windows_amd64 --crosstool_top="@xla//tools/toolchains/win/20240424:toolchain" -build:ci_windows_amd64 --extra_toolchains="@xla//tools/toolchains/win/20240424:cc-toolchain-x64_windows-clang-cl" +build:ci_windows_amd64 --crosstool_top="@xla//tools/toolchains/win2022/20241118:toolchain" +build:ci_windows_amd64 --extra_toolchains="@xla//tools/toolchains/win2022/20241118:cc-toolchain-x64_windows-clang-cl" build:ci_windows_amd64 --host_linkopt=/FORCE:MULTIPLE --linkopt=/FORCE:MULTIPLE build:ci_windows_amd64 --color=yes @@ -329,9 +337,9 @@ common:rbe_windows_amd64 --remote_instance_name=projects/tensorflow-testing/inst build:rbe_windows_amd64 --config=rbe # Set the host, execution, and target platform -build:rbe_windows_amd64 --host_platform="@xla//tools/toolchains/win:x64_windows-clang-cl" -build:rbe_windows_amd64 --extra_execution_platforms="@xla//tools/toolchains/win:x64_windows-clang-cl" -build:rbe_windows_amd64 --platforms="@xla//tools/toolchains/win:x64_windows-clang-cl" +build:rbe_windows_amd64 --host_platform="@xla//tools/toolchains/win2022:windows_ltsc2022_clang" +build:rbe_windows_amd64 --extra_execution_platforms="@xla//tools/toolchains/win2022:windows_ltsc2022_clang" +build:rbe_windows_amd64 --platforms="@xla//tools/toolchains/win2022:windows_ltsc2022_clang" build:rbe_windows_amd64 --shell_executable=C:\\tools\\msys64\\usr\\bin\\bash.exe build:rbe_windows_amd64 --enable_runfiles diff --git a/.github/ISSUE_TEMPLATE/bug-report.yml b/.github/ISSUE_TEMPLATE/bug-report.yml index 628310519b66..1f8c2b2ac254 100644 --- a/.github/ISSUE_TEMPLATE/bug-report.yml +++ b/.github/ISSUE_TEMPLATE/bug-report.yml @@ -24,7 +24,7 @@ body: [issue search]: https://github.com/jax-ml/jax/search?q=is%3Aissue&type=issues - [Raw report]: http://github.com/jax-ml/jax/issues/new + [Raw report]: https://github.com/jax-ml/jax/issues/new?template=none - type: textarea attributes: label: Description diff --git a/.github/workflows/bazel_cpu_rbe.yml b/.github/workflows/bazel_cpu_rbe.yml index d6816d492d1d..ef5084960b30 100644 --- a/.github/workflows/bazel_cpu_rbe.yml +++ b/.github/workflows/bazel_cpu_rbe.yml @@ -46,7 +46,10 @@ jobs: enable-x_64: 1 - python: "3.13" enable-x_64: 0 - name: "Bazel CPU tests (${{ matrix.runner }}, Python ${{ matrix.python }}, x64=${{ matrix.enable-x_64 }})" + # Only test a single Python version on Arm64 as we don't run the tests. + - python: "3.10" + runner: "linux-arm64-c4a-16" + name: "Bazel CPU ${{ (contains(matrix.runner, 'linux-arm64') && 'build only' || 'tests') }} (${{ matrix.runner }}, Python ${{ matrix.python }}, x64=${{ matrix.enable-x_64 }})" # End Presubmit Naming Check github-cpu-presubmits steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 @@ -54,5 +57,7 @@ jobs: uses: google-ml-infra/actions/ci_connection@main with: halt-dispatch-input: ${{ inputs.halt-for-connection }} - - name: Run Bazel CPU Tests with RBE + # Since we do not have a Linux Arm64 RBE pool, we do not run the tests on Arm64. Instead, we + # cross-compile the tests on the Linux x86 RBE pool. + - name: ${{ (contains(matrix.runner, 'linux-arm64') && 'Build' || 'Run') }} Bazel CPU Tests with RBE run: ./ci/run_bazel_test_cpu_rbe.sh \ No newline at end of file diff --git a/.github/workflows/bazel_cuda_non_rbe.yml b/.github/workflows/bazel_cuda_non_rbe.yml index 0b0e1cb62497..ff1cf9900ce3 100644 --- a/.github/workflows/bazel_cuda_non_rbe.yml +++ b/.github/workflows/bazel_cuda_non_rbe.yml @@ -47,7 +47,7 @@ jobs: # Explicitly set the shell to bash shell: bash runs-on: ${{ inputs.runner }} - container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.3-cudnn9.1:latest" + container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.8-cudnn9.8:latest" env: JAXCI_HERMETIC_PYTHON_VERSION: ${{ inputs.python }} @@ -79,6 +79,7 @@ jobs: continue-on-error: true run: >- mkdir -p $(pwd)/dist && + gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/ && gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ && gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*plugin*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ && gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*pjrt*${OS}*${ARCH}*.whl" $(pwd)/dist/ diff --git a/.github/workflows/bazel_optional_b200.yml b/.github/workflows/bazel_optional_b200.yml new file mode 100644 index 000000000000..6335fbacaf2c --- /dev/null +++ b/.github/workflows/bazel_optional_b200.yml @@ -0,0 +1,62 @@ +name: CI - Bazel Optional B200 CUDA tests +on: + # Runs on PR if label "CI Optional GPU Presubmit" is present. + workflow_dispatch: + inputs: + halt-for-connection: + description: 'Should this workflow run wait for a remote connection?' + type: choice + required: true + default: 'no' + options: + - 'yes' + - 'no' + pull_request: + branches: + - main + types: [ labeled, synchronize ] + schedule: + - cron: "0 */2 * * *" # Run once every 2 hours +permissions: + contents: read +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} + # Don't cancel in-progress jobs for main/release branches. + cancel-in-progress: ${{ !contains(github.ref, 'release/') && github.ref != 'main' }} +jobs: + run_tests: + if: ${{ github.event.repository.fork == false && (github.event_name == 'schedule' || contains(github.event.pull_request.labels.*.name, 'CI Optional GPU Presubmit')) }} + runs-on: linux-x86-a4-224-b200-1gpu + container: 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.8-cudnn9.8:latest' + name: "Bazel single B200 CUDA tests" +# End Presubmit Naming Check github-cuda-presubmits + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - name: Wait For Connection + uses: google-ml-infra/actions/ci_connection@main + with: + halt-dispatch-input: ${{ inputs.halt-for-connection }} + - name: Run Bazel CUDA Tests + run: | + nvidia-smi + bazel test --config=ci_linux_x86_64_cuda \ + --config=resultstore \ + --config=rbe_cache \ + --repo_env=HERMETIC_CUDA_VERSION="12.8.0" \ + --repo_env=HERMETIC_CUDNN_VERSION="9.8.0" \ + --repo_env=HERMETIC_PYTHON_VERSION="3.13" \ + --test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform \ + --run_under "$(pwd)/build/parallel_accelerator_execute.sh" \ + --test_output=errors \ + --test_env=JAX_ACCELERATOR_COUNT=1 \ + --test_env=JAX_TESTS_PER_ACCELERATOR=32 \ + --local_test_jobs=32 \ + --test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow \ + --test_tag_filters=-multiaccelerator \ + --test_env=TF_CPP_MIN_LOG_LEVEL=0 \ + --test_env=JAX_SKIP_SLOW_TESTS=true \ + --action_env=JAX_ENABLE_X64="1" \ + --action_env=NCCL_DEBUG=WARN \ + --color=yes \ + //tests:gpu_tests //tests:backend_independent_tests \ + //tests/pallas:gpu_tests //tests/pallas:backend_independent_tests \ No newline at end of file diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index c2e7acb91f7a..37a791784506 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -16,8 +16,8 @@ on: default: "linux-x86-n2-16" options: - "linux-x86-n2-16" - - "linux-arm64-c4a-64" - - "windows-x86-n2-64" + - "linux-arm64-t2a-48" + - "windows-x86-n2-16" artifact: description: "Which JAX artifact to build?" type: choice @@ -119,11 +119,11 @@ jobs: steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Enable RBE if building on Linux x86 - if: contains(inputs.runner, 'linux-x86') + - name: Enable RBE if building on Linux x86 or Windows x86 + if: contains(inputs.runner, 'linux-x86') || contains(inputs.runner, 'windows-x86') run: echo "JAXCI_BUILD_ARTIFACT_WITH_RBE=1" >> $GITHUB_ENV - - name: Enable Bazel remote cache (with writes enabled) if building on Linux Aarch64 or Windows x86 - if: contains(inputs.runner, 'linux-arm64') || contains(inputs.runner, 'windows-x86') + - name: Enable Bazel remote cache (with writes enabled) if building on Linux Aarch64 + if: contains(inputs.runner, 'linux-arm64') run: echo "JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE=1" >> $GITHUB_ENV # Halt for testing - name: Wait For Connection diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml index f43407af2ed9..7bb31f9d0327 100644 --- a/.github/workflows/ci-build.yaml +++ b/.github/workflows/ci-build.yaml @@ -31,11 +31,11 @@ jobs: steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set up Python 3.11 - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 + uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0 with: python-version: 3.11 - run: python -m pip install pre-commit - - uses: actions/cache@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 + - uses: actions/cache@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3 with: path: ~/.cache/pre-commit key: pre-commit-${{ env.pythonLocation }}-${{ hashFiles('.pre-commit-config.yaml', 'setup.py') }} @@ -70,7 +70,7 @@ jobs: apt update apt install -y libssl-dev - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 + uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0 with: python-version: ${{ matrix.python-version }} - name: Install dependencies @@ -108,7 +108,7 @@ jobs: steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 + uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0 with: python-version: ${{ matrix.python-version }} - name: Install dependencies @@ -140,9 +140,9 @@ jobs: - name: Image Setup run: | apt update - apt install -y libssl-dev libsqlite3-dev + apt install -y libssl-dev libsqlite3-dev build-essential - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 + uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0 with: python-version: ${{ matrix.python-version }} - name: Install dependencies @@ -151,7 +151,7 @@ jobs: uv pip install --system -r docs/requirements.txt - name: Render documentation run: | - sphinx-build -j auto --color -W --keep-going -b html -D nb_execution_mode=off docs docs/build/html + sphinx-build -j auto --color -W --keep-going -b html docs docs/build/html jax2tf_test: name: "jax2tf_test (py ${{ matrix.python-version }} on ${{ matrix.os }}, x64=${{ matrix.enable-x64}})" @@ -168,7 +168,7 @@ jobs: steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 + uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0 with: python-version: ${{ matrix.python-version }} - name: Install dependencies @@ -201,7 +201,7 @@ jobs: steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set up Python - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 + uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0 with: python-version: 3.12 - name: Install JAX diff --git a/.github/workflows/cloud-tpu-ci-nightly.yml b/.github/workflows/cloud-tpu-ci-nightly.yml index 099f4ad5c520..b50b07d5cc4a 100644 --- a/.github/workflows/cloud-tpu-ci-nightly.yml +++ b/.github/workflows/cloud-tpu-ci-nightly.yml @@ -26,7 +26,6 @@ jobs: matrix: jaxlib-version: ["head", "pypi_latest", "nightly", "nightly+oldest_supported_libtpu"] tpu: [ - # {type: "v3-8", cores: "4"}, # Enable when we have the v3 type available {type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"}, {type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"} ] diff --git a/.github/workflows/cloud-tpu-ci-presubmit.yml b/.github/workflows/cloud-tpu-ci-presubmit.yml index a92e3cc19313..40c99735c2de 100644 --- a/.github/workflows/cloud-tpu-ci-presubmit.yml +++ b/.github/workflows/cloud-tpu-ci-presubmit.yml @@ -62,4 +62,5 @@ jobs: python: "3.10" libtpu-version-type: "nightly" gcs_download_uri: ${{ needs.build-jax-artifacts.outputs.gcs_upload_uri }} + halt-for-connection: ${{ inputs.halt-for-connection || false }} # End Presubmit Naming Check github-tpu-presubmits \ No newline at end of file diff --git a/.github/workflows/community_release_actions.yml b/.github/workflows/community_release_actions.yml new file mode 100644 index 000000000000..1110cbad9475 --- /dev/null +++ b/.github/workflows/community_release_actions.yml @@ -0,0 +1,34 @@ +name: Release Actions + +on: + release: + types: [published] + +permissions: + contents: read + +jobs: + discord_release: + if: github.repository_owner == 'jax-ml' + runs-on: ubuntu-latest + steps: + - name: Get release URL + id: get-release-url + run: | + URL="https://docs.jax.dev/en/latest/changelog.html" + echo "::set-output name=URL::$URL" + - name: Get content + uses: 2428392/gh-truncate-string-action@b3ff790d21cf42af3ca7579146eedb93c8fb0757 # v1.4.1 + id: get-content + with: + stringToTruncate: | + JAX [${{ github.event.release.tag_name }}](<${{ steps.get-release-url.outputs.URL }}>) was just released! + + ${{ github.event.release.body }} + maxLength: 2000 + truncationSymbol: "..." + - name: Discord Webhook Action + uses: tsickert/discord-webhook@b217a69502f52803de774ded2b1ab7c282e99645 # v7.0.0 + with: + webhook-url: ${{ secrets.DISCORD_WEBHOOK_URL }} + content: ${{ steps.get-content.outputs.string }} diff --git a/.github/workflows/jax-array-api.yml b/.github/workflows/jax-array-api.yml index 2b97c5a05c1c..7df4228dd2a3 100644 --- a/.github/workflows/jax-array-api.yml +++ b/.github/workflows/jax-array-api.yml @@ -28,11 +28,11 @@ jobs: with: repository: data-apis/array-api-tests # TODO(jakevdp) update this to a stable release/tag when available. - ref: '0b89c5268e4e4a352223a487b8f63dbd1023872d' # Latest commit as of 2025-03-04 + ref: 'c48410f96fc58e02eea844e6b7f6cc01680f77ce' # Latest commit as of 2025-04-02 submodules: 'true' path: 'array-api-tests' - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 + uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0 with: python-version: ${{ matrix.python-version }} - name: Install dependencies diff --git a/.github/workflows/k8s.yaml b/.github/workflows/k8s.yaml new file mode 100644 index 000000000000..1042388fe9c6 --- /dev/null +++ b/.github/workflows/k8s.yaml @@ -0,0 +1,109 @@ +name: Distributed run using K8s Jobset + +on: + push: + branches: + - main + paths: + - 'jax/distributed.py' + - 'jax/_src/distributed.py' + - 'jax/_src/clusters/**' + pull_request: + branches: + - main + paths: + - 'jax/distributed.py' + - 'jax/_src/distributed.py' + - 'jax/_src/clusters/**' + +permissions: + contents: read + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} + cancel-in-progress: true + +defaults: + run: + shell: bash -ex -o pipefail {0} + +jobs: + + distributed-initialize: + runs-on: ubuntu-22.04 + steps: + - name: Checkout + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # ratchet:actions/checkout@v4 + with: + path: jax + + - name: Start Minikube cluster + uses: medyagh/setup-minikube@cea33675329b799adccc9526aa5daccc26cd5052 # ratchet:medyagh/setup-minikube@v0.0.19 + + - name: Install K8s Jobset + run: | + kubectl apply --server-side -f https://github.com/kubernetes-sigs/jobset/releases/download/v0.6.0/manifests.yaml + + - name: Build image + run: | + cat > Dockerfile <> $GITHUB_ENV + else + gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ && + gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*plugin*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ && + gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*pjrt*${OS}*${ARCH}*.whl" $(pwd)/dist/ + fi - name: Skip the test run if the wheel artifacts were not downloaded successfully if: steps.download-wheel-artifacts.outcome == 'failure' run: | diff --git a/.github/workflows/pytest_tpu.yml b/.github/workflows/pytest_tpu.yml index a105a2feb347..0b56635a8aac 100644 --- a/.github/workflows/pytest_tpu.yml +++ b/.github/workflows/pytest_tpu.yml @@ -54,6 +54,11 @@ on: # - "pypi_latest": Use the latest libtpu wheel from PyPI. # - "oldest_supported_libtpu": Use the oldest supported libtpu wheel. default: "nightly" + download-jax-only-from-gcs: + description: "Whether to download only the jax wheel from GCS (e.g for testing a jax only release)" + required: false + default: '0' + type: string gcs_download_uri: description: "GCS location prefix from where the artifacts should be downloaded" required: true @@ -110,7 +115,11 @@ jobs: run: | mkdir -p $(pwd)/dist gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/ - gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ + if [[ "${{ inputs.download-jax-only-from-gcs }}" == "1" ]]; then + echo "JAX only release. Only downloading the jax wheel from the release bucket." + else + gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ + fi - name: Skip the test run if the wheel artifacts were not downloaded successfully if: steps.download-wheel-artifacts.outcome == 'failure' run: | diff --git a/.github/workflows/tsan-suppressions.txt b/.github/workflows/tsan-suppressions_3.13.txt similarity index 83% rename from .github/workflows/tsan-suppressions.txt rename to .github/workflows/tsan-suppressions_3.13.txt index 7b713b2da194..dac134bf5169 100644 --- a/.github/workflows/tsan-suppressions.txt +++ b/.github/workflows/tsan-suppressions_3.13.txt @@ -2,14 +2,11 @@ # are racing on a call to __register_frame_info(), but that function appears to be correctly locked internally. race:llvm::RuntimeDyldELF::registerEHFrames -# https://github.com/python/cpython/issues/128050 -race:partial_vectorcall_fallback - # https://github.com/openxla/xla/issues/20686 race:dnnl_sgemm -# https://github.com/python/cpython/issues/128130 -race_top:run_eval_code_obj +# https://github.com/python/cpython/issues/128050 +race:partial_vectorcall_fallback # Likely only happens when the process is crashing. race:dump_traceback @@ -18,18 +15,22 @@ race:dump_traceback # Fixed in Python 3.14, but not backported to 3.13. race:immortalize_interned race:_PyUnicode_InternMortal +race:_PyUnicode_InternImmortal # https://github.com/python/cpython/issues/128144 # Fixed in Python 3.14, but not backported to 3.13. race_top:PyMember_GetOne -# https://github.com/python/cpython/issues/129547 -race:type_get_annotations - +# https://github.com/python/cpython/issues/131680 +# Fixed in Python 3.14, but not backported to 3.13. +race_top:new_reference +race:_Py_IsOwnedByCurrentThread # https://github.com/python/cpython/issues/129748 race:mi_block_set_nextx +# https://github.com/python/cpython/issues/128130 +race_top:run_eval_code_obj # Races because the LAPACK and BLAS in our scipy isn't TSAN instrumented. race:heevd_ffi @@ -65,3 +66,11 @@ race:gemm_oncopy # https://github.com/python/cpython/issues/130547 # race:split_keys_entry_added + +# https://github.com/python/cpython/issues/129547 +# Maybe fixed? +# race:type_get_annotations + +# https://github.com/python/cpython/issues/132013 +# Fixed on 3.14 and not backported to 3.13 +race_top:frozenset_hash \ No newline at end of file diff --git a/.github/workflows/tsan-suppressions_3.14.txt b/.github/workflows/tsan-suppressions_3.14.txt new file mode 100644 index 000000000000..6e1d34e6db65 --- /dev/null +++ b/.github/workflows/tsan-suppressions_3.14.txt @@ -0,0 +1,29 @@ +# false-positive caused because we haven't tsan-instrumented libgcc_s. Multiple threads +# are racing on a call to __register_frame_info(), but that function appears to be correctly locked internally. +race:llvm::RuntimeDyldELF::registerEHFrames + +# https://github.com/openxla/xla/issues/20686 +race:dnnl_sgemm + +# https://github.com/python/cpython/issues/128050 +race:partial_vectorcall_fallback + +# Likely only happens when the process is crashing. +race:dump_traceback + +# https://github.com/python/cpython/issues/129748 +race:mi_block_set_nextx + +# https://github.com/python/cpython/issues/128130 +race_top:run_eval_code_obj + +# https://github.com/python/cpython/issues/132214 +race_top:update_one_slot + +# Races because the LAPACK and BLAS in our scipy isn't TSAN instrumented. +race:heevd_ffi +race:gesdd_ffi +race:dscal_k_ +race:scal_k_ +race:gemm_beta +race:gemm_oncopy diff --git a/.github/workflows/tsan.yaml b/.github/workflows/tsan.yaml index 7d93707e4e92..ef1cf99d6d74 100644 --- a/.github/workflows/tsan.yaml +++ b/.github/workflows/tsan.yaml @@ -13,6 +13,8 @@ on: - main paths: - '**/workflows/tsan.yaml' + - '**/workflows/tsan-suppressions*.txt' + - '**/workflows/requirements_lock_3_13_ft.patch' jobs: tsan: @@ -21,6 +23,16 @@ jobs: image: index.docker.io/library/ubuntu@sha256:b359f1067efa76f37863778f7b6d0e8d911e3ee8efa807ad01fbf5dc1ef9006b # ratchet:ubuntu:24.04 strategy: fail-fast: false + matrix: + include: + - name-prefix: "with 3.13" + python-version: "3.13" + github_branch: "3.13" + requirements_lock_name: "requirements_lock_3_13_ft" + - name-prefix: "with 3.14" + python-version: "3.14" + github_branch: "main" + requirements_lock_name: "requirements_lock_3_14_ft" defaults: run: shell: bash -l {0} @@ -43,22 +55,33 @@ jobs: with: repository: python/cpython path: cpython - ref: "3.13" + ref: ${{ matrix.github_branch }} - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: repository: numpy/numpy path: numpy submodules: true + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + if: ${{ matrix.python-version == '3.14' }} + with: + repository: scipy/scipy + path: scipy + submodules: true - - name: Restore cached TSAN CPython + - name: Get year & week number + id: get-date + run: echo "date=$(/bin/date "+%Y-%U")" >> $GITHUB_OUTPUT + shell: bash -l {0} + + - name: Restore cached TSAN CPython ${{ matrix.python-version }} id: cache-cpython-tsan-restore - uses: actions/cache/restore@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 + uses: actions/cache/restore@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3 with: path: | ./python-tsan.tgz - key: ${{ runner.os }}-cpython-tsan-${{ hashFiles('cpython/configure.ac') }} + key: ${{ runner.os }}-cpython-tsan-${{ matrix.python-version }}-${{ steps.get-date.outputs.date }} - - name: Build CPython with enabled TSAN + - name: Build TSAN CPython ${{ matrix.python-version }} if: steps.cache-cpython-tsan-restore.outputs.cache-hit != 'true' run: | cd cpython @@ -72,27 +95,22 @@ jobs: # Create archive to be used with bazel as hermetic python: cd ${GITHUB_WORKSPACE} && tar -czpf python-tsan.tgz cpython-tsan - - name: Save TSAN CPython + - name: Save TSAN CPython ${{ matrix.python-version }} id: cache-cpython-tsan-save if: steps.cache-cpython-tsan-restore.outputs.cache-hit != 'true' - uses: actions/cache/save@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 + uses: actions/cache/save@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3 with: path: | ./python-tsan.tgz - key: ${{ runner.os }}-cpython-tsan-${{ hashFiles('cpython/configure.ac') }} - - - name: Get year & week number - id: get-date - run: echo "date=$(/bin/date "+%Y-%U")" >> $GITHUB_OUTPUT - shell: bash -l {0} + key: ${{ runner.os }}-cpython-tsan-${{ matrix.python-version }}-${{ steps.get-date.outputs.date }} - name: Restore cached TSAN Numpy id: cache-numpy-tsan-restore - uses: actions/cache/restore@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 + uses: actions/cache/restore@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3 with: path: | ./wheelhouse - key: ${{ runner.os }}-numpy-tsan-${{ hashFiles('numpy/pyproject.toml') }}-${{ steps.get-date.outputs.date }} + key: ${{ runner.os }}-numpy-tsan-${{ matrix.python-version }}-${{ hashFiles('numpy/pyproject.toml') }}-${{ steps.get-date.outputs.date }} - name: Build TSAN Numpy wheel if: steps.cache-numpy-tsan-restore.outputs.cache-hit != 'true' @@ -113,7 +131,8 @@ jobs: python3 -m pip install uv~=0.5.30 # Make sure to install a compatible Cython version (master branch is best for now) - python3 -m uv pip install -r requirements/build_requirements.txt -U git+https://github.com/cython/cython + NO_CYTHON_COMPILE=true python3 -m uv pip install -U git+https://github.com/cython/cython + python3 -m uv pip install -r requirements/build_requirements.txt CC=clang-18 CXX=clang++-18 python3 -m pip wheel --wheel-dir dist -v . --no-build-isolation -Csetup-args=-Db_sanitize=thread -Csetup-args=-Dbuildtype=debugoptimized @@ -142,11 +161,87 @@ jobs: - name: Save TSAN Numpy wheel id: cache-numpy-tsan-save if: steps.cache-numpy-tsan-restore.outputs.cache-hit != 'true' - uses: actions/cache/save@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 + uses: actions/cache/save@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3 + with: + path: | + ./wheelhouse + key: ${{ runner.os }}-numpy-tsan-${{ matrix.python-version }}-${{ hashFiles('numpy/pyproject.toml') }}-${{ steps.get-date.outputs.date }} + + - name: Restore cached Scipy + if: ${{ matrix.python-version == '3.14' }} + id: cache-scipy-restore + uses: actions/cache/restore@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3 with: path: | ./wheelhouse - key: ${{ runner.os }}-numpy-tsan-${{ hashFiles('numpy/pyproject.toml') }}-${{ steps.get-date.outputs.date }} + key: ${{ runner.os }}-scipy-${{ matrix.python-version }}-${{ hashFiles('scipy/pyproject.toml') }}-${{ steps.get-date.outputs.date }} + + - name: Build Scipy wheel + if: ${{ steps.cache-scipy-restore.outputs.cache-hit != 'true' && matrix.python-version == '3.14' }} + run: | + # Install scipy dependencies: + apt-get install -y gfortran libopenblas-dev liblapack-dev pkg-config --no-install-recommends + + cd scipy + + # If we restored cpython from cache, we need to get python interpreter from python-tsan.tgz + if [ ! -d ${GITHUB_WORKSPACE}/cpython-tsan/bin/ ]; then + echo "Extract cpython from python-tsan.tgz" + pushd . + ls ${GITHUB_WORKSPACE}/python-tsan.tgz + cd ${GITHUB_WORKSPACE} && tar -xzf python-tsan.tgz + ls ${GITHUB_WORKSPACE}/cpython-tsan/bin/ + popd + fi + + export PATH=${GITHUB_WORKSPACE}/cpython-tsan/bin/:$PATH + + python3 -m pip install uv~=0.5.30 + # Make sure to install a compatible Cython version (master branch is best for now) + NO_CYTHON_COMPILE=true python3 -m uv pip install -U git+https://github.com/cython/cython + python3 -m uv pip install -U --pre numpy --extra-index-url file://${GITHUB_WORKSPACE}/wheelhouse/ + python3 -m uv pip install pythran pybind11 meson-python ninja + + python3 -m uv pip list | grep -E "(numpy|pythran|cython|pybind11)" + + export CC=clang-18 + export CXX=clang++-18 + python3 -m pip wheel --wheel-dir dist -vvv . --no-build-isolation --no-deps -Csetup-args=-Dbuildtype=debugoptimized + + python3 -m uv pip list | grep -E "(numpy|pythran|cython|pybind11)" + + # Create simple index and copy the wheel + mkdir -p ${GITHUB_WORKSPACE}/wheelhouse/scipy + + scipy_whl_name=($(cd dist && ls scipy*.whl)) + if [ -z "${scipy_whl_name}" ]; then exit 1; fi + + echo "Built TSAN Scipy wheel: ${scipy_whl_name}" + + cp dist/${scipy_whl_name} ${GITHUB_WORKSPACE}/wheelhouse/scipy + + # Recreate wheelhouse index with Numpy and Scipy + cat << EOF > ${GITHUB_WORKSPACE}/wheelhouse/index.html + + numpy>
+ scipy>
+ + EOF + + cat << EOF > ${GITHUB_WORKSPACE}/wheelhouse/scipy/index.html + + ${scipy_whl_name}
+ + EOF + + - name: Save Scipy wheel + id: cache-scipy-save + if: ${{ steps.cache-scipy-restore.outputs.cache-hit != 'true' && matrix.python-version == '3.14' }} + uses: actions/cache/save@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3 + with: + path: | + ./wheelhouse + key: ${{ runner.os }}-scipy-${{ matrix.python-version }}-${{ hashFiles('scipy/pyproject.toml') }}-${{ steps.get-date.outputs.date }} - name: Build Jax and run tests timeout-minutes: 120 @@ -163,7 +258,7 @@ jobs: python3 -VV python3 build/build.py build --configure_only \ - --python_version=3.13-ft \ + --python_version=${{ matrix.python-version }}-ft \ --bazel_options=--repo_env=HERMETIC_PYTHON_URL="file://${GITHUB_WORKSPACE}/python-tsan.tgz" \ --bazel_options=--repo_env=HERMETIC_PYTHON_SHA256=${PYTHON_SHA256} \ --bazel_options=--repo_env=HERMETIC_PYTHON_PREFIX="cpython-tsan/" \ @@ -173,18 +268,32 @@ jobs: --bazel_options=--copt=-g \ --clang_path=/usr/bin/clang-18 - # Patch build/requirements_lock_3_13_ft.txt to use TSAN instrumented NumPy - sed -i "s|+--extra-index-url.*|+--extra-index-url file://${GITHUB_WORKSPACE}/wheelhouse/|" .github/workflows/requirements_lock_3_13_ft.patch - cat .github/workflows/requirements_lock_3_13_ft.patch - git apply .github/workflows/requirements_lock_3_13_ft.patch || exit 1 + if [ "${{ matrix.python-version }}" == "3.13" ]; then + # Patch build/requirements_lock_3_13_ft.txt to use TSAN instrumented NumPy - # Display the content for debugging in logs - cat build/requirements_lock_3_13_ft.txt | head -15 - # Check the patch - cat build/requirements_lock_3_13_ft.txt | head -15 | grep -E "(--pre|.*${GITHUB_WORKSPACE}/wheelhouse/|numpy)" - if [ "$?" == "1" ]; then echo "Could not find the patch in the requirements_lock_3_13_ft.txt"; exit 1; fi - cat build/requirements_lock_3_13_ft.txt | grep -E "(numpy==)" - if [ "$?" == "0" ]; then "Found original numpy dependency in the requirements_lock_3_13_ft.txt"; exit 1; fi + sed -i "s|+--extra-index-url.*|+--extra-index-url file://${GITHUB_WORKSPACE}/wheelhouse/|" .github/workflows/${{ matrix.requirements_lock_name }}.patch + cat .github/workflows/${{ matrix.requirements_lock_name }}.patch + git apply .github/workflows/${{ matrix.requirements_lock_name }}.patch || exit 1 + + # Display the content for debugging in logs + cat build/${{ matrix.requirements_lock_name }}.txt | head -15 + # Check the patch + cat build/${{ matrix.requirements_lock_name }}.txt | head -15 | grep -E "(--pre|.*${GITHUB_WORKSPACE}/wheelhouse/|numpy)" + if [ "$?" == "1" ]; then echo "Could not find the patch in the ${{ matrix.requirements_lock_name }}.txt"; exit 1; fi + cat build/${{ matrix.requirements_lock_name }}.txt | grep -E "(numpy==)" + if [ "$?" == "0" ]; then "Found original numpy dependency in the ${{ matrix.requirements_lock_name }}.txt"; exit 1; fi + + else + # Patch build/requirements_lock_3_14_ft.txt to use TSAN instrumented NumPy and Scipy + + sed -i "s|--extra-index-url.*|--extra-index-url file://${GITHUB_WORKSPACE}/wheelhouse/|" build/${{ matrix.requirements_lock_name }}.txt + + # We should install jpeg dev package to be able to build Pillow from source: + apt-get install -y libjpeg-dev --no-install-recommends + + # Install scipy runtime dependencies (in case we restore scipy wheel from cache): + apt-get install -y libopenblas-dev liblapack-dev --no-install-recommends + fi echo "JAX_NUM_GENERATED_CASES=$JAX_NUM_GENERATED_CASES" echo "JAX_ENABLE_X64=$JAX_ENABLE_X64" @@ -200,17 +309,22 @@ jobs: # Check numpy version ./bazel cquery @pypi_numpy//:* | grep whl + if [ "${{ matrix.python-version }}" == "3.14" ]; then + # Check scipy version + ./bazel cquery @pypi_scipy//:* | grep whl + fi + # Build JAX and run tests ./bazel test \ --test_env=JAX_NUM_GENERATED_CASES=$JAX_NUM_GENERATED_CASES \ --test_env=JAX_ENABLE_X64=$JAX_ENABLE_X64 \ --test_env=JAX_SKIP_SLOW_TESTS=$JAX_SKIP_SLOW_TESTS \ --test_env=PYTHON_GIL=0 \ - --test_env=TSAN_OPTIONS=halt_on_error=1,suppressions=$PWD/.github/workflows/tsan-suppressions.txt \ + --test_env=TSAN_OPTIONS=halt_on_error=1,suppressions=$PWD/.github/workflows/tsan-suppressions_${{ matrix.python-version }}.txt \ --test_env=JAX_TEST_NUM_THREADS=8 \ --test_output=errors \ --local_test_jobs=32 \ - --test_timeout=600 \ + --test_timeout=1800 \ --config=resultstore \ --config=rbe_cache \ //tests:cpu_tests diff --git a/.github/workflows/upstream-nightly.yml b/.github/workflows/upstream-nightly.yml index 5132a12cf16f..ba2c750f8a8a 100644 --- a/.github/workflows/upstream-nightly.yml +++ b/.github/workflows/upstream-nightly.yml @@ -33,7 +33,7 @@ jobs: steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 + uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0 with: python-version: ${{ matrix.python-version }} - name: Install JAX test requirements diff --git a/.github/workflows/wheel_tests_continuous.yml b/.github/workflows/wheel_tests_continuous.yml index ecdf43b133cc..175fc2f22d4a 100644 --- a/.github/workflows/wheel_tests_continuous.yml +++ b/.github/workflows/wheel_tests_continuous.yml @@ -19,7 +19,7 @@ name: CI - Wheel Tests (Continuous) on: schedule: - - cron: "0 */2 * * *" # Run once every 2 hours + - cron: "0 */3 * * *" # Run once every 3 hours workflow_dispatch: # allows triggering the workflow run manually concurrency: @@ -44,7 +44,7 @@ jobs: fail-fast: false # don't cancel all jobs on failure matrix: # Runner OS and Python values need to match the matrix stategy in the CPU tests job - runner: ["linux-x86-n2-16", "linux-arm64-t2a-48", "windows-x86-n2-64"] + runner: ["linux-x86-n2-16", "linux-arm64-t2a-48", "windows-x86-n2-16"] artifact: ["jaxlib"] python: ["3.10"] # Note: For reasons unknown, Github actions groups jobs with the same top-level name in the @@ -111,28 +111,21 @@ jobs: matrix: # Python values need to match the matrix stategy in the artifact build jobs above # See exlusions for what is fully tested - runner: ["linux-x86-g2-48-l4-4gpu", "linux-x86-a3-8g-h100-8gpu","linux-x86-a4-224-b200-1gpu"] + runner: ["linux-x86-g2-48-l4-4gpu", "linux-x86-a3-8g-h100-8gpu", "linux-x86-a4-224-b200-1gpu"] python: ["3.10",] - cuda: ["12.1","12.3","12.8"] + cuda: ["12.1", "12.8"] enable-x64: [1, 0] exclude: - # L4 does not run on cuda 12.8 but tests other configs - - runner: "linux-x86-g2-48-l4-4gpu" - cuda: "12.8" - # H100 runs only a single config, CUDA 12.3 Enable x64 1 - - runner: "linux-x86-a3-8g-h100-8gpu" - cuda: "12.8" + # H100 runs only a single config, CUDA 12.8 Enable x64 1 - runner: "linux-x86-a3-8g-h100-8gpu" cuda: "12.1" - runner: "linux-x86-a3-8g-h100-8gpu" enable-x64: "0" # B200 runs only a single config, CUDA 12.8 Enable x64 1 - - runner: "linux-x86-a4-224-b200-1gpu" - enable-x64: "0" - runner: "linux-x86-a4-224-b200-1gpu" cuda: "12.1" - runner: "linux-x86-a4-224-b200-1gpu" - cuda: "12.3" + enable-x64: "0" name: "Pytest CUDA (JAX artifacts version = ${{ format('{0}', 'head') }})" with: @@ -148,7 +141,7 @@ jobs: # build job fails. E.g Windows build job fails but everything else succeeds. In this case, we # still want to run the tests for other platforms. if: ${{ !cancelled() }} - needs: [build-jaxlib-artifact, build-cuda-artifacts] + needs: [build-jax-artifact, build-jaxlib-artifact, build-cuda-artifacts] uses: ./.github/workflows/bazel_cuda_non_rbe.yml strategy: fail-fast: false # don't cancel all jobs on failure diff --git a/.github/workflows/wheel_tests_nightly_release.yml b/.github/workflows/wheel_tests_nightly_release.yml index adb678be9d9d..f6d2aa9b97c6 100644 --- a/.github/workflows/wheel_tests_nightly_release.yml +++ b/.github/workflows/wheel_tests_nightly_release.yml @@ -1,12 +1,14 @@ # CI - Wheel Tests (Nightly/Release) # -# This workflow builds JAX artifacts and runs CPU/CUDA tests. +# This workflow is used to test the JAX wheels that was built by internal CI jobs. # -# It orchestrates the following: -# 1. run-pytest-cpu: Calls the `pytest_cpu.yml` workflow which downloads the jaxlib wheel that was +# 1. run-pytest-cpu: Calls the `pytest_cpu.yml` workflow which downloads the JAX wheels that was # built by internal CI jobs and runs CPU tests. -# 2. run-pytest-cuda: Calls the `pytest_cuda.yml` workflow which downloads the jaxlib and CUDA -# artifacts that were built by internal CI jobs and runs the CUDA tests. +# 2. run-pytest-cuda: Calls the `pytest_cuda.yml` workflow which downloads the JAX wheels that was +# built by internal CI jobs and runs CUDA tests. +# 3. run-pytest-tpu: Calls the `pytest_tpu.yml` workflow which downloads the JAX wheels that was +# built by internal CI jobs and runs TPU tests. +# 4. verify-release-wheels-install: Verifies that JAX's release wheels can be installed. name: CI - Wheel Tests (Nightly/Release) on: @@ -17,6 +19,11 @@ on: required: true default: 'gs://jax-nightly-release-transient/nightly/latest' type: string + download-jax-only-from-gcs: + description: "Whether to download only the jax wheel from GCS (e.g for testing a jax only release)" + required: true + default: '0' + type: string concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} @@ -41,6 +48,7 @@ jobs: runner: ${{ matrix.runner }} python: ${{ matrix.python }} enable-x64: ${{ matrix.enable-x64 }} + download-jax-only-from-gcs: ${{inputs.download-jax-only-from-gcs}} gcs_download_uri: ${{inputs.gcs_download_uri}} run-pytest-cuda: @@ -52,7 +60,7 @@ jobs: # that build the wheels. runner: ["linux-x86-g2-48-l4-4gpu"] python: ["3.10","3.11", "3.12", "3.13", "3.13-nogil"] - cuda: ["12.3", "12.1"] + cuda: ["12.1", "12.8"] enable-x64: [0] name: "Pytest CUDA (JAX artifacts version = ${{ startsWith(github.ref_name, 'release/') && 'latest release' || 'nightly' }})" with: @@ -60,6 +68,7 @@ jobs: python: ${{ matrix.python }} cuda: ${{ matrix.cuda }} enable-x64: ${{ matrix.enable-x64 }} + download-jax-only-from-gcs: ${{inputs.download-jax-only-from-gcs}} gcs_download_uri: ${{inputs.gcs_download_uri}} run-pytest-tpu: @@ -79,7 +88,7 @@ jobs: exclude: - libtpu-version-type: ${{ startsWith(github.ref_name, 'release/') && 'nightly' }} - libtpu-version-type: ${{ !startsWith(github.ref_name, 'release/') && 'pypi_latest' }} - # Run a single Python version for v4-8. + # Run a single Python version for v4-8 - tpu-specs: type: "v4-8" python: "3.10" @@ -98,4 +107,89 @@ jobs: python: ${{ matrix.python }} run-full-tpu-test-suite: "1" libtpu-version-type: ${{ matrix.libtpu-version-type }} - gcs_download_uri: ${{inputs.gcs_download_uri}} \ No newline at end of file + download-jax-only-from-gcs: ${{inputs.download-jax-only-from-gcs}} + gcs_download_uri: ${{inputs.gcs_download_uri}} + + verify-release-wheels-install: + if: ${{ startsWith(github.ref_name, 'release/')}} + defaults: + run: + # Set the shell to bash as GitHub actions runs with /bin/sh by default + shell: bash + runs-on: linux-x86-n2-16 + strategy: + fail-fast: false # don't cancel all jobs on failure + matrix: + python: ["3.10", "3.13", "3.13-nogil"] + container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest" + + # Verifies that JAX's release wheels can be installed + name: "Verify release wheels install (Python ${{ matrix.python }})" + + env: + PYTHON: "python${{ matrix.python }}" + + steps: + - name: Download release wheels from GCS + run: | + mkdir -p $(pwd)/dist + final_gcs_download_uri=${{ inputs.gcs_download_uri }} + + # Get the major and minor version of Python. + # E.g if python=3.10, then python_major_minor=310 + # E.g if python=3.13-nogil, then python_major_minor=313t + python_major_minor=${{ matrix.python }} + python_major_minor=$(echo "${python_major_minor//-nogil/t}" | tr -d '.') + python_major_minor="cp${python_major_minor%t}-cp${python_major_minor}-" + + gsutil -m cp -r "${final_gcs_download_uri}"/jax*py3*none*any.whl $(pwd)/dist/ + + jax_wheel=$(ls dist/jax*py3*none*any.whl 2>/dev/null) + echo "JAX_WHEEL=$jax_wheel" >> $GITHUB_ENV + + if [[ "${{ inputs.download-jax-only-from-gcs }}" != "1" ]]; then + gsutil -m cp -r "${final_gcs_download_uri}/jaxlib*${python_major_minor}*linux*x86_64*.whl" $(pwd)/dist/ + gsutil -m cp -r "${final_gcs_download_uri}/jax*cuda*plugin*${python_major_minor}*linux*x86_64*.whl" $(pwd)/dist/ + gsutil -m cp -r "${final_gcs_download_uri}/jax*cuda*pjrt*linux*x86_64*.whl" $(pwd)/dist/ + + jaxlib_wheel=$(ls dist/jaxlib*${python_major_minor}*linux*x86_64*.whl 2>/dev/null) + jax_cuda_plugin_wheel=$(ls dist/jax*cuda*plugin*${python_major_minor}*linux*x86_64*.whl 2>/dev/null) + jax_cuda_pjrt_wheel=$(ls dist/jax*cuda*pjrt*linux*x86_64*.whl 2>/dev/null) + + echo "JAXLIB_WHEEL=$jaxlib_wheel" >> $GITHUB_ENV + echo "JAX_CUDA_PLUGIN_WHEEL=$jax_cuda_plugin_wheel" >> $GITHUB_ENV + echo "JAX_CUDA_PJRT_WHEEL=$jax_cuda_pjrt_wheel" >> $GITHUB_ENV + fi + - name: Verify JAX CPU packages can be installed + run: | + $PYTHON -m uv venv ~/test_cpu && source ~/test_cpu/bin/activate + if [[ "${{ inputs.download-jax-only-from-gcs }}" == "1" ]]; then + uv pip install $JAX_WHEEL + else + uv pip install $JAX_WHEEL $JAXLIB_WHEEL + fi + - name: Verify JAX TPU packages can be installed + run: | + $PYTHON -m uv venv ~/test_tpu && source ~/test_tpu/bin/activate + + if [[ "${{ inputs.download-jax-only-from-gcs }}" == "1" ]]; then + uv pip install $JAX_WHEEL[tpu] + else + uv pip install $JAX_WHEEL[tpu] $JAXLIB_WHEEL + fi + - name: Verify JAX CUDA packages can be installed (Nvidia Pip Packages) + run: | + $PYTHON -m uv venv ~/test_cuda_pip && source ~/test_cuda_pip/bin/activate + if [[ "${{ inputs.download-jax-only-from-gcs }}" == "1" ]]; then + uv pip install $JAX_WHEEL[cuda] + else + uv pip install $JAX_WHEEL[cuda] $JAXLIB_WHEEL $JAX_CUDA_PJRT_WHEEL $JAX_CUDA_PLUGIN_WHEEL[with-cuda] + fi + - name: Verify JAX CUDA packages can be installed (CUDA local) + run: | + $PYTHON -m uv venv ~/test_cuda_local && source ~/test_cuda_local/bin/activate + if [[ "${{ inputs.download-jax-only-from-gcs }}" == "1" ]]; then + uv pip install $JAX_WHEEL[cuda12-local] + else + uv pip install $JAX_WHEEL $JAXLIB_WHEEL $JAX_CUDA_PJRT_WHEEL $JAX_CUDA_PLUGIN_WHEEL + fi \ No newline at end of file diff --git a/.github/workflows/wheel_win_x64.yml b/.github/workflows/wheel_win_x64.yml index 444bc83f2889..a2b3aeddc24a 100644 --- a/.github/workflows/wheel_win_x64.yml +++ b/.github/workflows/wheel_win_x64.yml @@ -27,7 +27,7 @@ jobs: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 + - uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0 with: python-version: ${{ matrix.pyver }} cache: 'pip' @@ -38,7 +38,7 @@ jobs: JAXLIB_RELEASE: true run: | python -m pip install uv~=0.5.30 - python -m uv pip install -r build/test-requirements.txt \ + python -m uv pip install -r build/test-requirements.txt ` --upgrade numpy==2.0.0 scipy==1.13.1 "C:\\msys64\\;C:\\msys64\\usr\\bin\\;" >> $env:GITHUB_PATH python.exe build\build.py build --wheels=jaxlib ` @@ -58,7 +58,7 @@ jobs: JAX_SKIP_SLOW_TESTS: true PY_COLORS: 1 run: | - python -m uv pip install --find-links ${{ github.workspace }}\dist jaxlib \ + python -m uv pip install --find-links ${{ github.workspace }}\dist jaxlib ` -e ${{ github.workspace }} echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS" pytest -n auto --tb=short tests examples diff --git a/.github/workflows/windows_ci.yml b/.github/workflows/windows_ci.yml index fc2b63396f56..5a435023ffda 100644 --- a/.github/workflows/windows_ci.yml +++ b/.github/workflows/windows_ci.yml @@ -35,7 +35,7 @@ jobs: with: path: jax - - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 + - uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0 with: python-version: ${{ matrix.pyver }} cache: 'pip' diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 27ccc6d831f3..89ce80d9a815 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -15,6 +15,7 @@ repos: - id: check-merge-conflict - id: check-toml - id: check-yaml + exclude: examples/k8s/svc-acct.yaml - id: end-of-file-fixer # only include python files files: \.py$ diff --git a/.readthedocs.yml b/.readthedocs.yml index 6f807aa82377..3b7ba275a0d6 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -9,6 +9,12 @@ build: os: "ubuntu-22.04" tools: python: "3.10" + jobs: + post_checkout: + # Skip building PRs unless tagged with the "documentation" label. + - | + [ "${READTHEDOCS_VERSION_TYPE}" != "external" ] && echo "Building latest" && exit 0 + (curl -sL https://api.github.com/repos/jax-ml/jax/issues/${READTHEDOCS_VERSION}/labels | grep -q "https://api.github.com/repos/jax-ml/jax/labels/documentation") && echo "Building PR with label" || exit 183 # Build documentation in the docs/ directory with Sphinx sphinx: diff --git a/BUILD.bazel b/BUILD.bazel index 33cbefd29f0b..8dbf2bed0902 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -12,10 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +load( + "@xla//third_party/py:py_import.bzl", + "py_import", +) load("@xla//third_party/py:python_wheel.bzl", "collect_data_files", "transitive_py_deps") load( "//jaxlib:jax.bzl", + "jax_source_package", "jax_wheel", + "py_deps", + "pytype_test", ) collect_data_files( @@ -41,6 +48,7 @@ transitive_py_deps( "//jax:sparse_test_util", "//jax:test_util", "//jax/_src/lib", + "//jax/_src/pallas/fuser", "//jax/_src/pallas/mosaic_gpu", "//jax/experimental/array_serialization:serialization", "//jax/experimental/jax2tf", @@ -65,20 +73,98 @@ py_binary( ], ) +WHEEL_SOURCE_FILES = [ + ":transitive_py_data", + ":transitive_py_deps", + "//jax:py.typed", + "AUTHORS", + "LICENSE", + "README.md", + "pyproject.toml", + "setup.py", +] + jax_wheel( name = "jax_wheel", - build_wheel_only = False, platform_independent = True, - source_files = [ - ":transitive_py_data", - ":transitive_py_deps", - "//jax:py.typed", - "AUTHORS", - "LICENSE", - "README.md", - "pyproject.toml", - "setup.py", - ], + source_files = WHEEL_SOURCE_FILES, + wheel_binary = ":build_wheel", + wheel_name = "jax", +) + +jax_wheel( + name = "jax_wheel_editable", + editable = True, + platform_independent = True, + source_files = WHEEL_SOURCE_FILES, wheel_binary = ":build_wheel", wheel_name = "jax", ) + +jax_source_package( + name = "jax_source_package", + source_files = WHEEL_SOURCE_FILES, + source_package_binary = ":build_wheel", + source_package_name = "jax", +) + +genrule( + name = "wheel_additives", + srcs = [ + "//jax:internal_export_back_compat_test_util", + "//jax:internal_test_harnesses", + "//jax:internal_test_util", + "//jax:internal_export_back_compat_test_data", + "//jax:experimental/pallas/ops/tpu/random/philox.py", + "//jax:experimental/pallas/ops/tpu/random/prng_utils.py", + "//jax:experimental/pallas/ops/tpu/random/threefry.py", + "//jax/experimental/mosaic/gpu/examples:flash_attention.py", + "//jax/experimental/mosaic/gpu/examples:matmul.py", + ], + outs = ["wheel_additives.zip"], + cmd = "$(location @bazel_tools//tools/zip:zipper) c $@ $(SRCS)", + tools = ["@bazel_tools//tools/zip:zipper"], +) + +COMMON_DEPS = py_deps([ + "absl/testing", + "numpy", + "ml_dtypes", + "scipy", + "opt_einsum", + "hypothesis", + "cloudpickle", + "flatbuffers", +]) + +py_import( + name = "jax_py_import", + wheel = ":jax_wheel", + wheel_deps = [":wheel_additives"], + deps = COMMON_DEPS, +) + +# This target is used to add more sources to the jax wheel. +# This is needed for the tests that depend on jax and use modules that are not part of +# the jax wheel, but share the same package paths as the modules in the jax wheel. +py_import( + name = "jax_wheel_with_internal_test_util", + wheel = "@pypi_jax//:whl", + wheel_deps = [":wheel_additives"], + deps = COMMON_DEPS, +) + +pytype_test( + name = "jax_wheel_size_test", + srcs = ["//jaxlib/tools:wheel_size_test.py"], + args = [ + "--wheel-path=$(location :jax_wheel)", + "--max-size-mib=5", + ], + data = [":jax_wheel"], + main = "wheel_size_test.py", + tags = [ + "manual", + "notap", + ], +) diff --git a/CHANGELOG.md b/CHANGELOG.md index c30877ecae14..7e027db7e32b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,6 @@ # Change log -Best viewed [here](https://jax.readthedocs.io/en/latest/changelog.html). +Best viewed [here](https://docs.jax.dev/en/latest/changelog.html). For the changes specific to the experimental Pallas APIs, see {ref}`pallas-changelog`. @@ -16,6 +16,88 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. ## Unreleased +## JAX 0.6.0 (April 16, 2025) + +* Breaking changes + + * {func}`jax.numpy.array` no longer accepts `None`. This behavior was + deprecated since November 2023 and is now removed. + * Removed the `config.jax_data_dependent_tracing_fallback` config option, + which was added temporarily in v0.4.36 to allow users to opt out of the + new "stackless" tracing machinery. + * Removed the `config.jax_eager_pmap` config option. + * Disallow the calling of `lower` and `trace` AOT APIs on the result + of `jax.jit` if there have been subsequent wrappers applied. + Previously this worked, but silently ignored the wrappers. + The workaround is to apply `jax.jit` last among the wrappers, + and similarly for `jax.pmap`. + See {jax-issue}`#27873`. + * The `cuda12_pip` extra for `jax` has been removed; use `pip install jax[cuda12]` + instead. + +* Changes + * The minimum CuDNN version is v9.8. + * JAX is now built using CUDA 12.8. All versions of CUDA 12.1 or newer remain + supported. + * JAX package extras are now updated to use dash instead of underscore to + align with PEP 685. For instance, if you were previously using `pip install jax[cuda12_local]` + to install JAX, run `pip install jax[cuda12-local]` instead. + * {func}`jax.jit` now requires `fun` to be passed by position, and additional + arguments to be passed by keyword. Doing otherwise will result in a + DeprecationWarning in v0.6.X, and an error in starting in v0.7.X. + +* Deprecations + + * {func}`jax.tree_util.build_tree` is deprecated. Use {func}`jax.tree.unflatten` + instead. + * Implemented host callback handlers for CPU and GPU devices using XLA's FFI + and removed existing CPU/GPU handlers using XLA's custom call. + * All APIs in `jax.lib.xla_extension` are now deprecated. + * `jax.interpreters.mlir.hlo` and `jax.interpreters.mlir.func_dialect`, + which were accidental exports, have been removed. If needed, they are + available from `jax.extend.mlir`. + * `jax.interpreters.mlir.custom_call` is deprecated. The APIs provided by + {mod}`jax.ffi` should be used instead. + * The deprecated use of {func}`jax.ffi.ffi_call` with inline arguments is no + longer supported. {func}`~jax.ffi.ffi_call` now unconditionally returns a + callable. + * The following exports in `jax.lib.xla_client` are deprecated: + `get_topology_for_devices`, `heap_profile`, `mlir_api_version`, `Client`, + `CompileOptions`, `DeviceAssignment`, `Frame`, `HloSharding`, `OpSharding`, + `Traceback`. + * The following internal APIs in `jax.util` are deprecated: + `HashableFunction`, `as_hashable_function`, `cache`, `safe_map`, `safe_zip`, + `split_dict`, `split_list`, `split_list_checked`, `split_merge`, `subvals`, + `toposort`, `unzip2`, `wrap_name`, and `wraps`. + * `jax.dlpack.to_dlpack` has been deprecated. You can usually pass a JAX + `Array` directly to the `from_dlpack` function of another framework. If you + need the functionality of `to_dlpack`, use the `__dlpack__` attribute of an + array. + * `jax.lax.infeed`, `jax.lax.infeed_p`, `jax.lax.outfeed`, and + `jax.lax.outfeed_p` are deprecated and will be removed in JAX v0.7.0. + * Several previously-deprecated APIs have been removed, including: + * From `jax.lib.xla_client`: `ArrayImpl`, `FftType`, `PaddingType`, + `PrimitiveType`, `XlaBuilder`, `dtype_to_etype`, + `ops`, `register_custom_call_target`, `shape_from_pyval`, `Shape`, + `XlaComputation`. + * From `jax.lib.xla_extension`: `ArrayImpl`, `XlaRuntimeError`. + * From `jax`: `jax.treedef_is_leaf`, `jax.tree_flatten`, `jax.tree_map`, + `jax.tree_leaves`, `jax.tree_structure`, `jax.tree_transpose`, and + `jax.tree_unflatten`. Replacements can be found in {mod}`jax.tree` or + {mod}`jax.tree_util`. + * From `jax.core`: `AxisSize`, `ClosedJaxpr`, `EvalTrace`, `InDBIdx`, `InputType`, + `Jaxpr`, `JaxprEqn`, `Literal`, `MapPrimitive`, `OpaqueTraceState`, `OutDBIdx`, + `Primitive`, `Token`, `TRACER_LEAK_DEBUGGER_WARNING`, `Var`, `concrete_aval`, + `dedup_referents`, `escaped_tracer_error`, `extend_axis_env_nd`, `full_lower`, `get_referent`, `jaxpr_as_fun`, `join_effects`, `lattice_join`, + `leaked_tracer_error`, `maybe_find_leaked_tracers`, `raise_to_shaped`, + `raise_to_shaped_mappings`, `reset_trace_state`, `str_eqn_compact`, + `substitute_vars_in_output_ty`, `typecompat`, and `used_axis_names_jaxpr`. Most + have no public replacement, though a few are available at {mod}`jax.extend.core`. + * The `vectorized` argument to {func}`~jax.pure_callback` and + {func}`~jax.ffi.ffi_call`. Use the `vmap_method` parameter instead. + +## jax 0.5.3 (Mar 19, 2025) + * New Features * Added a `allow_negative_indices` option to {func}`jax.lax.dynamic_slice`, @@ -81,7 +163,7 @@ Patch release of 0.5.1 ## jax 0.5.0 (Jan 17, 2025) As of this release, JAX now uses -[effort-based versioning](https://jax.readthedocs.io/en/latest/jep/25516-effver.html). +[effort-based versioning](https://docs.jax.dev/en/latest/jep/25516-effver.html). Since this release makes a breaking change to PRNG key semantics that may require users to update their code, we are bumping the "meso" version of JAX to signify this. @@ -172,7 +254,7 @@ to signify this. * New Features * {func}`jax.export.export` can be used for device-polymorphic export with shardings constructed with {func}`jax.sharding.AbstractMesh`. - See the [jax.export documentation](https://jax.readthedocs.io/en/latest/export/export.html#device-polymorphic-export). + See the [jax.export documentation](https://docs.jax.dev/en/latest/export/export.html#device-polymorphic-export). * Added {func}`jax.lax.split`. This is a primitive version of {func}`jax.numpy.split`, added because it yields a more compact transpose during automatic differentiation. @@ -214,7 +296,7 @@ This is a patch release of jax 0.4.36. Only "jax" was released at this version. after being deprecated in JAX v0.4.31. Instead use `xb = jax.lib.xla_bridge`, `xc = jax.lib.xla_client`, and `xe = jax.lib.xla_extension`. * The deprecated module `jax.experimental.export` has been removed. It was replaced - by {mod}`jax.export` in JAX v0.4.30. See the [migration guide](https://jax.readthedocs.io/en/latest/export/export.html#migration-guide-from-jax-experimental-export) + by {mod}`jax.export` in JAX v0.4.30. See the [migration guide](https://docs.jax.dev/en/latest/export/export.html#migration-guide-from-jax-experimental-export) for information on migrating to the new API. * The `initial` argument to {func}`jax.nn.softmax` and {func}`jax.nn.log_softmax` has been removed, after being deprecated in v0.4.27. @@ -252,7 +334,7 @@ This is a patch release of jax 0.4.36. Only "jax" was released at this version. call that we guarantee export stability. This is because this custom call relies on Triton IR, which is not guaranteed to be stable. If you need to export code that uses this custom call, you can use the `disabled_checks` - parameter. See more details in the [documentation](https://jax.readthedocs.io/en/latest/export/export.html#compatibility-guarantees-for-custom-calls). + parameter. See more details in the [documentation](https://docs.jax.dev/en/latest/export/export.html#compatibility-guarantees-for-custom-calls). * New Features * {func}`jax.jit` got a new `compiler_options: dict[str, Any]` argument, for @@ -532,7 +614,7 @@ See the 0.4.33 release notes for more details. * Added an API for exporting and serializing JAX functions. This used to exist in `jax.experimental.export` (which is being deprecated), and will now live in `jax.export`. - See the [documentation](https://jax.readthedocs.io/en/latest/export/index.html). + See the [documentation](https://docs.jax.dev/en/latest/export/index.html). * Deprecations * Internal pretty-printing tools `jax.core.pp_*` are deprecated, and will be removed @@ -541,7 +623,7 @@ See the 0.4.33 release notes for more details. release. This previously was the case, but there was an inadvertent regression in the last several JAX releases. * `jax.experimental.export` is deprecated. Use {mod}`jax.export` instead. - See the [migration guide](https://jax.readthedocs.io/en/latest/export/export.html#migration-guide-from-jax-experimental-export). + See the [migration guide](https://docs.jax.dev/en/latest/export/export.html#migration-guide-from-jax-experimental-export). * Passing an array in place of a dtype is now deprecated in most cases; e.g. for arrays `x` and `y`, `x.astype(y)` will raise a warning. To silence it use `x.astype(y.dtype)`. * `jax.xla_computation` is deprecated and will be removed in a future release. @@ -753,7 +835,7 @@ See the 0.4.33 release notes for more details. deprecated. Use `jax.experimental.shard_map` or `jax.vmap` with the `spmd_axis_name` argument for expressing SPMD device-parallel computations. * The `jax.experimental.host_callback` module is deprecated. - Use instead the [new JAX external callbacks](https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html). + Use instead the [new JAX external callbacks](https://docs.jax.dev/en/latest/notebooks/external_callbacks.html). Added `JAX_HOST_CALLBACK_LEGACY` flag to assist in the transition to the new callbacks. See {jax-issue}`#20385` for a discussion. * Passing arguments to {func}`jax.numpy.array_equal` and {func}`jax.numpy.array_equiv` @@ -1225,9 +1307,9 @@ See the 0.4.33 release notes for more details. * Deprecations * Python 3.8 support has been dropped as per - https://jax.readthedocs.io/en/latest/deprecation.html + https://docs.jax.dev/en/latest/deprecation.html * JAX now requires NumPy 1.22 or newer as per - https://jax.readthedocs.io/en/latest/deprecation.html + https://docs.jax.dev/en/latest/deprecation.html * Passing optional arguments to {func}`jax.numpy.ndarray.at` by position is no longer supported, after being deprecated in JAX version 0.4.7. For example, instead of `x.at[i].get(True)`, use `x.at[i].get(indices_are_sorted=True)` @@ -1272,7 +1354,7 @@ See the 0.4.33 release notes for more details. * Deprecations * Python 3.8 support has been dropped as per - https://jax.readthedocs.io/en/latest/deprecation.html + https://docs.jax.dev/en/latest/deprecation.html ## jax 0.4.13 (June 22, 2023) @@ -1451,7 +1533,7 @@ See the 0.4.33 release notes for more details. ## jax 0.4.7 (March 27, 2023) * Changes - * As per https://jax.readthedocs.io/en/latest/jax_array_migration.html#jax-array-migration + * As per https://docs.jax.dev/en/latest/jax_array_migration.html#jax-array-migration `jax.config.jax_array` cannot be disabled anymore. * `jax.config.jax_jit_pjit_api_merge` cannot be disabled anymore. * {func}`jax.experimental.jax2tf.convert` now supports the `native_serialization` @@ -1535,7 +1617,7 @@ Changes: on top of each other. With the `jit`-`pjit` implementation merge, `jit` becomes an initial style primitive which means that we trace to jaxpr as early as possible. For more information see - [this section in autodidax](https://jax.readthedocs.io/en/latest/autodidax.html#on-the-fly-final-style-and-staged-initial-style-processing). + [this section in autodidax](https://docs.jax.dev/en/latest/autodidax.html#on-the-fly-final-style-and-staged-initial-style-processing). Moving to initial style should simplify JAX's internals and make development of features like dynamic shapes, etc easier. You can disable it only via the environment variable i.e. @@ -1620,9 +1702,9 @@ Changes: simplifies and unifies JAX internals, and allows us to unify `jit` and `pjit`. `jax.Array` has been enabled by default in JAX 0.4 and makes some breaking change to the `pjit` API. The [jax.Array migration - guide](https://jax.readthedocs.io/en/latest/jax_array_migration.html) can + guide](https://docs.jax.dev/en/latest/jax_array_migration.html) can help you migrate your codebase to `jax.Array`. You can also look at the - [Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) + [Distributed arrays and automatic parallelization](https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) tutorial to understand the new concepts. * `PartitionSpec` and `Mesh` are now out of experimental. The new API endpoints are `jax.sharding.PartitionSpec` and `jax.sharding.Mesh`. @@ -1651,7 +1733,7 @@ Changes: * The behavior of `XLA_PYTHON_CLIENT_MEM_FRACTION=.XX` has been changed to allocate XX% of the total GPU memory instead of the previous behavior of using currently available GPU memory to calculate preallocation. Please refer to - [GPU memory allocation](https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html) for + [GPU memory allocation](https://docs.jax.dev/en/latest/gpu_memory_allocation.html) for more details. * The deprecated method `.block_host_until_ready()` has been removed. Use `.block_until_ready()` instead. @@ -1765,7 +1847,7 @@ Changes: * Changes * Ahead-of-time lowering and compilation functionality (tracked in {jax-issue}`#7733`) is stable and public. See [the - overview](https://jax.readthedocs.io/en/latest/aot.html) and the API docs + overview](https://docs.jax.dev/en/latest/aot.html) and the API docs for {mod}`jax.stages`. * Introduced {class}`jax.Array`, intended to be used for both `isinstance` checks and type annotations for array types in JAX. Notice that this included some subtle @@ -1786,7 +1868,7 @@ Changes: * Breaking changes * {func}`jax.checkpoint`, also known as {func}`jax.remat`, no longer supports the `concrete` option, following the previous version's deprecation; see - [JEP 11830](https://jax.readthedocs.io/en/latest/jep/11830-new-remat-checkpoint.html). + [JEP 11830](https://docs.jax.dev/en/latest/jep/11830-new-remat-checkpoint.html). * Changes * Added {func}`jax.pure_callback` that enables calling back to pure Python functions from compiled functions (e.g. functions decorated with `jax.jit` or `jax.pmap`). * Deprecations: @@ -1798,7 +1880,7 @@ Changes: * [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.15...main). * Breaking changes * Support for NumPy 1.19 has been dropped, per the - [deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html). + [deprecation policy](https://docs.jax.dev/en/latest/deprecation.html). Please upgrade to NumPy 1.20 or newer. * Changes * Added {mod}`jax.debug` that includes utilities for runtime value debugging such at {func}`jax.debug.print` and {func}`jax.debug.breakpoint`. @@ -1816,7 +1898,7 @@ Changes: {mod}`jax.example_libraries.optimizers`. * {func}`jax.checkpoint`, also known as {func}`jax.remat`, has a new implementation switched on by default, meaning the old implementation is - deprecated; see [JEP 11830](https://jax.readthedocs.io/en/latest/jep/11830-new-remat-checkpoint.html). + deprecated; see [JEP 11830](https://docs.jax.dev/en/latest/jep/11830-new-remat-checkpoint.html). ## jax 0.3.15 (July 22, 2022) * [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.14...jax-v0.3.15). @@ -1948,7 +2030,7 @@ Changes: * {func}`jax.numpy.linalg.matrix_rank` on TPUs now accepts complex input. * {func}`jax.scipy.cluster.vq.vq` has been added. * `jax.experimental.maps.mesh` has been deleted. - Please use `jax.experimental.maps.Mesh`. Please see https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.maps.Mesh.html#jax.experimental.maps.Mesh + Please use `jax.experimental.maps.Mesh`. Please see https://docs.jax.dev/en/latest/_autosummary/jax.experimental.maps.Mesh.html#jax.experimental.maps.Mesh for more information. * {func}`jax.scipy.linalg.qr` now returns a length-1 tuple rather than the raw array when `mode='r'`, in order to match the behavior of `scipy.linalg.qr` ({jax-issue}`#10452`) @@ -2064,7 +2146,7 @@ Changes: * Changes: * The functions `jax.ops.index_update`, `jax.ops.index_add`, which were deprecated in 0.2.22, have been removed. Please use - [the `.at` property on JAX arrays](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html) + [the `.at` property on JAX arrays](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html) instead, e.g., `x.at[idx].set(y)`. * Moved `jax.experimental.ann.approx_*_k` into `jax.lax`. These functions are optimized alternatives to `jax.lax.top_k`. @@ -2110,13 +2192,13 @@ Changes: commits](https://github.com/jax-ml/jax/compare/jax-v0.2.28...jax-v0.3.0). * Changes - * jax version has been bumped to 0.3.0. Please see the [design doc](https://jax.readthedocs.io/en/latest/design_notes/jax_versioning.html) + * jax version has been bumped to 0.3.0. Please see the [design doc](https://docs.jax.dev/en/latest/design_notes/jax_versioning.html) for the explanation. ## jaxlib 0.3.0 (Feb 10, 2022) * Changes * Bazel 5.0.0 is now required to build jaxlib. - * jaxlib version has been bumped to 0.3.0. Please see the [design doc](https://jax.readthedocs.io/en/latest/design_notes/jax_versioning.html) + * jaxlib version has been bumped to 0.3.0. Please see the [design doc](https://docs.jax.dev/en/latest/design_notes/jax_versioning.html) for the explanation. ## jax 0.2.28 (Feb 1, 2022) @@ -2138,7 +2220,7 @@ Changes: by default. * Breaking changes * Support for NumPy 1.18 has been dropped, per the - [deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html). + [deprecation policy](https://docs.jax.dev/en/latest/deprecation.html). Please upgrade to a supported NumPy version. * Bug fixes * Fixed a bug where apparently identical pytreedef objects constructed by different routes @@ -2150,7 +2232,7 @@ Changes: * Breaking changes: * Support for NumPy 1.18 has been dropped, per the - [deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html). + [deprecation policy](https://docs.jax.dev/en/latest/deprecation.html). Please upgrade to a supported NumPy version. * The host_callback primitives have been simplified to drop the special autodiff handling for hcb.id_tap and id_print. @@ -2277,7 +2359,7 @@ Changes: * Deprecations * The functions `jax.ops.index_update`, `jax.ops.index_add` etc. are deprecated and will be removed in a future JAX release. Please use - [the `.at` property on JAX arrays](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html) + [the `.at` property on JAX arrays](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html) instead, e.g., `x.at[idx].set(y)`. For now, these functions produce a `DeprecationWarning`. * New features: @@ -2341,7 +2423,7 @@ Changes: commits](https://github.com/jax-ml/jax/compare/jax-v0.2.18...jax-v0.2.19). * Breaking changes: * Support for NumPy 1.17 has been dropped, per the - [deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html). + [deprecation policy](https://docs.jax.dev/en/latest/deprecation.html). Please upgrade to a supported NumPy version. * The `jit` decorator has been added around the implementation of a number of operators on JAX arrays. This speeds up dispatch times for common @@ -2362,10 +2444,10 @@ Changes: ## jaxlib 0.1.70 (Aug 9, 2021) * Breaking changes: * Support for Python 3.6 has been dropped, per the - [deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html). + [deprecation policy](https://docs.jax.dev/en/latest/deprecation.html). Please upgrade to a supported Python version. * Support for NumPy 1.17 has been dropped, per the - [deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html). + [deprecation policy](https://docs.jax.dev/en/latest/deprecation.html). Please upgrade to a supported NumPy version. * The host_callback mechanism now uses one thread per local device for @@ -2379,7 +2461,7 @@ Changes: * Breaking changes: * Support for Python 3.6 has been dropped, per the - [deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html). + [deprecation policy](https://docs.jax.dev/en/latest/deprecation.html). Please upgrade to a supported Python version. * The minimum jaxlib version is now 0.1.69. * The `backend` argument to {py:func}`jax.dlpack.from_dlpack` has been @@ -2428,7 +2510,7 @@ Changes: * Breaking changes: * Support for NumPy 1.16 has been dropped, per the - [deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html). + [deprecation policy](https://docs.jax.dev/en/latest/deprecation.html). * Bug fixes: * Fixed bug that prevented round-tripping from JAX to TF and back: @@ -2968,7 +3050,7 @@ Changes: * Support for reduction over subsets of a pmapped axis using `axis_index_groups` {jax-issue}`#2382`. * Experimental support for printing and calling host-side Python function from - compiled code. See [id_print and id_tap](https://jax.readthedocs.io/en/latest/jax.experimental.host_callback.html) + compiled code. See [id_print and id_tap](https://docs.jax.dev/en/latest/jax.experimental.host_callback.html) ({jax-issue}`#3006`). * Notable changes: * The visibility of names exported from {mod}`jax.numpy` has been @@ -3040,7 +3122,7 @@ Changes: ## jax 0.1.63 (April 12, 2020) * [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.62...jax-v0.1.63). -* Added `jax.custom_jvp` and `jax.custom_vjp` from {jax-issue}`#2026`, see the [tutorial notebook](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html). Deprecated `jax.custom_transforms` and removed it from the docs (though it still works). +* Added `jax.custom_jvp` and `jax.custom_vjp` from {jax-issue}`#2026`, see the [tutorial notebook](https://docs.jax.dev/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html). Deprecated `jax.custom_transforms` and removed it from the docs (though it still works). * Add `scipy.sparse.linalg.cg` {jax-issue}`#2566`. * Changed how Tracers are printed to show more useful information for debugging {jax-issue}`#2591`. * Made `jax.numpy.isclose` handle `nan` and `inf` correctly {jax-issue}`#2501`. diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 314d4387a044..046d3df3195c 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,4 +1,4 @@ # Contributing to JAX For information on how to contribute to JAX, see -[Contributing to JAX](https://jax.readthedocs.io/en/latest/contributing.html) +[Contributing to JAX](https://docs.jax.dev/en/latest/contributing.html) diff --git a/README.md b/README.md index 0aca7cf58e6e..00391f314044 100644 --- a/README.md +++ b/README.md @@ -11,8 +11,8 @@ | [**Transformations**](#transformations) | [**Install guide**](#installation) | [**Neural net libraries**](#neural-network-libraries) -| [**Change logs**](https://jax.readthedocs.io/en/latest/changelog.html) -| [**Reference docs**](https://jax.readthedocs.io/en/latest/) +| [**Change logs**](https://docs.jax.dev/en/latest/changelog.html) +| [**Reference docs**](https://docs.jax.dev/en/latest/) ## What is JAX? @@ -48,7 +48,7 @@ are instances of such transformations. Others are parallel programming of multiple accelerators, with more to come. This is a research project, not an official Google product. Expect -[sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html). +[sharp edges](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html). Please help by trying it out, [reporting bugs](https://github.com/jax-ml/jax/issues), and letting us know what you think! @@ -83,15 +83,15 @@ perex_grads = jit(vmap(grad_loss, in_axes=(None, 0, 0))) # fast per-example gra ## Quickstart: Colab in the Cloud Jump right in using a notebook in your browser, connected to a Google Cloud GPU. Here are some starter notebooks: -- [The basics: NumPy on accelerators, `grad` for differentiation, `jit` for compilation, and `vmap` for vectorization](https://jax.readthedocs.io/en/latest/quickstart.html) +- [The basics: NumPy on accelerators, `grad` for differentiation, `jit` for compilation, and `vmap` for vectorization](https://docs.jax.dev/en/latest/quickstart.html) - [Training a Simple Neural Network, with TensorFlow Dataset Data Loading](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb) **JAX now runs on Cloud TPUs.** To try out the preview, see the [Cloud TPU Colabs](https://github.com/jax-ml/jax/tree/main/cloud_tpu_colabs). For a deeper dive into JAX: -- [The Autodiff Cookbook, Part 1: easy and powerful automatic differentiation in JAX](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html) -- [Common gotchas and sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html) +- [The Autodiff Cookbook, Part 1: easy and powerful automatic differentiation in JAX](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html) +- [Common gotchas and sharp edges](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html) - See the [full list of notebooks](https://github.com/jax-ml/jax/tree/main/docs/notebooks). @@ -105,7 +105,7 @@ Here are four transformations of primary interest: `grad`, `jit`, `vmap`, and JAX has roughly the same API as [Autograd](https://github.com/hips/autograd). The most popular function is -[`grad`](https://jax.readthedocs.io/en/latest/jax.html#jax.grad) +[`grad`](https://docs.jax.dev/en/latest/jax.html#jax.grad) for reverse-mode gradients: ```python @@ -129,13 +129,13 @@ print(grad(grad(grad(tanh)))(1.0)) ``` For more advanced autodiff, you can use -[`jax.vjp`](https://jax.readthedocs.io/en/latest/jax.html#jax.vjp) for +[`jax.vjp`](https://docs.jax.dev/en/latest/jax.html#jax.vjp) for reverse-mode vector-Jacobian products and -[`jax.jvp`](https://jax.readthedocs.io/en/latest/jax.html#jax.jvp) for +[`jax.jvp`](https://docs.jax.dev/en/latest/jax.html#jax.jvp) for forward-mode Jacobian-vector products. The two can be composed arbitrarily with one another, and with other JAX transformations. Here's one way to compose those to make a function that efficiently computes [full Hessian -matrices](https://jax.readthedocs.io/en/latest/_autosummary/jax.hessian.html#jax.hessian): +matrices](https://docs.jax.dev/en/latest/_autosummary/jax.hessian.html#jax.hessian): ```python from jax import jit, jacfwd, jacrev @@ -160,15 +160,15 @@ print(abs_val_grad(-1.0)) # prints -1.0 (abs_val is re-evaluated) ``` See the [reference docs on automatic -differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) +differentiation](https://docs.jax.dev/en/latest/jax.html#automatic-differentiation) and the [JAX Autodiff -Cookbook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html) +Cookbook](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html) for more. ### Compilation with `jit` You can use XLA to compile your functions end-to-end with -[`jit`](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit), +[`jit`](https://docs.jax.dev/en/latest/jax.html#just-in-time-compilation-jit), used either as an `@jit` decorator or as a higher-order function. ```python @@ -189,12 +189,12 @@ You can mix `jit` and `grad` and any other JAX transformation however you like. Using `jit` puts constraints on the kind of Python control flow the function can use; see -the tutorial on [Control Flow and Logical Operators with JIT](https://jax.readthedocs.io/en/latest/control-flow.html) +the tutorial on [Control Flow and Logical Operators with JIT](https://docs.jax.dev/en/latest/control-flow.html) for more. ### Auto-vectorization with `vmap` -[`vmap`](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) is +[`vmap`](https://docs.jax.dev/en/latest/jax.html#vectorization-vmap) is the vectorizing map. It has the familiar semantics of mapping a function along array axes, but instead of keeping the loop on the outside, it pushes the loop down into a @@ -259,7 +259,7 @@ differentiation for fast Jacobian and Hessian matrix calculations in ### SPMD programming with `pmap` For parallel programming of multiple accelerators, like multiple GPUs, use -[`pmap`](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap). +[`pmap`](https://docs.jax.dev/en/latest/jax.html#parallelization-pmap). With `pmap` you write single-program multiple-data (SPMD) programs, including fast parallel collective communication operations. Applying `pmap` will mean that the function you write is compiled by XLA (similarly to `jit`), then @@ -284,7 +284,7 @@ print(pmap(jnp.mean)(result)) ``` In addition to expressing pure maps, you can use fast [collective communication -operations](https://jax.readthedocs.io/en/latest/jax.lax.html#parallel-operators) +operations](https://docs.jax.dev/en/latest/jax.lax.html#parallel-operators) between devices: ```python @@ -341,20 +341,20 @@ for more. For a more thorough survey of current gotchas, with examples and explanations, we highly recommend reading the [Gotchas -Notebook](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html). +Notebook](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html). Some standouts: 1. JAX transformations only work on [pure functions](https://en.wikipedia.org/wiki/Pure_function), which don't have side-effects and respect [referential transparency](https://en.wikipedia.org/wiki/Referential_transparency) (i.e. object identity testing with `is` isn't preserved). If you use a JAX transformation on an impure Python function, you might see an error like `Exception: Can't lift Traced...` or `Exception: Different traces at same level`. 1. [In-place mutating updates of - arrays](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#in-place-updates), like `x[i] += y`, aren't supported, but [there are functional alternatives](https://jax.readthedocs.io/en/latest/jax.ops.html). Under a `jit`, those functional alternatives will reuse buffers in-place automatically. + arrays](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html#in-place-updates), like `x[i] += y`, aren't supported, but [there are functional alternatives](https://docs.jax.dev/en/latest/jax.ops.html). Under a `jit`, those functional alternatives will reuse buffers in-place automatically. 1. [Random numbers are - different](https://jax.readthedocs.io/en/latest/random-numbers.html), but for [good reasons](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md). + different](https://docs.jax.dev/en/latest/random-numbers.html), but for [good reasons](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md). 1. If you're looking for [convolution - operators](https://jax.readthedocs.io/en/latest/notebooks/convolutions.html), + operators](https://docs.jax.dev/en/latest/notebooks/convolutions.html), they're in the `jax.lax` package. 1. JAX enforces single-precision (32-bit, e.g. `float32`) values by default, and [to enable - double-precision](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision) + double-precision](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision) (64-bit, e.g. `float64`) one needs to set the `jax_enable_x64` variable at startup (or set the environment variable `JAX_ENABLE_X64=True`). On TPU, JAX uses 32-bit values by default for everything _except_ internal @@ -368,14 +368,14 @@ Some standouts: and NumPy types aren't preserved, namely `np.add(1, np.array([2], np.float32)).dtype` is `float64` rather than `float32`. 1. Some transformations, like `jit`, [constrain how you can use Python control - flow](https://jax.readthedocs.io/en/latest/control-flow.html). + flow](https://docs.jax.dev/en/latest/control-flow.html). You'll always get loud errors if something goes wrong. You might have to use [`jit`'s `static_argnums` - parameter](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit), + parameter](https://docs.jax.dev/en/latest/jax.html#just-in-time-compilation-jit), [structured control flow - primitives](https://jax.readthedocs.io/en/latest/jax.lax.html#control-flow-operators) + primitives](https://docs.jax.dev/en/latest/jax.lax.html#control-flow-operators) like - [`lax.scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html#jax.lax.scan), + [`lax.scan`](https://docs.jax.dev/en/latest/_autosummary/jax.lax.scan.html#jax.lax.scan), or just use `jit` on smaller subfunctions. ## Installation @@ -403,7 +403,7 @@ Some standouts: | Mac GPU | Follow [Apple's instructions](https://developer.apple.com/metal/jax/). | | Intel GPU | Follow [Intel's instructions](https://github.com/intel/intel-extension-for-openxla/blob/main/docs/acc_jax.md). | -See [the documentation](https://jax.readthedocs.io/en/latest/installation.html) +See [the documentation](https://docs.jax.dev/en/latest/installation.html) for information on alternative installation strategies. These include compiling from source, installing with Docker, using other versions of CUDA, a community-supported conda build, and answers to some frequently-asked questions. @@ -417,7 +417,7 @@ for training neural networks in JAX. If you want a fully featured library for ne training with examples and how-to guides, try [Flax](https://github.com/google/flax) and its [documentation site](https://flax.readthedocs.io/en/latest/nnx/index.html). -Check out the [JAX Ecosystem section](https://jax.readthedocs.io/en/latest/#ecosystem) +Check out the [JAX Ecosystem section](https://docs.jax.dev/en/latest/#ecosystem) on the JAX documentation site for a list of JAX-based network libraries, which includes [Optax](https://github.com/deepmind/optax) for gradient processing and optimization, [chex](https://github.com/deepmind/chex) for reliable code and testing, and @@ -452,7 +452,7 @@ paper. ## Reference documentation For details about the JAX API, see the -[reference documentation](https://jax.readthedocs.io/). +[reference documentation](https://docs.jax.dev/). For getting started as a JAX developer, see the -[developer documentation](https://jax.readthedocs.io/en/latest/developer.html). +[developer documentation](https://docs.jax.dev/en/latest/developer.html). diff --git a/WORKSPACE b/WORKSPACE index 129488281ea9..5c093ec2228f 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -14,8 +14,10 @@ python_init_repositories( "3.12": "//build:requirements_lock_3_12.txt", "3.13": "//build:requirements_lock_3_13.txt", "3.13-ft": "//build:requirements_lock_3_13_ft.txt", + "3.14-ft": "//build:requirements_lock_3_14_ft.txt", }, local_wheel_inclusion_list = [ + "jax-*", "jaxlib*", "jax_cuda*", "jax-cuda*", diff --git a/benchmarks/api_benchmark.py b/benchmarks/api_benchmark.py index cabebce2227c..a62b78d66ced 100644 --- a/benchmarks/api_benchmark.py +++ b/benchmarks/api_benchmark.py @@ -847,7 +847,7 @@ def safe_map(state): args = tuple(list(range(state.range(0))) for _ in range(state.range(1))) def f(*args): return tuple(args) while state: - jax.util.safe_map(f, *args) + jax._src.util.safe_map(f, *args) @google_benchmark.register @google_benchmark.option.arg_names(['arg_lengths', 'num_args']) @@ -855,7 +855,7 @@ def f(*args): return tuple(args) def safe_zip(state): args = tuple(list(range(state.range(0))) for _ in range(state.range(1))) while state: - jax.util.safe_zip(*args) + jax._src.util.safe_zip(*args) @google_benchmark.register diff --git a/benchmarks/tracing_benchmark.py b/benchmarks/tracing_benchmark.py new file mode 100644 index 000000000000..e06ad538d476 --- /dev/null +++ b/benchmarks/tracing_benchmark.py @@ -0,0 +1,76 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Benchmarks for Jax tracing.""" + +import google_benchmark +import jax +from jax import random +from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel as splash +from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask as mask_lib +import numpy as np + + +def make_mqa_splash_attention_fn_and_args(): + seed = 0 + key = random.key(seed) + k1, k2, k3 = random.split(key, 3) + + q_seq_len = 1024 + kv_seq_len = 1024 + num_q_heads = 2 + head_dim_qk = 128 + head_dim_v = 128 + dtype = np.dtype("float32") + + q = random.uniform(k1, (num_q_heads, q_seq_len, head_dim_qk), dtype=dtype) + k = random.uniform(k2, (kv_seq_len, head_dim_qk), dtype=dtype) + v = random.uniform(k3, (kv_seq_len, head_dim_v), dtype=dtype) + + mask = mask_lib.NumpyMask( + mask_lib.make_random_mask((q_seq_len, kv_seq_len), sparsity=0.5, seed=0) + ) + mask = mask_lib.MultiHeadMask(tuple(mask for _ in range(num_q_heads))) + block_sizes = splash.BlockSizes.get_default() + + return ( + jax.jit( + splash.make_splash_mqa_single_device(mask, block_sizes=block_sizes) + ) + ), (q, k, v) + + +@google_benchmark.register +@google_benchmark.option.unit(google_benchmark.kMillisecond) +def test_pallas_mqa_splash_attention_trace(state): + attn, (q, k, v) = make_mqa_splash_attention_fn_and_args() + + while state: + _ = attn.trace(q, k, v) + jax.clear_caches() + + +@google_benchmark.register +@google_benchmark.option.unit(google_benchmark.kMillisecond) +def test_pallas_mqa_splash_attention_lower(state): + attn, (q, k, v) = make_mqa_splash_attention_fn_and_args() + traced = attn.trace(q, k, v) + + while state: + _ = traced.lower(lowering_platforms=("tpu",)) + jax.clear_caches() + + +if __name__ == "__main__": + google_benchmark.main() diff --git a/build/build.py b/build/build.py index d38b911bb904..87aa36aeba8b 100755 --- a/build/build.py +++ b/build/build.py @@ -68,13 +68,20 @@ # rule as the default. WHEEL_BUILD_TARGET_DICT_NEW = { "jax": "//:jax_wheel", + "jax_editable": "//:jax_wheel_editable", + "jax_source_package": "//:jax_source_package", "jaxlib": "//jaxlib/tools:jaxlib_wheel", + "jaxlib_editable": "//jaxlib/tools:jaxlib_wheel_editable", "jax-cuda-plugin": "//jaxlib/tools:jax_cuda_plugin_wheel", + "jax-cuda-plugin_editable": "//jaxlib/tools:jax_cuda_plugin_wheel_editable", "jax-cuda-pjrt": "//jaxlib/tools:jax_cuda_pjrt_wheel", + "jax-cuda-pjrt_editable": "//jaxlib/tools:jax_cuda_pjrt_wheel_editable", "jax-rocm-plugin": "//jaxlib/tools:jax_rocm_plugin_wheel", "jax-rocm-pjrt": "//jaxlib/tools:jax_rocm_pjrt_wheel", } +_JAX_CUDA_VERSION = "12" + def add_global_arguments(parser: argparse.ArgumentParser): """Adds all the global arguments that applies to all the CLI subcommands.""" parser.add_argument( @@ -382,6 +389,11 @@ async def main(): arch = platform.machine() os_name = platform.system().lower() + custom_wheel_version_suffix = "" + wheel_build_date = "" + wheel_git_hash = "" + wheel_type = "snapshot" + args = parser.parse_args() logger.info("%s", BANNER) @@ -407,7 +419,7 @@ async def main(): for option in args.bazel_startup_options: bazel_command_base.append(option) - if not args.use_new_wheel_build_rule or args.command == "requirements_update": + if args.command == "requirements_update" or not args.use_new_wheel_build_rule: bazel_command_base.append("run") else: bazel_command_base.append("build") @@ -489,6 +501,7 @@ async def main(): if args.use_clang: clang_path = args.clang_path or utils.get_clang_path_or_exit() clang_major_version = utils.get_clang_major_version(clang_path) + clangpp_path = utils.get_clangpp_path(clang_path) logging.debug( "Using Clang as the compiler, clang path: %s, clang version: %s", clang_path, @@ -498,6 +511,7 @@ async def main(): # Use double quotes around clang path to avoid path issues on Windows. wheel_build_command_base.append(f"--action_env=CLANG_COMPILER_PATH=\"{clang_path}\"") wheel_build_command_base.append(f"--repo_env=CC=\"{clang_path}\"") + wheel_build_command_base.append(f"--repo_env=CXX=\"{clangpp_path}\"") wheel_build_command_base.append(f"--repo_env=BAZEL_COMPILER=\"{clang_path}\"") if clang_major_version >= 16: @@ -612,6 +626,17 @@ async def main(): ) for option in args.bazel_options: wheel_build_command_base.append(option) + + # Parse the build options for the wheel version suffix. + if "ML_WHEEL_TYPE" in option: + wheel_type = option.split("=")[-1] + if "ML_WHEEL_VERSION_SUFFIX" in option: + custom_wheel_version_suffix = option.split("=")[-1].replace("-", "") + if "ML_WHEEL_BUILD_DATE" in option: + wheel_build_date = option.split("=")[-1].replace("-", "") + if "ML_WHEEL_GIT_HASH" in option: + wheel_git_hash = option.split("=")[-1][:9] + if "cuda" in args.wheels: wheel_build_command_base.append("--config=cuda_libraries_from_stubs") @@ -659,8 +684,13 @@ async def main(): ) # Append the build target to the Bazel command. - build_target = wheel_build_targets[wheel] + if args.use_new_wheel_build_rule and args.editable: + build_target = wheel_build_targets[wheel + "_editable"] + else: + build_target = wheel_build_targets[wheel] wheel_build_command.append(build_target) + if args.use_new_wheel_build_rule and wheel == "jax" and not args.editable: + wheel_build_command.append(wheel_build_targets["jax_source_package"]) if not args.use_new_wheel_build_rule: wheel_build_command.append("--") @@ -692,6 +722,54 @@ async def main(): if result.return_code != 0: raise RuntimeError(f"Command failed with return code {result.return_code}") + if args.use_new_wheel_build_rule: + output_path = args.output_path + jax_bazel_dir = os.path.join("bazel-bin", "dist") + jaxlib_and_plugins_bazel_dir = os.path.join( + "bazel-bin", "jaxlib", "tools", "dist" + ) + for wheel in args.wheels.split(","): + if wheel == "jax": + bazel_dir = jax_bazel_dir + else: + bazel_dir = jaxlib_and_plugins_bazel_dir + if "cuda" in wheel: + wheel_dir = wheel.replace("cuda", f"cuda{_JAX_CUDA_VERSION}").replace( + "-", "_" + ) + else: + wheel_dir = wheel + + if args.editable: + src_dir = os.path.join(bazel_dir, wheel_dir) + dst_dir = os.path.join(output_path, wheel_dir) + utils.copy_dir_recursively(src_dir, dst_dir) + else: + wheel_version_suffix = "dev0+selfbuilt" + if wheel_type == "release": + wheel_version_suffix = custom_wheel_version_suffix + elif wheel_type in ["nightly", "custom"]: + wheel_version_suffix = f".dev{wheel_build_date}" + if wheel_type == "custom": + wheel_version_suffix += ( + f"+{wheel_git_hash}{custom_wheel_version_suffix}" + ) + if wheel in ["jax", "jax-cuda-pjrt"]: + python_tag = "py" + else: + python_tag = "cp" + utils.copy_individual_files( + bazel_dir, + output_path, + f"{wheel_dir}*{wheel_version_suffix}-{python_tag}*.whl", + ) + if wheel == "jax": + utils.copy_individual_files( + bazel_dir, + output_path, + f"{wheel_dir}*{wheel_version_suffix}.tar.gz", + ) + # Exit with success if all wheels in the list were built successfully. sys.exit(0) diff --git a/build/collect-profile-requirements.txt b/build/collect-profile-requirements.txt index da25d4b6ffe1..a7d57dd2c4ef 100644 --- a/build/collect-profile-requirements.txt +++ b/build/collect-profile-requirements.txt @@ -1,4 +1,4 @@ tensorflow -tensorboard-plugin-profile +tensorboard-plugin-profile<=2.19.0 # Needed for the profile plugin to work without error protobuf diff --git a/build/gpu-test-requirements.txt b/build/gpu-test-requirements.txt index ff43f91ba90f..d0dda5cf526c 100644 --- a/build/gpu-test-requirements.txt +++ b/build/gpu-test-requirements.txt @@ -5,7 +5,7 @@ nvidia-cublas-cu12>=12.1.3.1 ; sys_platform == "linux" nvidia-cuda-cupti-cu12>=12.1.105 ; sys_platform == "linux" nvidia-cuda-nvcc-cu12>=12.6.85 ; sys_platform == "linux" nvidia-cuda-runtime-cu12>=12.1.105 ; sys_platform == "linux" -nvidia-cudnn-cu12>=9.1,<10.0 ; sys_platform == "linux" +nvidia-cudnn-cu12>=9.8,<10.0 ; sys_platform == "linux" nvidia-cufft-cu12>=11.0.2.54 ; sys_platform == "linux" nvidia-cusolver-cu12>=11.4.5.107 ; sys_platform == "linux" nvidia-cusparse-cu12>=12.1.0.106 ; sys_platform == "linux" diff --git a/build/requirements.in b/build/requirements.in index ec7fc71b07e1..b023cedfbd19 100644 --- a/build/requirements.in +++ b/build/requirements.in @@ -13,7 +13,8 @@ numpy~=2.1.0; python_version>="3.13" # # runtime deps # -scipy>=1.13.1 +scipy>=1.13.1; python_version<="3.12" +scipy>=1.15.2; python_version>="3.13" ml_dtypes>=0.4.0 opt_einsum diff --git a/build/requirements_lock_3_10.txt b/build/requirements_lock_3_10.txt index 6ed6b59aa584..8bf5293bd948 100644 --- a/build/requirements_lock_3_10.txt +++ b/build/requirements_lock_3_10.txt @@ -410,10 +410,10 @@ nvidia-cuda-runtime-cu12==12.8.57 ; sys_platform == "linux" \ --hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \ --hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5 # via -r build/test-requirements.txt -nvidia-cudnn-cu12==9.7.1.26 ; sys_platform == "linux" \ - --hash=sha256:6d011159a158f3cfc47bf851aea79e31bcff60d530b70ef70474c84cac484d07 \ - --hash=sha256:7b805b9a4cf9f3da7c5f4ea4a9dff7baf62d1a612d6154a7e0d2ea51ed296241 \ - --hash=sha256:848a61d40ef3b32bd4e1fadb599f0cf04a4b942fbe5fb3be572ad75f9b8c53ef +nvidia-cudnn-cu12==9.8.0.87 ; sys_platform == "linux" \ + --hash=sha256:b4b5cfddc32aa4180f9d390ee99e9a9f55a89e7087329b41aba4319327e22466 \ + --hash=sha256:b883faeb2f6f15dba7bbb6756eab6a0d9cecb59db5b0fa07577b9cfa24cd99f4 \ + --hash=sha256:d6b02cd0e3e24aa31d0193a8c39fec239354360d7d81055edddb69f35d53a4c8 # via -r build/test-requirements.txt nvidia-cufft-cu12==11.3.3.41 ; sys_platform == "linux" \ --hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \ diff --git a/build/requirements_lock_3_11.txt b/build/requirements_lock_3_11.txt index 8446e8361505..487346ab6d12 100644 --- a/build/requirements_lock_3_11.txt +++ b/build/requirements_lock_3_11.txt @@ -405,10 +405,10 @@ nvidia-cuda-runtime-cu12==12.8.57 ; sys_platform == "linux" \ --hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \ --hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5 # via -r build/test-requirements.txt -nvidia-cudnn-cu12==9.7.1.26 ; sys_platform == "linux" \ - --hash=sha256:6d011159a158f3cfc47bf851aea79e31bcff60d530b70ef70474c84cac484d07 \ - --hash=sha256:7b805b9a4cf9f3da7c5f4ea4a9dff7baf62d1a612d6154a7e0d2ea51ed296241 \ - --hash=sha256:848a61d40ef3b32bd4e1fadb599f0cf04a4b942fbe5fb3be572ad75f9b8c53ef +nvidia-cudnn-cu12==9.8.0.87 ; sys_platform == "linux" \ + --hash=sha256:b4b5cfddc32aa4180f9d390ee99e9a9f55a89e7087329b41aba4319327e22466 \ + --hash=sha256:b883faeb2f6f15dba7bbb6756eab6a0d9cecb59db5b0fa07577b9cfa24cd99f4 \ + --hash=sha256:d6b02cd0e3e24aa31d0193a8c39fec239354360d7d81055edddb69f35d53a4c8 # via -r build/test-requirements.txt nvidia-cufft-cu12==11.3.3.41 ; sys_platform == "linux" \ --hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \ diff --git a/build/requirements_lock_3_12.txt b/build/requirements_lock_3_12.txt index 0436ab6dd486..e2f76cab8abc 100644 --- a/build/requirements_lock_3_12.txt +++ b/build/requirements_lock_3_12.txt @@ -405,10 +405,10 @@ nvidia-cuda-runtime-cu12==12.8.57 ; sys_platform == "linux" \ --hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \ --hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5 # via -r build/test-requirements.txt -nvidia-cudnn-cu12==9.7.1.26 ; sys_platform == "linux" \ - --hash=sha256:6d011159a158f3cfc47bf851aea79e31bcff60d530b70ef70474c84cac484d07 \ - --hash=sha256:7b805b9a4cf9f3da7c5f4ea4a9dff7baf62d1a612d6154a7e0d2ea51ed296241 \ - --hash=sha256:848a61d40ef3b32bd4e1fadb599f0cf04a4b942fbe5fb3be572ad75f9b8c53ef +nvidia-cudnn-cu12==9.8.0.87 ; sys_platform == "linux" \ + --hash=sha256:b4b5cfddc32aa4180f9d390ee99e9a9f55a89e7087329b41aba4319327e22466 \ + --hash=sha256:b883faeb2f6f15dba7bbb6756eab6a0d9cecb59db5b0fa07577b9cfa24cd99f4 \ + --hash=sha256:d6b02cd0e3e24aa31d0193a8c39fec239354360d7d81055edddb69f35d53a4c8 # via -r build/test-requirements.txt nvidia-cufft-cu12==11.3.3.41 ; sys_platform == "linux" \ --hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \ diff --git a/build/requirements_lock_3_13.txt b/build/requirements_lock_3_13.txt index e74d40b798f4..403d0ad8a061 100644 --- a/build/requirements_lock_3_13.txt +++ b/build/requirements_lock_3_13.txt @@ -460,10 +460,10 @@ nvidia-cuda-runtime-cu12==12.8.57 ; sys_platform == "linux" \ --hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \ --hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5 # via -r build/test-requirements.txt -nvidia-cudnn-cu12==9.7.1.26 ; sys_platform == "linux" \ - --hash=sha256:6d011159a158f3cfc47bf851aea79e31bcff60d530b70ef70474c84cac484d07 \ - --hash=sha256:7b805b9a4cf9f3da7c5f4ea4a9dff7baf62d1a612d6154a7e0d2ea51ed296241 \ - --hash=sha256:848a61d40ef3b32bd4e1fadb599f0cf04a4b942fbe5fb3be572ad75f9b8c53ef +nvidia-cudnn-cu12==9.8.0.87 ; sys_platform == "linux" \ + --hash=sha256:b4b5cfddc32aa4180f9d390ee99e9a9f55a89e7087329b41aba4319327e22466 \ + --hash=sha256:b883faeb2f6f15dba7bbb6756eab6a0d9cecb59db5b0fa07577b9cfa24cd99f4 \ + --hash=sha256:d6b02cd0e3e24aa31d0193a8c39fec239354360d7d81055edddb69f35d53a4c8 # via -r build/test-requirements.txt nvidia-cufft-cu12==11.3.3.41 ; sys_platform == "linux" \ --hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \ diff --git a/build/requirements_lock_3_13_ft.txt b/build/requirements_lock_3_13_ft.txt index e7a2968e981e..507e896ab8db 100644 --- a/build/requirements_lock_3_13_ft.txt +++ b/build/requirements_lock_3_13_ft.txt @@ -413,10 +413,10 @@ nvidia-cuda-runtime-cu12==12.8.57 ; sys_platform == "linux" \ --hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \ --hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5 # via -r build/test-requirements.txt -nvidia-cudnn-cu12==9.7.1.26 ; sys_platform == "linux" \ - --hash=sha256:6d011159a158f3cfc47bf851aea79e31bcff60d530b70ef70474c84cac484d07 \ - --hash=sha256:7b805b9a4cf9f3da7c5f4ea4a9dff7baf62d1a612d6154a7e0d2ea51ed296241 \ - --hash=sha256:848a61d40ef3b32bd4e1fadb599f0cf04a4b942fbe5fb3be572ad75f9b8c53ef +nvidia-cudnn-cu12==9.8.0.87 ; sys_platform == "linux" \ + --hash=sha256:b4b5cfddc32aa4180f9d390ee99e9a9f55a89e7087329b41aba4319327e22466 \ + --hash=sha256:b883faeb2f6f15dba7bbb6756eab6a0d9cecb59db5b0fa07577b9cfa24cd99f4 \ + --hash=sha256:d6b02cd0e3e24aa31d0193a8c39fec239354360d7d81055edddb69f35d53a4c8 # via -r build/test-requirements.txt nvidia-cufft-cu12==11.3.3.41 ; sys_platform == "linux" \ --hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \ @@ -596,47 +596,53 @@ rich==13.9.4 \ --hash=sha256:439594978a49a09530cff7ebc4b5c7103ef57baf48d5ea3184f21d9a2befa098 \ --hash=sha256:6049d5e6ec054bf2779ab3358186963bac2ea89175919d699e378b99738c2a90 # via -r build/test-requirements.txt -scipy==1.15.0 \ - --hash=sha256:0e5b34f8894f9904cc578008d1a9467829c1817e9f9cb45e6d6eeb61d2ab7731 \ - --hash=sha256:0fcb16eb04d84670722ce8d93b05257df471704c913cb0ff9dc5a1c31d1e9422 \ - --hash=sha256:129f899ed275c0515d553b8d31696924e2ca87d1972421e46c376b9eb87de3d2 \ - --hash=sha256:161f80a98047c219c257bf5ce1777c574bde36b9d962a46b20d0d7e531f86863 \ - --hash=sha256:1b29e4fc02e155a5fd1165f1e6a73edfdd110470736b0f48bcbe48083f0eee37 \ - --hash=sha256:1e2448acd79c6374583581a1ded32ac71a00c2b9c62dfa87a40e1dd2520be111 \ - --hash=sha256:2240e1fd0782e62e1aacdc7234212ee271d810f67e9cd3b8d521003a82603ef8 \ - --hash=sha256:300742e2cc94e36a2880ebe464a1c8b4352a7b0f3e36ec3d2ac006cdbe0219ac \ - --hash=sha256:327163ad73e54541a675240708244644294cb0a65cca420c9c79baeb9648e479 \ - --hash=sha256:351899dd2a801edd3691622172bc8ea01064b1cada794f8641b89a7dc5418db6 \ - --hash=sha256:35c68f7044b4e7ad73a3e68e513dda946989e523df9b062bd3cf401a1a882192 \ - --hash=sha256:36be480e512d38db67f377add5b759fb117edd987f4791cdf58e59b26962bee4 \ - --hash=sha256:37ce9394cdcd7c5f437583fc6ef91bd290014993900643fdfc7af9b052d1613b \ - --hash=sha256:46e91b5b16909ff79224b56e19cbad65ca500b3afda69225820aa3afbf9ec020 \ - --hash=sha256:4e08c6a36f46abaedf765dd2dfcd3698fa4bd7e311a9abb2d80e33d9b2d72c34 \ - --hash=sha256:52475011be29dfcbecc3dfe3060e471ac5155d72e9233e8d5616b84e2b542054 \ - --hash=sha256:5972e3f96f7dda4fd3bb85906a17338e65eaddfe47f750e240f22b331c08858e \ - --hash=sha256:5abbdc6ede5c5fed7910cf406a948e2c0869231c0db091593a6b2fa78be77e5d \ - --hash=sha256:5beb0a2200372b7416ec73fdae94fe81a6e85e44eb49c35a11ac356d2b8eccc6 \ - --hash=sha256:61513b989ee8d5218fbeb178b2d51534ecaddba050db949ae99eeb3d12f6825d \ - --hash=sha256:6d26f17c64abd6c6c2dfb39920f61518cc9e213d034b45b2380e32ba78fde4c0 \ - --hash=sha256:6f376d7c767731477bac25a85d0118efdc94a572c6b60decb1ee48bf2391a73b \ - --hash=sha256:767e8cf6562931f8312f4faa7ddea412cb783d8df49e62c44d00d89f41f9bbe8 \ - --hash=sha256:82bff2eb01ccf7cea8b6ee5274c2dbeadfdac97919da308ee6d8e5bcbe846443 \ - --hash=sha256:952d2e9eaa787f0a9e95b6e85da3654791b57a156c3e6609e65cc5176ccfe6f2 \ - --hash=sha256:9c8254fe21dd2c6c8f7757035ec0c31daecf3bb3cffd93bc1ca661b731d28136 \ - --hash=sha256:aeac60d3562a7bf2f35549bdfdb6b1751c50590f55ce7322b4b2fc821dc27fca \ - --hash=sha256:b1432102254b6dc7766d081fa92df87832ac25ff0b3d3a940f37276e63eb74ff \ - --hash=sha256:bdca4c7bb8dc41307e5f39e9e5d19c707d8e20a29845e7533b3bb20a9d4ccba0 \ - --hash=sha256:c9624eeae79b18cab1a31944b5ef87aa14b125d6ab69b71db22f0dbd962caf1e \ - --hash=sha256:ccb6248a9987193fe74363a2d73b93bc2c546e0728bd786050b7aef6e17db03c \ - --hash=sha256:cd9d9198a7fd9a77f0eb5105ea9734df26f41faeb2a88a0e62e5245506f7b6df \ - --hash=sha256:d13bbc0658c11f3d19df4138336e4bce2c4fbd78c2755be4bf7b8e235481557f \ - --hash=sha256:d35aef233b098e4de88b1eac29f0df378278e7e250a915766786b773309137c4 \ - --hash=sha256:de112c2dae53107cfeaf65101419662ac0a54e9a088c17958b51c95dac5de56d \ - --hash=sha256:e9baff912ea4f78a543d183ed6f5b3bea9784509b948227daaf6f10727a0e2e5 \ - --hash=sha256:eb1533c59f0ec6c55871206f15a5c72d1fae7ad3c0a8ca33ca88f7c309bbbf8c \ - --hash=sha256:ec915cd26d76f6fc7ae8522f74f5b2accf39546f341c771bb2297f3871934a52 \ - --hash=sha256:fde0f3104dfa1dfbc1f230f65506532d0558d43188789eaf68f97e106249a913 \ - --hash=sha256:fe00169cf875bed0b3c40e4da45b57037dc21d7c7bf0c85ed75f210c281488f1 +scipy==1.15.2 ; python_version >= "3.13" \ + --hash=sha256:01edfac9f0798ad6b46d9c4c9ca0e0ad23dbf0b1eb70e96adb9fa7f525eff0bf \ + --hash=sha256:03205d57a28e18dfd39f0377d5002725bf1f19a46f444108c29bdb246b6c8a11 \ + --hash=sha256:08b57a9336b8e79b305a143c3655cc5bdbe6d5ece3378578888d2afbb51c4e37 \ + --hash=sha256:11e7ad32cf184b74380f43d3c0a706f49358b904fa7d5345f16ddf993609184d \ + --hash=sha256:28a0d2c2075946346e4408b211240764759e0fabaeb08d871639b5f3b1aca8a0 \ + --hash=sha256:2b871df1fe1a3ba85d90e22742b93584f8d2b8e6124f8372ab15c71b73e428b8 \ + --hash=sha256:302093e7dfb120e55515936cb55618ee0b895f8bcaf18ff81eca086c17bd80af \ + --hash=sha256:42dabaaa798e987c425ed76062794e93a243be8f0f20fff6e7a89f4d61cb3d40 \ + --hash=sha256:447ce30cee6a9d5d1379087c9e474628dab3db4a67484be1b7dc3196bfb2fac9 \ + --hash=sha256:4c6676490ad76d1c2894d77f976144b41bd1a4052107902238047fb6a473e971 \ + --hash=sha256:54c462098484e7466362a9f1672d20888f724911a74c22ae35b61f9c5919183d \ + --hash=sha256:597a0c7008b21c035831c39927406c6181bcf8f60a73f36219b69d010aa04737 \ + --hash=sha256:5a6fd6eac1ce74a9f77a7fc724080d507c5812d61e72bd5e4c489b042455865e \ + --hash=sha256:5ea7ed46d437fc52350b028b1d44e002646e28f3e8ddc714011aaf87330f2f32 \ + --hash=sha256:601881dfb761311045b03114c5fe718a12634e5608c3b403737ae463c9885d53 \ + --hash=sha256:62ca1ff3eb513e09ed17a5736929429189adf16d2d740f44e53270cc800ecff1 \ + --hash=sha256:69ea6e56d00977f355c0f84eba69877b6df084516c602d93a33812aa04d90a3d \ + --hash=sha256:6a8e34cf4c188b6dd004654f88586d78f95639e48a25dfae9c5e34a6dc34547e \ + --hash=sha256:6d0194c37037707b2afa7a2f2a924cf7bac3dc292d51b6a925e5fcb89bc5c776 \ + --hash=sha256:6f223753c6ea76983af380787611ae1291e3ceb23917393079dcc746ba60cfb5 \ + --hash=sha256:6f5e296ec63c5da6ba6fa0343ea73fd51b8b3e1a300b0a8cae3ed4b1122c7462 \ + --hash=sha256:7cd5b77413e1855351cdde594eca99c1f4a588c2d63711388b6a1f1c01f62274 \ + --hash=sha256:869269b767d5ee7ea6991ed7e22b3ca1f22de73ab9a49c44bad338b725603301 \ + --hash=sha256:87994da02e73549dfecaed9e09a4f9d58a045a053865679aeb8d6d43747d4df3 \ + --hash=sha256:888307125ea0c4466287191e5606a2c910963405ce9671448ff9c81c53f85f58 \ + --hash=sha256:92233b2df6938147be6fa8824b8136f29a18f016ecde986666be5f4d686a91a4 \ + --hash=sha256:9412f5e408b397ff5641080ed1e798623dbe1ec0d78e72c9eca8992976fa65aa \ + --hash=sha256:9b18aa747da280664642997e65aab1dd19d0c3d17068a04b3fe34e2559196cb9 \ + --hash=sha256:9de9d1416b3d9e7df9923ab23cd2fe714244af10b763975bea9e4f2e81cebd27 \ + --hash=sha256:a2ec871edaa863e8213ea5df811cd600734f6400b4af272e1c011e69401218e9 \ + --hash=sha256:a5080a79dfb9b78b768cebf3c9dcbc7b665c5875793569f48bf0e2b1d7f68f6f \ + --hash=sha256:a8bf5cb4a25046ac61d38f8d3c3426ec11ebc350246a4642f2f315fe95bda655 \ + --hash=sha256:b09ae80010f52efddb15551025f9016c910296cf70adbf03ce2a8704f3a5ad20 \ + --hash=sha256:b5e025e903b4f166ea03b109bb241355b9c42c279ea694d8864d033727205e65 \ + --hash=sha256:bad78d580270a4d32470563ea86c6590b465cb98f83d760ff5b0990cb5518a93 \ + --hash=sha256:bae43364d600fdc3ac327db99659dcb79e6e7ecd279a75fe1266669d9a652828 \ + --hash=sha256:c4697a10da8f8765bb7c83e24a470da5797e37041edfd77fd95ba3811a47c4fd \ + --hash=sha256:c90ebe8aaa4397eaefa8455a8182b164a6cc1d59ad53f79943f266d99f68687f \ + --hash=sha256:cd58a314d92838f7e6f755c8a2167ead4f27e1fd5c1251fd54289569ef3495ec \ + --hash=sha256:cf72ff559a53a6a6d77bd8eefd12a17995ffa44ad86c77a5df96f533d4e6c6bb \ + --hash=sha256:def751dd08243934c884a3221156d63e15234a3155cf25978b0a668409d45eb6 \ + --hash=sha256:e7c68b6a43259ba0aab737237876e5c2c549a031ddb7abc28c7b47f22e202ded \ + --hash=sha256:ecf797d2d798cf7c838c6d98321061eb3e72a74710e6c40540f0e8087e3b499e \ + --hash=sha256:f031846580d9acccd0044efd1a90e6f4df3a6e12b4b6bd694a7bc03a89892b28 \ + --hash=sha256:fb530e4794fc8ea76a4a21ccb67dea33e5e0e60f07fc38a49e821e1eae3b71a0 \ + --hash=sha256:fe8a9eb875d430d81755472c5ba75e84acc980e4a8f6204d402849234d3017db # via -r build/requirements.in six==1.17.0 \ --hash=sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274 \ @@ -658,7 +664,7 @@ zipp==3.21.0 \ --hash=sha256:2c9958f6430a2040341a52eb608ed6dd93ef4392e02ffe219417c1b28b5dd1f4 \ --hash=sha256:ac1bbe05fd2991f160ebce24ffbac5f6d11d83dc90891255885223d42b3cd931 # via etils -# python 3.13t can compile 0.23.0 +# python 3.13t can't compile 0.23.0 # due to https://github.com/indygreg/python-zstandard/issues/231 # zstandard==0.23.0 \ # --hash=sha256:034b88913ecc1b097f528e42b539453fa82c3557e414b3de9d5632c80439a473 \ diff --git a/build/requirements_lock_3_14_ft.txt b/build/requirements_lock_3_14_ft.txt new file mode 100644 index 000000000000..e50305f4fa48 --- /dev/null +++ b/build/requirements_lock_3_14_ft.txt @@ -0,0 +1,21 @@ +--pre +--extra-index-url https://pypi.anaconda.org/scientific-python-nightly-wheels/simple +numpy + +--pre +--extra-index-url https://pypi.anaconda.org/scientific-python-nightly-wheels/simple +scipy + +absl-py==2.1.0 + +attrs==24.3.0 + +hypothesis==6.123.9 + +sortedcontainers==2.4.0 + +flatbuffers==24.12.23 + +ml-dtypes==0.5.1 + +opt-einsum==3.4.0 diff --git a/build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm b/build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm index 08b6bd3ff8d6..8afe8b17252c 100644 --- a/build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm +++ b/build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm @@ -9,7 +9,7 @@ ARG ROCM_BUILD_NUM # manylinux base image. However, adding this does fix an issue where Bazel isn't able # to find them. RUN --mount=type=cache,target=/var/cache/dnf \ - dnf install -y gcc-c++-8.5.0-22.el8_10.x86_64 + dnf install -y gcc-c++-8.5.0-22.el8_10.x86_64 numactl-devel RUN --mount=type=cache,target=/var/cache/dnf \ --mount=type=bind,source=build/rocm/tools/get_rocm.py,target=get_rocm.py \ diff --git a/build/rocm/tools/build_wheels.py b/build/rocm/tools/build_wheels.py index fd98bbb8ec04..a7ebdf86f916 100644 --- a/build/rocm/tools/build_wheels.py +++ b/build/rocm/tools/build_wheels.py @@ -226,7 +226,10 @@ def fix_wheel(path, jax_path): py_bin = "/opt/python/cp310-cp310/bin" env["PATH"] = "%s:%s" % (py_bin, env["PATH"]) - cmd = ["pip", "install", "auditwheel>=6"] + # NOTE(mrodden): auditwheel 6.0 added lddtree module, but 6.3.0 changed + # the fuction to ldd and also changed its behavior + # constrain range to 6.0 to 6.2.x + cmd = ["pip", "install", "auditwheel>=6,<6.3"] subprocess.run(cmd, check=True, env=env) fixwheel_path = os.path.join(jax_path, "build/rocm/tools/fixwheel.py") diff --git a/build/tools/utils.py b/build/tools/utils.py index 7e375169827b..c52b89a1e6d2 100644 --- a/build/tools/utils.py +++ b/build/tools/utils.py @@ -14,6 +14,7 @@ # ============================================================================== # Helper script for tools/utilities used by the JAX build CLI. import collections +import glob import hashlib import logging import os @@ -201,6 +202,20 @@ def get_clang_major_version(clang_path): return major_version +def get_clangpp_path(clang_path): + clang_path = pathlib.Path(clang_path) + clang_exec_name = clang_path.name + clangpp_exec_name = clang_exec_name + if "clang++" not in clang_exec_name: + clangpp_exec_name = clang_exec_name.replace("clang", "clang++") + clangpp_path = clang_path.parent / clangpp_exec_name + if not clangpp_path.exists(): + raise FileNotFoundError( + f"Failed to get clang++ path from clang path: '{clang_path!s}'. " + f"Tried the path: '{clangpp_path!s}'." + ) + return str(clangpp_path) + def get_gcc_major_version(gcc_path: str): gcc_version_proc = subprocess.run( [gcc_path, "-dumpversion"], @@ -256,3 +271,28 @@ def _parse_string_as_bool(s): return False else: raise ValueError(f"Expected either 'true' or 'false'; got {s}") + + +def copy_dir_recursively(src, dst): + if os.path.exists(dst): + shutil.rmtree(dst) + os.makedirs(dst, exist_ok=True) + for root, dirs, files in os.walk(src): + relative_path = os.path.relpath(root, src) + dst_dir = os.path.join(dst, relative_path) + os.makedirs(dst_dir, exist_ok=True) + for f in files: + src_file = os.path.join(root, f) + dst_file = os.path.join(dst_dir, f) + shutil.copy2(src_file, dst_file) + logging.info("Editable wheel path: %s" % dst) + + +def copy_individual_files(src, dst, regex): + os.makedirs(dst, exist_ok=True) + for f in glob.glob(os.path.join(src, regex)): + dst_file = os.path.join(dst, os.path.basename(f)) + if os.path.exists(dst_file): + os.remove(dst_file) + shutil.copy2(f, dst_file) + logging.info("Distribution path: %s" % dst_file) diff --git a/build_wheel.py b/build_wheel.py index f8e1595d3c3a..793523e8e3b2 100644 --- a/build_wheel.py +++ b/build_wheel.py @@ -47,6 +47,25 @@ parser.add_argument( "--srcs", help="source files for the wheel", action="append" ) +parser.add_argument( + "--build-wheel-only", + default=False, + help=( + "Whether to build the wheel only. Optional." + ), +) +parser.add_argument( + "--build-source-package-only", + default=False, + help=( + "Whether to build the source package only. Optional." + ), +) +parser.add_argument( + "--editable", + action="store_true", + help="Create an 'editable' jax build instead of a wheel.", +) args = parser.parse_args() @@ -76,7 +95,11 @@ def prepare_srcs(deps: list[str], srcs_dir: str) -> None: """ for file in deps: - if not (file.startswith("bazel-out") or file.startswith("external")): + if not ( + file.startswith("bazel-out") + or file.startswith("external") + or file.startswith("jaxlib") + ): copy_file(file, srcs_dir) @@ -89,13 +112,18 @@ def prepare_srcs(deps: list[str], srcs_dir: str) -> None: try: os.makedirs(args.output_path, exist_ok=True) prepare_srcs(args.srcs, pathlib.Path(sources_path)) - build_utils.build_wheel( - sources_path, - args.output_path, - package_name="jax", - git_hash=args.jaxlib_git_hash, - build_wheel_only=False, - ) + package_name = "jax" + if args.editable: + build_utils.build_editable(sources_path, args.output_path, package_name) + else: + build_utils.build_wheel( + sources_path, + args.output_path, + package_name, + git_hash=args.jaxlib_git_hash, + build_wheel_only=args.build_wheel_only, + build_source_package_only=args.build_source_package_only, + ) finally: if tmpdir: tmpdir.cleanup() diff --git a/ci/README.md b/ci/README.md index ea867df52f97..31af3ec0ef87 100644 --- a/ci/README.md +++ b/ci/README.md @@ -1,10 +1,254 @@ -# JAX continuous integration +# JAX Continuous Integration -> [!WARNING] -> This folder is still under construction. It is part of an ongoing -> effort to improve the structure of CI and build related files within the -> JAX repo. This warning will be removed when the contents of this -> directory are stable and appropriate documentation around its usage is in -> place. +This folder contains the configuration files and scripts used to build and test +JAX. It is typically used by continuous integration (CI) jobs to automate builds +and run comprehensive tests across various platforms and configurations. This +page provides an overview of the JAX CI system, its components, and the +different workflows it supports. -******************************************************************************** \ No newline at end of file +******************************************************************************** + +## JAX's CI System + +![Overview of JAX's CI System](jax_ci_system.png) + +JAX's CI system is composed of several interacting components and orchestrates +builds and tests using a hybrid approach, leveraging both an internal CI system +and GitHub Actions as well as an internal build orchestrator for managing +nightly and release flows. It encompasses several distinct workflows, including +comprehensive presubmit checks triggered on pull requests and branch pushes, +bi-hourly continuous builds, extensive nightly builds with broad platform +coverage, and a controlled release process that culminates in PyPI publication. + +These flows build four packages: `jax`, `jaxlib`, `jax-cuda-plugin`, +`jax-cuda-pjrt` and support a range of environments, including: + +* **Linux x86:** CPU, TPU, CUDA +* **Linux aarch64:** CPU, CUDA +* **Windows x86:** CPU +* **Mac Arm64:** CPU + +### Architecture Overview + +1. **Internal CI System:** An internal CI system is used for specific build and + test tasks, such as nightly builds, release candidate (RC) builds, and + Mac-specific testing. + +2. **GitHub Actions:** Used for presubmit checks, continuous integration builds + and tests, and nightly/release artifact testing. + +3. **Build Orchestrator:** An internal tool used to manage complex workflows + such as nightly / release flows, promoting RC builds to release, etc. + +4. **Artifact Storage:** + +* Google Cloud Storage (GCS) Buckets: Used for temporary storage of artifacts + between jobs in GitHub Actions workflows and for storing packages built + during nightly and release flows before testing. +* Artifact Registry: Used to store nightly packages, RC packages and final + releases. +* PyPI: Where final releases are published. + +### CI Workflows and Where They Run + +JAX's CI system consists of the following workflows: + +1. **Presubmits:** Presubmits are run in GitHub actions and are triggered on + pull requests that target the `main` branch and on pushes to the `main` and + `release` branch. JAX's presubmit run time SLO is about 10 minutes so these + are typically run using Bazel with remote build execution + ([RBE](https://bazel.build/remote/rbe)). RBE allows us to execute build and + test actions on a distributed system, separate from the local machine, + instead of solely on the local machine. This enables faster build and test + times by utilizing parallel computing resources and caching across a cluster + of machines. However, we also use Pytest in workflows where we are not able + to use RBE such as the TPU presubmit. In such presubmits, we usually run a + subset of tests to be able to satisfy the presubmit run time SLO. To see the + list of the presubmit workflows, + [click here](https://github.com/search?q=repo%3Ajax-ml%2Fjax+path%3A.github%2Fworkflows%2F+%28path%3A**%2F*.yml+OR+path%3A**%2F*.yaml%29+%22pull_request%22&type=code). + +2. **Continuous:** These jobs are run in GitHub actions and are scheduled to + run once every 2 hours on the `main` branch. It builds JAX packages and runs + a wide range of tests targeting different environments such as CPU, CUDA + (L4, H100, B200, etc), and TPU (v4-8, v5e-8, etc.). For more information, + see + [wheel_tests_continuous.yml](https://github.com/jax-ml/jax/blob/main/.github/workflows/wheel_tests_continuous.yml) + ([An example run](https://github.com/jax-ml/jax/actions/workflows/wheel_tests_continuous.yml).) + +3. **Nightly Builds and Tests:** These jobs use an hybrid approach of both the + internal CI system and GitHub actions. The jobs are triggered once every + night by the internal build orchestrator tool. It first triggers the jobs in + the internal CI system to build the JAX packages for different + configurations (Python versions, CUDA versions, etc) and uploads them to a + staging bucket in GCS as well as to the nightly artifact registry. Next, + testing jobs are triggered that download the artifacts from the staging + bucket and run tests. Mac testing jobs are run in the internal CI system. + For non-Mac testing, a trigger job is run that invokes the + [wheel_tests_nightly_release.yml](https://github.com/jax-ml/jax/blob/main/.github/workflows/wheel_tests_nightly_release.yml) + workflow in GitHub Actions. JAX's nightly artifacts can be found here: + [jax](https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/jax), + [jaxlib](https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/jaxlib), + [jax-cuda-plugin](https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/jax-cuda12-plugin), + [jax-cuda-pjrt](https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/jax-cuda12-pjrt). + +4. **Release Builds and Tests:** Release flow is similar to the nightly flow + except for few differences. First, release process has to be triggered + manually in the internal build orchestrator and should be done only after a + release branch (E.g `release/0.5.3`) has been created. The build jobs build + two sets of artifacts for each package: 1. RC wheels 2. Final version + wheels. These two sets are pretty much the same package except for their + metadata and wheel tags. The RC wheels are then uploaded to the staging + bucket and release artifact registry. After the uploads are done, the test + jobs are triggered. As with the nightly flow, Mac test jobs are run in the + internal CI system while non-Mac test jobs are run in GitHub actions. To see + the GitHub actions run for a particular release, filter the workflow runs by + its branch name. + + +5. **Promote RC to Final and Publish to PyPI:** If the RC wheels pass all + testing, then we are ready to promote it as the final version and publish it + to PyPI. This entire flow is internal and is run in our internal CI system. + Final version of the packages are published to PyPI and JAX's release + artifact registry. JAX's release artifacts (RC and final versions) can be + found here: + [jax](https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-release-artifacts-registry/simple/jax), + [jaxlib](https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-release-artifacts-registry/simple/jaxlib), + [jax-cuda-plugin](https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-release-artifacts-registry/simple/jax-cuda12-plugin), + [jax-cuda-pjrt](https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-release-artifacts-registry/simple/jax-cuda12-pjrt). + +### JAX's Official CI and Build/Test Scripts + +JAX's CI jobs (both internal and those on GitHub actions) run the scripts in +this folder. An overview of the different folders and their purpose is given +below: + +- **ci/**: Contains all build scripts, environment files, and utility scripts. +- **ci/utilities/**: Contains helper scripts used throughout the build/test + process. See + [README.md](https://github.com/jax-ml/jax/blob/main/ci/utilities/README.md) + for a brief overview of these utility scripts and their behavior. +- **ci/envs/**: Holds environment files that set `JAXCI` environment variables + that control build and test configurations. see + [README.md](https://github.com/jax-ml/jax/blob/main/ci/envs/README.md) to + see the complete list of these variables and their behavior. + +Every build script in this folder first source the `JAXCI` envs in +[default.env](https://github.com/jax-ml/jax/blob/main/ci/envs/default.env) and +then run the +[setup_build_environment.sh](https://github.com/jax-ml/jax/blob/main/ci/utilities/setup_build_environment.sh) +script to set up the build environment. + +A brief overview of each build script in this folder is given below: + +> [!NOTE] +> Both internal and GitHub action jobs run under the +> [ml-build](https://github.com/tensorflow/tensorflow/tree/master/ci/official/containers) +> Docker image which contains build tools such as Python, Bazelisk, LLVM/Clang, +> manylinux compliant libraries (in Linux images), etc. + +- **build_artifacts.sh:** These build the various JAX artifacts. We build + three different type of artifacts based on the type of job: Nightly, + RC/Release, or at HEAD. +- **run_bazel_test_cpu_rbe.sh/run_bazel_test_cuda_rbe.sh**: These run Bazel + tests with RBE on every GitHub PR. We test compatibility with both CPU and + CUDA. On platforms where RBE is not natively supported (e.g Linux Arm64), we + cross-compile the test targets for Linux Aarch64 on Linux x86. As the tests + still need to be run on the host machines and because running the tests on a + single machine can take a long time, we skip running them on these + platforms. +- **run_bazel_test_cuda_non_rbe.sh**: These run the following Bazel CUDA + tests: Single accelerator tests with one GPU apiece and Multi-accelerator + tests with all GPUs. These jobs depend on local JAX wheels and therefore + require that the following wheels to be present in the `../dist` folder: + `jax`, `jaxlib`, `jax-cuda-plugin`, and `jax-cuda-pjrt` wheels. In CI + builds, we first build these wheels from source and then run the `bazel + test` command. +- **run_pytest_*.sh**: These run tests with Pytests and use the JAX wheel + packages installed on the system. In CI builds, we build the wheels first + from source and then run the `pytest` commands. We test compatibility with + CPU, CUDA, and TPU. These are primarily run as part of the continuous and + nightly/release test jobs except for TPU which is also run as a presubmit + testing a subset of the tests. + +## Different Test Configurations + +JAX's CI Test jobs run under different test configurations. These configurations +are described briefly in the sections below. + +### XLA Versions + +JAX's CI builds rely on XLA, but use different versions depending on the type of +build. To ensure stability and reproducibility, nightly and release builds use a +pinned XLA version specified in the JAX +[workspace](https://github.com/jax-ml/jax/blob/34a2f0ca4a8f8a26d9a056f8785f412bd156dc23/third_party/xla/workspace.bzl#L24-L25). + +However, to keep JAX compatible with the latest XLA developments, presubmit and +postsubmit builds utilize the most recent XLA version. This is done by +overriding the default XLA dependency with a local copy of the XLA repository. +We do this by passing `--override_repository=xla=/path/to/local/xla` which +instructs Bazel to depend on the XLA in the local system instead of the version +in the workspace. + +The CI system uses the `JAXCI` environment variables to manage this process. +When running jobs that need to use XLA at head, we set `JAXCI_CLONE_MAIN_XLA=1`. +This clones the XLA repository at head and sets `JAXCI_XLA_GIT_DIR` to its path. +[JAX build CLI](https://github.com/jax-ml/jax/blob/main/build/build.py) +automatically adds the necessary Bazel flag (`--override_repository`) to point +to this local XLA version during the build process if `JAXCI_XLA_GIT_DIR` is +set. In jobs where the build CLI is not used such as the RBE presubmits, we +explicitly include `--override_repository=xla="${JAXCI_XLA_GIT_DIR}"` as part +of the test command. + +### Enabling/Disabling 64-bit Data Types + +By default, JAX enforces single-precision numbers to mitigate the Numpy API’s +tendency to aggressively promote operands to `double`. In order to use +double-precision numbers, we need to set the `JAX_ENABLE_X64` environment +variable. In CI, we test both configurations in presubmits and postsubmits by +using the `JAXCI_ENABLE_X64` environment variable. + + + +## [Googlers Only] Connecting to CI Runners for Debugging + +If you are a Googler, you can connect to one of the self-hosted runners we have +on GitHub to debug your workflow. For more information, see +go/ml-github-actions:connect. + +## Running These Scripts Locally on Your Machine + +> [!IMPORTANT] +> If you are a Linux / Windows user, you need to have Docker installed as a +> prerequisite. Additionally, if running on Windows, please run these commands +> in a bash environment as all the scripts are written in Shell. + +Follow the steps below to run a CI script locally on your machine. + +1. [Optional] Set `JAXCI` variables in your shell environment. See + [ci/envs/README.md](https://github.com/jax-ml/jax/blob/main/ci/envs/README.md) + for the list of `JAXCI` variables and their behavior. + +2. [Linux/Windows] + + Start the Docker container by running: + + ```bash + ./ci/utilities/run_docker_container.sh + ``` + + This will start a Docker container named "jax". Note that if you set any + `JAXCI` variables in step 1, they will also be be set in the container. + + Run the script under the Docker container. + + ```bash + # docker exec jax + docker exec jax ./ci/build_artifacts.sh jaxlib + ``` + +3. [Mac] Execute the build script directly. + + ```bash + # ./ + ./ci/build_artifacts.sh jaxlib + ``` diff --git a/ci/build_artifacts.sh b/ci/build_artifacts.sh index 84b8d35a2a50..d7ffe82eb699 100755 --- a/ci/build_artifacts.sh +++ b/ci/build_artifacts.sh @@ -96,6 +96,7 @@ if [[ "${allowed_artifacts[@]}" =~ "${artifact}" ]]; then --bazel_options=--config="$bazelrc_config" $bazel_remote_cache \ --python_version=$JAXCI_HERMETIC_PYTHON_VERSION \ --verbose --detailed_timestamped_log --use_new_wheel_build_rule \ + --output_path="$JAXCI_OUTPUT_DIR" \ $artifact_tag_flags # If building release artifacts, we also build a release candidate ("rc") @@ -105,18 +106,10 @@ if [[ "${allowed_artifacts[@]}" =~ "${artifact}" ]]; then --bazel_options=--config="$bazelrc_config" $bazel_remote_cache \ --python_version=$JAXCI_HERMETIC_PYTHON_VERSION \ --verbose --detailed_timestamped_log --use_new_wheel_build_rule \ + --output_path="$JAXCI_OUTPUT_DIR" \ $artifact_tag_flags --bazel_options=--repo_env=ML_WHEEL_VERSION_SUFFIX="$JAXCI_WHEEL_RC_VERSION" fi - # Move the built artifacts from the Bazel cache directory to the output - # directory. - if [[ "$artifact" == "jax" ]]; then - mv bazel-bin/dist/*.whl "$JAXCI_OUTPUT_DIR" - mv bazel-bin/dist/*.tar.gz "$JAXCI_OUTPUT_DIR" - else - mv bazel-bin/jaxlib/tools/dist/*.whl "$JAXCI_OUTPUT_DIR" - fi - # If building `jaxlib` or `jax-cuda-plugin` or `jax-cuda-pjrt` for Linux, we # run `auditwheel show` to verify manylinux compliance. if [[ "$os" == "linux" ]] && [[ "$artifact" != "jax" ]]; then diff --git a/ci/envs/README.md b/ci/envs/README.md new file mode 100644 index 000000000000..6b5dc554d824 --- /dev/null +++ b/ci/envs/README.md @@ -0,0 +1,41 @@ +# JAXCI Environment Variables + +This docpage describes the various `JAXCI` environment variables that are used +in the CI scripts and their behaviors. These variables are used to control the +behavior of the CI scripts such as the Python version used, path to JAX/XLA +repo, if to clone XLA repo, etc. + +Name | Default Value | Behavior | Usage +------------------------------------------- | ---------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ----- +`JAXCI_JAX_GIT_DIR` | Present working directory: `$(pwd)` | Path to the JAX's Git directory. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_JAX_GIT_DIR&type=code) +`JAXCI_HERMETIC_PYTHON_VERSION` | System default | Controls the version of hermetic Python to use. This affects the Bazel commands only such as when building artifacts or when running the Bazel test scripts. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_HERMETIC_PYTHON_VERSION&type=code) +`JAXCI_XLA_GIT_DIR` | Unset | When using a local copy of XLA, this points to the root of the XLA git repoistory. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_XLA_GIT_DIR&type=code) +`JAXCI_CLONE_MAIN_XLA` | 0 | If set to 1, the XLA repository is cloned at HEAD and its path is set in `JAXCI_XLA_GIT_DIR` | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_CLONE_MAIN_XLA&type=code) +`JAXCI_XLA_COMMIT` | Unset | Allows overriding the XLA commit that is used when using a local copy of XLA. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_XLA_COMMIT&type=code) +`JAXCI_OUTPUT_DIR` | `$(pwd)/dist` | Controls the location where the artifacts are written to. The directory will be automatically created if it does not exist. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_OUTPUT_DIR&type=code) +`JAXCI_BUILD_ARTIFACT_WITH_RBE` | 0 | When set to 1, Bazel will use RBE to build the artifacts. Requires gcloud authentication and only certain platforms support RBE so this typically only set in CI builds | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_BUILD_ARTIFACT_WITH_RBE&type=code) +`JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE` | 0 | When set to 1, Bazel will also try to push new cache entries to the cache bucket. Since writes to the bucket require authentication, this flag is enabled only for CI builds. Note that the builds using RBE use the RBE cache and not Bazel's remote cache, therefore this variable is a no-op if `JAXCI_BUILD_ARTIFACT_WITH_RBE` is set to 1. When `JAXCI_BUILD_ARTIFACT_WITH_RBE` and `JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE` are both not set, Bazel will still read from the public cache bucket to try to speed up the build. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE&type=code) +`JAXCI_ARTIFACT_TYPE` | "default" | Controls the type of artifacts to build. Valid values are "default", "release", "nightly". This affects the wheel tag and metadata, see [ci/build_artifacts.sh](https://github.com/jax-ml/jax/blob/main/ci/build_artifacts.sh) to understand how. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_ARTIFACT_TYPE&type=code) +`JAXCI_WHEEL_RC_VERSION` | Unset | During the release process, we build a Release Candidate (RC) wheel in addition to the release wheel. This environment variable sets the version of the RC wheel to build. Values are set internally. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_WHEEL_RC_VERSION&type=code) +`JAXCI_PYTHON` | `python${JAXCI_HERMETIC_PYTHON_VERSION}` | Points to the system Python binary to use. It used by scripts that make use of the system Python such as the Pytest scripts. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_PYTHON&type=code) +`JAXCI_ENABLE_X64` | 0 | By default, JAX enforces single-precision numbers to mitigate the Numpy API’s tendency to aggressively promote operands to `double`. When set to 1, the tests will use double-precision numbers. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_ENABLE_X64&type=code) +`JAXCI_TPU_CORES` | Unset | Sets the number of TPU cores for the TPU machine type. Values are set in the workflow files. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_TPU_CORES&type=code) +`JAXCI_RUN_FULL_TPU_TEST_SUITE` | 0 | When set to 1, the full TPU test suite is run. Otherwise, a subset of tests is run. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_RUN_FULL_TPU_TEST_SUITE&type=code) +`JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI` | Unset | Used to control the installation of JAX [extras](https://github.com/jax-ml/jax/blob/7e42539653d33ec995487b683794c0bc86f7199b/setup.py#L64) from PyPI. See [ci/utilities/install_wheels_locally.sh](https://github.com/jax-ml/jax/blob/main/ci/utilities/install_wheels_locally.sh) for the list of valid values and their behavior. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI&type=code) + +## Docker Specific Environment Variables + +> [!NOTE] +> The following environment variables only affect the build if the +> [run_docker_container.sh](https://github.com/jax-ml/jax/blob/main/ci/utilities/run_docker_container.sh) +> script was invoked to start a Docker container and the build is running inside +> that container. Typically, this would be the internal CI builds and local +> builds. Note that while GitHub actions use the same Docker images, they do not +> invoke "run_docker_container.sh" as they leverage built-in containerization +> features to run jobs within a container. + +Name | Default Value | Behavior | Usage +----------------------- | ------------------------------------------------------------------------------------------------------------ | ---------------------------------------------------------------------------------------------------- | ----- +`JAXCI_DOCKER_WORK_DIR` | "/jax" | The path on the container where the JAX Git repository is mounted to. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_DOCKER_WORK_DIR&type=code) +`JAXCI_DOCKER_ARGS` | Empty String | Space seprated string of additional arguments that will be passed when starting the Docker container | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_DOCKER_ARGS&type=code) +`JAXCI_DOCKER_IMAGE` | Depends on the system (see [ci/envs/docker.env](https://github.com/jax-ml/jax/blob/main/ci/envs/docker.env)) | Docker image to pull | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_DOCKER_IMAGE&type=code) diff --git a/ci/envs/default.env b/ci/envs/default.env index a5a5d56eb8b3..774464724646 100644 --- a/ci/envs/default.env +++ b/ci/envs/default.env @@ -13,9 +13,8 @@ # limitations under the License. # ============================================================================== # This file contains all the default values for the "JAXCI_" environment -# variables used in the CI scripts. These variables are used to control the -# behavior of the CI scripts such as the Python version used, path to JAX/XLA -# repo, if to clone XLA repo, etc. +# variables used in the CI scripts. See ci/envs/README.md for more details on +# the behavior of these variables and their usage in the CI scripts. # The path to the JAX git repository. export JAXCI_JAX_GIT_DIR=$(pwd) @@ -25,12 +24,10 @@ export JAXCI_JAX_GIT_DIR=$(pwd) export JAXCI_HERMETIC_PYTHON_VERSION=${JAXCI_HERMETIC_PYTHON_VERSION:-$(python3 -V | awk '{print $2}' | awk -F. '{print $1"."$2}')} # Set JAXCI_XLA_GIT_DIR to the root of the XLA git repository to use a local -# copy of XLA instead of the pinned version in the WORKSPACE. When -# JAXCI_CLONE_MAIN_XLA=1, this gets set automatically. +# copy of XLA instead of the pinned version in the WORKSPACE. export JAXCI_XLA_GIT_DIR=${JAXCI_XLA_GIT_DIR:-} -# If set to 1, the builds will clone the XLA repository at HEAD and set its -# path in JAXCI_XLA_GIT_DIR. +# If set to 1, the builds will clone the XLA repository at HEAD. export JAXCI_CLONE_MAIN_XLA=${JAXCI_CLONE_MAIN_XLA:-0} # Allows overriding the XLA commit that is used. @@ -39,49 +36,35 @@ export JAXCI_XLA_COMMIT=${JAXCI_XLA_COMMIT:-} # Controls the location where the artifacts are written to. export JAXCI_OUTPUT_DIR="$(pwd)/dist" -# When enabled, artifacts will be built with RBE. Requires gcloud authentication -# and only certain platforms support RBE. Therefore, this flag is enabled only -# for CI builds where RBE is supported. +# Whether to use RBE to build the artifacts. export JAXCI_BUILD_ARTIFACT_WITH_RBE=${JAXCI_BUILD_ARTIFACT_WITH_RBE:-0} -# On platforms where RBE is not supported, we use Bazel remote cache to speed up -# builds. When this flag is enabled, Bazel will also try to push new cache -# entries to the bucket. Since writes to the bucket require authentication, this -# flag is enabled only for CI builds. +# Whether to write new cache entries to the remote cache bucket. export JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE=${JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE:-0} -# Type of artifacts to build. Valid values are "default", "release", "nightly". -# This affects the wheel naming/tag. +# Controls the type of artifacts to build. Valid values are "default", "release", "nightly". export JAXCI_ARTIFACT_TYPE=${JAXCI_ARTIFACT_TYPE:-"default"} -# When building release artifacts, we build a release candidate wheel ("rc" -# tagged wheel) in addition to the release wheel. This environment variable -# sets the version of the release candidate ("RC") artifact to build. +# Controls the version of the Release Candidate wheel to build during the +# release process. export JAXCI_WHEEL_RC_VERSION=${JAXCI_WHEEL_RC_VERSION:-} # ############################################################################# # Test script specific environment variables. # ############################################################################# -# Sets the value of `JAX_ENABLE_X64` in the test scripts. CI builds override -# this value in the Github action workflow files. +# Whether to use double-precision numbers in the tests. export JAXCI_ENABLE_X64=${JAXCI_ENABLE_X64:-0} -# Pytest specific environment variables below. Used in run_pytest_*.sh scripts. -# Sets the number of TPU cores for the TPU machine type. These values are -# defined in the TPU GitHub Actions workflow. +# Sets the number of TPU cores for the TPU machine type. export JAXCI_TPU_CORES=${JAXCI_TPU_CORES:-} -# JAXCI_PYTHON points to the Python interpreter to use for installing JAX wheels -# on the system. By default, it is set to match the version of the hermetic -# Python used by Bazel for building the wheels. +# JAXCI_PYTHON points to the Python binary on the system that should be used +# for installing the JAX wheels on the system and running Pytest scripts. export JAXCI_PYTHON=${JAXCI_PYTHON:-python${JAXCI_HERMETIC_PYTHON_VERSION}} # When set to 1, the full TPU test suite is run. Otherwise, a subset of tests # is run. export JAXCI_RUN_FULL_TPU_TEST_SUITE=${JAXCI_RUN_FULL_TPU_TEST_SUITE:-0} -# We use this environment variable to control which additional wheels to install -# from PyPI. For instance, it can be set to "tpu_pypi" to install the latest -# libtpu wheel from PyPI. See ci/utilities/install_wheels_locally.sh for the -# list of valid values and their behavior. +# Controls which additional extras to install from PyPI. export JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI=${JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI:-""} \ No newline at end of file diff --git a/ci/envs/docker.env b/ci/envs/docker.env index 82a76d33350c..a0f558520d45 100644 --- a/ci/envs/docker.env +++ b/ci/envs/docker.env @@ -41,5 +41,5 @@ fi # Windows image for building JAX artifacts, running Pytests CPU tests, and Bazel # tests if [[ $os =~ "msys_nt" ]]; then - export JAXCI_DOCKER_IMAGE="us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/tf-test-windows@sha256:6e2b299f12418d70ea522646b3dd618042a102f2ac2e4f8b1e423638549ea801" + export JAXCI_DOCKER_IMAGE="us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/tf-test-windows:latest" fi \ No newline at end of file diff --git a/ci/jax_ci_system.png b/ci/jax_ci_system.png new file mode 100644 index 000000000000..19efe62ae59e Binary files /dev/null and b/ci/jax_ci_system.png differ diff --git a/ci/run_bazel_test_cpu_rbe.sh b/ci/run_bazel_test_cpu_rbe.sh index 248111e0247a..d8cb190079e0 100755 --- a/ci/run_bazel_test_cpu_rbe.sh +++ b/ci/run_bazel_test_cpu_rbe.sh @@ -64,5 +64,7 @@ else --action_env=JAX_ENABLE_X64="$JAXCI_ENABLE_X64" \ --test_output=errors \ --color=yes \ - //tests:cpu_tests //tests:backend_independent_tests + //tests:cpu_tests //tests:backend_independent_tests \ + //jaxlib/tools:jaxlib_wheel_size_test \ + //:jax_wheel_size_test fi \ No newline at end of file diff --git a/ci/run_bazel_test_cuda_rbe.sh b/ci/run_bazel_test_cuda_rbe.sh index 17bd8d9db4f8..94c6a89fdb8c 100755 --- a/ci/run_bazel_test_cuda_rbe.sh +++ b/ci/run_bazel_test_cuda_rbe.sh @@ -48,4 +48,10 @@ bazel test --config=rbe_linux_x86_64_cuda \ --test_env=JAX_SKIP_SLOW_TESTS=true \ --action_env=JAX_ENABLE_X64="$JAXCI_ENABLE_X64" \ --color=yes \ - //tests:gpu_tests //tests:backend_independent_tests //tests/pallas:gpu_tests //tests/pallas:backend_independent_tests \ No newline at end of file + --@local_config_cuda//cuda:override_include_cuda_libs=true \ + //tests:gpu_tests //tests:backend_independent_tests \ + //tests/pallas:gpu_tests //tests/pallas:backend_independent_tests \ + //jaxlib/tools:jax_cuda_plugin_wheel_size_test \ + //jaxlib/tools:jax_cuda_pjrt_wheel_size_test \ + //jaxlib/tools:jaxlib_wheel_size_test \ + //:jax_wheel_size_test \ No newline at end of file diff --git a/ci/run_pytest_cpu.sh b/ci/run_pytest_cpu.sh index 43581ef2c96c..9de29691f753 100755 --- a/ci/run_pytest_cpu.sh +++ b/ci/run_pytest_cpu.sh @@ -26,13 +26,13 @@ set -exu -o history -o allexport # Source default JAXCI environment variables. source ci/envs/default.env +# Set up the build environment. +source "ci/utilities/setup_build_environment.sh" + # Install jaxlib wheel inside the $JAXCI_OUTPUT_DIR directory on the system. echo "Installing wheels locally..." source ./ci/utilities/install_wheels_locally.sh -# Set up the build environment. -source "ci/utilities/setup_build_environment.sh" - # Print all the installed packages echo "Installed packages:" "$JAXCI_PYTHON" -m uv pip list diff --git a/ci/utilities/README.md b/ci/utilities/README.md new file mode 100644 index 000000000000..35af5241767b --- /dev/null +++ b/ci/utilities/README.md @@ -0,0 +1,16 @@ +# JAX CI Utility Scripts + +This docpage gives a brief overview of the different utility scripts and what +they are used for. + +- **setup_build_environment.sh**: Sets up the build environment such as + cloning the latest XLA, adjusting file paths (for Windows), etc. +- **convert_msys_paths_to_win_paths.py**: Converts MSYS Linux-like paths + stored in env variables to Windows paths. +- **install_wheels_locally.sh**: Used by Pytest scripts to install JAX wheels + and any additional extras on the system. +- **run_auditwheel.sh**: Verifies that the Linux artifacts are "manylinux" + compliant. +- **run_docker_container.sh**: Runs a Docker container called "jax". Images + are read from the `JAXCI_DOCKER_IMAGE` environment variable in + [ci/envs/docker.env](https://github.com/jax-ml/jax/blob/main/ci/envs/docker.env). diff --git a/ci/utilities/install_wheels_locally.sh b/ci/utilities/install_wheels_locally.sh index f98f7658ad18..53f070d1e0e6 100644 --- a/ci/utilities/install_wheels_locally.sh +++ b/ci/utilities/install_wheels_locally.sh @@ -26,27 +26,34 @@ for i in "${!WHEELS[@]}"; do # Append [tpu] to the jax wheel name to download the latest libtpu wheel # from PyPI. WHEELS[$i]="${WHEELS[$i]}[tpu]" + elif [[ "$JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI" == "jax_cuda_pypi" ]]; then + # Append [cuda12-local] to the jax wheel name to download the latest + # release of JAX's CUDA plugin and PJRT packages from PyPI. This is used + # when running CUDA tests for a "jax" only release. + WHEELS[$i]="${WHEELS[$i]}[cuda12-local]" fi fi done -if [[ -z "${WHEELS[@]}" ]]; then - echo "ERROR: No wheels found under $JAXCI_OUTPUT_DIR" - exit 1 -fi +if [[ -n "${WHEELS[@]}" ]]; then + echo "Installing the following wheels:" + echo "${WHEELS[@]}" -echo "Installing the following wheels:" -echo "${WHEELS[@]}" - -# Install `uv` if it's not already installed. `uv` is much faster than pip for -# installing Python packages. -if ! command -v uv >/dev/null 2>&1; then - pip install uv~=0.5.30 -fi + # Install `uv` if it's not already installed. `uv` is much faster than pip for + # installing Python packages. + if ! command -v uv >/dev/null 2>&1; then + pip install uv~=0.5.30 + fi -# On Windows, convert MSYS Linux-like paths to Windows paths. -if [[ $(uname -s) =~ "MSYS_NT" ]]; then - "$JAXCI_PYTHON" -m uv pip install $(cygpath -w "${WHEELS[@]}") + # On Windows, convert MSYS Linux-like paths to Windows paths. + if [[ $(uname -s) =~ "MSYS_NT" ]]; then + "$JAXCI_PYTHON" -m uv pip install $(cygpath -w "${WHEELS[@]}") + else + "$JAXCI_PYTHON" -m uv pip install "${WHEELS[@]}" + fi else - "$JAXCI_PYTHON" -m uv pip install "${WHEELS[@]}" + # Note that we don't exit here because the wheels may have been installed + # earlier in a different step in the CI job. + echo "INFO: No wheels found under $JAXCI_OUTPUT_DIR" + echo "INFO: Skipping local wheel installation." fi \ No newline at end of file diff --git a/ci/utilities/run_auditwheel.sh b/ci/utilities/run_auditwheel.sh index 30b6a3b51865..b8f80c3e6778 100755 --- a/ci/utilities/run_auditwheel.sh +++ b/ci/utilities/run_auditwheel.sh @@ -26,6 +26,10 @@ if [[ -z "$WHEELS" ]]; then fi for wheel in $WHEELS; do + # Skip checking manylinux compliance for jax wheel. + if [[ "$wheel" =~ 'jax-' ]]; then + continue + fi printf "\nRunning auditwheel on the following wheel:" ls $wheel OUTPUT_FULL=$(python -m auditwheel show $wheel) diff --git a/cloud_tpu_colabs/JAX_NeurIPS_2020_demo.ipynb b/cloud_tpu_colabs/JAX_NeurIPS_2020_demo.ipynb index edaa71b93e85..5bc045d0f606 100644 --- a/cloud_tpu_colabs/JAX_NeurIPS_2020_demo.ipynb +++ b/cloud_tpu_colabs/JAX_NeurIPS_2020_demo.ipynb @@ -225,7 +225,7 @@ "* Jacobian pre-accumulation for elementwise operations (like `gelu`)\n", "\n", "\n", - "For much more, see the [JAX Autodiff Cookbook (Part 1)](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html)." + "For much more, see the [JAX Autodiff Cookbook (Part 1)](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html)." ] }, { diff --git a/cloud_tpu_colabs/JAX_demo.ipynb b/cloud_tpu_colabs/JAX_demo.ipynb index d7ba5ed334f4..b69246c57e0b 100644 --- a/cloud_tpu_colabs/JAX_demo.ipynb +++ b/cloud_tpu_colabs/JAX_demo.ipynb @@ -315,7 +315,7 @@ "* Jacobian pre-accumulation for elementwise operations (like `gelu`)\n", "\n", "\n", - "For much more, see the [JAX Autodiff Cookbook (Part 1)](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html)." + "For much more, see the [JAX Autodiff Cookbook (Part 1)](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html)." ] }, { diff --git a/cloud_tpu_colabs/Pmap_Cookbook.ipynb b/cloud_tpu_colabs/Pmap_Cookbook.ipynb index ea126ac4f1e7..8b16cd7694eb 100644 --- a/cloud_tpu_colabs/Pmap_Cookbook.ipynb +++ b/cloud_tpu_colabs/Pmap_Cookbook.ipynb @@ -59,7 +59,7 @@ "id": "2e_06-OAJNyi" }, "source": [ - "A basic starting point is expressing parallel maps with [`pmap`](https://jax.readthedocs.io/en/latest/jax.html#jax.pmap):" + "A basic starting point is expressing parallel maps with [`pmap`](https://docs.jax.dev/en/latest/jax.html#jax.pmap):" ] }, { @@ -407,7 +407,7 @@ "source": [ "When writing nested `pmap` functions in the decorator style, axis names are resolved according to lexical scoping.\n", "\n", - "Check [the JAX reference documentation](https://jax.readthedocs.io/en/latest/jax.lax.html#parallel-operators) for a complete list of the parallel operators. More are being added!\n", + "Check [the JAX reference documentation](https://docs.jax.dev/en/latest/jax.lax.html#parallel-operators) for a complete list of the parallel operators. More are being added!\n", "\n", "Here's how to use `lax.ppermute` to implement a simple halo exchange for a [Rule 30](https://en.wikipedia.org/wiki/Rule_30) simulation:" ] diff --git a/cloud_tpu_colabs/README.md b/cloud_tpu_colabs/README.md index db3dc5f30814..6e5501584da0 100644 --- a/cloud_tpu_colabs/README.md +++ b/cloud_tpu_colabs/README.md @@ -4,7 +4,7 @@ The same JAX code that runs on CPU and GPU can also be run on TPU. Cloud TPUs have the advantage of quickly giving you access to multiple TPU accelerators, including in [Colab](https://research.google.com/colaboratory/). All of the example notebooks here use -[`jax.pmap`](https://jax.readthedocs.io/en/latest/jax.html#jax.pmap) to run JAX +[`jax.pmap`](https://docs.jax.dev/en/latest/jax.html#jax.pmap) to run JAX computation across multiple TPU cores from Colab. You can also run the same code directly on a [Cloud TPU VM](https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm). diff --git a/docs/README.md b/docs/README.md index 12e00425592f..54b8a67477b0 100644 --- a/docs/README.md +++ b/docs/README.md @@ -1,2 +1,2 @@ To rebuild the documentation, -see [Update Documentation](https://jax.readthedocs.io/en/latest/developer.html#update-documentation). +see [Update Documentation](https://docs.jax.dev/en/latest/developer.html#update-documentation). diff --git a/docs/_static/pallas/gpu/grid_tiling_off.svg b/docs/_static/pallas/gpu/grid_tiling_off.svg new file mode 100644 index 000000000000..b11d85759ce4 --- /dev/null +++ b/docs/_static/pallas/gpu/grid_tiling_off.svg @@ -0,0 +1,175 @@ + + + + + A (6x16 tiles) + B (16x16 tiles) + C = A @ B (6x16 tiles) + + + + + + + + diff --git a/docs/_static/pallas/gpu/grid_tiling_on.svg b/docs/_static/pallas/gpu/grid_tiling_on.svg new file mode 100644 index 000000000000..9d24a8187179 --- /dev/null +++ b/docs/_static/pallas/gpu/grid_tiling_on.svg @@ -0,0 +1,183 @@ + + + + + A (6x16 tiles) + B (16x16 tiles) + C = A @ B (6x16 tiles) + + + + + + + + diff --git a/docs/_static/pallas/gpu/memory_spaces.svg b/docs/_static/pallas/gpu/memory_spaces.svg new file mode 100644 index 000000000000..73dc31a12406 --- /dev/null +++ b/docs/_static/pallas/gpu/memory_spaces.svg @@ -0,0 +1,96 @@ + + + + + + Faster / Smaller Capacity + + + Slower / Larger Capacity + + + + + + Registers (RMEM) + Fastest Latency & BW + Smallest Capacity + + Holds arrays (in Pallas). + Spills if full! + + + + + Tensor Memory (TMEM) + Fastest Latency & BW + Smallest Capacity + + Explicitly managed. + Blackwell specific. + + + + + + Shared Memory (SMEM) + Fast (close to compute) + Small Capacity (per SM) + Partitioned into private slices for each CUDA block/cluster. + + + + L2 Cache + Moderate Speed + Moderate Capacity (~100MBs) + Shared betwen SMs, not directly programmable. + + + + Global Memory (GMEM) + Slowest Latency & Bandwidth + Largest Capacity (GBs) + Main GPU memory (HBM/GDDR technology). + + + + + diff --git a/docs/_static/pallas/gpu/nvidia_sm.svg b/docs/_static/pallas/gpu/nvidia_sm.svg new file mode 100644 index 000000000000..76b4edb2afad --- /dev/null +++ b/docs/_static/pallas/gpu/nvidia_sm.svg @@ -0,0 +1,99 @@ + + + + + Streaming Multiprocessor + + + + + + Warp Scheduler + + TensorCore + + ALU + (Float/Int) + + Load/Store + + + Special + Functions + + + + + + + Warp Scheduler + + TensorCore + + ALU + (Float/Int) + + Load/Store + + + Special + Functions + + + + + + + Warp Scheduler + + TensorCore + + ALU + (Float/Int) + + Load/Store + + + Special + Functions + + + + + + + Warp Scheduler + + TensorCore + + ALU + (Float/Int) + + Load/Store + + + Special + Functions + + + + + + Shared Memory / L1 Cache + + + diff --git a/docs/_static/pallas/pipelining_bandwidth_bound.svg b/docs/_static/pallas/pipelining_bandwidth_bound.svg new file mode 100644 index 000000000000..45b78a7ce35e --- /dev/null +++ b/docs/_static/pallas/pipelining_bandwidth_bound.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/pallas/pipelining_compute_bound.svg b/docs/_static/pallas/pipelining_compute_bound.svg new file mode 100644 index 000000000000..cb3b58eaef99 --- /dev/null +++ b/docs/_static/pallas/pipelining_compute_bound.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/pallas/pipelining_example.svg b/docs/_static/pallas/pipelining_example.svg new file mode 100644 index 000000000000..59ca5b433b11 --- /dev/null +++ b/docs/_static/pallas/pipelining_example.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/pallas/pipelining_latency_multistage.svg b/docs/_static/pallas/pipelining_latency_multistage.svg new file mode 100644 index 000000000000..2c40f1692b9a --- /dev/null +++ b/docs/_static/pallas/pipelining_latency_multistage.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/pallas/pipelining_mem_hierarchy.svg b/docs/_static/pallas/pipelining_mem_hierarchy.svg new file mode 100644 index 000000000000..d7a2e6cbabd8 --- /dev/null +++ b/docs/_static/pallas/pipelining_mem_hierarchy.svg @@ -0,0 +1,30 @@ + + + + + + + + + + + + Registers + SRAM/Caches + DRAM/HBM + Network + + Fastest + Fast + Slow + Slowest + + Lowest Capacity + Low Capacity + High Capacity + Highest Capacity + + diff --git a/docs/about.md b/docs/about.md index 58e1703842b9..baeed941c8c3 100644 --- a/docs/about.md +++ b/docs/about.md @@ -19,7 +19,7 @@ technology stack](#components). First, we design the `jax` module to be [composable](https://github.com/jax-ml/jax?tab=readme-ov-file#transformations) and -[extensible](https://jax.readthedocs.io/en/latest/jax.extend.html), so +[extensible](https://docs.jax.dev/en/latest/jax.extend.html), so that a wide variety of domain-specific libraries can thrive outside of it in a decentralized manner. Second, we lean heavily on a modular backend stack (compiler and runtime) to target different @@ -42,10 +42,10 @@ scale. JAX's day-to-day development takes place in the open on GitHub, using pull requests, the issue tracker, discussions, and [JAX Enhancement Proposals -(JEPs)](https://jax.readthedocs.io/en/latest/jep/index.html). Reading +(JEPs)](https://docs.jax.dev/en/latest/jep/index.html). Reading and participating in these is a good way to get involved. We also maintain [developer -notes](https://jax.readthedocs.io/en/latest/contributor_guide.html) +notes](https://docs.jax.dev/en/latest/contributor_guide.html) that cover JAX's internal design. The JAX core team determines whether to accept changes and @@ -56,7 +56,7 @@ intricate decision structure over time (e.g. with designated area owners) if/when it becomes useful to do so. For more see [contributing to -JAX](https://jax.readthedocs.io/en/latest/contributing.html). +JAX](https://docs.jax.dev/en/latest/contributing.html). (components)= ## A modular stack @@ -71,7 +71,7 @@ and (b) an advancing hardware landscape, we lean heavily on While the JAX core library focuses on the fundamentals, we want to encourage domain-specific libraries and tools to be built on top of JAX. Indeed, [many -libraries](https://jax.readthedocs.io/en/latest/#ecosystem) have +libraries](https://docs.jax.dev/en/latest/#ecosystem) have emerged around JAX to offer higher-level features and extensions. How do we encourage such decentralized development? We guide it with @@ -80,11 +80,11 @@ building blocks (e.g. numerical primitives, NumPy operations, arrays, and transformations), encouraging auxiliary libraries to develop utilities as needed for their domain. In addition, JAX exposes a handful of more advanced APIs for -[customization](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html) +[customization](https://docs.jax.dev/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html) and -[extensibility](https://jax.readthedocs.io/en/latest/jax.extend.html). Libraries +[extensibility](https://docs.jax.dev/en/latest/jax.extend.html). Libraries can [lean on these -APIs](https://jax.readthedocs.io/en/latest/building_on_jax.html) in +APIs](https://docs.jax.dev/en/latest/building_on_jax.html) in order to use JAX as an internal means of implementation, to integrate more with its transformations like autodiff, and more. diff --git a/docs/advanced-autodiff.md b/docs/advanced-autodiff.md index eaa3bc7317c8..bef2fd088a3a 100644 --- a/docs/advanced-autodiff.md +++ b/docs/advanced-autodiff.md @@ -876,7 +876,7 @@ There are two ways to define differentiation rules in JAX: 1. Using {func}`jax.custom_jvp` and {func}`jax.custom_vjp` to define custom differentiation rules for Python functions that are already JAX-transformable; and 2. Defining new `core.Primitive` instances along with all their transformation rules, for example to call into functions from other systems like solvers, simulators, or general numerical computing systems. -This notebook is about #1. To read instead about #2, refer to the [notebook on adding primitives](https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html). +This notebook is about #1. To read instead about #2, refer to the [notebook on adding primitives](https://docs.jax.dev/en/latest/notebooks/How_JAX_primitives_work.html). ### TL;DR: Custom JVPs with {func}`jax.custom_jvp` @@ -1608,7 +1608,7 @@ Array(-0.91113025, dtype=float32) #### Working with `list` / `tuple` / `dict` containers (and other pytrees) -You should expect standard Python containers like lists, tuples, namedtuples, and dicts to just work, along with nested versions of those. In general, any [pytrees](https://jax.readthedocs.io/en/latest/pytrees.html) are permissible, so long as their structures are consistent according to the type constraints. +You should expect standard Python containers like lists, tuples, namedtuples, and dicts to just work, along with nested versions of those. In general, any [pytrees](https://docs.jax.dev/en/latest/pytrees.html) are permissible, so long as their structures are consistent according to the type constraints. Here's a contrived example with {func}`jax.custom_jvp`: diff --git a/docs/aot.md b/docs/aot.md index 1fcf11ab945d..8f68c2758148 100644 --- a/docs/aot.md +++ b/docs/aot.md @@ -26,7 +26,7 @@ are arrays, JAX does the following in order: carries out this specialization by a process that we call _tracing_. During tracing, JAX stages the specialization of `F` to a jaxpr, which is a function in the [Jaxpr intermediate - language](https://jax.readthedocs.io/en/latest/jaxpr.html). + language](https://docs.jax.dev/en/latest/jaxpr.html). 2. **Lower** this specialized, staged-out computation to the XLA compiler's input language, StableHLO. diff --git a/docs/api_compatibility.md b/docs/api_compatibility.md index 749c5907bc6b..9dca1fc08f50 100644 --- a/docs/api_compatibility.md +++ b/docs/api_compatibility.md @@ -91,7 +91,7 @@ guarantees of the main JAX package. If you have code that uses `jax.extend`, we would strongly recommend CI tests against JAX's nightly releases, so as to catch potential changes before they are released. -For details on `jax.extend`, see the [`jax.extend` module docuementation](https://jax.readthedocs.io/en/latest/jax.extend.html), or the design document, {ref}`jax-extend-jep`. +For details on `jax.extend`, see the [`jax.extend` module docuementation](https://docs.jax.dev/en/latest/jax.extend.html), or the design document, {ref}`jax-extend-jep`. ## Numerics and randomness diff --git a/docs/autodidax.ipynb b/docs/autodidax.ipynb index 7ec91affa05d..b6f12b624f8b 100644 --- a/docs/autodidax.ipynb +++ b/docs/autodidax.ipynb @@ -72,7 +72,7 @@ "outputs, we want to override primitive application and let different values\n", "flow through our program. For example, we might want to replace the\n", "application of every primitive with an application of [its JVP\n", - "rule](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html),\n", + "rule](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html),\n", "and let primal-tangent pairs flow through our program. Moreover, we want to be\n", "able to compose multiple transformations, leading to stacks of interpreters." ] @@ -3620,7 +3620,7 @@ "source": [ "Notice that we're not currently supporting the case where the predicate value\n", "itself is batched. In mainline JAX, we handle this case by transforming the\n", - "conditional to a [select primitive](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.select.html).\n", + "conditional to a [select primitive](https://docs.jax.dev/en/latest/_autosummary/jax.lax.select.html).\n", "That transformation is semantically correct so long as `true_fun` and\n", "`false_fun` do not involve any side-effecting primitives.\n", "\n", diff --git a/docs/autodidax.md b/docs/autodidax.md index 2d4d6cd528af..1c375e21227c 100644 --- a/docs/autodidax.md +++ b/docs/autodidax.md @@ -72,7 +72,7 @@ where we apply primitive operations to numerical inputs to produce numerical outputs, we want to override primitive application and let different values flow through our program. For example, we might want to replace the application of every primitive with an application of [its JVP -rule](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html), +rule](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html), and let primal-tangent pairs flow through our program. Moreover, we want to be able to compose multiple transformations, leading to stacks of interpreters. @@ -2843,7 +2843,7 @@ print(out) Notice that we're not currently supporting the case where the predicate value itself is batched. In mainline JAX, we handle this case by transforming the -conditional to a [select primitive](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.select.html). +conditional to a [select primitive](https://docs.jax.dev/en/latest/_autosummary/jax.lax.select.html). That transformation is semantically correct so long as `true_fun` and `false_fun` do not involve any side-effecting primitives. diff --git a/docs/autodidax.py b/docs/autodidax.py index f8c6372fe30d..6329234224cb 100644 --- a/docs/autodidax.py +++ b/docs/autodidax.py @@ -62,7 +62,7 @@ # outputs, we want to override primitive application and let different values # flow through our program. For example, we might want to replace the # application of every primitive with an application of [its JVP -# rule](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html), +# rule](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html), # and let primal-tangent pairs flow through our program. Moreover, we want to be # able to compose multiple transformations, leading to stacks of interpreters. @@ -2837,7 +2837,7 @@ def cond_vmap_rule(axis_size, vals_in, dims_in, *, true_jaxpr, false_jaxpr): # Notice that we're not currently supporting the case where the predicate value # itself is batched. In mainline JAX, we handle this case by transforming the -# conditional to a [select primitive](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.select.html). +# conditional to a [select primitive](https://docs.jax.dev/en/latest/_autosummary/jax.lax.select.html). # That transformation is semantically correct so long as `true_fun` and # `false_fun` do not involve any side-effecting primitives. # diff --git a/docs/building_on_jax.md b/docs/building_on_jax.md index 9416b16cde10..6d13f517f50b 100644 --- a/docs/building_on_jax.md +++ b/docs/building_on_jax.md @@ -45,8 +45,8 @@ Here are more specific examples of each pattern. ### Direct usage Jax can be directly imported and utilized to build models “from scratch” as shown across this website, -for example in [JAX Tutorials](https://jax.readthedocs.io/en/latest/tutorials.html) -or [Neural Network with JAX](https://jax.readthedocs.io/en/latest/notebooks/neural_network_with_tfds_data.html). +for example in [JAX Tutorials](https://docs.jax.dev/en/latest/tutorials.html) +or [Neural Network with JAX](https://docs.jax.dev/en/latest/notebooks/neural_network_with_tfds_data.html). This may be the best option if you are unable to find prebuilt code for your particular challenge, or if you're looking to reduce the number of dependencies in your codebase. diff --git a/docs/conf.py b/docs/conf.py index 45964b6d8d7e..87fef6337f29 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -80,6 +80,7 @@ def _do_not_evaluate_in_jax( "sphinx_remove_toctrees", 'sphinx_copybutton', 'jax_extensions', + 'jax_list_config_options', 'sphinx_design', 'sphinxext.rediraffe', ] @@ -132,6 +133,7 @@ def _do_not_evaluate_in_jax( # These are kept in sync using the jupytext pre-commit hook. 'notebooks/*.md', 'pallas/quickstart.md', + 'pallas/pipelining.md', 'pallas/tpu/pipelining.md', 'pallas/tpu/distributed.md', 'pallas/tpu/sparse.md', @@ -222,17 +224,16 @@ def _do_not_evaluate_in_jax( 'jep/9407-type-promotion.*', # TODO(jakevdp): enable execution on the following if possible: 'notebooks/Distributed_arrays_and_automatic_parallelization.*', - 'notebooks/explicit-sharding.*', 'notebooks/autodiff_remat.*', # Fails on readthedocs with Kernel Died 'notebooks/convolutions.ipynb', # Requires accelerators 'pallas/quickstart.*', + 'pallas/pipelining.*', 'pallas/tpu/pipelining.*', 'pallas/tpu/distributed.*', 'pallas/tpu/sparse.*', 'pallas/tpu/matmul.*', - 'sharded-computation.*', 'distributed_data_loading.*' ] diff --git a/docs/config_options.rst b/docs/config_options.rst new file mode 100644 index 000000000000..a8ef4e93a834 --- /dev/null +++ b/docs/config_options.rst @@ -0,0 +1,66 @@ +.. _jax: + +.. This target is required to prevent the Sphinx build error "Unknown target name: jax". +.. The custom directive list_config_options imports JAX to extract real configuration +.. data, which causes Sphinx to look for a target named "jax". This dummy target +.. satisfies that requirement while allowing the actual JAX import to work. + +Configuration Options +===================== + +JAX provides various configuration options to customize its behavior. These options control everything from numerical precision to debugging features. + +How to Use Configuration Options +-------------------------------- + +JAX configuration options can be set in several ways: + +1. **Environment variables** (set before running your program): + + .. code-block:: bash + + export JAX_ENABLE_X64=True + python my_program.py + +2. **Runtime configuration** (in your Python code): + + .. code-block:: python + + import jax + jax.config.update("jax_enable_x64", True) + +3. **Command-line flags** (using Abseil): + + .. code-block:: python + + # In your code: + import jax + jax.config.parse_flags_with_absl() + + .. code-block:: bash + + # When running: + python my_program.py --jax_enable_x64=True + +Common Configuration Options +---------------------------- + +Here are some of the most frequently used configuration options: + +- ``jax_enable_x64`` -- Enable 64-bit floating-point precision +- ``jax_disable_jit`` -- Disable JIT compilation for debugging +- ``jax_debug_nans`` -- Check for and raise errors on NaNs +- ``jax_platforms`` -- Control which backends (CPU/GPU/TPU) JAX will initialize +- ``jax_numpy_rank_promotion`` -- Control automatic rank promotion behavior +- ``jax_default_matmul_precision`` -- Set default precision for matrix multiplication operations + +.. raw:: html + +
+ +All Configuration Options +------------------------- + +Below is a complete list of all available JAX configuration options: + +.. list_config_options:: diff --git a/docs/contributing.md b/docs/contributing.md index 99d78453c436..53a863fdcd8c 100644 --- a/docs/contributing.md +++ b/docs/contributing.md @@ -6,7 +6,7 @@ Everyone can contribute to JAX, and we value everyone's contributions. There are ways to contribute, including: - Answering questions on JAX's [discussions page](https://github.com/jax-ml/jax/discussions) -- Improving or expanding JAX's [documentation](http://jax.readthedocs.io/) +- Improving or expanding JAX's [documentation](http://docs.jax.dev/) - Contributing to JAX's [code-base](http://github.com/jax-ml/jax/) - Contributing in any of the above ways to the broader ecosystem of [libraries built on JAX](https://github.com/jax-ml/jax#neural-network-libraries) diff --git a/docs/control-flow.md b/docs/control-flow.md index 7cb959f3e434..8f59bd92add7 100644 --- a/docs/control-flow.md +++ b/docs/control-flow.md @@ -244,19 +244,19 @@ lax.cond(False, lambda x: x+1, lambda x: x-1, operand) `jax.lax` provides two other functions that allow branching on dynamic predicates: -- [`lax.select`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.select.html) is +- [`lax.select`](https://docs.jax.dev/en/latest/_autosummary/jax.lax.select.html) is like a batched version of `lax.cond`, with the choices expressed as pre-computed arrays rather than as functions. -- [`lax.switch`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.switch.html) is +- [`lax.switch`](https://docs.jax.dev/en/latest/_autosummary/jax.lax.switch.html) is like `lax.cond`, but allows switching between any number of callable choices. In addition, `jax.numpy` provides several numpy-style interfaces to these functions: -- [`jnp.where`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.where.html) with +- [`jnp.where`](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.where.html) with three arguments is the numpy-style wrapper of `lax.select`. -- [`jnp.piecewise`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.piecewise.html) +- [`jnp.piecewise`](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.piecewise.html) is a numpy-style wrapper of `lax.switch`, but switches on a list of boolean conditions rather than a single scalar index. -- [`jnp.select`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.select.html) has +- [`jnp.select`](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.select.html) has an API similar to `jnp.piecewise`, but the choices are given as pre-computed arrays rather than as functions. It is implemented in terms of multiple calls to `lax.select`. diff --git a/docs/default_dtypes.md b/docs/default_dtypes.md new file mode 100644 index 000000000000..629f7fb5c314 --- /dev/null +++ b/docs/default_dtypes.md @@ -0,0 +1,82 @@ +(default-dtypes)= +# Default dtypes and the X64 flag +JAX strives to meet the needs of a range of numerical computing practitioners, who +sometimes have conflicting preferences. When it comes to default dtypes, there are +two different camps: + +- Classic scientific computing practitioners (i.e. users of tools like {mod}`numpy` or + {mod}`scipy`) tend to value accuracy of computations foremost: such users would + prefer that computations default to the **widest available representation**: e.g. + floating point values should default to `float64`, integers to `int64`, etc. +- AI researchers (i.e. folks implementing and training neural networks) tend to value + speed over accuracy, to the point where they have developed special data types like + [bfloat16](https://en.wikipedia.org/wiki/Bfloat16_floating-point_format) and others + which deliberately discard the least significant bits in order to speed up computation. + For these users, the mere presence of a float64 value in their computation can lead + to programs that are slow at best, and incompatible with their hardware at worst! + These users would prefer that computations default to `float32` or `int32`. + +The main mechanism JAX offers for this is the `jax_enable_x64` flag, which controls +whether 64-bit values can be created at all. By default this flag is set to `False` +(serving the needs of AI researchers and practitioners), but can be set to `True` +by users who value accuracy over computational speed. + +## Default setting: 32-bits everywhere +By default `jax_enable_x64` is set to False, and so {mod}`jax.numpy` array creation +functions will default to returning 32-bit values. + +For example: +```python +>>> import jax.numpy as jnp + +>>> jnp.arange(5) +Array([0, 1, 2, 3, 4], dtype=int32) + +>>> jnp.zeros(5) +Array([0., 0., 0., 0., 0.], dtype=float32) + +>>> jnp.ones(5, dtype=int) +Array([1, 1, 1, 1, 1], dtype=int32) + +``` + +Beyond defaults, because 64-bit values can be so poisonous to AI workflows, having +this flag set to False prevents you from creating 64-bit arrays at all! For example: +``` +>>> jnp.arange(5, dtype='float64') # doctest: +SKIP +UserWarning: Explicitly requested dtype float64 requested in arange is not available, and will be +truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the +JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more. +Array([0., 1., 2., 3., 4.], dtype=float32) +``` + +## The X64 flag: enabling 64-bit values +To work in the "other mode" where functions default to producing 64-bit values, you can set the +`jax_enable_x64` flag to `True`: +```python +import jax +import jax.numpy as jnp + +jax.config.update('jax_enable_x64', True) + +print(repr(jnp.arange(5))) +print(repr(jnp.zeros(5))) +print(repr(jnp.ones(5, dtype=int))) +``` +``` +Array([0, 1, 2, 3, 4], dtype=int64) +Array([0., 0., 0., 0., 0.], dtype=float64) +Array([1, 1, 1, 1, 1], dtype=int64) +``` + +The X64 configuration can also be set via the `JAX_ENABLE_X64` shell environment variable, +for example: +```bash +$ JAX_ENABLE_X64=1 python main.py +``` +The X64 flag is intended as a **global setting** that should have one value for your whole +program, set at the top of your main file. A common feature request is for the flag to +be contextually configurable (e.g. enabling X64 just for one section of a long program): +this turns out to be difficult to implement within JAX's programming model, where code +execution may happen in a different context than code compilation. There is ongoing work +exploring the feasibility of relaxing this constraint, so stay tuned! diff --git a/docs/developer.md b/docs/developer.md index 0affbba9ed36..9edeaeac83f8 100644 --- a/docs/developer.md +++ b/docs/developer.md @@ -1,7 +1,7 @@ (building-from-source)= # Building from source - + First, obtain the JAX source code: @@ -526,23 +526,27 @@ bazel test //tests:cpu_tests //tests:backend_independent_tests `//tests:gpu_tests` and `//tests:tpu_tests` are also available, if you have the necessary hardware. -To use a preinstalled `jaxlib` instead of building it you first need to -make it available in the hermetic Python. To install a specific version of -`jaxlib` within hermetic Python run (using `jaxlib >= 0.4.26` as an example): +To use the preinstalled `jax` and `jaxlib` instead of building them you first +need to make them available in the hermetic Python. To install the specific +versions of `jax` and `jaxlib` within hermetic Python run (using `jax >= 0.4.26` +and `jaxlib >= 0.4.26` as an example): ``` +echo -e "\njax >= 0.4.26" >> build/requirements.in echo -e "\njaxlib >= 0.4.26" >> build/requirements.in python build/build.py requirements_update ``` -Alternatively, to install `jaxlib` from a local wheel (assuming Python 3.12): +Alternatively, to install `jax` and `jaxlib` from the local wheels +(assuming Python 3.12): ``` +echo -e "\n$(realpath jax-0.4.26-py3-none-any.whl)" >> build/requirements.in echo -e "\n$(realpath jaxlib-0.4.26-cp312-cp312-manylinux2014_x86_64.whl)" >> build/requirements.in python build/build.py requirements_update --python_version=3.12 ``` -Once you have `jaxlib` installed hermetically, run: +Once you have `jax` and `jaxlib` installed hermetically, run: ``` bazel test --//jax:build_jaxlib=false //tests:cpu_tests //tests:backend_independent_tests @@ -785,7 +789,7 @@ desired formats, and which the `jupytext --sync` command recognizes when invoked #### Notebooks within the Sphinx build Some of the notebooks are built automatically as part of the pre-submit checks and -as part of the [Read the docs](https://jax.readthedocs.io/en/latest) build. +as part of the [Read the docs](https://docs.jax.dev/en/latest) build. The build will fail if cells raise errors. If the errors are intentional, you can either catch them, or tag the cell with `raises-exceptions` metadata ([example PR](https://github.com/jax-ml/jax/pull/2402/files)). You have to add this metadata by hand in the `.ipynb` file. It will be preserved when somebody else @@ -796,7 +800,7 @@ See `exclude_patterns` in [conf.py](https://github.com/jax-ml/jax/blob/main/docs ### Documentation building on `readthedocs.io` -JAX's auto-generated documentation is at . +JAX's auto-generated documentation is at . The documentation building is controlled for the entire project by the [readthedocs JAX settings](https://readthedocs.org/dashboard/jax). The current settings @@ -809,7 +813,7 @@ For each automated documentation build you can see the If you want to test the documentation generation on Readthedocs, you can push code to the `test-docs` branch. That branch is also built automatically, and you can -see the generated documentation [here](https://jax.readthedocs.io/en/test-docs/). If the documentation build +see the generated documentation [here](https://docs.jax.dev/en/test-docs/). If the documentation build fails you may want to [wipe the build environment for test-docs](https://docs.readthedocs.io/en/stable/guides/wipe-environment.html). For a local test, I was able to do it in a fresh directory by replaying the commands diff --git a/docs/export/export.md b/docs/export/export.md index 18cdcc6c51d0..63c0db14f905 100644 --- a/docs/export/export.md +++ b/docs/export/export.md @@ -161,7 +161,7 @@ e.g., the inference system.) What **matters is when the exporting and consuming components were built**, not the time when the exporting and the compilation happen. For external JAX users, it is -[possible to run JAX and jaxlib at different versions](https://jax.readthedocs.io/en/latest/jep/9419-jax-versioning.html#how-are-jax-and-jaxlib-versioned); +[possible to run JAX and jaxlib at different versions](https://docs.jax.dev/en/latest/jep/9419-jax-versioning.html#how-are-jax-and-jaxlib-versioned); what matters is when the jaxlib release was built. To reduce chances of incompatibility, internal JAX users should: diff --git a/docs/export/shape_poly.md b/docs/export/shape_poly.md index 9254030a4e1c..6b63a536ab48 100644 --- a/docs/export/shape_poly.md +++ b/docs/export/shape_poly.md @@ -86,7 +86,7 @@ matching the structure of the arguments passed to it. The polymorphic shapes specification can be a pytree prefix in cases where one specification should apply to multiple arguments, as in the above example. -See [how optional parameters are matched to arguments](https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees). +See [how optional parameters are matched to arguments](https://docs.jax.dev/en/latest/pytrees.html#applying-optional-parameters-to-pytrees). A few examples of shape specifications: @@ -609,7 +609,7 @@ Division had remainder 1 when computing the value of 'd'. Using the following polymorphic shapes specifications: args[0].shape = (b, b, 2*d). Obtained dimension variables: 'b' = 3 from specification 'b' for dimension args[0].shape[0] (= 3), . -Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details. +Please see https://docs.jax.dev/en/latest/export/shape_poly.html#shape-assertion-errors for more details. ``` diff --git a/docs/faq.rst b/docs/faq.rst index 44267f6f5f7d..f5d43d25afb6 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -4,7 +4,7 @@ Frequently asked questions (FAQ) .. comment RST primer for Sphinx: https://thomas-cokelaer.info/tutorials/sphinx/rest_syntax.html .. comment Some links referenced here. Use `JAX - The Sharp Bits`_ (underscore at the end) to reference -.. _JAX - The Sharp Bits: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html +.. _JAX - The Sharp Bits: https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html We are collecting answers to frequently asked questions here. Contributions welcome! @@ -116,7 +116,7 @@ code in JAX's internal representation, typically because it makes heavy use of Python control flow such as ``for`` loops. For a handful of loop iterations, Python is OK, but if you need *many* loop iterations, you should rewrite your code to make use of JAX's -`structured control flow primitives `_ +`structured control flow primitives `_ (such as :func:`lax.scan`) or avoid wrapping the loop with ``jit`` (you can still use ``jit`` decorated functions *inside* the loop). @@ -454,8 +454,8 @@ performing matrix-matrix multiplication) to amortize the increased overhead of JAX/accelerators vs NumPy/CPU. For example, if we switch this example to use 10x10 input instead, JAX/GPU runs 10x slower than NumPy/CPU (100 µs vs 10 µs). -.. _To JIT or not to JIT: https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html#to-jit-or-not-to-jit -.. _Double (64 bit) precision: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision +.. _To JIT or not to JIT: https://docs.jax.dev/en/latest/notebooks/thinking_in_jax.html#to-jit-or-not-to-jit +.. _Double (64 bit) precision: https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision .. _`%time and %timeit magics`: https://ipython.readthedocs.io/en/stable/interactive/magics.html#magic-time .. _Colab: https://colab.research.google.com/ @@ -841,12 +841,12 @@ reducing :code:`XLA_PYTHON_CLIENT_MEM_FRACTION` from the default of :code:`.75`, or setting :code:`XLA_PYTHON_CLIENT_PREALLOCATE=false`. For more details, please see the page on `JAX GPU memory allocation`_. -.. _JIT mechanics: https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html#jit-mechanics-tracing-and-static-variables -.. _External callbacks in JAX: https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html -.. _Pure callback example: https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html#example-pure-callback-with-custom-jvp -.. _IO callback example: https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html#exploring-jax-experimental-io-callback +.. _JIT mechanics: https://docs.jax.dev/en/latest/notebooks/thinking_in_jax.html#jit-mechanics-tracing-and-static-variables +.. _External callbacks in JAX: https://docs.jax.dev/en/latest/notebooks/external_callbacks.html +.. _Pure callback example: https://docs.jax.dev/en/latest/notebooks/external_callbacks.html#example-pure-callback-with-custom-jvp +.. _IO callback example: https://docs.jax.dev/en/latest/notebooks/external_callbacks.html#exploring-jax-experimental-io-callback .. _Heaviside Step Function: https://en.wikipedia.org/wiki/Heaviside_step_function .. _Sigmoid Function: https://en.wikipedia.org/wiki/Sigmoid_function .. _algebraic_simplifier.cc: https://github.com/openxla/xla/blob/33f815e190982dac4f20d1f35adb98497a382377/xla/hlo/transforms/simplifiers/algebraic_simplifier.cc#L4851 -.. _JAX GPU memory allocation: https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html +.. _JAX GPU memory allocation: https://docs.jax.dev/en/latest/gpu_memory_allocation.html .. _dynamic linker search pattern: https://man7.org/linux/man-pages/man8/ld.so.8.html diff --git a/docs/ffi.ipynb b/docs/ffi.ipynb index b622fba9d5bc..f74ae9d58a78 100644 --- a/docs/ffi.ipynb +++ b/docs/ffi.ipynb @@ -439,7 +439,7 @@ "As far as JAX is concerned, the foreign function is a black box that can't be inspected to determine the appropriate behavior when differentiated.\n", "Therefore, it is the {func}`~jax.ffi.ffi_call` user's responsibility to define a custom derivative rule.\n", "\n", - "More details about custom derivative rules can be found in the [custom derivatives tutorial](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html), but the most common pattern used for implementing differentiation for foreign functions is to define a {func}`~jax.custom_vjp` which itself calls a foreign function.\n", + "More details about custom derivative rules can be found in the [custom derivatives tutorial](https://docs.jax.dev/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html), but the most common pattern used for implementing differentiation for foreign functions is to define a {func}`~jax.custom_vjp` which itself calls a foreign function.\n", "In this case, we actually define two new FFI calls:\n", "\n", "1. `rms_norm_fwd` returns two outputs: (a) the \"primal\" result, and (b) the \"residuals\" which are used in the backwards pass.\n", @@ -785,7 +785,7 @@ "{func}`~jax.experimental.custom_partitioning.custom_partitioning` works by adding Python callbacks into the XLA compiler's partitioning pass, which allows very flexible logic, but also comes with some rough edges.\n", "We won't go into too much detail on the caveats here, but the main issues that you should be aware of are:\n", "\n", - "1. `custom_partitioning` can cause unexpected cache misses when used with the JAX's [Persistent compilation cache](https://jax.readthedocs.io/en/latest/persistent_compilation_cache.html). This can be mitigated using the `jax_remove_custom_partitioning_ptr_from_cache_key` configuration flag, but that isn't always appropriate either.\n", + "1. `custom_partitioning` can cause unexpected cache misses when used with the JAX's [Persistent compilation cache](https://docs.jax.dev/en/latest/persistent_compilation_cache.html). This can be mitigated using the `jax_remove_custom_partitioning_ptr_from_cache_key` configuration flag, but that isn't always appropriate either.\n", "2. Debugging `custom_partitioning` logic can be tedious because Python errors don't always get propagated, instead causing your Python process to exit. That being said, any exceptions will show up in the process logs, so you should be able to track them down there.\n", "\n", "All that being said, here's how we can wrap our FFI implementation of `rms_norm` using {func}`~jax.experimental.custom_partitioning.custom_partitioning`:" diff --git a/docs/ffi.md b/docs/ffi.md index 4aa03c217855..97648c78e118 100644 --- a/docs/ffi.md +++ b/docs/ffi.md @@ -353,7 +353,7 @@ Unlike with batching, {func}`~jax.ffi.ffi_call` doesn't provide any default supp As far as JAX is concerned, the foreign function is a black box that can't be inspected to determine the appropriate behavior when differentiated. Therefore, it is the {func}`~jax.ffi.ffi_call` user's responsibility to define a custom derivative rule. -More details about custom derivative rules can be found in the [custom derivatives tutorial](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html), but the most common pattern used for implementing differentiation for foreign functions is to define a {func}`~jax.custom_vjp` which itself calls a foreign function. +More details about custom derivative rules can be found in the [custom derivatives tutorial](https://docs.jax.dev/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html), but the most common pattern used for implementing differentiation for foreign functions is to define a {func}`~jax.custom_vjp` which itself calls a foreign function. In this case, we actually define two new FFI calls: 1. `rms_norm_fwd` returns two outputs: (a) the "primal" result, and (b) the "residuals" which are used in the backwards pass. @@ -591,7 +591,7 @@ If you can't use {func}`~jax.experimental.shard_map.shard_map`, an alternative a {func}`~jax.experimental.custom_partitioning.custom_partitioning` works by adding Python callbacks into the XLA compiler's partitioning pass, which allows very flexible logic, but also comes with some rough edges. We won't go into too much detail on the caveats here, but the main issues that you should be aware of are: -1. `custom_partitioning` can cause unexpected cache misses when used with the JAX's [Persistent compilation cache](https://jax.readthedocs.io/en/latest/persistent_compilation_cache.html). This can be mitigated using the `jax_remove_custom_partitioning_ptr_from_cache_key` configuration flag, but that isn't always appropriate either. +1. `custom_partitioning` can cause unexpected cache misses when used with the JAX's [Persistent compilation cache](https://docs.jax.dev/en/latest/persistent_compilation_cache.html). This can be mitigated using the `jax_remove_custom_partitioning_ptr_from_cache_key` configuration flag, but that isn't always appropriate either. 2. Debugging `custom_partitioning` logic can be tedious because Python errors don't always get propagated, instead causing your Python process to exit. That being said, any exceptions will show up in the process logs, so you should be able to track them down there. All that being said, here's how we can wrap our FFI implementation of `rms_norm` using {func}`~jax.experimental.custom_partitioning.custom_partitioning`: diff --git a/docs/gpu_memory_allocation.rst b/docs/gpu_memory_allocation.rst index 6667589e7b72..be40dfc8004c 100644 --- a/docs/gpu_memory_allocation.rst +++ b/docs/gpu_memory_allocation.rst @@ -69,7 +69,7 @@ Common causes of OOM failures disabling the automatic remat pass produces different trade-offs between compute and memory. Note however, that the algorithm is basic and you can often get better trade-off between compute and memory by disabling the automatic remat pass and doing - it manually with `the jax.remat API `_ + it manually with `the jax.remat API `_ Experimental features diff --git a/docs/gpu_performance_tips.md b/docs/gpu_performance_tips.md index bf032dccff88..bade464d22a1 100644 --- a/docs/gpu_performance_tips.md +++ b/docs/gpu_performance_tips.md @@ -1,6 +1,6 @@ # GPU performance tips - + This document focuses on performance tips for neural network workloads @@ -58,7 +58,173 @@ training on Nvidia GPUs](https://github.com/NVIDIA/JAX-Toolbox/blob/main/rosetta * **--xla_gpu_triton_gemm_any** Use the Triton-based GEMM (matmul) emitter for any GEMM that it supports. The default value is False. -### Communication flags +## Communication tips + +### Auto and manual PGLE + +The Profile Guided Latency Estimator (PGLE) workflow measures the actual running time +of compute and collectives, the the profile information is fed back into XLA compiler +for a better scheduling decision. + +The Profile Guided Latency Estimator can be used manually or automatically. In the auto mode +JAX will collect profile information and recompile a module in a single run. While +in manual mode you need to run a task twice, the first time to collect and save profiles +and the second to compile and run with provided data. + +**Important**: the JAX profiler, which is used by both of the PGLE workflows documented +below, cannot co-exist with the NVIDIA Nsight Systems profiler. This limitation can be +avoided by using the JAX compilation cache, as described below. + +### Auto PGLE +The auto PGLE can be turned on by setting the following environment variables: + +Mandatory: +```bash +export JAX_ENABLE_PGLE=true + +# For JAX version <= 0.5.0 make sure to include: +export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true" +``` + +Optional: +```bash +export JAX_PGLE_PROFILING_RUNS=3 +export JAX_PGLE_AGGREGATION_PERCENTILE=85 + +# Right now the auto PGLE profile collection doesn't work with command buffer. +# If the command buffer is enabled, Auto PGLE will disable it during profile +# colletion and enable it back after the recompilation. If you need to have a +# consistent command buffer logic with and with PGLE profile you can disable it +# manually: +export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_enable_command_buffer=''" +``` + +Or in the JAX this can be set as the following: + +``` +import jax +from jax._src import config + +with config.enable_pgle(True), config.pgle_profiling_runs(1): + # Run with the profiler collecting performance information. + train_step() + # Automatically re-compile with PGLE profile results + train_step() + ... +``` + +You can control amount of reruns used to collect profile data by changing `JAX_PGLE_PROFILING_RUNS`. +Increasing this parameter would lead to better profile information, but it will also increase the +amount of non-optimized training steps. + +Decreasing the `JAX_PGLE_AGGREGATION_PERCENTILE` parameter might help in case when performance between steps is too noisy to filter out a non-relevant measures. + +**Attention:** Auto PGLE doesn't work for pre-compiled modules. Since JAX need to recompile the module during execution the auto PGLE will not work neither for AoT nor for the following case: + +``` +import jax +from jax._src import config + +train_step_compiled = train_step().lower().compile() + +with config.enable_pgle(True), config.pgle_profiling_runs(1): + train_step_compiled() + # No effect since module was pre-compiled. + train_step_compiled() +``` + +#### Collecting NVIDIA Nsight Systems profiles when using AutoPGLE +[jax#24910](https://github.com/jax-ml/jax/pull/24910) (JAX v0.5.1 and newer) added a +new JAX configuration option, `JAX_COMPILATION_CACHE_EXPECT_PGLE`, which tells JAX to +attempt to load PGLE-optimized compiled functions from the persistent compilation +cache. + +This allows a two-step process, where the first step writes a PGLE-optimized function +to the cache: +```bash +export JAX_ENABLE_COMPILATION_CACHE=yes # not strictly needed, on by default +export JAX_COMPILATION_CACHE_DIR=/root/jax_cache +JAX_ENABLE_PGLE=yes python my-model.py +``` +And the second step uses Nsight Systems and loads the PGLE-optimized function from the +cache: +```bash +JAX_COMPILATION_CACHE_EXPECT_PGLE=yes nsys profile python my-model.py +``` +See also [this page]( +https://docs.jax.dev/en/latest/persistent_compilation_cache.html#pitfalls) for more +information about the persistent compilation cache and possible pitfalls. + +### Manual PGLE + +If you still want to use a manual Profile Guided Latency Estimator the workflow in XLA/GPU is: + +- 1. Run your workload once, with async collectives and latency hiding scheduler enabled. + +You could do so by setting: + +```bash +export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true" +``` + +- 2. Collect and post process a profile by using JAX profiler, saving the extracted instruction latencies into a binary protobuf file. + +```python +import os +from etils import epath +import jax +from jax.experimental import profiler as exp_profiler + +# Define your profile directory +profile_dir = 'gs://my_bucket/profile' +jax.profiler.start_trace(profile_dir) + +# run your workflow +# for i in range(10): +# train_step() + +# Stop trace +jax.profiler.stop_trace() +profile_dir = epath.Path(profile_dir) +directories = profile_dir.glob('plugins/profile/*/') +directories = [d for d in directories if d.is_dir()] +rundir = directories[-1] +logging.info('rundir: %s', rundir) + +# Post process the profile +fdo_profile = exp_profiler.get_profiled_instructions_proto(os.fspath(rundir)) + +# Save the profile proto to a file. +dump_dir = rundir / 'profile.pb' +dump_dir.parent.mkdir(parents=True, exist_ok=True) +dump_dir.write_bytes(fdo_profile) + +``` + +After this step, you will get a `profile.pb` file under the `rundir` printed in the code. + +- 3. Run the workload again feeding that file into the compilation. + +You need to pass the `profile.pb` file to the `--xla_gpu_pgle_profile_file_or_directory_path` flag. + +```bash + export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_pgle_profile_file_or_directory_path=/path/to/profile/profile.pb" +``` + +To enable logging in the XLA and check if the profile is good, set the logging level to include `INFO`: + +```bash +export TF_CPP_MIN_LOG_LEVEL=0 +``` + +Run the real workflow, if you found these loggings in the running log, it means the profiler is used in the latency hiding scheduler: + +``` +2023-07-21 16:09:43.551600: I external/xla/xla/service/gpu/gpu_hlo_schedule.cc:478] Using PGLE profile from /tmp/profile/plugins/profile/2023_07_20_18_29_30/profile.pb +2023-07-21 16:09:43.551741: I external/xla/xla/service/gpu/gpu_hlo_schedule.cc:573] Found profile, using profile guided latency estimator +``` + +#### Flags * **--xla_gpu_enable_latency_hiding_scheduler** This flag enables latency hiding schedulers to overlap asynchronous communication with computation efficiently. @@ -77,20 +243,6 @@ training on Nvidia GPUs](https://github.com/NVIDIA/JAX-Toolbox/blob/main/rosetta By adjusting this factor, users can fine-tune the trade-off between memory efficiency and performance optimizations. -* **--xla_gpu_enable_pipelined_collectives** When using pipeline parallelism, - this flag enables overlapping the (i+1)-th layer weight `AllGather` with the - i-th layer computation. It also enables overlapping (i+1)-th layer - weight `Reduce`/`ReduceScatter` with i-th layer's computation. The default - value is False. **There are some bugs when this flag is turned on.** -* **--xla_gpu_collective_permute_decomposer_threshold** This flag is useful when - performing [GSPMD pipelining](https://arxiv.org/abs/2105.04663). Setting a - nonzero threshold decomposes `CollectivePermute`s into - `CollectivePermuteReceiveDone` and `CollectivePermuteSendDone` pairs, so that - computation can be performed between each corresponding - `ReceiveDone`/`SendDone` pair and hence achieve more overlap. By default the - threshold is 0 and there is no decomposition. Setting it to threshold > 0 such - as `--xla_gpu_collective_permute_decomposer_threshold=1024` can enable this - feature. * **--xla_gpu_all_gather_combine_threshold_bytes** **--xla_gpu_reduce_scatter_combine_threshold_bytes** **--xla_gpu_all_reduce_combine_threshold_bytes** @@ -102,6 +254,227 @@ training on Nvidia GPUs](https://github.com/NVIDIA/JAX-Toolbox/blob/main/rosetta combine at least a Transformer Layer's weight `AllGather`/`ReduceScatter`. By default, the `combine_threshold_bytes` is set to 256. +### Pipeline Parallelism on GPU + +XLA implements SPMD-based pipeline parallelism optimizations. This is a scaling technique +where the forward and backward pass are split into multiple pipeline stages. +Each device (or device group) processes the result of the previous +pipeline stage (or the pipeline input) and sends its partial result to the next +stage until the end of the pipeline is reached. This optimization works best +when the latency of the computation is larger than communication. At compile +time, the operations will be rearranged to overlap communication with +computation. + +For an optimized schedule, we recommend these XLA flags: +``` +--xla_gpu_enable_latency_hiding_scheduler=true +--xla_gpu_enable_command_buffer='' +--xla_disable_hlo_passes=collective-permute-motion +--xla_gpu_experimental_pipeline_parallelism_opt_level=PIPELINE_PARALLELISM_OPT_LEVEL_ENABLE +``` + +The following JAX example demonstrates a pattern where communication operations +are scheduled to overlap with computations. In this example we will illustrate +how to set up an optimized pipeline parallelism scheduling using 4 GPUs that +form a communication ring (device 0 -> device 1 -> device 2 -> device 3 -> +device 0). We refer to the pattern `0 -> 1 -> 2 -> 3` as the forward edge, and +`3 -> 0` as the back edge. + +``` +# Imports and setup +import functools +import jax +from jax import sharding +from jax.experimental import mesh_utils +import jax.numpy as jnp +import jax.random + +NUM_DEVICES = 4 +NUM_MICROBATCHES = 5 +NUM_CIRC_REPEATS = 2 +CONTRACTING_DIM_SIZE = 4096 +NON_CONTRACTING_DIM_SIZE = 8192 +COMPUTE_INTENSITY = 32 + +# Creates a collective permute for the "forward edge". +# 0->1, 1->2, ... (N-2)->(N-1) +def shift_right(arr): + padding = [[1, 0]] + [[0, 0]] * (arr.ndim - 1) + # Use lax.slice to guarantee the gradient is a pad. + return jax.lax.slice(jnp.pad(arr, padding), [0] * arr.ndim, arr.shape) + + +# Creates a collective permute for the "back edge". +# (N-1)->0 +def cycle_back(arr): + padding = [[0, NUM_DEVICES - 1]] + [[0, 0]] * (arr.ndim - 1) + return jax.lax.slice( + jnp.pad(arr, padding), + [NUM_DEVICES - 1] + [0] * (arr.ndim - 1), + (NUM_DEVICES - 1 + arr.shape[0],) + arr.shape[1:], + ) + + +def select_on_first_device(then_value, else_value): + assert then_value.shape == else_value.shape + is_first_device = jax.lax.broadcasted_iota("int32", then_value.shape, 0) == 0 + return jnp.where(is_first_device, then_value, else_value) + + +def select_on_last_device(then_value, else_value): + assert then_value.shape == else_value.shape + is_last_device = ( + jax.lax.broadcasted_iota("int32", then_value.shape, 0) == NUM_DEVICES - 1 + ) + return jnp.where(is_last_device, then_value, else_value) + + +def select_on_first_cycle(i, then_value, else_value): + assert then_value.shape == else_value.shape + is_first_cycle = i < NUM_MICROBATCHES + return jnp.where(is_first_cycle, then_value, else_value) + + +def while_body(carry, i): + """Body of the pipeline while loop.""" + weights, input_buffer, output_buffer, fwd_edge_data, bwd_edge_data = carry + + # Read input data from input buffer. + input_data = jax.lax.dynamic_slice( + input_buffer, + (0, (i + 0) % NUM_MICROBATCHES, 0, 0), + (NUM_DEVICES, 1, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE), + ) + + # Collective permute on the "forward edge" shifts data to the next stage. + fwd_edge_data = shift_right(fwd_edge_data) + + # Select compute argument based on device and pipeline cycle. + compute_argument = select_on_first_device( + select_on_first_cycle(i, input_data, bwd_edge_data), + fwd_edge_data, + ).reshape((NUM_DEVICES, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE)) + + # A few matmuls to simulate compute. + tmp = compute_argument + for _ in range(COMPUTE_INTENSITY): + tmp = jax.lax.dot_general(weights, tmp, (((2,), (1,)), ((0,), (0,)))) + compute_result = tmp.reshape( + (NUM_DEVICES, 1, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE) + ) + + # Read data from buffer to pass it to the first device of the pipeline on the + # "back edge". + bwd_edge_data = jax.lax.dynamic_slice( + output_buffer, + (0, (1 + i) % NUM_MICROBATCHES, 0, 0), + (NUM_DEVICES, 1, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE), + ) + + # Colelctive permute on the "back edge" passes data to the first device. + bwd_edge_data = cycle_back(bwd_edge_data) + + # Update output buffer. We do this after reading from it to avoid the data + # dependency. + output_buffer = jax.lax.dynamic_update_slice( + output_buffer, + compute_result, + (0, (2 + i) % NUM_MICROBATCHES, 0, 0), + ) + + fwd_edge_data = compute_result + carry = ( + weights, + input_buffer, + output_buffer, + fwd_edge_data, + bwd_edge_data, + ) + return carry, i + + +@functools.partial(jax.jit, static_argnames=["mesh"]) +def entry_computation(weights, input_buffer, mesh): + + # Init output buffer. + output_buffer = jnp.zeros_like(input_buffer) + + # Init dummy data for forward and backward edge passed through the while loop. + dummy_data = jnp.zeros( + shape=(NUM_DEVICES, 1, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE) + ).astype(jnp.float32) + dummy_data = jax.device_put( + dummy_data, + sharding.NamedSharding( + mesh, sharding.PartitionSpec("the_one_and_only_axis") + ), + ) + + # Start pipeline. + carry = weights, input_buffer, output_buffer, dummy_data, dummy_data + num_iterations = NUM_CIRC_REPEATS * NUM_MICROBATCHES + NUM_DEVICES - 1 + carry, _ = jax.lax.scan(while_body, carry, xs=jnp.arange(num_iterations)) + _, _, output_buffer, _, _ = carry + + return output_buffer + + +def main(_): + + # Expect constant number of devices. + assert NUM_DEVICES == jax.local_device_count() + + # Create mesh. + mesh = sharding.Mesh( + mesh_utils.create_device_mesh([NUM_DEVICES]), + axis_names=["the_one_and_only_axis"], + ) + + # Init weights. + weights = 1.0 / CONTRACTING_DIM_SIZE + weights = jax.lax.broadcast_in_dim( + weights, + shape=(NUM_DEVICES, CONTRACTING_DIM_SIZE, CONTRACTING_DIM_SIZE), + broadcast_dimensions=(), + ) + weights = jax.device_put( + weights, + sharding.NamedSharding( + mesh, sharding.PartitionSpec("the_one_and_only_axis") + ), + ) + + # Init random input and replicate it across all devices. + random_key = jax.random.key(0) + input_buffer = jax.random.uniform( + random_key, + shape=( + NUM_MICROBATCHES, + CONTRACTING_DIM_SIZE, + NON_CONTRACTING_DIM_SIZE, + ), + ) + input_buffer = jax.lax.broadcast_in_dim( + input_buffer, + shape=( + NUM_DEVICES, + NUM_MICROBATCHES, + CONTRACTING_DIM_SIZE, + NON_CONTRACTING_DIM_SIZE, + ), + broadcast_dimensions=[1, 2, 3], + ) + input_buffer = jax.device_put( + input_buffer, + sharding.NamedSharding( + mesh, sharding.PartitionSpec("the_one_and_only_axis") + ), + ) + + # Run computation. + output_buffer = entry_computation(weights, input_buffer, mesh) + print(f"output_buffer = \n{output_buffer}") +``` ## NCCL flags These Nvidia NCCL flag values may be useful for single-host multi-device diff --git a/docs/gradient-checkpointing.md b/docs/gradient-checkpointing.md index 0938a5da944f..e4e842df49f0 100644 --- a/docs/gradient-checkpointing.md +++ b/docs/gradient-checkpointing.md @@ -341,7 +341,7 @@ def predict(params, x): return x ``` -By itself, {func}`jax.ad_checkpoint import.checkpoint_name` is just an identity function. But because some policy functions know to look for them, you can use the names to control whether certain values output by {func}`jax.ad_checkpoint import.checkpoint_name` are considered saveable: +By itself, {func}`jax.ad_checkpoint.checkpoint_name` is just an identity function. But because some policy functions know to look for them, you can use the names to control whether certain values output by {func}`jax.ad_checkpoint.checkpoint_name` are considered saveable: ```{code-cell} print_saved_residuals(loss, params, x, y) diff --git a/docs/index.rst b/docs/index.rst index ba8ebcbdd128..aafe38da9f4c 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -165,6 +165,10 @@ maintains an up-to-date list. changelog glossary +.. toctree:: + :maxdepth: 2 + + config_options .. _Awesome JAX: https://github.com/n2cholas/awesome-jax .. _AXLearn: https://github.com/apple/axlearn diff --git a/docs/installation.md b/docs/installation.md index ee675dd1e586..500347e04ab1 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -158,7 +158,7 @@ pip install --upgrade pip # Installs the wheel compatible with NVIDIA CUDA 12 and cuDNN 9.0 or newer. # Note: wheels only available on linux. -pip install --upgrade "jax[cuda12_local]" +pip install --upgrade "jax[cuda12-local]" ``` **These `pip` installations do not work with Windows, and may fail silently; refer to the table @@ -229,7 +229,7 @@ refer to JAX has experimental ROCm support. There are two ways to install JAX: * Use [AMD's Docker container](https://hub.docker.com/r/rocm/jax-community/tags); or -* Build from source. Refer to the section [Additional notes for building a ROCm jaxlib for AMD GPUs](https://jax.readthedocs.io/en/latest/developer.html#additional-notes-for-building-a-rocm-jaxlib-for-amd-gpus). +* Build from source. Refer to the section [Additional notes for building a ROCm jaxlib for AMD GPUs](https://docs.jax.dev/en/latest/developer.html#additional-notes-for-building-a-rocm-jaxlib-for-amd-gpus). (install-intel-gpu)= ## Intel GPU @@ -296,7 +296,7 @@ pip install -U --pre jax jaxlib libtpu requests -f https://storage.googleapis.co - NVIDIA GPU (CUDA 12): ```bash -pip install -U --pre jax jaxlib "jax-cuda12-plugin[with_cuda]" jax-cuda12-pjrt -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html +pip install -U --pre jax jaxlib "jax-cuda12-plugin[with-cuda]" jax-cuda12-pjrt -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html ``` - NVIDIA GPU (CUDA 12) legacy: diff --git a/docs/jax-primitives.md b/docs/jax-primitives.md index abdc8be6d0a8..819d0418e894 100644 --- a/docs/jax-primitives.md +++ b/docs/jax-primitives.md @@ -21,7 +21,7 @@ kernelspec: A JAX primitive is the basic computational unit of a JAX program. This document explains the interface that a JAX primitive must support to allow JAX to perform all its transformations (this is not a how-to guide). -For example, the multiply-add operation can be implemented in terms of the low-level `jax.lax.*` primitives (which are like XLA operator wrappers) or `jax.core.Primitive("multiply_add")`, as demonstrated further below. +For example, the multiply-add operation can be implemented in terms of the low-level `jax.lax.*` primitives (which are like XLA operator wrappers) or `jax.extend.core.Primitive("multiply_add")`, as demonstrated further below. And JAX is able to take sequences of such primitive operations, and transform them via its composable transformations of Python functions, such as {func}`jax.jit`, {func}`jax.grad` and {func}`jax.vmap`. JAX implements these transforms in a *JAX-traceable* way. This means that when a Python function is executed, the only operations it applies to the data are either: @@ -171,7 +171,7 @@ The JAX traceability property is satisfied as long as the function is written in The right way to add support for multiply-add is in terms of existing JAX primitives, as shown above. However, to demonstrate how JAX primitives work, pretend that you want to add a new primitive to JAX for the multiply-add functionality. ```{code-cell} -from jax import core +from jax.extend import core multiply_add_p = core.Primitive("multiply_add") # Create the primitive @@ -300,7 +300,7 @@ def multiply_add_lowering(ctx, xc, yc, zc): return [hlo.AddOp(hlo.MulOp(xc, yc), zc).result] # Now, register the lowering rule with JAX. -# For GPU, refer to the https://jax.readthedocs.io/en/latest/Custom_Operation_for_GPUs.html +# For GPU, refer to the https://docs.jax.dev/en/latest/Custom_Operation_for_GPUs.html from jax.interpreters import mlir mlir.register_lowering(multiply_add_p, multiply_add_lowering, platform='cpu') diff --git a/docs/jax.nn.rst b/docs/jax.nn.rst index adb13f89903d..339f07f4cdcc 100644 --- a/docs/jax.nn.rst +++ b/docs/jax.nn.rst @@ -40,6 +40,7 @@ Activation functions glu squareplus mish + identity Other functions --------------- @@ -53,3 +54,6 @@ Other functions standardize one_hot dot_product_attention + scaled_matmul + get_scaled_dot_general_config + scaled_dot_general diff --git a/docs/jax.rst b/docs/jax.rst index 98cd464cda15..2f16df613e5a 100644 --- a/docs/jax.rst +++ b/docs/jax.rst @@ -57,6 +57,7 @@ Configuration enable_custom_prng enable_custom_vjp_by_custom_transpose log_compiles + no_tracing numpy_rank_promotion transfer_guard diff --git a/docs/jax.scipy.rst b/docs/jax.scipy.rst index dcbb673997ad..3c436697e1be 100644 --- a/docs/jax.scipy.rst +++ b/docs/jax.scipy.rst @@ -69,6 +69,7 @@ jax.scipy.linalg lu lu_factor lu_solve + pascal polar qr rsf2csf diff --git a/docs/jax_array_migration.md b/docs/jax_array_migration.md index 95d4a632a295..3cc1629b2068 100644 --- a/docs/jax_array_migration.md +++ b/docs/jax_array_migration.md @@ -1,3 +1,6 @@ +--- +orphan: true +--- (jax-array-migration)= # jax.Array migration @@ -24,7 +27,7 @@ the unified jax.Array After the migration is complete `jax.Array` will be the only type of array in JAX. -This doc explains how to migrate existing codebases to `jax.Array`. For more information on using `jax.Array` and JAX parallelism APIs, see the [Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) tutorial. +This doc explains how to migrate existing codebases to `jax.Array`. For more information on using `jax.Array` and JAX parallelism APIs, see the [Distributed arrays and automatic parallelization](https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) tutorial. ### How to enable jax.Array? diff --git a/docs/jep/10657-sequencing-effects.md b/docs/jep/10657-sequencing-effects.md index 5f7eb0da4c04..ac3024519101 100644 --- a/docs/jep/10657-sequencing-effects.md +++ b/docs/jep/10657-sequencing-effects.md @@ -47,7 +47,7 @@ g() In many cases, JAX will execute `f` and `g` *in parallel*, dispatching the computations onto different threads -- `g` might actually be executed before `f`. Parallel execution is a nice performance optimization, especially if copying -to and from a device is expensive (see the [asynchronous dispatch note](https://jax.readthedocs.io/en/latest/async_dispatch.html) for more details). +to and from a device is expensive (see the [asynchronous dispatch note](https://docs.jax.dev/en/latest/async_dispatch.html) for more details). In practice, however, we often don't need to think about asynchronous dispatch because we're writing pure functions and only care about the inputs and outputs of functions -- we'll naturally block on future diff --git a/docs/jep/12049-type-annotations.md b/docs/jep/12049-type-annotations.md index 7a20958c5cab..bf6123b2bc7f 100644 --- a/docs/jep/12049-type-annotations.md +++ b/docs/jep/12049-type-annotations.md @@ -35,7 +35,7 @@ def slice(operand: Array, start_indices: Sequence[int], For the purposes of static type checking, this use of `Array = Any` for array type annotations puts no constraint on the argument values (`Any` is equivalent to no annotation at all), but it does serve as a form of useful in-code documentation for the developer. -For the sake of generated documentation, the name of the alias gets lost (the [HTML docs](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.slice.html) for `jax.lax.slice` report operand as type `Any`), so the documentation benefit does not go beyond the source code (though we could enable some `sphinx-autodoc` options to improve this: See [autodoc_type_aliases](https://www.sphinx-doc.org/en/master/usage/extensions/autodoc.html#confval-autodoc_type_aliases)). +For the sake of generated documentation, the name of the alias gets lost (the [HTML docs](https://docs.jax.dev/en/latest/_autosummary/jax.lax.slice.html) for `jax.lax.slice` report operand as type `Any`), so the documentation benefit does not go beyond the source code (though we could enable some `sphinx-autodoc` options to improve this: See [autodoc_type_aliases](https://www.sphinx-doc.org/en/master/usage/extensions/autodoc.html#confval-autodoc_type_aliases)). A benefit of this level of type annotation is that it is never wrong to annotate a value with `Any`, so it will provide a concrete benefit to developers and users in the form of documentation, without added complexity of satisfying the stricter needs of any particular static type checker. @@ -122,7 +122,7 @@ All told, the array-type-granularity challenge is less of an issue than the othe ### Challenge 5: imprecise APIs inherited from NumPy A large part of JAX’s user-facing API is inherited from NumPy within the {mod}`jax.numpy` submodule. -NumPy’s API was developed years before static type checking became part of the Python language, and follows Python’s historic recommendations to use a [duck-typing](https://docs.python.org/3/glossary.html#term-duck-typing)/[EAFP](https://docs.python.org/3/glossary.html#term-eafp) coding style, in which strict type-checking at runtime is discouraged. As a concrete example of this, consider the {func}`numpy.tile` function, which is defined like this: +NumPy’s API was developed years before static type checking became part of the Python language, and follows Python’s historic recommendations to use a [duck-typing](https://docs.python.org/3/glossary.html#term-duck-typing)/[EAFP](https://docs.python.org/3/glossary.html#term-EAFP) coding style, in which strict type-checking at runtime is discouraged. As a concrete example of this, consider the {func}`numpy.tile` function, which is defined like this: ```python def tile(A, reps): diff --git a/docs/jep/14273-shard-map.md b/docs/jep/14273-shard-map.md index 63742bc852c6..fa6681551d17 100644 --- a/docs/jep/14273-shard-map.md +++ b/docs/jep/14273-shard-map.md @@ -4,7 +4,7 @@ *January 2023* **This was the design doc proposing `shard_map`. You may instead want -[the up-to-date user docs](https://jax.readthedocs.io/en/latest/notebooks/shard_map.html).** +[the up-to-date user docs](https://docs.jax.dev/en/latest/notebooks/shard_map.html).** ## Motivation @@ -18,7 +18,7 @@ We need great APIs for both, and rather than being mutually exclusive alternatives, they need to compose with each other. With `pjit` (now just `jit`) we have [a next-gen -API](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) +API](https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) for the first school. But we haven't quite leveled-up the second school. `pmap` follows the second school, but over time we found it has [fatal flaws](#why-dont-pmap-or-xmap-already-solve-this). `xmap` solved those flaws, diff --git a/docs/jep/15856-jex.md b/docs/jep/15856-jex.md index a5625abf8930..a821405c399e 100644 --- a/docs/jep/15856-jex.md +++ b/docs/jep/15856-jex.md @@ -14,13 +14,13 @@ import jax.extend as jex Several projects depend on JAX's codebase internals, often to use its core machinery (e.g. to write a -[transformation over its IR](https://jax.readthedocs.io/en/latest/notebooks/Writing_custom_interpreters_in_Jax.html)) +[transformation over its IR](https://docs.jax.dev/en/latest/notebooks/Writing_custom_interpreters_in_Jax.html)) or to extend it (e.g. to [define new primitives](https://github.com/dfm/extending-jax)). Two challenges for these dependencies are (a) that our internals aren't all solidly designed for external use, and (b) that circumventing JAX's public API is -[unsupported](https://jax.readthedocs.io/en/latest/api_compatibility.html). +[unsupported](https://docs.jax.dev/en/latest/api_compatibility.html). In other words, our internals are often used like a library, but are neither structured nor updated like one. @@ -50,12 +50,12 @@ removed altogether. To keep development overhead low, `jax.extend` would not follow the public -[API compatibility](https://jax.readthedocs.io/en/latest/api_compatibility.html) +[API compatibility](https://docs.jax.dev/en/latest/api_compatibility.html) policy. It would promise no deprecation windows nor backwards compatibility between releases. Every release may break existing callers without simple recourse (e.g. without a flag reintroducing prior behavior). We would rely on the -[changelog](https://jax.readthedocs.io/en/latest/changelog.html) +[changelog](https://docs.jax.dev/en/latest/changelog.html) to call out such changes. Callers of `jax.extend` that need to upgrade their code regularly @@ -108,7 +108,7 @@ to process the Jaxpr IR (the output of At initialization, this module will contain many more symbols than what's needed to define primitives and rules, including various names used in setting up -["final-style transformations"](https://jax.readthedocs.io/en/latest/autodidax.html#on-the-fly-final-style-and-staged-initial-style-processing), +["final-style transformations"](https://docs.jax.dev/en/latest/autodidax.html#on-the-fly-final-style-and-staged-initial-style-processing), such as the current `jax._src.core.Trace` and `Tracer` classes. We can revisit whether `jex.core` should also support final-style extensions alongside initial style approaches, and whether it can do so by a more @@ -137,7 +137,7 @@ tracer types from `jex`. This module plus `jex.core` ought to suffice for replicating today's custom primitive tutorials (e.g. -[ours](https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html) +[ours](https://docs.jax.dev/en/latest/notebooks/How_JAX_primitives_work.html) and [dfm's](https://github.com/dfm/extending-jax)). For instance, defining a primitive and its behavior under `jax.jit` @@ -184,6 +184,6 @@ arrays. We have only one item in mind for now. The XLA compiler's array sharding format is more expressive than [those provided by -JAX](https://jax.readthedocs.io/en/latest/jax.sharding.html). We could +JAX](https://docs.jax.dev/en/latest/jax.sharding.html). We could provide this as `jex.sharding.XlaOpShardingProto`, corresponding to today's `jax._src.lib.xla_client.OpSharding` internally. diff --git a/docs/jep/17111-shmap-transpose.md b/docs/jep/17111-shmap-transpose.md index 2fdf5f822835..00d8a3f383fd 100644 --- a/docs/jep/17111-shmap-transpose.md +++ b/docs/jep/17111-shmap-transpose.md @@ -497,7 +497,7 @@ of every function instance along which the outputs are mapped, whereas for mesh axes over which the output is unmapped only one copy of the value is used. See [the `shmap` -JEP](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html) for examples +JEP](https://docs.jax.dev/en/latest/jep/14273-shard-map.html) for examples of unmapped inputs and outputs. For comparison, in `vmap` unmapped inputs/outputs are indicated by using `in_axes` / `out_axes` of `None` (rather than an `int`). diff --git a/docs/jep/2026-custom-derivatives.md b/docs/jep/2026-custom-derivatives.md index ce149fa6fb35..b09926425667 100644 --- a/docs/jep/2026-custom-derivatives.md +++ b/docs/jep/2026-custom-derivatives.md @@ -2,7 +2,7 @@ This is a design document, explaining some of the thinking behind the design and implementation of `jax.custom_jvp` and `jax.custom_vjp`. For user-oriented -documentation, see [the tutorial notebook](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html). +documentation, see [the tutorial notebook](https://docs.jax.dev/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html). There are two ways to define differentiation rules in JAX: 1. using `jax.custom_jvp` and `jax.custom_vjp` to define custom differentiation diff --git a/docs/jep/4008-custom-vjp-update.md b/docs/jep/4008-custom-vjp-update.md index 1e2270e052a6..c3f2be151ef7 100644 --- a/docs/jep/4008-custom-vjp-update.md +++ b/docs/jep/4008-custom-vjp-update.md @@ -4,7 +4,7 @@ _Oct 14 2020_ This doc assumes familiarity with `jax.custom_vjp`, as described in the [Custom derivative rules for JAX-transformable Python -functions](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html) +functions](https://docs.jax.dev/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html) notebook. ## What to update diff --git a/docs/jep/4410-omnistaging.md b/docs/jep/4410-omnistaging.md index f95c15f404b6..5b4536864ac2 100644 --- a/docs/jep/4410-omnistaging.md +++ b/docs/jep/4410-omnistaging.md @@ -266,7 +266,7 @@ While tracing the function ex1 at ex1.py:4, this value became a tracer due to JA You can use transformation parameters such as `static_argnums` for `jit` to avoid tracing particular arguments of transformed functions. -See https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information. +See https://docs.jax.dev/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information. Encountered tracer value: Tracedwith ``` diff --git a/docs/jep/9407-type-promotion.ipynb b/docs/jep/9407-type-promotion.ipynb index a1ede3177a3a..5f12877c97a9 100644 --- a/docs/jep/9407-type-promotion.ipynb +++ b/docs/jep/9407-type-promotion.ipynb @@ -12,7 +12,7 @@ "\n", "*Jake VanderPlas, December 2021*\n", "\n", - "One of the challenges faced in the design of any numerical computing library is the choice of how to handle operations between values of different types. This document outlines the thought process behind the promotion semantics used by JAX, summarized in [JAX Type Promotion Semantics](https://jax.readthedocs.io/en/latest/type_promotion.html)." + "One of the challenges faced in the design of any numerical computing library is the choice of how to handle operations between values of different types. This document outlines the thought process behind the promotion semantics used by JAX, summarized in [JAX Type Promotion Semantics](https://docs.jax.dev/en/latest/type_promotion.html)." ] }, { @@ -1335,7 +1335,7 @@ "However, these advantages comes with a few tradeoffs:\n", "\n", "- mixed float/integer promotion is very prone to precision loss: for example, `int64` (with a maximum value of $9.2 \\times 10^{18}$) can be promoted to `float16` (with a maximum value of $6.5 \\times 10^4$), meaning most representable values will become `inf`.\n", - "- as mentioned above, `f*` can no longer be thought of as a \"scalar type\", but as a different flavor of float64. In JAX's parlance, this is referred to as a [*weak type*](https://jax.readthedocs.io/en/latest/type_promotion.html#weakly-typed-values-in-jax), in that it is represented as 64-bit, but only weakly holds to this bit width in promotion with other values.\n", + "- as mentioned above, `f*` can no longer be thought of as a \"scalar type\", but as a different flavor of float64. In JAX's parlance, this is referred to as a [*weak type*](https://docs.jax.dev/en/latest/type_promotion.html#weakly-typed-values-in-jax), in that it is represented as 64-bit, but only weakly holds to this bit width in promotion with other values.\n", "\n", "Note that also, this approach still leaves the `uint64` promotion question unanswered, although it is perhaps reasonable to close the lattice by connecting `u64` to `f*`." ] @@ -1413,7 +1413,7 @@ "id": "o0-E2KWjYEXO" }, "source": [ - "The behavior resulting from this choice is summarized in [JAX Type Promotion Semantics](https://jax.readthedocs.io/en/latest/type_promotion.html). Notably, aside from the inclusion of larger unsigned types (`u16`, `u32`, `u64`) and some details about the behavior of scalar/weak types (`i*`, `f*`, `c*`), this type promotion scheme turns out to be very close to that chosen by PyTorch.\n", + "The behavior resulting from this choice is summarized in [JAX Type Promotion Semantics](https://docs.jax.dev/en/latest/type_promotion.html). Notably, aside from the inclusion of larger unsigned types (`u16`, `u32`, `u64`) and some details about the behavior of scalar/weak types (`i*`, `f*`, `c*`), this type promotion scheme turns out to be very close to that chosen by PyTorch.\n", "\n", "For those interested, the appendix below prints the full promotion tables used by NumPy, Tensorflow, PyTorch, and JAX." ] @@ -2883,7 +2883,7 @@ "source": [ "### JAX Type Promotion: `jax.numpy`\n", "\n", - "`jax.numpy` follows type promotion rules laid out at https://jax.readthedocs.io/en/latest/type_promotion.html. Here we use `i*`, `f*`, `c*` to indicate both Python scalars and weakly-typed arrays." + "`jax.numpy` follows type promotion rules laid out at https://docs.jax.dev/en/latest/type_promotion.html. Here we use `i*`, `f*`, `c*` to indicate both Python scalars and weakly-typed arrays." ] }, { diff --git a/docs/jep/9407-type-promotion.md b/docs/jep/9407-type-promotion.md index ff67a8c21399..c047d76c1b18 100644 --- a/docs/jep/9407-type-promotion.md +++ b/docs/jep/9407-type-promotion.md @@ -20,7 +20,7 @@ kernelspec: *Jake VanderPlas, December 2021* -One of the challenges faced in the design of any numerical computing library is the choice of how to handle operations between values of different types. This document outlines the thought process behind the promotion semantics used by JAX, summarized in [JAX Type Promotion Semantics](https://jax.readthedocs.io/en/latest/type_promotion.html). +One of the challenges faced in the design of any numerical computing library is the choice of how to handle operations between values of different types. This document outlines the thought process behind the promotion semantics used by JAX, summarized in [JAX Type Promotion Semantics](https://docs.jax.dev/en/latest/type_promotion.html). +++ {"id": "Rod6OOyUVbQ8"} @@ -680,7 +680,7 @@ This is important because `f16` and `bf16` are not comparable because they utili However, these advantages comes with a few tradeoffs: - mixed float/integer promotion is very prone to precision loss: for example, `int64` (with a maximum value of $9.2 \times 10^{18}$) can be promoted to `float16` (with a maximum value of $6.5 \times 10^4$), meaning most representable values will become `inf`. -- as mentioned above, `f*` can no longer be thought of as a "scalar type", but as a different flavor of float64. In JAX's parlance, this is referred to as a [*weak type*](https://jax.readthedocs.io/en/latest/type_promotion.html#weakly-typed-values-in-jax), in that it is represented as 64-bit, but only weakly holds to this bit width in promotion with other values. +- as mentioned above, `f*` can no longer be thought of as a "scalar type", but as a different flavor of float64. In JAX's parlance, this is referred to as a [*weak type*](https://docs.jax.dev/en/latest/type_promotion.html#weakly-typed-values-in-jax), in that it is represented as 64-bit, but only weakly holds to this bit width in promotion with other values. Note that also, this approach still leaves the `uint64` promotion question unanswered, although it is perhaps reasonable to close the lattice by connecting `u64` to `f*`. @@ -730,7 +730,7 @@ nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos +++ {"id": "o0-E2KWjYEXO"} -The behavior resulting from this choice is summarized in [JAX Type Promotion Semantics](https://jax.readthedocs.io/en/latest/type_promotion.html). Notably, aside from the inclusion of larger unsigned types (`u16`, `u32`, `u64`) and some details about the behavior of scalar/weak types (`i*`, `f*`, `c*`), this type promotion scheme turns out to be very close to that chosen by PyTorch. +The behavior resulting from this choice is summarized in [JAX Type Promotion Semantics](https://docs.jax.dev/en/latest/type_promotion.html). Notably, aside from the inclusion of larger unsigned types (`u16`, `u32`, `u64`) and some details about the behavior of scalar/weak types (`i*`, `f*`, `c*`), this type promotion scheme turns out to be very close to that chosen by PyTorch. For those interested, the appendix below prints the full promotion tables used by NumPy, Tensorflow, PyTorch, and JAX. @@ -900,7 +900,7 @@ display.HTML(table.to_html()) ### JAX Type Promotion: `jax.numpy` -`jax.numpy` follows type promotion rules laid out at https://jax.readthedocs.io/en/latest/type_promotion.html. Here we use `i*`, `f*`, `c*` to indicate both Python scalars and weakly-typed arrays. +`jax.numpy` follows type promotion rules laid out at https://docs.jax.dev/en/latest/type_promotion.html. Here we use `i*`, `f*`, `c*` to indicate both Python scalars and weakly-typed arrays. ```{code-cell} :cellView: form diff --git a/docs/jep/9419-jax-versioning.md b/docs/jep/9419-jax-versioning.md index b964aa2af45d..85b95257ebae 100644 --- a/docs/jep/9419-jax-versioning.md +++ b/docs/jep/9419-jax-versioning.md @@ -167,16 +167,16 @@ We maintain an additional version number (`_version`) in [`xla_client.py` in the XLA repository](https://github.com/openxla/xla/blob/main/xla/python/xla_client.py). The idea is that this version number, is defined in `xla/python` together with the C++ parts of JAX, is also accessible to JAX Python as -`jax._src.lib.xla_extension_version`, and must +`jax._src.lib.jaxlib_extension_version`, and must be incremented every time that a change is made to the XLA/Python code that has backwards compatibility implications for `jax`. The JAX Python code can then use this version number to maintain backwards compatibility, e.g.: ``` -from jax._src.lib import xla_extension_version +from jax._src.lib import jaxlib_extension_version # 123 is the new version number for _version in xla_client.py -if xla_extension_version >= 123: +if jaxlib_extension_version >= 123: # Use new code path ... else: diff --git a/docs/jit-compilation.md b/docs/jit-compilation.md index 5e5be308068a..093f5ec4ab72 100644 --- a/docs/jit-compilation.md +++ b/docs/jit-compilation.md @@ -55,7 +55,7 @@ The {ref}`jax-internals-jaxpr` section of the documentation provides more inform Importantly, notice that the jaxpr does not capture the side-effect present in the function: there is nothing in it corresponding to `global_list.append(x)`. This is a feature, not a bug: JAX transformations are designed to understand side-effect-free (a.k.a. functionally pure) code. -If *pure function* and *side-effect* are unfamiliar terms, this is explained in a little more detail in [🔪 JAX - The Sharp Bits 🔪: Pure Functions](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions). +If *pure function* and *side-effect* are unfamiliar terms, this is explained in a little more detail in [🔪 JAX - The Sharp Bits 🔪: Pure Functions](https:docs.jax.devio/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions). Impure functions are dangerous because under JAX transformations they are likely not to behave as intended; they might fail silently, or produce surprising downstream errors like leaked Tracers. Moreover, JAX often can't detect when side effects are present. diff --git a/docs/notebooks/Common_Gotchas_in_JAX.ipynb b/docs/notebooks/Common_Gotchas_in_JAX.ipynb index a1435c4e557e..a9d4a9424f9f 100644 --- a/docs/notebooks/Common_Gotchas_in_JAX.ipynb +++ b/docs/notebooks/Common_Gotchas_in_JAX.ipynb @@ -307,7 +307,7 @@ "id": "go3L4x3w4-9p" }, "source": [ - "If we try to update a JAX device array in-place, however, we get an __error__! (☉_☉)" + "If we try to do in-place indexed updating on a `jax.Array`, however, we get an __error__! (☉_☉)" ] }, { @@ -346,7 +346,7 @@ "evalue": "ignored", "output_type": "error", "traceback": [ - "\u001b[0;31mTypeError\u001b[0m\u001b[0;31m:\u001b[0m '' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html\n" + "\u001b[0;31mTypeError\u001b[0m\u001b[0;31m:\u001b[0m '' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html\n" ] } ], @@ -357,6 +357,45 @@ "jax_array[1, :] = 1.0" ] }, + { + "cell_type": "markdown", + "id": "8f520bec", + "metadata": {}, + "source": [ + "And if we try to do `__iadd__`-style in-place updating, we get __different behavior than NumPy__! (☉_☉) (☉_☉)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "20fbed45", + "metadata": {}, + "outputs": [], + "source": [ + "jax_array = jnp.array([10, 20])\n", + "jax_array_new = jax_array\n", + "jax_array_new += 10\n", + "print(jax_array_new) # `jax_array_new` is rebound to a new value [20, 30], but...\n", + "print(jax_array) # the original value is unodified as [10, 20] !\n", + "\n", + "numpy_array = np.array([10, 20])\n", + "numpy_array_new = numpy_array\n", + "numpy_array_new += 10\n", + "print(numpy_array_new) # `numpy_array_new is numpy_array`, and it was updated\n", + "print(numpy_array) # in-place, so both are [20, 30] !" + ] + }, + { + "cell_type": "markdown", + "id": "2604e220", + "metadata": {}, + "source": [ + "That's because NumPy defines `__iadd__` to perform in-place mutation. In\n", + "contrast, `jax.Array` doesn't define an `__iadd__`, so Python treats\n", + "`jax_array_new += 10` as syntactic sugar for `jax_array_new = jax_array_new +\n", + "10`, rebinding the variable without mutating any arrays." + ] + }, { "cell_type": "markdown", "metadata": { @@ -365,7 +404,7 @@ "source": [ "Allowing mutation of variables in-place makes program analysis and transformation difficult. JAX requires that programs are pure functions.\n", "\n", - "Instead, JAX offers a _functional_ array update using the [`.at` property on JAX arrays](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at)." + "Instead, JAX offers a _functional_ array update using the [`.at` property on JAX arrays](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at)." ] }, { @@ -415,6 +454,7 @@ } ], "source": [ + "jax_array = jnp.zeros((3,3), dtype=jnp.float32)\n", "updated_array = jax_array.at[1, :].set(1.0)\n", "print(\"updated array:\\n\", updated_array)" ] @@ -521,7 +561,7 @@ "id": "sTjJ3WuaDyqU" }, "source": [ - "For more details on indexed array updates, see the [documentation for the `.at` property](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at)." + "For more details on indexed array updates, see the [documentation for the `.at` property](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at)." ] }, { @@ -604,7 +644,7 @@ "id": "NAcXJNAcDi_v" }, "source": [ - "If you would like finer-grained control over the behavior for out-of-bound indices, you can use the optional parameters of [`ndarray.at`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html); for example:" + "If you would like finer-grained control over the behavior for out-of-bound indices, you can use the optional parameters of [`ndarray.at`](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html); for example:" ] }, { @@ -971,7 +1011,7 @@ "evalue": "ignored", "output_type": "error", "traceback": [ - "\u001b[0;31mNonConcreteBooleanIndexError\u001b[0m\u001b[0;31m:\u001b[0m Array boolean indices must be concrete; got ShapedArray(bool[5])\n\nSee https://jax.readthedocs.io/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError\n" + "\u001b[0;31mNonConcreteBooleanIndexError\u001b[0m\u001b[0;31m:\u001b[0m Array boolean indices must be concrete; got ShapedArray(bool[5])\n\nSee https://docs.jax.dev/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError\n" ] } ], @@ -1296,7 +1336,7 @@ "While `jax.numpy` makes every attempt to replicate the behavior of numpy's API, there do exist corner cases where the behaviors differ.\n", "Many such cases are discussed in detail in the sections above; here we list several other known places where the APIs diverge.\n", "\n", - "- For binary operations, JAX's type promotion rules differ somewhat from those used by NumPy. See [Type Promotion Semantics](https://jax.readthedocs.io/en/latest/type_promotion.html) for more details.\n", + "- For binary operations, JAX's type promotion rules differ somewhat from those used by NumPy. See [Type Promotion Semantics](https://docs.jax.dev/en/latest/type_promotion.html) for more details.\n", "- When performing unsafe type casts (i.e. casts in which the target dtype cannot represent the input value), JAX's behavior may be backend dependent, and in general may diverge from NumPy's behavior. Numpy allows control over the result in these scenarios via the `casting` argument (see [`np.ndarray.astype`](https://numpy.org/devdocs/reference/generated/numpy.ndarray.astype.html)); JAX does not provide any such configuration, instead directly inheriting the behavior of [XLA:ConvertElementType](https://www.tensorflow.org/xla/operation_semantics#convertelementtype).\n", "\n", " Here is an example of an unsafe cast with differing results between NumPy and JAX:\n", diff --git a/docs/notebooks/Common_Gotchas_in_JAX.md b/docs/notebooks/Common_Gotchas_in_JAX.md index 80ab69be1ed8..0857edc132fa 100644 --- a/docs/notebooks/Common_Gotchas_in_JAX.md +++ b/docs/notebooks/Common_Gotchas_in_JAX.md @@ -177,7 +177,7 @@ print(numpy_array) +++ {"id": "go3L4x3w4-9p"} -If we try to update a JAX device array in-place, however, we get an __error__! (☉_☉) +If we try to do in-place indexed updating on a `jax.Array`, however, we get an __error__! (☉_☉) ```{code-cell} ipython3 :id: iOscaa_GecEK @@ -197,11 +197,32 @@ jax_array = jnp.zeros((3,3), dtype=jnp.float32) jax_array[1, :] = 1.0 ``` +And if we try to do `__iadd__`-style in-place updating, we get __different behavior than NumPy__! (☉_☉) (☉_☉) + +```{code-cell} ipython3 +jax_array = jnp.array([10, 20]) +jax_array_new = jax_array +jax_array_new += 10 +print(jax_array_new) # `jax_array_new` is rebound to a new value [20, 30], but... +print(jax_array) # the original value is unodified as [10, 20] ! + +numpy_array = np.array([10, 20]) +numpy_array_new = numpy_array +numpy_array_new += 10 +print(numpy_array_new) # `numpy_array_new is numpy_array`, and it was updated +print(numpy_array) # in-place, so both are [20, 30] ! +``` + +That's because NumPy defines `__iadd__` to perform in-place mutation. In +contrast, `jax.Array` doesn't define an `__iadd__`, so Python treats +`jax_array_new += 10` as syntactic sugar for `jax_array_new = jax_array_new + +10`, rebinding the variable without mutating any arrays. + +++ {"id": "7mo76sS25Wco"} Allowing mutation of variables in-place makes program analysis and transformation difficult. JAX requires that programs are pure functions. -Instead, JAX offers a _functional_ array update using the [`.at` property on JAX arrays](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at). +Instead, JAX offers a _functional_ array update using the [`.at` property on JAX arrays](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at). +++ {"id": "hfloZ1QXCS_J"} @@ -219,6 +240,7 @@ For example, the update above can be written as: :id: PBGI-HIeCP_s :outputId: de13f19a-2066-4df1-d503-764c34585529 +jax_array = jnp.zeros((3,3), dtype=jnp.float32) updated_array = jax_array.at[1, :].set(1.0) print("updated array:\n", updated_array) ``` @@ -261,7 +283,7 @@ print(new_jax_array) +++ {"id": "sTjJ3WuaDyqU"} -For more details on indexed array updates, see the [documentation for the `.at` property](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at). +For more details on indexed array updates, see the [documentation for the `.at` property](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at). +++ {"id": "oZ_jE2WAypdL"} @@ -292,7 +314,7 @@ jnp.arange(10)[11] +++ {"id": "NAcXJNAcDi_v"} -If you would like finer-grained control over the behavior for out-of-bound indices, you can use the optional parameters of [`ndarray.at`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html); for example: +If you would like finer-grained control over the behavior for out-of-bound indices, you can use the optional parameters of [`ndarray.at`](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html); for example: ```{code-cell} ipython3 :id: -0-MaFddO-xy @@ -664,7 +686,7 @@ x.dtype # --> dtype('float64') While `jax.numpy` makes every attempt to replicate the behavior of numpy's API, there do exist corner cases where the behaviors differ. Many such cases are discussed in detail in the sections above; here we list several other known places where the APIs diverge. -- For binary operations, JAX's type promotion rules differ somewhat from those used by NumPy. See [Type Promotion Semantics](https://jax.readthedocs.io/en/latest/type_promotion.html) for more details. +- For binary operations, JAX's type promotion rules differ somewhat from those used by NumPy. See [Type Promotion Semantics](https://docs.jax.dev/en/latest/type_promotion.html) for more details. - When performing unsafe type casts (i.e. casts in which the target dtype cannot represent the input value), JAX's behavior may be backend dependent, and in general may diverge from NumPy's behavior. Numpy allows control over the result in these scenarios via the `casting` argument (see [`np.ndarray.astype`](https://numpy.org/devdocs/reference/generated/numpy.ndarray.astype.html)); JAX does not provide any such configuration, instead directly inheriting the behavior of [XLA:ConvertElementType](https://www.tensorflow.org/xla/operation_semantics#convertelementtype). Here is an example of an unsafe cast with differing results between NumPy and JAX: diff --git a/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb b/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb index e550cbf36da3..e80c7ae94687 100644 --- a/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb +++ b/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb @@ -17,9 +17,9 @@ "1. using `jax.custom_jvp` and `jax.custom_vjp` to define custom differentiation rules for Python functions that are already JAX-transformable; and\n", "2. defining new `core.Primitive` instances along with all their transformation rules, for example to call into functions from other systems like solvers, simulators, or general numerical computing systems.\n", "\n", - "This notebook is about #1. To read instead about #2, see the [notebook on adding primitives](https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html).\n", + "This notebook is about #1. To read instead about #2, see the [notebook on adding primitives](https://docs.jax.dev/en/latest/notebooks/How_JAX_primitives_work.html).\n", "\n", - "For an introduction to JAX's automatic differentiation API, see [The Autodiff Cookbook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html). This notebook assumes some familiarity with [jax.jvp](https://jax.readthedocs.io/en/latest/jax.html#jax.jvp) and [jax.grad](https://jax.readthedocs.io/en/latest/jax.html#jax.grad), and the mathematical meaning of JVPs and VJPs." + "For an introduction to JAX's automatic differentiation API, see [The Autodiff Cookbook](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html). This notebook assumes some familiarity with [jax.jvp](https://docs.jax.dev/en/latest/jax.html#jax.jvp) and [jax.grad](https://docs.jax.dev/en/latest/jax.html#jax.grad), and the mathematical meaning of JVPs and VJPs." ] }, { @@ -2035,7 +2035,7 @@ "source": [ "### Working with `list` / `tuple` / `dict` containers (and other pytrees)\n", "\n", - "You should expect standard Python containers like lists, tuples, namedtuples, and dicts to just work, along with nested versions of those. In general, any [pytrees](https://jax.readthedocs.io/en/latest/pytrees.html) are permissible, so long as their structures are consistent according to the type constraints. \n", + "You should expect standard Python containers like lists, tuples, namedtuples, and dicts to just work, along with nested versions of those. In general, any [pytrees](https://docs.jax.dev/en/latest/pytrees.html) are permissible, so long as their structures are consistent according to the type constraints. \n", "\n", "Here's a contrived example with `jax.custom_jvp`:" ] diff --git a/docs/notebooks/Custom_derivative_rules_for_Python_code.md b/docs/notebooks/Custom_derivative_rules_for_Python_code.md index 8a63f142693e..82b97e195bd9 100644 --- a/docs/notebooks/Custom_derivative_rules_for_Python_code.md +++ b/docs/notebooks/Custom_derivative_rules_for_Python_code.md @@ -24,9 +24,9 @@ There are two ways to define differentiation rules in JAX: 1. using `jax.custom_jvp` and `jax.custom_vjp` to define custom differentiation rules for Python functions that are already JAX-transformable; and 2. defining new `core.Primitive` instances along with all their transformation rules, for example to call into functions from other systems like solvers, simulators, or general numerical computing systems. -This notebook is about #1. To read instead about #2, see the [notebook on adding primitives](https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html). +This notebook is about #1. To read instead about #2, see the [notebook on adding primitives](https://docs.jax.dev/en/latest/notebooks/How_JAX_primitives_work.html). -For an introduction to JAX's automatic differentiation API, see [The Autodiff Cookbook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html). This notebook assumes some familiarity with [jax.jvp](https://jax.readthedocs.io/en/latest/jax.html#jax.jvp) and [jax.grad](https://jax.readthedocs.io/en/latest/jax.html#jax.grad), and the mathematical meaning of JVPs and VJPs. +For an introduction to JAX's automatic differentiation API, see [The Autodiff Cookbook](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html). This notebook assumes some familiarity with [jax.jvp](https://docs.jax.dev/en/latest/jax.html#jax.jvp) and [jax.grad](https://docs.jax.dev/en/latest/jax.html#jax.grad), and the mathematical meaning of JVPs and VJPs. +++ {"id": "9Fg3NFNY-2RY"} @@ -1048,7 +1048,7 @@ Array(-0.91113025, dtype=float32) ### Working with `list` / `tuple` / `dict` containers (and other pytrees) -You should expect standard Python containers like lists, tuples, namedtuples, and dicts to just work, along with nested versions of those. In general, any [pytrees](https://jax.readthedocs.io/en/latest/pytrees.html) are permissible, so long as their structures are consistent according to the type constraints. +You should expect standard Python containers like lists, tuples, namedtuples, and dicts to just work, along with nested versions of those. In general, any [pytrees](https://docs.jax.dev/en/latest/pytrees.html) are permissible, so long as their structures are consistent according to the type constraints. Here's a contrived example with `jax.custom_jvp`: diff --git a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb index 8abee469d552..90d92c4ea241 100644 --- a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb +++ b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb @@ -1276,7 +1276,7 @@ "id": "3qfPjJdhgerc" }, "source": [ - "So computation follows data placement: when we explicitly shard data with `jax.device_put`, and apply functions to that data, the compiler attempts to parallelize the computation and decide the output sharding. This policy for sharded data is a generalization of [JAX's policy of following explicit device placement](https://jax.readthedocs.io/en/latest/faq.html#controlling-data-and-computation-placement-on-devices)." + "So computation follows data placement: when we explicitly shard data with `jax.device_put`, and apply functions to that data, the compiler attempts to parallelize the computation and decide the output sharding. This policy for sharded data is a generalization of [JAX's policy of following explicit device placement](https://docs.jax.dev/en/latest/faq.html#controlling-data-and-computation-placement-on-devices)." ] }, { @@ -1382,7 +1382,7 @@ "id": "6ZYcK8eXrn0p" }, "source": [ - "We say arrays that have been explicitly placed or sharded with `jax.device_put` are _committed_ to their device(s), and so won't be automatically moved. See the [device placement FAQ](https://jax.readthedocs.io/en/latest/faq.html#controlling-data-and-computation-placement-on-devices) for more information.\n", + "We say arrays that have been explicitly placed or sharded with `jax.device_put` are _committed_ to their device(s), and so won't be automatically moved. See the [device placement FAQ](https://docs.jax.dev/en/latest/faq.html#controlling-data-and-computation-placement-on-devices) for more information.\n", "\n", "When arrays are _not_ explicitly placed or sharded with `jax.device_put`, they are placed _uncommitted_ on the default device.\n", "Unlike committed arrays, uncommitted arrays can be moved and resharded automatically: that is, uncommitted arrays can be arguments to a computation even if other arguments are explicitly placed on different devices.\n", @@ -2339,7 +2339,7 @@ "source": [ "### Generating random numbers\n", "\n", - "JAX comes with a functional, deterministic [random number generator](https://jax.readthedocs.io/en/latest/jep/263-prng.html). It underlies the various sampling functions in the [`jax.random` module](https://jax.readthedocs.io/en/latest/jax.random.html), such as `jax.random.uniform`.\n", + "JAX comes with a functional, deterministic [random number generator](https://docs.jax.dev/en/latest/jep/263-prng.html). It underlies the various sampling functions in the [`jax.random` module](https://docs.jax.dev/en/latest/jax.random.html), such as `jax.random.uniform`.\n", "\n", "JAX's random numbers are produced by a counter-based PRNG, so in principle, random number generation should be a pure map over counter values. A pure map is a trivially partitionable operation in principle. It should require no cross-device communication, nor any redundant computation across devices.\n", "\n", diff --git a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md index c207f0ae4a00..79990fefb95d 100644 --- a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md +++ b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md @@ -427,7 +427,7 @@ jax.debug.visualize_array_sharding(w_copy) +++ {"id": "3qfPjJdhgerc"} -So computation follows data placement: when we explicitly shard data with `jax.device_put`, and apply functions to that data, the compiler attempts to parallelize the computation and decide the output sharding. This policy for sharded data is a generalization of [JAX's policy of following explicit device placement](https://jax.readthedocs.io/en/latest/faq.html#controlling-data-and-computation-placement-on-devices). +So computation follows data placement: when we explicitly shard data with `jax.device_put`, and apply functions to that data, the compiler attempts to parallelize the computation and decide the output sharding. This policy for sharded data is a generalization of [JAX's policy of following explicit device placement](https://docs.jax.dev/en/latest/faq.html#controlling-data-and-computation-placement-on-devices). +++ {"id": "QRB95LaWuT80"} @@ -484,7 +484,7 @@ except ValueError as e: print_exception(e) +++ {"id": "6ZYcK8eXrn0p"} -We say arrays that have been explicitly placed or sharded with `jax.device_put` are _committed_ to their device(s), and so won't be automatically moved. See the [device placement FAQ](https://jax.readthedocs.io/en/latest/faq.html#controlling-data-and-computation-placement-on-devices) for more information. +We say arrays that have been explicitly placed or sharded with `jax.device_put` are _committed_ to their device(s), and so won't be automatically moved. See the [device placement FAQ](https://docs.jax.dev/en/latest/faq.html#controlling-data-and-computation-placement-on-devices) for more information. When arrays are _not_ explicitly placed or sharded with `jax.device_put`, they are placed _uncommitted_ on the default device. Unlike committed arrays, uncommitted arrays can be moved and resharded automatically: that is, uncommitted arrays can be arguments to a computation even if other arguments are explicitly placed on different devices. @@ -854,7 +854,7 @@ outputId: 479c4d81-cb0b-40a5-89ba-394c10dc3297 ### Generating random numbers -JAX comes with a functional, deterministic [random number generator](https://jax.readthedocs.io/en/latest/jep/263-prng.html). It underlies the various sampling functions in the [`jax.random` module](https://jax.readthedocs.io/en/latest/jax.random.html), such as `jax.random.uniform`. +JAX comes with a functional, deterministic [random number generator](https://docs.jax.dev/en/latest/jep/263-prng.html). It underlies the various sampling functions in the [`jax.random` module](https://docs.jax.dev/en/latest/jax.random.html), such as `jax.random.uniform`. JAX's random numbers are produced by a counter-based PRNG, so in principle, random number generation should be a pure map over counter values. A pure map is a trivially partitionable operation in principle. It should require no cross-device communication, nor any redundant computation across devices. diff --git a/docs/notebooks/README.md b/docs/notebooks/README.md index 07be4441ade8..c945c197ad19 100644 --- a/docs/notebooks/README.md +++ b/docs/notebooks/README.md @@ -1,2 +1,2 @@ For instructions on how to change and test notebooks, see -[Update Documentation](https://jax.readthedocs.io/en/latest/developer.html#update-documentation). +[Update Documentation](https://docs.jax.dev/en/latest/developer.html#update-documentation). diff --git a/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb b/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb index 00ba9186eeec..d22457c5d718 100644 --- a/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb +++ b/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb @@ -24,7 +24,7 @@ "\n", "Here we show how to add your own function transformations to the system, by writing a custom Jaxpr interpreter. And we'll get composability with all the other transformations for free.\n", "\n", - "**This example uses internal JAX APIs, which may break at any time. Anything not in [the API Documentation](https://jax.readthedocs.io/en/latest/jax.html) should be assumed internal.**" + "**This example uses internal JAX APIs, which may break at any time. Anything not in [the API Documentation](https://docs.jax.dev/en/latest/jax.html) should be assumed internal.**" ] }, { @@ -215,8 +215,8 @@ "# Importing Jax functions useful for tracing/interpreting.\n", "from functools import wraps\n", "\n", - "from jax import core\n", "from jax import lax\n", + "from jax.extend import core\n", "from jax._src.util import safe_map" ] }, diff --git a/docs/notebooks/Writing_custom_interpreters_in_Jax.md b/docs/notebooks/Writing_custom_interpreters_in_Jax.md index 10c4e7cb6e3b..ad707a9746fc 100644 --- a/docs/notebooks/Writing_custom_interpreters_in_Jax.md +++ b/docs/notebooks/Writing_custom_interpreters_in_Jax.md @@ -27,7 +27,7 @@ etc.) that enable writing concise, accelerated code. Here we show how to add your own function transformations to the system, by writing a custom Jaxpr interpreter. And we'll get composability with all the other transformations for free. -**This example uses internal JAX APIs, which may break at any time. Anything not in [the API Documentation](https://jax.readthedocs.io/en/latest/jax.html) should be assumed internal.** +**This example uses internal JAX APIs, which may break at any time. Anything not in [the API Documentation](https://docs.jax.dev/en/latest/jax.html) should be assumed internal.** ```{code-cell} ipython3 :id: s27RDKvKXFL8 @@ -147,8 +147,8 @@ Let's use `make_jaxpr` to trace a function into a Jaxpr. # Importing Jax functions useful for tracing/interpreting. from functools import wraps -from jax import core from jax import lax +from jax.extend import core from jax._src.util import safe_map ``` diff --git a/docs/notebooks/autodiff_remat.ipynb b/docs/notebooks/autodiff_remat.ipynb index feb906546341..d8a74e4b15fd 100644 --- a/docs/notebooks/autodiff_remat.ipynb +++ b/docs/notebooks/autodiff_remat.ipynb @@ -348,7 +348,7 @@ "source": [ "### Let's think step by step\n", "\n", - "You might want to first (re)read [the Autodiff Cookbook Part 1](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html)." + "You might want to first (re)read [the Autodiff Cookbook Part 1](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html)." ] }, { diff --git a/docs/notebooks/autodiff_remat.md b/docs/notebooks/autodiff_remat.md index 8ba87dcfee18..12564bd91f30 100644 --- a/docs/notebooks/autodiff_remat.md +++ b/docs/notebooks/autodiff_remat.md @@ -156,7 +156,7 @@ print_fwd_bwd(f3, W1, W2, W3, x) ### Let's think step by step -You might want to first (re)read [the Autodiff Cookbook Part 1](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html). +You might want to first (re)read [the Autodiff Cookbook Part 1](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html). +++ {"id": "VMfwm_yinvoZ"} diff --git a/docs/notebooks/explicit-sharding.ipynb b/docs/notebooks/explicit-sharding.ipynb index d656e12d4068..dada1f0db507 100644 --- a/docs/notebooks/explicit-sharding.ipynb +++ b/docs/notebooks/explicit-sharding.ipynb @@ -28,7 +28,7 @@ "of work and it's also easy to make mistakes that way because there's no way to\n", "check that the shardings make sense together. More commonly, people add just\n", "enough sharding annotations to constrain the compiler. But this is a slow\n", - "iterative process. It's hard to know ahead of time what XLA's gSPMD pass will\n", + "iterative process. It's hard to know ahead of time what XLA's GSPMD pass will\n", "do (it's a whole-program optimization) so all you can do is add annotations,\n", "inspect XLA's sharding choices to see what happened, and repeat.\n", "\n", @@ -59,7 +59,7 @@ "import numpy as np\n", "import jax.numpy as jnp\n", "from jax.sharding import PartitionSpec as P, AxisType, set_mesh, get_abstract_mesh\n", - "from jax.experimental.shard import reshard, auto_axes\n", + "from jax.experimental.shard import reshard, auto_axes, explicit_axes\n", "\n", "jax.config.update('jax_num_cpu_devices', 8)" ] @@ -652,7 +652,51 @@ "id": "_3sfJjRq8w9f" }, "source": [ - "As you can see, inside `g`, the type of `arr1` is `ShapedArray(float32[4,4@Y])` which indicates it's Explicit over `Y` mesh axis while auto over `X`." + "As you can see, inside `g`, the type of `arr1` is `ShapedArray(float32[4,4@Y])` which indicates it's Explicit over `Y` mesh axis while auto over `X`.\n", + "\n", + "\n", + "You can also use the `explicit_axes` API to drop into `Explicit` mode over some or all mesh axes." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a102e9c7", + "metadata": {}, + "outputs": [], + "source": [ + "auto_mesh = jax.make_mesh((2, 4), (\"X\", \"Y\"),\n", + " axis_types=(AxisType.Auto, AxisType.Auto))\n", + "\n", + "@functools.partial(explicit_axes, axes=('X', 'Y'))\n", + "def explicit_g(y):\n", + " print(f'mesh inside g: {get_abstract_mesh()}')\n", + " print(f'y.sharding inside g: {jax.typeof(y) = }')\n", + " z = y * 2\n", + " print(f'z.sharding inside g: {jax.typeof(z) = }', end='\\n\\n')\n", + " return z\n", + "\n", + "@jax.jit\n", + "def f(arr1):\n", + " print(f'mesh inside f: {get_abstract_mesh()}', end='\\n\\n')\n", + " x = jnp.sin(arr1)\n", + "\n", + " z = explicit_g(x, in_shardings=P(\"X\", \"Y\"))\n", + "\n", + " return z + 1\n", + "\n", + "with jax.sharding.use_mesh(auto_mesh):\n", + " some_x = jax.device_put(np.arange(16).reshape(4, 4), P(\"X\", \"Y\"))\n", + " f(some_x)" + ] + }, + { + "cell_type": "markdown", + "id": "e64d40de", + "metadata": {}, + "source": [ + "As you can see, all axes of mesh inside `f` are of type `Auto` while inside `g`, they are of type `Explicit`.\n", + "Because of that, sharding is visible on the type of arrays inside `g`." ] }, { diff --git a/docs/notebooks/explicit-sharding.md b/docs/notebooks/explicit-sharding.md index 7c59a675d8ec..a091060393b6 100644 --- a/docs/notebooks/explicit-sharding.md +++ b/docs/notebooks/explicit-sharding.md @@ -31,7 +31,7 @@ constraints? You could put them on every single intermediate but that's a lot of work and it's also easy to make mistakes that way because there's no way to check that the shardings make sense together. More commonly, people add just enough sharding annotations to constrain the compiler. But this is a slow -iterative process. It's hard to know ahead of time what XLA's gSPMD pass will +iterative process. It's hard to know ahead of time what XLA's GSPMD pass will do (it's a whole-program optimization) so all you can do is add annotations, inspect XLA's sharding choices to see what happened, and repeat. @@ -56,7 +56,7 @@ import jax import numpy as np import jax.numpy as jnp from jax.sharding import PartitionSpec as P, AxisType, set_mesh, get_abstract_mesh -from jax.experimental.shard import reshard, auto_axes +from jax.experimental.shard import reshard, auto_axes, explicit_axes jax.config.update('jax_num_cpu_devices', 8) ``` @@ -403,6 +403,38 @@ f(some_x) As you can see, inside `g`, the type of `arr1` is `ShapedArray(float32[4,4@Y])` which indicates it's Explicit over `Y` mesh axis while auto over `X`. + +You can also use the `explicit_axes` API to drop into `Explicit` mode over some or all mesh axes. + +```{code-cell} ipython3 +auto_mesh = jax.make_mesh((2, 4), ("X", "Y"), + axis_types=(AxisType.Auto, AxisType.Auto)) + +@functools.partial(explicit_axes, axes=('X', 'Y')) +def explicit_g(y): + print(f'mesh inside g: {get_abstract_mesh()}') + print(f'y.sharding inside g: {jax.typeof(y) = }') + z = y * 2 + print(f'z.sharding inside g: {jax.typeof(z) = }', end='\n\n') + return z + +@jax.jit +def f(arr1): + print(f'mesh inside f: {get_abstract_mesh()}', end='\n\n') + x = jnp.sin(arr1) + + z = explicit_g(x, in_shardings=P("X", "Y")) + + return z + 1 + +with jax.sharding.use_mesh(auto_mesh): + some_x = jax.device_put(np.arange(16).reshape(4, 4), P("X", "Y")) + f(some_x) +``` + +As you can see, all axes of mesh inside `f` are of type `Auto` while inside `g`, they are of type `Explicit`. +Because of that, sharding is visible on the type of arrays inside `g`. + +++ {"id": "sJcWbfAh7UcO"} ## Concrete array shardings can mention `Auto` mesh axis diff --git a/docs/notebooks/neural_network_with_tfds_data.ipynb b/docs/notebooks/neural_network_with_tfds_data.ipynb index c31a99746866..a909d9329e24 100644 --- a/docs/notebooks/neural_network_with_tfds_data.ipynb +++ b/docs/notebooks/neural_network_with_tfds_data.ipynb @@ -46,7 +46,7 @@ "\n", "![JAX](https://raw.githubusercontent.com/jax-ml/jax/main/images/jax_logo_250px.png)\n", "\n", - "Let's combine everything we showed in the [quickstart](https://jax.readthedocs.io/en/latest/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use `tensorflow/datasets` data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library :P).\n", + "Let's combine everything we showed in the [quickstart](https://docs.jax.dev/en/latest/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use `tensorflow/datasets` data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library :P).\n", "\n", "Of course, you can use JAX with any API that is compatible with NumPy to make specifying the model a bit more plug-and-play. Here, just for explanatory purposes, we won't use any neural network libraries or special APIs for building our model." ] diff --git a/docs/notebooks/neural_network_with_tfds_data.md b/docs/notebooks/neural_network_with_tfds_data.md index 53b7d47358c2..9c153d704763 100644 --- a/docs/notebooks/neural_network_with_tfds_data.md +++ b/docs/notebooks/neural_network_with_tfds_data.md @@ -44,7 +44,7 @@ _Forked from_ `neural_network_and_data_loading.ipynb` ![JAX](https://raw.githubusercontent.com/jax-ml/jax/main/images/jax_logo_250px.png) -Let's combine everything we showed in the [quickstart](https://jax.readthedocs.io/en/latest/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use `tensorflow/datasets` data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library :P). +Let's combine everything we showed in the [quickstart](https://docs.jax.dev/en/latest/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use `tensorflow/datasets` data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library :P). Of course, you can use JAX with any API that is compatible with NumPy to make specifying the model a bit more plug-and-play. Here, just for explanatory purposes, we won't use any neural network libraries or special APIs for building our model. diff --git a/docs/notebooks/shard_map.ipynb b/docs/notebooks/shard_map.ipynb index d73b0d4c0f3e..eb8f54f70d7a 100644 --- a/docs/notebooks/shard_map.ipynb +++ b/docs/notebooks/shard_map.ipynb @@ -13,9 +13,9 @@ "\n", "`shard_map` is a single-program multiple-data (SPMD) multi-device parallelism API to map a function over shards of data. Mapped function applications, or _instances_, communicate with each other via explicit collective communication operations.\n", "\n", - "`shard_map` is complementary to, and composable with, the automatic compiler-based parallelization built into `jit`. With `jit` you write code as if for a single device, and [the compiler can automatically partition computation over multiple devices](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html), generating per-device code and communication collectives behind the scenes. With `shard_map` you take control, writing your own partitioned code and explicit collectives. Or you can do a bit of both: take manual control across groups of devices while leaving within-group device partitioning up to the compiler. The two approaches can be mixed, matched, and composed as needed.\n", + "`shard_map` is complementary to, and composable with, the automatic compiler-based parallelization built into `jit`. With `jit` you write code as if for a single device, and [the compiler can automatically partition computation over multiple devices](https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html), generating per-device code and communication collectives behind the scenes. With `shard_map` you take control, writing your own partitioned code and explicit collectives. Or you can do a bit of both: take manual control across groups of devices while leaving within-group device partitioning up to the compiler. The two approaches can be mixed, matched, and composed as needed.\n", "\n", - "If you're familiar with `pmap`, think of `shard_map` as an evolution. It's more expressive, performant, and composable with other JAX APIs. It even works eagerly, for easier debugging! (For more, see [a detailed comparison to `pmap`.](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html#why-don-t-pmap-or-xmap-already-solve-this))\n", + "If you're familiar with `pmap`, think of `shard_map` as an evolution. It's more expressive, performant, and composable with other JAX APIs. It even works eagerly, for easier debugging! (For more, see [a detailed comparison to `pmap`.](https://docs.jax.dev/en/latest/jep/14273-shard-map.html#why-don-t-pmap-or-xmap-already-solve-this))\n", "\n", "By reading this tutorial, you'll learn how to use `shard_map` to get full control over your multi-device code. You'll see in detail how it composes with `jax.jit`'s automatic parallelization and `jax.grad`'s automatic differentiation. We'll also give some basic examples of neural network parallelization strategies.\n", "\n", @@ -481,6 +481,346 @@ "`Array`s, or physically how to interpret the buffers across devices as the\n", "physical layout of a single logical `Array`.\n", "\n", + "#### Tracking how values vary over manual mesh axes, and `check_rep=True`\n", + "\n", + "Under a `shard_map`, values can vary across function instances, or they can be\n", + "the same. For example, when we use `in_specs` to split an argument over a mesh\n", + "axis, each function instance along that mesh axis gets a different value:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "38668c79", + "metadata": {}, + "outputs": [], + "source": [ + "mesh = jax.make_mesh((2,), ('i',))\n", + "\n", + "@partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i'))\n", + "def f(x):\n", + " print(x)\n", + " return 2 * x\n", + "\n", + "x = jnp.arange(6.)\n", + "f(x)" + ] + }, + { + "cell_type": "markdown", + "id": "00b66850", + "metadata": {}, + "source": [ + "If instead `in_specs` does not split the argument over a mesh axis, the value\n", + "is the same for each function instance along that axis:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9d0dfa6d", + "metadata": {}, + "outputs": [], + "source": [ + "@partial(shard_map, mesh=mesh, in_specs=P(), out_specs=P())\n", + "def f(x):\n", + " print(x)\n", + " return 2 * x\n", + "\n", + "x = jnp.arange(6.)\n", + "f(x)" + ] + }, + { + "cell_type": "markdown", + "id": "594b4574", + "metadata": {}, + "source": [ + "A collective's output may have a different variance than its input. For\n", + "example, applying a `psum` produces the same output on each function instance\n", + "along an axis:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "df486b2f", + "metadata": {}, + "outputs": [], + "source": [ + "@partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P())\n", + "def f(x):\n", + " y = jax.lax.psum(x, 'i')\n", + " print(y)\n", + " return y\n", + "\n", + "x = jnp.arange(6.)\n", + "f(x)" + ] + }, + { + "cell_type": "markdown", + "id": "bf6a17ad", + "metadata": {}, + "source": [ + "In general, each intermediate value in a `shard_map` can be either unvarying or\n", + "possibly-varying over each manual mesh axis. That information can be tracked in\n", + "the JAX type system, enabled by the `check_rep=True` argument to `shard_map`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d7f32190", + "metadata": {}, + "outputs": [], + "source": [ + "@partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P())\n", + "def f(x):\n", + " print(jax.typeof(x)) # f32[3]{i}\n", + " y = jax.lax.psum(x, 'i')\n", + " print(jax.typeof(y)) # f32[3]\n", + " return y\n", + "\n", + "x = jnp.arange(6.)\n", + "f(x)" + ] + }, + { + "cell_type": "markdown", + "id": "f76cc47f", + "metadata": {}, + "source": [ + "Here, the type `f32[3]{i}` means that the value of `x` is varying over mesh\n", + "axis `'i'`. The type of `y` printing as `f32[3]` indicates it is unvarying over\n", + "all mesh axes; that is, empty sets are not printed. We call this part of the\n", + "type the _varying manual axes_ (VMA), and it can be accessed via\n", + "`jax.typeof(x).vma`.\n", + "\n", + "In general, the VMA type of a value can include any subset of the manual mesh\n", + "axes over which the `shard_map` is acting:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e69a02d3", + "metadata": {}, + "outputs": [], + "source": [ + "mesh = jax.make_mesh((4, 2), ('i', 'j'))\n", + "\n", + "@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i'))\n", + "def f(x):\n", + " print(jax.typeof(x)) # f32[2,2]{i,j}\n", + " y = jax.lax.psum(x, 'j')\n", + " assert jax.typeof(y).vma == {'i'}\n", + " print(jax.typeof(y)) # f32[2,2]{i}\n", + " return y\n", + "\n", + "x = jnp.arange(8 * 4.).reshape(8, 4)\n", + "f(x)" + ] + }, + { + "cell_type": "markdown", + "id": "a36f1654", + "metadata": {}, + "source": [ + "Tracking varying manual axes can be useful:\n", + "1. Your code can include prints, assertions, or conditionals about whether\n", + " values are varying over expected mesh axes;\n", + "2. It enables efficient reverse-mode autodiff that doesn't require defensive\n", + " `psum`s (see [JEP](https://docs.jax.dev/en/latest/jep/17111-shmap-transpose.html));\n", + "3. The correctness of `out_specs` can be checked, ruling out the potential bug\n", + " example below.\n", + "\n", + "For example, this `out_specs` bug is caught with `check_rep=True`, but uncaught\n", + "without it:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c92c1d4d", + "metadata": {}, + "outputs": [], + "source": [ + "mesh = jax.make_mesh((2,), ('i',))\n", + "\n", + "x = jnp.arange(6.)\n", + "try:\n", + " y = shard_map(lambda x: x, mesh, in_specs=P('i'), out_specs=P())(x)\n", + "except Exception as e:\n", + " print(e)" + ] + }, + { + "cell_type": "markdown", + "id": "68bc33af", + "metadata": {}, + "source": [ + "Here the `out_specs` incorrectly promise that each function instance along mesh\n", + "axis `'i'` produces the same value and thus we can choose just one of them.\n", + "With `check_rep=True` (the default) it raises an exception, while with\n", + "`check_rep=False` there is no exception and instead we get silent undefined\n", + "behavior.\n", + "\n", + "Sometimes we want to treat a value that is unvarying over a mesh axis as\n", + "varying over that mesh axis. That's what `jax.lax.pvary` does:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "21276d78", + "metadata": {}, + "outputs": [], + "source": [ + "@partial(shard_map, mesh=mesh, in_specs=P(), out_specs=None)\n", + "def f(x):\n", + " print(jax.typeof(x)) # f32[6]\n", + " y = jax.lax.pvary(x, 'i')\n", + " print(jax.typeof(y)) # f32[6]{i}\n", + "\n", + "x = jnp.arange(6.)\n", + "f(x)" + ] + }, + { + "cell_type": "markdown", + "id": "8f766c1a", + "metadata": {}, + "source": [ + "Think of `jax.lax.pvary` as applying a type cast: it's a no-op at runtime,\n", + "though under reverse-mode autodiff it transposes to a `jax.lax.psum` (see\n", + "[JEP](https://docs.jax.dev/en/latest/jep/17111-shmap-transpose.html)). That\n", + "makes sense because they do opposite things to the VMA: where `y: f32[3]{i} =\n", + "jax.lax.pvary(x: f32[3], 'i')`, we correspondingly have `x_grad: f32[3] =\n", + "jax.lax.psum(y_grad: f32[3]{i}, 'i')`.\n", + "\n", + "JAX implicitly inserts `jax.lax.pvary` calls in many cases, especially for\n", + "binary operations:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e22d52a4", + "metadata": {}, + "outputs": [], + "source": [ + "@partial(shard_map, mesh=mesh, in_specs=(P('i'), P()), out_specs=P('i'))\n", + "def f(x, y):\n", + " return x * y\n", + "\n", + "x = jnp.arange(6.)\n", + "y = jnp.arange(3.)\n", + "print(jax.make_jaxpr(f)(x, y))" + ] + }, + { + "cell_type": "markdown", + "id": "1bd7f6a5", + "metadata": {}, + "source": [ + "In a jaxpr, the multiplication operation requires the VMA types of its\n", + "arguments to match, but for convenience the `jax.numpy` and `jax.lax` APIs\n", + "automatically apply `jax.lax.pvary` to make argument VMA types agree.\n", + "\n", + "\n", + "\n", + "In some cases, like with `jax.lax.scan`, you might need to apply\n", + "`jax.lax.pvary` yourself to ensure VMA types match as required. For example,\n", + "this code raises an error:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6e33a5fb", + "metadata": {}, + "outputs": [], + "source": [ + "mesh = jax.make_mesh((2,), ('i',))\n", + "\n", + "@partial(shard_map, mesh=mesh, in_specs=(P('i'), P()), out_specs=P('i'))\n", + "def f(x, y):\n", + " def body(carry, _):\n", + " c1, c2 = carry\n", + " return (c2, c1), () # swap the carry\n", + " (x_, y_), _ = jax.lax.scan(body, (x, y), (), length=2)\n", + " return x_, y_\n", + "\n", + "x = jnp.arange(6.)\n", + "y = jnp.arange(3.)\n", + "\n", + "try:\n", + " f(x, y)\n", + "except Exception as e:\n", + " print(e)" + ] + }, + { + "cell_type": "markdown", + "id": "7b6fef36", + "metadata": {}, + "source": [ + "To make the types match, we need to apply `jax.lax.pvary` to some arguments to\n", + "the `scan`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8c8dbd11", + "metadata": {}, + "outputs": [], + "source": [ + "mesh = jax.make_mesh((2,), ('i',))\n", + "\n", + "@partial(shard_map, mesh=mesh, in_specs=(P('i'), P()), out_specs=P('i'))\n", + "def f(x, y):\n", + " def body(carry, _):\n", + " c1, c2 = carry\n", + " return (c2, c1), () # swap the carry\n", + "\n", + " y = jax.lax.pvary(y, 'i') # apply pvary to fix the error\n", + " (x_, y_), _ = jax.lax.scan(body, (x, y), (), length=2)\n", + " return x_, y_\n", + "\n", + "x = jnp.arange(6.)\n", + "y = jnp.arange(3.)\n", + "\n", + "f(x, y)" + ] + }, + { + "cell_type": "markdown", + "id": "10271c3c", + "metadata": {}, + "source": [ + "Here's a summary of collective primitives and how they affect varying manual axis types:\n", + "\n", + "| Name | Device variance type | Example | Lowers to HLO | Transpose |\n", + "| --- | --- | --- | --- | --- |\n", + "| `psum_invariant` | `Varying -> Invariant` | `y:f32[3]{j} = psum(x:f32[3]{i,j}, axis='i')` | `AllReduceSum` (communication) | `pvary` |\n", + "| `pvary` | `Invariant -> Varying` | `y:f32[3]{i} = pvary(x:f32[3], 'i')` | no-op (no communication) | `psum_invariant` |\n", + "| `all_to_all` | `Varying -> Varying` | `y:f32[16]{i} = all_to_all(x:f32[16]{i}, 'i', 0, 0)` `AllToAll` (communication) | `all_to_all` |\n", + "| `axis_index` | `() -> Varying` | `idx:i32[]{i} = axis_index('i')` | `ReplicaId` and some arithmetic (no communication) | n/a |\n", + "| `psum_scatter` | `Varying -> Varying` | `y:f32[2]{i} = psum_scatter(x:f32[16]{i}, 'i')` | `ReduceScatterSum` (communication) | `all_gather` |\n", + "| `all_gather` | `Varying -> Varying` | `y:f32[16]{i} = all_gather(x:f32[2]{i}, 'i')` | `AllGather` (communication) | `psum_scatter` |\n", + "| `pscatter` | `Invariant -> Varying` | `y:f32[2]{i} = pscatter(x:f32[16], 'i')` | `lambda x: x[axis_index('i'), None]` (no communication) | `all_gather_invariant` |\n", + "| `all_gather_invariant` | `Varying -> Invariant` | `y:f32[16] = all_gather_invariant(x:f32[2]{i}, 'i')` | `AllGather` (communication) | `pscatter` |\n", + "\n", + "A few notes on the table:\n", + "* The function `jax.lax.psum` is a convenience wrapper around `psum_invariant`.\n", + "* It's surprising that `all_gather` is `Varying -> Varying`, but that's because\n", + " it's really the transpose of `psum_scatter` which is `Varying -> Varying`.\n", + "* Neither `pscatter` nor `all_gather_invariant` have user APIs at the time of\n", + " writing, but they're described here for completeness.\n", + "\n", + "\n", "## API Specification\n", "\n", "```python\n", @@ -499,7 +839,7 @@ "* `mesh` encodes devices arranged in an array and with associated axis names, just like it does for `sharding.NamedSharding`;\n", "* `in_specs` and `out_specs` are `PartitionSpec`s which can affinely mention axis names from `mesh` to express slicing/unconcatenation and concatenation of inputs and outputs, respectively, with unmentioned names corresponding to replication and untiling (assert-replicated-so-give-me-one-copy), respectively;\n", "* `auto` is an optional set of axis names corresponding to the subset of names of `mesh` to treat automatically in the body, as in the caller, rather than manually;\n", - "* `check_rep` is an optional boolean indicating whether to check statically for any replication errors in `out_specs`, and also whether to enable a related automatic differentiation optimization (see [JEP](https://jax.readthedocs.io/en/latest/jep/17111-shmap-transpose.html)).\n", + "* `check_rep` is an optional boolean indicating whether to check statically for any replication errors in `out_specs`, and also whether to enable a related automatic differentiation optimization (see [JEP](https://docs.jax.dev/en/latest/jep/17111-shmap-transpose.html)).\n", "\n", "The shapes of the arguments passed to `f` have the same ranks as the arguments\n", "passed to `shard_map`-of-`f`, and the shape of an argument to `f` is computed\n", @@ -1520,7 +1860,7 @@ "source": [ "Compare these examples with the purely [automatic partitioning examples in the\n", "\"Distributed arrays and automatic partitioning\"\n", - "doc](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html).\n", + "doc](https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html).\n", "While in those automatic partitioning examples we don't need to edit the model\n", "functions to use different parallelization strategies, with `shard_map` we\n", "often do.\n", @@ -1626,7 +1966,7 @@ "parameters from the forward pass for use on the backward pass. Instead, we want\n", "to gather them again on the backward pass. We can express that by using\n", "`jax.remat` with a [custom\n", - "policy](https://jax.readthedocs.io/en/latest/notebooks/autodiff_remat.html#custom-policies-for-what-s-saveable)\n", + "policy](https://docs.jax.dev/en/latest/notebooks/autodiff_remat.html#custom-policies-for-what-s-saveable)\n", "(or a `custom_vjp`), though XLA typically does that rematerialization\n", "automatically.\n", "\n", diff --git a/docs/notebooks/shard_map.md b/docs/notebooks/shard_map.md index c52cf0e6d22b..ae9206059b1e 100644 --- a/docs/notebooks/shard_map.md +++ b/docs/notebooks/shard_map.md @@ -22,9 +22,9 @@ kernelspec: `shard_map` is a single-program multiple-data (SPMD) multi-device parallelism API to map a function over shards of data. Mapped function applications, or _instances_, communicate with each other via explicit collective communication operations. -`shard_map` is complementary to, and composable with, the automatic compiler-based parallelization built into `jit`. With `jit` you write code as if for a single device, and [the compiler can automatically partition computation over multiple devices](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html), generating per-device code and communication collectives behind the scenes. With `shard_map` you take control, writing your own partitioned code and explicit collectives. Or you can do a bit of both: take manual control across groups of devices while leaving within-group device partitioning up to the compiler. The two approaches can be mixed, matched, and composed as needed. +`shard_map` is complementary to, and composable with, the automatic compiler-based parallelization built into `jit`. With `jit` you write code as if for a single device, and [the compiler can automatically partition computation over multiple devices](https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html), generating per-device code and communication collectives behind the scenes. With `shard_map` you take control, writing your own partitioned code and explicit collectives. Or you can do a bit of both: take manual control across groups of devices while leaving within-group device partitioning up to the compiler. The two approaches can be mixed, matched, and composed as needed. -If you're familiar with `pmap`, think of `shard_map` as an evolution. It's more expressive, performant, and composable with other JAX APIs. It even works eagerly, for easier debugging! (For more, see [a detailed comparison to `pmap`.](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html#why-don-t-pmap-or-xmap-already-solve-this)) +If you're familiar with `pmap`, think of `shard_map` as an evolution. It's more expressive, performant, and composable with other JAX APIs. It even works eagerly, for easier debugging! (For more, see [a detailed comparison to `pmap`.](https://docs.jax.dev/en/latest/jep/14273-shard-map.html#why-don-t-pmap-or-xmap-already-solve-this)) By reading this tutorial, you'll learn how to use `shard_map` to get full control over your multi-device code. You'll see in detail how it composes with `jax.jit`'s automatic parallelization and `jax.grad`'s automatic differentiation. We'll also give some basic examples of neural network parallelization strategies. @@ -328,6 +328,226 @@ Instead, `out_specs` just encodes how to assemble the block outputs into `Array`s, or physically how to interpret the buffers across devices as the physical layout of a single logical `Array`. +#### Tracking how values vary over manual mesh axes, and `check_rep=True` + +Under a `shard_map`, values can vary across function instances, or they can be +the same. For example, when we use `in_specs` to split an argument over a mesh +axis, each function instance along that mesh axis gets a different value: + +```{code-cell} +mesh = jax.make_mesh((2,), ('i',)) + +@partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i')) +def f(x): + print(x) + return 2 * x + +x = jnp.arange(6.) +f(x) +``` + +If instead `in_specs` does not split the argument over a mesh axis, the value +is the same for each function instance along that axis: + +```{code-cell} +@partial(shard_map, mesh=mesh, in_specs=P(), out_specs=P()) +def f(x): + print(x) + return 2 * x + +x = jnp.arange(6.) +f(x) +``` + +A collective's output may have a different variance than its input. For +example, applying a `psum` produces the same output on each function instance +along an axis: + +```{code-cell} +@partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P()) +def f(x): + y = jax.lax.psum(x, 'i') + print(y) + return y + +x = jnp.arange(6.) +f(x) +``` + +In general, each intermediate value in a `shard_map` can be either unvarying or +possibly-varying over each manual mesh axis. That information can be tracked in +the JAX type system, enabled by the `check_rep=True` argument to `shard_map`: + +```{code-cell} +@partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P()) +def f(x): + print(jax.typeof(x)) # f32[3]{i} + y = jax.lax.psum(x, 'i') + print(jax.typeof(y)) # f32[3] + return y + +x = jnp.arange(6.) +f(x) +``` + +Here, the type `f32[3]{i}` means that the value of `x` is varying over mesh +axis `'i'`. The type of `y` printing as `f32[3]` indicates it is unvarying over +all mesh axes; that is, empty sets are not printed. We call this part of the +type the _varying manual axes_ (VMA), and it can be accessed via +`jax.typeof(x).vma`. + +In general, the VMA type of a value can include any subset of the manual mesh +axes over which the `shard_map` is acting: + +```{code-cell} +mesh = jax.make_mesh((4, 2), ('i', 'j')) + +@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i')) +def f(x): + print(jax.typeof(x)) # f32[2,2]{i,j} + y = jax.lax.psum(x, 'j') + assert jax.typeof(y).vma == {'i'} + print(jax.typeof(y)) # f32[2,2]{i} + return y + +x = jnp.arange(8 * 4.).reshape(8, 4) +f(x) +``` + +Tracking varying manual axes can be useful: +1. Your code can include prints, assertions, or conditionals about whether + values are varying over expected mesh axes; +2. It enables efficient reverse-mode autodiff that doesn't require defensive + `psum`s (see [JEP](https://docs.jax.dev/en/latest/jep/17111-shmap-transpose.html)); +3. The correctness of `out_specs` can be checked, ruling out the potential bug + example below. + +For example, this `out_specs` bug is caught with `check_rep=True`, but uncaught +without it: + +```{code-cell} +mesh = jax.make_mesh((2,), ('i',)) + +x = jnp.arange(6.) +try: + y = shard_map(lambda x: x, mesh, in_specs=P('i'), out_specs=P())(x) +except Exception as e: + print(e) +``` + +Here the `out_specs` incorrectly promise that each function instance along mesh +axis `'i'` produces the same value and thus we can choose just one of them. +With `check_rep=True` (the default) it raises an exception, while with +`check_rep=False` there is no exception and instead we get silent undefined +behavior. + +Sometimes we want to treat a value that is unvarying over a mesh axis as +varying over that mesh axis. That's what `jax.lax.pvary` does: + +```{code-cell} +@partial(shard_map, mesh=mesh, in_specs=P(), out_specs=None) +def f(x): + print(jax.typeof(x)) # f32[6] + y = jax.lax.pvary(x, 'i') + print(jax.typeof(y)) # f32[6]{i} + +x = jnp.arange(6.) +f(x) +``` + +Think of `jax.lax.pvary` as applying a type cast: it's a no-op at runtime, +though under reverse-mode autodiff it transposes to a `jax.lax.psum` (see +[JEP](https://docs.jax.dev/en/latest/jep/17111-shmap-transpose.html)). That +makes sense because they do opposite things to the VMA: where `y: f32[3]{i} = +jax.lax.pvary(x: f32[3], 'i')`, we correspondingly have `x_grad: f32[3] = +jax.lax.psum(y_grad: f32[3]{i}, 'i')`. + +JAX implicitly inserts `jax.lax.pvary` calls in many cases, especially for +binary operations: + +```{code-cell} +@partial(shard_map, mesh=mesh, in_specs=(P('i'), P()), out_specs=P('i')) +def f(x, y): + return x * y + +x = jnp.arange(6.) +y = jnp.arange(3.) +print(jax.make_jaxpr(f)(x, y)) +``` + +In a jaxpr, the multiplication operation requires the VMA types of its +arguments to match, but for convenience the `jax.numpy` and `jax.lax` APIs +automatically apply `jax.lax.pvary` to make argument VMA types agree. + + + +In some cases, like with `jax.lax.scan`, you might need to apply +`jax.lax.pvary` yourself to ensure VMA types match as required. For example, +this code raises an error: + +```{code-cell} +mesh = jax.make_mesh((2,), ('i',)) + +@partial(shard_map, mesh=mesh, in_specs=(P('i'), P()), out_specs=P('i')) +def f(x, y): + def body(carry, _): + c1, c2 = carry + return (c2, c1), () # swap the carry + (x_, y_), _ = jax.lax.scan(body, (x, y), (), length=2) + return x_, y_ + +x = jnp.arange(6.) +y = jnp.arange(3.) + +try: + f(x, y) +except Exception as e: + print(e) +``` + +To make the types match, we need to apply `jax.lax.pvary` to some arguments to +the `scan`: + +```{code-cell} +mesh = jax.make_mesh((2,), ('i',)) + +@partial(shard_map, mesh=mesh, in_specs=(P('i'), P()), out_specs=P('i')) +def f(x, y): + def body(carry, _): + c1, c2 = carry + return (c2, c1), () # swap the carry + + y = jax.lax.pvary(y, 'i') # apply pvary to fix the error + (x_, y_), _ = jax.lax.scan(body, (x, y), (), length=2) + return x_, y_ + +x = jnp.arange(6.) +y = jnp.arange(3.) + +f(x, y) +``` + +Here's a summary of collective primitives and how they affect varying manual axis types: + +| Name | Device variance type | Example | Lowers to HLO | Transpose | +| --- | --- | --- | --- | --- | +| `psum_invariant` | `Varying -> Invariant` | `y:f32[3]{j} = psum(x:f32[3]{i,j}, axis='i')` | `AllReduceSum` (communication) | `pvary` | +| `pvary` | `Invariant -> Varying` | `y:f32[3]{i} = pvary(x:f32[3], 'i')` | no-op (no communication) | `psum_invariant` | +| `all_to_all` | `Varying -> Varying` | `y:f32[16]{i} = all_to_all(x:f32[16]{i}, 'i', 0, 0)` `AllToAll` (communication) | `all_to_all` | +| `axis_index` | `() -> Varying` | `idx:i32[]{i} = axis_index('i')` | `ReplicaId` and some arithmetic (no communication) | n/a | +| `psum_scatter` | `Varying -> Varying` | `y:f32[2]{i} = psum_scatter(x:f32[16]{i}, 'i')` | `ReduceScatterSum` (communication) | `all_gather` | +| `all_gather` | `Varying -> Varying` | `y:f32[16]{i} = all_gather(x:f32[2]{i}, 'i')` | `AllGather` (communication) | `psum_scatter` | +| `pscatter` | `Invariant -> Varying` | `y:f32[2]{i} = pscatter(x:f32[16], 'i')` | `lambda x: x[axis_index('i'), None]` (no communication) | `all_gather_invariant` | +| `all_gather_invariant` | `Varying -> Invariant` | `y:f32[16] = all_gather_invariant(x:f32[2]{i}, 'i')` | `AllGather` (communication) | `pscatter` | + +A few notes on the table: +* The function `jax.lax.psum` is a convenience wrapper around `psum_invariant`. +* It's surprising that `all_gather` is `Varying -> Varying`, but that's because + it's really the transpose of `psum_scatter` which is `Varying -> Varying`. +* Neither `pscatter` nor `all_gather_invariant` have user APIs at the time of + writing, but they're described here for completeness. + + ## API Specification ```python @@ -346,7 +566,7 @@ where: * `mesh` encodes devices arranged in an array and with associated axis names, just like it does for `sharding.NamedSharding`; * `in_specs` and `out_specs` are `PartitionSpec`s which can affinely mention axis names from `mesh` to express slicing/unconcatenation and concatenation of inputs and outputs, respectively, with unmentioned names corresponding to replication and untiling (assert-replicated-so-give-me-one-copy), respectively; * `auto` is an optional set of axis names corresponding to the subset of names of `mesh` to treat automatically in the body, as in the caller, rather than manually; -* `check_rep` is an optional boolean indicating whether to check statically for any replication errors in `out_specs`, and also whether to enable a related automatic differentiation optimization (see [JEP](https://jax.readthedocs.io/en/latest/jep/17111-shmap-transpose.html)). +* `check_rep` is an optional boolean indicating whether to check statically for any replication errors in `out_specs`, and also whether to enable a related automatic differentiation optimization (see [JEP](https://docs.jax.dev/en/latest/jep/17111-shmap-transpose.html)). The shapes of the arguments passed to `f` have the same ranks as the arguments passed to `shard_map`-of-`f`, and the shape of an argument to `f` is computed @@ -1061,7 +1281,7 @@ params, batch = init(jax.random.key(0), layer_sizes, batch_size) Compare these examples with the purely [automatic partitioning examples in the "Distributed arrays and automatic partitioning" -doc](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html). +doc](https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html). While in those automatic partitioning examples we don't need to edit the model functions to use different parallelization strategies, with `shard_map` we often do. @@ -1137,7 +1357,7 @@ There's one other ingredient we need: we don't want to store the fully gathered parameters from the forward pass for use on the backward pass. Instead, we want to gather them again on the backward pass. We can express that by using `jax.remat` with a [custom -policy](https://jax.readthedocs.io/en/latest/notebooks/autodiff_remat.html#custom-policies-for-what-s-saveable) +policy](https://docs.jax.dev/en/latest/notebooks/autodiff_remat.html#custom-policies-for-what-s-saveable) (or a `custom_vjp`), though XLA typically does that rematerialization automatically. diff --git a/docs/notebooks/thinking_in_jax.ipynb b/docs/notebooks/thinking_in_jax.ipynb index 5ddcdd32e2b4..d6cbf6e02198 100644 --- a/docs/notebooks/thinking_in_jax.ipynb +++ b/docs/notebooks/thinking_in_jax.ipynb @@ -248,7 +248,7 @@ "id": "yRYF0YgO3F4H" }, "source": [ - "For updating individual elements, JAX provides an [indexed update syntax](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax-numpy-ndarray-at) that returns an updated copy:" + "For updating individual elements, JAX provides an [indexed update syntax](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax-numpy-ndarray-at) that returns an updated copy:" ] }, { @@ -423,7 +423,7 @@ "id": "0GPqgT7S0q8r" }, "source": [ - "Under the hood, this NumPy operation is translated to a much more general convolution implemented by [`lax.conv_general_dilated`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv_general_dilated.html):" + "Under the hood, this NumPy operation is translated to a much more general convolution implemented by [`lax.conv_general_dilated`](https://docs.jax.dev/en/latest/_autosummary/jax.lax.conv_general_dilated.html):" ] }, { @@ -461,7 +461,7 @@ "id": "7mdo6ycczlbd" }, "source": [ - "This is a batched convolution operation designed to be efficient for the types of convolutions often used in deep neural nets. It requires much more boilerplate, but is far more flexible and scalable than the convolution provided by NumPy (See [Convolutions in JAX](https://jax.readthedocs.io/en/latest/notebooks/convolutions.html) for more detail on JAX convolutions).\n", + "This is a batched convolution operation designed to be efficient for the types of convolutions often used in deep neural nets. It requires much more boilerplate, but is far more flexible and scalable than the convolution provided by NumPy (See [Convolutions in JAX](https://docs.jax.dev/en/latest/notebooks/convolutions.html) for more detail on JAX convolutions).\n", "\n", "At their heart, all `jax.lax` operations are Python wrappers for operations in XLA; here, for example, the convolution implementation is provided by [XLA:ConvWithGeneralPadding](https://www.tensorflow.org/xla/operation_semantics#convwithgeneralpadding_convolution).\n", "Every JAX operation is eventually expressed in terms of these fundamental XLA operations, which is what enables just-in-time (JIT) compilation." @@ -562,7 +562,7 @@ "id": "3GvisB-CA9M8" }, "source": [ - "But due to the compilation (which includes fusing of operations, avoidance of allocating temporary arrays, and a host of other tricks), execution times can be orders of magnitude faster in the JIT-compiled case (note the use of `block_until_ready()` to account for JAX's [asynchronous dispatch](https://jax.readthedocs.io/en/latest/async_dispatch.html)):" + "But due to the compilation (which includes fusing of operations, avoidance of allocating temporary arrays, and a host of other tricks), execution times can be orders of magnitude faster in the JIT-compiled case (note the use of `block_until_ready()` to account for JAX's [asynchronous dispatch](https://docs.jax.dev/en/latest/async_dispatch.html)):" ] }, { @@ -650,7 +650,7 @@ "evalue": "ignored", "output_type": "error", "traceback": [ - "\u001b[0;31mNonConcreteBooleanIndexError\u001b[0m\u001b[0;31m:\u001b[0m Array boolean indices must be concrete; got ShapedArray(bool[10])\n\nSee https://jax.readthedocs.io/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError\n" + "\u001b[0;31mNonConcreteBooleanIndexError\u001b[0m\u001b[0;31m:\u001b[0m Array boolean indices must be concrete; got ShapedArray(bool[10])\n\nSee https://docs.jax.dev/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError\n" ] } ], @@ -835,7 +835,7 @@ "evalue": "ignored", "output_type": "error", "traceback": [ - "\u001b[0;31mTracerBoolConversionError\u001b[0m\u001b[0;31m:\u001b[0m Attempted boolean conversion of traced array with shape bool[]..\nThe error occurred while tracing the function f at :1 for jit. This concrete value was not available in Python because it depends on the value of the argument neg.\nSee https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError\n" + "\u001b[0;31mTracerBoolConversionError\u001b[0m\u001b[0;31m:\u001b[0m Attempted boolean conversion of traced array with shape bool[]..\nThe error occurred while tracing the function f at :1 for jit. This concrete value was not available in Python because it depends on the value of the argument neg.\nSee https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError\n" ] } ], diff --git a/docs/notebooks/thinking_in_jax.md b/docs/notebooks/thinking_in_jax.md index 0693f6ba8579..7b0bb0d9b8ce 100644 --- a/docs/notebooks/thinking_in_jax.md +++ b/docs/notebooks/thinking_in_jax.md @@ -117,7 +117,7 @@ x[0] = 10 +++ {"id": "yRYF0YgO3F4H"} -For updating individual elements, JAX provides an [indexed update syntax](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax-numpy-ndarray-at) that returns an updated copy: +For updating individual elements, JAX provides an [indexed update syntax](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax-numpy-ndarray-at) that returns an updated copy: ```{code-cell} ipython3 :id: 8zqPEAeP3UK5 @@ -189,7 +189,7 @@ jnp.convolve(x, y) +++ {"id": "0GPqgT7S0q8r"} -Under the hood, this NumPy operation is translated to a much more general convolution implemented by [`lax.conv_general_dilated`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv_general_dilated.html): +Under the hood, this NumPy operation is translated to a much more general convolution implemented by [`lax.conv_general_dilated`](https://docs.jax.dev/en/latest/_autosummary/jax.lax.conv_general_dilated.html): ```{code-cell} ipython3 :id: pi4f6ikjzc3l @@ -206,7 +206,7 @@ result[0, 0] +++ {"id": "7mdo6ycczlbd"} -This is a batched convolution operation designed to be efficient for the types of convolutions often used in deep neural nets. It requires much more boilerplate, but is far more flexible and scalable than the convolution provided by NumPy (See [Convolutions in JAX](https://jax.readthedocs.io/en/latest/notebooks/convolutions.html) for more detail on JAX convolutions). +This is a batched convolution operation designed to be efficient for the types of convolutions often used in deep neural nets. It requires much more boilerplate, but is far more flexible and scalable than the convolution provided by NumPy (See [Convolutions in JAX](https://docs.jax.dev/en/latest/notebooks/convolutions.html) for more detail on JAX convolutions). At their heart, all `jax.lax` operations are Python wrappers for operations in XLA; here, for example, the convolution implementation is provided by [XLA:ConvWithGeneralPadding](https://www.tensorflow.org/xla/operation_semantics#convwithgeneralpadding_convolution). Every JAX operation is eventually expressed in terms of these fundamental XLA operations, which is what enables just-in-time (JIT) compilation. @@ -261,7 +261,7 @@ np.allclose(norm(X), norm_compiled(X), atol=1E-6) +++ {"id": "3GvisB-CA9M8"} -But due to the compilation (which includes fusing of operations, avoidance of allocating temporary arrays, and a host of other tricks), execution times can be orders of magnitude faster in the JIT-compiled case (note the use of `block_until_ready()` to account for JAX's [asynchronous dispatch](https://jax.readthedocs.io/en/latest/async_dispatch.html)): +But due to the compilation (which includes fusing of operations, avoidance of allocating temporary arrays, and a host of other tricks), execution times can be orders of magnitude faster in the JIT-compiled case (note the use of `block_until_ready()` to account for JAX's [asynchronous dispatch](https://docs.jax.dev/en/latest/async_dispatch.html)): ```{code-cell} ipython3 :id: 6mUB6VdDAEIY diff --git a/docs/notes.rst b/docs/notes.rst index 08265638000e..502385142b16 100644 --- a/docs/notes.rst +++ b/docs/notes.rst @@ -9,9 +9,6 @@ Dependencies and version compatibility: - :doc:`api_compatibility` outlines JAX's policies with regard to API compatibility across releases. - :doc:`deprecation` outlines JAX's policies with regard to compatibility with Python and NumPy. -Migrations and deprecations: - - :doc:`jax_array_migration` summarizes the changes to the default array type in jax v 0.4.1 - Memory and computation usage: - :doc:`async_dispatch` describes JAX's asynchronous dispatch model. - :doc:`concurrency` describes how JAX interacts with other Python concurrency. @@ -20,6 +17,10 @@ Memory and computation usage: Programmer guardrails: - :doc:`rank_promotion_warning` describes how to configure :mod:`jax.numpy` to avoid implicit rank promotion. +Arrays and data types: + - :doc:`type_promotion` describes JAX's implicit type promotion for functions of two or more values. + - :doc:`default_dtypes` describes how JAX determines the default dtype for array creation functions. + .. toctree:: :hidden: @@ -27,8 +28,9 @@ Programmer guardrails: api_compatibility deprecation - jax_array_migration async_dispatch concurrency gpu_memory_allocation - rank_promotion_warning \ No newline at end of file + rank_promotion_warning + type_promotion + default_dtypes diff --git a/docs/pallas/CHANGELOG.md b/docs/pallas/CHANGELOG.md index 2b1cad7c9a66..7533e6eda053 100644 --- a/docs/pallas/CHANGELOG.md +++ b/docs/pallas/CHANGELOG.md @@ -5,7 +5,7 @@ This is the list of changes specific to {class}`jax.experimental.pallas`. -For the overall JAX change log see [here](https://jax.readthedocs.io/en/latest/changelog.html). +For the overall JAX change log see [here](https://docs.jax.dev/en/latest/changelog.html). \n", + "![memory_hierarchy](../_static/pallas/pipelining_mem_hierarchy.svg)\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "WvW6Lo7d2jfb" + }, + "source": [ + "\n", + "In order to perform computation on values X and Y that live in HBM, we need to:\n", + "\n", + "1. Copy the values x and y into SRAM.\n", + "2. Load the values from SRAM into registers.\n", + "3. Execute the computation and store the result into registers.\n", + "4. Store the values in the output registers into SRAM.\n", + "5. Copy the output values in SRAM back to HBM.\n", + "\n", + "Let’s implement a Pallas function that does just that!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "executionInfo": { + "elapsed": 108, + "status": "ok", + "timestamp": 1744764235906, + "user": { + "displayName": "Justin Fu", + "userId": "17543197034567316452" + }, + "user_tz": 420 + }, + "id": "IrPhDFnT3Nvw", + "outputId": "8bc03872-fd9f-4610-9d53-d4b46be560f4" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "Array([[2., 2., 2., ..., 2., 2., 2.],\n", + " [2., 2., 2., ..., 2., 2., 2.],\n", + " [2., 2., 2., ..., 2., 2., 2.],\n", + " ...,\n", + " [2., 2., 2., ..., 2., 2., 2.],\n", + " [2., 2., 2., ..., 2., 2., 2.],\n", + " [2., 2., 2., ..., 2., 2., 2.]], dtype=float32)" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Note: This is a TPU example.\n", + "\n", + "def add_matrices_kernel(x_sram_ref, y_sram_ref, z_sram_ref):\n", + " # Load x and y from SRAM into registers\n", + " x_regs = x_sram_ref[:, :]\n", + " y_regs = y_sram_ref[:, :]\n", + " # Execute a vectorized add\n", + " z_regs = x_regs + y_regs\n", + " # Store the output values in registers back into SRAM\n", + " z_sram_ref[:, :] = z_regs\n", + "\n", + "\n", + "def add_matrices(x: jax.Array, y: jax.Array) -> jax.Array:\n", + " # pallas_call will first allocate scratch buffers for `x` and `y` in SRAM.\n", + " # It will then copy `x` and `y` from HBM into SRAM.\n", + " z = pl.pallas_call(\n", + " add_matrices_kernel, out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype)\n", + " )(x, y)\n", + " # pallas_call will also copy the output from SRAM back into HBM.\n", + " return z\n", + "\n", + "\n", + "x, y = jnp.ones((512, 512)), jnp.ones((512, 512))\n", + "add_matrices(x, y)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "gGjtwv9u3UNK" + }, + "source": [ + "We've written two functions: `add_matrices_kernel` and `add_matrices`.\n", + "\n", + "`add_matrices_kernel` operates using `Refs` that live in SRAM. Loading from a SRAM Ref produces a value that lives in registers. Values in registers behave like jax.Arrays in that we can use `jnp` and `jax.lax` operations on them to produce new values that live in registers. When we produce the values we'd like to return, we store them in the output SRAM `Ref`.\n", + "\n", + "The `add_matrices` function acts on `jax.Array`s and returns a `jax.Array`. Inside it, we pass `x` and `y` into pallas_call. `pallas_call` is responsible for copying `x` and `y` into SRAM and for allocating the SRAM buffers that the kernel operates on (including allocating `z_vmem_ref`, the output SRAM buffer). After the kernel function is finished running, `pallas_call` will also copy the value in `z_vmem_ref` to HBM, resulting in an output `jax.Array`.\n", + "\n", + "Pallas exposes access to lower level memory spaces like SRAM but writing performant kernels requires more care in utilizing the various memory spaces. For example, we need to consider both:\n", + "\n", + "- **Memory capacity**. SRAM is small! If our arrays are too big, the above kernel would not work because we cannot fit the input into SRAM. For reference, an `f32[2048, 2048]` array is 16MiB, so our above kernel won't scale beyond moderately sized arrays.\n", + "\n", + "- **Memory bandwidth**. Copying to/from HBM and SRAM takes a long time, at least compared to most compute instructions. The `add_matrices` function above will likely spend more time copying between HBM and SRAM than actually performing the addition itself.\n", + "\n", + "With these two constraints in mind, we'll have to rethink our strategy for getting performance out of our accelerators.\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "0Ebs2pCDgsEW" + }, + "source": [ + "## Pipelining Basics\n", + "\n", + "\n", + "How can we take advantage of the strengths of each form of type memory in the hierarchy, and be able to operate on large arrays stored in HBM while still utilizing fast SRAM for compute? Pipelining is a very general programming pattern which will allow us to do exactly this, but it requires transforming your problem into smaller sub-problems that can be overlapped in parallel.\n", + "\n", + "The first step in pipelining is to divide our problem into smaller subproblems that can fit inside of SRAM. For example, an elementwise operation is can be trivially transformed by operating on one slice of the source array at a time, which results in the following 3 steps (also known as stages): \n", + "\n", + "1. **copy_in**: Copy a slice `A[i]` from HBM to SRAM `X`.\n", + "2. **compute**: Load `X` into registers, compute a result, and store in SRAM `Y`\n", + "3. **copy_out**: Copy result `Y` back into HBM `A[i]`.\n", + "\n", + "Note that there is a data-dependence between steps 1-3, and we cannot trivially overlap them since we need step (1) to complete before starting step (2), and so on. However, there is no data dependence across multiple invocations of the subproblem - that is, we can execute step (1) for block `A[i+1]` while executing step (2) for block `A[i]` and step (3) for block `A[i-1]`.\n", + "\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8vCtShhBjzTd" + }, + "source": [ + "\n", + "![pipelining_example](../_static/pallas/pipelining_example.svg)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Qs3F--kwiOJm" + }, + "source": [ + "The diagram above depicts how an idealized pipelined program can be scheduled across time. The key insight is that in the majority of the kernel, the copy operations are executed in parallel with compute operations, meaning we can ideally \"hide\" the cost of transferring between HBM/SRAM with computation and keep the processor busy with as much uptime as possible.\n", + "\n", + "The initial startup time and final teardown time known as \"bubbles\", where only a subset of the stages are being executed while the pipeline is being \"filled\" or \"drained\". The bulk of the time is spent in the \"steady-state\" phase of the pipeline, where each pipeline stage is being executed in parallel across different iterations of the subproblem. While with more general pipelining approaches the goal is to achieve N-way parallelism (where N is the number of stages), with kernel pipelining we are usually bottlenecked either by memory bandwidth or processing speed. Therefore, our goal with kernel pipelining is typically to achieve full utilization of the FLOPs/s of our processor, meaning that at any point in time there is always a `compute` block active. In the figure above, the compute block is active in 6/8 timeslots, and assuming we are fully utilizing the processor in each compute timeslot, we would have achieved 75% utilization of the processor." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ZcSzl4N6pPbG" + }, + "source": [ + "### Deriving a Double-Buffered Pipeline\n", + "\n", + "Now lets look at how we could implement a pipeline in pseudocode. Consider the following elementwise program, where we load values from HBM (`A[i]`) with a `copy_in` instruction, add 1 to the result, and store the result back to HBM with `copy_out`:\n", + "\n", + "
\n",
+    "for i in range(N):\n",
+    "  copy_in(A[i], X)\n",
+    "  Y = X + 1\n",
+    "  copy_out(Y, A[i])\n",
+    "
\n", + "The issue with this approach is that `copy_in` and `copy_out` are typically blocking operations. So we are forced to wait for the copies to finish while the GPU/TPU is idle, then perform compute while the memory is idle. What we would like to do is to \"pre-fetch\" the input value that is required on the next iteration of the loop asynchronously while performing the computation for the current loop, so that compute and memory communication are happening simultaneously.\n", + "\n", + "In order to reason about the code transformation we will make, lets unroll the loop for N=4, and decompose the copy instructions into separate `copy_start` and `copy_wait` operations to be able to express asynchrony:\n", + "
\n",
+    "  # Itr 1\n",
+    "  copy_in_start(A[0], X)\n",
+    "  copy_in_wait(X)\n",
+    "  Y = X + 1\n",
+    "  copy_out_start(Y, A[0])\n",
+    "  copy_out_wait(Y)\n",
+    "\n",
+    "  # Itr 2\n",
+    "  copy_in_start(A[1], X)\n",
+    "  copy_in_wait(X)\n",
+    "  Y = X + 1\n",
+    "  copy_out_start(Y, A[1])\n",
+    "  copy_out_wait(Y)\n",
+    "\n",
+    "  # Itr 3\n",
+    "  copy_in_start(A[2], X)\n",
+    "  copy_in_wait(X)\n",
+    "  Y = X + 1\n",
+    "  copy_out_start(Y, A[2])\n",
+    "  copy_out_wait(Y)\n",
+    "\n",
+    "  # Itr 4\n",
+    "  copy_in_start(A[3], X)\n",
+    "  copy_in_wait(X)\n",
+    "  Y = X + 1\n",
+    "  copy_out_start(Y, A[3])\n",
+    "  copy_out_wait(Y)\n",
+    "
\n", + "\n", + "Once the loop has been unrolled, the pipelining transformation simply involves issuing `copy_start` instructions as early as possible, and `copy_wait` values as late as possible (right before we need the value). However, in the current state of the loop there is a fake data dependency through X - we cannot simultaneously perform an async copy into X while using it for computation or else we may have a race condition. Therefore, we can use a **multiple-buffering** technique where we keep 2 buffers for each input X and each output Y. With 2 buffers, we can push the `copy_in_start` one iteration ahead (with 3 buffers you can push 2 iterations, and so on) and we rewrite our loop as follows:\n", + "
\n",
+    "  # Prologue\n",
+    "  copy_in_start(A[0], X[0])\n",
+    "  \n",
+    "  # Itr 1\n",
+    "  copy_in_start(A[1], X[1])\n",
+    "  copy_in_wait(X[0])\n",
+    "  Y[0] = X[0] + 1\n",
+    "  copy_out_start(Y[0], A[0])\n",
+    "  copy_out_wait(Y[0])\n",
+    "\n",
+    "  # Itr 2 - Steady state\n",
+    "  copy_in_start(A[2], X[0])\n",
+    "  copy_in_wait(X[1])\n",
+    "  Y[1] = X[1] + 1\n",
+    "  copy_out_start(Y[1], A[1])\n",
+    "  copy_out_wait(Y[1])\n",
+    "\n",
+    "  # Itr 3 - Steady state\n",
+    "  copy_in_start(A[3], X[1])\n",
+    "  copy_in_wait(X[0])\n",
+    "  Y[0] = X[0] + 1\n",
+    "  copy_out_start(Y[0], A[2])\n",
+    "  copy_out_wait(Y[0])\n",
+    "\n",
+    "  # Itr 4 - No copy-in\n",
+    "  copy_in_wait(X[1])\n",
+    "  Y[1] = X[1] + 1\n",
+    "  copy_out_start(Y[1], A[2])\n",
+    "  copy_out_wait(Y[1])\n",
+    "
\n", + "\n", + "Next, we can push the `copy_out_wait` as late as possible, right before we need to write into Y on the subsequent loop iteration.\n", + "\n", + "
\n",
+    "  # Prologue\n",
+    "  copy_in_start(A[0], X[0])\n",
+    "  \n",
+    "  # Itr 1\n",
+    "  copy_in_start(A[1], X[1])\n",
+    "  copy_in_wait(X[0])\n",
+    "  Y[0] = X[0] + 1\n",
+    "  copy_out_start(Y[0], A[0])\n",
+    "\n",
+    "  # Itr 2 - Steady state\n",
+    "  copy_in_start(A[2], X[0])\n",
+    "  copy_in_wait(X[1])\n",
+    "  Y[1] = X[1] + 1\n",
+    "  copy_out_start(Y[1], A[1])\n",
+    "  copy_out_wait(Y[0])\n",
+    "\n",
+    "  # Itr 3 - Steady state\n",
+    "  copy_in_start(A[3], X[1])\n",
+    "  copy_in_wait(X[0])\n",
+    "  Y[0] = X[0] + 1\n",
+    "  copy_out_start(Y[0], A[2])\n",
+    "  copy_out_wait(Y[1])\n",
+    "\n",
+    "  # Itr 4 - No copy-in\n",
+    "  copy_in_wait(X[1])\n",
+    "  Y[1] = X[1] + 1\n",
+    "  copy_out_start(Y[1], A[2])\n",
+    "  copy_out_wait(Y[0])\n",
+    "\n",
+    "  # Epilogue\n",
+    "  copy_out_wait(Y[1])\n",
+    "
\n", + "\n", + "Finally, re-rolling our loop back into a for loop, we obtain the following pipelined loop:\n", + "\n", + "```\n", + "# Prologue\n", + "copy_in_start(A[0], X[0])\n", + "\n", + "# Main loop\n", + "for i in range(N):\n", + " cur_slot = i % 2\n", + " next_slot = (i + 1) % 2\n", + "\n", + " if i < N:\n", + " copy_in_start(A[i+1], X[next_slot])\n", + " \n", + " copy_in_wait(X[cur_slot])\n", + " Y[cur_slot] = X[cur_slot] + 1\n", + " copy_out_start(Y[cur_slot], A[i])\n", + "\n", + " if i > 0:\n", + " copy_out_wait(Y[next_slot])\n", + "\n", + "# Epilogue\n", + "copy_out_wait(Y[1])\n", + "```\n", + "\n", + "If we want to generalize this loop to handle a broader set of computations, notice that we essentially need to specify 3 pieces of information to the pipeline:\n", + "\n", + "- The **grid**, or the bounds of the for loop that specifies the number of subproblems to compute. In our example we had a 1-dimensional grid with size `(N,)`.\n", + "- The **kernel**, or the actual computation happening once the inputs have been loaded into SRAM. In our example we performed an elementwise addition `Y = X + 1`.\n", + "- The **data_slices**, which map a subproblem to corresponding slices into the HBM buffer. In our example the data slice was the identity function `lambda i: i`.\n", + "\n", + "By allowing the user to specify these pieces of information we can write a wide variety of programs following this pattern:\n", + "```python\n", + "def double_buffered_pipeline(\n", + " grid: tuple[int, ...],\n", + " kernel: Callable,\n", + " in_slices: Callable,\n", + " out_slices: Callable):\n", + " # Prologue\n", + " copy_in_start(in_hbm[in_slices(0)], in_sram[0])\n", + "\n", + " # Main loop\n", + " grid_size = prod(grid)\n", + " for i in range(grid_size):\n", + " cur_slot = i % 2\n", + " next_slot = (i + 1) % 2\n", + " if i < grid_size:\n", + " copy_in_start(in_hbm[data_slices(i+1)], in_sram[next_slot])\n", + " copy_in_wait(in_sram[cur_slot])\n", + "\n", + " kernel(inputs, outputs)\n", + "\n", + " copy_out_start(out_sram[cur_slot], out_hbm[out_slices(i)])\n", + " if i > 0:\n", + " copy_out_wait(out_sram[next_slot])\n", + "\n", + " # Epilogue\n", + " copy_out_wait(out_sram[1])\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ziBuvv8jDgxo" + }, + "source": [ + "Now that we've seen how to manually implement a pipelined loop, let's look into how to use the Pallas API." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "niMr39cPkJ2m" + }, + "source": [ + "## Pallas Pipelining API\n", + "\n", + "Pallas offers a pipelining API that abstracts away the boilerplate of maintaining multiple buffers and overlapping asynchronous communication with computation. The basics of this API are covered in the [quickstart](https://docs.jax.dev/en/latest/pallas/quickstart.html), so we will go over the API briefly here for completeness and discuss some sharp edges that arise from the use of pipelining.\n", + "\n", + "\n", + "### Grid\n", + "\n", + "The program **grid** is a tuple of integers specifying the number of subproblems as an array. The structure of the pipeline can be interpreted as a nested for-loop where the bounds of each loop.\n", + "\n", + "```\n", + "# For grid (N, M, K)\n", + "for n in range (N):\n", + " for m in range(M):\n", + " for k in range(K):\n", + " kernel()\n", + "```\n", + "\n", + "The kernel will be invoked a total of `prod(grid)` times. For more details, see [grid and blockspecs](https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#grid-a-k-a-kernels-in-a-loop).\n", + "\n", + "### BlockSpecs\n", + "\n", + "A BlockSpec specifies the size and slice of data copied to the kernel on each subproblem. The basic constructor to `pl.BlockSpec` involves specifying the `block_shape`, the size of a slice of data, and `index_map`, a function that takes in the program ids of the current subproblem and outputs _blocked_ indices into the source buffer. Blocked indices specify which block to copy on each iteration, assuming the source buffer has been carved into blocks of shape as `block_shape`. The `memory_space` argument specifies what memory space to copy the inputs to - be default this will be SRAM.\n", + "\n", + "```python\n", + "pl.BlockSpec(\n", + " block_shape: tuple[int, ...],\n", + " index_map: Callable,\n", + " memory_space: pl.MemorySpace\n", + ")\n", + "```\n", + "There should be one BlockSpec for each input and each output to the kernel. For more details, see [grid and blockspecs](https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#grid-a-k-a-kernels-in-a-loop).\n", + "\n", + "### Kernel\n", + "\n", + "The kernel function specifies what compute to perform on each subproblem. The kernel function should return no outputs, and instead all outputs should be written into the output buffers that are passed into the kernel. All inputs and output buffers are SRAM buffers by default (unless the user has overridden the behavior by specifying a `memory_space` on the corresponding `BlockSpec`).\n", + "\n", + "```python\n", + "def kernel(*input_buffers, *output_buffers):\n", + " # ... perform compute\n", + " # ... store result into output buffers\n", + "```\n", + "\n", + "The index of the current subproblem can be queried inside the kernel using `pl.program_id(grid_axis: int)`.\n", + "\n", + "\n", + "### Pallas Call\n", + "\n", + "The `pl.pallas_call` function is the main entry point to Pallas and performs pipelined execution when a grid and BlockSpecs are supplied. It has the following signature:\n", + "```python\n", + "def pallas_call(\n", + " kernel,\n", + " grid: tuple[int, ...],\n", + " in_specs: Sequence[PyTree[BlockSpec]],\n", + " out_specs: PyTree[BlockSpec],\n", + " out_shape: PyTree[jax.ShapeDtypeStruct],\n", + ") -> Callable:\n", + "```\n", + "`pallas_call` will return a callable function that when invoked with input values, will return outputs of the same shape as `out_shape`.\n", + "\n", + "`in_specs`, `out_specs`, and `out_shape` are PyTrees of their respective element type. The PyTrees for `in_specs` and the input buffers supplied to the kernel should match, and the PyTrees for `out_specs` and `out_shape` should also match.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "0mHZ63eAq_8j" + }, + "source": [ + "### Example - Elementwise Kernel revisited\n", + "\n", + "Let's revisit the initial `add_matrices_kernel` from the beginning of the tutorial, except using pipelining. We will add two input arrays of shape `f32[4096, 4096]` that live in HBM. As subproblems, we will carve up the inputs into `block_shape=(512, 512)` blocks and only add two blocks together at a time in the kernel. Because addition is elementwise, each `index_map` is identical and selects out the `i, j`th block on the `i, j`th iteration." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "iqr_qjONAHN9" + }, + "outputs": [], + "source": [ + "# Note: This is a TPU example.\n", + "\n", + "total_shape = (4096, 4096)\n", + "block_shape = (512, 512)\n", + "\n", + "def add_matrices_pipelined_kernel(x_ref, y_ref, o_ref):\n", + " o_ref[...] = x_ref[...] + y_ref[...]\n", + "\n", + "def add_matrices_pipelined(x: jax.Array, y: jax.Array):\n", + " return pl.pallas_call(\n", + " add_matrices_pipelined_kernel,\n", + " grid=tuple(total // block for (total, block) in zip(total_shape, block_shape)),\n", + " in_specs=[\n", + " pl.BlockSpec(block_shape, index_map=lambda i, j: (i, j)),\n", + " pl.BlockSpec(block_shape, index_map=lambda i, j: (i, j))\n", + " ],\n", + " out_specs=pl.BlockSpec(block_shape, index_map=lambda i, j: (i, j)),\n", + " out_shape=jax.ShapeDtypeStruct(total_shape, dtype=jnp.float32),\n", + " )(x, y)\n", + "\n", + "x = jax.random.uniform(jax.random.key(0), total_shape, dtype=jnp.float32)\n", + "y = jax.random.uniform(jax.random.key(1), total_shape, dtype=jnp.float32)\n", + "result = add_matrices_pipelined(x, y)\n", + "np.testing.assert_array_equal(\n", + " result, x + y\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "UWHD0_qm6DL7" + }, + "source": [ + "It turns out that with this API, writing a pipelined kernel is not much more lines of code than writing our original naive addition kernel!" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "BZ-4U6Cv6cvU" + }, + "source": [ + "### Parameterizing a Kernel\n", + "\n", + "It's common to parameterize the block shapes in our kernel. Block sizes are perhaps the most important parameter to tune when optimizing the performance of Pallas kernels! They give us control over the pipeline (for example, picking smaller blocks adds more iterations to our pipelined loop where each iteration has less work to do). Let's write a a function that does so:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "RZTAiwrZ6srD" + }, + "outputs": [], + "source": [ + "def add_matrices_pipelined_param(\n", + " x: jax.Array, y: jax.Array, *, bm: int = 256, bn: int = 256\n", + ") -> jax.Array:\n", + " m, n = x.shape\n", + " block_spec = pl.BlockSpec((bm, bn), lambda i, j: (i, j))\n", + " return pl.pallas_call(\n", + " add_matrices_kernel,\n", + " out_shape=x,\n", + " in_specs=[block_spec, block_spec],\n", + " out_specs=block_spec,\n", + " grid=(m // bm, n // bn),\n", + " )(x, y)\n", + "\n", + "np.testing.assert_array_equal(\n", + " add_matrices_pipelined_param(x, y, bm=256, bn=256), x + y\n", + ")\n", + "np.testing.assert_array_equal(\n", + " add_matrices_pipelined_param(x, y, bm=128, bn=128), x + y\n", + ")\n", + "np.testing.assert_array_equal(\n", + " add_matrices_pipelined_param(x, y, bm=512, bn=512), x + y\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "vO8VkbYj_ral" + }, + "source": [ + "## Sharp edges\n", + "\n", + "While pipelining provides a close approximation to the mental model of simply calling a kernel function in a loop, there are a number of sharp edges that arise from the use of intermediate buffers that are not fully hidden from the user and can result in subtle bugs.\n", + "\n", + "### Buffer Revisiting\n", + "\n", + "In general, a good rule-of-thumb to follow is that **the input buffers passed into the kernel function should be interpreted as read-only, and output buffers are write only**.\n", + "\n", + "Writing to inputs and reading from outputs will in most cases result in incorrectness. This is because the SRAM buffers passed to a kernel only contain copies of the data contained in the underlying HBM buffer. If an input SRAM buffer is updated, the updated results will never be written back out to HBM, and if an output buffer is updated, it's updated value is never read into SRAM. This issue is analogous to staleness issues encountered when using caches in general.\n", + "\n", + "There are two cases where a buffer supports both reads and writes - accumulation (discussed next), and marking a pair of input and output buffers as input-output aliased by passing in the `input_output_aliases` argument to `pallas_call`.\n", + "\n", + "\n", + "### Reductions and accumulation\n", + "\n", + "**Reduction/accumulation should only be performed over the last (innermost) dimensions of the grid, and the buffer should be initialized manually first.**\n", + "\n", + "Reductions are one of the few cases where the pipeline supports both reading and writing to an output buffer, but the reason it works is subtle.\n", + "The Pallas pipeline emitter performs an optimization where if the data slices between two consecutive iterations are the same, the pipeline will not issue a `copy_in`/`copy_out` on that buffer. This means the same SRAM buffer used in a previous iteration will be passed into the kernel again on the following iteration, and thus any writes that were issued to the output buffer will become visible on the next iteration. Once the grid index changes, the final accumulated SRAM buffer will be written out to HBM. This is also why reductions must be performed over the last dimensions of the grid -- we want to finish all of the accumulation while the output buffer is in SRAM in the innermost loop, then write it to HBM and never touch that output block again.\n", + "\n", + "As a concrete example, let's consider performing the following computation for reducing an `(8, 1024, 1024)` array along the first axies into a `(1024, 1024)` array.\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "executionInfo": { + "elapsed": 244, + "status": "ok", + "timestamp": 1744763773938, + "user": { + "displayName": "Justin Fu", + "userId": "17543197034567316452" + }, + "user_tz": 420 + }, + "id": "4qz1ET-_f9fJ", + "outputId": "e43067ef-933a-45a5-912a-e224151cfa60" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "Array([[8., 8., 8., ..., 8., 8., 8.],\n", + " [8., 8., 8., ..., 8., 8., 8.],\n", + " [8., 8., 8., ..., 8., 8., 8.],\n", + " ...,\n", + " [8., 8., 8., ..., 8., 8., 8.],\n", + " [8., 8., 8., ..., 8., 8., 8.],\n", + " [8., 8., 8., ..., 8., 8., 8.]], dtype=float32)" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "x = jnp.ones((8, 1024, 1024))\n", + "jnp.sum(x, axis=0)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "yX762DRrgCOG" + }, + "source": [ + "To do this using `pallas_call`, we could use a grid of size `(8,)` and in each iteration i load `x[i]` into SRAM. Then we could add `x[i]` to an output SRAM buffer. Let's implement this naively first." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "executionInfo": { + "elapsed": 79, + "status": "ok", + "timestamp": 1744763774254, + "user": { + "displayName": "Justin Fu", + "userId": "17543197034567316452" + }, + "user_tz": 420 + }, + "id": "ZEi1_vQVf-81", + "outputId": "581744b7-ddc1-4dc1-98ec-03c852772eda" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[65. 65. 65. ... 66. 66. 66.]\n", + " [65. 65. 65. ... 66. 66. 66.]\n", + " [65. 65. 65. ... 66. 66. 66.]\n", + " ...\n", + " [71. 71. 71. ... 72. 72. 72.]\n", + " [71. 71. 71. ... 72. 72. 72.]\n", + " [71. 71. 71. ... 72. 72. 72.]]\n" + ] + } + ], + "source": [ + "# Note: This is a TPU example.\n", + "\n", + "# Warning: this implementation is incorrect!\n", + "def incorrect_sum_kernel(x_ref, o_ref):\n", + " o_ref[...] += x_ref[...]\n", + "\n", + "def incorrect_sum(x: jax.Array,\n", + " block_size: tuple[int, ...] = (256, 256)) -> jax.Array:\n", + " reduction_size, *out_shape = x.shape\n", + " grid = (reduction_size, *(out // blk for out, blk in zip(out_shape, block_size)))\n", + " return pl.pallas_call(\n", + " incorrect_sum_kernel,\n", + " grid=grid,\n", + " # None in `block_shape` means we pick a size of 1 and squeeze it away\n", + " in_specs=[pl.BlockSpec((None, *block_size), lambda i, j, k: (i, j, k))],\n", + " out_specs=pl.BlockSpec(block_size, lambda i, j, k: (j, k)),\n", + " out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype),\n", + " )(x)\n", + "\n", + "result = incorrect_sum(x)\n", + "print(result)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "MglScPDD9618" + }, + "source": [ + "This result is completely wrong!\n", + "\n", + "There are two errors inside this kernel. First, we are accumulating along the first grid dimension instead of the last grid dimension. Second, `o_ref` is initially contains garbage values and thus we need to initialize it to zeros before we begin accumulation.\n", + "\n", + "After fixing these two issues, we obtain the following corrected kernel. In this new kernel, we use `@pl.when` to create a conditional that checks when the program ID is `0` along the reduction axis, indicating we are beginning to accumulate into a new output block. We have also moved the reduction dimension to the last axis of the `grid`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "executionInfo": { + "elapsed": 104, + "status": "ok", + "timestamp": 1744763774523, + "user": { + "displayName": "Justin Fu", + "userId": "17543197034567316452" + }, + "user_tz": 420 + }, + "id": "XtgD4nMa9_Bd", + "outputId": "9ef07cdf-9e22-4dc8-c17f-c96172639801" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[8. 8. 8. ... 8. 8. 8.]\n", + " [8. 8. 8. ... 8. 8. 8.]\n", + " [8. 8. 8. ... 8. 8. 8.]\n", + " ...\n", + " [8. 8. 8. ... 8. 8. 8.]\n", + " [8. 8. 8. ... 8. 8. 8.]\n", + " [8. 8. 8. ... 8. 8. 8.]]\n" + ] + } + ], + "source": [ + "# Note: This is a TPU example.\n", + "\n", + "def correct_sum_kernel(x_ref, o_ref):\n", + " @pl.when(pl.program_id(2) == 0)\n", + " def _():\n", + " o_ref[...] = jnp.zeros_like(o_ref)\n", + " o_ref[...] += x_ref[...]\n", + "\n", + "def correct_sum(x: jax.Array,\n", + " block_size: tuple[int, ...] = (256, 256)) -> jax.Array:\n", + " reduction_size, *out_shape = x.shape\n", + " # We moved the reduction to the last axis of the grid.\n", + " grid = (*(out // blk for out, blk in zip(out_shape, block_size)), reduction_size)\n", + " return pl.pallas_call(\n", + " correct_sum_kernel,\n", + " grid=grid,\n", + " # None in `block_shape` means we pick a size of 1 and squeeze it away\n", + " in_specs=[pl.BlockSpec((None, *block_size), lambda i, j, k: (k, i, j))],\n", + " out_specs=pl.BlockSpec(block_size, lambda i, j, k: (i, j)),\n", + " out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype),\n", + " )(x)\n", + "\n", + "result = correct_sum(x)\n", + "print(result)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "BckuFg6qcnVw" + }, + "source": [ + "\n", + "## Analyzing the performance\n", + "\n", + "What is the performance of a pipelined kernel? This question can vary depending on where the bottleneck is the hardware is. We are typically interested in 3 quantities:\n", + "- **Memory latency** $α$, the minimum latency of a memory transfer.\n", + "- **Memory bandwidth** $β$, the rate in bytes/second that we can transfer from HBM to SRAM.\n", + "- **FLOP/s** $F$, or floating-point-operations per second, the number of calculations per second that the processor can perform.\n", + "\n", + "We refer to a program as **compute-bound** if the processing speed FLOPs/s is the bottleneck, and as **memory-bound** if the bandwidth or latency are the bottleneck. Generally, our goal is to optimize a kernel such that it is compute-bound, meaning we are utilizing all of the available processing power of our hardware.\n", + "\n", + "Suppose we are running a program that requires $X$ bytes of memory transfers per kernel iteration, and runs $Y$ floating-point operations per iteration. The ratio of $X$ to $Y$ varies depending on the type of compute -- for elementwise operations such as addition or multiplication, they will both scale equally. However, for operations such as matrix multiplication, compute scales cubically with the size of the problem while memory scales quadratically.\n", + "\n", + "In a **compute-bound** regime, a pipeline running $N$ iterations would take $(\\alpha + X/\\beta) + N (Y/F)$ seconds, where the first term represents the cost of the initial bubble (multiply by a factor of 2 if there is also a bubble at the end), and the second term represents the total time of the steady-state of the pipeline. Assuming that N is large and there is enough work to produce a long pipeline, the dominating term in the runtime is $F$, the processing speed of the accelerator.\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NDY4mcae_nMO" + }, + "source": [ + "\n", + "![pipelining_compute](../_static/pallas/pipelining_compute_bound.svg)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HFWcaAudW4z1" + }, + "source": [ + "In a **memory-bound** regime it is useful to identify if the problem is the latency versus the bandwidth. If the bandwidth is the bottleneck, then the total runtime would take $\\alpha + X / \\beta$ seconds. In contrast with a latency-bound regime, the memory copies happen serially because the bandwidth is already saturated. Being memory-bound is generally not ideal as there will be gaps in time where the processor is idle, and in most hardware configurations the memory bandwidth $\\beta$ is orders of magnitude slower than the processing speed $F$." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "gqcCDsGg_sca" + }, + "source": [ + "\n", + "![pipelining_bandwidth](../_static/pallas/pipelining_bandwidth_bound.svg)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "V4YQCZf1W7X5" + }, + "source": [ + "If the bottleneck is specifically the latency and not the bandwidth, it is possible to fix the problem by inserting additional pipeline stages at the cost of additional SRAM required to store more buffers. With sufficient stages, the problem will either become compute or latency bound again depending on which bottleneck we hit first during the steady-stage stage of the pipeline. The downside, however, of a multi-stage pipeline is that the size of the bubble is proportional to the number of stages so it is important to make sure the pipeline is long enough such that the bubble does not take up a substantial amount of the total runtime.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Sj5PFl0s_yc6" + }, + "source": [ + "\n", + "![pipelining_latency](../_static/pallas/pipelining_latency_multistage.svg)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ar4NVxxFfKEb" + }, + "source": [ + "Pallas on TPU only supports double-buffering, as TPU programs can operate on larger block sizes and double-buffering is typically enough to cover the latency. On GPU, the number of pipeline stages can be specified in both the Triton (via `TritonCompilerParams`) and Mosaic GPU backends (via argument to the pipeline emitter). See the platform-specific pipelining documentation for more details." + ] + } + ], + "metadata": { + "colab": { + "last_runtime": { + "build_target": "//experimental/users/justinfu/pallas:colab", + "kind": "private" + }, + "provenance": [] + }, + "jupytext": { + "main_language": "python" + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/docs/pallas/pipelining.md b/docs/pallas/pipelining.md new file mode 100644 index 000000000000..42b91e368238 --- /dev/null +++ b/docs/pallas/pipelining.md @@ -0,0 +1,595 @@ +--- +jupyter: + jupytext: + main_language: python + text_representation: + extension: .md + format_name: markdown + format_version: '1.3' + jupytext_version: 1.16.4 + kernelspec: + display_name: Python 3 + name: python3 +--- + + +# Software Pipelining + +Software pipelining is an important technique in performance optimization by overlapping multiple asynchronous operations even if there are data dependencies between them. In the context of kernel writing, the most common form of pipelining involves overlapping communication and memory transfers with compute such that the hardware accelerator never stalls while waiting for data to arrive. Therefore, we will solely focus on the problem of communication-compute pipelining in this tutorial. We will begin by covering the problem conceptually, outlining the Pallas API for writing pipelines, and going over some realistic examples using the API. + +This tutorial only covers the conceptual foundations of pipelining. For platform-specific references, please see the [TPU](https://docs.jax.dev/en/latest/pallas/tpu/pipelining.html), or GPU (coming soon!) specific pipelining references. + + + +```python id="YkOjspo5BKPD" +import jax +from jax import numpy as jnp +from jax.experimental import pallas as pl +import numpy as np +``` + + +## Memory Hierarchies + +The first step in understanding pipelining conceptually involves understanding the different forms of memory available and the tradeoffs between them. Most hardware architectures (including CPUs, GPUs, and TPUs) utilize a wide variety of memory spaces that tradeoff capicity vs latency/bandwidth. For the purpose of Pallas, we are typically interested in registers, SRAM, DRAM, and potentially network communication: +- **Registers** are the the memory physically closest to the processor, and typically values must be loaded directly into registers before doing any compute on them. +- **SRAM** (also known as Shared Memory/L1 and L2 cache on GPUs, or VMEM on TPUs) also lives fairly close to the processor, but has larger capacity than registers. +SRAM on modern ML accelerators typically range in the 10-100MB range (TPU v5p contains 96MB of VMEM, and H100 GPUs contain ~30MB of L1 cache and 50MB of L2). +It's reasonable to expect the latency to access SRAM to be on the order of 10x longer than accessing a register. +- **DRAM** (also known as HBM) has much higher capacity than SRAM, typically in the 10-100GB range for modern ML accelerators. However, the latency is roughly on the order of 10x longer to access compared to SRAM. +- **Network** communication becomes crucial for larger workloads when the size of DRAM on a single device becomes insufficient or when we'd like to take advantage of parallel computations. We do not cover distributed pipelining in this tutorial, but see the [distributed TPU kernels](https://docs.jax.dev/en/latest/pallas/tpu/distributed.html) guide for writing pipelines across multiple devices. + + + + +![memory_hierarchy](../_static/pallas/pipelining_mem_hierarchy.svg) + + + + + + +In order to perform computation on values X and Y that live in HBM, we need to: + +1. Copy the values x and y into SRAM. +2. Load the values from SRAM into registers. +3. Execute the computation and store the result into registers. +4. Store the values in the output registers into SRAM. +5. Copy the output values in SRAM back to HBM. + +Let’s implement a Pallas function that does just that! + + +```python id="IrPhDFnT3Nvw" executionInfo={"status": "ok", "timestamp": 1744764235906, "user_tz": 420, "elapsed": 108, "user": {"displayName": "Justin Fu", "userId": "17543197034567316452"}} outputId="8bc03872-fd9f-4610-9d53-d4b46be560f4" +# Note: This is a TPU example. + +def add_matrices_kernel(x_sram_ref, y_sram_ref, z_sram_ref): + # Load x and y from SRAM into registers + x_regs = x_sram_ref[:, :] + y_regs = y_sram_ref[:, :] + # Execute a vectorized add + z_regs = x_regs + y_regs + # Store the output values in registers back into SRAM + z_sram_ref[:, :] = z_regs + + +def add_matrices(x: jax.Array, y: jax.Array) -> jax.Array: + # pallas_call will first allocate scratch buffers for `x` and `y` in SRAM. + # It will then copy `x` and `y` from HBM into SRAM. + z = pl.pallas_call( + add_matrices_kernel, out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype) + )(x, y) + # pallas_call will also copy the output from SRAM back into HBM. + return z + + +x, y = jnp.ones((512, 512)), jnp.ones((512, 512)) +add_matrices(x, y) +``` + + +We've written two functions: `add_matrices_kernel` and `add_matrices`. + +`add_matrices_kernel` operates using `Refs` that live in SRAM. Loading from a SRAM Ref produces a value that lives in registers. Values in registers behave like jax.Arrays in that we can use `jnp` and `jax.lax` operations on them to produce new values that live in registers. When we produce the values we'd like to return, we store them in the output SRAM `Ref`. + +The `add_matrices` function acts on `jax.Array`s and returns a `jax.Array`. Inside it, we pass `x` and `y` into pallas_call. `pallas_call` is responsible for copying `x` and `y` into SRAM and for allocating the SRAM buffers that the kernel operates on (including allocating `z_vmem_ref`, the output SRAM buffer). After the kernel function is finished running, `pallas_call` will also copy the value in `z_vmem_ref` to HBM, resulting in an output `jax.Array`. + +Pallas exposes access to lower level memory spaces like SRAM but writing performant kernels requires more care in utilizing the various memory spaces. For example, we need to consider both: + +- **Memory capacity**. SRAM is small! If our arrays are too big, the above kernel would not work because we cannot fit the input into SRAM. For reference, an `f32[2048, 2048]` array is 16MiB, so our above kernel won't scale beyond moderately sized arrays. + +- **Memory bandwidth**. Copying to/from HBM and SRAM takes a long time, at least compared to most compute instructions. The `add_matrices` function above will likely spend more time copying between HBM and SRAM than actually performing the addition itself. + +With these two constraints in mind, we'll have to rethink our strategy for getting performance out of our accelerators. + + + + + +## Pipelining Basics + + +How can we take advantage of the strengths of each form of type memory in the hierarchy, and be able to operate on large arrays stored in HBM while still utilizing fast SRAM for compute? Pipelining is a very general programming pattern which will allow us to do exactly this, but it requires transforming your problem into smaller sub-problems that can be overlapped in parallel. + +The first step in pipelining is to divide our problem into smaller subproblems that can fit inside of SRAM. For example, an elementwise operation is can be trivially transformed by operating on one slice of the source array at a time, which results in the following 3 steps (also known as stages): + +1. **copy_in**: Copy a slice `A[i]` from HBM to SRAM `X`. +2. **compute**: Load `X` into registers, compute a result, and store in SRAM `Y` +3. **copy_out**: Copy result `Y` back into HBM `A[i]`. + +Note that there is a data-dependence between steps 1-3, and we cannot trivially overlap them since we need step (1) to complete before starting step (2), and so on. However, there is no data dependence across multiple invocations of the subproblem - that is, we can execute step (1) for block `A[i+1]` while executing step (2) for block `A[i]` and step (3) for block `A[i-1]`. + + + + + + + +![pipelining_example](../_static/pallas/pipelining_example.svg) + + + + +The diagram above depicts how an idealized pipelined program can be scheduled across time. The key insight is that in the majority of the kernel, the copy operations are executed in parallel with compute operations, meaning we can ideally "hide" the cost of transferring between HBM/SRAM with computation and keep the processor busy with as much uptime as possible. + +The initial startup time and final teardown time known as "bubbles", where only a subset of the stages are being executed while the pipeline is being "filled" or "drained". The bulk of the time is spent in the "steady-state" phase of the pipeline, where each pipeline stage is being executed in parallel across different iterations of the subproblem. While with more general pipelining approaches the goal is to achieve N-way parallelism (where N is the number of stages), with kernel pipelining we are usually bottlenecked either by memory bandwidth or processing speed. Therefore, our goal with kernel pipelining is typically to achieve full utilization of the FLOPs/s of our processor, meaning that at any point in time there is always a `compute` block active. In the figure above, the compute block is active in 6/8 timeslots, and assuming we are fully utilizing the processor in each compute timeslot, we would have achieved 75% utilization of the processor. + + + +### Deriving a Double-Buffered Pipeline + +Now lets look at how we could implement a pipeline in pseudocode. Consider the following elementwise program, where we load values from HBM (`A[i]`) with a `copy_in` instruction, add 1 to the result, and store the result back to HBM with `copy_out`: + +
+for i in range(N):
+  copy_in(A[i], X)
+  Y = X + 1
+  copy_out(Y, A[i])
+
+The issue with this approach is that `copy_in` and `copy_out` are typically blocking operations. So we are forced to wait for the copies to finish while the GPU/TPU is idle, then perform compute while the memory is idle. What we would like to do is to "pre-fetch" the input value that is required on the next iteration of the loop asynchronously while performing the computation for the current loop, so that compute and memory communication are happening simultaneously. + +In order to reason about the code transformation we will make, lets unroll the loop for N=4, and decompose the copy instructions into separate `copy_start` and `copy_wait` operations to be able to express asynchrony: +
+  # Itr 1
+  copy_in_start(A[0], X)
+  copy_in_wait(X)
+  Y = X + 1
+  copy_out_start(Y, A[0])
+  copy_out_wait(Y)
+
+  # Itr 2
+  copy_in_start(A[1], X)
+  copy_in_wait(X)
+  Y = X + 1
+  copy_out_start(Y, A[1])
+  copy_out_wait(Y)
+
+  # Itr 3
+  copy_in_start(A[2], X)
+  copy_in_wait(X)
+  Y = X + 1
+  copy_out_start(Y, A[2])
+  copy_out_wait(Y)
+
+  # Itr 4
+  copy_in_start(A[3], X)
+  copy_in_wait(X)
+  Y = X + 1
+  copy_out_start(Y, A[3])
+  copy_out_wait(Y)
+
+ +Once the loop has been unrolled, the pipelining transformation simply involves issuing `copy_start` instructions as early as possible, and `copy_wait` values as late as possible (right before we need the value). However, in the current state of the loop there is a fake data dependency through X - we cannot simultaneously perform an async copy into X while using it for computation or else we may have a race condition. Therefore, we can use a **multiple-buffering** technique where we keep 2 buffers for each input X and each output Y. With 2 buffers, we can push the `copy_in_start` one iteration ahead (with 3 buffers you can push 2 iterations, and so on) and we rewrite our loop as follows: +
+  # Prologue
+  copy_in_start(A[0], X[0])
+  
+  # Itr 1
+  copy_in_start(A[1], X[1])
+  copy_in_wait(X[0])
+  Y[0] = X[0] + 1
+  copy_out_start(Y[0], A[0])
+  copy_out_wait(Y[0])
+
+  # Itr 2 - Steady state
+  copy_in_start(A[2], X[0])
+  copy_in_wait(X[1])
+  Y[1] = X[1] + 1
+  copy_out_start(Y[1], A[1])
+  copy_out_wait(Y[1])
+
+  # Itr 3 - Steady state
+  copy_in_start(A[3], X[1])
+  copy_in_wait(X[0])
+  Y[0] = X[0] + 1
+  copy_out_start(Y[0], A[2])
+  copy_out_wait(Y[0])
+
+  # Itr 4 - No copy-in
+  copy_in_wait(X[1])
+  Y[1] = X[1] + 1
+  copy_out_start(Y[1], A[2])
+  copy_out_wait(Y[1])
+
+ +Next, we can push the `copy_out_wait` as late as possible, right before we need to write into Y on the subsequent loop iteration. + +
+  # Prologue
+  copy_in_start(A[0], X[0])
+  
+  # Itr 1
+  copy_in_start(A[1], X[1])
+  copy_in_wait(X[0])
+  Y[0] = X[0] + 1
+  copy_out_start(Y[0], A[0])
+
+  # Itr 2 - Steady state
+  copy_in_start(A[2], X[0])
+  copy_in_wait(X[1])
+  Y[1] = X[1] + 1
+  copy_out_start(Y[1], A[1])
+  copy_out_wait(Y[0])
+
+  # Itr 3 - Steady state
+  copy_in_start(A[3], X[1])
+  copy_in_wait(X[0])
+  Y[0] = X[0] + 1
+  copy_out_start(Y[0], A[2])
+  copy_out_wait(Y[1])
+
+  # Itr 4 - No copy-in
+  copy_in_wait(X[1])
+  Y[1] = X[1] + 1
+  copy_out_start(Y[1], A[2])
+  copy_out_wait(Y[0])
+
+  # Epilogue
+  copy_out_wait(Y[1])
+
+ +Finally, re-rolling our loop back into a for loop, we obtain the following pipelined loop: + +``` +# Prologue +copy_in_start(A[0], X[0]) + +# Main loop +for i in range(N): + cur_slot = i % 2 + next_slot = (i + 1) % 2 + + if i < N: + copy_in_start(A[i+1], X[next_slot]) + + copy_in_wait(X[cur_slot]) + Y[cur_slot] = X[cur_slot] + 1 + copy_out_start(Y[cur_slot], A[i]) + + if i > 0: + copy_out_wait(Y[next_slot]) + +# Epilogue +copy_out_wait(Y[1]) +``` + +If we want to generalize this loop to handle a broader set of computations, notice that we essentially need to specify 3 pieces of information to the pipeline: + +- The **grid**, or the bounds of the for loop that specifies the number of subproblems to compute. In our example we had a 1-dimensional grid with size `(N,)`. +- The **kernel**, or the actual computation happening once the inputs have been loaded into SRAM. In our example we performed an elementwise addition `Y = X + 1`. +- The **data_slices**, which map a subproblem to corresponding slices into the HBM buffer. In our example the data slice was the identity function `lambda i: i`. + +By allowing the user to specify these pieces of information we can write a wide variety of programs following this pattern: +```python +def double_buffered_pipeline( + grid: tuple[int, ...], + kernel: Callable, + in_slices: Callable, + out_slices: Callable): + # Prologue + copy_in_start(in_hbm[in_slices(0)], in_sram[0]) + + # Main loop + grid_size = prod(grid) + for i in range(grid_size): + cur_slot = i % 2 + next_slot = (i + 1) % 2 + if i < grid_size: + copy_in_start(in_hbm[data_slices(i+1)], in_sram[next_slot]) + copy_in_wait(in_sram[cur_slot]) + + kernel(inputs, outputs) + + copy_out_start(out_sram[cur_slot], out_hbm[out_slices(i)]) + if i > 0: + copy_out_wait(out_sram[next_slot]) + + # Epilogue + copy_out_wait(out_sram[1]) +``` + + + +Now that we've seen how to manually implement a pipelined loop, let's look into how to use the Pallas API. + + + +## Pallas Pipelining API + +Pallas offers a pipelining API that abstracts away the boilerplate of maintaining multiple buffers and overlapping asynchronous communication with computation. The basics of this API are covered in the [quickstart](https://docs.jax.dev/en/latest/pallas/quickstart.html), so we will go over the API briefly here for completeness and discuss some sharp edges that arise from the use of pipelining. + + +### Grid + +The program **grid** is a tuple of integers specifying the number of subproblems as an array. The structure of the pipeline can be interpreted as a nested for-loop where the bounds of each loop. + +``` +# For grid (N, M, K) +for n in range (N): + for m in range(M): + for k in range(K): + kernel() +``` + +The kernel will be invoked a total of `prod(grid)` times. For more details, see [grid and blockspecs](https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#grid-a-k-a-kernels-in-a-loop). + +### BlockSpecs + +A BlockSpec specifies the size and slice of data copied to the kernel on each subproblem. The basic constructor to `pl.BlockSpec` involves specifying the `block_shape`, the size of a slice of data, and `index_map`, a function that takes in the program ids of the current subproblem and outputs _blocked_ indices into the source buffer. Blocked indices specify which block to copy on each iteration, assuming the source buffer has been carved into blocks of shape as `block_shape`. The `memory_space` argument specifies what memory space to copy the inputs to - be default this will be SRAM. + +```python +pl.BlockSpec( + block_shape: tuple[int, ...], + index_map: Callable, + memory_space: pl.MemorySpace +) +``` +There should be one BlockSpec for each input and each output to the kernel. For more details, see [grid and blockspecs](https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#grid-a-k-a-kernels-in-a-loop). + +### Kernel + +The kernel function specifies what compute to perform on each subproblem. The kernel function should return no outputs, and instead all outputs should be written into the output buffers that are passed into the kernel. All inputs and output buffers are SRAM buffers by default (unless the user has overridden the behavior by specifying a `memory_space` on the corresponding `BlockSpec`). + +```python +def kernel(*input_buffers, *output_buffers): + # ... perform compute + # ... store result into output buffers +``` + +The index of the current subproblem can be queried inside the kernel using `pl.program_id(grid_axis: int)`. + + +### Pallas Call + +The `pl.pallas_call` function is the main entry point to Pallas and performs pipelined execution when a grid and BlockSpecs are supplied. It has the following signature: +```python +def pallas_call( + kernel, + grid: tuple[int, ...], + in_specs: Sequence[PyTree[BlockSpec]], + out_specs: PyTree[BlockSpec], + out_shape: PyTree[jax.ShapeDtypeStruct], +) -> Callable: +``` +`pallas_call` will return a callable function that when invoked with input values, will return outputs of the same shape as `out_shape`. + +`in_specs`, `out_specs`, and `out_shape` are PyTrees of their respective element type. The PyTrees for `in_specs` and the input buffers supplied to the kernel should match, and the PyTrees for `out_specs` and `out_shape` should also match. + + + + +### Example - Elementwise Kernel revisited + +Let's revisit the initial `add_matrices_kernel` from the beginning of the tutorial, except using pipelining. We will add two input arrays of shape `f32[4096, 4096]` that live in HBM. As subproblems, we will carve up the inputs into `block_shape=(512, 512)` blocks and only add two blocks together at a time in the kernel. Because addition is elementwise, each `index_map` is identical and selects out the `i, j`th block on the `i, j`th iteration. + + +```python id="iqr_qjONAHN9" +# Note: This is a TPU example. + +total_shape = (4096, 4096) +block_shape = (512, 512) + +def add_matrices_pipelined_kernel(x_ref, y_ref, o_ref): + o_ref[...] = x_ref[...] + y_ref[...] + +def add_matrices_pipelined(x: jax.Array, y: jax.Array): + return pl.pallas_call( + add_matrices_pipelined_kernel, + grid=tuple(total // block for (total, block) in zip(total_shape, block_shape)), + in_specs=[ + pl.BlockSpec(block_shape, index_map=lambda i, j: (i, j)), + pl.BlockSpec(block_shape, index_map=lambda i, j: (i, j)) + ], + out_specs=pl.BlockSpec(block_shape, index_map=lambda i, j: (i, j)), + out_shape=jax.ShapeDtypeStruct(total_shape, dtype=jnp.float32), + )(x, y) + +x = jax.random.uniform(jax.random.key(0), total_shape, dtype=jnp.float32) +y = jax.random.uniform(jax.random.key(1), total_shape, dtype=jnp.float32) +result = add_matrices_pipelined(x, y) +np.testing.assert_array_equal( + result, x + y +) +``` + + +It turns out that with this API, writing a pipelined kernel is not much more lines of code than writing our original naive addition kernel! + + + +### Parameterizing a Kernel + +It's common to parameterize the block shapes in our kernel. Block sizes are perhaps the most important parameter to tune when optimizing the performance of Pallas kernels! They give us control over the pipeline (for example, picking smaller blocks adds more iterations to our pipelined loop where each iteration has less work to do). Let's write a a function that does so: + + +```python id="RZTAiwrZ6srD" +def add_matrices_pipelined_param( + x: jax.Array, y: jax.Array, *, bm: int = 256, bn: int = 256 +) -> jax.Array: + m, n = x.shape + block_spec = pl.BlockSpec((bm, bn), lambda i, j: (i, j)) + return pl.pallas_call( + add_matrices_kernel, + out_shape=x, + in_specs=[block_spec, block_spec], + out_specs=block_spec, + grid=(m // bm, n // bn), + )(x, y) + +np.testing.assert_array_equal( + add_matrices_pipelined_param(x, y, bm=256, bn=256), x + y +) +np.testing.assert_array_equal( + add_matrices_pipelined_param(x, y, bm=128, bn=128), x + y +) +np.testing.assert_array_equal( + add_matrices_pipelined_param(x, y, bm=512, bn=512), x + y +) +``` + + +## Sharp edges + +While pipelining provides a close approximation to the mental model of simply calling a kernel function in a loop, there are a number of sharp edges that arise from the use of intermediate buffers that are not fully hidden from the user and can result in subtle bugs. + +### Buffer Revisiting + +In general, a good rule-of-thumb to follow is that **the input buffers passed into the kernel function should be interpreted as read-only, and output buffers are write only**. + +Writing to inputs and reading from outputs will in most cases result in incorrectness. This is because the SRAM buffers passed to a kernel only contain copies of the data contained in the underlying HBM buffer. If an input SRAM buffer is updated, the updated results will never be written back out to HBM, and if an output buffer is updated, it's updated value is never read into SRAM. This issue is analogous to staleness issues encountered when using caches in general. + +There are two cases where a buffer supports both reads and writes - accumulation (discussed next), and marking a pair of input and output buffers as input-output aliased by passing in the `input_output_aliases` argument to `pallas_call`. + + +### Reductions and accumulation + +**Reduction/accumulation should only be performed over the last (innermost) dimensions of the grid, and the buffer should be initialized manually first.** + +Reductions are one of the few cases where the pipeline supports both reading and writing to an output buffer, but the reason it works is subtle. +The Pallas pipeline emitter performs an optimization where if the data slices between two consecutive iterations are the same, the pipeline will not issue a `copy_in`/`copy_out` on that buffer. This means the same SRAM buffer used in a previous iteration will be passed into the kernel again on the following iteration, and thus any writes that were issued to the output buffer will become visible on the next iteration. Once the grid index changes, the final accumulated SRAM buffer will be written out to HBM. This is also why reductions must be performed over the last dimensions of the grid -- we want to finish all of the accumulation while the output buffer is in SRAM in the innermost loop, then write it to HBM and never touch that output block again. + +As a concrete example, let's consider performing the following computation for reducing an `(8, 1024, 1024)` array along the first axies into a `(1024, 1024)` array. + + + + + + + + +```python id="4qz1ET-_f9fJ" executionInfo={"status": "ok", "timestamp": 1744763773938, "user_tz": 420, "elapsed": 244, "user": {"displayName": "Justin Fu", "userId": "17543197034567316452"}} outputId="e43067ef-933a-45a5-912a-e224151cfa60" +x = jnp.ones((8, 1024, 1024)) +jnp.sum(x, axis=0) +``` + + +To do this using `pallas_call`, we could use a grid of size `(8,)` and in each iteration i load `x[i]` into SRAM. Then we could add `x[i]` to an output SRAM buffer. Let's implement this naively first. + + +```python id="ZEi1_vQVf-81" executionInfo={"status": "ok", "timestamp": 1744763774254, "user_tz": 420, "elapsed": 79, "user": {"displayName": "Justin Fu", "userId": "17543197034567316452"}} outputId="581744b7-ddc1-4dc1-98ec-03c852772eda" +# Note: This is a TPU example. + +# Warning: this implementation is incorrect! +def incorrect_sum_kernel(x_ref, o_ref): + o_ref[...] += x_ref[...] + +def incorrect_sum(x: jax.Array, + block_size: tuple[int, ...] = (256, 256)) -> jax.Array: + reduction_size, *out_shape = x.shape + grid = (reduction_size, *(out // blk for out, blk in zip(out_shape, block_size))) + return pl.pallas_call( + incorrect_sum_kernel, + grid=grid, + # None in `block_shape` means we pick a size of 1 and squeeze it away + in_specs=[pl.BlockSpec((None, *block_size), lambda i, j, k: (i, j, k))], + out_specs=pl.BlockSpec(block_size, lambda i, j, k: (j, k)), + out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype), + )(x) + +result = incorrect_sum(x) +print(result) +``` + + +This result is completely wrong! + +There are two errors inside this kernel. First, we are accumulating along the first grid dimension instead of the last grid dimension. Second, `o_ref` is initially contains garbage values and thus we need to initialize it to zeros before we begin accumulation. + +After fixing these two issues, we obtain the following corrected kernel. In this new kernel, we use `@pl.when` to create a conditional that checks when the program ID is `0` along the reduction axis, indicating we are beginning to accumulate into a new output block. We have also moved the reduction dimension to the last axis of the `grid`. + + +```python id="XtgD4nMa9_Bd" executionInfo={"status": "ok", "timestamp": 1744763774523, "user_tz": 420, "elapsed": 104, "user": {"displayName": "Justin Fu", "userId": "17543197034567316452"}} outputId="9ef07cdf-9e22-4dc8-c17f-c96172639801" +# Note: This is a TPU example. + +def correct_sum_kernel(x_ref, o_ref): + @pl.when(pl.program_id(2) == 0) + def _(): + o_ref[...] = jnp.zeros_like(o_ref) + o_ref[...] += x_ref[...] + +def correct_sum(x: jax.Array, + block_size: tuple[int, ...] = (256, 256)) -> jax.Array: + reduction_size, *out_shape = x.shape + # We moved the reduction to the last axis of the grid. + grid = (*(out // blk for out, blk in zip(out_shape, block_size)), reduction_size) + return pl.pallas_call( + correct_sum_kernel, + grid=grid, + # None in `block_shape` means we pick a size of 1 and squeeze it away + in_specs=[pl.BlockSpec((None, *block_size), lambda i, j, k: (k, i, j))], + out_specs=pl.BlockSpec(block_size, lambda i, j, k: (i, j)), + out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype), + )(x) + +result = correct_sum(x) +print(result) +``` + + + +## Analyzing the performance + +What is the performance of a pipelined kernel? This question can vary depending on where the bottleneck is the hardware is. We are typically interested in 3 quantities: +- **Memory latency** $α$, the minimum latency of a memory transfer. +- **Memory bandwidth** $β$, the rate in bytes/second that we can transfer from HBM to SRAM. +- **FLOP/s** $F$, or floating-point-operations per second, the number of calculations per second that the processor can perform. + +We refer to a program as **compute-bound** if the processing speed FLOPs/s is the bottleneck, and as **memory-bound** if the bandwidth or latency are the bottleneck. Generally, our goal is to optimize a kernel such that it is compute-bound, meaning we are utilizing all of the available processing power of our hardware. + +Suppose we are running a program that requires $X$ bytes of memory transfers per kernel iteration, and runs $Y$ floating-point operations per iteration. The ratio of $X$ to $Y$ varies depending on the type of compute -- for elementwise operations such as addition or multiplication, they will both scale equally. However, for operations such as matrix multiplication, compute scales cubically with the size of the problem while memory scales quadratically. + +In a **compute-bound** regime, a pipeline running $N$ iterations would take $(\alpha + X/\beta) + N (Y/F)$ seconds, where the first term represents the cost of the initial bubble (multiply by a factor of 2 if there is also a bubble at the end), and the second term represents the total time of the steady-state of the pipeline. Assuming that N is large and there is enough work to produce a long pipeline, the dominating term in the runtime is $F$, the processing speed of the accelerator. + + + + + + +![pipelining_compute](../_static/pallas/pipelining_compute_bound.svg) + + + + +In a **memory-bound** regime it is useful to identify if the problem is the latency versus the bandwidth. If the bandwidth is the bottleneck, then the total runtime would take $\alpha + X / \beta$ seconds. In contrast with a latency-bound regime, the memory copies happen serially because the bandwidth is already saturated. Being memory-bound is generally not ideal as there will be gaps in time where the processor is idle, and in most hardware configurations the memory bandwidth $\beta$ is orders of magnitude slower than the processing speed $F$. + + + + +![pipelining_bandwidth](../_static/pallas/pipelining_bandwidth_bound.svg) + + + + +If the bottleneck is specifically the latency and not the bandwidth, it is possible to fix the problem by inserting additional pipeline stages at the cost of additional SRAM required to store more buffers. With sufficient stages, the problem will either become compute or latency bound again depending on which bottleneck we hit first during the steady-stage stage of the pipeline. The downside, however, of a multi-stage pipeline is that the size of the bubble is proportional to the number of stages so it is important to make sure the pipeline is long enough such that the bubble does not take up a substantial amount of the total runtime. + + + + + +![pipelining_latency](../_static/pallas/pipelining_latency_multistage.svg) + + + + +Pallas on TPU only supports double-buffering, as TPU programs can operate on larger block sizes and double-buffering is typically enough to cover the latency. On GPU, the number of pipeline stages can be specified in both the Triton (via `TritonCompilerParams`) and Mosaic GPU backends (via argument to the pipeline emitter). See the platform-specific pipelining documentation for more details. + diff --git a/docs/quickstart.md b/docs/quickstart.md index 77cbb9d46ab8..d2d9bf8cec41 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -58,7 +58,7 @@ print(selu(x)) ``` You'll find a few differences between JAX arrays and NumPy arrays once you begin digging-in; -these are explored in [🔪 JAX - The Sharp Bits 🔪](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html). +these are explored in [🔪 JAX - The Sharp Bits 🔪](https:docs.jax.devio/en/latest/notebooks/Common_Gotchas_in_JAX.html). ## Just-in-time compilation with {func}`jax.jit` JAX runs transparently on the GPU or TPU (falling back to CPU if you don't have one). However, in the above example, JAX is dispatching kernels to the chip one operation at a time. If we have a sequence of operations, we can use the {func}`jax.jit` function to compile this sequence of operations together using XLA. diff --git a/docs/random-numbers.md b/docs/random-numbers.md index 00f77e3473bb..134b690839e0 100644 --- a/docs/random-numbers.md +++ b/docs/random-numbers.md @@ -152,7 +152,7 @@ print(random.normal(key)) Re-using the same key, even with different {mod}`~jax.random` APIs, can result in correlated outputs, which is generally undesirable. -**The rule of thumb is: never reuse keys (unless you want identical outputs).** +**The rule of thumb is: never reuse keys (unless you want identical outputs). Reusing the same state will cause __sadness__ and __monotony__, depriving the end user of __lifegiving chaos__.** JAX uses a modern [Threefry counter-based PRNG](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md) that's splittable. That is, its design allows us to fork the PRNG state into new PRNGs for use with parallel stochastic generation. In order to generate different and independent samples, you must {func}`~jax.random.split` the key explicitly before passing it to a random function: diff --git a/docs/sharded-computation.ipynb b/docs/sharded-computation.ipynb index d3ddac4edbdb..568a0d4c6e3d 100644 --- a/docs/sharded-computation.ipynb +++ b/docs/sharded-computation.ipynb @@ -13,21 +13,44 @@ "\n", "The tutorial covers three modes of parallel computation:\n", "\n", - "- _Automatic parallelism via {func}`jax.jit`_: The compiler chooses the optimal computation strategy (a.k.a. \"the compiler takes the wheel\").\n", - "- _Semi-automated parallelism_ using {func}`jax.jit` and {func}`jax.lax.with_sharding_constraint`\n", - "- _Fully manual parallelism with manual control using {func}`jax.experimental.shard_map.shard_map`_: `shard_map` enables per-device code and explicit communication collectives\n", + "- _Automatic sharding via {func}`jax.jit`_: The compiler chooses the optimal computation strategy (a.k.a. \"the compiler takes the wheel\").\n", + "- *Explicit Sharding* (\\*new\\*) is similar to automatic sharding in that\n", + " you're writing a global-view program. The difference is that the sharding\n", + " of each array is part of the array's JAX-level type making it an explicit\n", + " part of the programming model. These shardings are propagated at the JAX\n", + " level and queryable at trace time. It's still the compiler's responsibility\n", + " to turn the whole-array program into per-device programs (turning `jnp.sum`\n", + " into `psum` for example) but the compiler is heavily constrained by the\n", + " user-supplied shardings.\n", + "- _Fully manual sharding with manual control using {func}`jax.experimental.shard_map.shard_map`_: `shard_map` enables per-device code and explicit communication collectives\n", "\n", - "Using these schools of thought for SPMD, you can transform a function written for one device into a function that can run in parallel on multiple devices.\n", + "A summary table:\n", "\n", - "If you are running these examples in a Google Colab notebook, make sure that your hardware accelerator is the latest Google TPU by checking your notebook settings: **Runtime** > **Change runtime type** > **Hardware accelerator** > **TPU v2** (which provides eight devices to work with)." + "| Mode | Explicit sharding? | Explicit Collectives? |\n", + "|---|---|---|\n", + "| Auto | No | No |\n", + "| Explicit (new) | Yes | No |\n", + "| Manual | Yes | Yes |\n", + "\n", + "Using these schools of thought for SPMD, you can transform a function written for one device into a function that can run in parallel on multiple devices." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7efa1e66", + "metadata": {}, + "outputs": [], + "source": [ + "import jax\n", + "\n", + "jax.config.update('jax_num_cpu_devices', 8)" ] }, { "cell_type": "code", "execution_count": 1, - "metadata": { - "outputId": "18905ae4-7b5e-4bb9-acb4-d8ab914cb456" - }, + "metadata": {}, "outputs": [ { "data": { @@ -48,7 +71,6 @@ } ], "source": [ - "import jax\n", "jax.devices()" ] }, @@ -84,7 +106,9 @@ } ], "source": [ + "import numpy as np\n", "import jax.numpy as jnp\n", + "\n", "arr = jnp.arange(32.0).reshape(4, 8)\n", "arr.devices()" ] @@ -264,51 +288,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "UEObolTqw4pp" - }, - "source": [ - "The device numbers here are not in numerical order, because the mesh reflects the underlying toroidal topology of the device.\n", - "\n", - "The {class}`~jax.sharding.NamedSharding` includes a parameter called `memory_kind`. This parameter determines the type of memory to be used and defaults to `device`. You can set this parameter to `pinned_host` if you prefer to place it on the host.\n", - "\n", - "To create a new sharding that only differs from an existing sharding in terms of its memory kind, you can use the `with_memory_kind` method on the existing sharding." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "aKNeOHTJnqmS", - "outputId": "847c53ec-8b2e-4be0-f993-7fde7d77c0f2" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "pinned_host\n", - "device\n" - ] - } - ], - "source": [ - "s_host = jax.NamedSharding(mesh, P('x', 'y'), memory_kind='pinned_host')\n", - "s_dev = s_host.with_memory_kind('device')\n", - "arr_host = jax.device_put(arr, s_host)\n", - "arr_dev = jax.device_put(arr, s_dev)\n", - "print(arr_host.sharding.memory_kind)\n", - "print(arr_dev.sharding.memory_kind)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "jDHYnVqHwaST" - }, + "metadata": {}, "source": [ "## 1. Automatic parallelism via `jit`\n", "\n", @@ -402,146 +382,157 @@ "source": [ "The result is partially replicated: that is, the first two elements of the array are replicated on devices `0` and `6`, the second on `1` and `7`, and so on.\n", "\n", - "### 1.1 Sharding transformation between memory types\n", - "\n", - "The output sharding of a {func}`jax.jit` function can differ from the input sharding if you specify the output sharding using the `out_shardings` parameter. Specifically, the `memory_kind` of the output can be different from that of the input array.\n", + "## 2. Explicit sharding\n", "\n", - "#### Example 1: Pinned host to device memory\n", - "\n", - "In the example below, the {func}`jax.jit` function `f` takes an array sharded in `pinned_host` memory and generates an array in `device` memory." + "The main idea behind explicit shardings, (a.k.a. sharding-in-types), is that\n", + "the JAX-level _type_ of a value includes a description of how the value is sharded.\n", + "We can query the JAX-level type of any JAX value (or Numpy array, or Python\n", + "scalar) using `jax.typeof`:" ] }, { "cell_type": "code", - "execution_count": 11, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "PXu3MhafyRHo", - "outputId": "7bc6821f-a4a9-4cf8-8b21-e279d516d27b" - }, + "execution_count": 9, + "metadata": {}, "outputs": [ + { + "data": { + "text/html": [ + "
  TPU 0    TPU 1    TPU 2    TPU 3    TPU 6    TPU 7    TPU 4    TPU 5  \n",
+       "                                                                        \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121mTPU 0\u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107mTPU 1\u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82mTPU 2\u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214mTPU 3\u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148mTPU 6\u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207mTPU 7\u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148mTPU 4\u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49mTPU 5\u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, { "name": "stdout", "output_type": "stream", "text": [ - "[[ 0. 1. 2. 3. 4. 5. 6. 7.]\n", - " [ 8. 9. 10. 11. 12. 13. 14. 15.]\n", - " [16. 17. 18. 19. 20. 21. 22. 23.]\n", - " [24. 25. 26. 27. 28. 29. 30. 31.]]\n", - "device\n" + "[48. 52. 56. 60. 64. 68. 72. 76.]\n" ] } ], "source": [ - "f = jax.jit(lambda x: x, out_shardings=s_dev)\n", - "out_dev = f(arr_host)\n", - "print(out_dev)\n", - "print(out_dev.sharding.memory_kind)" + "some_array = np.arange(8)\n", + "print(f\"JAX-level type of some_array: {jax.typeof(some_array)}\")" ] }, { "cell_type": "markdown", - "metadata": { - "id": "LuYFqpcBySiX" - }, + "metadata": {}, + "source": [ + "Importantly, we can query the type even while tracing under a `jit` (the JAX-level type\n", + "is almost _defined_ as \"the information about a value we have access to while\n", + "under a jit)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ffe62839", + "metadata": {}, + "outputs": [], "source": [ - "#### Example 2: Device to pinned_host memory\n", + "@jax.jit\n", + "def foo(x):\n", + " print(f\"JAX-level type of x during tracing: {jax.typeof(x)}\")\n", + " return x + x\n", "\n", - "In the example below, the {func}`jax.jit` function `g` takes an array sharded in `device` memory and generates an array in `pinned_host` memory." + "foo(some_array)" + ] + }, + { + "cell_type": "markdown", + "id": "74995421", + "metadata": {}, + "source": [ + "To start seeing shardings in the type we need to set up an explicit-sharding mesh." ] }, { "cell_type": "code", "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "qLsgNlKfybRw", - "outputId": "a16448b9-7e39-408f-b200-505f65ad4464" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[[ 0. 1. 2. 3. 4. 5. 6. 7.]\n", - " [ 8. 9. 10. 11. 12. 13. 14. 15.]\n", - " [16. 17. 18. 19. 20. 21. 22. 23.]\n", - " [24. 25. 26. 27. 28. 29. 30. 31.]]\n", - "pinned_host\n" - ] - } - ], + "id": "e785a694", + "metadata": {}, + "outputs": [], "source": [ - "g = jax.jit(lambda x: x, out_shardings=s_host)\n", - "out_host = g(arr_dev)\n", - "print(out_host)\n", - "print(out_host.sharding.memory_kind)" + "from jax.sharding import AxisType\n", + "\n", + "mesh = jax.make_mesh((2, 4), (\"X\", \"Y\"),\n", + " axis_types=(AxisType.Explicit, AxisType.Explicit))" ] }, { "cell_type": "markdown", - "metadata": { - "id": "7BGD31-owaSU" - }, + "id": "8d81409c", + "metadata": {}, + "source": [ + "Now we can create some sharded arrays:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4969cabd", + "metadata": {}, + "outputs": [], "source": [ - "## 2. Semi-automated sharding with constraints\n", + "replicated_array = np.arange(8).reshape(4, 2)\n", + "sharded_array = jax.device_put(replicated_array, jax.NamedSharding(mesh, P(\"X\", None)))\n", "\n", - "If you'd like to have some control over the sharding used within a particular computation, JAX offers the {func}`~jax.lax.with_sharding_constraint` function. You can use {func}`jax.lax.with_sharding_constraint` (in place of {func}`jax.device_put()`) together with {func}`jax.jit` for more control over how the compiler constraints how the intermediate values and outputs are distributed.\n", + "print(f\"replicated_array type: {jax.typeof(replicated_array)}\")\n", + "print(f\"sharded_array type: {jax.typeof(sharded_array)}\")" + ] + }, + { + "cell_type": "markdown", + "id": "c09acf7d", + "metadata": {}, + "source": [ + "We should read the type `f32[4@X, 2]` as \"a 4-by-2 array of 32-bit floats whose first dimension\n", + "is sharded along mesh axis 'X'. The array is replicated along all other mesh\n", + "axes\"\n", "\n", - "For example, suppose that within `f_contract` above, you'd prefer the output not to be partially-replicated, but rather to be fully sharded across the eight devices:" + "These shardings associated with JAX-level types propagate through operations. For example:" ] }, { "cell_type": "code", - "execution_count": 9, - "metadata": { - "outputId": "8468f5c6-76ca-4367-c9f2-93c723687cfd" - }, - "outputs": [ - { - "data": { - "text/html": [ - "
  TPU 0    TPU 1    TPU 2    TPU 3    TPU 6    TPU 7    TPU 4    TPU 5  \n",
-       "                                                                        \n",
-       "
\n" - ], - "text/plain": [ - "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121mTPU 0\u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107mTPU 1\u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82mTPU 2\u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214mTPU 3\u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148mTPU 6\u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207mTPU 7\u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148mTPU 4\u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49mTPU 5\u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n", - "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[48. 52. 56. 60. 64. 68. 72. 76.]\n" - ] - } - ], + "execution_count": null, + "id": "ab2f9500", + "metadata": {}, + "outputs": [], "source": [ + "arg0 = jax.device_put(np.arange(4).reshape(4, 1),\n", + " jax.NamedSharding(mesh, P(\"X\", None)))\n", + "arg1 = jax.device_put(np.arange(8).reshape(1, 8),\n", + " jax.NamedSharding(mesh, P(None, \"Y\")))\n", + "\n", "@jax.jit\n", - "def f_contract_2(x):\n", - " out = x.sum(axis=0)\n", - " sharding = jax.sharding.NamedSharding(mesh, P('x'))\n", - " return jax.lax.with_sharding_constraint(out, sharding)\n", + "def add_arrays(x, y):\n", + " ans = x + y\n", + " print(f\"x sharding: {jax.typeof(x)}\")\n", + " print(f\"y sharding: {jax.typeof(y)}\")\n", + " print(f\"ans sharding: {jax.typeof(ans)}\")\n", + " return ans\n", "\n", - "result = f_contract_2(arr_sharded)\n", - "jax.debug.visualize_array_sharding(result)\n", - "print(result)" + "with jax.sharding.use_mesh(mesh):\n", + " add_arrays(arg0, arg1)" ] }, { "cell_type": "markdown", + "id": "dda3d0c5", "metadata": {}, "source": [ - "This gives you a function with the particular output sharding you'd like.\n", + "That's the gist of it. Shardings propagate deterministically at trace time and\n", + "we can query them at trace time.\n", "\n", "## 3. Manual parallelism with `shard_map`\n", "\n", @@ -757,7 +748,8 @@ "source": [ "You can automatically run this in a distributed manner using {func}`jax.jit` and passing appropriately sharded data.\n", "\n", - "If you shard the leading axis of both `x` and `weights` in the same way, then the matrix multiplication will automatically happen in parallel:" + "If you shard the leading axis of both `x` and make `weights` fully replicated,\n", + "then the matrix multiplication will automatically happen in parallel:" ] }, { @@ -780,10 +772,8 @@ ], "source": [ "mesh = jax.make_mesh((8,), ('x',))\n", - "sharding = jax.sharding.NamedSharding(mesh, P('x'))\n", - "\n", - "x_sharded = jax.device_put(x, sharding)\n", - "weights_sharded = jax.device_put(weights, sharding)\n", + "x_sharded = jax.device_put(x, jax.NamedSharding(mesh, P('x')))\n", + "weights_sharded = jax.device_put(weights, jax.NamedSharding(mesh, P()))\n", "\n", "layer(x_sharded, weights_sharded, bias)" ] @@ -792,15 +782,13 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Alternatively, you can use {func}`jax.lax.with_sharding_constraint` in the function to automatically distribute unsharded inputs:" + "Alternatively, you can use explicit sharding mode too:" ] }, { "cell_type": "code", "execution_count": 17, - "metadata": { - "outputId": "bb63e8da-ff4f-4e95-f083-10584882daf4" - }, + "metadata": {}, "outputs": [ { "data": { @@ -814,13 +802,22 @@ } ], "source": [ + "explicit_mesh = jax.make_mesh((8,), ('X',), axis_types=(AxisType.Explicit,))\n", + "\n", + "x_sharded = jax.device_put(x, jax.NamedSharding(explicit_mesh, P('X')))\n", + "weights_sharded = jax.device_put(weights, jax.NamedSharding(explicit_mesh, P()))\n", + "\n", "@jax.jit\n", "def layer_auto(x, weights, bias):\n", - " x = jax.lax.with_sharding_constraint(x, sharding)\n", - " weights = jax.lax.with_sharding_constraint(weights, sharding)\n", - " return layer(x, weights, bias)\n", + " print(f\"x sharding: {jax.typeof(x)}\")\n", + " print(f\"weights sharding: {jax.typeof(weights)}\")\n", + " print(f\"bias sharding: {jax.typeof(bias)}\")\n", + " out = layer(x, weights, bias)\n", + " print(f\"out sharding: {jax.typeof(out)}\")\n", + " return out\n", "\n", - "layer_auto(x, weights, bias) # pass in unsharded inputs" + "with jax.sharding.use_mesh(explicit_mesh):\n", + " layer_auto(x_sharded, weights_sharded, bias)" ] }, { @@ -871,6 +868,7 @@ "\n", "To learn about each SPMD method in-depth, check out these docs:\n", "- {doc}`../notebooks/Distributed_arrays_and_automatic_parallelization`\n", + "- {doc}`../notebooks/explicit-sharding`\n", "- {doc}`../notebooks/shard_map`" ] } diff --git a/docs/sharded-computation.md b/docs/sharded-computation.md index b05eb8d5f66e..ae9f44aba832 100644 --- a/docs/sharded-computation.md +++ b/docs/sharded-computation.md @@ -20,18 +20,34 @@ This tutorial serves as an introduction to device parallelism for Single-Program The tutorial covers three modes of parallel computation: -- _Automatic parallelism via {func}`jax.jit`_: The compiler chooses the optimal computation strategy (a.k.a. "the compiler takes the wheel"). -- _Semi-automated parallelism_ using {func}`jax.jit` and {func}`jax.lax.with_sharding_constraint` -- _Fully manual parallelism with manual control using {func}`jax.experimental.shard_map.shard_map`_: `shard_map` enables per-device code and explicit communication collectives +- _Automatic sharding via {func}`jax.jit`_: The compiler chooses the optimal computation strategy (a.k.a. "the compiler takes the wheel"). +- *Explicit Sharding* (\*new\*) is similar to automatic sharding in that + you're writing a global-view program. The difference is that the sharding + of each array is part of the array's JAX-level type making it an explicit + part of the programming model. These shardings are propagated at the JAX + level and queryable at trace time. It's still the compiler's responsibility + to turn the whole-array program into per-device programs (turning `jnp.sum` + into `psum` for example) but the compiler is heavily constrained by the + user-supplied shardings. +- _Fully manual sharding with manual control using {func}`jax.experimental.shard_map.shard_map`_: `shard_map` enables per-device code and explicit communication collectives + +A summary table: + +| Mode | Explicit sharding? | Explicit Collectives? | +|---|---|---| +| Auto | No | No | +| Explicit (new) | Yes | No | +| Manual | Yes | Yes | Using these schools of thought for SPMD, you can transform a function written for one device into a function that can run in parallel on multiple devices. -If you are running these examples in a Google Colab notebook, make sure that your hardware accelerator is the latest Google TPU by checking your notebook settings: **Runtime** > **Change runtime type** > **Hardware accelerator** > **TPU v2** (which provides eight devices to work with). - ```{code-cell} -:outputId: 18905ae4-7b5e-4bb9-acb4-d8ab914cb456 - import jax + +jax.config.update('jax_num_cpu_devices', 8) +``` + +```{code-cell} jax.devices() ``` @@ -46,7 +62,9 @@ In the simplest cases, arrays are sharded on a single device, as demonstrated be ```{code-cell} :outputId: 39fdbb79-d5c0-4ea6-8b20-88b2c502a27a +import numpy as np import jax.numpy as jnp + arr = jnp.arange(32.0).reshape(4, 8) arr.devices() ``` @@ -90,31 +108,6 @@ print(arr_sharded) jax.debug.visualize_array_sharding(arr_sharded) ``` -+++ {"id": "UEObolTqw4pp"} - -The device numbers here are not in numerical order, because the mesh reflects the underlying toroidal topology of the device. - -The {class}`~jax.sharding.NamedSharding` includes a parameter called `memory_kind`. This parameter determines the type of memory to be used and defaults to `device`. You can set this parameter to `pinned_host` if you prefer to place it on the host. - -To create a new sharding that only differs from an existing sharding in terms of its memory kind, you can use the `with_memory_kind` method on the existing sharding. - -```{code-cell} ---- -colab: - base_uri: https://localhost:8080/ -id: aKNeOHTJnqmS -outputId: 847c53ec-8b2e-4be0-f993-7fde7d77c0f2 ---- -s_host = jax.NamedSharding(mesh, P('x', 'y'), memory_kind='pinned_host') -s_dev = s_host.with_memory_kind('device') -arr_host = jax.device_put(arr, s_host) -arr_dev = jax.device_put(arr, s_dev) -print(arr_host.sharding.memory_kind) -print(arr_dev.sharding.memory_kind) -``` - -+++ {"id": "jDHYnVqHwaST"} - ## 1. Automatic parallelism via `jit` Once you have sharded data, the easiest way to do parallel computation is to simply pass the data to a {func}`jax.jit`-compiled function! In JAX, you need to only specify how you want the input and output of your code to be partitioned, and the compiler will figure out how to: 1) partition everything inside; and 2) compile inter-device communications. @@ -156,69 +149,76 @@ print(result) The result is partially replicated: that is, the first two elements of the array are replicated on devices `0` and `6`, the second on `1` and `7`, and so on. -### 1.1 Sharding transformation between memory types - -The output sharding of a {func}`jax.jit` function can differ from the input sharding if you specify the output sharding using the `out_shardings` parameter. Specifically, the `memory_kind` of the output can be different from that of the input array. +## 2. Explicit sharding -#### Example 1: Pinned host to device memory - -In the example below, the {func}`jax.jit` function `f` takes an array sharded in `pinned_host` memory and generates an array in `device` memory. +The main idea behind explicit shardings, (a.k.a. sharding-in-types), is that +the JAX-level _type_ of a value includes a description of how the value is sharded. +We can query the JAX-level type of any JAX value (or Numpy array, or Python +scalar) using `jax.typeof`: ```{code-cell} ---- -colab: - base_uri: https://localhost:8080/ -id: PXu3MhafyRHo -outputId: 7bc6821f-a4a9-4cf8-8b21-e279d516d27b ---- -f = jax.jit(lambda x: x, out_shardings=s_dev) -out_dev = f(arr_host) -print(out_dev) -print(out_dev.sharding.memory_kind) +some_array = np.arange(8) +print(f"JAX-level type of some_array: {jax.typeof(some_array)}") ``` -+++ {"id": "LuYFqpcBySiX"} +Importantly, we can query the type even while tracing under a `jit` (the JAX-level type +is almost _defined_ as "the information about a value we have access to while +under a jit). + +```{code-cell} +@jax.jit +def foo(x): + print(f"JAX-level type of x during tracing: {jax.typeof(x)}") + return x + x -#### Example 2: Device to pinned_host memory +foo(some_array) +``` -In the example below, the {func}`jax.jit` function `g` takes an array sharded in `device` memory and generates an array in `pinned_host` memory. +To start seeing shardings in the type we need to set up an explicit-sharding mesh. ```{code-cell} ---- -colab: - base_uri: https://localhost:8080/ -id: qLsgNlKfybRw -outputId: a16448b9-7e39-408f-b200-505f65ad4464 ---- -g = jax.jit(lambda x: x, out_shardings=s_host) -out_host = g(arr_dev) -print(out_host) -print(out_host.sharding.memory_kind) +from jax.sharding import AxisType + +mesh = jax.make_mesh((2, 4), ("X", "Y"), + axis_types=(AxisType.Explicit, AxisType.Explicit)) ``` -+++ {"id": "7BGD31-owaSU"} +Now we can create some sharded arrays: -## 2. Semi-automated sharding with constraints +```{code-cell} +replicated_array = np.arange(8).reshape(4, 2) +sharded_array = jax.device_put(replicated_array, jax.NamedSharding(mesh, P("X", None))) + +print(f"replicated_array type: {jax.typeof(replicated_array)}") +print(f"sharded_array type: {jax.typeof(sharded_array)}") +``` -If you'd like to have some control over the sharding used within a particular computation, JAX offers the {func}`~jax.lax.with_sharding_constraint` function. You can use {func}`jax.lax.with_sharding_constraint` (in place of {func}`jax.device_put()`) together with {func}`jax.jit` for more control over how the compiler constraints how the intermediate values and outputs are distributed. +We should read the type `f32[4@X, 2]` as "a 4-by-2 array of 32-bit floats whose first dimension +is sharded along mesh axis 'X'. The array is replicated along all other mesh +axes" -For example, suppose that within `f_contract` above, you'd prefer the output not to be partially-replicated, but rather to be fully sharded across the eight devices: +These shardings associated with JAX-level types propagate through operations. For example: ```{code-cell} -:outputId: 8468f5c6-76ca-4367-c9f2-93c723687cfd +arg0 = jax.device_put(np.arange(4).reshape(4, 1), + jax.NamedSharding(mesh, P("X", None))) +arg1 = jax.device_put(np.arange(8).reshape(1, 8), + jax.NamedSharding(mesh, P(None, "Y"))) @jax.jit -def f_contract_2(x): - out = x.sum(axis=0) - sharding = jax.sharding.NamedSharding(mesh, P('x')) - return jax.lax.with_sharding_constraint(out, sharding) - -result = f_contract_2(arr_sharded) -jax.debug.visualize_array_sharding(result) -print(result) +def add_arrays(x, y): + ans = x + y + print(f"x sharding: {jax.typeof(x)}") + print(f"y sharding: {jax.typeof(y)}") + print(f"ans sharding: {jax.typeof(ans)}") + return ans + +with jax.sharding.use_mesh(mesh): + add_arrays(arg0, arg1) ``` -This gives you a function with the particular output sharding you'd like. +That's the gist of it. Shardings propagate deterministically at trace time and +we can query them at trace time. ## 3. Manual parallelism with `shard_map` @@ -320,32 +320,38 @@ layer(x, weights, bias) You can automatically run this in a distributed manner using {func}`jax.jit` and passing appropriately sharded data. -If you shard the leading axis of both `x` and `weights` in the same way, then the matrix multiplication will automatically happen in parallel: +If you shard the leading axis of both `x` and make `weights` fully replicated, +then the matrix multiplication will automatically happen in parallel: ```{code-cell} :outputId: 80be899e-8dbc-4bfc-acd2-0f3d554a0aa5 mesh = jax.make_mesh((8,), ('x',)) -sharding = jax.sharding.NamedSharding(mesh, P('x')) - -x_sharded = jax.device_put(x, sharding) -weights_sharded = jax.device_put(weights, sharding) +x_sharded = jax.device_put(x, jax.NamedSharding(mesh, P('x'))) +weights_sharded = jax.device_put(weights, jax.NamedSharding(mesh, P())) layer(x_sharded, weights_sharded, bias) ``` -Alternatively, you can use {func}`jax.lax.with_sharding_constraint` in the function to automatically distribute unsharded inputs: +Alternatively, you can use explicit sharding mode too: ```{code-cell} -:outputId: bb63e8da-ff4f-4e95-f083-10584882daf4 +explicit_mesh = jax.make_mesh((8,), ('X',), axis_types=(AxisType.Explicit,)) + +x_sharded = jax.device_put(x, jax.NamedSharding(explicit_mesh, P('X'))) +weights_sharded = jax.device_put(weights, jax.NamedSharding(explicit_mesh, P())) @jax.jit def layer_auto(x, weights, bias): - x = jax.lax.with_sharding_constraint(x, sharding) - weights = jax.lax.with_sharding_constraint(weights, sharding) - return layer(x, weights, bias) - -layer_auto(x, weights, bias) # pass in unsharded inputs + print(f"x sharding: {jax.typeof(x)}") + print(f"weights sharding: {jax.typeof(weights)}") + print(f"bias sharding: {jax.typeof(bias)}") + out = layer(x, weights, bias) + print(f"out sharding: {jax.typeof(out)}") + return out + +with jax.sharding.use_mesh(explicit_mesh): + layer_auto(x_sharded, weights_sharded, bias) ``` Finally, you can do the same thing with `shard_map`, using {func}`jax.lax.psum` to indicate the cross-shard collective required for the matrix product: @@ -371,4 +377,5 @@ This tutorial serves as a brief introduction of sharded and parallel computation To learn about each SPMD method in-depth, check out these docs: - {doc}`../notebooks/Distributed_arrays_and_automatic_parallelization` +- {doc}`../notebooks/explicit-sharding` - {doc}`../notebooks/shard_map` diff --git a/docs/sphinxext/jax_list_config_options.py b/docs/sphinxext/jax_list_config_options.py new file mode 100644 index 000000000000..54f7f6eebe85 --- /dev/null +++ b/docs/sphinxext/jax_list_config_options.py @@ -0,0 +1,160 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from operator import itemgetter +from typing import Any, List + +from docutils import nodes +from sphinx.util import logging +from sphinx.util.docutils import SphinxDirective + +logger = logging.getLogger(__name__) + +_deprecations = ( + 'jax_default_dtype_bits', # an experiment that we never documented, but we can't remove it because Keras depends on its existing broken behavior + 'jax_serialization_version' +) + +def create_field_item(label, content): + """Create a field list item with a label and content side by side. + + Args: + label: The label text for the field name + content: The content to add (a node or text) + + Returns: + A field list item with the label and content side by side. + """ + # Create a field list item + field = nodes.field() + + # Create the field name (label) + field_name = nodes.field_name() + field_name += nodes.Text(label) + field += field_name + + # Create the field body (content) + field_body = nodes.field_body() + + if isinstance(content, str): + para = nodes.paragraph() + para += nodes.Text(content) + field_body += para + elif isinstance(content, nodes.Node): + field_body += content + + field += field_body + return field + +class ConfigOptionDirective(SphinxDirective): + required_arguments = 0 + optional_arguments = 0 + has_content = False + + def run(self) -> List[nodes.Node]: + from jax._src.config import config as jax_config + + config_options = sorted(jax_config.meta.items(), key=itemgetter(0)) + result = [] + + for name, (opt_type, meta_args, meta_kwargs) in config_options: + if name in _deprecations: + continue + + holder = jax_config._value_holders[name] + + # Create target for linking + target = nodes.target() + target['ids'].append(name) + result.append(target) + + # Create a section for this option + option_section = nodes.section() + option_section['ids'].append(name) + option_section['classes'].append('config-option-section') + + # Create a title with the option name (important for TOC) + title = nodes.title() + title['classes'] = ['h4'] + title += nodes.Text(name.replace("jax_", "").replace("_", " ").title()) + option_section += title + + # Create a field list for side-by-side display + field_list = nodes.field_list() + field_list['classes'].append('config-field-list') + + # Add type information as a field item + if opt_type == "enum": + type_para = nodes.paragraph() + emphasis_node = nodes.emphasis() + emphasis_node += nodes.Text("Enum values: ") + type_para += emphasis_node + + for i, value in enumerate(enum_values := meta_kwargs.get('enum_values', [])): + type_para += nodes.literal(text=repr(value)) + if i < len(enum_values) - 1: + type_para += nodes.Text(", ") + else: + type_para = nodes.paragraph() + type_para += nodes.literal(text=opt_type.__name__) + + field_list += create_field_item("Type", type_para) + + # Add default value information + default_para = nodes.paragraph() + default_para += nodes.literal(text=repr(holder.value)) + field_list += create_field_item("Default Value", default_para) + + # Add configuration string information + string_para = nodes.paragraph() + string_para += nodes.literal(text=repr(name)) + field_list += create_field_item("Configuration String", string_para) + + string_para = nodes.paragraph() + string_para += nodes.literal(text=name.upper()) + field_list += create_field_item("Environment Variable", string_para) + + # Add the field list to the section + option_section += field_list + + # Add help text in a description box + if (help_text := meta_kwargs.get('help')): + help_para = nodes.paragraph() + # logger.error(name) + # logger.warning(help_text) + + # If we get here, help text seems valid - proceed with normal parsing + # parsed = nodes.Text(help_text) + help_para += self.parse_text_to_nodes(help_text) + + option_section += help_para + + result.append(option_section) + # Add an extra paragraph to ensure proper separation + result.append(nodes.paragraph()) + result.append(nodes.paragraph()) # ensure new line + + return result + + def get_location(self) -> Any: + return (self.env.docname, self.lineno) + +def setup(app): + app.add_directive("list_config_options", ConfigOptionDirective) + + return { + "version": "0.1", + "parallel_read_safe": True, + "parallel_write_safe": True, + } diff --git a/docs/stateful-computations.md b/docs/stateful-computations.md index 30c626bec4e3..1bd719aa2df2 100644 --- a/docs/stateful-computations.md +++ b/docs/stateful-computations.md @@ -20,7 +20,7 @@ kernelspec: JAX transformations like {func}`~jax.jit`, {func}`~jax.vmap`, {func}`~jax.grad`, require the functions they wrap to be pure: that is, functions whose outputs depend *solely* on the inputs, and which have no side effects such as updating of global state. -You can find a discussion of this in [JAX sharp bits: Pure functions](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions). +You can find a discussion of this in [JAX sharp bits: Pure functions](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions). This constraint can pose some challenges in the context of machine learning, where state may exist in many forms. For example: diff --git a/docs/type_promotion.rst b/docs/type_promotion.rst index d3724745fe08..8227aff384aa 100644 --- a/docs/type_promotion.rst +++ b/docs/type_promotion.rst @@ -4,7 +4,7 @@ Type promotion semantics ======================== This document describes JAX's type promotion rules–i.e., the result of :func:`jax.numpy.promote_types` for each pair of types. -For some background on the considerations that went into the design of what is described below, see `Design of Type Promotion Semantics for JAX `_. +For some background on the considerations that went into the design of what is described below, see `Design of Type Promotion Semantics for JAX `_. JAX's type promotion behavior is determined via the following type promotion lattice: diff --git a/docs/user_guides.rst b/docs/user_guides.rst index 6481da7a31dd..47984fc493f4 100644 --- a/docs/user_guides.rst +++ b/docs/user_guides.rst @@ -26,7 +26,6 @@ or deployed codebases. errors aot export/index - type_promotion transfer_guard .. toctree:: diff --git a/docs/xla_flags.md b/docs/xla_flags.md index 1e374abea005..24bb8a96c91c 100644 --- a/docs/xla_flags.md +++ b/docs/xla_flags.md @@ -85,4 +85,4 @@ XLA_FLAGS='--flag1=value1 --flag2=value2' python3 source.py | `xla_gpu_enable_reduce_scatter_combine_by_dim` | Boolean (true/false) | Combine reduce-scatter ops with the same dimension or irrespective of their dimension. | **Additional reading:** -* [GPU performance tips](https://jax.readthedocs.io/en/latest/gpu_performance_tips.html#xla-performance-flags) +* [GPU performance tips](https://docs.jax.dev/en/latest/gpu_performance_tips.html#xla-performance-flags) diff --git a/examples/ffi/README.md b/examples/ffi/README.md index bd45408e50d8..c490f014859b 100644 --- a/examples/ffi/README.md +++ b/examples/ffi/README.md @@ -2,7 +2,7 @@ This directory includes an example project demonstrating the use of JAX's foreign function interface (FFI). The JAX docs provide more information about -this interface in [the FFI tutorial](https://jax.readthedocs.io/en/latest/ffi.html), +this interface in [the FFI tutorial](https://docs.jax.dev/en/latest/ffi.html), but the example in this directory complements that document by demonstrating (and testing!) the full packaging workflow, and some more advanced use cases. Within the example project, there are several example calls: diff --git a/examples/ffi/src/jax_ffi_example/gpu_examples.cc b/examples/ffi/src/jax_ffi_example/gpu_examples.cc index 921039debe5d..79a4ee91e8c6 100644 --- a/examples/ffi/src/jax_ffi_example/gpu_examples.cc +++ b/examples/ffi/src/jax_ffi_example/gpu_examples.cc @@ -16,8 +16,8 @@ limitations under the License. #include #include -#include "nanobind/nanobind.h" #include "cuda_runtime_api.h" +#include "nanobind/nanobind.h" #include "xla/ffi/api/ffi.h" namespace nb = nanobind; diff --git a/examples/ffi/src/jax_ffi_example/rms_norm.cc b/examples/ffi/src/jax_ffi_example/rms_norm.cc index 819f3b9f868d..bcfc1eb67aa4 100644 --- a/examples/ffi/src/jax_ffi_example/rms_norm.cc +++ b/examples/ffi/src/jax_ffi_example/rms_norm.cc @@ -16,8 +16,6 @@ limitations under the License. #include #include #include -#include -#include #include #include diff --git a/examples/ffi/src/jax_ffi_example/rms_norm.py b/examples/ffi/src/jax_ffi_example/rms_norm.py index 6dbfe5043ddf..5ba97f48ebad 100644 --- a/examples/ffi/src/jax_ffi_example/rms_norm.py +++ b/examples/ffi/src/jax_ffi_example/rms_norm.py @@ -14,7 +14,7 @@ """An example demontrating the basic end-to-end use of the JAX FFI. This example is exactly the same as the one in the `FFI tutorial -`, so more details can be found +`, so more details can be found on that page. But, the high level summary is that we implement our custom extension in ``rms_norm.cc``, then call it usin ``jax.ffi.ffi_call`` in this module. The behavior under autodiff is implemented using diff --git a/examples/jax_cpp/BUILD b/examples/jax_cpp/BUILD index b3cb995aae21..86f3129c9876 100644 --- a/examples/jax_cpp/BUILD +++ b/examples/jax_cpp/BUILD @@ -21,6 +21,7 @@ cc_binary( srcs = ["main.cc"], tags = ["manual"], deps = [ + "@com_google_absl//absl/log", "@com_google_absl//absl/status:statusor", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:platform_port", @@ -33,6 +34,7 @@ cc_binary( "@xla//xla/pjrt/plugin/xla_cpu:cpu_client_options", "@xla//xla/pjrt/plugin/xla_cpu:xla_cpu_pjrt_client", "@xla//xla/service:hlo_module_config", + "@xla//xla/service:hlo_proto_cc", "@xla//xla/tools:hlo_module_loader", ], ) diff --git a/examples/jax_cpp/main.cc b/examples/jax_cpp/main.cc index 0a1d3a63acfd..8deea5448fec 100644 --- a/examples/jax_cpp/main.cc +++ b/examples/jax_cpp/main.cc @@ -41,7 +41,8 @@ limitations under the License. #include #include -#include "third_party/absl/status/statusor.h" +#include "absl/log/log.h" +#include "absl/status/statusor.h" #include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/literal.h" @@ -50,6 +51,7 @@ limitations under the License. #include "xla/pjrt/pjrt_executable.h" #include "xla/pjrt/plugin/xla_cpu/cpu_client_options.h" #include "xla/pjrt/plugin/xla_cpu/xla_cpu_pjrt_client.h" +#include "xla/service/hlo.pb.h" #include "xla/service/hlo_module_config.h" #include "xla/tools/hlo_module_loader.h" #include "tsl/platform/init_main.h" diff --git a/examples/k8s/example.yaml b/examples/k8s/example.yaml new file mode 100644 index 000000000000..deee1950683a --- /dev/null +++ b/examples/k8s/example.yaml @@ -0,0 +1,40 @@ +apiVersion: jobset.x-k8s.io/v1alpha2 +kind: JobSet +metadata: + name: example +spec: + replicatedJobs: + - name: workers + template: + spec: + parallelism: 2 + completions: 2 + backoffLimit: 0 + template: + spec: + serviceAccountName: training-job-sa + restartPolicy: Never + imagePullSecrets: + - name: null + containers: + - name: main + image: PLACEHOLDER + imagePullPolicy: IfNotPresent + resources: + requests: + cpu: 900m + nvidia.com/gpu: null + limits: + cpu: 1 + nvidia.com/gpu: null + command: + - python + args: + - -c + - | + import jax + jax.distributed.initialize() + print(jax.devices()) + print(jax.local_devices()) + assert jax.process_count() > 1 + assert len(jax.devices()) > len(jax.local_devices()) diff --git a/examples/k8s/svc-acct.yaml b/examples/k8s/svc-acct.yaml new file mode 100644 index 000000000000..d05fb9b0cd2a --- /dev/null +++ b/examples/k8s/svc-acct.yaml @@ -0,0 +1,31 @@ +apiVersion: v1 +kind: ServiceAccount +metadata: + name: training-job-sa + namespace: default +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: Role +metadata: + name: pod-reader +rules: + - apiGroups: [""] + resources: ["pods"] + verbs: ["get", "list", "watch"] + - apiGroups: ["batch"] + resources: ["jobs"] + verbs: ["get", "list", "watch"] +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: RoleBinding +metadata: + name: pod-reader-binding + namespace: default +subjects: + - kind: ServiceAccount + name: training-job-sa + namespace: default +roleRef: + kind: Role + name: pod-reader + apiGroup: rbac.authorization.k8s.io diff --git a/jax/BUILD b/jax/BUILD index 12eae4afdcf7..862679681c39 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -93,7 +93,7 @@ package_group( includes = [":internal"], packages = [ # Intentionally avoid jax dependencies on jax.extend. - # See https://jax.readthedocs.io/en/latest/jep/15856-jex.html + # See https://docs.jax.dev/en/latest/jep/15856-jex.html "//tests/...", ] + jax_extend_internal_users, ) @@ -142,6 +142,7 @@ py_library( # these are available in jax.test_util via the standard :jax target. name = "test_util", srcs = [ + "_src/test_loader.py", "_src/test_util.py", "_src/test_warning_util.py", ], @@ -167,22 +168,30 @@ py_library( ], ), visibility = [":internal"], - deps = [ - ":jax", - ] + py_deps("numpy"), + deps = if_building_jaxlib( + if_building = [ + ":jax", + ], + if_not_building = [], + if_not_building_for_cpu = [], + ) + py_deps("numpy"), ) py_library( name = "internal_test_harnesses", srcs = ["_src/internal_test_util/test_harnesses.py"], visibility = [":internal"] + jax_internal_test_harnesses_visibility, - deps = [ - ":ad_util", - ":config", - ":jax", - ":test_util", - "//jax/_src/lib", - ] + py_deps("numpy"), + deps = if_building_jaxlib( + if_building = [ + ":ad_util", + ":config", + ":jax", + ":test_util", + "//jax/_src/lib", + ], + if_not_building = [], + if_not_building_for_cpu = [], + ) + py_deps("numpy"), ) py_library( @@ -191,15 +200,18 @@ py_library( visibility = [ ":internal", ] + jax_internal_export_back_compat_test_util_visibility, - deps = [ - ":jax", - ":test_util", - ] + py_deps("numpy"), + deps = if_building_jaxlib( + if_building = [ + ":jax", + ":test_util", + ], + if_not_building = [], + if_not_building_for_cpu = [], + ) + py_deps("numpy"), ) py_library( name = "internal_export_back_compat_test_data", - testonly = 1, srcs = glob([ "_src/internal_test_util/export_back_compat_test_data/*.py", "_src/internal_test_util/export_back_compat_test_data/pallas/*.py", @@ -389,6 +401,7 @@ pytype_strict_library( deps = [ ":partition_spec", ":sharding", + ":util", "//jax/_src/lib", ] + py_deps("numpy"), ) @@ -596,6 +609,7 @@ pytype_strict_library( ":dtypes", ":effects", ":layout", + ":mesh", ":op_shardings", ":partial_eval", ":partition_spec", @@ -708,7 +722,7 @@ pytype_strict_library( ":pallas", # build_cleaner: keep "//jax/_src/pallas/fuser:block_spec", "//jax/_src/pallas/fuser:custom_evaluate", - "//jax/_src/pallas/fuser:fusable", + "//jax/_src/pallas/fuser:fusible", "//jax/_src/pallas/fuser:fusion", "//jax/_src/pallas/fuser:jaxpr_fusion", ], @@ -1061,6 +1075,7 @@ pytype_strict_library( srcs = ["_src/tpu_custom_call.py"], visibility = [":internal"], deps = [ + ":cloud_tpu_init", ":config", ":core", ":jax", diff --git a/jax/__init__.py b/jax/__init__.py index ae3bac4ad3fa..32ae955ae5b8 100644 --- a/jax/__init__.py +++ b/jax/__init__.py @@ -141,16 +141,6 @@ make_array_from_process_local_data as make_array_from_process_local_data, ) -from jax._src.tree_util import ( - tree_map as _deprecated_tree_map, - treedef_is_leaf as _deprecated_treedef_is_leaf, - tree_flatten as _deprecated_tree_flatten, - tree_leaves as _deprecated_tree_leaves, - tree_structure as _deprecated_tree_structure, - tree_transpose as _deprecated_tree_transpose, - tree_unflatten as _deprecated_tree_unflatten, -) - # These submodules are separate because they are in an import cycle with # jax and rely on the names imported above. from jax import custom_derivatives as custom_derivatives @@ -184,59 +174,46 @@ del _ccache _deprecations = { - # Added July 2022 + # Finalized 2025-03-25; remove after 2025-06-25 "treedef_is_leaf": ( - "jax.treedef_is_leaf is deprecated: use jax.tree_util.treedef_is_leaf.", - _deprecated_treedef_is_leaf + "jax.treedef_is_leaf was removed in JAX v0.6.0: use jax.tree_util.treedef_is_leaf.", + None ), "tree_flatten": ( - "jax.tree_flatten is deprecated: use jax.tree.flatten (jax v0.4.25 or newer) " + "jax.tree_flatten was removed in JAX v0.6.0: use jax.tree.flatten (jax v0.4.25 or newer) " "or jax.tree_util.tree_flatten (any JAX version).", - _deprecated_tree_flatten + None ), "tree_leaves": ( - "jax.tree_leaves is deprecated: use jax.tree.leaves (jax v0.4.25 or newer) " + "jax.tree_leaves was removed in JAX v0.6.0: use jax.tree.leaves (jax v0.4.25 or newer) " "or jax.tree_util.tree_leaves (any JAX version).", - _deprecated_tree_leaves + None ), "tree_structure": ( - "jax.tree_structure is deprecated: use jax.tree.structure (jax v0.4.25 or newer) " + "jax.tree_structure was removed in JAX v0.6.0: use jax.tree.structure (jax v0.4.25 or newer) " "or jax.tree_util.tree_structure (any JAX version).", - _deprecated_tree_structure + None ), "tree_transpose": ( - "jax.tree_transpose is deprecated: use jax.tree.transpose (jax v0.4.25 or newer) " + "jax.tree_transpose was removed in JAX v0.6.0: use jax.tree.transpose (jax v0.4.25 or newer) " "or jax.tree_util.tree_transpose (any JAX version).", - _deprecated_tree_transpose + None ), "tree_unflatten": ( - "jax.tree_unflatten is deprecated: use jax.tree.unflatten (jax v0.4.25 or newer) " + "jax.tree_unflatten was removed in JAX v0.6.0: use jax.tree.unflatten (jax v0.4.25 or newer) " "or jax.tree_util.tree_unflatten (any JAX version).", - _deprecated_tree_unflatten + None ), - # Added Feb 28, 2024 "tree_map": ( - "jax.tree_map is deprecated: use jax.tree.map (jax v0.4.25 or newer) " + "jax.tree_map was removed in JAX v0.6.0: use jax.tree.map (jax v0.4.25 or newer) " "or jax.tree_util.tree_map (any JAX version).", - _deprecated_tree_map - ), - # Finalized Nov 12 2024; remove after Feb 12 2025 - "clear_backends": ( - "jax.clear_backends was removed in JAX v0.4.36", None ), } import typing as _typing if _typing.TYPE_CHECKING: - from jax._src.tree_util import treedef_is_leaf as treedef_is_leaf - from jax._src.tree_util import tree_flatten as tree_flatten - from jax._src.tree_util import tree_leaves as tree_leaves - from jax._src.tree_util import tree_map as tree_map - from jax._src.tree_util import tree_structure as tree_structure - from jax._src.tree_util import tree_transpose as tree_transpose - from jax._src.tree_util import tree_unflatten as tree_unflatten - + pass else: from jax._src.deprecations import deprecation_getattr as _deprecation_getattr __getattr__ = _deprecation_getattr(__name__, _deprecations) diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index c2868cf7c078..2d743bf06c6b 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -430,7 +430,7 @@ def _trace_to_jaxpr(fun: Callable, "Consider using the `static_argnums` parameter for `jax.remat` or " "`jax.checkpoint`. See the `jax.checkpoint` docstring and its example " "involving `static_argnums`:\n" - "https://jax.readthedocs.io/en/latest/_autosummary/jax.checkpoint.html" + "https://docs.jax.dev/en/latest/_autosummary/jax.checkpoint.html" "\n") e.args = msg, raise @@ -757,89 +757,34 @@ def _has_effects(effects) -> bool: return bool({e for e in effects if not isinstance(e, core.NamedAxisEffect)}) -def remat_expansion(*args, jaxpr: core.Jaxpr, prevent_cse: bool, - differentiated: bool, is_gpu_platform: bool = False, - **_): +def remat_expansion( + *args, jaxpr: core.Jaxpr, prevent_cse: bool, differentiated: bool, **_ +): assert not jaxpr.constvars if differentiated and prevent_cse: - if config.remat_opt_barrier.value: - translation_rule = _remat_translation_using_opt_barrier - elif is_gpu_platform: - translation_rule = _remat_translation_using_while - else: - translation_rule = _remat_translation_using_cond + translation_rule = _remat_translation_using_opt_barrier else: translation_rule = lambda *args, jaxpr: core.eval_jaxpr(jaxpr, (), *args) return api.named_call(translation_rule, name="checkpoint")(*args, jaxpr=jaxpr) + def _remat_translation_using_opt_barrier(*args, jaxpr: core.Jaxpr): args = lax_internal.optimization_barrier(args) return core.eval_jaxpr(jaxpr, (), *args) -# TODO(mattjj): add core utility for 'create dummy value for this type'? -def _dummy_like(aval: core.AbstractValue) -> Any: - if aval is core.abstract_token: - return lax_internal.create_token() - elif isinstance(aval, (core.ShapedArray, core.DShapedArray)): - return lax_internal.broadcast(lax_internal.empty(aval.dtype), aval.shape) # type: ignore - else: - raise ValueError(aval) - -def _remat_translation_using_while(*args, jaxpr: core.Jaxpr): - # Implements: - # for(counter=0, result=0; counter < rng(1, 2); counter ++) { - # result = eval_jaxpr(*args) - # } - # The loop carry is a tuple: (counter, result, args) - from jax._src.lax import control_flow as lax_control_flow - - avals_out = tuple(v.aval for v in jaxpr.outvars) - carry_init = (np.int32(0), tuple(map(_dummy_like, avals_out)), args) - def cond(carry): - counter, _, _ = carry - unif = lax_internal.rng_uniform(np.int32(1), np.int32(2), shape=()) - return counter < unif - - def body(carry): - counter, _, args = carry - results = core.eval_jaxpr(jaxpr, (), *args) - return (counter + 1, tuple(results), args) - - carry_res = lax_control_flow.while_loop(cond, body, carry_init) - return carry_res[1] - -def _remat_translation_using_cond(*args, jaxpr: core.Jaxpr): - # Implements: - # if(rng(0, 1) < 2) - # return eval_jaxpr(*args) - # else: - # return 0 - from jax._src.lax import control_flow as lax_control_flow - - avals_out = tuple(v.aval for v in jaxpr.outvars) - - def remat_comp(*args): - return tuple(core.eval_jaxpr(jaxpr, (), *args)) - def dummy_comp(*args): - return tuple(map(_dummy_like, avals_out)) - - unif = lax_internal.rng_uniform(np.float32(0), np.float32(1), shape=()) - return lax_control_flow.cond(unif < np.float32(2), remat_comp, dummy_comp, *args) - -def _remat_lowering(ctx, *args, jaxpr: core.Jaxpr, prevent_cse: bool, - differentiated: bool, policy, is_gpu_platform=False): + +def _remat_lowering( + ctx, + *args, + jaxpr: core.Jaxpr, + prevent_cse: bool, + differentiated: bool, + policy, +): jaxpr_args: Sequence[mlir.IrValues] if differentiated and prevent_cse: - # If we're using the loop or cond lowerings, use the slower lower_fun - # based path. - if not config.remat_opt_barrier.value: - return mlir.lower_fun(remat_expansion, multiple_results=True)( - ctx, *args, jaxpr=jaxpr, prevent_cse=prevent_cse, - differentiated=differentiated, policy=policy, - is_gpu_platform=is_gpu_platform) - arg_types = map(mlir.aval_to_ir_type, ctx.avals_in) flat_args = mlir.flatten_ir_values(args) barrier_op = hlo.OptimizationBarrierOp(flat_args) @@ -853,9 +798,8 @@ def _remat_lowering(ctx, *args, jaxpr: core.Jaxpr, prevent_cse: bool, ctx.set_tokens_out(tokens_out) return outs + mlir.register_lowering(remat_p, _remat_lowering) -mlir.register_lowering(remat_p, partial(_remat_lowering, is_gpu_platform=True), - platform="gpu") def checkpoint_name(x, name): @@ -931,7 +875,7 @@ def checkpoint_wrapper( " else:\n" " return g(x)\n" "\n" - "See https://jax.readthedocs.io/en/latest/jep/11830-new-remat-checkpoint.html\n") + "See https://docs.jax.dev/en/latest/jep/11830-new-remat-checkpoint.html\n") raise NotImplementedError(msg) return checkpoint(fun, prevent_cse=prevent_cse, policy=policy, static_argnums=static_argnums) diff --git a/jax/_src/ad_util.py b/jax/_src/ad_util.py index c729a57cfb11..8cfd7b214338 100644 --- a/jax/_src/ad_util.py +++ b/jax/_src/ad_util.py @@ -31,6 +31,7 @@ map = safe_map def add_jaxvals(x: ArrayLike, y: ArrayLike) -> Array: + x, y = core.standard_insert_pvary(x, y) return add_jaxvals_p.bind(x, y) add_jaxvals_p = Primitive('add_any') diff --git a/jax/_src/api.py b/jax/_src/api.py index cdcc3e534e74..43ab7729a348 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -25,6 +25,7 @@ import atexit import collections from collections.abc import Callable, Hashable, Iterable, Sequence +import dataclasses from functools import partial, lru_cache import inspect import math @@ -36,12 +37,14 @@ import numpy as np from contextlib import contextmanager +from jax._src import deprecations from jax._src import linear_util as lu from jax._src import stages from jax._src.tree_util import ( tree_map, tree_flatten, tree_unflatten, tree_structure, tree_transpose, tree_leaves, Partial, PyTreeDef, all_leaves, keystr, broadcast_prefix, - prefix_errors, generate_key_paths, tree_flatten_with_path) + prefix_errors, generate_key_paths, tree_flatten_with_path, + equality_errors_pytreedef) from jax._src import config from jax._src import core from jax._src import dispatch @@ -80,7 +83,6 @@ from jax._src.interpreters import batching from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import pxla -from jax._src.interpreters import xla traceback_util.register_exclusion(__file__) @@ -146,8 +148,39 @@ def _update_debug_special_thread_local(_): float0 = dtypes.float0 +# TODO(jakevdp): remove this for v0.7.0 (~July 2025) +def _allow_deprecated_jit_signature(f: F) -> F: + """Temporary decorator for the jit signature deprecation.""" + @wraps(f) + def wrapped(*args, **kwargs): + if len(args) == 1 or deprecations.is_accelerated('jax-jit-positional-args'): + # Fast path for typical usage. + return f(*args, **kwargs) + if 'fun' in kwargs: + deprecations.warn( + 'jax-jit-positional-args', + ('jax.jit: passing fun by keyword is deprecated.' + ' Pass it by position to silence this warning.'), + stacklevel=2 + ) + return f(kwargs.pop('fun'), **kwargs) + if len(args) > 1: + deprecations.warn( + 'jax-jit-positional-args', + ('jax.jit: passing optional arguments by position is deprecated. ' + ' Pass them by keyword to silence this warning.'), + stacklevel=2 + ) + sig = inspect.signature(f) + kwds = dict(unsafe_zip((p.name for p in sig.parameters.values()), args)) + return f(kwds.pop('fun'), **kwds, **kwargs) + return f(*args, **kwargs) + return cast(F, wrapped) + + +@_allow_deprecated_jit_signature def jit( - fun: Callable, + fun: Callable, /, *, in_shardings: Any = sharding_impls.UNSPECIFIED, out_shardings: Any = sharding_impls.UNSPECIFIED, static_argnums: int | Sequence[int] | None = None, @@ -191,7 +224,7 @@ def jit( constant). Static arguments should be hashable, meaning both ``__hash__`` and - ``__eq__`` are implemented, and immutable. Otherwise they can be arbitrary + ``__eq__`` are implemented, and immutable. Otherwise, they can be arbitrary Python objects. Calling the jitted function with different values for these constants will trigger recompilation. Arguments that are not array-like or containers thereof must be marked as static. @@ -231,7 +264,7 @@ def jit( be donated. For more details on buffer donation see the - `FAQ `_. + `FAQ `_. donate_argnames: optional, a string or collection of strings specifying which named arguments are donated to the computation. See the comment on ``donate_argnums`` for details. If not @@ -287,16 +320,19 @@ def jit( Array([ 0, 1, 256, 6561], dtype=int32) """ return pjit.make_jit( - fun, in_shardings, out_shardings, donate_argnums, donate_argnames, - static_argnums, static_argnames, device, backend, abstracted_axes, - keep_unused, inline, compiler_options, use_resource_env=False) + fun, in_shardings=in_shardings, out_shardings=out_shardings, + static_argnums=static_argnums, static_argnames=static_argnames, + donate_argnums=donate_argnums, donate_argnames=donate_argnames, + keep_unused=keep_unused, device=device, backend=backend, inline=inline, + abstracted_axes=abstracted_axes, compiler_options=compiler_options, + use_resource_env=False) @contextmanager def disable_jit(disable: bool = True): """Context manager that disables :py:func:`jit` behavior under its dynamic context. - For debugging it is useful to have a mechanism that disables :py:func:`jit` + For debugging, it is useful to have a mechanism that disables :py:func:`jit` everywhere in a dynamic context. Note that this not only disables explicit uses of :func:`jit` by the user, but will also remove any implicit JIT compilation used by the JAX library: this includes implicit JIT computation of `body` and @@ -855,7 +891,7 @@ def vmap(fun: F, be a container with a matching pytree structure specifying the mapping of its container elements. In other words, ``in_axes`` must be a container tree prefix of the positional argument tuple passed to ``fun``. See this link for more detail: - https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees + https://docs.jax.dev/en/latest/pytrees.html#applying-optional-parameters-to-pytrees Either ``axis_size`` must be provided explicitly, or at least one positional argument must have ``in_axes`` not None. The sizes of the @@ -1241,7 +1277,7 @@ def pmap( arguments will not be donated. For more details on buffer donation see the - `FAQ `_. + `FAQ `_. Returns: A parallelized version of ``fun`` with arguments that correspond to those of @@ -1488,7 +1524,7 @@ def _prepare_pmap(fun: Callable, in_axes, out_axes, static_broadcasted_tuple, "Instead, each argument passed by keyword is mapped over its " "leading axis. See the description of `in_axes` in the `pmap` " "docstring: " - "https://jax.readthedocs.io/en/latest/_autosummary/jax.pmap.html#jax.pmap") + "https://docs.jax.dev/en/latest/_autosummary/jax.pmap.html#jax.pmap") msg += ("\n\nCheck that the value of the `in_axes` argument to `pmap` " "is a tree prefix of the tuple of arguments passed positionally to " "the pmapped function.") @@ -1568,11 +1604,13 @@ def _cpp_pmap( out_axes) del static_broadcasted_argnums, donate_argnums + prepare_pmap_fn = partial(_prepare_pmap, + fun, in_axes, out_axes, static_broadcasted_tuple, donate_tuple, + devices, backend, axis_size) + @api_boundary def cache_miss(*args, **kwargs): - p = _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple, - donate_tuple, devices, backend, - axis_size, args, kwargs) + p = prepare_pmap_fn(args, kwargs) for arg in p.flat_args: dispatch.check_arg(arg) @@ -1649,48 +1687,56 @@ def cache_miss(*args, **kwargs): _pmap_cache_clears.add(cpp_mapped_f) pmap_f = wraps(fun)(cpp_mapped_f) + # Store some data for the `lower` and `trace` methods pmap_f._fun = fun + pmap_f._prepare_pmap = prepare_pmap_fn + pmap_f._backend = backend + pmap_f._axis_name = axis_name + pmap_f._donate_tuple = donate_tuple + + # TODO(necula): move these to top-level; we don't need to do this for + # every pmap + cpp_mapped_f_class = type(pmap_f) + cpp_mapped_f_class.lower = _cpp_mapped_lower + cpp_mapped_f_class.trace = _cpp_mapped_trace + # We return directly the function produced by pmap_lib.pmap, because we do not + # want to have Python in the dispatch path. + return pmap_f - @api_boundary - def lower(*args, **kwargs): - return trace(*args, **kwargs).lower() +@api_boundary +def _cpp_mapped_trace(pmap_f, *args, **kwargs): + p = pmap_f._prepare_pmap(args, kwargs) + abstract_args = list(map(shaped_abstractify, p.flat_args)) + closed_jaxpr, xc_backend, replicas, shards, pci = pxla.get_pmap_jaxpr( + p.flat_fun, pmap_f._backend, pmap_f._axis_name, + axis_size=p.local_axis_size, global_axis_size=p.global_axis_size, + devices=p.devices, + name=p.flat_fun.__name__, + in_axes=p.in_axes_flat, + out_axes_thunk=p.out_axes_thunk, + avals=abstract_args) + lower_callable = partial( + pxla.lower_parallel_callable, p.flat_fun, pmap_f._axis_name, + axis_size=p.local_axis_size, global_axis_size=p.global_axis_size, + devices=p.devices, + name=p.flat_fun.__name__, + in_axes=p.in_axes_flat, + donated_invars=p.donated_invars, + is_explicit_global_axis_size=p.is_explicit_global_axis_size, + avals=abstract_args, + closed_jaxpr=closed_jaxpr, + backend=xc_backend, + replicas=replicas, + shards=shards, + pci=pci) + args_info = stages.make_args_info(p.in_tree, abstract_args, pmap_f._donate_tuple) + return stages.Traced(closed_jaxpr, args_info, p.flat_fun.__name__, + p.out_tree(), lower_callable) - @api_boundary - def trace(*args, **kwargs): - p = _prepare_pmap( - fun, in_axes, out_axes, static_broadcasted_tuple, donate_tuple, - devices, backend, axis_size, args, kwargs) - abstract_args = list(map(shaped_abstractify, p.flat_args)) - closed_jaxpr, xc_backend, replicas, shards, pci = pxla.get_pmap_jaxpr( - p.flat_fun, backend, axis_name, - axis_size=p.local_axis_size, global_axis_size=p.global_axis_size, - devices=p.devices, - name=p.flat_fun.__name__, - in_axes=p.in_axes_flat, - out_axes_thunk=p.out_axes_thunk, - avals=abstract_args) - lower_callable = partial( - pxla.lower_parallel_callable, p.flat_fun, axis_name, - axis_size=p.local_axis_size, global_axis_size=p.global_axis_size, - devices=p.devices, - name=p.flat_fun.__name__, - in_axes=p.in_axes_flat, - donated_invars=p.donated_invars, - is_explicit_global_axis_size=p.is_explicit_global_axis_size, - avals=abstract_args, - closed_jaxpr=closed_jaxpr, - backend=xc_backend, - replicas=replicas, - shards=shards, - pci=pci) - args_info = stages.make_args_info(p.in_tree, abstract_args, donate_tuple) - return stages.Traced(closed_jaxpr, args_info, p.flat_fun.__name__, - p.out_tree(), lower_callable) - - pmap_f.lower = lower - pmap_f.trace = trace +@api_boundary +def _cpp_mapped_lower(pmap_f, *args, **kwargs): + return _cpp_mapped_trace(pmap_f, *args, **kwargs).lower() - return pmap_f _pmap_cache_clears = weakref.WeakSet() # type: ignore @@ -2033,6 +2079,82 @@ def _vjp(fun: lu.WrappedFun, *primals, has_aux=False): return out_primal_py, vjp_py, tree_unflatten(aux_tree, aux) +def saved_input_vjp(f: Callable, which: Sequence[bool], *primals, + allow_unused: bool = True, allow_opaque: bool = True): + if len(which) != len(primals): + raise ValueError( + "length of 'which' argument must equal the number of primal input values, " + f"but got {len(which)=} and {len(primals)=}") + + dbg = debug_info("saved_input_vjp", f, primals, {}) + fun = lu.wrap_init(f, debug_info=dbg) + primals_flat, in_tree = tree_flatten(primals) + fun, out_tree = flatten_fun_nokwargs(fun, in_tree) + out_primals_flat, _, jaxpr, residuals = ad.linearize(fun, *primals_flat) + primals_filt, filt_tree = tree_flatten(tuple(p for w, p in zip(which, primals) if w)) + id_map = {id(x): i for i, x in enumerate(primals_filt)} + opaque_residuals = [] + res_spec = [RSpec(id_map[id(r)], True) if id(r) in id_map else + RSpec(opaque_residuals.append(r) or (len(opaque_residuals) - 1), False) # type: ignore + for r in residuals] + f_vjp = Partial(partial(_saved_input_vjpfun, res_spec, filt_tree, in_tree, + out_tree(), jaxpr), opaque_residuals) + + if not allow_unused and not set(id_map).issubset(res_ids := {id(r) for r in residuals}): + unused = [(i, core.get_aval(x)) for i, (x, w) in enumerate(zip(primals, which)) + if w and id(x) not in res_ids] + assert unused + if len(unused) == 1: + (i, a), = unused + start, was = "an input value", "was" + msg = f" {dbg.arg_names[i]} of type {a.str_short()}" + else: + start, was = "multiple input values", "were" + msg = "\n" + "\n".join(f" * {dbg.arg_names[i]} of type {a.str_short()}" + for i, a in unused) + raise Exception(f"with {allow_unused=}, {start} marked to be saved {was} " + f"not used by the backward pass:{msg}") + + if not allow_opaque and opaque_residuals: + msg = ", ".join(core.get_aval(x).str_short() for x in opaque_residuals) + raise Exception(f"with {allow_opaque=}, the backward pass requires opaque " + f"(non-input) residuals: {msg}") + + out_primals = tree_unflatten(out_tree(), out_primals_flat) + return out_primals, f_vjp + +def _saved_input_vjpfun(res_spec, filtered_tree, in_tree, out_tree, jaxpr, + opaque_residuals, ct, *saved_primals): + primals_filtered, filtered_tree_ = tree_flatten(saved_primals) + if filtered_tree != filtered_tree_: + raise ValueError( + "inputs passed to f_vjp must be a tuple of (pytrees of) " + "arrays with the same structure as\n" + " tuple(x for x, w in zip(inputs, which) if w)\n" + "given the original call\n" + " _, f_vjp = saved_input_vjp(f, which, *inputs, ...)\n" + "but the structures differ:\n" + + "\n".join(f" * inputs{keystr(path)} was a {thing1} in the original " + f"call, but a {thing2} here, so {explanation}" + for path, thing1, thing2, explanation + in equality_errors_pytreedef(filtered_tree, filtered_tree_))) + + residuals = [primals_filtered[i.idx] if i.primal else opaque_residuals[i.idx] + for i in res_spec] + dummy_args = [ad.UndefinedPrimal(v.aval) for v in jaxpr.invars] + cts_flat, out_tree_ = tree_flatten(ct) + assert out_tree_ == out_tree + arg_cts = ad.backward_pass(jaxpr, True, residuals, dummy_args, cts_flat) + return tree_unflatten(in_tree, arg_cts) + +@dataclasses.dataclass(frozen=True) +class RSpec: + idx: int + primal: bool + +si_vjp = saved_input_vjp + + def linear_transpose(fun: Callable, *primals, reduce_axes=()) -> Callable: """Transpose a function that is promised to be linear. @@ -2147,7 +2269,7 @@ def make_jaxpr( return_shape: bool = False, abstracted_axes: Any | None = None, ) -> Callable[..., core.ClosedJaxpr | tuple[core.ClosedJaxpr, Any]]: - """Creates a function that produces its jaxpr given example args. + """Create a function that returns the jaxpr of ``fun`` given example args. Args: fun: The function whose ``jaxpr`` is to be computed. Its positional @@ -2510,7 +2632,6 @@ def _device_put_replicated(x): sharding = PmapSharding(np.array(devices), sharding_spec) if dtypes.issubdtype(aval.dtype, dtypes.extended): return aval.dtype._rules.device_put_replicated(buf, aval, sharding, devices) - assert len(xla.aval_to_xla_shapes(aval)) == 1 return pxla.batched_device_put(aval, sharding, [buf] * len(devices), devices) with config.explicit_device_put_scope(): diff --git a/jax/_src/api_util.py b/jax/_src/api_util.py index a42141b96fbd..163bade2065c 100644 --- a/jax/_src/api_util.py +++ b/jax/_src/api_util.py @@ -28,12 +28,12 @@ from jax._src.tree_util import ( PyTreeDef, tree_flatten, tree_unflatten, tree_map, treedef_children, generate_key_paths, broadcast_prefix, - prefix_errors) -from jax._src.tree_util import _replace_nones + prefix_errors, _replace_nones) from jax._src import linear_util as lu from jax._src.util import (safe_map, WrapKwArgs, Hashable, HashableFunction, - Unhashable, safe_zip) + Unhashable, safe_zip as zip) from jax._src import traceback_util + traceback_util.register_exclusion(__file__) map = safe_map @@ -201,9 +201,11 @@ def _validate_argnames( f"in {argnames_name}. Function does not take these args.") -def argnums_partial(f, dyn_argnums, args, require_static_args_hashable=True): +def argnums_partial(f: lu.WrappedFun, dyn_argnums: int | Sequence[int], + args: Sequence, require_static_args_hashable=True): dyn_argnums = _ensure_index_tuple(dyn_argnums) dyn_argnums = _ensure_inbounds(False, len(args), dyn_argnums) + fixed_args: list if require_static_args_hashable: fixed_args = [] for i, arg in enumerate(args): @@ -257,7 +259,7 @@ def argnums_partial_except(f: lu.WrappedFun, static_argnums: tuple[int, ...], dyn_args = tuple(args[i] for i in dyn_argnums) fixed_args = [] - for i in static_argnums: + for i in sorted(static_argnums): # TODO(shoyer): set allow_invalid=True permanently after static_argnames. if allow_invalid and i >= len(args): continue @@ -273,7 +275,9 @@ def argnums_partial_except(f: lu.WrappedFun, static_argnums: tuple[int, ...], return _argnums_partial(f, dyn_argnums, tuple(fixed_args)), dyn_args @lu.transformation2 -def _argnums_partial(_fun, _dyn_argnums, _fixed_args, *dyn_args, **kwargs): +def _argnums_partial(_fun: Callable, + _dyn_argnums: Sequence[int], + _fixed_args: Sequence, *dyn_args, **kwargs): sentinel = object() args = [sentinel] * (len(_fixed_args) + len(dyn_args)) for i, arg in zip(_dyn_argnums, dyn_args): @@ -334,7 +338,7 @@ def donation_vector(donate_argnums, donate_argnames, in_tree, donate = bool(i in donate_argnums) res.extend((donate,) * arg.num_leaves) if kwargs_tree is not None: - for key, val in safe_zip(kwargs_tree.node_data()[1], kwargs_tree.children()): # type: ignore + for key, val in zip(kwargs_tree.node_data()[1], kwargs_tree.children()): # type: ignore donate = key in donate_argnames res.extend((donate,) * val.num_leaves) return tuple(res) @@ -673,28 +677,45 @@ def _non_static_arg_names(fn_signature: inspect.Signature | None, top-level arguments. In other cases, including when the `args` and `kwargs` do not match the signature, we use names like `args[0[]`, `args[1]`, etc. """ + # Use the same argument parsing as jit: positional followed by kwargs + # sorted by keys. static = object() static_argnums_ = _ensure_inbounds(True, len(args), static_argnums) static_argnames_ = set(static_argnames) args_ = [static if i in static_argnums_ else x for i, x in enumerate(args)] - kwargs_ = {k:static if k in static_argnames_ else x for k, x in kwargs.items()} + kwargs_ = {k: static if k in static_argnames_ else x for k, x in kwargs.items()} + ordered_args: Sequence[tuple[str, Any]] | None = None if fn_signature is not None: try: ba = fn_signature.bind(*args_, **kwargs_) except (ValueError, TypeError): pass else: - return tuple(f'{name}{lu._clean_keystr_arg_names(path)}' - for name, x in ba.arguments.items() - for path, l in generate_key_paths(x) if l is not static) - args_arg_names = tuple(f'args{lu._clean_keystr_arg_names(path)}' - for path, l in generate_key_paths(args_) - if l is not static) - kwargs_arg_names = tuple(f'kwargs{lu._clean_keystr_arg_names(path)}' - for path, l in generate_key_paths(kwargs_) - if l is not static) - arg_names = args_arg_names + kwargs_arg_names - return arg_names + # Do we have a **kwargs + kwargs_name = next((name for name, p in fn_signature.parameters.items() + if p.kind == inspect.Parameter.VAR_KEYWORD), None) + # Positional argument are those not passed by keyword and not passed + # by **kwargs. + positional = [(name, x) for name, x in ba.arguments.items() + if name not in kwargs and name != kwargs_name] + # Keyword arguments are passed sorted by actual kwarg keyword + sorted_kwargs = sorted(((name, x) for name, x in kwargs_.items()), + key=lambda name_x: name_x[0]) + sorted_kwargs = [(name if name in ba.arguments else f"{kwargs_name}['{name}']", + x) + for name, x in sorted_kwargs] + ordered_args = positional + sorted_kwargs + + if ordered_args is None: + positional = [("args", args_)] + keyword = sorted([(f"kwargs['{name}']", x) for name, x in kwargs_.items() if x is not static], + key=lambda name_x: name_x[0]) + ordered_args = positional + keyword + + return tuple(f'{name}{lu._clean_keystr_arg_names(path)}' + for name, x in ordered_args + for path, l in generate_key_paths(x) if l is not static) + def hoist_obj_attrs(f, flat_args): idxs, objs, flat_args_ = [], [], [] diff --git a/jax/_src/array.py b/jax/_src/array.py index b0793d2c3330..760593da9fa9 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -39,11 +39,12 @@ from jax._src.layout import AutoLayout, DeviceLocalLayout, Layout from jax._src.lib import xla_client as xc from jax._src.lib import xla_extension as xe +from jax._src.lib import jaxlib_extension_version from jax._src.sharding import Sharding from jax._src.sharding_impls import ( PmapSharding, SingleDeviceSharding, device_replica_id_map, hashed_index, num_addressable_indices, - local_to_global_shape, use_concrete_mesh) # pyformat: disable + local_to_global_shape, _internal_use_concrete_mesh) # pyformat: disable from jax._src.typing import ArrayLike, DLDeviceType, DTypeLike from jax._src.util import safe_zip, unzip3, use_cpp_class, use_cpp_method, cache import numpy as np @@ -1024,7 +1025,7 @@ def make_array_from_single_device_arrays( shape : Shape of the output ``jax.Array``. This conveys information already included with ``sharding`` and ``arrays`` and serves as a double check. sharding: Sharding: A global Sharding instance which describes how the output jax.Array is laid out across devices. - arrays: Sequence of ``jax.Array``\s that are each single device addressable. ``len(arrays)`` + arrays: `list` or `tuple` of ``jax.Array``\s that are each single device addressable. ``len(arrays)`` must equal ``len(sharding.addressable_devices)`` and the shape of each array must be the same. For multiprocess code, each process will call with a different ``arrays`` argument that corresponds to that processes' data. These arrays are commonly created via ``jax.device_put``. @@ -1071,14 +1072,15 @@ def make_array_from_single_device_arrays( if dtypes.issubdtype(aval.dtype, dtypes.extended): return aval.dtype._rules.make_sharded_array(aval, sharding, arrays, committed=True) + arrays = list(arrays) if isinstance(arrays, tuple) else arrays # TODO(phawkins): ideally the cast() could be checked. try: return ArrayImpl(aval, sharding, cast(Sequence[ArrayImpl], arrays), committed=True) except TypeError: - if not isinstance(arrays, Sequence): + if not isinstance(arrays, list): raise TypeError("jax.make_array_from_single_device_arrays `arrays` " - "argument must be a Sequence (list or tuple), but got " + "argument must be a list or tuple, but got " f"{type(arrays)}.") if any(isinstance(arr, core.Tracer) for arr in arrays): raise ValueError( @@ -1092,9 +1094,8 @@ def _get_aval_array(self): return core.update_aval_with_sharding(self.aval, self.sharding) core.pytype_aval_mappings[ArrayImpl] = _get_aval_array -# TODO(jakevdp) replace this with true inheritance at the C++ level. -basearray.Array.register(ArrayImpl) - +if jaxlib_extension_version < 325: + basearray.Array.register(ArrayImpl) def _array_mlir_constant_handler(val): try: @@ -1149,7 +1150,7 @@ def shard_device_array(x, devices, indices, sharding): else: # TODO(yashkatariya): Maybe this should be set when we call the handler in # InputsHandler.__call__? - with use_concrete_mesh(None): + with _internal_use_concrete_mesh(None): shards = x._multi_slice(start_indices, limit_indices, removed_dims) aval = core.shaped_abstractify(x) return pxla.batched_device_put(aval, sharding, shards, devices) diff --git a/jax/_src/basearray.py b/jax/_src/basearray.py index a89d4a2949be..6cd60deda3b0 100644 --- a/jax/_src/basearray.py +++ b/jax/_src/basearray.py @@ -17,9 +17,15 @@ from __future__ import annotations import abc -import numpy as np -from typing import Any, Union from collections.abc import Sequence +import sys +from typing import Any, Union + +from jax._src.lib import jaxlib_extension_version +from jax._src.lib import xla_client as xc +from jax._src.util import use_cpp_class +import numpy as np + # TODO(jakevdp): fix import cycles and define these. Device = Any @@ -29,7 +35,9 @@ # Array is a type annotation for standard JAX arrays and tracers produced by # core functions in jax.lax and jax.numpy; it is not meant to include # future non-standard array types like KeyArray and BInt. -class Array(abc.ABC): + + +class Array: """Array base class for JAX ``jax.Array`` is the public interface for instance checks and type annotation @@ -47,8 +55,6 @@ def f(x: Array) -> Array: # type annotations are valid for traced and non-trace :func:`jax.numpy.array`, :func:`jax.numpy.zeros`, :func:`jax.numpy.ones`, :func:`jax.numpy.full`, :func:`jax.numpy.arange`, etc. """ - # Note: abstract methods for this class are defined dynamically in - # lax_numpy.py # For the sake of static type analysis, these definitions are mirrored in the # associated basearray.pyi file. @@ -56,42 +62,41 @@ def f(x: Array) -> Array: # type annotations are valid for traced and non-trace __hash__ = None @property - @abc.abstractmethod def dtype(self) -> np.dtype: """The data type (:class:`numpy.dtype`) of the array.""" + raise NotImplementedError @property - @abc.abstractmethod def ndim(self) -> int: """The number of dimensions in the array.""" + raise NotImplementedError @property - @abc.abstractmethod def size(self) -> int: """The total number of elements in the array.""" + raise NotImplementedError @property - @abc.abstractmethod def shape(self) -> tuple[int, ...]: """The shape of the array.""" + raise NotImplementedError # Documentation for sharding-related methods and properties defined on ArrayImpl: - @abc.abstractmethod def addressable_data(self, index: int) -> Array: """Return an array of the addressable data at a particular index.""" + raise NotImplementedError @property - @abc.abstractmethod def addressable_shards(self) -> Sequence[Shard]: """List of addressable shards.""" + raise NotImplementedError @property - @abc.abstractmethod def global_shards(self) -> Sequence[Shard]: """List of global shards.""" + raise NotImplementedError @property - @abc.abstractmethod def is_fully_addressable(self) -> bool: """Is this Array fully addressable? @@ -103,19 +108,19 @@ def is_fully_addressable(self) -> bool: a jax.Array which is fully replicated can span across multiple hosts and is not fully addressable. """ + raise NotImplementedError @property - @abc.abstractmethod def is_fully_replicated(self) -> bool: """Is this Array fully replicated?""" + raise NotImplementedError @property - @abc.abstractmethod def sharding(self) -> Sharding: """The sharding for the array.""" + raise NotImplementedError @property - @abc.abstractmethod def committed(self) -> bool: """Whether the array is committed or not. @@ -137,20 +142,20 @@ def committed(self) -> bool: a + b # Raises an error ``` - See https://jax.readthedocs.io/en/latest/faq.html#controlling-data-and-computation-placement-on-devices + See https://docs.jax.dev/en/latest/faq.html#controlling-data-and-computation-placement-on-devices for more information. """ + raise NotImplementedError @property - @abc.abstractmethod def device(self) -> Device | Sharding: """Array API-compatible device attribute. For single-device arrays, this returns a Device. For sharded arrays, this returns a Sharding. """ + raise NotImplementedError - @abc.abstractmethod def copy_to_host_async(self): """Copies an ``Array`` to the host asynchronously. @@ -165,17 +170,30 @@ def copy_to_host_async(self): array, but does not wait for the copy to complete. This may speed up a future on-host access to the array's contents. """ + raise NotImplementedError + + +if jaxlib_extension_version >= 325: + Array = use_cpp_class(xc.Array)(Array) +else: + class Array(Array, metaclass=abc.ABCMeta): + ... Array.__module__ = "jax" + # StaticScalar is the Union of all scalar types that can be converted to # JAX arrays, and are possible to mark as static arguments. StaticScalar = Union[ np.bool_, np.number, # NumPy scalar types bool, int, float, complex, # Python scalar types ] -StaticScalar.__doc__ = "Type annotation for JAX-compatible static scalars." + +if sys.version_info[:2] < (3, 14): + # Python 3.14 raises + # AttributeError: 'typing.Union' object attribute '__doc__' is read-only + StaticScalar.__doc__ = "Type annotation for JAX-compatible static scalars." # ArrayLike is a Union of all objects that can be implicitly converted to a @@ -187,4 +205,8 @@ def copy_to_host_async(self): np.ndarray, # NumPy array type StaticScalar, # valid scalars ] -ArrayLike.__doc__ = "Type annotation for JAX array-like objects." + +if sys.version_info[:2] < (3, 14): + # Python 3.14 raises + # AttributeError: 'typing.Union' object attribute '__doc__' is read-only + ArrayLike.__doc__ = "Type annotation for JAX array-like objects." diff --git a/jax/_src/basearray.pyi b/jax/_src/basearray.pyi index a368b593332d..8bf68f622051 100644 --- a/jax/_src/basearray.pyi +++ b/jax/_src/basearray.pyi @@ -14,11 +14,12 @@ import abc from collections.abc import Callable, Sequence from types import ModuleType -from typing import Any, Protocol, Union, runtime_checkable +from typing import Any, Protocol, runtime_checkable, Union import numpy as np -from jax._src.sharding import Sharding from jax._src.partition_spec import PartitionSpec +from jax._src.sharding import Sharding + # TODO(jakevdp) de-duplicate this with the DTypeLike definition in typing.py. # We redefine these here to prevent circular imports. @@ -39,7 +40,8 @@ Traceback = Any PrecisionLike = Any -class Array(abc.ABC): +# TODO(slebedev): Remove the metaclass once ``jax_extension_version >= 325``. +class Array(metaclass=abc.ABCMeta): aval: Any @property diff --git a/jax/_src/cache_key.py b/jax/_src/cache_key.py index e4b6e7a2669c..6fe3d8819d3c 100644 --- a/jax/_src/cache_key.py +++ b/jax/_src/cache_key.py @@ -110,6 +110,10 @@ def get( bytes(jaxlib_version_str.encode("utf-8")) ), ), + ( + "backend version", + lambda hash_obj: _hash_platform(hash_obj, backend) + ), ( "XLA flags", lambda hash_obj: _hash_xla_flags(hash_obj, get_flag_prefixes()), @@ -126,7 +130,7 @@ def get( ), ( "accelerator_config", - lambda hash_obj: _hash_accelerator_config(hash_obj, devices, backend), + lambda hash_obj: _hash_accelerator_config(hash_obj, devices), ), ( "compression", @@ -220,7 +224,7 @@ def _hash_devices(hash_obj, devices: np.ndarray) -> None: _hash_string(hash_obj, device.device_kind) -def _hash_accelerator_config(hash_obj, accelerators: np.ndarray, backend): +def _hash_accelerator_config(hash_obj, accelerators: np.ndarray): accelerator_devices = [] for accelerator in accelerators.flat: accelerator_devices.append(accelerator) @@ -233,9 +237,8 @@ def _hash_accelerator_config(hash_obj, accelerators: np.ndarray, backend): # PjRtTopologyDescription as yet. logger.info("get (_hash_accelerator_config): unable to hash " "accelerator config, falling back to hashing " - "devices + platform: %s (type %s)", ex, type(ex)) + "devices %s (type %s)", ex, type(ex)) _hash_devices(hash_obj, accelerators) - _hash_platform(hash_obj, backend) # LINT.IfChange(xla_flags) xla_flags_to_exclude_from_cache_key = [ diff --git a/jax/_src/callback.py b/jax/_src/callback.py index bdceb98d92b7..25bdb801edce 100644 --- a/jax/_src/callback.py +++ b/jax/_src/callback.py @@ -23,7 +23,6 @@ import jax from jax._src import config from jax._src import core -from jax._src import deprecations from jax._src import dispatch from jax._src import dtypes from jax._src import effects @@ -33,6 +32,7 @@ from jax._src import tree_util from jax._src import util from jax._src import xla_bridge as xb +from jax._src.lib import jaxlib_extension_version from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir @@ -47,10 +47,6 @@ logger = logging.getLogger(__name__) -# TODO(dfm): Remove after 6 months. -# Added Oct 1, 2024 -deprecations.register("jax-callback-vectorized") - # `pure_callback_p` is the main primitive for staging out Python pure callbacks. pure_callback_p = core.Primitive("pure_callback") pure_callback_p.multiple_results = True @@ -82,10 +78,9 @@ def pure_callback_impl( result_avals, callback: _FlatCallback, sharding: SingleDeviceSharding | None, - vectorized: bool | DeprecatedArg, vmap_method: str | None, ): - del sharding, vectorized, vmap_method, result_avals + del sharding, vmap_method, result_avals try: cpu_device, *_ = jax.local_devices(backend="cpu") except RuntimeError as e: @@ -113,10 +108,9 @@ def pure_callback_abstract_eval( callback: _FlatCallback, result_avals, sharding: SingleDeviceSharding | None, - vectorized: bool | DeprecatedArg, vmap_method: str | None, ): - del avals, callback, sharding, vectorized, vmap_method + del avals, callback, sharding, vmap_method return result_avals @@ -166,7 +160,7 @@ def _callback_op_sharding( sharding_impls.SdyArraySharding( mesh_shape=(), dimension_shardings=[ - sharding_impls.SdyDimSharding(axes=[], is_closed=True) + sharding_impls.SdyDimSharding(axes=[], is_open=False) ] * avals_out[0].ndim, logical_device_ids=())]) else: @@ -200,7 +194,11 @@ def _callback_op_sharding( # program has bulk array semantics, so we run the callback with a MAXIMAL # sharding and hence execute it only once on the full logical value). if config.use_shardy_partitioner.value: - op_sharding = sharding_impls.SdyArrayShardingList([ + # For shardy, we need to have the same number of shardy annotations as the + # number of result ops. If there are no result ops, we need 1 shardy + # annotation. + num_sdy_shardings = max(1, len(avals_out)) + op_sharding = sharding_impls.SdyArrayShardingList(num_sdy_shardings * [ sharding_impls.SdyArraySharding( mesh_shape=(), dimension_shardings=[], @@ -287,7 +285,7 @@ def pure_callback( When `vmap`-ed the behavior will depend on the value of the ``vmap_method``. * Calling :func:`~jax.vmap` on a callback without an explicit ``vmap_method`` - is deprecated and it will eventually raise ``NotImplementedError``. + raises a ``NotImplementedError``. * ``vmap_method="sequential"`` uses :func:`~jax.lax.map` to loop over the batched arguments, calling ``callback`` once for each batch element. * ``vmap_method="sequential_unrolled"`` is like ``sequential``, but the loop @@ -297,9 +295,8 @@ def pure_callback( * ``vmap_method="broadcast_all"`` behaves like ``expand_dims``, but the inputs are tiled to the expected batched shape. - If necessary, the legacy behavior provided by the deprecated - ``vectorized=True`` argument can be recovered using - ``vmap_method="legacy_vectorized"``. + If necessary, the legacy behavior provided by the removed ``vectorized=True`` + argument can be recovered using ``vmap_method="legacy_vectorized"``. The current default behavior is to use ``vmap_method="sequential"`` when not specified, but this behavior is deprecated, and in the future, the @@ -366,20 +363,13 @@ def pure_callback( (4,) (4,) Array([1., 2., 3., 4.], dtype=float32) - .. _External Callbacks: https://jax.readthedocs.io/en/latest/external-callbacks.html + .. _External Callbacks: https://docs.jax.dev/en/latest/external-callbacks.html """ - if not isinstance(vectorized, DeprecatedArg) and not vectorized is None: - deprecations.warn( - "jax-callback-vectorized", - "The vectorized argument of jax.pure_callback is deprecated and setting " - "it will soon raise an error. To avoid an error in the future, and to " - "suppress this warning, please use the vmap_method argument instead.", - stacklevel=2) - if vmap_method is not None: - raise ValueError( - "the vectorized and vmap_method arguments of jax.pure_callback cannot " - "be used together. Please use the vmap_method argument.") - vmap_method = "legacy_vectorized" if vectorized else "sequential" + # TODO(danfm): Remove this check 3 months after v0.6.0 is released. + if not isinstance(vectorized, DeprecatedArg): + raise ValueError( + "The 'vectorized' argument of jax.pure_callback was removed in JAX " + "v0.6.0. Use 'vmap_method' instead.") allowed_vmap_methods = ["sequential", "sequential_unrolled", "expand_dims", "broadcast_all", "legacy_vectorized", None] if vmap_method not in allowed_vmap_methods: @@ -397,7 +387,6 @@ def pure_callback( callback=_FlatCallback(callback, in_tree), result_avals=tuple(flat_result_avals), sharding=sharding, - vectorized=vectorized, vmap_method=vmap_method, ) return tree_util.tree_unflatten(out_tree, out_flat) @@ -575,7 +564,7 @@ def io_callback( - :func:`jax.debug.callback`: callback designed for general-purpose debugging. - :func:`jax.debug.print`: callback designed for printing. - .. _External Callbacks: https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html + .. _External Callbacks: https://docs.jax.dev/en/latest/notebooks/external_callbacks.html """ flat_args, in_tree = tree_util.tree_flatten((args, kwargs)) tree_util.tree_map(_check_shape_dtype, result_shape_dtypes) @@ -592,7 +581,6 @@ def io_callback( return tree_util.tree_unflatten(out_tree, out_flat) - def is_empty_shape(s: core.Shape) -> bool: return any(d == 0 for d in s) @@ -666,6 +654,25 @@ def receive_from_host( return token, result + +def _aval_to_xla_shape(aval: core.AbstractValue) -> xc.Shape: + try: + return _xla_shape_handlers[type(aval)](aval) + except KeyError as err: + raise TypeError(f"No xla_shape_handler for type: {type(aval)}") from err + +_xla_shape_handlers: dict[type[core.AbstractValue], + Callable[[Any], xc.Shape]] = {} + +def _make_array_shape(aval: core.ShapedArray) -> xc.Shape: + aval = core.physical_aval(aval) + dtype = np.dtype('bool') if aval.dtype == dtypes.float0 else aval.dtype + return xc.Shape.array_shape(dtype, aval.shape) +_xla_shape_handlers[core.ShapedArray] = _make_array_shape + +_xla_shape_handlers[core.AbstractToken] = lambda _: xc.Shape.token_shape() + + def _emit_tpu_python_callback( backend: xb.XlaBackend, ctx: mlir.LoweringRuleContext, @@ -695,8 +702,7 @@ def _wrapped_callback(*args): # pylint: disable=function-redefined send_channel = ctx.module_context.new_channel() dummy_send_aval = core.ShapedArray((1,), np.float32) dummy_send_val = mlir.ir_constant(np.zeros(1, np.float32)) - operand_shapes = [*operand_shapes, - xla.aval_to_xla_shapes(dummy_send_aval)[0]] + operand_shapes = [*operand_shapes, _aval_to_xla_shape(dummy_send_aval)] token = send_to_host(send_channel, token, dummy_send_val, callback.__name__, sharding=sharding) send_channels.append(send_channel) @@ -747,22 +753,49 @@ def emit_python_callback( result_avals: Sequence[core.ShapedArray], *, has_side_effect: bool, + partitioned: bool = False, sharding: SdyArrayShardingList | xc.OpSharding | None = None, operand_layouts: Sequence[Sequence[int] | None] | None = None, result_layouts: Sequence[Sequence[int] | None] | None = None, ) -> tuple[Sequence[mlir.IrValues], Any, Any]: - """Emits MLIR that calls back to a provided Python function.""" + """Emits MLIR that calls back to a provided Python function. + + Args: + ctx: The lowering context. + callback: The Python callback function. + token: The token to use for the callback. + operands: The operands to the callback. + operand_avals: The abstract values of the operands. + result_avals: The abstract values of the results. + has_side_effect: Whether the callback has side effects. + partitioned: If True, then `callback` is called on local shards only. If + False, then `callback` is called on all shards. + sharding: The sharding of the callback. + operand_layouts: The layouts of the operands. + result_layouts: The layouts of the results. + + Returns: + A tuple of MLIR result values, a new token (if any), and the host callback + object. + """ if len(ctx.module_context.platforms) > 1: raise NotImplementedError("multi-platform lowering for python_callback") platform = ctx.module_context.platforms[0] if platform not in {"cpu", "cuda", "rocm", "tpu"}: raise ValueError( f"`EmitPythonCallback` not supported on {platform} backend.") + if partitioned: + if platform not in {"cpu", "cuda", "rocm"}: + raise ValueError( + f"Partitioned callback not supported on {platform} backend.") + if jaxlib_extension_version < 329: + raise ValueError( + "Partitioned callback not supported on jaxlib version < 329.") + if result_avals: + raise ValueError("Partitioned callback not supported with return values.") backend = ctx.module_context.get_backend() - result_shapes = util.flatten( - [xla.aval_to_xla_shapes(result_aval) for result_aval in result_avals]) - operand_shapes = util.flatten( - [xla.aval_to_xla_shapes(op_aval) for op_aval in operand_avals]) + result_shapes = [_aval_to_xla_shape(aval) for aval in result_avals] + operand_shapes = [_aval_to_xla_shape(aval) for aval in operand_avals] # Handling layouts if operand_layouts is None: operand_layouts = util.concatenate( @@ -822,55 +855,98 @@ def _wrapped_callback(*args): for result_aval in result_avals] return outputs, token, None - result_types = mlir.flatten_ir_types([mlir.aval_to_ir_type(aval) for aval in result_avals]) - if token: + # TODO(dsuo): Remove this once we bump minimum_jaxlib_version to "0.5.4". + if jaxlib_extension_version <= 320: + result_types = mlir.flatten_ir_types([mlir.aval_to_ir_type(aval) for aval in result_avals]) + if token: + + callback_without_token = _wrapped_callback + def _wrapped_callback(token, *args): # type: ignore # pylint: disable=function-redefined + return (token, *callback_without_token(*args)) + + operand_shapes = [ + _aval_to_xla_shape(core.abstract_token), *operand_shapes + ] + result_shapes = [ + _aval_to_xla_shape(core.abstract_token), *result_shapes + ] + operands = [token, *operands] + result_types = [mlir.token_type(), *result_types] + operand_mlir_layouts = [_layout_to_mlir_layout(None), *operand_mlir_layouts] + result_mlir_layouts = [_layout_to_mlir_layout(None), *result_mlir_layouts] + callback_descriptor, ifrt_callback = ( + backend.get_emit_python_callback_descriptor(_wrapped_callback, + operand_shapes, + result_shapes)) + ctx.module_context.add_host_callback(ifrt_callback) + descriptor_operand = mlir.ir_constant(callback_descriptor) + callback_operands = [descriptor_operand, *operands] + if operand_mlir_layouts is not None: + operand_mlir_layouts = [_layout_to_mlir_layout([]), *operand_mlir_layouts] + result_type = ir.TupleType.get_tuple(result_types) + call_target_name = ("xla_python_gpu_callback" + if platform in {"cuda", "rocm"} else "xla_python_cpu_callback") + result = hlo.CustomCallOp( + [result_type], + callback_operands, + call_target_name=ir.StringAttr.get(call_target_name), + has_side_effect=ir.BoolAttr.get(has_side_effect), + api_version=mlir.i32_attr(2), + called_computations=ir.ArrayAttr.get([]), + backend_config=ir.StringAttr.get(str(callback_descriptor)), + operand_layouts=( + None if operand_mlir_layouts is None + else ir.ArrayAttr.get(operand_mlir_layouts)), + result_layouts=( + None if result_mlir_layouts is None + else ir.ArrayAttr.get(result_mlir_layouts))) + if sharding is not None: + mlir.set_sharding(result, sharding) + results = [ + hlo.get_tuple_element(result, mlir.i32_attr(i)) + for i in range(len(result_types)) + ] + else: + device = "gpu" if platform in {"cuda", "rocm"} else "cpu" + partition = "_partitioned" if partitioned else "" + call_target_name = f"xla_ffi{partition}_python_{device}_callback" + if token: + callback_without_token = _wrapped_callback + def _wrapped_callback(token, *args): # type: ignore # pylint: disable=function-redefined + return (token, *callback_without_token(*args)) + operands = [token, *operands] + if ( + config.use_shardy_partitioner.value + and sharding is not None + and len(ctx.avals_out) > 0 + and isinstance(sharding, sharding_impls.SdyArrayShardingList) + ): + # Add a sharding annotation for the token if we have at least one + # output. Otherwise, the single shardy annotation required of all ops + # (even those without any results) can annotate the token. + sharding = sharding_impls.SdyArrayShardingList( + [*sharding.shardings, sharding.shardings[-1]] + ) + ctx = dataclasses.replace( + ctx, + avals_in=[core.abstract_token, *ctx.avals_in], + avals_out=[core.abstract_token, *ctx.avals_out], + ) - callback_without_token = _wrapped_callback - def _wrapped_callback(token, *args): # type: ignore # pylint: disable=function-redefined - return (token, *callback_without_token(*args)) + # TODO(dsuo): Remove this line once we deprecate the XLA custom call + # handler. + ifrt_callback = _wrapped_callback + ctx.module_context.add_host_callback(ifrt_callback) + index = np.uint64(len(ctx.module_context.host_callbacks) - 1) + result = ffi.build_ffi_lowering_function( # type: ignore + call_target_name, + has_side_effect=has_side_effect, + )(ctx, *operands, index=np.uint64(index)) - operand_shapes = [ - xla.aval_to_xla_shapes(core.abstract_token)[0], *operand_shapes - ] - result_shapes = [ - xla.aval_to_xla_shapes(core.abstract_token)[0], *result_shapes - ] - operands = [token, *operands] - result_types = [mlir.token_type(), *result_types] - operand_mlir_layouts = [_layout_to_mlir_layout(None), *operand_mlir_layouts] - result_mlir_layouts = [_layout_to_mlir_layout(None), *result_mlir_layouts] - callback_descriptor, ifrt_callback = ( - backend.get_emit_python_callback_descriptor(_wrapped_callback, - operand_shapes, - result_shapes)) - ctx.module_context.add_host_callback(ifrt_callback) - descriptor_operand = mlir.ir_constant(callback_descriptor) - callback_operands = [descriptor_operand, *operands] - if operand_mlir_layouts is not None: - operand_mlir_layouts = [_layout_to_mlir_layout([]), *operand_mlir_layouts] - result_type = ir.TupleType.get_tuple(result_types) - call_target_name = ("xla_python_gpu_callback" - if platform in {"cuda", "rocm"} else "xla_python_cpu_callback") - result = hlo.CustomCallOp( - [result_type], - callback_operands, - call_target_name=ir.StringAttr.get(call_target_name), - has_side_effect=ir.BoolAttr.get(has_side_effect), - api_version=mlir.i32_attr(2), - called_computations=ir.ArrayAttr.get([]), - backend_config=ir.StringAttr.get(str(callback_descriptor)), - operand_layouts=( - None if operand_mlir_layouts is None - else ir.ArrayAttr.get(operand_mlir_layouts)), - result_layouts=( - None if result_mlir_layouts is None - else ir.ArrayAttr.get(result_mlir_layouts))) - if sharding is not None: - mlir.set_sharding(result, sharding) - results = [ - hlo.get_tuple_element(result, mlir.i32_attr(i)) - for i in range(len(result_types)) - ] + if sharding is not None: + mlir.set_sharding(result, sharding) + + results = result.results # type: ignore if token: token, *results = results return results, token, ifrt_callback diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index 1ec8ad50b456..959ff36881e1 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -600,7 +600,7 @@ def isnan(x): lax.igamma_p, lax.igammac_p, lax.integer_pow_p, lax.lgamma_p, lax.linear_solve_p, lax.log1p_p, lax.log_p, lax.logistic_p, lax.mul_p, lax.pad_p, lax.pow_p, lax.psum_p, - lax.random_gamma_grad_p, lax.reduce_p, lax.reduce_prod_p, + lax.reduce_p, lax.reduce_prod_p, lax.reduce_sum_p, lax.reduce_window_p, lax.reduce_window_sum_p, lax.regularized_incomplete_beta_p, lax.rem_p, lax.rng_uniform_p, lax.rsqrt_p, lax.sin_p, @@ -913,14 +913,14 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr, # Update pjit params to account for extra error values. num_error_vals = len(err_vals) num_out_error_vals = out_tree.num_leaves - len(out_shardings) - sharding = sharding_impls.UNSPECIFIED new_in_shardings = (*[sharding] * num_error_vals, *in_shardings) - new_out_shardings = (*[sharding] * num_out_error_vals, *out_shardings) new_in_layouts = (*[None] * num_error_vals, *in_layouts) - new_out_layouts = (*[None] * num_out_error_vals, *out_layouts) new_donated_invars = (*[False] * num_error_vals, *donated_invars) + new_out_shardings = (*[sharding] * num_out_error_vals, *out_shardings) + new_out_layouts = (*[None] * num_out_error_vals, *out_layouts) + err_and_out = pjit.pjit_p.bind( *new_vals_in, jaxpr=checked_jaxpr, @@ -966,13 +966,15 @@ def shard_map_error_check( new_vals_in = [*err_vals, *vals_in] in_avals = list(map(core.get_aval, new_vals_in)) auto = kwargs.get('auto') + check_rep = kwargs.get('check_rep') for i, v in enumerate(in_avals): if not (sharder := core.shard_aval_handlers.get(type(v))): raise ValueError(f'Unsupported aval type: {type(v)}') - in_avals[i] = sharder(mesh, auto, new_in_names[i], v) + in_avals[i] = sharder(mesh, auto, check_rep, new_in_names[i], v) with (shard_map._extend_axis_env(mesh, auto), - mesh_lib.use_abstract_mesh(shard_map._as_manual_mesh(mesh, auto))): + mesh_lib.use_abstract_mesh(shard_map._as_manual_mesh(mesh, auto)), + config._check_rep(check_rep)): # jaxpr to checked_jaxpr checked_jaxpr, out_tree, _ = jaxpr_to_checkify_jaxpr( pe.close_jaxpr(jaxpr), enabled_errors, err_tree, *in_avals @@ -985,7 +987,7 @@ def expand_errors_leading_dim(*xs): errs = [lax.expand_dims(e, [0]) for e in errs] return *errs, *outs - with core.extend_axis_env_nd(mesh.shape.items()): + with core.extend_axis_env_nd(mesh.shape.items()), config._check_rep(check_rep): jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic( lu.wrap_init(expand_errors_leading_dim, debug_info=checked_jaxpr.jaxpr.debug_info), diff --git a/jax/_src/clusters/k8s_cluster.py b/jax/_src/clusters/k8s_cluster.py index 1274724b8ebd..23b2d68e11a5 100644 --- a/jax/_src/clusters/k8s_cluster.py +++ b/jax/_src/clusters/k8s_cluster.py @@ -34,16 +34,18 @@ def is_env_present(cls) -> bool: if 'KUBERNETES_SERVICE_HOST' in os.environ: try: import kubernetes as k8s # pytype: disable=import-error - except ImportError as e: - warnings.warn(textwrap.fill( - "Kubernetes environment detected, but the `kubernetes` package is " - "not installed to enable automatic bootstrapping in this " - "environment. To enable automatic boostrapping, please install " - "jax with the [k8s] extra. For example:" - " pip install jax[k8s]" - " OR" - " pip install jax[k8s,]" - )) + except (ImportError, ModuleNotFoundError): + warnings.warn( + '\n'.join([ + textwrap.fill( + "Kubernetes environment detected, but the `kubernetes` package " + "is not installed to enable automatic bootstrapping in this " + "environment. To enable automatic boostrapping, please install " + "jax with the [k8s] extra. For example:"), + " pip install jax[k8s]", + " pip install jax[k8s,]", + ]) + ) return False k8s.config.load_incluster_config() diff --git a/jax/_src/compilation_cache.py b/jax/_src/compilation_cache.py index f1b56adf3359..e8f7c9f7509c 100644 --- a/jax/_src/compilation_cache.py +++ b/jax/_src/compilation_cache.py @@ -275,7 +275,7 @@ def put_executable_and_time( f"PERSISTENT CACHE WRITE with key {cache_key}, this is unexpected because " "JAX_COMPILATION_CACHE_EXPECT_PGLE is set. The execution that populated the " "cache may lack coverage, " - "https://jax.readthedocs.io/en/latest/persistent_compilation_cache.html may " + "https://docs.jax.dev/en/latest/persistent_compilation_cache.html may " "help debug why this has happened") cache.put(cache_key, executable_and_time) diff --git a/jax/_src/compiler.py b/jax/_src/compiler.py index dea532d13031..9ac47aa4f0ea 100644 --- a/jax/_src/compiler.py +++ b/jax/_src/compiler.py @@ -17,6 +17,8 @@ from __future__ import annotations from collections.abc import Sequence +import copy +from functools import partial import logging import time from typing import Any, Callable @@ -197,15 +199,6 @@ def get_compile_options( config.memory_fitting_level.value ).value - # This is a temporary workaround to simplify the AutoPGLE usage. - # TODO(b/376647494): Remove once the bug is fixed. - if ((config.enable_pgle.value and config.pgle_profiling_runs.value > 0) - or config.compilation_cache_expect_pgle.value): - logger.debug("Explicitly disabling command buffer scheduling for AutoPGLE.") - if env_options_overrides is None: - env_options_overrides = {} - env_options_overrides['xla_gpu_enable_command_buffer'] = '' - if env_options_overrides is not None: # Some overrides are passed directly on build_options. overrides_on_build_options = [ @@ -298,6 +291,8 @@ def backend_compile( options: xc.CompileOptions, host_callbacks: Sequence[Any], ) -> xc.LoadedExecutable: + sym_name = module.operation.attributes['sym_name'] + module_name = ir.StringAttr(sym_name).value # Convert ir.Module to a string representation, unless the backend # explicitly flags the ability to handle a module directly (avoiding the # overhead of back and forth conversions). @@ -308,6 +303,14 @@ def backend_compile( else: built_c = module + if (options.executable_build_options.fdo_profile is not None + and len(options.executable_build_options.fdo_profile)): + logger.debug( + "Compiling module %s with FDO profile of length %d", + module_name, + len(options.executable_build_options.fdo_profile), + ) + try: # we use a separate function call to ensure that XLA compilation appears # separately in Python profiling results @@ -362,72 +365,31 @@ def compile_or_get_cached( if dumped_to := mlir.dump_module_to_file(computation, "compile"): logging.info("Dumped the module to %s.", dumped_to) - use_compilation_cache = compilation_cache.is_cache_used(backend) - is_multi_process = ( len({device.process_index for device in devices.flatten()}) > 1 ) min_device_process_id = min( devices.flatten(), key=lambda device: device.id ).process_index - is_auto_pgle_used = ( - config.enable_pgle.value and config.pgle_profiling_runs.value > 0 - ) - if not use_compilation_cache: - if ( - is_multi_process - and is_auto_pgle_used - and distributed.global_state.client is not None - ): - compile_options.executable_build_options.fdo_profile = ( - _share_fdo_profiles( - computation, - devices, - compile_options, - backend, - distributed.global_state.client, - min_device_process_id, - ) - ) + # cache_key: may be None if compilation caching is disabled + cache_key, compile_options = _resolve_compilation_strategy( + computation, + devices, + compile_options, + backend, + pgle_profiler, + is_multi_process, + module_name, + min_device_process_id, + ) + if cache_key is None: return backend_compile(backend, computation, compile_options, host_callbacks) monitoring.record_event('/jax/compilation_cache/compile_requests_use_cache') - try: - if config.remove_custom_partitioning_ptr_from_cache_key.value: - ignore_callbacks = cache_key_type.IgnoreCallbacks.CUSTOM_PARTITIONING - else: - ignore_callbacks = cache_key_type.IgnoreCallbacks.NO - - cache_key = compilation_cache.get_cache_key( - computation, - devices, - compile_options, - backend, - ignore_callbacks=ignore_callbacks, - ) - except xc._xla.XlaRuntimeError as ex: - logger.error("compile_or_get_cached: unable to generate cache key, " - "skipping the cache: %s", ex) - return backend_compile(backend, computation, compile_options, - host_callbacks) - - if is_auto_pgle_used or config.compilation_cache_expect_pgle.value: - cache_key = _resolve_pgle_module_cache_key( - computation, - devices, - compile_options, - backend, - pgle_profiler, - is_multi_process, - cache_key, - module_name, - min_device_process_id, - ) - cache_retrieval_start = time.monotonic() retrieved_executable, retrieved_compile_time = _cache_read( module_name, cache_key, compile_options, backend) @@ -481,85 +443,130 @@ def compile_or_get_cached( # 1. PGLE optimized module (the one which was recompiled with FDO profile) is # in the persistent cache. In this case the module should be returned from # cache and PGLE should be disabled for this module. Is module is stored in -# the persistent cache under the "pgle_profiled_module_key" which calculated -# with replacing FDO profile with flag which identify that module were PGLE -# profiled. +# the persistent cache under the "pgle_optimized_cache_key", which is +# calculated by replacing the FDO profile with a sentinel value that identifies +# that the module was optimized with PGLE. # 2. PGLE profiled module is not in the persistent cache and the module is -# getting built with an FDO profile. In this case we need to share FDO profile -# with other processes and store the result under the -# "pgle_profiled_module_key" so later in case 1 we will be able to find the +# getting built with an FDO profile. In this case we need to share the FDO +# profile with any other processes and store the result under the +# "pgle_optimized_cache_key" so later in case 1 we will be able to find the # module. # 3. PGLE profiled module is not in the persistent cache and the module is # getting compiled to be PGLEd (FDO profile is empty). In this case we need to -# simply return the non-PGLE profiled module from the persistent cache. +# simply return the non-PGLE profiled module from the persistent cache if it +# exists, and otherwise compile it. # # If the compilation_cache_expect_pgle option is set then in case 1 the PGLE # optimized module will be loaded even if PGLE is not enabled in the current # process. This is useful if we want to combine the use of PGLE with other # profiling tools (e.g. Nsight Systems) that cannot co-exist with PGLE due to # contention for CUPTI resources. -def _resolve_pgle_module_cache_key( +def _resolve_compilation_strategy( computation: ir.Module, devices: np.ndarray, compile_options: xc.CompileOptions, backend: xc.Client, pgle_profiler: profiler.PGLEProfiler | None, is_multi_process: bool, - cache_key: str, module_name: str, min_device_process_id: int, -) -> str: - fdo_profile = compile_options.executable_build_options.fdo_profile - compile_options.executable_build_options.fdo_profile = b"pgle profiled" - - pgle_profiled_module_key = compilation_cache.get_cache_key( - computation, - devices, - compile_options, - backend, - cache_key_type.IgnoreCallbacks.ALL, +) -> tuple[str | None, xc.CompileOptions]: + is_auto_pgle_used = ( + config.enable_pgle.value and config.pgle_profiling_runs.value > 0 ) - compile_options.executable_build_options.fdo_profile = fdo_profile - - result_key = cache_key - if _is_executable_in_cache(backend, pgle_profiled_module_key): - # Load PGLE profiled module from the persistent cache. - result_key = pgle_profiled_module_key - if config.compilation_cache_expect_pgle.value: - logging.info(f"PGLE-optimized {module_name} loaded from compilation cache") - if pgle_profiler is not None: - pgle_profiler.disable() + + get_cache_key = partial(_get_cache_key, backend=backend, + computation=computation, devices=devices) + + if is_auto_pgle_used or config.compilation_cache_expect_pgle.value: + # This can be None if cache key generation fails. + pgle_optimized_cache_key = get_cache_key(compile_options, + override_fdo_profile=b"pgle profiled") + # TODO(b/376647494): remove the workaround when the bug is fixed; the JAX + # profiler cannot collect sufficiently detailed profile data for PGLE if + # command buffers / CUDA graphs are enabled. Therefore disable command + # buffers when compiling for PGLE data collection, but not if AutoPGLE is + # not enabled, and not when re-compiling using PGLE data. This condition + # includes `compilation_cache_expect_pgle` so that slow-to-compile modules + # that are not executed often enough to trigger re-compilation will still + # be cached between an "enable_pgle" run and an "expect_pgle" run. + first_pass_compile_options = copy.deepcopy(compile_options) + first_pass_compile_options.env_option_overrides += [ + ("xla_gpu_enable_command_buffer", ""), + ] else: - # No PGLE-optimised module found in the persistent cache. - if (config.compilation_cache_expect_pgle.value - and _is_executable_in_cache(backend, cache_key)): - # The user asserted this miss was unexpected; emit a warning + pgle_optimized_cache_key = None + first_pass_compile_options = compile_options + + # This can be None if cache key generation fails or caching is disabled + cache_key = get_cache_key(first_pass_compile_options) + + if cache_key is not None and pgle_optimized_cache_key is not None: + # The compilation cache is enabled and AutoPGLE is enabled/expected + if _is_executable_in_cache(backend, pgle_optimized_cache_key): + if config.compilation_cache_expect_pgle.value: + logging.info(f"PGLE-optimized {module_name} loaded from compilation cache") + # No need to record N profiles in this case + if pgle_profiler is not None: + pgle_profiler.disable() + return pgle_optimized_cache_key, compile_options + elif (config.compilation_cache_expect_pgle.value + and _is_executable_in_cache(backend, cache_key)): + # No PGLE-optimized module found in the persistent cache, and the user + # asserted (expect_pgle) that this miss was unexpected warnings.warn(f"PERSISTENT CACHE MISS for PGLE-optimized {module_name} " "despite non-PGLE hit; it may not have been executed " "enough times when the cache was populated") - if fdo_profile is not None and len(fdo_profile) > 0: - # Store module under PGLE profiled module cache key. - result_key = pgle_profiled_module_key - if is_multi_process and distributed.global_state.client is not None: - compile_options.executable_build_options.fdo_profile = ( - _share_fdo_profiles( - computation, - devices, - compile_options, - backend, - distributed.global_state.client, - min_device_process_id, - ) - ) - else: - compile_options.executable_build_options.fdo_profile = fdo_profile - logger.debug( - "Compiling module %s with FDO profile of length %d", - module_name, - len(compile_options.executable_build_options.fdo_profile), + + if (is_auto_pgle_used + and compile_options.executable_build_options.fdo_profile is not None + and len(compile_options.executable_build_options.fdo_profile)): + # Profile data are available to trigger a PGLE-optimized recompilation; + # store under `pgle_optimized_cache_key` if the cache is enabled + if is_multi_process and distributed.global_state.client is not None: + compile_options.executable_build_options.fdo_profile = ( + _share_fdo_profiles( + computation, + devices, + compile_options, + backend, + distributed.global_state.client, + min_device_process_id, ) - return result_key + ) + return pgle_optimized_cache_key, compile_options + else: + # Compile for PGLE collection, store under `cache_key` if the cache is + # enabled. This is also the AutoPGLE-disabled path. + return cache_key, first_pass_compile_options +def _get_cache_key( + options: xc.CompileOptions, + backend: xc.Client, + computation: ir.Module, + devices: np.ndarray, + override_fdo_profile: bytes | None = None) -> str | None: + if not compilation_cache.is_cache_used(backend): + return None + if config.remove_custom_partitioning_ptr_from_cache_key.value: + ignore_callbacks = cache_key_type.IgnoreCallbacks.CUSTOM_PARTITIONING + else: + ignore_callbacks = cache_key_type.IgnoreCallbacks.NO + if override_fdo_profile is not None: + options = copy.deepcopy(options) + options.executable_build_options.fdo_profile = override_fdo_profile + try: + return compilation_cache.get_cache_key( + computation, + devices, + options, + backend, + ignore_callbacks, + ) + except xc._xla.XlaRuntimeError as ex: + logger.error("compile_or_get_cached: unable to generate cache key, " + "skipping the cache: %s", ex) + return None # The process that has the lowest device ID should share FDO profile before # compilation with other processes. diff --git a/jax/_src/config.py b/jax/_src/config.py index cf6a07834a10..aec8b1450fd0 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -235,17 +235,20 @@ def trace_context(): threefry_partitionable.value, threefry_gpu_kernel_lowering.value, use_direct_linearize.value, - varying_axes_in_types.value, softmax_custom_jvp.value, disable_jit.value, debug_key_reuse.value, jax_xla_profile_version.value, + _check_rep.value, # Technically this affects jaxpr->stablehlo lowering, not tracing. hlo_source_file_canonicalization_regex.value, pgle_profiling_runs.value, enable_pgle.value, use_shardy_partitioner.value, - use_high_dynamic_range_gumbel.value) + use_high_dynamic_range_gumbel.value, + error_checking_behavior_nan.value, + error_checking_behavior_divide.value, + error_checking_behavior_oob.value) config = Config() @@ -356,7 +359,7 @@ def __exit__(self, exc_type, exc_value, traceback): " This will be enabled by default in future versions of JAX, at which " "point all uses of the flag will be considered deprecated (following " "the `API compatibility policy " - "`_).") + "`_).") UPGRADE_BOOL_EXTRA_DESC = " (transient)" @@ -908,7 +911,7 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]: 'The calling convention version number to use for exporting. This must be ' 'within the range of versions supported by the tf.XlaCallModule ' 'used in your deployment environment. ' - 'See https://jax.readthedocs.io/en/latest/export/shape_poly.html#calling-convention-versions.' + 'See https://docs.jax.dev/en/latest/export/shape_poly.html#calling-convention-versions.' ) ) @@ -917,7 +920,7 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]: default=bool_env('JAX_EXPORT_IGNORE_FORWARD_COMPATIBILIY', False), help=( 'Whether to ignore the forward compatibility lowering rules. ' - 'See https://jax.readthedocs.io/en/latest/export/export.html#compatibility-guarantees-for-custom-calls.' + 'See https://docs.jax.dev/en/latest/export/export.html#compatibility-guarantees-for-custom-calls.' ) ) @@ -995,7 +998,7 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]: name='jax_explain_cache_misses', default=False, help=('Each time there is a miss on one of the main caches (e.g. the ' - 'tracing cache), log an explanation.. Logging is performed with ' + 'tracing cache), log an explanation. Logging is performed with ' '`logging`. When this option is set, the log level is WARNING; ' 'otherwise the level is DEBUG.')) @@ -1017,13 +1020,13 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]: name='jax_spmd_mode', enum_values=['allow_all', 'allow_jit'], default='allow_jit', - help=("Decides whether Math on `jax.Array`'s that are not fully addressable " - "(i.e. spans across multiple processes) is allowed. The options are: " - "* allow_jit: Default, `pjit` and `jax.jit` computations are allowed " - " to execute on non-fully addressable `jax.Array`s\n" - "* allow_all: `jnp`, normal math (like `a + b`, etc), `pjit`, " - " `jax.jit` and all other operations are allowed to " - " execute on non-fully addressable `jax.Array`s.")) + help=("Decides whether Math on ``jax.Array`` objects that are not fully addressable " + "(i.e. spans across multiple processes) is allowed. The options are:\n\n" + "* ``allow_jit``: Default, ``pjit`` and ``jax.jit`` computations are allowed " + " to execute on non-fully addressable ``jax.Array`` objects\n" + "* ``allow_all``: ``jnp``, normal math (like ``a + b``, etc), ``pjit``, " + " ``jax.jit`` and all other operations are allowed to " + " execute on non-fully addressable ``jax.Array`` objects.")) distributed_debug = bool_state( @@ -1088,20 +1091,13 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]: help=('Use direct linearization instead JVP followed by partial eval'), include_in_jit_key=True) -varying_axes_in_types = bool_state( - name='jax_varying_axes_in_types', +# TODO make it so people don't use this, this is internal... +_check_rep = bool_state( + name='check_rep', default=False, - help=('Adds varying manual axes to ShapedArray to track which mesh axes the' - ' array is varying over. This will help to remove the efficient' - ' transpose rewrite machinery in shard_map'), + help='internal implementation detail of shard_map, DO NOT USE', include_in_jit_key=True) -data_dependent_tracing_fallback = bool_state( - name='jax_data_dependent_tracing_fallback', - default=False, - help=('When True, falls back to trace dispatch based on data dependence ' - 'instead of throwing an escaped tracer error.')) - softmax_custom_jvp = bool_state( name='jax_softmax_custom_jvp', default=False, @@ -1317,6 +1313,41 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]: ), ) +# TODO(ayx): Move these 3 flags out of config once we have a user-level +# extension mechanism for adding contexts to which the jit cache is sensitive. +error_checking_behavior_nan = enum_state( + name='jax_error_checking_behavior_nan', + enum_values=['ignore', 'raise'], + default='ignore', + help=( + 'Specify the behavior when a NaN is encountered. Options are "ignore"' + ' or "raise".' + ), + include_in_jit_key=True, +) + +error_checking_behavior_divide = enum_state( + name='jax_error_checking_behavior_divide', + enum_values=['ignore', 'raise'], + default='ignore', + help=( + 'Specify the behavior when a divide by zero is encountered. Options are' + ' "ignore" or "raise".' + ), + include_in_jit_key=True, +) + +error_checking_behavior_oob = enum_state( + name='jax_error_checking_behavior_oob', + enum_values=['ignore', 'raise'], + default='ignore', + help=( + 'Specify the behavior when an out of bounds access is encountered.' + ' Options are "ignore" or "raise".' + ), + include_in_jit_key=True, +) + def _update_x64_global(val): jax_jit.global_state().enable_x64 = val @@ -1437,18 +1468,17 @@ def _update_disable_jit_thread_local(val): enum_values=["off", "tracebackhide", "remove_frames", "quiet_remove_frames", "auto"], default="auto", - help="Controls how JAX filters internal frames out of tracebacks.\n\n" - "Valid values are:\n" - " * \"off\": disables traceback filtering.\n" - " * \"auto\": use \"tracebackhide\" if running under a sufficiently" - " new IPython, or \"remove_frames\" otherwise.\n" - " * \"tracebackhide\": adds \"__tracebackhide__\" annotations to" - " hidden stack frames, which some traceback printers support.\n" - " * \"remove_frames\": removes hidden frames from tracebacks, and adds" - " the unfiltered traceback as a __cause__ of the exception.\n" - " * \"quiet_remove_frames\": removes hidden frames from tracebacks, and adds" - " a brief message (to the __cause__ of the exception) describing that this has" - " happened.\n") + help="Controls how JAX filters internal frames out of tracebacks. Valid values are:\n" + "- ``off``: disables traceback filtering.\n" + "- ``auto``: use ``tracebackhide`` if running under a sufficiently " + "new IPython, or ``remove_frames`` otherwise.\n" + "- ``tracebackhide``: adds ``__tracebackhide__`` annotations to " + "hidden stack frames, which some traceback printers support.\n" + "- ``remove_frames``: removes hidden frames from tracebacks, and adds " + "the unfiltered traceback as a ``__cause__`` of the exception.\n" + "- ``quiet_remove_frames``: removes hidden frames from tracebacks, and adds " + "a brief message (to the ``__cause__`` of the exception) describing that this has " + "happened.\n\n") # This flag is for internal use. # TODO(tianjianlu): Removes once we always enable cusparse lowering. @@ -1474,13 +1504,6 @@ def _update_disable_jit_thread_local(val): help=('Attempt constant folding during staging.'), include_in_jit_key=True) -# This flag is temporary during rollout of the remat barrier. -# TODO(parkers): Remove if there are no complaints. -remat_opt_barrier = bool_state( - name='jax_remat_opt_barrier', - default=True, - help=('Enables using optimization-barrier op for lowering remat.')) - enable_remat_opt_pass = bool_state( name='jax_compiler_enable_remat_pass', default=True, @@ -1489,13 +1512,6 @@ def _update_disable_jit_thread_local(val): 'compute when encountering OOM errors. However, you are ' 'likely to get better results manually with jax.checkpoint')) -# TODO(sharadmv,mattjj): set default to True, then remove -eager_pmap = bool_state( - name='jax_eager_pmap', - default=True, - upgrade=True, - help='Enable eager-mode pmap when jax_disable_jit is activated.') - no_tracing = bool_state( name='jax_no_tracing', default=False, @@ -1650,7 +1666,7 @@ def transfer_guard(new_val: str) -> Iterator[None]: """A contextmanager to control the transfer guard level for all transfers. For more information, see - https://jax.readthedocs.io/en/latest/transfer_guard.html + https://docs.jax.dev/en/latest/transfer_guard.html Args: new_val: The new thread-local transfer guard level for all transfers. @@ -1685,16 +1701,17 @@ def _update_garbage_collection_guard(state, key, val): # The default is applied by guard_lib. default=None, help=( - 'Select garbage collection guard level for "jax.Array" objects.\nThis' - ' option can be used to control what happens when a "jax.Array"' - ' object is garbage collected. It is desirable for "jax.Array"' - ' objects to be freed by Python reference couting rather than garbage' + 'Select garbage collection guard level for ``jax.Array`` objects.\n\n' + 'This option can be used to control what happens when a ``jax.Array``' + ' object is garbage collected. It is desirable for ``jax.Array``' + ' objects to be freed by Python reference counting rather than garbage' ' collection in order to avoid device memory being held by the arrays' - ' until garbage collection occurs.\n\nValid values are:\n * "allow":' - ' do not log garbage collection of "jax.Array" objects.\n * "log":' - ' log an error when a "jax.Array" is garbage collected.\n * "fatal":' - ' fatal error if a "jax.Array" is garbage collected.\nDefault is' - ' "allow". Note that not all cycles may be detected.' + ' until garbage collection occurs.\n\n' + 'Valid values are:\n\n' + '* ``allow``: do not log garbage collection of ``jax.Array`` objects.\n' + '* ``log``: log an error when a ``jax.Array`` is garbage collected.\n' + '* ``fatal``: fatal error if a ``jax.Array`` is garbage collected.\n\n' + 'Default is ``allow``. Note that not all cycles may be detected.' ), update_global_hook=lambda val: _update_garbage_collection_guard( guard_lib.global_state(), 'garbage_collect_array', val diff --git a/jax/_src/core.py b/jax/_src/core.py index 36ce2f004ed4..013219021e05 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -45,7 +45,7 @@ ConcretizationTypeError, TracerArrayConversionError, TracerBoolConversionError, TracerIntegerConversionError, UnexpectedTracerError) from jax._src import linear_util as lu - +from jax._src.tree_util import tree_flatten, tree_unflatten from jax._src import source_info_util from jax._src.util import (safe_zip, safe_map, curry, tuple_insert, tuple_delete, cache, @@ -320,7 +320,7 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): config.compute_on_context_manager.set_local(self.prev_compute_type) config.threefry_partitionable.set_local(self.prev_threefry_partitionable) - if self.context.xla_metadata is not None: + if self.context.xla_metadata: config.xla_metadata_context_manager.set_local(self.prev_xla_metadata) config.abstract_mesh_context_manager.set_local(self.prev_abstract_mesh) @@ -412,7 +412,6 @@ def new_jaxpr_eqn(invars, outvars, primitive, params, effects, source_info=None, _var_counter = it.count() -@total_ordering class Var: __slots__ = ["count", "suffix", "aval"] @@ -425,11 +424,6 @@ def __init__(self, suffix: str, aval: AbstractValue): self.suffix = suffix self.aval = aval - # TODO(phawkins, mattjj): remove ordering of variables. JAX itself does not - # care about variable ordering, but the downstream package kfac_jax does. - def __lt__(self, other): - return self.count < other.count - def __repr__(self): return f'Var(id={id(self)}){self.suffix}:{self.aval.str_short()}' @@ -503,9 +497,7 @@ def bind(self, *args, **params): def _true_bind(self, *args, **params): for arg in args: - if (isinstance(arg, Tracer) - and not arg._trace.is_valid() - and not config.data_dependent_tracing_fallback.value): + if isinstance(arg, Tracer) and not arg._trace.is_valid(): raise escaped_tracer_error(arg) # TODO: figure out how to handle function arguments # assert (not config.enable_checks.value or @@ -574,7 +566,7 @@ def read(v: Atom) -> Any: def write(v: Var, val: Any) -> None: if config.enable_checks.value and not config.dynamic_shapes.value: - assert typecheck(v.aval, val), (v.aval, val) + assert typecheck(v.aval, val), (v.aval, get_aval(val)) env[v] = val env: dict[Var, Any] = {} @@ -1021,10 +1013,6 @@ def process_primitive(self, primitive, args, params): else: # TODO(dougalm): delete. this shouldn't be necessary args = map(full_lower, args) - if config.data_dependent_tracing_fallback.value: - for arg in args: - if isinstance(arg, Tracer): - return primitive.bind_with_trace(arg._trace, args, params) check_eval_args(args) return primitive.impl(*args, **params) @@ -1502,11 +1490,6 @@ def _jaxpr_type_to_callable_annotation(jaxpr: Jaxpr) -> InputType: for v in jaxpr.invars] return tuple(out) -# TODO(dougalm): Deprecate. This is here for backwards compat. -def lattice_join(x, y): - assert typematch(x, y) - return x - # For use in typing annotations to denote either a Tracer or a `valid_jaxtype`. Value = Any @@ -1528,7 +1511,7 @@ def check_valid_jaxtype(x): def update_aval_with_sharding(aval, sharding): if isinstance(sharding, NamedSharding): - aval = aval.update(sharding=NamedSharding( + return aval.update(sharding=NamedSharding( sharding.mesh.abstract_mesh, sharding.spec._normalized_spec_for_aval(aval.ndim))) return aval @@ -1660,7 +1643,8 @@ def physical_aval(aval): if isinstance(aval, ShapedArray): from jax._src.sharding_impls import physical_sharding # type: ignore return ShapedArray((*aval.shape, *elt_aval.shape), elt_aval.dtype, - sharding=physical_sharding(aval, aval.sharding)) + sharding=physical_sharding(aval, aval.sharding), + vma=aval.vma) return DShapedArray((*aval.shape, *elt_aval.shape), elt_aval.dtype) return aval @@ -1786,6 +1770,10 @@ def _invalid_shape_error(shape: Shape, context: str=""): return TypeError(msg) +class ShardingTypeError(Exception): + pass + + # TODO(dougalm): Cast scalar, numpy arrays, etc to jax arrays so that values # passed to primitives are always have avals, etc i.e. they are canonical. def canonicalize_value(val): @@ -1894,28 +1882,40 @@ def get_sharding(sharding, shape): raise ValueError("Mesh of an aval must be an AbstractMesh. " f"Got {out_s.mesh} of type {type(out_s.mesh)}") _check_divisibility(out_s, shape) + assert out_s.memory_kind is None return out_s -def str_short_aval(shape, dtype, mesh, spec, short_dtypes=False, - mesh_axis_types=False) -> str: +def str_short_aval(shape, dtype, mesh, spec, vma, + short_dtypes=False, mesh_axis_types=False) -> str: dt_str = dtypes.short_dtype_name(dtype) if short_dtypes else dtype.name dt_str = dt_str.replace('void', 'float0') shapestr = _get_shape_sharding_str(shape, spec) mesh_axes = f'({mesh._axis_types_dict})' if mesh_axis_types else '' - return f'{dt_str}[{shapestr}]{mesh_axes}' + vma = f"{{{','.join(i for i in vma)}}}" if vma else '' + return f'{dt_str}[{shapestr}]{vma}{mesh_axes}' + +def get_vma(vma, mesh): + if mesh.empty: + return vma + for i in vma: + if mesh._name_to_type[i] != AxisType.Manual: + raise ValueError( + "Axes mentioned in `vma` field of ShapedArray should" + f" be of type `Manual`. Got axis: {i} of type {mesh._name_to_type[i]}") + assert isinstance(vma, frozenset) + return vma class ShapedArray(UnshapedArray): - __slots__ = ['shape', 'sharding', 'varying_manual_axes'] # inherits slots from parent + __slots__ = ['shape', 'sharding', 'vma'] # inherits slots from parent array_abstraction_level = 2 def __init__(self, shape, dtype, weak_type=False, *, sharding=None, - varying_manual_axes: frozenset[AxisName] = frozenset()): + vma: frozenset[AxisName] = frozenset()): self.shape = canonicalize_shape(shape) self.dtype = _dtype_object(dtype) self.weak_type = weak_type self.sharding = get_sharding(sharding, self.shape) - if config.varying_axes_in_types.value: - self.varying_manual_axes = varying_manual_axes + self.vma = get_vma(vma, self.sharding.mesh) def update(self, shape=None, dtype=None, weak_type=None, **kwargs): if shape is None: @@ -1926,9 +1926,8 @@ def update(self, shape=None, dtype=None, weak_type=None, **kwargs): weak_type = self.weak_type if 'sharding' not in kwargs: kwargs['sharding'] = self.sharding - if 'varying_manual_axes' not in kwargs: - kwargs['varying_manual_axes'] = getattr(self, 'varying_manual_axes', - frozenset()) + if 'vma' not in kwargs: + kwargs['vma'] = self.vma return ShapedArray(shape, dtype, weak_type, **kwargs) ndim = property(lambda self: len(self.shape)) @@ -1946,26 +1945,24 @@ def __eq__(self, other): and self.dtype == other.dtype and self.shape == other.shape and self.weak_type == other.weak_type and self.sharding == other.sharding - and (getattr(self, 'varying_manual_axes', frozenset()) == - getattr(other, 'varying_manual_axes', frozenset()))) + and self.vma == other.vma) def __hash__(self): # can use hash(self.dtype) and rely on the fact that numpy reuses base dtype # objects, e.g. `np.zeros(3).dtype is np.zeros(4).dtype`, or we can use # the unique character code via hash(self.dtype.char) return hash((self.shape, self.dtype, self.weak_type, self.sharding, - getattr(self, 'varying_manual_axes', frozenset()))) + self.vma)) def to_tangent_aval(self): return ShapedArray( self.shape, primal_dtype_to_tangent_dtype(self.dtype), - self.weak_type, sharding=self.sharding, - varying_manual_axes=getattr(self, 'varying_manual_axes', frozenset())) + self.weak_type, sharding=self.sharding, vma=self.vma) def str_short(self, short_dtypes=False, mesh_axis_types=False): return str_short_aval( self.shape, self.dtype, self.sharding.mesh, self.sharding.spec, - short_dtypes, mesh_axis_types) + self.vma, short_dtypes, mesh_axis_types) def _len(self, ignored_tracer): try: @@ -1996,6 +1993,63 @@ def primal_dtype_to_tangent_dtype(primal_dtype): return primal_dtype +def pvary(x, axis_name): + axes = (axis_name,) if not isinstance(axis_name, tuple) else axis_name + if not axis_name: + return x + xs, treedef = tree_flatten(x) + ys = pvary_p.bind(*xs, axes=axes, axis_index_groups=None) + return tree_unflatten(treedef, ys) + +pvary_p = Primitive('pvary') +pvary_p.multiple_results = True +pvary_p.def_impl(lambda *args, axes, axis_index_groups: args) + +def _pvary_abstract_eval(*args, axes, axis_index_groups): + if not config._check_rep.value: + return args + assert isinstance(axes, tuple) + arg_vma = [a.vma for a in args] + # If there is intersection between arg_vma and axes, error + if any(set(axes) & a for a in arg_vma): + raise ValueError( + "Collective pvary must be applied to a " + f"non-device-varying type, but got {arg_vma} for collective acting " + f"over axis name {axes}. Please open an issue at " + "https://github.com/jax-ml/jax/issues, and as a temporary " + "workaround pass the check_rep=False argument to shard_map") + sharding = NamedSharding(mesh_lib.get_abstract_mesh(), P()) + return [a.update(sharding=sharding, vma=a.vma.union(frozenset(axes))) + for a in args] +pvary_p.def_abstract_eval(_pvary_abstract_eval) + + +def standard_insert_pvary(*args): + if not config._check_rep.value: + return args + if not args: + return args + in_vma = [frozenset() if (aval := get_aval(a)) is abstract_token + else aval.vma for a in args] + out_vma = frozenset.union(*in_vma) + return [pvary(arg, tuple(n for n in out_vma if n not in src)) + if out_vma - src else arg for arg, src in zip(args, in_vma)] + +def standard_vma_rule(prim_name, *avals, **kwargs) -> frozenset[AxisName]: + if not config._check_rep.value: + return frozenset() + avals = tuple(a for a in avals if a is not abstract_token) + if not avals: + return frozenset() + vma, *vmas = [a.vma for a in avals] + if not all(vma == vma_ for vma_ in vmas): + raise ValueError( + f'Primitive {prim_name} requires varying manual axes ' + f'to match, but got {[vma, *vmas]}. Please open an issue at ' + 'https://github.com/jax-ml/jax/issues and as a temporary ' + 'workaround pass the check_rep=False argument to shard_map') + return vma + # Dynamic shape stuff below here! We keep the abstract values distinct just so # as not to interfere with any static shape machinery. @@ -2042,6 +2096,10 @@ def update(self, shape=None, dtype=None, weak_type=None): def sharding(self): return NamedSharding(mesh_lib.empty_abstract_mesh, P()) + @property + def vma(self): + return frozenset() + def _len(self, tracer): return self.shape[0] @@ -2146,6 +2204,7 @@ def __init__(self, aval, buf): def __getitem__(self, idx): return self._aval._getitem(self, idx) def __setitem__(self, idx, x): return self._aval._setitem(self, idx, x) def __repr__(self) -> str: return 'Mutable' + repr(self[...]) + def __len__(self) -> int: return self._aval._len(self) pytype_aval_mappings[MutableArray] = lambda x: x._aval def mutable_array(init_val): @@ -2203,16 +2262,10 @@ def block_until_ready(self): pytype_aval_mappings[Token] = lambda _: abstract_token -# TODO(dougalm): Deprecate these. They're just here for backwards compat. -def raise_to_shaped(aval): - return aval -raise_to_shaped_mappings: dict[type, Callable] = {} - ### Operations on shapes and dimension sizes. class InconclusiveDimensionOperation(Exception): """Raised when we cannot conclusively compute with symbolic dimensions.""" - pass def is_symbolic_dim(v: Any) -> bool: """Checks if a value is a symbolic dimension used for shape polymorphism. @@ -2547,12 +2600,12 @@ def unmapped_aval(size: AxisSize, axis: int | None, def _map_shaped_array( size: int, axis: int | None, aval: ShapedArray) -> ShapedArray: - assert axis is None or aval.shape[axis] == size - # TODO: Extend the named shape - if axis is None: return aval + # assert axis is None or aval.shape[axis] == size + if axis is None: + return aval sharding = aval.sharding.with_spec(tuple_delete(aval.sharding.spec, axis)) return ShapedArray(tuple_delete(aval.shape, axis), aval.dtype, - weak_type=aval.weak_type, sharding=sharding) + weak_type=aval.weak_type, sharding=sharding, vma=aval.vma) def _unmap_shaped_array( size: int, axis: int | None, explicit_mesh_axis, aval: ShapedArray @@ -2562,7 +2615,8 @@ def _unmap_shaped_array( sharding = aval.sharding.with_spec(tuple_insert( aval.sharding.spec, axis, explicit_mesh_axis)) return ShapedArray(tuple_insert(aval.shape, axis, size), aval.dtype, - weak_type=aval.weak_type, sharding=sharding) + weak_type=aval.weak_type, sharding=sharding, + vma=aval.vma) else: raise TypeError(axis) def _map_dshaped_array( @@ -2674,7 +2728,8 @@ def typematch(t1: AbstractValue, t2: AbstractValue) -> bool: # could try normalizing first and then doing simple equality. # TODO(yashkatariya): Also check `sharding` here. # See https://github.com/jax-ml/jax/issues/26474 - return t1.dtype == t2.dtype and definitely_equal_shape(t1.shape, t2.shape) + return (t1.dtype == t2.dtype and definitely_equal_shape(t1.shape, t2.shape) + and t1.vma == t2.vma) # type: ignore else: return False diff --git a/jax/_src/cudnn/fused_attention_stablehlo.py b/jax/_src/cudnn/fused_attention_stablehlo.py index c7e7c83f30f8..463665f6fdca 100644 --- a/jax/_src/cudnn/fused_attention_stablehlo.py +++ b/jax/_src/cudnn/fused_attention_stablehlo.py @@ -24,12 +24,11 @@ from jax._src import dispatch from jax._src.custom_partitioning import custom_partitioning from jax._src.interpreters import batching +from jax._src.interpreters import mlir from jax._src.lib import cuda_versions from jax._src import xla_bridge -from jax.interpreters import mlir -from jax.interpreters import xla -from jax.interpreters.mlir import hlo -from jax.interpreters.mlir import ir +from jax._src.lib.mlir import ir +from jax._src.lib.mlir.dialects import hlo import jax.numpy as jnp from jax.sharding import NamedSharding, PartitionSpec @@ -368,8 +367,9 @@ def check_is_flash_attention( f"Unsupported sequence length Q {T}, KV {S}." ) - if is_packed and cudnn_version < 90600: - raise NotImplementedError("Packed layout requires cudnn version >= 9.6.") + if is_packed and (cudnn_version < 90600 or not check_compute_capability("9.0")): + raise NotImplementedError( + "Packed layout requires cudnn version >= 9.6 and at least hopper arch.") def check_cudnn_version(): # check if cuDNN is installed @@ -396,7 +396,7 @@ def is_cuda_compute_capability_equal(capability): def _dot_product_attention_fwd( query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, scale, seed, dropout_rate, variadic_args, mask_type, layout, - sliding_window_length, cudnn_version): + sliding_window_length, cudnn_version, return_residual): # check if flash attention is supported for this attention pattern check_is_flash_attention( query, key, layout, cudnn_version, bias is not None, False, @@ -405,14 +405,16 @@ def _dot_product_attention_fwd( query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, scale=scale, seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args, mask_type=mask_type, layout=layout, - sliding_window_length=sliding_window_length, is_training=False) - output = outputs[0] - return output + sliding_window_length=sliding_window_length, is_training=False or return_residual) + if return_residual: + return outputs + else: + return outputs[0] def _dot_product_attention_fwd_rule( query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, scale, seed, dropout_rate, variadic_args, mask_type, layout, - sliding_window_length, cudnn_version): + sliding_window_length, cudnn_version, return_residual): # check if flash attention is supported for this attention pattern check_is_flash_attention( query, key, layout, cudnn_version, bias is not None, True, @@ -424,11 +426,14 @@ def _dot_product_attention_fwd_rule( sliding_window_length=sliding_window_length, is_training=True) res = (query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, outputs[1], outputs[0]) - return outputs[0], res + if return_residual: + return outputs, res + else: + return outputs[0], res def _dot_product_attention_bwd_rule( scale, seed, dropout_rate, variadic_args, mask_type, layout, - sliding_window_length, is_training, res, grad_output): + sliding_window_length, is_training, return_residual, res, grad_output): (query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, activation, fwd_output) = res grads = _dot_product_attention_bwd_p_wrapper.bind( @@ -1018,7 +1023,7 @@ def sharded_impl(*args): _dot_product_attention_fwd_p = core.Primitive("dot_product_attention_fwd") _dot_product_attention_fwd_p.multiple_results = True _dot_product_attention_fwd_p.def_impl( - functools.partial(xla.apply_primitive, _dot_product_attention_fwd_p) + functools.partial(dispatch.apply_primitive, _dot_product_attention_fwd_p) ) _dot_product_attention_fwd_p.def_abstract_eval( _dot_product_attention_fwd_abstract @@ -1043,7 +1048,7 @@ def sharded_impl(*args): _dot_product_attention_bwd_p = core.Primitive("dot_product_attention_bwd") _dot_product_attention_bwd_p.multiple_results = True _dot_product_attention_bwd_p.def_impl( - functools.partial(xla.apply_primitive, _dot_product_attention_bwd_p) + functools.partial(dispatch.apply_primitive, _dot_product_attention_bwd_p) ) _dot_product_attention_bwd_p.def_abstract_eval( _dot_product_attention_bwd_abstract @@ -1098,7 +1103,7 @@ def sharded_impl(*args): _dot_product_attention_bwd_p_wrapper ) -@functools.partial(jax.custom_vjp, nondiff_argnums=(8, 9, 10, 11, 12, 13, 14, 15)) +@functools.partial(jax.custom_vjp, nondiff_argnums=(8, 9, 10, 11, 12, 13, 14, 15, 16)) def _dot_product_attention(query: Array, key: Array, value: Array, @@ -1114,13 +1119,14 @@ def _dot_product_attention(query: Array, mask_type: bool, layout: int, sliding_window_length: int | None, - cudnn_version: int): + cudnn_version: int, + return_residual: bool): output = _dot_product_attention_fwd( query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, scale=scale, seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args, mask_type=mask_type, layout=layout, sliding_window_length=sliding_window_length, - cudnn_version=cudnn_version) + cudnn_version=cudnn_version, return_residual=return_residual) return output _dot_product_attention.defvjp( @@ -1604,7 +1610,7 @@ def _dot_product_attention_fp8_bwd_partition( _dot_product_attention_fp8_fwd_p = core.Primitive("dot_product_attention_fp8_fwd") _dot_product_attention_fp8_fwd_p.multiple_results = True _dot_product_attention_fp8_fwd_p.def_impl( - functools.partial(xla.apply_primitive, _dot_product_attention_fp8_fwd_p) + functools.partial(dispatch.apply_primitive, _dot_product_attention_fp8_fwd_p) ) _dot_product_attention_fp8_fwd_p.def_abstract_eval( _dot_product_attention_fp8_fwd_abstract @@ -1629,7 +1635,7 @@ def _dot_product_attention_fp8_bwd_partition( _dot_product_attention_fp8_bwd_p = core.Primitive("dot_product_attention_fp8_bwd") _dot_product_attention_fp8_bwd_p.multiple_results = True _dot_product_attention_fp8_bwd_p.def_impl( - functools.partial(xla.apply_primitive, _dot_product_attention_fp8_bwd_p) + functools.partial(dispatch.apply_primitive, _dot_product_attention_fp8_bwd_p) ) _dot_product_attention_fp8_bwd_p.def_abstract_eval( _dot_product_attention_fp8_bwd_abstract @@ -1720,7 +1726,8 @@ def dot_product_attention( dropout_rate: float = 0., qkv_layout: str = "BTNH", sliding_window_length: int | None = None, - use_fp8: bool = False + use_fp8: bool = False, + return_residual: bool = False ): """Computes dot-product attention given query (Q), key (K), and value (V). @@ -1776,8 +1783,12 @@ def dot_product_attention( is the index of each token. E.g., if sliding_window_length == 3 and the sequence is [0, 1, 2, 3, c, 4, 5], token `c` can attend to [4, 5, c]. use_fp8: Whether to use FP8 attention mechanism. + return_residual: Whether to return the logsumexp tensor of shape BTN + or BNT to users. See section 3.1.1 in the FlashAttention-2 paper: + https://arxiv.org/pdf/2307.08691 to find the definition of logsumexp. Returns: - Output of the same shape as the query. + output: the same shape as the query. + residual: the logsumexp tensor if return_residual=True. (non fp8) amax_s: amax of state. (fp8 only) amax_o: amax of output. (fp8 only) """ @@ -1851,5 +1862,5 @@ def dot_product_attention( output = _dot_product_attention( query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, scale, seed, dropout_rate, variadic_args, mask_type, layout.value, - sliding_window_length, cudnn_version) + sliding_window_length, cudnn_version, return_residual) return output diff --git a/jax/_src/cudnn/fusion.py b/jax/_src/cudnn/fusion.py index f320672463cb..355b33e1509c 100644 --- a/jax/_src/cudnn/fusion.py +++ b/jax/_src/cudnn/fusion.py @@ -16,8 +16,8 @@ import jax from jax._src import core as jax_core from jax.interpreters import mlir -from jax.interpreters.mlir import hlo -from jax.interpreters.mlir import ir +from jax._src.lib.mlir import ir +from jax._src.lib.mlir.dialects import hlo diff --git a/jax/_src/cudnn/scaled_matmul_stablehlo.py b/jax/_src/cudnn/scaled_matmul_stablehlo.py index 1a8dee293082..6766e3992202 100644 --- a/jax/_src/cudnn/scaled_matmul_stablehlo.py +++ b/jax/_src/cudnn/scaled_matmul_stablehlo.py @@ -492,13 +492,14 @@ def quantize(x, config): elif config.mode == "nvfp4": assert config.scale_type == jnp.float8_e4m3fn assert config.global_scale.dtype == jnp.float32 + SCALE_MAX = jnp.finfo(config.scale_type).max.astype(x.dtype) - scales = scales / config.global_scale - scales_q = jax.lax.optimization_barrier(scales.astype(jnp.float8_e4m3fn)) - scaled_x = x / (scales_q.astype(jnp.float32) * - config.global_scale).astype(x.dtype) + scales_q = jnp.clip(scales / config.global_scale, 0, SCALE_MAX) + scales_q = jax.lax.optimization_barrier(scales_q.astype(config.scale_type)) + scaled_x = x / scales_q.astype(jnp.float32) else: raise ValueError(f"Unrecognized mode: {config.mode}.") + clipped_x = jnp.clip(scaled_x, -MAX, MAX) x_q = clipped_x.astype(config.data_type) @@ -639,6 +640,17 @@ def scaled_dot_bwd(dimension_numbers, preferred_element_type, configs, res, g): } grad_lhs = scaled_dot_general_transpose_lhs(*args, **lhs_kw_args) grad_rhs = scaled_dot_general_transpose_rhs(*args, **rhs_kw_args) + + # We apply a Straight-Through Estimator (STE) with zero-out behavior: if + # inputs are clipped during quantization in fprop, their corresponding gradients + # are zeroed out; otherwise, they pass through unchanged. + if configs[2].mode == "nvfp4": + assert rhs.dtype == lhs.dtype + MAX = jnp.finfo(configs[0].data_type).max.astype(lhs.dtype) + SCALE_MAX = jnp.finfo(configs[0].scale_type).max.astype(lhs.dtype) + grad_lhs = jnp.where(jnp.abs(lhs) <= configs[0].global_scale * MAX * SCALE_MAX, grad_lhs, 0) + grad_rhs = jnp.where(jnp.abs(rhs) <= configs[1].global_scale * MAX * SCALE_MAX, grad_rhs, 0) + return (grad_lhs, grad_rhs) diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index 32856106ad8f..a8f136477bd9 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -130,7 +130,7 @@ def f_jvp(primals, tangents): For a more detailed introduction, see the tutorial_. - .. _tutorial: https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html + .. _tutorial: https://docs.jax.dev/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html """ fun: Callable[..., ReturnValue] nondiff_argnums: Sequence[int] @@ -521,7 +521,7 @@ def f_bwd(res, g): For a more detailed introduction, see the tutorial_. - .. _tutorial: https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html + .. _tutorial: https://docs.jax.dev/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html """ def __init__(self, @@ -1424,53 +1424,65 @@ def linear_call(fun: Callable, (residual_args, linear_args), {})), t_in_tree) - t_jaxpr, t_consts = _initial_style_jaxpr(t, (*res_avals, *out_avals)) - t_jaxpr_closed = _close_jaxpr(t_jaxpr) - - if t_out_tree() != lin_tree: - raise TypeError( - 'transpose output pytree structure must match that of linear inputs, ' - f'got output structure {t_out_tree()} ' - f'and input structure {lin_tree}.') + @pe._memoize + def transpose_thunk(): + t_jaxpr, t_consts = _initial_style_jaxpr(t, (*res_avals, *out_avals)) + if t_out_tree() != lin_tree: + raise TypeError( + 'transpose output pytree structure must match that of linear inputs, ' + f'got output structure {t_out_tree()} ' + f'and input structure {lin_tree}.') + return _close_jaxpr(t_jaxpr), t_consts - out = linear_call_p.bind(*f_consts, *t_consts, *operands_res, *operands_lin, + out = linear_call_p.bind(*f_consts, *operands_res, *operands_lin, callee=f_jaxpr_closed, - transpose=t_jaxpr_closed, + transpose_thunk=transpose_thunk, num_callee_consts=len(f_consts), - num_transpose_consts=len(t_consts), num_res=len(operands_res)) return tree_unflatten(out_tree(), out) -def _linear_call_impl(*args, callee, transpose, num_callee_consts, - num_transpose_consts, num_res): - del transpose - consts, _, operands_res, operands_lin = split_list( - args, [num_callee_consts, num_transpose_consts, num_res]) - return core.eval_jaxpr(callee.jaxpr, (), *consts, *operands_res, *operands_lin) - -def _linear_call_transpose_rule(cts, *args, callee, transpose, - num_callee_consts, - num_transpose_consts, num_res): - f_consts, t_consts, operands_res, operands_lin = split_list( - args, [num_callee_consts, num_transpose_consts, num_res]) +def _linear_call_impl(*args, callee, transpose_thunk, num_callee_consts, + num_res): + del transpose_thunk, num_callee_consts, num_res + return core.eval_jaxpr(callee.jaxpr, (), *args) + +def _linear_call_jvp_rule(primals, tangents, callee, transpose_thunk, + num_callee_consts, num_res): + consts_and_res, primals = split_list(primals, [num_callee_consts + num_res]) + const_tangents, tangents = split_list(tangents, [num_callee_consts + num_res]) + assert all(type(t) is Zero for t in const_tangents) + primals_out = linear_call_p.bind( + *consts_and_res, *primals, callee=callee, transpose_thunk=transpose_thunk, + num_callee_consts=num_callee_consts, num_res=num_res) + tangents_out = linear_call_p.bind( + *consts_and_res, *tangents, callee=callee, transpose_thunk=transpose_thunk, + num_callee_consts=num_callee_consts, num_res=num_res) + return primals_out, tangents_out + +def _linear_call_transpose_rule(cts, *args, callee, transpose_thunk, + num_callee_consts, num_res): + transpose, t_consts = transpose_thunk() + f_consts, operands_res, operands_lin = split_list( + args, [num_callee_consts, num_res]) _, _, cts_avals = split_list( - transpose.in_avals, [num_transpose_consts, num_res]) + transpose.in_avals, [len(t_consts), num_res]) assert all(ad.is_undefined_primal(x) for x in operands_lin) assert all(not ad.is_undefined_primal(x) for x in operands_res) + def new_transpose_thunk(): + return callee, f_consts + cts = [zeros_like_aval(a) if type(ct) is Zero else ct for ct, a in zip(cts, cts_avals)] - - cts_out = linear_call_p.bind(*t_consts, *f_consts, *operands_res, *cts, + cts_out = linear_call_p.bind(*t_consts, *operands_res, *cts, callee=transpose, - transpose=callee, + transpose_thunk=new_transpose_thunk, num_callee_consts=len(t_consts), - num_transpose_consts=len(f_consts), num_res=len(operands_res)) - return [None] * (num_callee_consts + num_transpose_consts + num_res) + cts_out + return [None] * (num_callee_consts + num_res) + cts_out def _linear_call_abstract_eval(*args, **kwargs): return kwargs['callee'].out_avals @@ -1479,6 +1491,7 @@ def _linear_call_abstract_eval(*args, **kwargs): linear_call_p.multiple_results = True linear_call_p.def_impl(_linear_call_impl) linear_call_p.def_abstract_eval(_linear_call_abstract_eval) +ad.primitive_jvps[linear_call_p] = _linear_call_jvp_rule ad.primitive_transposes[linear_call_p] = _linear_call_transpose_rule xla.register_initial_style_primitive(linear_call_p) mlir.register_lowering(linear_call_p, mlir.lower_fun( diff --git a/jax/_src/custom_partitioning.py b/jax/_src/custom_partitioning.py index 5374071517f1..feb1e0c39cc6 100644 --- a/jax/_src/custom_partitioning.py +++ b/jax/_src/custom_partitioning.py @@ -500,6 +500,12 @@ def __call__(self, *args, **kwargs): infer_sharding_from_operands = None sharding_rule = None if config.use_shardy_partitioner.value: + if (self.sharding_rule is None and + (self.propagate_user_sharding is not None or + self.infer_sharding_from_operands is not None)): + raise ValueError("Shardy is used, but sharding propagation callbacks " + "instead of sharding_rule are provided. Need to " + "provide sharding_rule to migrate to Shardy.") sharding_rule = self.sharding_rule else: propagate_user_sharding = self.propagate_user_sharding diff --git a/jax/_src/custom_transpose.py b/jax/_src/custom_transpose.py index 5e87fdb203c9..21e607b5bff2 100644 --- a/jax/_src/custom_transpose.py +++ b/jax/_src/custom_transpose.py @@ -177,15 +177,19 @@ def bind_with_trace(self, trace, call_args, params): # TODO(frostig,mattjj): consider keeping `call` as a named parameter # instead of following this "call primitive" convention. def get_bind_params(self, params): - assert 'call_jaxpr' in params - assert 'transpose_jaxpr_thunk' in params - new_params: dict[str, Any] = dict(params) - new_params['transpose'] = make_transpose_from_thunk( - new_params.pop('transpose_jaxpr_thunk'), - new_params['lin_tree']) - call_jaxpr: core.ClosedJaxpr = new_params.pop('call_jaxpr') - call = lu.wrap_init(core.jaxpr_as_fun(call_jaxpr), - debug_info=call_jaxpr.jaxpr.debug_info) + if 'call_jaxpr' in params: + assert 'transpose_jaxpr_thunk' in params + new_params: dict[str, Any] = dict(params) + new_params['transpose'] = make_transpose_from_thunk( + new_params.pop('transpose_jaxpr_thunk'), + new_params['lin_tree']) + call_jaxpr: core.ClosedJaxpr = new_params.pop('call_jaxpr') + call = lu.wrap_init(core.jaxpr_as_fun(call_jaxpr), + debug_info=call_jaxpr.jaxpr.debug_info) + else: + assert 'transpose' in params + new_params: dict[str, Any] = dict(params) + call = new_params.pop("call") return [call], new_params diff --git a/jax/_src/debugging.py b/jax/_src/debugging.py index b61b28e12f43..dc140c22650d 100644 --- a/jax/_src/debugging.py +++ b/jax/_src/debugging.py @@ -78,8 +78,8 @@ class OrderedDebugEffect(effects.Effect): @debug_callback_p.def_impl def debug_callback_impl(*args, callback: Callable[..., Any], - effect: DebugEffect): - del effect + effect: DebugEffect, partitioned: bool): + del effect, partitioned try: cpu_device, *_ = jax.local_devices(backend="cpu") except RuntimeError as e: @@ -99,8 +99,8 @@ def debug_callback_impl(*args, callback: Callable[..., Any], @debug_callback_p.def_effectful_abstract_eval def debug_callback_abstract_eval(*flat_avals, callback: Callable[..., Any], - effect: DebugEffect): - del flat_avals, callback + effect: DebugEffect, partitioned: bool): + del flat_avals, callback, partitioned return [], {effect} def debug_callback_batching_rule(args, dims, **params): @@ -144,7 +144,7 @@ def f(): lambda: []) return shard_map(f, axis_context.mesh, in_specs=(), out_specs=[])() -def debug_callback_lowering(ctx, *args, effect, callback, **params): +def debug_callback_lowering(ctx, *args, effect, partitioned, callback, **params): axis_context = ctx.module_context.axis_context if isinstance(axis_context, sharding_impls.SPMDAxisContext): # We're a shard_map, which might be partial-manual or full-manual. @@ -152,8 +152,14 @@ def debug_callback_lowering(ctx, *args, effect, callback, **params): if partial_auto: # If we have partial manual / partial auto sharding, we gather and # conditionally run the callback. - lower = partial(_debug_callback_partial_auto, axis_context, - effect=effect, callback=callback, **params) + lower = partial( + _debug_callback_partial_auto, + axis_context, + effect=effect, + partitioned=partitioned, + callback=callback, + **params, + ) return mlir.lower_fun(lower)(ctx, *args) elif set(axis_context.manual_axes) == set(axis_context.mesh.axis_names): # If we have fully manual sharding during lowering, that means the JAX @@ -164,7 +170,7 @@ def debug_callback_lowering(ctx, *args, effect, callback, **params): sharding_impls.SdyArraySharding( mesh_shape=(), dimension_shardings=[ - sharding_impls.SdyDimSharding(axes=[], is_closed=True) + sharding_impls.SdyDimSharding(axes=[], is_open=False) ] * ctx.avals_out[0].ndim, logical_device_ids=())]) else: @@ -191,18 +197,23 @@ def debug_callback_lowering(ctx, *args, effect, callback, **params): def _callback(*flat_args): debug_callback_p.impl( - *flat_args, effect=effect, callback=callback, **params) + *flat_args, + effect=effect, + partitioned=partitioned, + callback=callback, + **params, + ) return () if effects.ordered_effects.contains(effect): token = ctx.tokens_in.get(effect) result, token, _ = cb.emit_python_callback( ctx, _callback, token, list(args), ctx.avals_in, ctx.avals_out, - has_side_effect=True) + has_side_effect=True, partitioned=partitioned) ctx.set_tokens_out(mlir.TokenSet({effect: token})) else: result, _, _ = cb.emit_python_callback( ctx, _callback, None, list(args), ctx.avals_in, ctx.avals_out, - has_side_effect=True, sharding=sharding) + has_side_effect=True, partitioned=partitioned, sharding=sharding) return result mlir.register_lowering(debug_callback_p, debug_callback_lowering, platform="cpu") @@ -244,14 +255,22 @@ def _debug_callback_partial_eval_custom(saveable, unks_in, inst_in, eqn): @state_discharge.register_discharge_rule(debug_callback_p) def _debug_callback_state_discharge_rule( - in_avals, out_avals, *args, effect, callback, **params + in_avals, out_avals, *args, effect, partitioned, callback, **params ): del in_avals, out_avals # Unused. - out = debug_callback_p.bind(*args, effect=effect, callback=callback, **params) + out = debug_callback_p.bind( + *args, effect=effect, partitioned=partitioned, callback=callback, **params + ) return args, out -def debug_callback(callback: Callable[..., None], *args: Any, - ordered: bool = False, **kwargs: Any) -> None: + +def debug_callback( + callback: Callable[..., None], + *args: Any, + ordered: bool = False, + partitioned: bool = False, + **kwargs: Any, +) -> None: """Calls a stageable Python callback. For more explanation, see `External Callbacks`_. @@ -274,6 +293,9 @@ def debug_callback(callback: Callable[..., None], *args: Any, ordered: A keyword only argument used to indicate whether or not the staged out computation will enforce ordering of this callback w.r.t. other ordered callbacks. + partitioned: If True, then print local shards only; this option avoids an + all-gather of the operands. If False, print with logical operands; this + option requires an all-gather of operands first. **kwargs: The keyword arguments to the callback. Returns: @@ -284,7 +306,7 @@ def debug_callback(callback: Callable[..., None], *args: Any, - :func:`jax.pure_callback`: callback designed for pure functions. - :func:`jax.debug.print`: callback designed for printing. - .. _External Callbacks: https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html + .. _External Callbacks: https://docs.jax.dev/en/latest/notebooks/external_callbacks.html """ if not callable(callback): raise TypeError("first argument to jax.debug.callback must be callable, " @@ -312,7 +334,10 @@ def _flat_callback(*dyn_args): return () effect = ordered_debug_effect if ordered else debug_effect - debug_callback_p.bind(*dyn_args, callback=_flat_callback, effect=effect) + debug_callback_p.bind( + *dyn_args, callback=_flat_callback, effect=effect, partitioned=partitioned + ) + class _DebugPrintFormatChecker(string.Formatter): @@ -338,7 +363,10 @@ def _format_print_callback(fmt: str, np_printoptions, *args, **kwargs): with np.printoptions(**np_printoptions): sys.stdout.write(fmt.format(*args, **kwargs) + "\n") -def debug_print(fmt: str, *args, ordered: bool = False, **kwargs) -> None: + +def debug_print( + fmt: str, *args, ordered: bool = False, partitioned: bool = False, **kwargs +) -> None: """Prints values and works in staged out JAX functions. This function does *not* work with f-strings because formatting is delayed. @@ -367,6 +395,9 @@ def debug_print(fmt: str, *args, **kwargs): ordered: A keyword only argument used to indicate whether or not the staged out computation will enforce ordering of this ``jax.debug.print`` w.r.t. other ordered ``jax.debug.print`` calls. + partitioned: If True, then print local shards only; this option avoids an + all-gather of the operands. If False, print with logical operands; this + option requires an all-gather of operands first. **kwargs: Additional keyword arguments to be formatted, as if passed to ``fmt.format``. """ @@ -374,7 +405,7 @@ def debug_print(fmt: str, *args, **kwargs): formatter.format(fmt, *args, **kwargs) debug_callback(partial(_format_print_callback, fmt, np.get_printoptions()), - *args, **kwargs, ordered=ordered) + *args, **kwargs, ordered=ordered, partitioned=partitioned) # Sharding visualization diff --git a/jax/_src/deprecations.py b/jax/_src/deprecations.py index 37f2f0264782..329491b1e8a8 100644 --- a/jax/_src/deprecations.py +++ b/jax/_src/deprecations.py @@ -127,12 +127,11 @@ def warn(deprecation_id: str, message: str, stacklevel: int) -> None: register('jax-dlpack-import-legacy') register('jax-nn-one-hot-float-input') register("jax-numpy-astype-complex-to-real") -register("jax-numpy-array-none") register('jax-numpy-clip-args') register('jax-numpy-linalg-matrix_rank-tol') register('jax-numpy-linalg-pinv-rcond') register('jax-numpy-quantile-interpolation') register('jax-numpy-reduction-non-boolean-where') register('jax-numpy-trimzeros-not-1d-array') -register('pallas-gpu-triton') register('jax-scipy-special-sph-harm') +register('jax-jit-positional-args') diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 2330f7628966..eea687145c0e 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -44,14 +44,15 @@ from jax._src.interpreters import pxla from jax._src.interpreters import xla from jax._src.layout import DeviceLocalLayout, Layout +from jax._src.lib import jaxlib_extension_version from jax._src.lib import xla_client as xc from jax._src.mesh import AbstractMesh, Mesh from jax._src.monitoring import record_event_duration_secs, record_event_time_span from jax._src.partition_spec import PartitionSpec from jax._src.sharding import Sharding -from jax._src.sharding_impls import ( NamedSharding, - SingleDeviceSharding, TransferToMemoryKind, - is_single_device_sharding) +from jax._src.sharding_impls import ( + NamedSharding, SingleDeviceSharding, TransferToMemoryKind, GSPMDSharding, + PositionalSharding, is_single_device_sharding) import numpy as np @@ -132,11 +133,11 @@ def get_token_input( # TODO(yueshengys): This might still be buggy in a multi-process SPMD # scenario. Revise the logic later. A distributed shutdown barrier inside # the XLA program may be needed. - return jax.device_put(tok, jax.sharding.PositionalSharding(devices)) + return jax.device_put(tok, PositionalSharding(devices)) # We only use replicated sharding for the first time when the token for the # order effect hasn't been created. - s = jax.sharding.GSPMDSharding.get_replicated(devices) + s = GSPMDSharding.get_replicated(devices) sharded_tok = core.Token(pxla.shard_args([s], [None], [None], [tok])[0]) self.current_tokens[eff] = sharded_tok return sharded_tok @@ -495,6 +496,9 @@ def _device_put_sharding_impl(x, aval, device, copy): return _DeferredShardArg(x, x.sharding, aval, x.committed, copy) elif is_single_device_sharding(x.sharding): device = x.sharding._device_assignment[0] if device is None else device + if copy == CopySemantics.COPY and jaxlib_extension_version >= 327: + return xc.batched_device_put(aval, SingleDeviceSharding(device), [x], + [device], True, True) return pxla.batched_device_put(aval, SingleDeviceSharding(device), [x], [device]) diff --git a/jax/_src/distributed.py b/jax/_src/distributed.py index af50e2e9e31a..fb0aebb0e642 100644 --- a/jax/_src/distributed.py +++ b/jax/_src/distributed.py @@ -41,6 +41,7 @@ class State: client: Any | None = None preemption_sync_manager: Any | None = None coordinator_address: str | None = None + slice_index: int | None = None def initialize(self, coordinator_address: str | None = None, @@ -53,7 +54,8 @@ def initialize(self, service_heartbeat_interval_seconds: int = 10, service_max_missing_heartbeats: int = 10, client_heartbeat_interval_seconds: int = 10, - client_max_missing_heartbeats: int = 10): + client_max_missing_heartbeats: int = 10, + slice_index: int | None = None): coordinator_address = (coordinator_address or os.environ.get('JAX_COORDINATOR_ADDRESS')) if isinstance(local_device_ids, int): @@ -149,6 +151,10 @@ def initialize(self, self.initialize_preemption_sync_manager() + if slice_index is None and 'JAX_SLICE_INDEX' in os.environ: + slice_index = int(os.environ.get('JAX_SLICE_INDEX')) # type: ignore + self.slice_index = slice_index + def shutdown(self): if self.client: self.client.shutdown() @@ -175,7 +181,8 @@ def initialize(coordinator_address: str | None = None, local_device_ids: int | Sequence[int] | None = None, cluster_detection_method: str | None = None, initialization_timeout: int = 300, - coordinator_bind_address: str | None = None): + coordinator_bind_address: str | None = None, + slice_index: int | None = None): """Initializes the JAX distributed system. Calling :func:`~jax.distributed.initialize` prepares JAX for execution on @@ -236,6 +243,8 @@ def initialize(coordinator_address: str | None = None, all available addresses on the same port as ``coordinator_address``. On systems that have multiple network interfaces per node it may be insufficient to only have the coordinator service listen on one address/interface. + slice_index: The slice index assigned to this process' local devices. If any process sets ``slice_index``, + then all processes must do so. If ``None`` the slice indices will be chosen automatically. Raises: RuntimeError: If :func:`~jax.distributed.initialize` is called more than once @@ -261,7 +270,8 @@ def initialize(coordinator_address: str | None = None, "This includes any computation, but also calls to jax.devices, jax.device_put, and others.") global_state.initialize(coordinator_address, num_processes, process_id, local_device_ids, cluster_detection_method, - initialization_timeout, coordinator_bind_address) + initialization_timeout, coordinator_bind_address, + slice_index=slice_index) def is_initialized() -> bool: diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index 01500c008405..d1e5b7bf430b 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -90,19 +90,18 @@ def type(self) -> type: ... # fp8 support -# TODO: remove Optional when minimum ml_dtypes version >= 0.5.0 -float8_e3m4: type[np.generic] | None = None -float8_e4m3: type[np.generic] | None = None -float8_e8m0fnu: type[np.generic] | None = None +float8_e3m4: type[np.generic] = ml_dtypes.float8_e3m4 +float8_e4m3: type[np.generic] = ml_dtypes.float8_e4m3 +float8_e8m0fnu: type[np.generic] = ml_dtypes.float8_e8m0fnu float8_e4m3b11fnuz: type[np.generic] = ml_dtypes.float8_e4m3b11fnuz float8_e4m3fn: type[np.generic] = ml_dtypes.float8_e4m3fn float8_e4m3fnuz: type[np.generic] = ml_dtypes.float8_e4m3fnuz float8_e5m2: type[np.generic] = ml_dtypes.float8_e5m2 float8_e5m2fnuz: type[np.generic] = ml_dtypes.float8_e5m2fnuz -_float8_e3m4_dtype: np.dtype | None = None -_float8_e4m3_dtype: np.dtype | None = None -_float8_e8m0fnu_dtype: np.dtype | None = None +_float8_e3m4_dtype: np.dtype = np.dtype(float8_e3m4) +_float8_e4m3_dtype: np.dtype = np.dtype(float8_e4m3) +_float8_e8m0fnu_dtype: np.dtype = np.dtype(float8_e8m0fnu) _float8_e4m3b11fnuz_dtype: np.dtype = np.dtype(float8_e4m3b11fnuz) _float8_e4m3fn_dtype: np.dtype = np.dtype(float8_e4m3fn) _float8_e4m3fnuz_dtype: np.dtype = np.dtype(float8_e4m3fnuz) @@ -111,9 +110,9 @@ def type(self) -> type: ... # fp4 support # TODO: remove Optional when minimum ml_dtypes version >= 0.5.0 -float4_e2m1fn: type[np.generic] | None = None +float4_e2m1fn: type[np.generic] = ml_dtypes.float4_e2m1fn -_float4_e2m1fn_dtype: np.dtype | None = None +_float4_e2m1fn_dtype: np.dtype = np.dtype(float4_e2m1fn) def supports_inf(dtype: DTypeLike) -> bool: """Return true if the dtype supports infinity, else return False.""" @@ -127,6 +126,10 @@ def supports_inf(dtype: DTypeLike) -> bool: _bfloat16_dtype: np.dtype = np.dtype(bfloat16) _custom_float_scalar_types = [ + float4_e2m1fn, + float8_e3m4, + float8_e4m3, + float8_e8m0fnu, float8_e4m3b11fnuz, float8_e4m3fn, float8_e4m3fnuz, @@ -135,6 +138,10 @@ def supports_inf(dtype: DTypeLike) -> bool: bfloat16, ] _custom_float_dtypes = [ + _float4_e2m1fn_dtype, + _float8_e3m4_dtype, + _float8_e4m3_dtype, + _float8_e8m0fnu_dtype, _float8_e4m3b11fnuz_dtype, _float8_e4m3fn_dtype, _float8_e4m3fnuz_dtype, @@ -143,6 +150,9 @@ def supports_inf(dtype: DTypeLike) -> bool: _bfloat16_dtype, ] _float8_dtypes = [ + _float8_e3m4_dtype, + _float8_e4m3_dtype, + _float8_e8m0fnu_dtype, _float8_e4m3b11fnuz_dtype, _float8_e4m3fn_dtype, _float8_e4m3fnuz_dtype, @@ -150,58 +160,28 @@ def supports_inf(dtype: DTypeLike) -> bool: _float8_e5m2fnuz_dtype, ] -_float4_dtypes: list[np.dtype] = [] - -# TODO: remove the if statements below when minimum ml_dtypes version >= 0.5.0 -if hasattr(ml_dtypes, "float8_e4m3"): - float8_e4m3 = ml_dtypes.float8_e4m3 - _float8_e4m3_dtype = np.dtype(float8_e4m3) - _custom_float_scalar_types.insert(0, float8_e4m3) # type: ignore[arg-type] - _custom_float_dtypes.insert(0, _float8_e4m3_dtype) - _float8_dtypes.insert(0, _float8_e4m3_dtype) -if hasattr(ml_dtypes, "float8_e3m4"): - float8_e3m4 = ml_dtypes.float8_e3m4 - _float8_e3m4_dtype = np.dtype(float8_e3m4) - _custom_float_scalar_types.insert(0, float8_e3m4) # type: ignore[arg-type] - _custom_float_dtypes.insert(0, _float8_e3m4_dtype) - _float8_dtypes.insert(0, _float8_e3m4_dtype) -if hasattr(ml_dtypes, "float8_e8m0fnu"): - float8_e8m0fnu = ml_dtypes.float8_e8m0fnu - _float8_e8m0fnu_dtype = np.dtype(float8_e8m0fnu) - _custom_float_scalar_types.insert(0, float8_e8m0fnu) # type: ignore[arg-type] - _custom_float_dtypes.insert(0, _float8_e8m0fnu_dtype) - _float8_dtypes.insert(0, _float8_e8m0fnu_dtype) -if hasattr(ml_dtypes, "float4_e2m1fn"): - float4_e2m1fn = ml_dtypes.float4_e2m1fn - _float4_e2m1fn_dtype = np.dtype(float4_e2m1fn) - _custom_float_scalar_types.insert(0, float4_e2m1fn) # type: ignore[arg-type] - _custom_float_dtypes.insert(0, _float4_e2m1fn_dtype) - _float4_dtypes.insert(0, _float4_e2m1fn_dtype) - -# 2-bit integer support -int2: type[np.generic] | None = None -uint2: type[np.generic] | None = None - -_int2_dtype: np.dtype | None = None -_uint2_dtype: np.dtype | None = None - -_intn_dtypes = [] - -# Remove the condition once the minimum ml_dtypes version required by JAX -# contains https://github.com/jax-ml/ml_dtypes/pull/154. -if hasattr(ml_dtypes, 'int2'): - int2 = ml_dtypes.int2 - uint2 = ml_dtypes.uint2 - _int2_dtype = np.dtype(int2) - _uint2_dtype = np.dtype(uint2) - _intn_dtypes.extend([_int2_dtype, _uint2_dtype]) +_float4_dtypes: list[np.dtype] = [ + _float4_e2m1fn_dtype, +] + +int2: type[np.generic] = ml_dtypes.int2 +uint2: type[np.generic] = ml_dtypes.uint2 + +_int2_dtype: np.dtype = np.dtype(int2) +_uint2_dtype: np.dtype = np.dtype(uint2) # 4-bit integer support int4: type[np.generic] = ml_dtypes.int4 uint4: type[np.generic] = ml_dtypes.uint4 _int4_dtype = np.dtype(int4) _uint4_dtype = np.dtype(uint4) -_intn_dtypes.extend([_int4_dtype, _uint4_dtype]) + +_intn_dtypes = [ + _int2_dtype, + _uint2_dtype, + _int4_dtype, + _uint4_dtype, +] # Default types. bool_ = np.bool_ @@ -472,9 +452,9 @@ def _issubdtype_cached(a: type | np.dtype | ExtendedDType, # to the normal scalar type hierarchy. if a_sctype in _custom_float_scalar_types: return b_sctype in {a_sctype, np.floating, np.inexact, np.number, np.generic} - if (int2 is not None and a_sctype == int2) or a_sctype == int4: + if a_sctype in [int2, int4]: return b_sctype in {a_sctype, np.signedinteger, np.integer, np.number, np.generic} - if (uint2 is not None and a_sctype == uint2) or a_sctype == uint4: + if a_sctype in [uint2, uint4]: return b_sctype in {a_sctype, np.unsignedinteger, np.integer, np.number, np.generic} # Otherwise, fall back to numpy.issubdtype @@ -491,6 +471,7 @@ def _issubdtype_cached(a: type | np.dtype | ExtendedDType, _unsigned_types: list[JAXType] _int_types: list[JAXType] _unsigned_types = [ + np.dtype(uint2), np.dtype(uint4), np.dtype('uint8'), np.dtype('uint16'), @@ -498,6 +479,7 @@ def _issubdtype_cached(a: type | np.dtype | ExtendedDType, np.dtype('uint64'), ] _signed_types = [ + np.dtype(int2), np.dtype(int4), np.dtype('int8'), np.dtype('int16'), @@ -505,11 +487,6 @@ def _issubdtype_cached(a: type | np.dtype | ExtendedDType, np.dtype('int64'), ] -if _int2_dtype is not None: - _signed_types.insert(0, _int2_dtype) -if _uint2_dtype is not None: - _unsigned_types.insert(0, _uint2_dtype) - _int_types = _unsigned_types + _signed_types _float_types: list[JAXType] = [ @@ -622,11 +599,7 @@ def _type_promotion_lattice(jax_numpy_dtype_promotion: str) -> dict[JAXType, lis This DAG maps each type to its immediately higher type on the lattice. """ b1, = _bool_types - if _int2_dtype is not None: - assert _uint2_dtype is not None - _uint2, uint4, u1, u2, u4, u8, _int2, int4, i1, i2, i4, i8 = _int_types - else: - uint4, u1, u2, u4, u8, int4, i1, i2, i4, i8 = _int_types + uint2, uint4, u1, u2, u4, u8, int2, int4, i1, i2, i4, i8 = _int_types *f1_types, bf, f2, f4, f8 = _float_types c4, c8 = _complex_types i_, f_, c_ = _weak_types @@ -634,19 +607,13 @@ def _type_promotion_lattice(jax_numpy_dtype_promotion: str) -> dict[JAXType, lis out: dict[JAXType, list[JAXType]] out = { b1: [i_], - i_: [u1, uint4, i1, int4], - uint4: [], u1: [i2, u2], u2: [i4, u4], u4: [i8, u8], u8: [f_], - int4: [], i1: [i2], i2: [i4], i4: [i8], i8: [f_], + i_: [u1, uint2, uint4, i1, int2, int4], + uint2: [], uint4: [], u1: [i2, u2], u2: [i4, u4], u4: [i8, u8], u8: [f_], + int2: [], int4: [], i1: [i2], i2: [i4], i4: [i8], i8: [f_], f_: [*f1_types, bf, f2, c_], **{t: [] for t in f1_types}, bf: [f4], f2: [f4], f4: [f8, c4], f8: [c8], c_: [c4], c4: [c8], c8: [], } - if _int2_dtype is not None: - out[i_].append(_int2_dtype) - out[_int2_dtype] = [] - if _uint2_dtype is not None: - out[i_].append(_uint2_dtype) - out[_uint2_dtype] = [] return out elif jax_numpy_dtype_promotion == 'strict': return { diff --git a/jax/_src/effects.py b/jax/_src/effects.py index 36528c5feae5..d55333540355 100644 --- a/jax/_src/effects.py +++ b/jax/_src/effects.py @@ -47,7 +47,7 @@ for each thread the `RuntimeToken` returned by the last dispatched computation. For more details, see the design note: -https://jax.readthedocs.io/en/latest/jep/10657-sequencing-effects.html. +https://docs.jax.dev/en/latest/jep/10657-sequencing-effects.html. """ from __future__ import annotations diff --git a/jax/_src/error_check.py b/jax/_src/error_check.py index 60dc2f76a5b2..b80def4fd2db 100644 --- a/jax/_src/error_check.py +++ b/jax/_src/error_check.py @@ -14,27 +14,30 @@ from __future__ import annotations +import dataclasses from functools import partial +import json import threading +import traceback as tb_lib +from types import TracebackType +import warnings import jax from jax._src import core from jax._src import source_info_util from jax._src import traceback_util import jax._src.mesh as mesh_lib -from jax.experimental.shard_map import shard_map +from jax.experimental import shard_map +import jax.export import jax.numpy as jnp from jax.sharding import NamedSharding, PartitionSpec as P -Traceback = source_info_util.Traceback - - traceback_util.register_exclusion(__file__) class JaxValueError(ValueError): - """Exception raised for failed runtime error checks in JAX.""" + """Exception raised for runtime errors detected within JAX computations.""" #: The default error code for no error. @@ -44,8 +47,9 @@ class JaxValueError(ValueError): _NO_ERROR = jnp.iinfo(jnp.uint32).max -_error_list_lock = threading.Lock() -_error_list: list[tuple[str, Traceback]] = [] # (error_message, traceback) pair +_error_list_lock = threading.RLock() +# (error_message, traceback) pairs. Traceback is `str` when imported from AOT. +_error_list: list[tuple[str, TracebackType | str]] = [] class _ErrorStorage(threading.local): @@ -58,38 +62,42 @@ def __init__(self): def _initialize_error_code_ref() -> None: - """Initialize error_code_ref in the current thread. + """Initialize the error code ref in the current thread. - The size of the error code array is determined by the mesh in the context. In - single-device environment, the array is a scalar. In multi-device - environment, the array has the same shape as the mesh. + The shape and size of the error code array depend on the mesh in the context. + In single-device environments, the array is a scalar. In multi-device + environments, its shape and size match those of the mesh. """ - with core.eval_context(): - # Get mesh from the context. - mesh = mesh_lib.get_concrete_mesh() - - if mesh is None: # single-device case. - error_code = jnp.uint32(_NO_ERROR) - - else: # multi-device case. - sharding = NamedSharding(mesh, P(*mesh.axis_names)) - error_code = jnp.full( - mesh.axis_sizes, - jnp.uint32(_NO_ERROR), - device=sharding, - ) + # Get mesh from the context. + mesh = mesh_lib.get_concrete_mesh() + + if mesh is None: # single-device case. + error_code = jnp.uint32(_NO_ERROR) + + else: # multi-device case. + sharding = NamedSharding(mesh, P(*mesh.axis_names)) + error_code = jnp.full( + mesh.axis_sizes, + jnp.uint32(_NO_ERROR), + device=sharding, + ) - _error_storage.ref = core.mutable_array(error_code) + _error_storage.ref = core.mutable_array(error_code) class error_checking_context: - """Redefine the error checking state based on the mesh in the context. + """Redefine the internal error state based on the mesh in the context. - This context manager should be used when starting a multi-device - computation, and whenever the mesh is changed. + When using JAX in multi-device environments in explicit mode, error tracking + needs to be properly aligned with the device mesh. This context manager + ensures that the internal error state is correctly initialized based on the + current mesh configuration. - When exiting the context, the error checking state will be reset to the - original state. + This context manager should be used when starting a multi-device computation, + or when switching between different device meshes. + + On entering the context, it initializes a new error state based on the mesh in + the context. On exiting the context, it restores the previous error state. """ __slots__ = ("old_ref",) @@ -99,7 +107,8 @@ def __init__(self): def __enter__(self): self.old_ref = _error_storage.ref - _initialize_error_code_ref() + with core.eval_context(): + _initialize_error_code_ref() return self def __exit__(self, exc_type, exc_value, traceback): @@ -107,19 +116,46 @@ def __exit__(self, exc_type, exc_value, traceback): def set_error_if(pred: jax.Array, /, msg: str) -> None: - """Set error if any element of pred is true. - - If the error is already set, the new error will be ignored. It will not - override the existing error. - - In auto mode, this function does not work under jit. + """Set the internal error state if any element of `pred` is `True`. + + This function is used inside JAX computations to detect runtime errors without + immediately halting execution. When this function is traced (e.g., inside + :func:`jax.jit`), the corresponding error message and its traceback are + recorded. At execution time, if `pred` contains any `True` values, the error + state is set, but execution continues without interruption. The recorded error + can later be raised using :func:`raise_if_error`. + + If the error state has already been set, subsequent errors are ignored and + will not override the existing error. + + For multi-device environments, in explicit mode, users must call + :func:`error_checking_context` to initialize a new error tracking state that + matches the device mesh. In auto mode, implicit cross-device communication may + occur inside this function, which could impact performance. A warning is + issued in such cases. + + When exporting a function with `jax.export`, error checking must be explicitly + wrapped using :func:`wrap_for_export` before export and + :func:`unwrap_from_import` after import. + + Args: + pred: A JAX boolean array. If any element of `pred` is `True`, the internal + error state will be set. + msg: The corresponding error message to be raised later. """ if _error_storage.ref is None: - _initialize_error_code_ref() + with core.eval_context(): + _initialize_error_code_ref() assert _error_storage.ref is not None + # Get the traceback. traceback = source_info_util.current().traceback assert traceback is not None + traceback = traceback.as_python_traceback() + assert isinstance(traceback, TracebackType) + traceback = traceback_util.filter_traceback(traceback) + assert isinstance(traceback, TracebackType) + with _error_list_lock: new_error_code = jnp.uint32(len(_error_list)) _error_list.append((msg, traceback)) @@ -127,41 +163,55 @@ def set_error_if(pred: jax.Array, /, msg: str) -> None: out_sharding = core.typeof(_error_storage.ref).sharding in_sharding: NamedSharding = core.typeof(pred).sharding - if out_sharding.mesh.shape_tuple == (): # single-device case. + # Reduce `pred`. + if all(dim is None for dim in out_sharding.spec): # single-device case. pred = pred.any() else: # multi-device case. has_auto_axes = mesh_lib.AxisType.Auto in in_sharding.mesh.axis_types - if has_auto_axes: - raise NotImplementedError( - "Error checking in auto mode is not supported yet. Please use" - " explicit mode." - ) - if out_sharding.mesh != in_sharding.mesh: - raise ValueError( - "The error code state and the predicate must be on the same mesh, " - f"but got {out_sharding.mesh} and {in_sharding.mesh} respectively. " - "Please use `with error_checking_context()` to redefine the error " - "code state based on the mesh." + if has_auto_axes: # auto mode. + warnings.warn( + "When at least one mesh axis of `pred` is in auto mode, calling" + " `set_error_if` will cause implicit communication between devices." + " To avoid this, consider converting the mesh axis in auto mode to" + " explicit mode.", + RuntimeWarning, ) - pred = shard_map( - partial(jnp.any, keepdims=True), - mesh=out_sharding.mesh, - in_specs=in_sharding.spec, - out_specs=out_sharding.spec, - )(pred) # perform per-device reduction + pred = pred.any() # reduce to a single scalar + else: # explicit mode. + if out_sharding.mesh != in_sharding.mesh: + raise ValueError( + "The error code state and the predicate must be on the same mesh, " + f"but got {out_sharding.mesh} and {in_sharding.mesh} respectively. " + "Please use `with error_checking_context()` to redefine the error " + "code state based on the mesh." + ) + pred = shard_map.shard_map( + partial(jnp.any, keepdims=True), + mesh=out_sharding.mesh, + in_specs=in_sharding.spec, + out_specs=out_sharding.spec, + )(pred) # perform per-device reduction error_code = _error_storage.ref[...] - should_update = jnp.logical_and(pred, error_code == jnp.uint32(_NO_ERROR)) + should_update = jnp.logical_and(error_code == jnp.uint32(_NO_ERROR), pred) error_code = jnp.where(should_update, new_error_code, error_code) # TODO(ayx): support vmap and shard_map. _error_storage.ref[...] = error_code def raise_if_error() -> None: - """Raise error if an error is set. + """Raise an exception if the internal error state is set. + + This function should be called after a computation completes to check for any + errors that were marked during execution via `set_error_if()`. If an error + exists, it raises a `JaxValueError` with the corresponding error message. + + This function should not be called inside a traced function (e.g., inside + :func:`jax.jit`). Doing so will raise a `ValueError`. - This function should be called after the computation is finished. It should - not be called within a traced context, such as within a jitted function." + Raises: + JaxValueError: If the internal error state is set. + ValueError: If called within a traced JAX function. """ if _error_storage.ref is None: # if not initialized, do nothing return @@ -180,8 +230,136 @@ def raise_if_error() -> None: device=_error_storage.ref.sharding, ) # clear the error code - msg, traceback = _error_list[error_code] - exc = JaxValueError(msg) - traceback = traceback.as_python_traceback() - filtered_traceback = traceback_util.filter_traceback(traceback) - raise exc.with_traceback(filtered_traceback) + with _error_list_lock: + msg, traceback = _error_list[error_code] + if isinstance(traceback, str): # from imported AOT functions + exc = JaxValueError( + f"{msg}\nThe original traceback is shown below:\n{traceback}" + ) + raise exc + else: + exc = JaxValueError(msg) + raise exc.with_traceback(traceback) + + +@dataclasses.dataclass(frozen=True) +class _ErrorClass: + """A class to store error information for AOT compilation. + + This class is used internally by the wrapper functions `wrap_for_export` and + `unwrap_from_import` to encapsulate error-related data within an exported + function. + + Attributes: + error_code (jax.Array): A JAX array representing the final error state of + the function to be exported. This value is local to the wrapper function. + error_list (list[tuple[str, str]]): A list of `(error_message, traceback)` + pairs containing error messages and corresponding stack traces. This error + list is local to the wrapper function, and does not contain pairs of error + information from other functions. + """ + + error_code: jax.Array + error_list: list[tuple[str, str]] + + +jax.tree_util.register_dataclass( + _ErrorClass, data_fields=("error_code",), meta_fields=("error_list",) +) +jax.export.register_pytree_node_serialization( + _ErrorClass, + serialized_name=f"{_ErrorClass.__module__}.{_ErrorClass.__name__}", + serialize_auxdata=lambda x: json.dumps(x, ensure_ascii=False).encode( + "utf-8" + ), + deserialize_auxdata=lambda x: json.loads(x.decode("utf-8")), +) + + +def _traceback_to_str(traceback: TracebackType) -> str: + """Convert a traceback to a string for export.""" + return "".join(tb_lib.format_list(tb_lib.extract_tb(traceback))).rstrip("\n") + + +def wrap_for_export(f): + """Wrap a function with error checking to make it compatible with AOT mode. + + Error checking relies on global state, which cannot be serialized across + processes. This wrapper ensures that the error state remains within the + function scope, making it possible to export the function and later import in + other processes. + + When the function is later imported, it must be wrapped with + :func:`unwrap_from_import` to integrate the error checking mechanism of the + imported function into the global error checking mechanism of the current + process. + + This function should only be applied once to a function; wrapping the same + function multiple times is unnecessary. + """ + + def inner(*args, **kwargs): + global _error_list + + # 1. Save the old state and initialize a new state. + with core.eval_context(): + old_ref = _error_storage.ref + _initialize_error_code_ref() + with _error_list_lock: + old_error_list, _error_list = _error_list, [] + + # 2. Trace the function. + out = f(*args, **kwargs) + error_code = _error_storage.ref[...].min() + + # 3. Restore the old state. + _error_list, new_error_list = old_error_list, _error_list + with core.eval_context(): + _error_storage.ref = old_ref + + new_error_list = [ + (msg, _traceback_to_str(traceback)) for msg, traceback in new_error_list + ] + return out, _ErrorClass(error_code, new_error_list) + + return inner + + +def unwrap_from_import(f): + """Unwrap a function after AOT import to restore error checking. + + When an AOT-exported function is imported in a new process, its error state is + separate from the global error state of the current process. This wrapper + ensures that errors detected during execution are correctly integrated into + the global error checking mechanism of the current process. + + This function should only be applied to functions that were previously wrapped + with :func:`wrap_for_export` before export. + """ + if _error_storage.ref is None: + with core.eval_context(): + _initialize_error_code_ref() + assert _error_storage.ref is not None + + def inner(*args, **kwargs): + out, error_class = f(*args, **kwargs) + new_error_code, error_list = error_class.error_code, error_class.error_list + + # Update the global error list. + with _error_list_lock: + offset = len(_error_list) + _error_list.extend(error_list) + + # Update the global error code array. + error_code = _error_storage.ref[...] + should_update = jnp.logical_and( + error_code == jnp.uint32(_NO_ERROR), + new_error_code != jnp.uint32(_NO_ERROR), + ) + error_code = jnp.where(should_update, new_error_code + offset, error_code) + # TODO(ayx): support vmap and shard_map. + _error_storage.ref[...] = error_code + + return out + + return inner diff --git a/jax/_src/errors.py b/jax/_src/errors.py index 6540fd1f5d41..20b82f629f6f 100644 --- a/jax/_src/errors.py +++ b/jax/_src/errors.py @@ -21,7 +21,7 @@ class _JAXErrorMixin: """Mixin for JAX-specific errors""" - _error_page = 'https://jax.readthedocs.io/en/latest/errors.html' + _error_page = 'https://docs.jax.dev/en/latest/errors.html' _module_name = "jax.errors" def __init__(self, message: str): @@ -306,7 +306,7 @@ class TracerArrayConversionError(JAXTypeError): and concrete vs. abstract values, you may want to read :ref:`faq-different-kinds-of-jax-values`. - .. _External Callbacks: https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html + .. _External Callbacks: https://docs.jax.dev/en/latest/notebooks/external_callbacks.html """ def __init__(self, tracer: core.Tracer): super().__init__( @@ -530,7 +530,7 @@ class UnexpectedTracerError(JAXTypeError): function ``f`` that stores, in some scope outside of ``f``, a reference to an intermediate value, that value is considered to have been leaked. Leaking values is a side effect. (Read more about avoiding side effects in - `Pure Functions `_) + `Pure Functions `_) JAX detects leaks when you then use the leaked value in another operation later on, at which point it raises an ``UnexpectedTracerError``. @@ -678,6 +678,5 @@ class KeyReuseError(JAXTypeError): This sort of key reuse is problematic because the JAX PRNG is stateless, and keys must be manually split; For more information on this see `the Pseudorandom Numbers - tutorial `_. + tutorial `_. """ - pass diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index afae3d9bcdc2..d02ef44f8318 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -43,7 +43,7 @@ from jax._src.interpreters import mlir from jax._src.interpreters import pxla from jax._src.lib import xla_client -from jax._src.lib import xla_extension, xla_extension_version +from jax._src.lib import xla_extension from jax._src.lib.mlir import ir, passmanager from jax._src.lib.mlir.dialects import hlo from jax._src.lib.mlir.dialects import func as func_dialect @@ -67,7 +67,7 @@ HloSharding = xla_client.HloSharding # The minimum and maximum supported calling convention version. -# See https://jax.readthedocs.io/en/latest/export/export.html#export-calling-convention-version +# See https://docs.jax.dev/en/latest/export/export.html#export-calling-convention-version minimum_supported_calling_convention_version = 9 maximum_supported_calling_convention_version = 9 @@ -153,16 +153,16 @@ class Exported: platforms: a tuple containing the platforms for which the function should be exported. The set of platforms in JAX is open-ended; users can add platforms. JAX built-in platforms are: 'tpu', 'cpu', 'cuda', 'rocm'. - See https://jax.readthedocs.io/en/latest/export/export.html#cross-platform-and-multi-platform-export. + See https://docs.jax.dev/en/latest/export/export.html#cross-platform-and-multi-platform-export. ordered_effects: the ordered effects present in the serialized module. - This is present from serialization version 9. See https://jax.readthedocs.io/en/latest/export/export.html#module-calling-convention + This is present from serialization version 9. See https://docs.jax.dev/en/latest/export/export.html#module-calling-convention for the calling convention in presence of ordered effects. unordered_effects: the unordered effects present in the serialized module. This is present from serialization version 9. mlir_module_serialized: the serialized lowered VHLO module. calling_convention_version: a version number for the calling convention of the exported module. - See more versioning details at https://jax.readthedocs.io/en/latest/export/export.html#calling-convention-versions. + See more versioning details at https://docs.jax.dev/en/latest/export/export.html#calling-convention-versions. module_kept_var_idx: the sorted indices of the arguments among `in_avals` that must be passed to the module. The other arguments have been dropped because they are not used. @@ -181,7 +181,7 @@ class Exported: for each primal output. It returns a tuple with the cotangents corresponding to the flattened primal inputs. - See a [description of the calling convention for the `mlir_module`](https://jax.readthedocs.io/en/latest/export/export.html#module-calling-convention). + See a [description of the calling convention for the `mlir_module`](https://docs.jax.dev/en/latest/export/export.html#module-calling-convention). """ fun_name: str in_tree: tree_util.PyTreeDef @@ -306,7 +306,7 @@ def call(self, *args, **kwargs): The invocation supports reverse-mode AD, and all the features supported by exporting: shape polymorphism, multi-platform, device polymorphism. - See the examples in the [JAX export documentation](https://jax.readthedocs.io/en/latest/export/export.html). + See the examples in the [JAX export documentation](https://docs.jax.dev/en/latest/export/export.html). """ return call_exported(self)(*args, **kwargs) @@ -529,6 +529,7 @@ def export( *, platforms: Sequence[str] | None = None, disabled_checks: Sequence[DisabledSafetyCheck] = (), + _override_lowering_rules: Sequence[tuple[Any, Any]] | None = None ) -> Callable[..., Exported]: """Exports a JAX function for persistent serialization. @@ -540,7 +541,14 @@ def export( the exported code takes an argument specifying the platform. If None, then use the default JAX backend. The calling convention for multiple platforms is explained at - https://jax.readthedocs.io/en/latest/export/export.html#module-calling-convention. + https://docs.jax.dev/en/latest/export/export.html#module-calling-convention. + _override_lowering_rules: an optional sequence of custom lowering rules + for some JAX primitives. Each element of the sequence is a pair + of a JAX primitive and a lowering function. Defining lowering rules + is an advanced feature using JAX internal APIs, which are subject + to change. Furthermore, the responsibility for the stability of the + MLIR emitted through these custom lowering rules, rests with the user + of these rules. disabled_checks: the safety checks to disable. See documentation for of `jax.export.DisabledSafetyCheck`. @@ -568,7 +576,8 @@ def export( Array([0.09983342, 0.19866933, 0.29552022, 0.38941833], dtype=float32) """ return _export_internal(fun_jit, platforms=platforms, - disabled_checks=disabled_checks) + disabled_checks=disabled_checks, + override_lowering_rules=_override_lowering_rules) # TODO(necula): remove this once we improve the integration with jax2tf. @@ -577,13 +586,14 @@ def _export_internal( *, platforms: Sequence[str] | None = None, disabled_checks: Sequence[DisabledSafetyCheck] = (), - _device_assignment_for_internal_jax2tf_use_only = None, + _device_assignment_for_internal_jax2tf_use_only=None, + override_lowering_rules=None, ) -> Callable[..., Exported]: """Exports native serialization for a JAX function. Note: this function exists only for internal usage by jax2tf. Use `jax.export` instead. - See https://jax.readthedocs.io/en/latest/export/export.html + See https://docs.jax.dev/en/latest/export/export.html See docstring of `export` for more details. """ @@ -604,6 +614,7 @@ def do_export(*args_specs, **kwargs_specs) -> Exported: lowered = traced.lower( lowering_platforms=actual_lowering_platforms, _private_parameters=mlir.LoweringParameters( + override_lowering_rules=override_lowering_rules, for_export=True, export_ignore_forward_compatibility=config.export_ignore_forward_compatibility.value)) return _export_lowered( @@ -674,10 +685,8 @@ def _export_lowered( # Shardy was used during lowering if we can find the Shardy mesh in the # module. Note that the mesh should have been lifted by the # `sdy-lift-inlined-meshes` pass in mlir.py. - shardy_enabled = False - if xla_extension_version >= 319: - shardy_enabled = xla_extension.sdy.lowered_with_shardy( - mlir.module_to_bytecode(mlir_module)) + shardy_enabled = xla_extension.sdy.lowered_with_shardy( + mlir.module_to_bytecode(mlir_module)) mlir_module_serialized = _module_to_bytecode(mlir_module, shardy_enabled) @@ -784,7 +793,7 @@ def _get_exported_vjp(exp_primal: Exported) -> Exported: _get_vjp=_get_exported_vjp) def _module_to_bytecode(module: ir.Module, shardy_enabled: bool) -> bytes: - if xla_extension_version >= 319 and shardy_enabled: + if shardy_enabled: mlir_str = xla_extension.sdy.sdy_round_trip_export_pipeline( mlir.module_to_bytecode(module)) else: @@ -828,7 +837,7 @@ def _wrap_main_func( ) -> ir.Module: """Wraps the lowered module with a new "main" handling dimension arguments. - See calling convention documentation https://jax.readthedocs.io/en/latest/export/export.html#module-calling-convention. + See calling convention documentation https://docs.jax.dev/en/latest/export/export.html#module-calling-convention. Args: module: the HLO module as obtained from lowering. @@ -1073,17 +1082,12 @@ def _check_lowering(lowering) -> None: *_CPU_FFI_KERNELS, *_GPU_FFI_KERNELS, "Sharding", "SPMDFullToShardShape", "SPMDShardToFullShape", + "annotate_device_placement", "cu_threefry2x32_ffi", # Triton IR does not guarantee stability. # "__gpu$xla.gpu.triton", - # cholesky on CPU - "lapack_spotrf", "lapack_dpotrf", "lapack_cpotrf", "lapack_zpotrf", # eigh on TPU "Eigh", - # eig on CPU - "lapack_sgeev", "lapack_dgeev", "lapack_cgeev", "lapack_zgeev", - # svd on CPU - "lapack_sgesdd", "lapack_dgesdd", "lapack_cgesdd", "lapack_zgesdd", # qr and svd on TPU "Qr", "ProductOfElementaryHouseholderReflectors", # triangular_solve on CPU @@ -1092,8 +1096,6 @@ def _check_lowering(lowering) -> None: "lapack_sgees", "lapack_dgees", "lapack_cgees", "lapack_zgees", # tridiagonal on CPU "lapack_ssytrd", "lapack_dsytrd", "lapack_chetrd", "lapack_zhetrd", - # hessenberg on CPU - "lapack_sgehrd", "lapack_dgehrd", "lapack_cgehrd", "lapack_zgehrd", # lu on TPU "LuDecomposition", # ApproxTopK on TPU @@ -1177,7 +1179,7 @@ def walk_operations(op): disallowed_custom_call_ops_str = "\n".join(disallowed_custom_call_ops) msg = ("Cannot serialize code with custom calls whose targets have no " "compatibility guarantees. " - "See https://jax.readthedocs.io/en/latest/export/export.html#compatibility-guarantees-for-custom-calls. " + "See https://docs.jax.dev/en/latest/export/export.html#compatibility-guarantees-for-custom-calls. " "Examples are:\n" f"{disallowed_custom_call_ops_str}.\n") raise ValueError(msg) @@ -1210,10 +1212,11 @@ def _hlo_sharding_to_xla_compatible_sharding( def _hlo_sharding_to_gspmd_sharding( hlo_sharding: HloSharding | None, - device_assignment: Sequence[jax.Device]) -> sharding.GSPMDSharding | None: + device_assignment: Sequence[jax.Device] + ) -> sharding_impls.GSPMDSharding | None: if hlo_sharding is None: return None - return sharding.GSPMDSharding(device_assignment, hlo_sharding) + return sharding_impls.GSPMDSharding(device_assignment, hlo_sharding) def _hlo_sharding_to_named_sharding( @@ -1423,10 +1426,8 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args, ctx.module_context.shape_poly_state.uses_dim_vars = True submodule = ir.Module.parse(exported.mlir_module()) - shardy_enabled = False - if xla_extension_version >= 319: - shardy_enabled = xla_extension.sdy.lowered_with_shardy( - mlir.module_to_bytecode(submodule)) + shardy_enabled = xla_extension.sdy.lowered_with_shardy( + mlir.module_to_bytecode(submodule)) if shardy_enabled: submodule = ir.Module.parse(xla_extension.sdy.sdy_round_trip_import_shardings( mlir.module_to_bytecode(submodule))) @@ -1436,12 +1437,11 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args, 'builtin.module(sdy-lift-inlined-meshes)') pipeline.run(submodule.operation) - # TODO(bartchr): delete this once I have JAX export support multiple meshes. mesh = None if shardy_enabled: sdy_mesh_axes = xla_extension.sdy.get_mesh(mlir.module_to_bytecode(submodule)) - mesh = mesh_lib.AbstractMesh( - *list(zip(*sdy_mesh_axes))[::-1]) if sdy_mesh_axes else None + mesh = (mesh_lib.AbstractMesh(*list(zip(*sdy_mesh_axes))[::-1]) + if sdy_mesh_axes else mesh_lib.empty_abstract_mesh) axis_context = ctx.module_context.axis_context if isinstance(axis_context, sharding_impls.ShardingContext): diff --git a/jax/_src/export/serialization.fbs b/jax/_src/export/serialization.fbs index 7d3e342f1879..01cfa9944dfd 100644 --- a/jax/_src/export/serialization.fbs +++ b/jax/_src/export/serialization.fbs @@ -45,7 +45,7 @@ enum AbstractValueKind: byte { } enum DType: byte { - // Last used id: 22 + // Last used id: 29 bool = 0, i8 = 1, i16 = 2, @@ -76,6 +76,10 @@ enum DType: byte { f8_e5m2fnuz = 21, f8_e8m0fnu = 25, f4_e2m1fn = 26, + + key_fry = 27, + key_rbg = 28, + key_unsafe_rbg = 29, } table AbstractValue { diff --git a/jax/_src/export/serialization.py b/jax/_src/export/serialization.py index ac97c11d1177..3d878cccc701 100644 --- a/jax/_src/export/serialization.py +++ b/jax/_src/export/serialization.py @@ -31,6 +31,7 @@ from jax._src import core from jax._src import dtypes from jax._src import effects +from jax._src import prng from jax._src import tree_util from jax._src.export import serialization_generated as ser_flatbuf from jax._src.export import _export @@ -48,6 +49,8 @@ # Version 2, Dec 16th, 2023, adds the f0 dtype. # Version 3, October 16th, 2024, adds serialization for namedtuple and custom types # This version is backwards compatible with Version 2. +# Version 4, April 7th, 2025, adds serialization for PRNGs key types. +# This version is backwards compatible with Version 2 and 3. _SERIALIZATION_VERSION = 2 def serialize(exp: _export.Exported, vjp_order: int = 0) -> bytearray: @@ -357,16 +360,16 @@ def _deserialize_pytreedef_to_pytree(p: ser_flatbuf.PyTreeDef): dtypes._float8_e4m3fnuz_dtype: ser_flatbuf.DType.f8_e4m3fnuz, dtypes._float8_e5m2_dtype: ser_flatbuf.DType.f8_e5m2, dtypes._float8_e5m2fnuz_dtype: ser_flatbuf.DType.f8_e5m2fnuz, + dtypes._float8_e3m4_dtype: ser_flatbuf.DType.f8_e3m4, + dtypes._float8_e4m3_dtype: ser_flatbuf.DType.f8_e4m3, + dtypes._float8_e8m0fnu_dtype: ser_flatbuf.DType.f8_e8m0fnu, + dtypes._float4_e2m1fn_dtype: ser_flatbuf.DType.f4_e2m1fn, + + prng.KeyTy(prng.prngs["threefry2x32"]): ser_flatbuf.DType.key_fry, + prng.KeyTy(prng.prngs["rbg"]): ser_flatbuf.DType.key_rbg, + prng.KeyTy(prng.prngs["unsafe_rbg"]): ser_flatbuf.DType.key_unsafe_rbg, } -if dtypes._float8_e3m4_dtype is not None: - _dtype_to_dtype_kind[dtypes._float8_e3m4_dtype] = ser_flatbuf.DType.f8_e3m4 -if dtypes._float8_e4m3_dtype is not None: - _dtype_to_dtype_kind[dtypes._float8_e4m3_dtype] = ser_flatbuf.DType.f8_e4m3 -if dtypes._float8_e8m0fnu_dtype is not None: - _dtype_to_dtype_kind[dtypes._float8_e8m0fnu_dtype] = ser_flatbuf.DType.f8_e8m0fnu -if dtypes._float4_e2m1fn_dtype is not None: - _dtype_to_dtype_kind[dtypes._float4_e2m1fn_dtype] = ser_flatbuf.DType.f4_e2m1fn _dtype_kind_to_dtype = { kind: dtype for dtype, kind in _dtype_to_dtype_kind.items() } diff --git a/jax/_src/export/serialization_generated.py b/jax/_src/export/serialization_generated.py index b1fc13333777..34211c1ebe54 100644 --- a/jax/_src/export/serialization_generated.py +++ b/jax/_src/export/serialization_generated.py @@ -53,16 +53,19 @@ class DType(object): bf16 = 14 i4 = 15 ui4 = 16 - f8_e3m4 = 24 - f8_e4m3 = 23 f8_e4m3b11fnuz = 17 f8_e4m3fn = 18 f8_e4m3fnuz = 19 f8_e5m2 = 20 f8_e5m2fnuz = 21 f0 = 22 + f8_e4m3 = 23 + f8_e3m4 = 24 f8_e8m0fnu = 25 f4_e2m1fn = 26 + key_fry = 27 + key_rbg = 28 + key_unsafe_rbg = 29 class ShardingKind(object): diff --git a/jax/_src/export/shape_poly.py b/jax/_src/export/shape_poly.py index 6a6ce93712ff..405592cadd2b 100644 --- a/jax/_src/export/shape_poly.py +++ b/jax/_src/export/shape_poly.py @@ -13,7 +13,7 @@ # limitations under the License. """Shape polymorphism support. -See documentation at https://jax.readthedocs.io/en/latest/export/shape_poly.html. +See documentation at https://docs.jax.dev/en/latest/export/shape_poly.html. """ from __future__ import annotations @@ -70,7 +70,7 @@ class InconclusiveDimensionOperation(core.InconclusiveDimensionOperation): are non-constant, and the result of the operation cannot be represented as a boolean value for all values of the symbolic dimensions involved. -Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#comparison-of-symbolic-dimensions-is-partially-supported +Please see https://docs.jax.dev/en/latest/export/shape_poly.html#comparison-of-symbolic-dimensions-is-partially-supported for more details. """ @@ -227,7 +227,7 @@ def evaluate(self, env: DimVarEnv, scope: SymbolicScope): return normalized_var._evaluate(env) # type: ignore err_msg = ( f"Encountered dimension variable '{self.var}' that is not appearing in the shapes of the function arguments.\n" - "Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#dimension-variables-must-be-solvable-from-the-input-shapes for more details.") + "Please see https://docs.jax.dev/en/latest/export/shape_poly.html#dimension-variables-must-be-solvable-from-the-input-shapes for more details.") raise UnexpectedDimVar(err_msg) else: operand_values = [opnd._evaluate(env) for opnd in self.operands] @@ -654,7 +654,7 @@ def _eq(self, other: _DimExpr) -> bool: # Here we really ought to raise InconclusiveDimensionOperation, but __eq__ # cannot raise exceptions, because it is used indirectly when hashing. # So, we say that the expressions are disequal, which is really unsound. - # See https://jax.readthedocs.io/en/latest/export/shape_poly.html#comparison-of-symbolic-dimensions-is-partially-supported + # See https://docs.jax.dev/en/latest/export/shape_poly.html#comparison-of-symbolic-dimensions-is-partially-supported return False return diff == 0 @@ -841,7 +841,7 @@ def __eq__(self, other: Any) -> bool: # Here we really ought to raise InconclusiveDimensionOperation, but __eq__ # cannot raise exceptions, because it is used indirectly when hashing. # So, we say that the expressions are disequal, which is really unsound. - # See https://jax.readthedocs.io/en/latest/export/shape_poly.html#comparison-of-symbolic-dimensions-is-partially-supported + # See https://docs.jax.dev/en/latest/export/shape_poly.html#comparison-of-symbolic-dimensions-is-partially-supported return False return diff == 0 @@ -986,7 +986,7 @@ class SymbolicScope: Holds the constraints on symbolic expressions. - See [the README](https://jax.readthedocs.io/en/latest/export/shape_poly.html#user-specified-symbolic-constraints) + See [the README](https://docs.jax.dev/en/latest/export/shape_poly.html#user-specified-symbolic-constraints) for more details. Args: @@ -1112,7 +1112,7 @@ def _check_same_scope(self, other: _DimExpr, f"Invalid mixing of symbolic scopes {when}.\n" f"Expected {self_descr}scope {self}\n" f"and found for '{other}' ({other_descr}) scope {other.scope}\n" - f"See https://jax.readthedocs.io/en/latest/export/shape_poly.html#user-specified-symbolic-constraints.") + f"See https://docs.jax.dev/en/latest/export/shape_poly.html#user-specified-symbolic-constraints.") def _clear_caches(self): self._bounds_cache.clear() @@ -1384,7 +1384,7 @@ def symbolic_shape(shape_spec: str | None, ) -> Sequence[DimSize]: """Constructs a symbolic shape from a string representation. - See https://jax.readthedocs.io/en/latest/export/shape_poly.html for examples. + See https://docs.jax.dev/en/latest/export/shape_poly.html for examples. Args: shape_spec: a symbolic shape specification. None stands for "...". @@ -1396,13 +1396,13 @@ def symbolic_shape(shape_spec: str | None, mod(e1, e2), max(e1, e2), or min(e1, e2). constraints: a sequence of constraints on symbolic dimension expressions, of the form `e1 >= e2` or `e1 <= e2`, or `e1 == e2`. - See [the documentation](https://jax.readthedocs.io/en/latest/export/shape_poly.html#user-specified-symbolic-constraints) + See [the documentation](https://docs.jax.dev/en/latest/export/shape_poly.html#user-specified-symbolic-constraints) for usage. scope: optionally, you can specify that the parsed symbolic expressions be created in the given scope. If this is missing, then a new `SymbolicScope` is created with the given `constraints`. You cannot specify both a `scope` and `constraints`. - See [the documentation](https://jax.readthedocs.io/en/latest/export/shape_poly.html#user-specified-symbolic-constraints) + See [the documentation](https://docs.jax.dev/en/latest/export/shape_poly.html#user-specified-symbolic-constraints) for usage. like: when `shape_spec` contains placeholders ("_", "..."), use this shape to fill in the placeholders. @@ -1437,7 +1437,7 @@ def symbolic_args_specs( """Constructs a pytree of jax.ShapeDtypeSpec arguments specs for `export`. See the documentation of :func:`jax.export.symbolic_shape` and - the [shape polymorphism documentation](https://jax.readthedocs.io/en/latest/export/shape_poly.html) for details. + the [shape polymorphism documentation](https://docs.jax.dev/en/latest/export/shape_poly.html) for details. Args: args: a pytree of arguments. These can be jax.Array, or jax.ShapeDTypeSpec. @@ -1450,7 +1450,7 @@ def symbolic_args_specs( applies to all arguments), or a pytree matching a prefix of the `args`. See [how optional parameters are matched to - arguments](https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees). + arguments](https://docs.jax.dev/en/latest/pytrees.html#applying-optional-parameters-to-pytrees). constraints: as for :func:`jax.export.symbolic_shape`. scope: as for :func:`jax.export.symbolic_shape`. @@ -2038,7 +2038,7 @@ def _solve_dim_equations( " Using the following polymorphic shapes specifications: " + ",".join(f"{arg_name}.shape = {arg_spec}" for arg_name, arg_spec in polymorphic_shape_specs)) + "." - solution_err_msg_trailer_errors = ". Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details." + solution_err_msg_trailer_errors = ". Please see https://docs.jax.dev/en/latest/export/shape_poly.html#shape-assertion-errors for more details." shape_constraints = ShapeConstraints() # accumulate shape constraints scope: SymbolicScope | None = None @@ -2171,6 +2171,6 @@ def add_explicit_symbolic_constraints(shape_env: DimVarEnv): " Unprocessed specifications: " + ", ".join(f"'{eqn.aval_dim_expr}' for dimension size {eqn.dim_name}" for eqn in eqns) + - ". Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#dimension-variables-must-be-solvable-from-the-input-shapes for more details." + ". Please see https://docs.jax.dev/en/latest/export/shape_poly.html#dimension-variables-must-be-solvable-from-the-input-shapes for more details." ) raise ValueError(err_msg) diff --git a/jax/_src/ffi.py b/jax/_src/ffi.py index 05697f00e945..f0c7d761ac2b 100644 --- a/jax/_src/ffi.py +++ b/jax/_src/ffi.py @@ -24,7 +24,6 @@ import jax from jax._src import core -from jax._src import deprecations from jax._src import dispatch from jax._src import effects from jax._src import util @@ -39,11 +38,6 @@ from jax._src.typing import (Array, ArrayLike, DeprecatedArg, DuckTypedArray, Shape) -# TODO(dfm): Remove after 6 months or less because there aren't any offical -# compatibility guarantees for jax.extend (see JEP 15856) -# Added Oct 13, 2024 -deprecations.register("jax-ffi-call-args") - map, unsafe_map = util.safe_map, map FfiLayoutOptions = Sequence[int] | DeviceLocalLayout | None @@ -325,7 +319,7 @@ def _convert_layouts_for_ffi_call( def ffi_call( target_name: str, result_shape_dtypes: ResultMetadata, - *deprecated_args: ArrayLike, + *, has_side_effect: bool = ..., vmap_method: str | None = ..., input_layouts: Sequence[FfiLayoutOptions] | None = ..., @@ -333,9 +327,8 @@ def ffi_call( input_output_aliases: dict[int, int] | None = ..., custom_call_api_version: int = ..., legacy_backend_config: str | None = ..., - vectorized: bool | DeprecatedArg = ..., - **deprecated_kwargs: Any, -) -> Callable[..., Array] | Array: + vectorized: bool | None | DeprecatedArg = DeprecatedArg(), +) -> Callable[..., Array]: ... @@ -343,7 +336,7 @@ def ffi_call( def ffi_call( target_name: str, result_shape_dtypes: Sequence[ResultMetadata], - *deprecated_args: ArrayLike, + *, has_side_effect: bool = ..., vmap_method: str | None = ..., input_layouts: Sequence[FfiLayoutOptions] | None = ..., @@ -351,16 +344,15 @@ def ffi_call( input_output_aliases: dict[int, int] | None = ..., custom_call_api_version: int = ..., legacy_backend_config: str | None = ..., - vectorized: bool | DeprecatedArg = ..., - **deprecated_kwargs: Any, -) -> Callable[..., Sequence[Array]] | Sequence[Array]: + vectorized: bool | None | DeprecatedArg = DeprecatedArg(), +) -> Callable[..., Sequence[Array]]: ... def ffi_call( target_name: str, result_shape_dtypes: ResultMetadata | Sequence[ResultMetadata], - *deprecated_args: ArrayLike, + *, has_side_effect: bool = False, vmap_method: str | None = None, input_layouts: Sequence[FfiLayoutOptions] | None = None, @@ -368,9 +360,8 @@ def ffi_call( input_output_aliases: dict[int, int] | None = None, custom_call_api_version: int = 4, legacy_backend_config: str | None = None, - vectorized: bool | DeprecatedArg = DeprecatedArg(), - **deprecated_kwargs: Any, -) -> Callable[..., Array | Sequence[Array]] | Array | Sequence[Array]: + vectorized: bool | None | DeprecatedArg = DeprecatedArg(), +) -> Callable[..., Array | Sequence[Array]]: """Call a foreign function interface (FFI) target. See the :ref:`ffi-tutorial` tutorial for more information. @@ -430,18 +421,11 @@ def ffi_call( to execute the FFI handler. Any keyword arguments are passed as named attributes to the FFI handler using XLA's FFI interface. """ - if not isinstance(vectorized, DeprecatedArg) and not vectorized is None: - deprecations.warn( - "jax-callback-vectorized", - "The vectorized argument of ffi_call is deprecated and setting " - "it will soon raise an error. To avoid an error in the future, and to " - "suppress this warning, please use the vmap_method argument instead.", - stacklevel=2) - if vmap_method is not None: - raise ValueError( - "the vectorized and vmap_method arguments of ffi_call cannot " - "be used together. Please use the vmap_method argument.") - vmap_method = "legacy_vectorized" if vectorized else "sequential" + # TODO(danfm): Remove this check 3 months after v0.6.0 is released. + if not isinstance(vectorized, DeprecatedArg): + raise ValueError( + "The 'vectorized' argument of jax.ffi.ffi_call was removed in JAX " + "v0.6.0. Use 'vmap_method' instead.") allowed_vmap_methods = ["sequential", "sequential_unrolled", "expand_dims", "broadcast_all", "legacy_vectorized", None] if vmap_method not in allowed_vmap_methods: @@ -515,11 +499,10 @@ def wrapped(*args: ArrayLike, **kwargs: Any): "and an output with a different layout " f"{static_output_layouts[o_idx]}.") static_input_output_aliases += ((i_idx, o_idx),) - + args = core.standard_insert_pvary(*args) results = ffi_call_p.bind( *args, result_avals=result_avals, - vectorized=vectorized, vmap_method=vmap_method, target_name=target_name, has_side_effect=has_side_effect, @@ -537,19 +520,7 @@ def wrapped(*args: ArrayLike, **kwargs: Any): else: return results[0] - if deprecated_args or deprecated_kwargs: - deprecations.warn( - "jax-ffi-call-args", - "Calling ffi_call directly with input arguments is deprecated. " - "Instead, ffi_call should be used to construct a callable, which can " - "then be called with the appropriate inputs. For example,\n" - " ffi_call('target_name', output_type, x, argument=5)\n" - "should be replaced with\n" - " ffi_call('target_name', output_type)(x, argument=5)", - stacklevel=2) - return wrapped(*deprecated_args, **deprecated_kwargs) - else: - return wrapped + return wrapped # ffi_call must support some small non-hashable input arguments, like np.arrays @@ -638,9 +609,10 @@ def ffi_call_abstract_eval( has_side_effect: bool, **_, ): - del avals_in # unused + out_vma = core.standard_vma_rule('ffi_call', *avals_in) effects = {_FfiEffect} if has_side_effect else core.no_effects - return result_avals, effects + return tuple(r if r is core.abstract_token else r.update(vma=out_vma) + for r in result_avals), effects def ffi_call_jvp(*args, target_name, **_): @@ -684,21 +656,10 @@ def ffi_batching_rule( args, dims, *, - vectorized: bool | None | DeprecatedArg, vmap_method: str | None, result_avals: Sequence[core.ShapedArray], **kwargs: Any, ): - if isinstance(vectorized, DeprecatedArg) and vmap_method is None: - deprecations.warn( - "jax-callback-vectorized", - f"The default behavior of {prim.name} under vmap will soon " - "change. Currently, the default behavior is to generate a sequential " - "vmap (i.e. a loop), but in the future the default will be to raise " - "an error. To keep the current default, set vmap_method='sequential'.", - stacklevel=6) - vmap_method = "sequential" - axis_size, = {a.shape[d] for a, d in zip(args, dims) if d is not batching.not_mapped} new_args = [arg if dim is batching.not_mapped else @@ -726,7 +687,6 @@ def ffi_batching_rule( for layout, d in zip(kwargs["input_layouts"], dims)) outvals = prim.bind( *new_args, - vectorized=vectorized, vmap_method=vmap_method, result_avals=batched_result_avals, **kwargs, @@ -742,7 +702,6 @@ def ffi_batching_rule( for layout in kwargs["input_layouts"]) outvals = prim.bind( *bcast_args, - vectorized=vectorized, vmap_method=vmap_method, result_avals=batched_result_avals, **kwargs, @@ -755,7 +714,6 @@ def _batch_fun(batched_args): return prim.bind( *merged_args, result_avals=result_avals, - vectorized=vectorized, vmap_method=vmap_method, **kwargs, ) diff --git a/jax/_src/flatten_util.py b/jax/_src/flatten_util.py index ff35b8db8e25..bf3bd33f286a 100644 --- a/jax/_src/flatten_util.py +++ b/jax/_src/flatten_util.py @@ -41,7 +41,7 @@ def ravel_pytree(pytree): component of the output. For details on dtype promotion, see - https://jax.readthedocs.io/en/latest/type_promotion.html. + https://docs.jax.dev/en/latest/type_promotion.html. """ leaves, treedef = tree_flatten(pytree) diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/annotate_data_placement.py b/jax/_src/internal_test_util/export_back_compat_test_data/annotate_data_placement.py new file mode 100644 index 000000000000..bf70df2cdb3a --- /dev/null +++ b/jax/_src/internal_test_util/export_back_compat_test_data/annotate_data_placement.py @@ -0,0 +1,73 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa + +import datetime +from numpy import array, float32, int32 + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2025_04_07_tpu = dict( + testdata_version=1, + platform='tpu', + custom_call_targets=['annotate_device_placement'], + serialized_date=datetime.date(2025, 4, 7), + inputs=(array([0.], dtype=float32), array([0.], dtype=float32)), + expected_outputs=(array([0.], dtype=float32),), + mlir_module_text=r""" +#loc1 = loc("x") +#loc2 = loc("y") +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<1xf32> {mhlo.memory_kind = "device", mhlo.sharding = "{maximal device=0}"} loc("x"), %arg1: tensor<1xf32> {mhlo.memory_kind = "pinned_host", mhlo.sharding = "{maximal device=0}"} loc("y")) -> (tensor<1xf32> {jax.result_info = "result", mhlo.memory_kind = "pinned_host", mhlo.sharding = "{maximal device=0}"}) { + %0 = stablehlo.add %arg0, %arg1 : tensor<1xf32> loc(#loc4) + %1 = stablehlo.custom_call @annotate_device_placement(%0) {has_side_effect = true, mhlo.frontend_attributes = {_xla_buffer_placement = "pinned_host"}} : (tensor<1xf32>) -> tensor<1xf32> loc(#loc) + return %1 : tensor<1xf32> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":878:13) +#loc4 = loc("jit(func)/jit(main)/add"(#loc3)) +""", + mlir_module_serialized=b"ML\xefR\rStableHLO_v1.9.3\x00\x01\x1b\x05\x01\x05\x0b\x01\x03\x0b\x03\t\x0f\x13\x17\x1b\x03oQ\x0b\x01%\x07\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x17\x0b\x13\x0b\x03-\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x1b\x0b\x0f#\x0b\x0b\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x01\x05\x0b\x0f\x03\x07\x13\x1b\x07\x02\n\x02\x1f\x11\x03\x05\x03\x07\x07\t\x0b\x03\r\x03\x05\x0f\x11\x01\x00\x05\x11\x05\x13\x05\x15\x1d\x13\x01\x05\x17\x1d\x17\x01\x05\x19\x1d\x1b\x1d\x05\x1b\x17\x1f\xba\r\x1b\x05\x1d\x03\x03#E\x05\x1f\x03\x01\x1d!\x1d#\x1d%\x1d'\x03\x0515\r\x05'3)+\x1d)\r\x05'-)+#\x07\x03\x03;\r\x07=?'-)+\x1d+\x1d-\x1d/\x1d1\r\x03G-\x1d3\x0b\x03\x1d5\x1d7\x05\x03\x01\t\x01\x02\x02)\x03\x05\t\x11\x05\x05\x05\x03\x05\t\x04e\x05\x01Q\x01\x05\x01\x07\x04S\x03\x01\x05\x03P\x01\x03\x07\x04?\x03\t\x0f\x05\x0b\x11\x0b\x15\x00\x05\x06\x19\x03\x05\x05\x01\x03\x07G\x01!\x05\x03\x05\x03\x05\t\x04\x01\x03\x07\x06\x03\x01\x05\x01\x00\x9a\x0695\x03-\x0f\x0b\x0f!\x0f\x19'\x1d#3i1\x05\x05\x13%)9\x15\x1f\x0f\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00func_v1\x00add_v1\x00custom_call_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00x\x00y\x00jit(func)/jit(main)/add\x00third_party/py/jax/tests/export_back_compat_test.py\x00mhlo.frontend_attributes\x00mhlo.memory_kind\x00mhlo.sharding\x00{maximal device=0}\x00pinned_host\x00device\x00jax.result_info\x00result\x00main\x00public\x00_xla_buffer_placement\x00\x00annotate_device_placement\x00\x08'\x07\x05\x1f\x01\x0b/79AC\x11IKM%O%%%", + xla_call_module_version=9, + nr_devices=1, +) # End paste + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2025_04_07_cuda = dict( + testdata_version=1, + platform='cuda', + custom_call_targets=['annotate_device_placement'], + serialized_date=datetime.date(2025, 4, 7), + inputs=(array([0.], dtype=float32), array([0.], dtype=float32)), + expected_outputs=(array([0.], dtype=float32),), + mlir_module_text=r""" +#loc1 = loc("x") +#loc2 = loc("y") +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<1xf32> {mhlo.memory_kind = "device", mhlo.sharding = "{maximal device=0}"} loc("x"), %arg1: tensor<1xf32> {mhlo.memory_kind = "pinned_host", mhlo.sharding = "{maximal device=0}"} loc("y")) -> (tensor<1xf32> {jax.result_info = "result", mhlo.memory_kind = "pinned_host", mhlo.sharding = "{maximal device=0}"}) { + %0 = stablehlo.add %arg0, %arg1 : tensor<1xf32> loc(#loc4) + %1 = stablehlo.custom_call @annotate_device_placement(%0) {has_side_effect = true, mhlo.frontend_attributes = {_xla_buffer_placement = "pinned_host"}} : (tensor<1xf32>) -> tensor<1xf32> loc(#loc) + return %1 : tensor<1xf32> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":878:13) +#loc4 = loc("jit(func)/jit(main)/add"(#loc3)) +""", + mlir_module_serialized=b"ML\xefR\rStableHLO_v1.9.3\x00\x01\x1b\x05\x01\x05\x0b\x01\x03\x0b\x03\t\x0f\x13\x17\x1b\x03oQ\x0b\x01%\x07\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x17\x0b\x13\x0b\x03-\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x1b\x0b\x0f#\x0b\x0b\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x01\x05\x0b\x0f\x03\x07\x13\x1b\x07\x02\n\x02\x1f\x11\x03\x05\x03\x07\x07\t\x0b\x03\r\x03\x05\x0f\x11\x01\x00\x05\x11\x05\x13\x05\x15\x1d\x13\x01\x05\x17\x1d\x17\x01\x05\x19\x1d\x1b\x1d\x05\x1b\x17\x1f\xba\r\x1b\x05\x1d\x03\x03#E\x05\x1f\x03\x01\x1d!\x1d#\x1d%\x1d'\x03\x0515\r\x05'3)+\x1d)\r\x05'-)+#\x07\x03\x03;\r\x07=?'-)+\x1d+\x1d-\x1d/\x1d1\r\x03G-\x1d3\x0b\x03\x1d5\x1d7\x05\x03\x01\t\x01\x02\x02)\x03\x05\t\x11\x05\x05\x05\x03\x05\t\x04e\x05\x01Q\x01\x05\x01\x07\x04S\x03\x01\x05\x03P\x01\x03\x07\x04?\x03\t\x0f\x05\x0b\x11\x0b\x15\x00\x05\x06\x19\x03\x05\x05\x01\x03\x07G\x01!\x05\x03\x05\x03\x05\t\x04\x01\x03\x07\x06\x03\x01\x05\x01\x00\x9a\x0695\x03-\x0f\x0b\x0f!\x0f\x19'\x1d#3i1\x05\x05\x13%)9\x15\x1f\x0f\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00func_v1\x00add_v1\x00custom_call_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00x\x00y\x00jit(func)/jit(main)/add\x00third_party/py/jax/tests/export_back_compat_test.py\x00mhlo.frontend_attributes\x00mhlo.memory_kind\x00mhlo.sharding\x00{maximal device=0}\x00pinned_host\x00device\x00jax.result_info\x00result\x00main\x00public\x00_xla_buffer_placement\x00\x00annotate_device_placement\x00\x08'\x07\x05\x1f\x01\x0b/79AC\x11IKM%O%%%", + xla_call_module_version=9, + nr_devices=1, +) # End paste diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_cholesky_lapack_potrf.py b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_cholesky_lapack_potrf.py index eb4143615da6..ee06d902d235 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_cholesky_lapack_potrf.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_cholesky_lapack_potrf.py @@ -17,345 +17,8 @@ import datetime from numpy import array, float32, complex64 -data_2023_06_19 = {} - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_06_19["f32"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_spotrf'], - serialized_date=datetime.date(2023, 6, 19), - inputs=(array([[ 24.343887, 13.603932, 20.50489 , 12.063956], - [ 13.603932, 58.879757, -31.84056 , 16.328012], - [ 20.50489 , -31.84056 , 66.890755, -9.92216 ], - [ 12.063956, 16.328012, -9.92216 , 23.640734]], dtype=float32),), - expected_outputs=(array([[ 4.9339523, 0. , 0. , 0. ], - [ 2.7572079, 7.1608353, 0. , 0. ], - [ 4.155875 , -6.0466647, 3.6134892, 0. ], - [ 2.4450896, 1.3387254, -3.3177967, 2.2050648]], dtype=float32),), - mlir_module_text=r""" -#loc = loc(unknown) -module @jit_cholesky attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<4x4xf32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"} loc(unknown)) -> (tensor<4x4xf32> {jax.result_info = ""}) { - %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<4x4xf32>) -> tensor<4x4xf32> loc(#loc2) - %1 = stablehlo.add %arg0, %0 : tensor<4x4xf32> loc(#loc3) - %2 = stablehlo.constant dense<2.000000e+00> : tensor loc(#loc) - %3 = stablehlo.broadcast_in_dim %2, dims = [] : (tensor) -> tensor<4x4xf32> loc(#loc4) - %4 = stablehlo.divide %1, %3 : tensor<4x4xf32> loc(#loc4) - %5 = stablehlo.constant dense<1> : tensor loc(#loc5) - %6 = stablehlo.constant dense<1> : tensor loc(#loc5) - %7 = stablehlo.constant dense<4> : tensor loc(#loc5) - %8:2 = stablehlo.custom_call @lapack_spotrf(%5, %6, %7, %4) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>]} : (tensor, tensor, tensor, tensor<4x4xf32>) -> (tensor<4x4xf32>, tensor) loc(#loc5) - %9 = stablehlo.constant dense<0> : tensor loc(#loc5) - %10 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor) -> tensor loc(#loc5) - %11 = stablehlo.compare EQ, %8#1, %10, SIGNED : (tensor, tensor) -> tensor loc(#loc5) - %12 = stablehlo.broadcast_in_dim %11, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) - %13 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc5) - %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<4x4xf32> loc(#loc5) - %15 = stablehlo.broadcast_in_dim %12, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc5) - %16 = stablehlo.select %15, %8#0, %14 : tensor<4x4xi1>, tensor<4x4xf32> loc(#loc5) - %17 = call @tril(%16) : (tensor<4x4xf32>) -> tensor<4x4xf32> loc(#loc6) - return %17 : tensor<4x4xf32> loc(#loc) - } loc(#loc) - func.func private @tril(%arg0: tensor<4x4xf32> loc(unknown)) -> tensor<4x4xf32> { - %0 = stablehlo.iota dim = 0 : tensor<4x4xi32> loc(#loc7) - %1 = stablehlo.constant dense<0> : tensor loc(#loc6) - %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<4x4xi32> loc(#loc8) - %3 = stablehlo.add %0, %2 : tensor<4x4xi32> loc(#loc8) - %4 = stablehlo.iota dim = 1 : tensor<4x4xi32> loc(#loc9) - %5 = stablehlo.compare GE, %3, %4, SIGNED : (tensor<4x4xi32>, tensor<4x4xi32>) -> tensor<4x4xi1> loc(#loc10) - %6 = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc6) - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<4x4xf32> loc(#loc11) - %8 = stablehlo.select %5, %arg0, %7 : tensor<4x4xi1>, tensor<4x4xf32> loc(#loc12) - return %8 : tensor<4x4xf32> loc(#loc6) - } loc(#loc6) -} loc(#loc) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":292:0) -#loc2 = loc("jit(cholesky)/jit(main)/transpose[permutation=(1, 0)]"(#loc1)) -#loc3 = loc("jit(cholesky)/jit(main)/add"(#loc1)) -#loc4 = loc("jit(cholesky)/jit(main)/div"(#loc1)) -#loc5 = loc("jit(cholesky)/jit(main)/cholesky"(#loc1)) -#loc6 = loc("jit(cholesky)/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=tril keep_unused=False inline=False]"(#loc1)) -#loc7 = loc("jit(cholesky)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=0]"(#loc1)) -#loc8 = loc("jit(cholesky)/jit(main)/jit(tril)/add"(#loc1)) -#loc9 = loc("jit(cholesky)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=1]"(#loc1)) -#loc10 = loc("jit(cholesky)/jit(main)/jit(tril)/ge"(#loc1)) -#loc11 = loc("jit(cholesky)/jit(main)/jit(tril)/broadcast_in_dim[shape=(4, 4) broadcast_dimensions=()]"(#loc1)) -#loc12 = loc("jit(cholesky)/jit(main)/jit(tril)/select_n"(#loc1)) -""", - mlir_module_serialized=b'ML\xefR\x01StableHLO_v0.9.0\x00\x01)\x05\x01\x03\x01\x03\x05\x03\x19\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x03"\x02\xd9%\x01\x87\x0f\x17\x07\x0b\x13\x0f\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0b\x0f\x13#\x0b\x0b\x0b33\x0b\x0b\x13\x0f\x0b\x0b\x13\x0f\x0b\x1b\x0f\x0b\x13\x0f\x0b\x0f\x0b\x13\x0b\x0f\x0b\x0f\x0b\x13\x0b\x0b\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x1b\x13\x13\x13\x0b\x03S\x0f\x0b\x0b\x0f\x0b\x0bO\x0f\x1b\x0b\x0b\x0b\x0b\x0f\x13\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0b\x1fO\x1f\x1f\x1f\x0b\x0b\x0b\x0b\x1b\x0f\x17\x13\x0b\x1fO\x01\x03\x0f\x03#\x17\x0f\x0f\x17\x07\x07\x07\x07\x17\x13\x07\x17\x13\x13\x13\x0f\x17\x02J\x07\x1dg\x03\x177\x92\x04\x01\x1f\x05\x1f\x03\x03\x1d\xb3\x1d5\x03\x05!\x11\x01\x05\x05#\x05%\x05\'\x05)\x05+\x03\x03\x07\xb1\x05-\x1d?\x03\x05/\x051\x1de\x03\x03\x03\x07\xbf\x03\x07+\x0f-\x0f\r/\x053\x055\x057\x03\x0b\x11\x95\x13\x89\x15\xa1\r\xa7\x17\xa9\x03\x0b\x11\x8d\x13\x89\x15\x8d\r\x8f\x17\xad\x059\x05;\x03\x03\x19\xaf\x1d=\x03\x05=\x05?\x03\x03\x19\xb5\x1dE\x03\x05A\x03\x05!\x91#\xb7\x1dK\x03\x05C\x03\x03\x07\xb9\x1dQ\x03\x05E\x1dU\x03\x05G\x03\x03Y\xbb\x05I\x1d]\x03\x05K\x1da\x03\x05M\x03\x03\x07\xbd\x05O\x05Q\x03\x03\x07\xc1\x03\x11m\xc3o\x8bq\xc5s\xc7u\xc9w\xcby\xcd{\xd1\x05S\x05U\x05W\x05Y\x05[\x05]\x05_\x05a\x03\x05!\x91#\xd3\x03\x03\x07\xd5\x03\x03\x1d\xd7\x03\x03\x85\x8f\x05c\x1f\x1d\x01#\x19\x1de\x03\x03\xab\x1dg\t\x07\x1f\x1f!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x03\x97\r\x05\x99\x9b\x9d\x9f\x1di\x1dk\x1dm\x1do\x03\x03\xa3\r\x03\xa5\x8b\x1dq\x1ds\x1du\r\x01\x1dw\x13\x0b\x01\x1f\x05\t\x00\x00\x00\x00\x1f\x1b\x01\x13\x0b\x05\x07\x05\x1f\x07\t\x00\x00\x00\x00\x1f\x15!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x07\t\x00\x00\x00@\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x04\x00\x00\x00\x0b\x05\x1dy\x03\x01\x05\x01\x03\t\x87\x87\x87\x93\x03\x03\xcf\x15\x03\x01\r\x01\x03\x05\x93\x87\x07\x01\x1f\x07\t\x00\x00\xc0\x7f\x1f\x15!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x02\x02)\x05\x11\x11\x0f)\x01\x11)\x01\x0f)\x05\x11\x11\x11\x1d\x01\t\x1b)\x05\x11\x11\r)\x03\t\x0b\x13\x11\x03\x03\x03\x03)\x03\x01\x0b)\x03\x01\x17)\x03\t\x17)\x01\r)\x05\x05\x05\r\x04\xd6\x03\x05\x01\x11\x05)\x07\x03\x01\t\x07\x11\x051\x05\x03)O\x03\x03\x05\x13\x07[W\x03\x03\x03\x01\x0b\x06_\x03\x03\x05\x01\x03\x03\x03\x05c\x03\x07\x05\x07%\t\x03\x03\x03\x07\x15\x06%\x03\x03\x05\x05\t\x03\x03\x01\'\x03\x05\x03\x03\x01\'\x03\x05\x03\x03\x01i\x03\x05\x17\x07\x01k\x05\x03\x05\t\r\x0f\x11\x0b\x03\x03\x01\x1b\x03\x05\x05\x07\x01\t\x03\x05\x03\x17\r\x07\x01}\x03!\x05\x15\x19\x05\x07\x01\t\x03#\x03\x1b\x03\x03\x01\x7f\x03\x07\x05\x07\x01\t\x03\x03\x03\x1f\x05\x07\x01\x81\x03\x13\x03\x1d\x0f\x06\x01\x03\x03\x07#\x13!\x19\x07\x0b\x83\x03\x03\x03%\x11\x04\x05\x03\'\x07\x11\x0b3\x05\x03\x15+\x03\x03\x05\t\x03;9\x03\t\x03\x03\x0b\x1b\x03\x05\x05\x07\x1f\t\x03\t\x03\x05\x0b\x06\x1f\x03\t\x05\x03\x07\t\x03CA\x03\t\r\x07IG\x03\x13\x05\t\x0b\x03\x03\x0bM\x03\x07\x05\x07O\t\x03\x03\x03\x0f\x0f\x06S\x03\x03\x07\r\x01\x11\x11\x04\x0b\x03\x13\x06\x03\x01\x05\x01\x00\n\x16{\x1d\x11\x0f\x0b!\x1b\x1d\x05\x1b\x0b\x03\x0f\x1f/!!)#\x1f\x19C99m\x19W\xb3K\x9bM\x9b\x97\xd2\x02\x1b%)+\x1b+\x1f\x1f\x15\x1d\x15\x13\r\x11\x1f\x15\x1b\x15\x15\x17\x0f\x11\x11)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00func_v1\x00iota_v1\x00add_v1\x00compare_v1\x00select_v1\x00return_v1\x00transpose_v1\x00divide_v1\x00custom_call_v1\x00call_v1\x00value\x00sym_name\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00broadcast_dimensions\x00compare_type\x00comparison_direction\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_cholesky\x00jit(cholesky)/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=tril keep_unused=False inline=False]\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00jit(cholesky)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=0]\x00jit(cholesky)/jit(main)/jit(tril)/add\x00jit(cholesky)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=1]\x00jit(cholesky)/jit(main)/jit(tril)/ge\x00jit(cholesky)/jit(main)/jit(tril)/broadcast_in_dim[shape=(4, 4) broadcast_dimensions=()]\x00jit(cholesky)/jit(main)/jit(tril)/select_n\x00permutation\x00jit(cholesky)/jit(main)/transpose[permutation=(1, 0)]\x00jit(cholesky)/jit(main)/add\x00jit(cholesky)/jit(main)/div\x00jit(cholesky)/jit(main)/cholesky\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00callee\x00\x00tril\x00jax.arg_info\x00x\x00mhlo.sharding\x00{replicated}\x00jax.result_info\x00main\x00public\x00private\x00lapack_spotrf\x00', - xla_call_module_version=6, -) # End paste - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_06_19["f64"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_dpotrf'], - serialized_date=datetime.date(2023, 6, 19), - inputs=(array([[ 23.022171138130666 , -16.79765603341739 , 0.9133449305189146, - -25.36636199966769 ], - [-16.79765603341739 , 31.655770252600092 , -1.5189878284433445, - 20.0344758332268 ], - [ 0.9133449305189146, -1.5189878284433445, 10.940134497877208 , - 8.169020034607513 ], - [-25.36636199966769 , 20.0344758332268 , 8.169020034607513 , - 37.054603917509596 ]]),), - expected_outputs=(array([[ 4.7981424674691215 , 0. , 0. , - 0. ], - [-3.500866459740129 , 4.404509539513645 , 0. , - 0. ], - [ 0.19035385812557523, -0.1935707899825621 , 3.2964268922333835 , - 0. ], - [-5.286704630312426 , 0.3465604732420997 , 2.8037778311164425 , - 1.060228174247855 ]]),), - mlir_module_text=r""" -#loc = loc(unknown) -module @jit_cholesky attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<4x4xf64> {jax.arg_info = "x", mhlo.sharding = "{replicated}"} loc(unknown)) -> (tensor<4x4xf64> {jax.result_info = ""}) { - %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<4x4xf64>) -> tensor<4x4xf64> loc(#loc2) - %1 = stablehlo.add %arg0, %0 : tensor<4x4xf64> loc(#loc3) - %2 = stablehlo.constant dense<2.000000e+00> : tensor loc(#loc) - %3 = stablehlo.broadcast_in_dim %2, dims = [] : (tensor) -> tensor<4x4xf64> loc(#loc4) - %4 = stablehlo.divide %1, %3 : tensor<4x4xf64> loc(#loc4) - %5 = stablehlo.constant dense<1> : tensor loc(#loc5) - %6 = stablehlo.constant dense<1> : tensor loc(#loc5) - %7 = stablehlo.constant dense<4> : tensor loc(#loc5) - %8:2 = stablehlo.custom_call @lapack_dpotrf(%5, %6, %7, %4) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>]} : (tensor, tensor, tensor, tensor<4x4xf64>) -> (tensor<4x4xf64>, tensor) loc(#loc5) - %9 = stablehlo.constant dense<0> : tensor loc(#loc5) - %10 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor) -> tensor loc(#loc5) - %11 = stablehlo.compare EQ, %8#1, %10, SIGNED : (tensor, tensor) -> tensor loc(#loc5) - %12 = stablehlo.broadcast_in_dim %11, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) - %13 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc5) - %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<4x4xf64> loc(#loc5) - %15 = stablehlo.broadcast_in_dim %12, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc5) - %16 = stablehlo.select %15, %8#0, %14 : tensor<4x4xi1>, tensor<4x4xf64> loc(#loc5) - %17 = call @tril(%16) : (tensor<4x4xf64>) -> tensor<4x4xf64> loc(#loc6) - return %17 : tensor<4x4xf64> loc(#loc) - } loc(#loc) - func.func private @tril(%arg0: tensor<4x4xf64> loc(unknown)) -> tensor<4x4xf64> { - %0 = stablehlo.iota dim = 0 : tensor<4x4xi32> loc(#loc7) - %1 = stablehlo.constant dense<0> : tensor loc(#loc6) - %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<4x4xi32> loc(#loc8) - %3 = stablehlo.add %0, %2 : tensor<4x4xi32> loc(#loc8) - %4 = stablehlo.iota dim = 1 : tensor<4x4xi32> loc(#loc9) - %5 = stablehlo.compare GE, %3, %4, SIGNED : (tensor<4x4xi32>, tensor<4x4xi32>) -> tensor<4x4xi1> loc(#loc10) - %6 = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc6) - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<4x4xf64> loc(#loc11) - %8 = stablehlo.select %5, %arg0, %7 : tensor<4x4xi1>, tensor<4x4xf64> loc(#loc12) - return %8 : tensor<4x4xf64> loc(#loc6) - } loc(#loc6) -} loc(#loc) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":292:0) -#loc2 = loc("jit(cholesky)/jit(main)/transpose[permutation=(1, 0)]"(#loc1)) -#loc3 = loc("jit(cholesky)/jit(main)/add"(#loc1)) -#loc4 = loc("jit(cholesky)/jit(main)/div"(#loc1)) -#loc5 = loc("jit(cholesky)/jit(main)/cholesky"(#loc1)) -#loc6 = loc("jit(cholesky)/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=tril keep_unused=False inline=False]"(#loc1)) -#loc7 = loc("jit(cholesky)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=0]"(#loc1)) -#loc8 = loc("jit(cholesky)/jit(main)/jit(tril)/add"(#loc1)) -#loc9 = loc("jit(cholesky)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=1]"(#loc1)) -#loc10 = loc("jit(cholesky)/jit(main)/jit(tril)/ge"(#loc1)) -#loc11 = loc("jit(cholesky)/jit(main)/jit(tril)/broadcast_in_dim[shape=(4, 4) broadcast_dimensions=()]"(#loc1)) -#loc12 = loc("jit(cholesky)/jit(main)/jit(tril)/select_n"(#loc1)) -""", - mlir_module_serialized=b'ML\xefR\x01StableHLO_v0.9.0\x00\x01)\x05\x01\x03\x01\x03\x05\x03\x19\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x03"\x02\xd9%\x01\x87\x0f\x17\x07\x0b\x13\x0f\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0b\x0f\x13#\x0b\x0b\x0b33\x0b\x0b\x13\x0f\x0b\x0b\x13\x0f\x0b\x1b\x0f\x0b\x13\x0f\x0b\x0f\x0b\x13\x0b\x0f\x0b\x0f\x0b\x13\x0b\x0b\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x1b\x13\x13\x13\x0b\x03S\x0f\x0b\x0b\x0f\x0b\x0bO\x0f\x1b\x0b\x0b\x0b\x0b\x0f\x13\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0b/O/\x1f\x1f\x0b\x0b\x0b\x0b\x1b\x0f\x17\x13\x0b/O\x01\x03\x0f\x03#\x17\x0f\x0f\x17\x07\x07\x07\x07\x17\x13\x07\x17\x13\x13\x13\x0f\x17\x02z\x07\x1dg\x03\x177\x92\x04\x01\x1f\x05\x1f\x03\x03\x1d\xb3\x1d5\x03\x05!\x11\x01\x05\x05#\x05%\x05\'\x05)\x05+\x03\x03\x07\xb1\x05-\x1d?\x03\x05/\x051\x1de\x03\x03\x03\x07\xbf\x03\x07+\x0f-\x0f\r/\x053\x055\x057\x03\x0b\x11\x95\x13\x89\x15\xa1\r\xa7\x17\xa9\x03\x0b\x11\x8d\x13\x89\x15\x8d\r\x8f\x17\xad\x059\x05;\x03\x03\x19\xaf\x1d=\x03\x05=\x05?\x03\x03\x19\xb5\x1dE\x03\x05A\x03\x05!\x91#\xb7\x1dK\x03\x05C\x03\x03\x07\xb9\x1dQ\x03\x05E\x1dU\x03\x05G\x03\x03Y\xbb\x05I\x1d]\x03\x05K\x1da\x03\x05M\x03\x03\x07\xbd\x05O\x05Q\x03\x03\x07\xc1\x03\x11m\xc3o\x8bq\xc5s\xc7u\xc9w\xcby\xcd{\xd1\x05S\x05U\x05W\x05Y\x05[\x05]\x05_\x05a\x03\x05!\x91#\xd3\x03\x03\x07\xd5\x03\x03\x1d\xd7\x03\x03\x85\x8f\x05c\x1f\x1d\x01#\x19\x1de\x03\x03\xab\x1dg\t\x07\x1f\x1f!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x03\x97\r\x05\x99\x9b\x9d\x9f\x1di\x1dk\x1dm\x1do\x03\x03\xa3\r\x03\xa5\x8b\x1dq\x1ds\x1du\r\x01\x1dw\x13\x0b\x01\x1f\x05\t\x00\x00\x00\x00\x1f\x1b\x01\x13\x0b\x05\x07\x05\x1f\x07\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x15!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x07\x11\x00\x00\x00\x00\x00\x00\x00@\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x04\x00\x00\x00\x0b\x05\x1dy\x03\x01\x05\x01\x03\t\x87\x87\x87\x93\x03\x03\xcf\x15\x03\x01\r\x01\x03\x05\x93\x87\x07\x01\x1f\x07\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x15!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x02\x02)\x05\x11\x11\x0f)\x01\x11)\x01\x0f)\x05\x11\x11\x11\x1d\x01\x0b\x1b)\x05\x11\x11\r)\x03\t\x0b\x13\x11\x03\x03\x03\x03)\x03\x01\x0b)\x03\x01\x17)\x03\t\x17)\x01\r)\x05\x05\x05\r\x04\xd6\x03\x05\x01\x11\x05)\x07\x03\x01\t\x07\x11\x051\x05\x03)O\x03\x03\x05\x13\x07[W\x03\x03\x03\x01\x0b\x06_\x03\x03\x05\x01\x03\x03\x03\x05c\x03\x07\x05\x07%\t\x03\x03\x03\x07\x15\x06%\x03\x03\x05\x05\t\x03\x03\x01\'\x03\x05\x03\x03\x01\'\x03\x05\x03\x03\x01i\x03\x05\x17\x07\x01k\x05\x03\x05\t\r\x0f\x11\x0b\x03\x03\x01\x1b\x03\x05\x05\x07\x01\t\x03\x05\x03\x17\r\x07\x01}\x03!\x05\x15\x19\x05\x07\x01\t\x03#\x03\x1b\x03\x03\x01\x7f\x03\x07\x05\x07\x01\t\x03\x03\x03\x1f\x05\x07\x01\x81\x03\x13\x03\x1d\x0f\x06\x01\x03\x03\x07#\x13!\x19\x07\x0b\x83\x03\x03\x03%\x11\x04\x05\x03\'\x07\x11\x0b3\x05\x03\x15+\x03\x03\x05\t\x03;9\x03\t\x03\x03\x0b\x1b\x03\x05\x05\x07\x1f\t\x03\t\x03\x05\x0b\x06\x1f\x03\t\x05\x03\x07\t\x03CA\x03\t\r\x07IG\x03\x13\x05\t\x0b\x03\x03\x0bM\x03\x07\x05\x07O\t\x03\x03\x03\x0f\x0f\x06S\x03\x03\x07\r\x01\x11\x11\x04\x0b\x03\x13\x06\x03\x01\x05\x01\x00\n\x16{\x1d\x11\x0f\x0b!\x1b\x1d\x05\x1b\x0b\x03\x0f\x1f/!!)#\x1f\x19C99m\x19W\xb3K\x9bM\x9b\x97\xd2\x02\x1b%)+\x1b+\x1f\x1f\x15\x1d\x15\x13\r\x11\x1f\x15\x1b\x15\x15\x17\x0f\x11\x11)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00func_v1\x00iota_v1\x00add_v1\x00compare_v1\x00select_v1\x00return_v1\x00transpose_v1\x00divide_v1\x00custom_call_v1\x00call_v1\x00value\x00sym_name\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00broadcast_dimensions\x00compare_type\x00comparison_direction\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_cholesky\x00jit(cholesky)/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=tril keep_unused=False inline=False]\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00jit(cholesky)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=0]\x00jit(cholesky)/jit(main)/jit(tril)/add\x00jit(cholesky)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=1]\x00jit(cholesky)/jit(main)/jit(tril)/ge\x00jit(cholesky)/jit(main)/jit(tril)/broadcast_in_dim[shape=(4, 4) broadcast_dimensions=()]\x00jit(cholesky)/jit(main)/jit(tril)/select_n\x00permutation\x00jit(cholesky)/jit(main)/transpose[permutation=(1, 0)]\x00jit(cholesky)/jit(main)/add\x00jit(cholesky)/jit(main)/div\x00jit(cholesky)/jit(main)/cholesky\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00callee\x00\x00tril\x00jax.arg_info\x00x\x00mhlo.sharding\x00{replicated}\x00jax.result_info\x00main\x00public\x00private\x00lapack_dpotrf\x00', - xla_call_module_version=6, -) # End paste - - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_06_19["c64"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_cpotrf'], - serialized_date=datetime.date(2023, 6, 19), - inputs=(array([[ 38.089394 +6.36582342e-09j, 3.3509154+3.13455486e+01j, - -0.5972489-3.80308151e+01j, -19.04205 +1.22770605e+01j], - [ 3.3509154-3.13455486e+01j, 73.875755 +4.06565448e-09j, - -12.427276 -1.23379612e+01j, 41.542507 -9.63993359e+00j], - [ -0.5972489+3.80308151e+01j, -12.427276 +1.23379612e+01j, - 73.04141 -4.18667753e-07j, 8.193126 -2.60565052e+01j], - [-19.04205 -1.22770605e+01j, 41.542507 +9.63993359e+00j, - 8.193126 +2.60565052e+01j, 52.977036 -1.09952367e-07j]], - dtype=complex64),), - expected_outputs=(array([[ 6.1716604 +0.j , 0. +0.j , - 0. +0.j , 0. +0.j ], - [ 0.542952 -5.078949j , 6.912687 +0.j , - 0. +0.j , 0. +0.j ], - [-0.09677281+6.162169j , 2.7373738 +1.3719271j, - 5.0679703 +0.j , 0. +0.j ], - [-3.0854013 -1.9892638j, 4.7903748 +3.8177056j, - 0.3555784 +0.5865844j, 1.2276335 +0.j ]], dtype=complex64),), - mlir_module_text=r""" -#loc = loc(unknown) -module @jit_cholesky attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<4x4xcomplex> {jax.arg_info = "x", mhlo.sharding = "{replicated}"} loc(unknown)) -> (tensor<4x4xcomplex> {jax.result_info = ""}) { - %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<4x4xcomplex>) -> tensor<4x4xcomplex> loc(#loc2) - %1 = stablehlo.real %0 : (tensor<4x4xcomplex>) -> tensor<4x4xf32> loc(#loc3) - %2 = stablehlo.imag %0 : (tensor<4x4xcomplex>) -> tensor<4x4xf32> loc(#loc4) - %3 = stablehlo.negate %2 : tensor<4x4xf32> loc(#loc5) - %4 = stablehlo.complex %1, %3 : tensor<4x4xcomplex> loc(#loc6) - %5 = stablehlo.add %arg0, %4 : tensor<4x4xcomplex> loc(#loc7) - %6 = stablehlo.constant dense<(2.000000e+00,0.000000e+00)> : tensor> loc(#loc) - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc8) - %8 = stablehlo.divide %5, %7 : tensor<4x4xcomplex> loc(#loc8) - %9 = stablehlo.constant dense<1> : tensor loc(#loc9) - %10 = stablehlo.constant dense<1> : tensor loc(#loc9) - %11 = stablehlo.constant dense<4> : tensor loc(#loc9) - %12:2 = stablehlo.custom_call @lapack_cpotrf(%9, %10, %11, %8) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>]} : (tensor, tensor, tensor, tensor<4x4xcomplex>) -> (tensor<4x4xcomplex>, tensor) loc(#loc9) - %13 = stablehlo.constant dense<0> : tensor loc(#loc9) - %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor loc(#loc9) - %15 = stablehlo.compare EQ, %12#1, %14, SIGNED : (tensor, tensor) -> tensor loc(#loc9) - %16 = stablehlo.broadcast_in_dim %15, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc9) - %17 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc9) - %18 = stablehlo.broadcast_in_dim %17, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc9) - %19 = stablehlo.broadcast_in_dim %16, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc9) - %20 = stablehlo.select %19, %12#0, %18 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc9) - %21 = call @tril(%20) : (tensor<4x4xcomplex>) -> tensor<4x4xcomplex> loc(#loc10) - return %21 : tensor<4x4xcomplex> loc(#loc) - } loc(#loc) - func.func private @tril(%arg0: tensor<4x4xcomplex> loc(unknown)) -> tensor<4x4xcomplex> { - %0 = stablehlo.iota dim = 0 : tensor<4x4xi32> loc(#loc11) - %1 = stablehlo.constant dense<0> : tensor loc(#loc10) - %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<4x4xi32> loc(#loc12) - %3 = stablehlo.add %0, %2 : tensor<4x4xi32> loc(#loc12) - %4 = stablehlo.iota dim = 1 : tensor<4x4xi32> loc(#loc13) - %5 = stablehlo.compare GE, %3, %4, SIGNED : (tensor<4x4xi32>, tensor<4x4xi32>) -> tensor<4x4xi1> loc(#loc14) - %6 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> loc(#loc10) - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc15) - %8 = stablehlo.select %5, %arg0, %7 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc16) - return %8 : tensor<4x4xcomplex> loc(#loc10) - } loc(#loc10) -} loc(#loc) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":292:0) -#loc2 = loc("jit(cholesky)/jit(main)/transpose[permutation=(1, 0)]"(#loc1)) -#loc3 = loc("jit(cholesky)/jit(main)/real"(#loc1)) -#loc4 = loc("jit(cholesky)/jit(main)/imag"(#loc1)) -#loc5 = loc("jit(cholesky)/jit(main)/neg"(#loc1)) -#loc6 = loc("jit(cholesky)/jit(main)/complex"(#loc1)) -#loc7 = loc("jit(cholesky)/jit(main)/add"(#loc1)) -#loc8 = loc("jit(cholesky)/jit(main)/div"(#loc1)) -#loc9 = loc("jit(cholesky)/jit(main)/cholesky"(#loc1)) -#loc10 = loc("jit(cholesky)/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=tril keep_unused=False inline=False]"(#loc1)) -#loc11 = loc("jit(cholesky)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=0]"(#loc1)) -#loc12 = loc("jit(cholesky)/jit(main)/jit(tril)/add"(#loc1)) -#loc13 = loc("jit(cholesky)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=1]"(#loc1)) -#loc14 = loc("jit(cholesky)/jit(main)/jit(tril)/ge"(#loc1)) -#loc15 = loc("jit(cholesky)/jit(main)/jit(tril)/broadcast_in_dim[shape=(4, 4) broadcast_dimensions=()]"(#loc1)) -#loc16 = loc("jit(cholesky)/jit(main)/jit(tril)/select_n"(#loc1)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x011\x05\x01\x03\x01\x03\x05\x03!\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f!#%\x03J\x02\xe9)\x01\x97\x17\x0f\x07\x0b\x13\x0f\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0b\x0f\x13#\x0b\x0b\x0b33\x0b\x0b\x13\x0f\x0b\x0b\x13\x0f\x0b\x1b\x0f\x0b\x13\x0f\x0b\x0f\x0b\x13\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x13\x0b\x0b\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x1b\x13\x13\x13\x0b\x03S\x0f\x0b\x0b\x0f\x0b\x0bO\x0f\x1b\x0b\x0b\x0b\x0b\x0f\x13\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0b/O/\x1f\x1f\x0b\x0b\x0b\x0b\x1b\x0f\x17\x13\x0b/O\x01\x03\x0f\x03'\x17\x0f\x0f\x17\x07\x07\x17\x0b\x07\x07\x17\x13\x07\x17\x13\x13\x13\x0f\x17\x02\xe6\x07\x177\x92\x04\x01\x1dw\x01\x1f\x05'\x03\x03\x1d\xc3\x1d5\x01\x05)\x11\x01\x05\x05+\x05-\x05/\x051\x053\x03\x03\x07\xc1\x055\x1d?\x01\x057\x059\x1du\x01\x03\x03\x07\xcf\x03\x07+\x0f-\x0f\r/\x05;\x05=\x05?\x03\x0b\x11\xa5\x13\x99\x15\xb1\r\xb7\x17\xb9\x03\x0b\x11\x9d\x13\x99\x15\x9d\r\x9f\x17\xbd\x05A\x05C\x03\x03\x19\xbf\x1d=\x01\x05E\x05G\x03\x03\x19\xc5\x1dE\x01\x05I\x03\x05!\xa1#\xc7\x1dK\x01\x05K\x03\x03\x07\xc9\x1dQ\x01\x05M\x1dU\x01\x05O\x03\x03Y\xcb\x05Q\x1d]\x01\x05S\x1da\x01\x05U\x1de\x01\x05W\x1di\x01\x05Y\x1dm\x01\x05[\x1dq\x01\x05]\x03\x03\x07\xcd\x05_\x05a\x03\x03\x07\xd1\x03\x11}\xd3\x7f\x9b\x81\xd5\x83\xd7\x85\xd9\x87\xdb\x89\xdd\x8b\xe1\x05c\x05e\x05g\x05i\x05k\x05m\x05o\x05q\x03\x05!\xa1#\xe3\x03\x03\x07\xe5\x03\x03\x1d\xe7\x03\x03\x95\x9f\x05s\x1f!\x01#\x1d\x1du\x03\x03\xbb\x1dw\t\x07\x1f#!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x03\xa7\r\x05\xa9\xab\xad\xaf\x1dy\x1d{\x1d}\x1d\x7f\x03\x03\xb3\r\x03\xb5\x9b\x1d\x81\x1d\x83\x1d\x85\r\x01\x1d\x87\x13\x0b\x01\x1f\x05\t\x00\x00\x00\x00\x1f\x1f\x01\x13\x0b\x05\x07\x05\x1f\x07\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x19!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x07\x11\x00\x00\x00@\x00\x00\x00\x00\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x04\x00\x00\x00\x0b\x05\x1d\x89\x03\x01\x05\x01\x03\t\x97\x97\x97\xa3\x03\x03\xdf\x15\x03\x01\r\x01\x03\x05\xa3\x97\x07\x01\x1f\x07\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f\x19!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x02\x02)\x05\x11\x11\x11)\x01\x15)\x01\x11)\x05\x11\x11\x15\x1d\x01)\x05\x11\x11\x13\x03\x13\t\x1b)\x05\x11\x11\r)\x03\t\x0b\x13\x11\x03\x03\x03\x03)\x03\x01\x0b)\x03\x01\x1b)\x03\t\x1b)\x01\r)\x05\x05\x05\r\x04J\x04\x05\x01\x11\x05)\x07\x03\x01\t\x07\x11\x051\x05\x031_\x03\x03\x05\x13\x07[W\x03\x03\x03\x01\x15\x06_\x03\x0f\x03\x03\x17\x06c\x03\x0f\x03\x03\x19\x06g\x03\x0f\x03\x07\x1b\x06k\x03\x03\x05\x05\t\x0b\x06o\x03\x03\x05\x01\x0b\x03\x03\x05s\x03\x07\x05\x07%\t\x03\x03\x03\x0f\x1d\x06%\x03\x03\x05\r\x11\x03\x03\x03'\x03\x05\x03\x03\x03'\x03\x05\x03\x03\x03y\x03\x05\x1f\x07\x03{\x05\x03\x05\t\x15\x17\x19\x13\x03\x03\x03\x1b\x03\x05\x05\x07\x03\t\x03\x05\x03\x1f\r\x07\x03\x8d\x03%\x05\x1d!\x05\x07\x03\t\x03'\x03#\x03\x03\x03\x8f\x03\x07\x05\x07\x03\t\x03\x03\x03'\x05\x07\x03\x91\x03\x17\x03%\x0f\x06\x03\x03\x03\x07+\x1b)!\x07\x0b\x93\x03\x03\x03-\x11\x04\x05\x03/\x07\x11\x0b3\x05\x03\x15+\x03\x03\x05\t\x03;9\x03\t\x03\x03\x0b\x1b\x03\x05\x05\x07\x1f\t\x03\t\x03\x05\x0b\x06\x1f\x03\t\x05\x03\x07\t\x03CA\x03\t\r\x07IG\x03\x17\x05\t\x0b\x03\x03\x0bM\x03\x07\x05\x07O\t\x03\x03\x03\x0f\x0f\x06S\x03\x03\x07\r\x01\x11\x11\x04\x0b\x03\x13\x06\x03\x01\x05\x01\x00\x96\x18\x8b\x1d\x11\x0f\x0b!\x1b\x1d\x05\x1b\x0b\x03\x0f\x1f/!!)#\x1f\x19C99A9;;m\x19W\xb3K\x9bM\x9b\x97\xd2\x02\x1b%)+\x1b+\x1f\x1f\x15\x1d\x15\x13\r\x11\x1f\x15\x17\x15\x11\x11\x1b\x15\x15\x17\x0f\x11\x11)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00func_v1\x00iota_v1\x00add_v1\x00compare_v1\x00select_v1\x00return_v1\x00transpose_v1\x00real_v1\x00imag_v1\x00negate_v1\x00complex_v1\x00divide_v1\x00custom_call_v1\x00call_v1\x00value\x00sym_name\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00broadcast_dimensions\x00compare_type\x00comparison_direction\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_cholesky\x00jit(cholesky)/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=tril keep_unused=False inline=False]\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00jit(cholesky)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=0]\x00jit(cholesky)/jit(main)/jit(tril)/add\x00jit(cholesky)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=1]\x00jit(cholesky)/jit(main)/jit(tril)/ge\x00jit(cholesky)/jit(main)/jit(tril)/broadcast_in_dim[shape=(4, 4) broadcast_dimensions=()]\x00jit(cholesky)/jit(main)/jit(tril)/select_n\x00permutation\x00jit(cholesky)/jit(main)/transpose[permutation=(1, 0)]\x00jit(cholesky)/jit(main)/real\x00jit(cholesky)/jit(main)/imag\x00jit(cholesky)/jit(main)/neg\x00jit(cholesky)/jit(main)/complex\x00jit(cholesky)/jit(main)/add\x00jit(cholesky)/jit(main)/div\x00jit(cholesky)/jit(main)/cholesky\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00callee\x00\x00tril\x00jax.arg_info\x00x\x00mhlo.sharding\x00{replicated}\x00jax.result_info\x00main\x00public\x00private\x00lapack_cpotrf\x00", - xla_call_module_version=6, -) # End paste - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_06_19["c128"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_zpotrf'], - serialized_date=datetime.date(2023, 6, 19), - inputs=(array([[ 77.35445791180521 -6.4555004827448569e-16j, - 16.89356598261691 -5.4959586590823566e+00j, - -21.124380423202325+6.4431220601700787e+01j, - 55.385054340628855+2.5198457006849742e+00j], - [ 16.89356598261691 +5.4959586590823566e+00j, - 67.125263428637 -3.2921739472953976e-16j, - 25.14078382035968 +1.2783276691803774e+01j, - 51.116221409460884-2.2635508887939348e+00j], - [-21.124380423202325-6.4431220601700787e+01j, - 25.14078382035968 -1.2783276691803774e+01j, - 107.43449297637208 -2.8959717546347756e-15j, - 12.493792156221616-5.7556567757218694e+01j], - [ 55.385054340628855-2.5198457006849715e+00j, - 51.116221409460884+2.2635508887939326e+00j, - 12.493792156221616+5.7556567757218708e+01j, - 78.9856503203742 +2.0971925518284437e-16j]]),), - expected_outputs=(array([[ 8.795138311124232 +0.j , - 0. +0.j , - 0. +0.j , - 0. +0.j ], - [ 1.9207845726825759+0.624885984127274j , - 7.940111306576433 +0.j , - 0. +0.j , - 0. +0.j ], - [-2.401824698593298 -7.325776846534311j , - 4.3238621722485755-0.026813746599595675j, - 5.413152651345813 +0.j , - 0. +0.j ], - [ 6.297235174866659 -0.28650438589440164j , - 4.936910868956218 +0.849977768846063j , - 0.7751580530200595+1.279980716041562j , - 3.451611642915363 +0.j ]]),), - mlir_module_text=r""" -#loc = loc(unknown) -module @jit_cholesky attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<4x4xcomplex> {jax.arg_info = "x", mhlo.sharding = "{replicated}"} loc(unknown)) -> (tensor<4x4xcomplex> {jax.result_info = ""}) { - %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<4x4xcomplex>) -> tensor<4x4xcomplex> loc(#loc2) - %1 = stablehlo.real %0 : (tensor<4x4xcomplex>) -> tensor<4x4xf64> loc(#loc3) - %2 = stablehlo.imag %0 : (tensor<4x4xcomplex>) -> tensor<4x4xf64> loc(#loc4) - %3 = stablehlo.negate %2 : tensor<4x4xf64> loc(#loc5) - %4 = stablehlo.complex %1, %3 : tensor<4x4xcomplex> loc(#loc6) - %5 = stablehlo.add %arg0, %4 : tensor<4x4xcomplex> loc(#loc7) - %6 = stablehlo.constant dense<(2.000000e+00,0.000000e+00)> : tensor> loc(#loc) - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc8) - %8 = stablehlo.divide %5, %7 : tensor<4x4xcomplex> loc(#loc8) - %9 = stablehlo.constant dense<1> : tensor loc(#loc9) - %10 = stablehlo.constant dense<1> : tensor loc(#loc9) - %11 = stablehlo.constant dense<4> : tensor loc(#loc9) - %12:2 = stablehlo.custom_call @lapack_zpotrf(%9, %10, %11, %8) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>]} : (tensor, tensor, tensor, tensor<4x4xcomplex>) -> (tensor<4x4xcomplex>, tensor) loc(#loc9) - %13 = stablehlo.constant dense<0> : tensor loc(#loc9) - %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor loc(#loc9) - %15 = stablehlo.compare EQ, %12#1, %14, SIGNED : (tensor, tensor) -> tensor loc(#loc9) - %16 = stablehlo.broadcast_in_dim %15, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc9) - %17 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc9) - %18 = stablehlo.broadcast_in_dim %17, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc9) - %19 = stablehlo.broadcast_in_dim %16, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc9) - %20 = stablehlo.select %19, %12#0, %18 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc9) - %21 = call @tril(%20) : (tensor<4x4xcomplex>) -> tensor<4x4xcomplex> loc(#loc10) - return %21 : tensor<4x4xcomplex> loc(#loc) - } loc(#loc) - func.func private @tril(%arg0: tensor<4x4xcomplex> loc(unknown)) -> tensor<4x4xcomplex> { - %0 = stablehlo.iota dim = 0 : tensor<4x4xi32> loc(#loc11) - %1 = stablehlo.constant dense<0> : tensor loc(#loc10) - %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<4x4xi32> loc(#loc12) - %3 = stablehlo.add %0, %2 : tensor<4x4xi32> loc(#loc12) - %4 = stablehlo.iota dim = 1 : tensor<4x4xi32> loc(#loc13) - %5 = stablehlo.compare GE, %3, %4, SIGNED : (tensor<4x4xi32>, tensor<4x4xi32>) -> tensor<4x4xi1> loc(#loc14) - %6 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> loc(#loc10) - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc15) - %8 = stablehlo.select %5, %arg0, %7 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc16) - return %8 : tensor<4x4xcomplex> loc(#loc10) - } loc(#loc10) -} loc(#loc) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":292:0) -#loc2 = loc("jit(cholesky)/jit(main)/transpose[permutation=(1, 0)]"(#loc1)) -#loc3 = loc("jit(cholesky)/jit(main)/real"(#loc1)) -#loc4 = loc("jit(cholesky)/jit(main)/imag"(#loc1)) -#loc5 = loc("jit(cholesky)/jit(main)/neg"(#loc1)) -#loc6 = loc("jit(cholesky)/jit(main)/complex"(#loc1)) -#loc7 = loc("jit(cholesky)/jit(main)/add"(#loc1)) -#loc8 = loc("jit(cholesky)/jit(main)/div"(#loc1)) -#loc9 = loc("jit(cholesky)/jit(main)/cholesky"(#loc1)) -#loc10 = loc("jit(cholesky)/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=tril keep_unused=False inline=False]"(#loc1)) -#loc11 = loc("jit(cholesky)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=0]"(#loc1)) -#loc12 = loc("jit(cholesky)/jit(main)/jit(tril)/add"(#loc1)) -#loc13 = loc("jit(cholesky)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=1]"(#loc1)) -#loc14 = loc("jit(cholesky)/jit(main)/jit(tril)/ge"(#loc1)) -#loc15 = loc("jit(cholesky)/jit(main)/jit(tril)/broadcast_in_dim[shape=(4, 4) broadcast_dimensions=()]"(#loc1)) -#loc16 = loc("jit(cholesky)/jit(main)/jit(tril)/select_n"(#loc1)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x011\x05\x01\x03\x01\x03\x05\x03!\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f!#%\x03J\x02\xe9)\x01\x97\x17\x0f\x07\x0b\x13\x0f\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0b\x0f\x13#\x0b\x0b\x0b33\x0b\x0b\x13\x0f\x0b\x0b\x13\x0f\x0b\x1b\x0f\x0b\x13\x0f\x0b\x0f\x0b\x13\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x13\x0b\x0b\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x1b\x13\x13\x13\x0b\x03S\x0f\x0b\x0b\x0f\x0b\x0bO\x0f\x1b\x0b\x0b\x0b\x0b\x0f\x13\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0bOOO\x1f\x1f\x0b\x0b\x0b\x0b\x1b\x0f\x17\x13\x0bOO\x01\x03\x0f\x03'\x17\x0f\x0f\x17\x07\x07\x17\x0b\x07\x07\x17\x13\x07\x17\x13\x13\x13\x0f\x17\x02F\x08\x177\x92\x04\x01\x1dw\x01\x1f\x05'\x03\x03\x1d\xc3\x1d5\x01\x05)\x11\x01\x05\x05+\x05-\x05/\x051\x053\x03\x03\x07\xc1\x055\x1d?\x01\x057\x059\x1du\x01\x03\x03\x07\xcf\x03\x07+\x0f-\x0f\r/\x05;\x05=\x05?\x03\x0b\x11\xa5\x13\x99\x15\xb1\r\xb7\x17\xb9\x03\x0b\x11\x9d\x13\x99\x15\x9d\r\x9f\x17\xbd\x05A\x05C\x03\x03\x19\xbf\x1d=\x01\x05E\x05G\x03\x03\x19\xc5\x1dE\x01\x05I\x03\x05!\xa1#\xc7\x1dK\x01\x05K\x03\x03\x07\xc9\x1dQ\x01\x05M\x1dU\x01\x05O\x03\x03Y\xcb\x05Q\x1d]\x01\x05S\x1da\x01\x05U\x1de\x01\x05W\x1di\x01\x05Y\x1dm\x01\x05[\x1dq\x01\x05]\x03\x03\x07\xcd\x05_\x05a\x03\x03\x07\xd1\x03\x11}\xd3\x7f\x9b\x81\xd5\x83\xd7\x85\xd9\x87\xdb\x89\xdd\x8b\xe1\x05c\x05e\x05g\x05i\x05k\x05m\x05o\x05q\x03\x05!\xa1#\xe3\x03\x03\x07\xe5\x03\x03\x1d\xe7\x03\x03\x95\x9f\x05s\x1f!\x01#\x1d\x1du\x03\x03\xbb\x1dw\t\x07\x1f#!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x03\xa7\r\x05\xa9\xab\xad\xaf\x1dy\x1d{\x1d}\x1d\x7f\x03\x03\xb3\r\x03\xb5\x9b\x1d\x81\x1d\x83\x1d\x85\r\x01\x1d\x87\x13\x0b\x01\x1f\x05\t\x00\x00\x00\x00\x1f\x1f\x01\x13\x0b\x05\x07\x05\x1f\x07!\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x19!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x07!\x00\x00\x00\x00\x00\x00\x00@\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x04\x00\x00\x00\x0b\x05\x1d\x89\x03\x01\x05\x01\x03\t\x97\x97\x97\xa3\x03\x03\xdf\x15\x03\x01\r\x01\x03\x05\xa3\x97\x07\x01\x1f\x07!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x19!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x02\x02)\x05\x11\x11\x11)\x01\x15)\x01\x11)\x05\x11\x11\x15\x1d\x01)\x05\x11\x11\x13\x03\x13\x0b\x1b)\x05\x11\x11\r)\x03\t\x0b\x13\x11\x03\x03\x03\x03)\x03\x01\x0b)\x03\x01\x1b)\x03\t\x1b)\x01\r)\x05\x05\x05\r\x04J\x04\x05\x01\x11\x05)\x07\x03\x01\t\x07\x11\x051\x05\x031_\x03\x03\x05\x13\x07[W\x03\x03\x03\x01\x15\x06_\x03\x0f\x03\x03\x17\x06c\x03\x0f\x03\x03\x19\x06g\x03\x0f\x03\x07\x1b\x06k\x03\x03\x05\x05\t\x0b\x06o\x03\x03\x05\x01\x0b\x03\x03\x05s\x03\x07\x05\x07%\t\x03\x03\x03\x0f\x1d\x06%\x03\x03\x05\r\x11\x03\x03\x03'\x03\x05\x03\x03\x03'\x03\x05\x03\x03\x03y\x03\x05\x1f\x07\x03{\x05\x03\x05\t\x15\x17\x19\x13\x03\x03\x03\x1b\x03\x05\x05\x07\x03\t\x03\x05\x03\x1f\r\x07\x03\x8d\x03%\x05\x1d!\x05\x07\x03\t\x03'\x03#\x03\x03\x03\x8f\x03\x07\x05\x07\x03\t\x03\x03\x03'\x05\x07\x03\x91\x03\x17\x03%\x0f\x06\x03\x03\x03\x07+\x1b)!\x07\x0b\x93\x03\x03\x03-\x11\x04\x05\x03/\x07\x11\x0b3\x05\x03\x15+\x03\x03\x05\t\x03;9\x03\t\x03\x03\x0b\x1b\x03\x05\x05\x07\x1f\t\x03\t\x03\x05\x0b\x06\x1f\x03\t\x05\x03\x07\t\x03CA\x03\t\r\x07IG\x03\x17\x05\t\x0b\x03\x03\x0bM\x03\x07\x05\x07O\t\x03\x03\x03\x0f\x0f\x06S\x03\x03\x07\r\x01\x11\x11\x04\x0b\x03\x13\x06\x03\x01\x05\x01\x00\x96\x18\x8b\x1d\x11\x0f\x0b!\x1b\x1d\x05\x1b\x0b\x03\x0f\x1f/!!)#\x1f\x19C99A9;;m\x19W\xb3K\x9bM\x9b\x97\xd2\x02\x1b%)+\x1b+\x1f\x1f\x15\x1d\x15\x13\r\x11\x1f\x15\x17\x15\x11\x11\x1b\x15\x15\x17\x0f\x11\x11)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00func_v1\x00iota_v1\x00add_v1\x00compare_v1\x00select_v1\x00return_v1\x00transpose_v1\x00real_v1\x00imag_v1\x00negate_v1\x00complex_v1\x00divide_v1\x00custom_call_v1\x00call_v1\x00value\x00sym_name\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00broadcast_dimensions\x00compare_type\x00comparison_direction\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_cholesky\x00jit(cholesky)/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=tril keep_unused=False inline=False]\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00jit(cholesky)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=0]\x00jit(cholesky)/jit(main)/jit(tril)/add\x00jit(cholesky)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=1]\x00jit(cholesky)/jit(main)/jit(tril)/ge\x00jit(cholesky)/jit(main)/jit(tril)/broadcast_in_dim[shape=(4, 4) broadcast_dimensions=()]\x00jit(cholesky)/jit(main)/jit(tril)/select_n\x00permutation\x00jit(cholesky)/jit(main)/transpose[permutation=(1, 0)]\x00jit(cholesky)/jit(main)/real\x00jit(cholesky)/jit(main)/imag\x00jit(cholesky)/jit(main)/neg\x00jit(cholesky)/jit(main)/complex\x00jit(cholesky)/jit(main)/add\x00jit(cholesky)/jit(main)/div\x00jit(cholesky)/jit(main)/cholesky\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00callee\x00\x00tril\x00jax.arg_info\x00x\x00mhlo.sharding\x00{replicated}\x00jax.result_info\x00main\x00public\x00private\x00lapack_zpotrf\x00", - xla_call_module_version=6, -) # End paste - data_2024_05_31 = {} - # Pasted from the test output (see export_back_compat_test_util.py module docstring) data_2024_05_31["c128"] = dict( testdata_version=1, diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_eig_lapack_geev.py b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_eig_lapack_geev.py index bc28857fa325..e6792dc2d1b4 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_eig_lapack_geev.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_eig_lapack_geev.py @@ -15,279 +15,10 @@ # ruff: noqa import datetime -from numpy import array, float32, complex64 - -data_2023_06_19 = {} - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_06_19["f32"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_sgeev'], - serialized_date=datetime.date(2023, 6, 19), - inputs=(), - expected_outputs=(array([ 3.2464241e+01+0.j, -2.4642489e+00+0.j, 1.4189274e-07+0.j, - -4.0686123e-07+0.j], dtype=complex64), array([[-0.40377745 +0.j, -0.82883257 +0.j, -0.06733338 +0.j, - -0.5208027 +0.j], - [-0.46480742 +0.j, -0.4371466 +0.j, 0.49492982 +0.j, - 0.82081676 +0.j], - [-0.52583724 +0.j, -0.045459956+0.j, -0.78785884 +0.j, - -0.07922471 +0.j], - [-0.5868671 +0.j, 0.3462263 +0.j, 0.36026272 +0.j, - -0.2207891 +0.j]], dtype=complex64), array([[-0.11417642+0.j, -0.73277813+0.j, 0.16960056+0.j, - -0.5435681 +0.j], - [-0.33000448+0.j, -0.28974825+0.j, 0.16204938+0.j, - 0.67456985+0.j], - [-0.54583275+0.j, 0.15328142+0.j, -0.8329006 +0.j, - 0.28156415+0.j], - [-0.761661 +0.j, 0.5963111 +0.j, 0.5012507 +0.j, - -0.41256607+0.j]], dtype=complex64)), - mlir_module_text=r""" -module @jit_func attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<4xcomplex> {jax.result_info = "[0]"}, tensor<4x4xcomplex> {jax.result_info = "[1]"}, tensor<4x4xcomplex> {jax.result_info = "[2]"}) { - %0 = stablehlo.iota dim = 0 : tensor<16xf32> loc(#loc3) - %1 = stablehlo.reshape %0 : (tensor<16xf32>) -> tensor<4x4xf32> loc(#loc4) - %2 = stablehlo.constant dense<1> : tensor loc(#loc5) - %3 = stablehlo.constant dense<4> : tensor loc(#loc5) - %4 = stablehlo.constant dense<86> : tensor loc(#loc5) - %5 = stablehlo.constant dense<86> : tensor loc(#loc5) - %6:8 = stablehlo.custom_call @lapack_sgeev(%2, %3, %4, %5, %1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>]} : (tensor, tensor, tensor, tensor, tensor<4x4xf32>) -> (tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xcomplex>, tensor<4x4xcomplex>, tensor) loc(#loc5) - %7 = stablehlo.complex %6#3, %6#4 : tensor<4xcomplex> loc(#loc5) - %8 = stablehlo.constant dense<0> : tensor loc(#loc5) - %9 = stablehlo.broadcast_in_dim %8, dims = [] : (tensor) -> tensor loc(#loc5) - %10 = stablehlo.compare EQ, %6#7, %9, SIGNED : (tensor, tensor) -> tensor loc(#loc5) - %11 = stablehlo.broadcast_in_dim %10, dims = [] : (tensor) -> tensor<1xi1> loc(#loc5) - %12 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc5) - %13 = stablehlo.broadcast_in_dim %12, dims = [] : (tensor>) -> tensor<4xcomplex> loc(#loc5) - %14 = stablehlo.broadcast_in_dim %11, dims = [0] : (tensor<1xi1>) -> tensor<4xi1> loc(#loc5) - %15 = stablehlo.select %14, %7, %13 : tensor<4xi1>, tensor<4xcomplex> loc(#loc5) - %16 = stablehlo.broadcast_in_dim %10, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) - %17 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc5) - %18 = stablehlo.broadcast_in_dim %17, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc5) - %19 = stablehlo.broadcast_in_dim %16, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc5) - %20 = stablehlo.select %19, %6#5, %18 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc5) - %21 = stablehlo.broadcast_in_dim %10, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) - %22 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc5) - %23 = stablehlo.broadcast_in_dim %22, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc5) - %24 = stablehlo.broadcast_in_dim %21, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc5) - %25 = stablehlo.select %24, %6#6, %23 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc5) - return %15, %20, %25 : tensor<4xcomplex>, tensor<4x4xcomplex>, tensor<4x4xcomplex> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc = loc(unknown) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":496:0) -#loc2 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":497:0) -#loc3 = loc("jit(func)/jit(main)/iota[dtype=float32 shape=(16,) dimension=0]"(#loc1)) -#loc4 = loc("jit(func)/jit(main)/reshape[new_sizes=(4, 4) dimensions=None]"(#loc1)) -#loc5 = loc("jit(func)/jit(main)/eig[compute_left_eigenvectors=True compute_right_eigenvectors=True]"(#loc2)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01%\x05\x01\x03\x01\x03\x05\x03\x15\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x03\xe7\x9b9\x01[\x0f\x13\x0b\x07\x0b\x13\x0f\x0b\x17\x0b\x13\x13#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0f\x0b\x13\x0b\x17\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x03AO\x0f\x0b\x0b/\x0b\x17\x13\x0b\x13\x0b\x13\x0b\x0b\x0b\x0f\x1f\x1f\x13\x0b\x0b\x0b\x0b\x1f+\x1f\x0f\x0b\x0b//O\x01\x03\x0f\x037\x17\x0f\x07\x13\x07\x07\x17\x0f\x0b\x0f\x07\x13\x17\x17\x1b\x13\x07\x07\x13\x13\x13\x13\x0f\x13\x13\x13\x13\x02v\x06\x1d9;\x03\x03\t\x8f\x05\x1b\x1f\x05\x1d\x03\x03\x05\x95\x11\x01\x05\x05\x1f\x17\x13\xc2\x07\x01\x05!\x03\x03\x05\x7f\x03\x03\t\x99\x03\x07\x1b\r\x1d\r\x0f\x1f\x05#\x05%\x05'\x03\x0b#_%e'g\x0fu)w\x05)\x05+\x05-\x05/\x03\x03-y\x051\x1d1\x11\x053\x1d5\x11\x055\x03\x03\x05{\x057\x17\x13\xc6\x07\x01\x03\x03\x05}\x03\x11A\x81C\x83E\x85G_I\x87K\x89M_O\x8b\x059\x05;\x05=\x05?\x05A\x05C\x05E\x05G\x03\x03\x05\x8d\x03\x05U\x91W\x93\x05I\x05K\x03\x03\t\x97\x1f)!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f'\x01\x03\x01\x1dM\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00#\x1f\x03\x07imq\r\x03ak\x1dO\r\x03ao\x1dQ\r\x03as\x1dS\x1dU\x1dW\x13\r\x01\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x04\x00\x00\x00\x1f\x15\x03V\x0b\x05\x1dY\x1d[\x05\x01\x03\x0b]]]][\x03\x11[[[cc[[]\x1f\x05\t\x00\x00\x00\x00\x1f-\x01\t\x07\x07\x01\x1f\x11\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f5\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f7!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x02\x02)\x05\x11\x11\x13)\x01#\x01)\x03\x11\x13\t\x1d)\x05\x11\x11\x0b)\x01\x13\x03\x0b)\x01%\x13)\x03\x11\x0b)\x05\x05\x05\x07)\x05\x11\x11\x07\x11\x01\x07\t\x03\x03)\x03A\x0b\x1b!)\x03\x01\x17)\x03\t\x17)\x03\x05\x17)\x03\x01\r)\x01\x07)\x03\x05\x07)\x03\x11\x07)\x03\x05\r)\x03\t\r\x04\x92\x03\x05\x01\x11\x07\x19\x07\x03\x01\x05\t\x11\x07!\x05\x03Cm\x0b\x03/+\x03!\r\x063\x03\x0f\x03\x01\x05\x03\x017\x03\x05\x05\x03\x01=\x03\x05\x05\x03\x01\x15\x03\x15\x05\x03\x01\x15\x03\x15\x0f\x07\x01?\x11\x0f\x0f\x0f\x19\x19\x03\x03\x05\x0b\x05\x07\t\x0b\x03\x11\x06\x01\x03\t\x05\x13\x15\x05\x03\x01Q\x03\x05\x03\x07\x01\x03\x03\x05\x03\x1f\x13\x07\x01S\x03/\x05\x1b!\x03\x07\x01\x03\x031\x03#\x05\x03\x01\x0b\x03\x11\x03\x07\x01\x03\x03\t\x03'\x03\x07\x01Y\x033\x03%\x07\x06\x01\x03\t\x07+\x1d)\x03\x07\x01\x03\x03\x1b\x03#\x05\x03\x01\x0b\x03\x11\x03\x07\x01\x03\x03\x03\x031\x03\x07\x01\x17\x03\x1d\x03/\x07\x06\x01\x03\x03\x075\x173\x03\x07\x01\x03\x03\x1b\x03#\x05\x03\x01\x0b\x03\x11\x03\x07\x01\x03\x03\x03\x03;\x03\x07\x01\x17\x03\x1d\x039\x07\x06\x01\x03\x03\x07?\x19=\x15\x04\x07\x07-7A\x06\x03\x01\x05\x01\x00&\r]\x1b\x03\x0f\x0b\t\t\t!+\x1b\x1f/!!)#\x1f\x19\xb1}\x81\x1f\x1f\x15\x1d\x15\x13%)\x97\x13+\r\x15\x17\x17\x1f\x17\x11\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00func_v1\x00iota_v1\x00reshape_v1\x00custom_call_v1\x00complex_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00jit(func)/jit(main)/iota[dtype=float32 shape=(16,) dimension=0]\x00jit(func)/jit(main)/reshape[new_sizes=(4, 4) dimensions=None]\x00jit(func)/jit(main)/eig[compute_left_eigenvectors=True compute_right_eigenvectors=True]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_sgeev\x00", - xla_call_module_version=6, -) # End paste - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_06_19["f64"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_dgeev'], - serialized_date=datetime.date(2023, 6, 19), - inputs=(), - expected_outputs=(array([ 3.2464249196572972e+01+0.j, -2.4642491965729802e+00+0.j, - -1.5210037805054253e-15+0.j, 1.2568096307462507e-16+0.j]), array([[-0.4037774907686232 +0.j, 0.8288327563197505 +0.j, - 0.5454962288885842 +0.j, -0.2420483778598153 +0.j], - [-0.46480737115848986 +0.j, 0.43714638836388725 +0.j, - -0.7640998541831632 +0.j, -0.04349021275982002 +0.j], - [-0.5258372515483576 +0.j, 0.045460020408024715+0.j, - -0.10828897829942748 +0.j, 0.8131255590990858 +0.j], - [-0.5868671319382249 +0.j, -0.3462263475478384 +0.j, - 0.32689260359400607 +0.j, -0.5275869684794504 +0.j]]), array([[-0.11417645138733863+0.j, 0.7327780959803554 +0.j, - 0.49133754464261303+0.j, -0.04933420991901029+0.j], - [-0.33000459866554765+0.j, 0.28974835239692637+0.j, - -0.8355289351028521 +0.j, -0.3408099365295394 +0.j], - [-0.545832745943757 +0.j, -0.1532813911865017 +0.j, - 0.1970452362778633 +0.j, 0.8296225028161098 +0.j], - [-0.7616608932219663 +0.j, -0.5963111347699308 +0.j, - 0.14714615418237506+0.j, -0.43947835636755994+0.j]])), - mlir_module_text=r""" -module @jit_func attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<4xcomplex> {jax.result_info = "[0]"}, tensor<4x4xcomplex> {jax.result_info = "[1]"}, tensor<4x4xcomplex> {jax.result_info = "[2]"}) { - %0 = stablehlo.iota dim = 0 : tensor<16xf64> loc(#loc3) - %1 = stablehlo.reshape %0 : (tensor<16xf64>) -> tensor<4x4xf64> loc(#loc4) - %2 = stablehlo.constant dense<1> : tensor loc(#loc5) - %3 = stablehlo.constant dense<4> : tensor loc(#loc5) - %4 = stablehlo.constant dense<86> : tensor loc(#loc5) - %5 = stablehlo.constant dense<86> : tensor loc(#loc5) - %6:8 = stablehlo.custom_call @lapack_dgeev(%2, %3, %4, %5, %1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>]} : (tensor, tensor, tensor, tensor, tensor<4x4xf64>) -> (tensor<4x4xf64>, tensor<4x4xf64>, tensor<4x4xf64>, tensor<4xf64>, tensor<4xf64>, tensor<4x4xcomplex>, tensor<4x4xcomplex>, tensor) loc(#loc5) - %7 = stablehlo.complex %6#3, %6#4 : tensor<4xcomplex> loc(#loc5) - %8 = stablehlo.constant dense<0> : tensor loc(#loc5) - %9 = stablehlo.broadcast_in_dim %8, dims = [] : (tensor) -> tensor loc(#loc5) - %10 = stablehlo.compare EQ, %6#7, %9, SIGNED : (tensor, tensor) -> tensor loc(#loc5) - %11 = stablehlo.broadcast_in_dim %10, dims = [] : (tensor) -> tensor<1xi1> loc(#loc5) - %12 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc5) - %13 = stablehlo.broadcast_in_dim %12, dims = [] : (tensor>) -> tensor<4xcomplex> loc(#loc5) - %14 = stablehlo.broadcast_in_dim %11, dims = [0] : (tensor<1xi1>) -> tensor<4xi1> loc(#loc5) - %15 = stablehlo.select %14, %7, %13 : tensor<4xi1>, tensor<4xcomplex> loc(#loc5) - %16 = stablehlo.broadcast_in_dim %10, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) - %17 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc5) - %18 = stablehlo.broadcast_in_dim %17, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc5) - %19 = stablehlo.broadcast_in_dim %16, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc5) - %20 = stablehlo.select %19, %6#5, %18 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc5) - %21 = stablehlo.broadcast_in_dim %10, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) - %22 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc5) - %23 = stablehlo.broadcast_in_dim %22, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc5) - %24 = stablehlo.broadcast_in_dim %21, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc5) - %25 = stablehlo.select %24, %6#6, %23 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc5) - return %15, %20, %25 : tensor<4xcomplex>, tensor<4x4xcomplex>, tensor<4x4xcomplex> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc = loc(unknown) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":496:0) -#loc2 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":497:0) -#loc3 = loc("jit(func)/jit(main)/iota[dtype=float64 shape=(16,) dimension=0]"(#loc1)) -#loc4 = loc("jit(func)/jit(main)/reshape[new_sizes=(4, 4) dimensions=None]"(#loc1)) -#loc5 = loc("jit(func)/jit(main)/eig[compute_left_eigenvectors=True compute_right_eigenvectors=True]"(#loc2)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01%\x05\x01\x03\x01\x03\x05\x03\x15\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x03\xe7\x9b9\x01[\x0f\x13\x0b\x07\x0b\x13\x0f\x0b\x17\x0b\x13\x13#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0f\x0b\x13\x0b\x17\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x03AO\x0f\x0b\x0b/\x0b\x17\x13\x0b\x13\x0b\x13\x0b\x0b\x0b\x0f\x1f\x1f\x13\x0b\x0b\x0b\x0b\x1f+\x1f\x0f\x0b\x0bO/O\x01\x03\x0f\x037\x17\x0f\x07\x13\x07\x07\x17\x0f\x0b\x0f\x07\x13\x17\x17\x1b\x13\x07\x07\x13\x13\x13\x13\x0f\x13\x13\x13\x13\x02\x96\x06\x1d9;\x03\x03\t\x8f\x05\x1b\x1f\x05\x1d\x03\x03\x05\x95\x11\x01\x05\x05\x1f\x17\x13\xc2\x07\x01\x05!\x03\x03\x05\x7f\x03\x03\t\x99\x03\x07\x1b\r\x1d\r\x0f\x1f\x05#\x05%\x05'\x03\x0b#_%e'g\x0fu)w\x05)\x05+\x05-\x05/\x03\x03-y\x051\x1d1\x11\x053\x1d5\x11\x055\x03\x03\x05{\x057\x17\x13\xc6\x07\x01\x03\x03\x05}\x03\x11A\x81C\x83E\x85G_I\x87K\x89M_O\x8b\x059\x05;\x05=\x05?\x05A\x05C\x05E\x05G\x03\x03\x05\x8d\x03\x05U\x91W\x93\x05I\x05K\x03\x03\t\x97\x1f)!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f'\x01\x03\x01\x1dM\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00#\x1f\x03\x07imq\r\x03ak\x1dO\r\x03ao\x1dQ\r\x03as\x1dS\x1dU\x1dW\x13\r\x01\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x04\x00\x00\x00\x1f\x15\x03V\x0b\x05\x1dY\x1d[\x05\x01\x03\x0b]]]][\x03\x11[[[cc[[]\x1f\x05\t\x00\x00\x00\x00\x1f-\x01\t\x07\x07\x01\x1f\x11!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f5\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f7!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x02\x02)\x05\x11\x11\x13)\x01#\x01)\x03\x11\x13\x0b\x1d)\x05\x11\x11\x0b)\x01\x13\x03\x0b)\x01%\x13)\x03\x11\x0b)\x05\x05\x05\x07)\x05\x11\x11\x07\x11\x01\x07\t\x03\x03)\x03A\x0b\x1b!)\x03\x01\x17)\x03\t\x17)\x03\x05\x17)\x03\x01\r)\x01\x07)\x03\x05\x07)\x03\x11\x07)\x03\x05\r)\x03\t\r\x04\x92\x03\x05\x01\x11\x07\x19\x07\x03\x01\x05\t\x11\x07!\x05\x03Cm\x0b\x03/+\x03!\r\x063\x03\x0f\x03\x01\x05\x03\x017\x03\x05\x05\x03\x01=\x03\x05\x05\x03\x01\x15\x03\x15\x05\x03\x01\x15\x03\x15\x0f\x07\x01?\x11\x0f\x0f\x0f\x19\x19\x03\x03\x05\x0b\x05\x07\t\x0b\x03\x11\x06\x01\x03\t\x05\x13\x15\x05\x03\x01Q\x03\x05\x03\x07\x01\x03\x03\x05\x03\x1f\x13\x07\x01S\x03/\x05\x1b!\x03\x07\x01\x03\x031\x03#\x05\x03\x01\x0b\x03\x11\x03\x07\x01\x03\x03\t\x03'\x03\x07\x01Y\x033\x03%\x07\x06\x01\x03\t\x07+\x1d)\x03\x07\x01\x03\x03\x1b\x03#\x05\x03\x01\x0b\x03\x11\x03\x07\x01\x03\x03\x03\x031\x03\x07\x01\x17\x03\x1d\x03/\x07\x06\x01\x03\x03\x075\x173\x03\x07\x01\x03\x03\x1b\x03#\x05\x03\x01\x0b\x03\x11\x03\x07\x01\x03\x03\x03\x03;\x03\x07\x01\x17\x03\x1d\x039\x07\x06\x01\x03\x03\x07?\x19=\x15\x04\x07\x07-7A\x06\x03\x01\x05\x01\x00&\r]\x1b\x03\x0f\x0b\t\t\t!+\x1b\x1f/!!)#\x1f\x19\xb1}\x81\x1f\x1f\x15\x1d\x15\x13%)\x97\x13+\r\x15\x17\x17\x1f\x17\x11\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00func_v1\x00iota_v1\x00reshape_v1\x00custom_call_v1\x00complex_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00jit(func)/jit(main)/iota[dtype=float64 shape=(16,) dimension=0]\x00jit(func)/jit(main)/reshape[new_sizes=(4, 4) dimensions=None]\x00jit(func)/jit(main)/eig[compute_left_eigenvectors=True compute_right_eigenvectors=True]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_dgeev\x00", - xla_call_module_version=6, -) # End paste - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_06_19["c64"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_cgeev'], - serialized_date=datetime.date(2023, 6, 19), - inputs=(), - expected_outputs=(array([ 3.2464237e+01+0.j, -2.4642489e+00+0.j, -5.7737714e-07+0.j, - 1.4719126e-07+0.j], dtype=complex64), array([[ 0.4037776 +0.j, 0.8288327 +0.j, -0.53126234 -0.j, - 0.052026853-0.j], - [ 0.46480742 +0.j, 0.43714646 -0.j, 0.80768156 +0.j, - -0.47577178 -0.j], - [ 0.52583724 +0.j, 0.045459922-0.j, -0.021575088-0.j, - 0.79546237 +0.j], - [ 0.5868671 +0.j, -0.3462263 -0.j, -0.25484383 -0.j, - -0.3717177 -0.j]], dtype=complex64), array([[ 0.114176475+0.j, 0.7327782 +0.j, -0.5452461 -0.j, - -0.13326685 -0.j], - [ 0.3300045 +0.j, 0.28974816 -0.j, 0.68821603 +0.j, - -0.2182906 -0.j], - [ 0.5458328 +0.j, -0.1532814 -0.j, 0.25930583 -0.j, - 0.8363818 +0.j], - [ 0.76166093 +0.j, -0.5963111 -0.j, -0.40227592 -0.j, - -0.4848244 -0.j]], dtype=complex64)), - mlir_module_text=r""" -module @jit_func attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<4xcomplex> {jax.result_info = "[0]"}, tensor<4x4xcomplex> {jax.result_info = "[1]"}, tensor<4x4xcomplex> {jax.result_info = "[2]"}) { - %0 = stablehlo.iota dim = 0 : tensor<16xcomplex> loc(#loc3) - %1 = stablehlo.reshape %0 : (tensor<16xcomplex>) -> tensor<4x4xcomplex> loc(#loc4) - %2 = stablehlo.constant dense<1> : tensor loc(#loc5) - %3 = stablehlo.constant dense<4> : tensor loc(#loc5) - %4 = stablehlo.constant dense<86> : tensor loc(#loc5) - %5 = stablehlo.constant dense<86> : tensor loc(#loc5) - %6:6 = stablehlo.custom_call @lapack_cgeev(%2, %3, %4, %5, %1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>]} : (tensor, tensor, tensor, tensor, tensor<4x4xcomplex>) -> (tensor<4x4xcomplex>, tensor<8xf32>, tensor<4xcomplex>, tensor<4x4xcomplex>, tensor<4x4xcomplex>, tensor) loc(#loc5) - %7 = stablehlo.constant dense<0> : tensor loc(#loc5) - %8 = stablehlo.broadcast_in_dim %7, dims = [] : (tensor) -> tensor loc(#loc5) - %9 = stablehlo.compare EQ, %6#5, %8, SIGNED : (tensor, tensor) -> tensor loc(#loc5) - %10 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor) -> tensor<1xi1> loc(#loc5) - %11 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc5) - %12 = stablehlo.broadcast_in_dim %11, dims = [] : (tensor>) -> tensor<4xcomplex> loc(#loc5) - %13 = stablehlo.broadcast_in_dim %10, dims = [0] : (tensor<1xi1>) -> tensor<4xi1> loc(#loc5) - %14 = stablehlo.select %13, %6#2, %12 : tensor<4xi1>, tensor<4xcomplex> loc(#loc5) - %15 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) - %16 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc5) - %17 = stablehlo.broadcast_in_dim %16, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc5) - %18 = stablehlo.broadcast_in_dim %15, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc5) - %19 = stablehlo.select %18, %6#3, %17 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc5) - %20 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) - %21 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc5) - %22 = stablehlo.broadcast_in_dim %21, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc5) - %23 = stablehlo.broadcast_in_dim %20, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc5) - %24 = stablehlo.select %23, %6#4, %22 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc5) - return %14, %19, %24 : tensor<4xcomplex>, tensor<4x4xcomplex>, tensor<4x4xcomplex> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc = loc(unknown) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":496:0) -#loc2 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":497:0) -#loc3 = loc("jit(func)/jit(main)/iota[dtype=complex64 shape=(16,) dimension=0]"(#loc1)) -#loc4 = loc("jit(func)/jit(main)/reshape[new_sizes=(4, 4) dimensions=None]"(#loc1)) -#loc5 = loc("jit(func)/jit(main)/eig[compute_left_eigenvectors=True compute_right_eigenvectors=True]"(#loc2)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01#\x05\x01\x03\x01\x03\x05\x03\x13\x07\t\x0b\r\x0f\x11\x13\x15\x17\x03\xe5\x9b7\x01[\x0f\x13\x0b\x07\x0b\x13\x0f\x0b\x17\x0b\x13\x13#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0f\x0b\x13\x0b\x17\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x03A\x0fO\x0b\x0b/\x0b\x17\x13\x0b\x13\x0b\x13\x0b\x0b\x0b\x0f\x1f\x1f\x13\x0b\x0b\x0b\x0b\x1f#\x1f\x0f\x0b\x0b//O\x01\x03\x0f\x035\x17\x0f\x07\x13\x0b\x07\x0f\x0f\x07\x07\x17\x17\x1b\x13\x07\x07\x13\x13\x13\x13\x13\x0f\x13\x13\x13\x13\x02Z\x06\x1d9;\x03\x03\t\x8f\x05\x19\x1f\x05\x1b\x03\x03\x05\x95\x11\x01\x05\x05\x1d\x17\x13\xc2\x07\x01\x05\x1f\x03\x03\x05\x7f\x03\x03\t\x99\x03\x07\x1b\r\x1d\r\x0f\x1f\x05!\x05#\x05%\x03\x0b#_%e'g\x0fu)w\x05'\x05)\x05+\x05-\x03\x03-y\x05/\x1d1\x11\x051\x1d5\x11\x053\x03\x03\x05{\x055\x17\x13\xc6\x07\x01\x03\x03\x05}\x03\x11A\x81C\x83E\x85G_I\x87K\x89M_O\x8b\x057\x059\x05;\x05=\x05?\x05A\x05C\x05E\x03\x03\x05\x8d\x03\x05U\x91W\x93\x05G\x05I\x03\x03\t\x97\x1f%\x01\x1f'!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1dK\x1f)\x11\x00\x00\x00\x00\x00\x00\x00\x00#\x1b\x03\x07imq\r\x03ak\x1dM\r\x03ao\x1dO\r\x03as\x1dQ\x1dS\x1dU\x13\r\x01\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x04\x00\x00\x00\x1f\x11\x03V\x0b\x05\x1dW\x1dY\x05\x01\x03\x0b[[[[]\x03\r]cc]][\x1f\x05\t\x00\x00\x00\x00\x1f+\x01\t\x07\x07\x01\x1f\x0f\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f3\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f5!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x02\x02)\x05\x11\x11\x0b)\x01\x1f\x01)\x03\x11\x0b\x03\x15\x1d)\x01\x0b)\x01!\x13\t)\x05\x05\x05\x07)\x05\x11\x11\x07\x11\x01\x07\t\x03\x03)\x03A\x0b\x1b!)\x03!\x15)\x03\x01\x13)\x03\t\x13)\x03\x05\x13)\x03\x01\r)\x01\x07)\x03\x05\x07)\x03\x11\x07)\x03\x05\r)\x03\t\r\x04j\x03\x05\x01\x11\x07\x19\x07\x03\x01\x05\t\x11\x07!\x05\x03=i\x0b\x03/+\x03\x1d\r\x063\x03\x03\x03\x01\x05\x03\x017\x03\x05\x05\x03\x01=\x03\x05\x05\x03\x01\x15\x03\x11\x05\x03\x01\x15\x03\x11\x0f\x07\x01?\r\x03#\t\x03\x03\x05\x0b\x05\x07\t\x0b\x03\x05\x03\x01Q\x03\x05\x03\x07\x01\x03\x03\x05\x03\x19\x11\x07\x01S\x03-\x05\x17\x1b\x03\x07\x01\x03\x03/\x03\x1d\x05\x03\x01\x0b\x03\x0f\x03\x07\x01\x03\x03\t\x03!\x03\x07\x01Y\x031\x03\x1f\x07\x06\x01\x03\t\x07%\x11#\x03\x07\x01\x03\x03\x17\x03\x1d\x05\x03\x01\x0b\x03\x0f\x03\x07\x01\x03\x03\x03\x03+\x03\x07\x01\x17\x03\x19\x03)\x07\x06\x01\x03\x03\x07/\x13-\x03\x07\x01\x03\x03\x17\x03\x1d\x05\x03\x01\x0b\x03\x0f\x03\x07\x01\x03\x03\x03\x035\x03\x07\x01\x17\x03\x19\x033\x07\x06\x01\x03\x03\x079\x157\x13\x04\x07\x07'1;\x06\x03\x01\x05\x01\x00\xfe\x0c[\x1b\x03\x0f\x0b\t\t\t!+\x1b\x1f/!!)#\x1f\x19\xb1}\x85\x1f\x1f\x15\x1d\x15\x13%)\x97\x13+\r\x15\x17\x1f\x17\x11\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00func_v1\x00iota_v1\x00reshape_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00jit(func)/jit(main)/iota[dtype=complex64 shape=(16,) dimension=0]\x00jit(func)/jit(main)/reshape[new_sizes=(4, 4) dimensions=None]\x00jit(func)/jit(main)/eig[compute_left_eigenvectors=True compute_right_eigenvectors=True]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_cgeev\x00", - xla_call_module_version=6, -) # End paste - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_06_19["c128"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_zgeev'], - serialized_date=datetime.date(2023, 6, 19), - inputs=(), - expected_outputs=(array([ 3.2464249196572965e+01+0.j, -2.4642491965729807e+00+0.j, - -1.6035677295293283e-15+0.j, 1.2218554396786611e-16+0.j]), array([[ 0.40377749076862335 +0.j, 0.8288327563197505 +0.j, - -0.5457111210844892 +0.j, -0.2322136424094458 -0.j], - [ 0.46480737115848997 +0.j, 0.4371463883638875 -0.j, - 0.7625701354883243 +0.j, -0.06012408092789514 -0.j], - [ 0.5258372515483578 +0.j, 0.045460020408024694-0.j, - 0.1119930922768192 +0.j, 0.8168890890841272 +0.j], - [ 0.5868671319382247 +0.j, -0.34622634754783854 -0.j, - -0.32885210668065423 +0.j, -0.5245513657467864 -0.j]]), array([[ 0.11417645138733871+0.j, 0.7327780959803554 +0.j, - -0.49606131100796214+0.j, -0.04689746607984153-0.j], - [ 0.3300045986655476 +0.j, 0.2897483523969264 -0.j, - 0.8344969112540657 +0.j, -0.34421909950105706-0.j], - [ 0.5458327459437571 +0.j, -0.15328139118650172-0.j, - -0.18080988948424467+0.j, 0.8291305972416383 +0.j], - [ 0.7616608932219663 +0.j, -0.5963111347699308 -0.j, - -0.1576257107618584 +0.j, -0.4380140316607401 -0.j]])), - mlir_module_text=r""" -module @jit_func attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<4xcomplex> {jax.result_info = "[0]"}, tensor<4x4xcomplex> {jax.result_info = "[1]"}, tensor<4x4xcomplex> {jax.result_info = "[2]"}) { - %0 = stablehlo.iota dim = 0 : tensor<16xcomplex> loc(#loc3) - %1 = stablehlo.reshape %0 : (tensor<16xcomplex>) -> tensor<4x4xcomplex> loc(#loc4) - %2 = stablehlo.constant dense<1> : tensor loc(#loc5) - %3 = stablehlo.constant dense<4> : tensor loc(#loc5) - %4 = stablehlo.constant dense<86> : tensor loc(#loc5) - %5 = stablehlo.constant dense<86> : tensor loc(#loc5) - %6:6 = stablehlo.custom_call @lapack_zgeev(%2, %3, %4, %5, %1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>]} : (tensor, tensor, tensor, tensor, tensor<4x4xcomplex>) -> (tensor<4x4xcomplex>, tensor<8xf64>, tensor<4xcomplex>, tensor<4x4xcomplex>, tensor<4x4xcomplex>, tensor) loc(#loc5) - %7 = stablehlo.constant dense<0> : tensor loc(#loc5) - %8 = stablehlo.broadcast_in_dim %7, dims = [] : (tensor) -> tensor loc(#loc5) - %9 = stablehlo.compare EQ, %6#5, %8, SIGNED : (tensor, tensor) -> tensor loc(#loc5) - %10 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor) -> tensor<1xi1> loc(#loc5) - %11 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc5) - %12 = stablehlo.broadcast_in_dim %11, dims = [] : (tensor>) -> tensor<4xcomplex> loc(#loc5) - %13 = stablehlo.broadcast_in_dim %10, dims = [0] : (tensor<1xi1>) -> tensor<4xi1> loc(#loc5) - %14 = stablehlo.select %13, %6#2, %12 : tensor<4xi1>, tensor<4xcomplex> loc(#loc5) - %15 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) - %16 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc5) - %17 = stablehlo.broadcast_in_dim %16, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc5) - %18 = stablehlo.broadcast_in_dim %15, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc5) - %19 = stablehlo.select %18, %6#3, %17 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc5) - %20 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) - %21 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc5) - %22 = stablehlo.broadcast_in_dim %21, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc5) - %23 = stablehlo.broadcast_in_dim %20, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc5) - %24 = stablehlo.select %23, %6#4, %22 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc5) - return %14, %19, %24 : tensor<4xcomplex>, tensor<4x4xcomplex>, tensor<4x4xcomplex> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc = loc(unknown) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":496:0) -#loc2 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":497:0) -#loc3 = loc("jit(func)/jit(main)/iota[dtype=complex128 shape=(16,) dimension=0]"(#loc1)) -#loc4 = loc("jit(func)/jit(main)/reshape[new_sizes=(4, 4) dimensions=None]"(#loc1)) -#loc5 = loc("jit(func)/jit(main)/eig[compute_left_eigenvectors=True compute_right_eigenvectors=True]"(#loc2)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01#\x05\x01\x03\x01\x03\x05\x03\x13\x07\t\x0b\r\x0f\x11\x13\x15\x17\x03\xe5\x9b7\x01[\x0f\x13\x0b\x07\x0b\x13\x0f\x0b\x17\x0b\x13\x13#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0f\x0b\x13\x0b\x17\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x03A\x0fO\x0b\x0b/\x0b\x17\x13\x0b\x13\x0b\x13\x0b\x0b\x0b\x0f\x1f\x1f\x13\x0b\x0b\x0b\x0b\x1f#\x1f\x0f\x0b\x0bO/O\x01\x03\x0f\x035\x17\x0f\x07\x13\x0b\x07\x0f\x0f\x07\x07\x17\x17\x1b\x13\x07\x07\x13\x13\x13\x13\x13\x0f\x13\x13\x13\x13\x02z\x06\x1d9;\x03\x03\t\x8f\x05\x19\x1f\x05\x1b\x03\x03\x05\x95\x11\x01\x05\x05\x1d\x17\x13\xc2\x07\x01\x05\x1f\x03\x03\x05\x7f\x03\x03\t\x99\x03\x07\x1b\r\x1d\r\x0f\x1f\x05!\x05#\x05%\x03\x0b#_%e'g\x0fu)w\x05'\x05)\x05+\x05-\x03\x03-y\x05/\x1d1\x11\x051\x1d5\x11\x053\x03\x03\x05{\x055\x17\x13\xc6\x07\x01\x03\x03\x05}\x03\x11A\x81C\x83E\x85G_I\x87K\x89M_O\x8b\x057\x059\x05;\x05=\x05?\x05A\x05C\x05E\x03\x03\x05\x8d\x03\x05U\x91W\x93\x05G\x05I\x03\x03\t\x97\x1f%\x01\x1f'!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1dK\x1f)\x11\x00\x00\x00\x00\x00\x00\x00\x00#\x1b\x03\x07imq\r\x03ak\x1dM\r\x03ao\x1dO\r\x03as\x1dQ\x1dS\x1dU\x13\r\x01\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x04\x00\x00\x00\x1f\x11\x03V\x0b\x05\x1dW\x1dY\x05\x01\x03\x0b[[[[]\x03\r]cc]][\x1f\x05\t\x00\x00\x00\x00\x1f+\x01\t\x07\x07\x01\x1f\x0f!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f3\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f5!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x02\x02)\x05\x11\x11\x0b)\x01\x1f\x01)\x03\x11\x0b\x03\x15\x1d)\x01\x0b)\x01!\x13\x0b)\x05\x05\x05\x07)\x05\x11\x11\x07\x11\x01\x07\t\x03\x03)\x03A\x0b\x1b!)\x03!\x15)\x03\x01\x13)\x03\t\x13)\x03\x05\x13)\x03\x01\r)\x01\x07)\x03\x05\x07)\x03\x11\x07)\x03\x05\r)\x03\t\r\x04j\x03\x05\x01\x11\x07\x19\x07\x03\x01\x05\t\x11\x07!\x05\x03=i\x0b\x03/+\x03\x1d\r\x063\x03\x03\x03\x01\x05\x03\x017\x03\x05\x05\x03\x01=\x03\x05\x05\x03\x01\x15\x03\x11\x05\x03\x01\x15\x03\x11\x0f\x07\x01?\r\x03#\t\x03\x03\x05\x0b\x05\x07\t\x0b\x03\x05\x03\x01Q\x03\x05\x03\x07\x01\x03\x03\x05\x03\x19\x11\x07\x01S\x03-\x05\x17\x1b\x03\x07\x01\x03\x03/\x03\x1d\x05\x03\x01\x0b\x03\x0f\x03\x07\x01\x03\x03\t\x03!\x03\x07\x01Y\x031\x03\x1f\x07\x06\x01\x03\t\x07%\x11#\x03\x07\x01\x03\x03\x17\x03\x1d\x05\x03\x01\x0b\x03\x0f\x03\x07\x01\x03\x03\x03\x03+\x03\x07\x01\x17\x03\x19\x03)\x07\x06\x01\x03\x03\x07/\x13-\x03\x07\x01\x03\x03\x17\x03\x1d\x05\x03\x01\x0b\x03\x0f\x03\x07\x01\x03\x03\x03\x035\x03\x07\x01\x17\x03\x19\x033\x07\x06\x01\x03\x03\x079\x157\x13\x04\x07\x07'1;\x06\x03\x01\x05\x01\x00\x02\r[\x1b\x03\x0f\x0b\t\t\t!+\x1b\x1f/!!)#\x1f\x19\xb1}\x87\x1f\x1f\x15\x1d\x15\x13%)\x97\x13+\r\x15\x17\x1f\x17\x11\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00func_v1\x00iota_v1\x00reshape_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00jit(func)/jit(main)/iota[dtype=complex128 shape=(16,) dimension=0]\x00jit(func)/jit(main)/reshape[new_sizes=(4, 4) dimensions=None]\x00jit(func)/jit(main)/eig[compute_left_eigenvectors=True compute_right_eigenvectors=True]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_zgeev\x00", - xla_call_module_version=6, -) # End paste - +from numpy import array, complex64 data_2024_08_19 = {} - # Pasted from the test output (see export_back_compat_test_util.py module docstring) data_2024_08_19["c128"] = dict( testdata_version=1, diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_eigh_lapack_syev.py b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_eigh_lapack_syev.py index f0696db1aeda..cd5f5c55caf9 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_eigh_lapack_syev.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_eigh_lapack_syev.py @@ -17,376 +17,8 @@ import datetime from numpy import array, float32, complex64 -data_2023_03_17 = dict( - # Pasted from the test output (see back_compat_test.py module docstring) - f32=dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_ssyevd'], - serialized_date=datetime.date(2023, 3, 17), - inputs=(), - expected_outputs=(array([[-0.6185769 , -0.20142993 , -0.09725195 , 0.62983674 , - -0.07926044 , 0.3605001 , -0.019093221 , -0.18446997 ], - [-0.47070873 , 0.29325768 , -0.19454119 , -0.6394365 , - 0.0622955 , 0.33249345 , 0.28112718 , -0.22856665 ], - [-0.32284075 , -0.12361939 , 0.20547704 , -0.18307868 , - 0.47294614 , -0.3170349 , -0.6373532 , -0.27266347 ], - [-0.17497246 , -0.079641335 , 0.15042791 , -0.15416273 , - -0.815209 , -0.38054234 , -0.083263926 , -0.31676024 ], - [-0.027104253 , -0.26490977 , 0.32271704 , 0.08653544 , - 0.30305928 , -0.33998996 , 0.6926741 , -0.360857 ], - [ 0.12076397 , 0.43288827 , -0.64385164 , 0.2652551 , - 0.09482376 , -0.37435007 , 0.00091664493, -0.40495378 ], - [ 0.26863196 , 0.51607686 , 0.53846526 , 0.16969058 , - -0.021670295 , 0.35755336 , -0.113144726 , -0.4490505 ], - [ 0.4165004 , -0.57262254 , -0.2814425 , -0.17463988 , - -0.01698498 , 0.3613705 , -0.12186296 , -0.49314725 ]], - dtype=float32), array([-2.4598808e+01, -3.3105560e-05, -3.1002426e-05, -1.0103593e-05, - -1.0022322e-05, 4.0141886e-06, 9.5510331e-06, 2.7659882e+02], - dtype=float32)), - mlir_module_text=r""" -module @jit__lambda_ { - func.func public @main() -> (tensor<8x8xf32> {jax.result_info = "[0]"}, tensor<8xf32> {jax.result_info = "[1]"}) { - %0 = stablehlo.iota dim = 0 : tensor<64xf32> - %1 = stablehlo.reshape %0 : (tensor<64xf32>) -> tensor<8x8xf32> - %2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<8x8xf32>) -> tensor<8x8xf32> - %3 = stablehlo.add %1, %2 : tensor<8x8xf32> - %4 = stablehlo.constant dense<2.000000e+00> : tensor - %5 = stablehlo.broadcast_in_dim %4, dims = [] : (tensor) -> tensor<8x8xf32> - %6 = stablehlo.divide %3, %5 : tensor<8x8xf32> - %7 = call @tril(%6) : (tensor<8x8xf32>) -> tensor<8x8xf32> - %8 = stablehlo.constant dense<1> : tensor - %9 = stablehlo.constant dense<1> : tensor - %10 = stablehlo.constant dense<8> : tensor - %11 = stablehlo.custom_call @lapack_ssyevd(%8, %9, %10, %7) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor<8x8xf32>) -> tuple, tensor<8xf32>, tensor, tensor<177xf32>, tensor<43xi32>> - %12 = stablehlo.get_tuple_element %11[0] : (tuple, tensor<8xf32>, tensor, tensor<177xf32>, tensor<43xi32>>) -> tensor<8x8xf32> - %13 = stablehlo.get_tuple_element %11[1] : (tuple, tensor<8xf32>, tensor, tensor<177xf32>, tensor<43xi32>>) -> tensor<8xf32> - %14 = stablehlo.get_tuple_element %11[2] : (tuple, tensor<8xf32>, tensor, tensor<177xf32>, tensor<43xi32>>) -> tensor - %15 = stablehlo.get_tuple_element %11[3] : (tuple, tensor<8xf32>, tensor, tensor<177xf32>, tensor<43xi32>>) -> tensor<177xf32> - %16 = stablehlo.get_tuple_element %11[4] : (tuple, tensor<8xf32>, tensor, tensor<177xf32>, tensor<43xi32>>) -> tensor<43xi32> - %17 = stablehlo.constant dense<0> : tensor - %18 = stablehlo.broadcast_in_dim %17, dims = [] : (tensor) -> tensor - %19 = stablehlo.compare EQ, %14, %18, SIGNED : (tensor, tensor) -> tensor - %20 = stablehlo.broadcast_in_dim %19, dims = [] : (tensor) -> tensor<1x1xi1> - %21 = stablehlo.constant dense<0x7FC00000> : tensor - %22 = stablehlo.broadcast_in_dim %21, dims = [] : (tensor) -> tensor<8x8xf32> - %23 = stablehlo.broadcast_in_dim %20, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<8x8xi1> - %24 = stablehlo.select %23, %12, %22 : tensor<8x8xi1>, tensor<8x8xf32> - %25 = stablehlo.broadcast_in_dim %19, dims = [] : (tensor) -> tensor<1xi1> - %26 = stablehlo.constant dense<0x7FC00000> : tensor - %27 = stablehlo.broadcast_in_dim %26, dims = [] : (tensor) -> tensor<8xf32> - %28 = stablehlo.broadcast_in_dim %25, dims = [0] : (tensor<1xi1>) -> tensor<8xi1> - %29 = stablehlo.select %28, %13, %27 : tensor<8xi1>, tensor<8xf32> - return %24, %29 : tensor<8x8xf32>, tensor<8xf32> - } - func.func private @tril(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { - %0 = stablehlo.iota dim = 0 : tensor<8x8xi32> - %1 = stablehlo.constant dense<0> : tensor - %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<8x8xi32> - %3 = stablehlo.add %0, %2 : tensor<8x8xi32> - %4 = stablehlo.iota dim = 1 : tensor<8x8xi32> - %5 = stablehlo.compare GE, %3, %4, SIGNED : (tensor<8x8xi32>, tensor<8x8xi32>) -> tensor<8x8xi1> - %6 = stablehlo.constant dense<0.000000e+00> : tensor - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<8x8xf32> - %8 = stablehlo.select %5, %arg0, %7 : tensor<8x8xi1>, tensor<8x8xf32> - return %8 : tensor<8x8xf32> - } -} -""", - mlir_module_serialized=b"ML\xefR\x03MLIRxxx-trunk\x00\x01-\x05\x01\x05\x01\x03\x05\x03\x1d\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f!\x03z\x02\xf77\x01\x9b\x0f\x17\x13\x0b\x07\x0f\x0b\x0b\x0b\x0b\x17\x0b\x0b\x0b\x0b\x13\x0b\x13\x0f\x0b\x0b\x17\x0f\x13\x13\x13\x0b33\x0b\x0f\x0b\x0b\x13\x0f\x0b\x1b\x0f\x0b\x13\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x13\x0b\x0f\x0b\x0f\x0b\x13\x0b\x13\x0b\x0b\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x13\x13\x13\x1b\x13\x13\x03]\x0f/\x0b\x0b\x0f\x0b\x0bO\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0b\x1fO\x1f\x1f\x1f\x0b\x0b\x0b\x0b\x1b\x0f\x17\x1f\x0f\x0f\x0f\x0f\x0f\x0b\x1fO/\x037\x17\x0f\x07\x0f\x07\x13\x07\x07\x17\x07\x17\x13\x17\x13\x17\x17\x13\x17\x1f\x13\x13\x13\x0f\x17\x13\x13\x13\x02\n\t\x1du\x03\x17\x11\xf6\x04\x01\x03\x03\x13\xc5\x05#\x1f\x1d;\x03\x05%\x05'\x05)\x05+\x17\x11\xf2\x04\x01\x05-\x05/\x051\x053\x03\x03!\xc1\x055\x03\x03\x07\xc3\x1dA\x03\x057\x059\x17\x11\xea\x04\x01\x1do\x15\x03\x03\x07\xd1\x03\x03\x07\xf1\x03\x03\x0f5\x05;\x03\x0b\x17\x9f\x19\xab\x1b\xad\x0f\xb7\x1d\xb9\x03\x0b\x17\xa3\x19\xbd\x1b\xa3\x0f\xa5\x1d\xbf\x05=\x1d?\x03\x05?\x05A\x03\x03!\xc7\x1dG\x03\x05C\x03\x05'\xa7)\xc9\x1dM\x03\x05E\x03\x03\x07\xcb\x1dS\x03\x05G\x1dW\x03\x05I\x1d[+\x05K\x1d_+\x05M\x03\x03c\xcd\x05O\x1dg\x15\x05Q\x1dk\x15\x05S\x03\x03\x07\xcf\x05U\x03\x03s\xa5\x05W\x05Y\x03\x03\x07\xd3\x03\x11{\xd5}\xd7\x7f\xd9\x81\x9f\x83\xdb\x85\xdd\x87\xdf\x89\xe3\x05[\x05]\x05_\x05a\x05c\x05e\x05g\x05i\x03\x03\r\xe5\x03\x03\r\xe7\x03\x03\r\xe9\x03\x03\r\xeb\x03\x03\r\xed\x03\x05'\xa7)\xef\x03\x03\x13\xf3\x03\x03\x13\xf5\x1f'\x01\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1dk\x03\x03\xbb\x1dm\t\x07\x1f)!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00#\x1d\x03\x05\xaf\xb3\r\x03\xa1\xb1\x1do\r\x03\xa1\xb5\x1dq\x1ds\x1du\r\x01#\x1f\x1dw\x13\r\x01\x1f\x03\t\x00\x00\x00\x00\x1f!\x01\x13\r\x05\x07\x05\x1f\x07\t\x00\x00\x00\x00\x1f\x17!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x07\t\x00\x00\x00@\x1f\x03\t\x01\x00\x00\x00\x1f\x03\t\x08\x00\x00\x00\x0b\x05\x1dy\x1d{\x05\x01\x03\t\x9b\x9b\x9b\xa9\x03\x03\xe1\x15\x03\x01\r\x01\x03\x0b\xa9\x9d\x9b\x9d\x9d\x13\x05\x01\x13\x05\x05\x13\x05\t\x13\x05\r\x13\x05\x11\x07\x01\x1f\x07\t\x00\x00\xc0\x7f\x1f\x17!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f5\x11\x00\x00\x00\x00\x00\x00\x00\x00)\x05!!\t)\x01\x05\x1b)\x01\t\t)\x03!\t\x1d\x01)\x05!!\x05\x13)\x05!!\x0f)\x03\t\r)\x03\x8a\x05\t)\x03\xad\x05\x11\x01\x05\x01\x0b\x11\x03\x01\x03\x01)\x03\x01\r)\x03\x02\x02\t/\x0b\x01\x0b\x03\x19\x1b)\x03\x01\x13)\x03\t\x13)\x03\x05\x13)\x01\x0f)\x05\x05\x05\x0f)\x03\x05\x0f)\x03!\x0f)\x03\x05\r\x04:\x05\x05\x01\x11\t3\x07\x03\x01\t\r\x11\t7\x05\x03=}\t\x03Y\x1f\x03#\x15\x06]\x03\x01\x03\x01\x17\x07ea\x03\x01\x03\x03\x0f\x06i\x03\x01\x05\x03\x05\x05\x03\tm\x03\x07\x03\x07-\x05\x03\x01\x03\t\x19\x06-\x03\x01\x05\x07\x0b\x1b\x07\x0bq\x03\x01\x03\r\x05\x03\x01/\x03\x03\x05\x03\x01/\x03\x03\x05\x03\x01w\x03\x03\x1d\x07\x01y\x03%\t\x11\x13\x15\x0f\x07\x07\x01\x8b\x03\x01\x03\x17\x07\x07\x01\x8d\x03\x0b\x03\x17\x07\x07\x01\x8f\x03\x03\x03\x17\x07\x07\x01\x91\x03\x19\x03\x17\x07\x07\x01\x93\x03\x1b\x03\x17\x05\x03\x01#\x03\x03\x03\x07\x01\x05\x03\x03\x03#\x11\x07\x01\x95\x03-\x05\x1d%\x03\x07\x01\x05\x03/\x03'\x05\x03\x011\x03\x07\x03\x07\x01\x05\x03\x01\x03+\x03\x07\x01\x97\x03\x15\x03)\x0b\x06\x01\x03\x01\x07/\x19-\x03\x07\x01\x05\x031\x03'\x05\x03\x011\x03\x07\x03\x07\x01\x05\x03\x0b\x035\x03\x07\x01\x99\x033\x033\x0b\x06\x01\x03\x0b\x079\x1b7\x13\x04\t\x051;\r\x11\x0b9\x05\x03\x15+\x03\x01\t\t\x03=\x1f\x03\x11\x05\x03\x0b#\x03\x03\x03\x07%\x05\x03\x11\x03\x05\x0f\x06%\x03\x11\x05\x03\x07\t\x03EC\x03\x11\x11\x07KI\x03\x15\x05\t\x0b\x05\x03\x0bO\x03\x07\x03\x07Q\x05\x03\x01\x03\x0f\x0b\x06U\x03\x01\x07\r\x01\x11\x13\x04\x0b\x03\x13\x06\x03\x01\x05\x01\x00\xb2\x19}\x1d\x03\x11\x0f\x0b\t\t\x0b!\x1f/!!)#\x1f\x19\x7f\x0f99m\x19\x85\x89W\xb3K\x9bM\x9b\x96\x04\x1b+\x1b\x1f\x1f\x15\x1d\x15+\x83\x13\r\r\x1f\x11\x15\x1b\x17\x15\x17\x0f\x11\x15\x11+\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00get_tuple_element_v1\x00iota_v1\x00select_v1\x00func_v1\x00add_v1\x00compare_v1\x00return_v1\x00reshape_v1\x00transpose_v1\x00divide_v1\x00call_v1\x00custom_call_v1\x00value\x00index\x00sym_name\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00broadcast_dimensions\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00compare_type\x00comparison_direction\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=tril in_positional_semantics=(<_PositionalSemantics.GLOBAL: 1>,) out_positional_semantics=_PositionalSemantics.GLOBAL keep_unused=False inline=False]\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=0]\x00jit()/jit(main)/jit(tril)/add\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=1]\x00jit()/jit(main)/jit(tril)/ge\x00jit()/jit(main)/jit(tril)/broadcast_in_dim[shape=(8, 8) broadcast_dimensions=()]\x00jit()/jit(main)/jit(tril)/select_n\x00jit()/jit(main)/iota[dtype=float32 shape=(64,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(8, 8) dimensions=None]\x00permutation\x00jit()/jit(main)/transpose[permutation=(1, 0)]\x00jit()/jit(main)/add\x00jit()/jit(main)/div\x00callee\x00jit()/jit(main)/eigh[lower=True sort_eigenvalues=True]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jax.result_info\x00tril\x00[0]\x00[1]\x00main\x00public\x00private\x00\x00lapack_ssyevd\x00", - xla_call_module_version=4, - ), # End paste - - # Pasted from the test output (see back_compat_test.py module docstring) - f64=dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_dsyevd'], - serialized_date=datetime.date(2023, 3, 17), - inputs=(), - expected_outputs=(array([[-6.1857700048412056e-01, 2.4081403770912022e-01, - 3.5662489253627483e-01, -6.3034019033669797e-01, - 1.0043483479985752e-16, -2.8842036081919542e-02, - 7.7164692943283169e-25, -1.8446994643771725e-01], - [-4.7070881487314614e-01, 4.7473787464450845e-01, - -4.8036836210243367e-01, 4.3802686872516400e-01, - 1.7961797619639258e-01, 8.3080980076741355e-03, - 2.1415294457221756e-01, -2.2856669794666584e-01], - [-3.2284062926217072e-01, -5.4336490915553370e-01, - 2.2181041859724990e-01, 2.9947877954402297e-01, - -3.6491813600134632e-01, 3.2867679819727436e-01, - 3.8223299448843473e-01, -2.7266344945561438e-01], - [-1.7497244365119530e-01, -8.9251550609769414e-02, - -6.3518515114898394e-02, 1.9162997359209971e-01, - -2.2087281326110139e-01, 5.9957027043505064e-02, - -8.7632498908241274e-01, -3.1676020096456303e-01], - [-2.7104258040220038e-02, -3.3772873786627672e-01, - 2.5901386593721748e-01, 1.7032650752287815e-01, - 6.7521217612940332e-01, -4.5036136532965476e-01, - -1.2279030059078447e-02, -3.6085695247351163e-01], - [ 1.2076392757075530e-01, -3.3834734096469254e-01, - -6.5506827461665540e-01, -5.0472498521116749e-01, - 6.9987430903492118e-02, 1.0595648906599275e-01, - 8.3443844143082022e-02, -4.0495370398246017e-01], - [ 2.6863211318173097e-01, 2.2958613191407318e-01, - 6.3952843755683941e-02, 1.8776775771084137e-02, - -5.3523731432241317e-01, -5.9199531677602002e-01, - 1.7916671834524248e-01, -4.4905045549140887e-01], - [ 4.1650029879270661e-01, 3.6355449432857079e-01, - 2.9755313100756142e-01, 1.6826270392615944e-02, - 1.9621068035557282e-01, 5.6830030587314817e-01, - 2.9607517592514246e-02, -4.9314720700035747e-01]]), array([-2.4598804776133626e+01, -4.6567755957874661e-14, - -1.9932120610662194e-14, -5.7323356091157378e-15, - -4.5459724251334835e-16, 4.0479851042511616e-14, - 9.2325194924982089e-14, 2.7659880477613365e+02])), - mlir_module_text=r""" -module @jit__lambda_ { - func.func public @main() -> (tensor<8x8xf64> {jax.result_info = "[0]"}, tensor<8xf64> {jax.result_info = "[1]"}) { - %0 = stablehlo.iota dim = 0 : tensor<64xf64> - %1 = stablehlo.reshape %0 : (tensor<64xf64>) -> tensor<8x8xf64> - %2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<8x8xf64>) -> tensor<8x8xf64> - %3 = stablehlo.add %1, %2 : tensor<8x8xf64> - %4 = stablehlo.constant dense<2.000000e+00> : tensor - %5 = stablehlo.broadcast_in_dim %4, dims = [] : (tensor) -> tensor<8x8xf64> - %6 = stablehlo.divide %3, %5 : tensor<8x8xf64> - %7 = call @tril(%6) : (tensor<8x8xf64>) -> tensor<8x8xf64> - %8 = stablehlo.constant dense<1> : tensor - %9 = stablehlo.constant dense<1> : tensor - %10 = stablehlo.constant dense<8> : tensor - %11 = stablehlo.custom_call @lapack_dsyevd(%8, %9, %10, %7) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor<8x8xf64>) -> tuple, tensor<8xf64>, tensor, tensor<177xf64>, tensor<43xi32>> - %12 = stablehlo.get_tuple_element %11[0] : (tuple, tensor<8xf64>, tensor, tensor<177xf64>, tensor<43xi32>>) -> tensor<8x8xf64> - %13 = stablehlo.get_tuple_element %11[1] : (tuple, tensor<8xf64>, tensor, tensor<177xf64>, tensor<43xi32>>) -> tensor<8xf64> - %14 = stablehlo.get_tuple_element %11[2] : (tuple, tensor<8xf64>, tensor, tensor<177xf64>, tensor<43xi32>>) -> tensor - %15 = stablehlo.get_tuple_element %11[3] : (tuple, tensor<8xf64>, tensor, tensor<177xf64>, tensor<43xi32>>) -> tensor<177xf64> - %16 = stablehlo.get_tuple_element %11[4] : (tuple, tensor<8xf64>, tensor, tensor<177xf64>, tensor<43xi32>>) -> tensor<43xi32> - %17 = stablehlo.constant dense<0> : tensor - %18 = stablehlo.broadcast_in_dim %17, dims = [] : (tensor) -> tensor - %19 = stablehlo.compare EQ, %14, %18, SIGNED : (tensor, tensor) -> tensor - %20 = stablehlo.broadcast_in_dim %19, dims = [] : (tensor) -> tensor<1x1xi1> - %21 = stablehlo.constant dense<0x7FF8000000000000> : tensor - %22 = stablehlo.broadcast_in_dim %21, dims = [] : (tensor) -> tensor<8x8xf64> - %23 = stablehlo.broadcast_in_dim %20, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<8x8xi1> - %24 = stablehlo.select %23, %12, %22 : tensor<8x8xi1>, tensor<8x8xf64> - %25 = stablehlo.broadcast_in_dim %19, dims = [] : (tensor) -> tensor<1xi1> - %26 = stablehlo.constant dense<0x7FF8000000000000> : tensor - %27 = stablehlo.broadcast_in_dim %26, dims = [] : (tensor) -> tensor<8xf64> - %28 = stablehlo.broadcast_in_dim %25, dims = [0] : (tensor<1xi1>) -> tensor<8xi1> - %29 = stablehlo.select %28, %13, %27 : tensor<8xi1>, tensor<8xf64> - return %24, %29 : tensor<8x8xf64>, tensor<8xf64> - } - func.func private @tril(%arg0: tensor<8x8xf64>) -> tensor<8x8xf64> { - %0 = stablehlo.iota dim = 0 : tensor<8x8xi32> - %1 = stablehlo.constant dense<0> : tensor - %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<8x8xi32> - %3 = stablehlo.add %0, %2 : tensor<8x8xi32> - %4 = stablehlo.iota dim = 1 : tensor<8x8xi32> - %5 = stablehlo.compare GE, %3, %4, SIGNED : (tensor<8x8xi32>, tensor<8x8xi32>) -> tensor<8x8xi1> - %6 = stablehlo.constant dense<0.000000e+00> : tensor - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<8x8xf64> - %8 = stablehlo.select %5, %arg0, %7 : tensor<8x8xi1>, tensor<8x8xf64> - return %8 : tensor<8x8xf64> - } -} -""", - mlir_module_serialized=b"ML\xefR\x03MLIRxxx-trunk\x00\x01-\x05\x01\x05\x01\x03\x05\x03\x1d\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f!\x03z\x02\xf77\x01\x9b\x0f\x17\x13\x0b\x07\x0f\x0b\x0b\x0b\x0b\x17\x0b\x0b\x0b\x0b\x13\x0b\x13\x0f\x0b\x0b\x17\x0f\x13\x13\x13\x0b33\x0b\x0f\x0b\x0b\x13\x0f\x0b\x1b\x0f\x0b\x13\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x13\x0b\x0f\x0b\x0f\x0b\x13\x0b\x13\x0b\x0b\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x13\x13\x13\x1b\x13\x13\x03]\x0f/\x0b\x0b\x0f\x0b\x0bO\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0b/O/\x1f\x1f\x0b\x0b\x0b\x0b\x1b\x0f\x17\x1f\x0f\x0f\x0f\x0f\x0f\x0b/O/\x037\x17\x0f\x07\x0f\x07\x13\x07\x07\x17\x07\x17\x13\x17\x13\x17\x17\x13\x17\x1f\x13\x13\x13\x0f\x17\x13\x13\x13\x02:\t\x1du\x03\x17\x11\xf6\x04\x01\x03\x03\x13\xc5\x05#\x1f\x1d;\x03\x05%\x05'\x05)\x05+\x17\x11\xf2\x04\x01\x05-\x05/\x051\x053\x03\x03!\xc1\x055\x03\x03\x07\xc3\x1dA\x03\x057\x059\x17\x11\xea\x04\x01\x1do\x15\x03\x03\x07\xd1\x03\x03\x07\xf1\x03\x03\x0f5\x05;\x03\x0b\x17\x9f\x19\xab\x1b\xad\x0f\xb7\x1d\xb9\x03\x0b\x17\xa3\x19\xbd\x1b\xa3\x0f\xa5\x1d\xbf\x05=\x1d?\x03\x05?\x05A\x03\x03!\xc7\x1dG\x03\x05C\x03\x05'\xa7)\xc9\x1dM\x03\x05E\x03\x03\x07\xcb\x1dS\x03\x05G\x1dW\x03\x05I\x1d[+\x05K\x1d_+\x05M\x03\x03c\xcd\x05O\x1dg\x15\x05Q\x1dk\x15\x05S\x03\x03\x07\xcf\x05U\x03\x03s\xa5\x05W\x05Y\x03\x03\x07\xd3\x03\x11{\xd5}\xd7\x7f\xd9\x81\x9f\x83\xdb\x85\xdd\x87\xdf\x89\xe3\x05[\x05]\x05_\x05a\x05c\x05e\x05g\x05i\x03\x03\r\xe5\x03\x03\r\xe7\x03\x03\r\xe9\x03\x03\r\xeb\x03\x03\r\xed\x03\x05'\xa7)\xef\x03\x03\x13\xf3\x03\x03\x13\xf5\x1f'\x01\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1dk\x03\x03\xbb\x1dm\t\x07\x1f)!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00#\x1d\x03\x05\xaf\xb3\r\x03\xa1\xb1\x1do\r\x03\xa1\xb5\x1dq\x1ds\x1du\r\x01#\x1f\x1dw\x13\r\x01\x1f\x03\t\x00\x00\x00\x00\x1f!\x01\x13\r\x05\x07\x05\x1f\x07\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x17!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x07\x11\x00\x00\x00\x00\x00\x00\x00@\x1f\x03\t\x01\x00\x00\x00\x1f\x03\t\x08\x00\x00\x00\x0b\x05\x1dy\x1d{\x05\x01\x03\t\x9b\x9b\x9b\xa9\x03\x03\xe1\x15\x03\x01\r\x01\x03\x0b\xa9\x9d\x9b\x9d\x9d\x13\x05\x01\x13\x05\x05\x13\x05\t\x13\x05\r\x13\x05\x11\x07\x01\x1f\x07\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x17!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f5\x11\x00\x00\x00\x00\x00\x00\x00\x00)\x05!!\t)\x01\x05\x1b)\x01\t\x0b)\x03!\t\x1d\x01)\x05!!\x05\x13)\x05!!\x0f)\x03\t\r)\x03\x8a\x05\t)\x03\xad\x05\x11\x01\x05\x01\x0b\x11\x03\x01\x03\x01)\x03\x01\r)\x03\x02\x02\t/\x0b\x01\x0b\x03\x19\x1b)\x03\x01\x13)\x03\t\x13)\x03\x05\x13)\x01\x0f)\x05\x05\x05\x0f)\x03\x05\x0f)\x03!\x0f)\x03\x05\r\x04:\x05\x05\x01\x11\t3\x07\x03\x01\t\r\x11\t7\x05\x03=}\t\x03Y\x1f\x03#\x15\x06]\x03\x01\x03\x01\x17\x07ea\x03\x01\x03\x03\x0f\x06i\x03\x01\x05\x03\x05\x05\x03\tm\x03\x07\x03\x07-\x05\x03\x01\x03\t\x19\x06-\x03\x01\x05\x07\x0b\x1b\x07\x0bq\x03\x01\x03\r\x05\x03\x01/\x03\x03\x05\x03\x01/\x03\x03\x05\x03\x01w\x03\x03\x1d\x07\x01y\x03%\t\x11\x13\x15\x0f\x07\x07\x01\x8b\x03\x01\x03\x17\x07\x07\x01\x8d\x03\x0b\x03\x17\x07\x07\x01\x8f\x03\x03\x03\x17\x07\x07\x01\x91\x03\x19\x03\x17\x07\x07\x01\x93\x03\x1b\x03\x17\x05\x03\x01#\x03\x03\x03\x07\x01\x05\x03\x03\x03#\x11\x07\x01\x95\x03-\x05\x1d%\x03\x07\x01\x05\x03/\x03'\x05\x03\x011\x03\x07\x03\x07\x01\x05\x03\x01\x03+\x03\x07\x01\x97\x03\x15\x03)\x0b\x06\x01\x03\x01\x07/\x19-\x03\x07\x01\x05\x031\x03'\x05\x03\x011\x03\x07\x03\x07\x01\x05\x03\x0b\x035\x03\x07\x01\x99\x033\x033\x0b\x06\x01\x03\x0b\x079\x1b7\x13\x04\t\x051;\r\x11\x0b9\x05\x03\x15+\x03\x01\t\t\x03=\x1f\x03\x11\x05\x03\x0b#\x03\x03\x03\x07%\x05\x03\x11\x03\x05\x0f\x06%\x03\x11\x05\x03\x07\t\x03EC\x03\x11\x11\x07KI\x03\x15\x05\t\x0b\x05\x03\x0bO\x03\x07\x03\x07Q\x05\x03\x01\x03\x0f\x0b\x06U\x03\x01\x07\r\x01\x11\x13\x04\x0b\x03\x13\x06\x03\x01\x05\x01\x00\xb2\x19}\x1d\x03\x11\x0f\x0b\t\t\x0b!\x1f/!!)#\x1f\x19\x7f\x0f99m\x19\x85\x89W\xb3K\x9bM\x9b\x96\x04\x1b+\x1b\x1f\x1f\x15\x1d\x15+\x83\x13\r\r\x1f\x11\x15\x1b\x17\x15\x17\x0f\x11\x15\x11+\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00get_tuple_element_v1\x00iota_v1\x00select_v1\x00func_v1\x00add_v1\x00compare_v1\x00return_v1\x00reshape_v1\x00transpose_v1\x00divide_v1\x00call_v1\x00custom_call_v1\x00value\x00index\x00sym_name\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00broadcast_dimensions\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00compare_type\x00comparison_direction\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=tril in_positional_semantics=(<_PositionalSemantics.GLOBAL: 1>,) out_positional_semantics=_PositionalSemantics.GLOBAL keep_unused=False inline=False]\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=0]\x00jit()/jit(main)/jit(tril)/add\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=1]\x00jit()/jit(main)/jit(tril)/ge\x00jit()/jit(main)/jit(tril)/broadcast_in_dim[shape=(8, 8) broadcast_dimensions=()]\x00jit()/jit(main)/jit(tril)/select_n\x00jit()/jit(main)/iota[dtype=float64 shape=(64,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(8, 8) dimensions=None]\x00permutation\x00jit()/jit(main)/transpose[permutation=(1, 0)]\x00jit()/jit(main)/add\x00jit()/jit(main)/div\x00callee\x00jit()/jit(main)/eigh[lower=True sort_eigenvalues=True]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jax.result_info\x00tril\x00[0]\x00[1]\x00main\x00public\x00private\x00\x00lapack_dsyevd\x00", - xla_call_module_version=4, - ), # End paste - - # Pasted from the test output (see back_compat_test.py module docstring) - c64=dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_cheevd'], - serialized_date=datetime.date(2023, 3, 17), - inputs=(), - expected_outputs=(array([[-0.6185769 +0.j, -0.20142993 +0.j, -0.09725195 +0.j, - 0.62983674 +0.j, -0.07926044 +0.j, 0.3605001 -0.j, - -0.019093221 +0.j, -0.18446997 +0.j], - [-0.47070873 +0.j, 0.29325768 +0.j, -0.19454116 +0.j, - -0.6394365 +0.j, 0.06229549 +0.j, 0.33249345 +0.j, - 0.28112718 +0.j, -0.22856665 +0.j], - [-0.32284075 +0.j, -0.12361939 +0.j, 0.20547704 +0.j, - -0.18307868 +0.j, 0.47294614 +0.j, -0.3170349 +0.j, - -0.6373532 +0.j, -0.27266347 +0.j], - [-0.17497246 +0.j, -0.079641335 +0.j, 0.15042792 +0.j, - -0.15416273 +0.j, -0.815209 +0.j, -0.38054234 +0.j, - -0.083263926 +0.j, -0.31676024 +0.j], - [-0.027104257 +0.j, -0.26490977 +0.j, 0.32271704 +0.j, - 0.08653544 +0.j, 0.30305928 +0.j, -0.33998996 +0.j, - 0.6926741 +0.j, -0.360857 +0.j], - [ 0.120763965 +0.j, 0.43288827 +0.j, -0.64385164 +0.j, - 0.2652551 +0.j, 0.094823755 +0.j, -0.37435007 +0.j, - 0.00091664493+0.j, -0.40495378 +0.j], - [ 0.26863196 +0.j, 0.51607686 +0.j, 0.53846526 +0.j, - 0.16969058 +0.j, -0.0216703 +0.j, 0.35755336 +0.j, - -0.113144726 +0.j, -0.4490505 +0.j], - [ 0.4165004 +0.j, -0.57262254 +0.j, -0.28144246 +0.j, - -0.17463988 +0.j, -0.016984984 +0.j, 0.3613705 +0.j, - -0.12186296 +0.j, -0.49314725 +0.j]], dtype=complex64), array([-2.4598808e+01, -3.3105560e-05, -3.1002426e-05, -1.0103593e-05, - -1.0022322e-05, 4.0141886e-06, 9.5510331e-06, 2.7659882e+02], - dtype=float32)), - mlir_module_text=r""" -module @jit__lambda_ { - func.func public @main() -> (tensor<8x8xcomplex> {jax.result_info = "[0]"}, tensor<8xf32> {jax.result_info = "[1]"}) { - %0 = stablehlo.iota dim = 0 : tensor<64xcomplex> - %1 = stablehlo.reshape %0 : (tensor<64xcomplex>) -> tensor<8x8xcomplex> - %2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<8x8xcomplex>) -> tensor<8x8xcomplex> - %3 = stablehlo.real %2 : (tensor<8x8xcomplex>) -> tensor<8x8xf32> - %4 = stablehlo.imag %2 : (tensor<8x8xcomplex>) -> tensor<8x8xf32> - %5 = stablehlo.negate %4 : tensor<8x8xf32> - %6 = stablehlo.complex %3, %5 : tensor<8x8xcomplex> - %7 = stablehlo.add %1, %6 : tensor<8x8xcomplex> - %8 = stablehlo.constant dense<(2.000000e+00,0.000000e+00)> : tensor> - %9 = stablehlo.broadcast_in_dim %8, dims = [] : (tensor>) -> tensor<8x8xcomplex> - %10 = stablehlo.divide %7, %9 : tensor<8x8xcomplex> - %11 = call @tril(%10) : (tensor<8x8xcomplex>) -> tensor<8x8xcomplex> - %12 = stablehlo.constant dense<1> : tensor - %13 = stablehlo.constant dense<1> : tensor - %14 = stablehlo.constant dense<8> : tensor - %15 = stablehlo.custom_call @lapack_cheevd(%12, %13, %14, %11) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor<8x8xcomplex>) -> tuple>, tensor<8xf32>, tensor, tensor<81xcomplex>, tensor<169xf32>, tensor<43xi32>> - %16 = stablehlo.get_tuple_element %15[0] : (tuple>, tensor<8xf32>, tensor, tensor<81xcomplex>, tensor<169xf32>, tensor<43xi32>>) -> tensor<8x8xcomplex> - %17 = stablehlo.get_tuple_element %15[1] : (tuple>, tensor<8xf32>, tensor, tensor<81xcomplex>, tensor<169xf32>, tensor<43xi32>>) -> tensor<8xf32> - %18 = stablehlo.get_tuple_element %15[2] : (tuple>, tensor<8xf32>, tensor, tensor<81xcomplex>, tensor<169xf32>, tensor<43xi32>>) -> tensor - %19 = stablehlo.get_tuple_element %15[3] : (tuple>, tensor<8xf32>, tensor, tensor<81xcomplex>, tensor<169xf32>, tensor<43xi32>>) -> tensor<81xcomplex> - %20 = stablehlo.get_tuple_element %15[4] : (tuple>, tensor<8xf32>, tensor, tensor<81xcomplex>, tensor<169xf32>, tensor<43xi32>>) -> tensor<169xf32> - %21 = stablehlo.get_tuple_element %15[5] : (tuple>, tensor<8xf32>, tensor, tensor<81xcomplex>, tensor<169xf32>, tensor<43xi32>>) -> tensor<43xi32> - %22 = stablehlo.constant dense<0> : tensor - %23 = stablehlo.broadcast_in_dim %22, dims = [] : (tensor) -> tensor - %24 = stablehlo.compare EQ, %18, %23, SIGNED : (tensor, tensor) -> tensor - %25 = stablehlo.broadcast_in_dim %24, dims = [] : (tensor) -> tensor<1x1xi1> - %26 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> - %27 = stablehlo.broadcast_in_dim %26, dims = [] : (tensor>) -> tensor<8x8xcomplex> - %28 = stablehlo.broadcast_in_dim %25, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<8x8xi1> - %29 = stablehlo.select %28, %16, %27 : tensor<8x8xi1>, tensor<8x8xcomplex> - %30 = stablehlo.broadcast_in_dim %24, dims = [] : (tensor) -> tensor<1xi1> - %31 = stablehlo.constant dense<0x7FC00000> : tensor - %32 = stablehlo.broadcast_in_dim %31, dims = [] : (tensor) -> tensor<8xf32> - %33 = stablehlo.broadcast_in_dim %30, dims = [0] : (tensor<1xi1>) -> tensor<8xi1> - %34 = stablehlo.select %33, %17, %32 : tensor<8xi1>, tensor<8xf32> - return %29, %34 : tensor<8x8xcomplex>, tensor<8xf32> - } - func.func private @tril(%arg0: tensor<8x8xcomplex>) -> tensor<8x8xcomplex> { - %0 = stablehlo.iota dim = 0 : tensor<8x8xi32> - %1 = stablehlo.constant dense<0> : tensor - %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<8x8xi32> - %3 = stablehlo.add %0, %2 : tensor<8x8xi32> - %4 = stablehlo.iota dim = 1 : tensor<8x8xi32> - %5 = stablehlo.compare GE, %3, %4, SIGNED : (tensor<8x8xi32>, tensor<8x8xi32>) -> tensor<8x8xi1> - %6 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor>) -> tensor<8x8xcomplex> - %8 = stablehlo.select %5, %arg0, %7 : tensor<8x8xi1>, tensor<8x8xcomplex> - return %8 : tensor<8x8xcomplex> - } -} -""", - mlir_module_serialized=b"ML\xefR\x03MLIRxxx-trunk\x00\x015\x05\x01\x05\x01\x03\x05\x03%\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f!#%')\x03\xc6\x02\x1e\x02?\x01\xa9\x0f\x17\x13\x0b\x17\x0b\x07\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x0b\x13\x0f\x0b\x0b\x17\x0f\x13\x13\x0b33\x0b\x0f\x0b\x0b\x13\x0f\x0b\x1b\x0f\x0b\x13\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x13\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x13\x0b\x13\x0b\x0b\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x13\x13\x13\x13\x1b\x17\x03a\x0f/\x0b\x0b\x0f\x0b\x0bO\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0b/O/\x1f\x1f\x0b\x0b\x0b\x0b\x1b\x0f\x17#\x0f\x0f\x0f\x0f\x0f\x0f\x0b/O\x1f/\x01\x07\x17\x17\x17\x03?\x17\x0f\x07\x0f\x07\x13\x07\x07\x0b\x17\x17\x07\x17\x13\x17\x17\x13\x0f\x17\x17\x13\x17#\x13\x13\x13\x0f\x17\x13\x13\x13\x02&\n\x1d\x83\x03\x17\x13\xf6\x04\x01\x03\x03\x15\xd3\x05+\x17\x13\xf2\x04\x01\x05-\x1f\x1d9\x03\x05/\x051\x053\x055\x057\x059\x05;\x03\x03!\xcf\x05=\x03\x03\x07\xd1\x1d?\x03\x05?\x05A\x17\x13\xea\x04\x01\x1d}\t\x03\x03\x07\xdf\x03\x03\x113\x05C\x03\x0b\x17\xad\x19\xb9\x1b\xbb\x11\xc5\x1d\xc7\x03\x0b\x17\xb1\x19\xcb\x1b\xb1\x11\xb3\x1d\xcd\x05E\x1d=\x03\x05G\x05I\x03\x03!\xd5\x1dE\x03\x05K\x03\x05'\xb5)\xd7\x1dK\x03\x05M\x03\x03\x07\xd9\x1dQ\x03\x05O\x1dU\x03\x05Q\x1dY+\x05S\x1d]+\x05U\x03\x03a\xdb\x05W\x1de\t\x05Y\x1di\t\x05[\x1dm\t\x05]\x1dq\t\x05_\x1du\t\x05a\x1dy\t\x05c\x03\x03\x07\xdd\x05e\x03\x03\x81\xb3\x05g\x05i\x03\x03\x07\xe1\x03\x11\x89\xe3\x8b\xe5\x8d\xe7\x8f\xad\x91\xe9\x93\xeb\x95\xed\x97\xf1\x05k\x05m\x05o\x05q\x05s\x05u\x05w\x05y\x03\x03\x0b\xf3\x03\x03\x0b\xf5\x03\x03\x0b\xf7\x03\x03\x0b\xf9\x03\x03\x0b\xfb\x03\x03\x0b\xfd\x03\x05'\xb5)\xff\x03\x03\x07\x02\x02\x1f/\x01\x1f3\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1d{\x03\x03\xc9\x1d}\t\x07\x1f1!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00#%\x03\x05\xbd\xc1\r\x03\xaf\xbf\x1d\x7f\r\x03\xaf\xc3\x1d\x81\x1d\x83\x1d\x85\r\x01#'\x1d\x87\x13\r\x01\x1f\x03\t\x00\x00\x00\x00\x1f)\x01\x13\r\x05\x07\x05\x1f\x07\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x1b!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x07\x11\x00\x00\x00@\x00\x00\x00\x00\x1f\x03\t\x01\x00\x00\x00\x1f\x03\t\x08\x00\x00\x00\x0b\x05\x1d\x89\x1d\x8b\x05\x01\x03\t\xa9\xa9\xa9\xb7\x03\x03\xef\x15\x03\x01\r\x01\x03\r\xb7\xab\xa9\xab\xab\xab\x13\x05\x01\x13\x05\x05\x13\x05\t\x13\x05\r\x13\x05\x11\x13\x05\x15\x07\x01\x1f\x07\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f\x1b!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f#\t\x00\x00\xc0\x7f\x1f=\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x03\x15\x06\x02\x03\x03\x07\n\x02\x03\x03\x15\x0e\x02)\x05!!\x11)\x01\x05\x1b)\x01\x11\t)\x03!\t\x1d\x01\x03\t)\x05!!\x05)\x05!!\t\x13)\x05!!\x0f)\x03\t\r)\x03\x8a\x02\x11)\x03J\x05\t)\x03\xad\x05)\x01\t\x11\x01\x05\x01\x0b\x11\x03\x01\x03\x01)\x03\x01\r)\x03\x02\x02\x11/\r\x01\x0b\x03\x1d\x1f!)\x03\x01\x17)\x03\t\x17)\x03\x05\x17)\x01\x0f)\x05\x05\x05\x0f)\x03\x05\x0f)\x03!\x0f)\x03\x05\r\x04\xda\x05\x05\x01\x11\r1\x07\x03\x01\t\r\x11\r5\x05\x03G\x91\t\x03W\x1f\x03+\x15\x06[\x03\x01\x03\x01\x17\x07c_\x03\x01\x03\x03\x19\x06g\x03\x15\x03\x05\x1b\x06k\x03\x15\x03\x05\x1d\x06o\x03\x15\x03\t\x1f\x06s\x03\x01\x05\x07\x0b\x0f\x06w\x03\x01\x05\x03\r\x05\x03\r{\x03\x07\x03\x07-\x05\x03\x01\x03\x11!\x06-\x03\x01\x05\x0f\x13#\x07\x0f\x7f\x03\x01\x03\x15\x05\x03\x01/\x03\x03\x05\x03\x01/\x03\x03\x05\x03\x01\x85\x03\x03%\x07\x01\x87\x03-\t\x19\x1b\x1d\x17\x07\x07\x01\x99\x03\x01\x03\x1f\x07\x07\x01\x9b\x03\x0b\x03\x1f\x07\x07\x01\x9d\x03\x03\x03\x1f\x07\x07\x01\x9f\x03\x1d\x03\x1f\x07\x07\x01\xa1\x03\x1f\x03\x1f\x07\x07\x01\xa3\x03!\x03\x1f\x05\x03\x01#\x03\x03\x03\x07\x01\x05\x03\x03\x03-\x11\x07\x01\xa5\x035\x05%/\x03\x07\x01\x05\x037\x031\x05\x03\x01\xa7\x03\x07\x03\x07\x01\x05\x03\x01\x035\x03\x07\x01\x12\x02\x03\x19\x033\x0b\x06\x01\x03\x01\x079!7\x03\x07\x01\x05\x039\x031\x05\x03\x01\x16\x02\x03#\x03\x07\x01\x05\x03\x0b\x03?\x03\x07\x01\x1a\x02\x03;\x03=\x0b\x06\x01\x03\x0b\x07C#A\x13\x04\r\x05;E\r\x11\x0f7\x05\x03\x15+\x03\x01\r\t\x03;\x1f\x03\x13\x05\x03\x0f#\x03\x03\x03\x07%\x05\x03\x13\x03\x05\x0f\x06%\x03\x13\x05\x03\x07\t\x03CA\x03\x13\x11\x07IG\x03\x19\x05\t\x0b\x05\x03\x0fM\x03\x07\x03\x07O\x05\x03\x01\x03\x0f\x0b\x06S\x03\x01\x07\r\x01\x11\x13\x04\x0f\x03\x13\x06\x03\x01\x05\x01\x00F\x1c\x8d\x1d\x03\x11\x0f\x0b\t\t\x0b!\x1f/!!)#\x1f\x19\x7f\x0f99A9;;m\x19\x85\x8dW\xb3K\x9bM\x9b\x96\x04\x1b+\x1b\x1f\x1f\x15\x1d\x15+\x83\x13\r\r\x1f\x11\x15\x17\x15\x11\x11\x1b\x17\x15\x17\x0f\x11\x15\x11+\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00get_tuple_element_v1\x00iota_v1\x00select_v1\x00func_v1\x00add_v1\x00compare_v1\x00return_v1\x00reshape_v1\x00transpose_v1\x00real_v1\x00imag_v1\x00negate_v1\x00complex_v1\x00divide_v1\x00call_v1\x00custom_call_v1\x00value\x00index\x00sym_name\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00broadcast_dimensions\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00compare_type\x00comparison_direction\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=tril in_positional_semantics=(<_PositionalSemantics.GLOBAL: 1>,) out_positional_semantics=_PositionalSemantics.GLOBAL keep_unused=False inline=False]\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=0]\x00jit()/jit(main)/jit(tril)/add\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=1]\x00jit()/jit(main)/jit(tril)/ge\x00jit()/jit(main)/jit(tril)/broadcast_in_dim[shape=(8, 8) broadcast_dimensions=()]\x00jit()/jit(main)/jit(tril)/select_n\x00jit()/jit(main)/iota[dtype=complex64 shape=(64,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(8, 8) dimensions=None]\x00permutation\x00jit()/jit(main)/transpose[permutation=(1, 0)]\x00jit()/jit(main)/real\x00jit()/jit(main)/imag\x00jit()/jit(main)/neg\x00jit()/jit(main)/complex\x00jit()/jit(main)/add\x00jit()/jit(main)/div\x00callee\x00jit()/jit(main)/eigh[lower=True sort_eigenvalues=True]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jax.result_info\x00tril\x00[0]\x00[1]\x00main\x00public\x00private\x00\x00lapack_cheevd\x00", - xla_call_module_version=4, - ), # End paste - - # Pasted from the test output (see back_compat_test.py module docstring) - c128=dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_zheevd'], - serialized_date=datetime.date(2023, 3, 17), - inputs=(), - expected_outputs=(array([[-6.1857700048412056e-01+0.j, 2.4081403770912022e-01+0.j, - 3.5662489253627483e-01+0.j, -6.3034019033669797e-01+0.j, - 1.0043483479985752e-16+0.j, -2.8842036081919542e-02+0.j, - 7.7164692943283169e-25+0.j, -1.8446994643771725e-01+0.j], - [-4.7070881487314609e-01+0.j, 4.7473787464450828e-01+0.j, - -4.8036836210243361e-01+0.j, 4.3802686872516400e-01+0.j, - 1.7961797619639255e-01+0.j, 8.3080980076741355e-03+0.j, - 2.1415294457221759e-01+0.j, -2.2856669794666584e-01+0.j], - [-3.2284062926217072e-01+0.j, -5.4336490915553370e-01+0.j, - 2.2181041859724987e-01+0.j, 2.9947877954402286e-01+0.j, - -3.6491813600134637e-01+0.j, 3.2867679819727436e-01+0.j, - 3.8223299448843473e-01+0.j, -2.7266344945561438e-01+0.j], - [-1.7497244365119527e-01+0.j, -8.9251550609769331e-02+0.j, - -6.3518515114898352e-02+0.j, 1.9162997359209963e-01+0.j, - -2.2087281326110142e-01+0.j, 5.9957027043505008e-02+0.j, - -8.7632498908241274e-01+0.j, -3.1676020096456303e-01+0.j], - [-2.7104258040220017e-02+0.j, -3.3772873786627688e-01+0.j, - 2.5901386593721754e-01+0.j, 1.7032650752287815e-01+0.j, - 6.7521217612940321e-01+0.j, -4.5036136532965476e-01+0.j, - -1.2279030059078447e-02+0.j, -3.6085695247351163e-01+0.j], - [ 1.2076392757075533e-01+0.j, -3.3834734096469249e-01+0.j, - -6.5506827461665529e-01+0.j, -5.0472498521116760e-01+0.j, - 6.9987430903492132e-02+0.j, 1.0595648906599270e-01+0.j, - 8.3443844143082035e-02+0.j, -4.0495370398246017e-01+0.j], - [ 2.6863211318173102e-01+0.j, 2.2958613191407312e-01+0.j, - 6.3952843755683969e-02+0.j, 1.8776775771084192e-02+0.j, - -5.3523731432241317e-01+0.j, -5.9199531677602002e-01+0.j, - 1.7916671834524250e-01+0.j, -4.4905045549140887e-01+0.j], - [ 4.1650029879270667e-01+0.j, 3.6355449432857068e-01+0.j, - 2.9755313100756148e-01+0.j, 1.6826270392616000e-02+0.j, - 1.9621068035557282e-01+0.j, 5.6830030587314817e-01+0.j, - 2.9607517592514260e-02+0.j, -4.9314720700035747e-01+0.j]]), array([-2.4598804776133626e+01, -4.6567755957874661e-14, - -1.9932120610662194e-14, -5.7323356091157378e-15, - -4.5459724251334835e-16, 4.0479851042511616e-14, - 9.2325194924982089e-14, 2.7659880477613365e+02])), - mlir_module_text=r""" -module @jit__lambda_ { - func.func public @main() -> (tensor<8x8xcomplex> {jax.result_info = "[0]"}, tensor<8xf64> {jax.result_info = "[1]"}) { - %0 = stablehlo.iota dim = 0 : tensor<64xcomplex> - %1 = stablehlo.reshape %0 : (tensor<64xcomplex>) -> tensor<8x8xcomplex> - %2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<8x8xcomplex>) -> tensor<8x8xcomplex> - %3 = stablehlo.real %2 : (tensor<8x8xcomplex>) -> tensor<8x8xf64> - %4 = stablehlo.imag %2 : (tensor<8x8xcomplex>) -> tensor<8x8xf64> - %5 = stablehlo.negate %4 : tensor<8x8xf64> - %6 = stablehlo.complex %3, %5 : tensor<8x8xcomplex> - %7 = stablehlo.add %1, %6 : tensor<8x8xcomplex> - %8 = stablehlo.constant dense<(2.000000e+00,0.000000e+00)> : tensor> - %9 = stablehlo.broadcast_in_dim %8, dims = [] : (tensor>) -> tensor<8x8xcomplex> - %10 = stablehlo.divide %7, %9 : tensor<8x8xcomplex> - %11 = call @tril(%10) : (tensor<8x8xcomplex>) -> tensor<8x8xcomplex> - %12 = stablehlo.constant dense<1> : tensor - %13 = stablehlo.constant dense<1> : tensor - %14 = stablehlo.constant dense<8> : tensor - %15 = stablehlo.custom_call @lapack_zheevd(%12, %13, %14, %11) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor<8x8xcomplex>) -> tuple>, tensor<8xf64>, tensor, tensor<81xcomplex>, tensor<169xf64>, tensor<43xi32>> - %16 = stablehlo.get_tuple_element %15[0] : (tuple>, tensor<8xf64>, tensor, tensor<81xcomplex>, tensor<169xf64>, tensor<43xi32>>) -> tensor<8x8xcomplex> - %17 = stablehlo.get_tuple_element %15[1] : (tuple>, tensor<8xf64>, tensor, tensor<81xcomplex>, tensor<169xf64>, tensor<43xi32>>) -> tensor<8xf64> - %18 = stablehlo.get_tuple_element %15[2] : (tuple>, tensor<8xf64>, tensor, tensor<81xcomplex>, tensor<169xf64>, tensor<43xi32>>) -> tensor - %19 = stablehlo.get_tuple_element %15[3] : (tuple>, tensor<8xf64>, tensor, tensor<81xcomplex>, tensor<169xf64>, tensor<43xi32>>) -> tensor<81xcomplex> - %20 = stablehlo.get_tuple_element %15[4] : (tuple>, tensor<8xf64>, tensor, tensor<81xcomplex>, tensor<169xf64>, tensor<43xi32>>) -> tensor<169xf64> - %21 = stablehlo.get_tuple_element %15[5] : (tuple>, tensor<8xf64>, tensor, tensor<81xcomplex>, tensor<169xf64>, tensor<43xi32>>) -> tensor<43xi32> - %22 = stablehlo.constant dense<0> : tensor - %23 = stablehlo.broadcast_in_dim %22, dims = [] : (tensor) -> tensor - %24 = stablehlo.compare EQ, %18, %23, SIGNED : (tensor, tensor) -> tensor - %25 = stablehlo.broadcast_in_dim %24, dims = [] : (tensor) -> tensor<1x1xi1> - %26 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> - %27 = stablehlo.broadcast_in_dim %26, dims = [] : (tensor>) -> tensor<8x8xcomplex> - %28 = stablehlo.broadcast_in_dim %25, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<8x8xi1> - %29 = stablehlo.select %28, %16, %27 : tensor<8x8xi1>, tensor<8x8xcomplex> - %30 = stablehlo.broadcast_in_dim %24, dims = [] : (tensor) -> tensor<1xi1> - %31 = stablehlo.constant dense<0x7FF8000000000000> : tensor - %32 = stablehlo.broadcast_in_dim %31, dims = [] : (tensor) -> tensor<8xf64> - %33 = stablehlo.broadcast_in_dim %30, dims = [0] : (tensor<1xi1>) -> tensor<8xi1> - %34 = stablehlo.select %33, %17, %32 : tensor<8xi1>, tensor<8xf64> - return %29, %34 : tensor<8x8xcomplex>, tensor<8xf64> - } - func.func private @tril(%arg0: tensor<8x8xcomplex>) -> tensor<8x8xcomplex> { - %0 = stablehlo.iota dim = 0 : tensor<8x8xi32> - %1 = stablehlo.constant dense<0> : tensor - %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<8x8xi32> - %3 = stablehlo.add %0, %2 : tensor<8x8xi32> - %4 = stablehlo.iota dim = 1 : tensor<8x8xi32> - %5 = stablehlo.compare GE, %3, %4, SIGNED : (tensor<8x8xi32>, tensor<8x8xi32>) -> tensor<8x8xi1> - %6 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor>) -> tensor<8x8xcomplex> - %8 = stablehlo.select %5, %arg0, %7 : tensor<8x8xi1>, tensor<8x8xcomplex> - return %8 : tensor<8x8xcomplex> - } -} -""", - mlir_module_serialized=b"ML\xefR\x03MLIRxxx-trunk\x00\x015\x05\x01\x05\x01\x03\x05\x03%\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f!#%')\x03\xc6\x02\x1e\x02?\x01\xa9\x0f\x17\x13\x0b\x17\x0b\x07\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x0b\x13\x0f\x0b\x0b\x17\x0f\x13\x13\x0b33\x0b\x0f\x0b\x0b\x13\x0f\x0b\x1b\x0f\x0b\x13\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x13\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x13\x0b\x13\x0b\x0b\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x13\x13\x13\x13\x1b\x17\x03a\x0f/\x0b\x0b\x0f\x0b\x0bO\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0bOOO\x1f\x1f\x0b\x0b\x0b\x0b\x1b\x0f\x17#\x0f\x0f\x0f\x0f\x0f\x0f\x0bOO//\x01\x07\x17\x17\x17\x03?\x17\x0f\x07\x0f\x07\x13\x07\x07\x0b\x17\x17\x07\x17\x13\x17\x17\x13\x0f\x17\x17\x13\x17#\x13\x13\x13\x0f\x17\x13\x13\x13\x02\x96\n\x1d\x83\x03\x17\x13\xf6\x04\x01\x03\x03\x15\xd3\x05+\x17\x13\xf2\x04\x01\x05-\x1f\x1d9\x03\x05/\x051\x053\x055\x057\x059\x05;\x03\x03!\xcf\x05=\x03\x03\x07\xd1\x1d?\x03\x05?\x05A\x17\x13\xea\x04\x01\x1d}\t\x03\x03\x07\xdf\x03\x03\x113\x05C\x03\x0b\x17\xad\x19\xb9\x1b\xbb\x11\xc5\x1d\xc7\x03\x0b\x17\xb1\x19\xcb\x1b\xb1\x11\xb3\x1d\xcd\x05E\x1d=\x03\x05G\x05I\x03\x03!\xd5\x1dE\x03\x05K\x03\x05'\xb5)\xd7\x1dK\x03\x05M\x03\x03\x07\xd9\x1dQ\x03\x05O\x1dU\x03\x05Q\x1dY+\x05S\x1d]+\x05U\x03\x03a\xdb\x05W\x1de\t\x05Y\x1di\t\x05[\x1dm\t\x05]\x1dq\t\x05_\x1du\t\x05a\x1dy\t\x05c\x03\x03\x07\xdd\x05e\x03\x03\x81\xb3\x05g\x05i\x03\x03\x07\xe1\x03\x11\x89\xe3\x8b\xe5\x8d\xe7\x8f\xad\x91\xe9\x93\xeb\x95\xed\x97\xf1\x05k\x05m\x05o\x05q\x05s\x05u\x05w\x05y\x03\x03\x0b\xf3\x03\x03\x0b\xf5\x03\x03\x0b\xf7\x03\x03\x0b\xf9\x03\x03\x0b\xfb\x03\x03\x0b\xfd\x03\x05'\xb5)\xff\x03\x03\x07\x02\x02\x1f/\x01\x1f3\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1d{\x03\x03\xc9\x1d}\t\x07\x1f1!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00#%\x03\x05\xbd\xc1\r\x03\xaf\xbf\x1d\x7f\r\x03\xaf\xc3\x1d\x81\x1d\x83\x1d\x85\r\x01#'\x1d\x87\x13\r\x01\x1f\x03\t\x00\x00\x00\x00\x1f)\x01\x13\r\x05\x07\x05\x1f\x07!\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x1b!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x07!\x00\x00\x00\x00\x00\x00\x00@\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x03\t\x01\x00\x00\x00\x1f\x03\t\x08\x00\x00\x00\x0b\x05\x1d\x89\x1d\x8b\x05\x01\x03\t\xa9\xa9\xa9\xb7\x03\x03\xef\x15\x03\x01\r\x01\x03\r\xb7\xab\xa9\xab\xab\xab\x13\x05\x01\x13\x05\x05\x13\x05\t\x13\x05\r\x13\x05\x11\x13\x05\x15\x07\x01\x1f\x07!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x1b!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f#\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f=\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x03\x15\x06\x02\x03\x03\x07\n\x02\x03\x03\x15\x0e\x02)\x05!!\x11)\x01\x05\x1b)\x01\x11\x0b)\x03!\t\x1d\x01\x03\t)\x05!!\x05)\x05!!\t\x13)\x05!!\x0f)\x03\t\r)\x03\x8a\x02\x11)\x03J\x05\t)\x03\xad\x05)\x01\t\x11\x01\x05\x01\x0b\x11\x03\x01\x03\x01)\x03\x01\r)\x03\x02\x02\x11/\r\x01\x0b\x03\x1d\x1f!)\x03\x01\x17)\x03\t\x17)\x03\x05\x17)\x01\x0f)\x05\x05\x05\x0f)\x03\x05\x0f)\x03!\x0f)\x03\x05\r\x04\xda\x05\x05\x01\x11\r1\x07\x03\x01\t\r\x11\r5\x05\x03G\x91\t\x03W\x1f\x03+\x15\x06[\x03\x01\x03\x01\x17\x07c_\x03\x01\x03\x03\x19\x06g\x03\x15\x03\x05\x1b\x06k\x03\x15\x03\x05\x1d\x06o\x03\x15\x03\t\x1f\x06s\x03\x01\x05\x07\x0b\x0f\x06w\x03\x01\x05\x03\r\x05\x03\r{\x03\x07\x03\x07-\x05\x03\x01\x03\x11!\x06-\x03\x01\x05\x0f\x13#\x07\x0f\x7f\x03\x01\x03\x15\x05\x03\x01/\x03\x03\x05\x03\x01/\x03\x03\x05\x03\x01\x85\x03\x03%\x07\x01\x87\x03-\t\x19\x1b\x1d\x17\x07\x07\x01\x99\x03\x01\x03\x1f\x07\x07\x01\x9b\x03\x0b\x03\x1f\x07\x07\x01\x9d\x03\x03\x03\x1f\x07\x07\x01\x9f\x03\x1d\x03\x1f\x07\x07\x01\xa1\x03\x1f\x03\x1f\x07\x07\x01\xa3\x03!\x03\x1f\x05\x03\x01#\x03\x03\x03\x07\x01\x05\x03\x03\x03-\x11\x07\x01\xa5\x035\x05%/\x03\x07\x01\x05\x037\x031\x05\x03\x01\xa7\x03\x07\x03\x07\x01\x05\x03\x01\x035\x03\x07\x01\x12\x02\x03\x19\x033\x0b\x06\x01\x03\x01\x079!7\x03\x07\x01\x05\x039\x031\x05\x03\x01\x16\x02\x03#\x03\x07\x01\x05\x03\x0b\x03?\x03\x07\x01\x1a\x02\x03;\x03=\x0b\x06\x01\x03\x0b\x07C#A\x13\x04\r\x05;E\r\x11\x0f7\x05\x03\x15+\x03\x01\r\t\x03;\x1f\x03\x13\x05\x03\x0f#\x03\x03\x03\x07%\x05\x03\x13\x03\x05\x0f\x06%\x03\x13\x05\x03\x07\t\x03CA\x03\x13\x11\x07IG\x03\x19\x05\t\x0b\x05\x03\x0fM\x03\x07\x03\x07O\x05\x03\x01\x03\x0f\x0b\x06S\x03\x01\x07\r\x01\x11\x13\x04\x0f\x03\x13\x06\x03\x01\x05\x01\x00J\x1c\x8d\x1d\x03\x11\x0f\x0b\t\t\x0b!\x1f/!!)#\x1f\x19\x7f\x0f99A9;;m\x19\x85\x8fW\xb3K\x9bM\x9b\x96\x04\x1b+\x1b\x1f\x1f\x15\x1d\x15+\x83\x13\r\r\x1f\x11\x15\x17\x15\x11\x11\x1b\x17\x15\x17\x0f\x11\x15\x11+\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00get_tuple_element_v1\x00iota_v1\x00select_v1\x00func_v1\x00add_v1\x00compare_v1\x00return_v1\x00reshape_v1\x00transpose_v1\x00real_v1\x00imag_v1\x00negate_v1\x00complex_v1\x00divide_v1\x00call_v1\x00custom_call_v1\x00value\x00index\x00sym_name\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00broadcast_dimensions\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00compare_type\x00comparison_direction\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=tril in_positional_semantics=(<_PositionalSemantics.GLOBAL: 1>,) out_positional_semantics=_PositionalSemantics.GLOBAL keep_unused=False inline=False]\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=0]\x00jit()/jit(main)/jit(tril)/add\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=1]\x00jit()/jit(main)/jit(tril)/ge\x00jit()/jit(main)/jit(tril)/broadcast_in_dim[shape=(8, 8) broadcast_dimensions=()]\x00jit()/jit(main)/jit(tril)/select_n\x00jit()/jit(main)/iota[dtype=complex128 shape=(64,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(8, 8) dimensions=None]\x00permutation\x00jit()/jit(main)/transpose[permutation=(1, 0)]\x00jit()/jit(main)/real\x00jit()/jit(main)/imag\x00jit()/jit(main)/neg\x00jit()/jit(main)/complex\x00jit()/jit(main)/add\x00jit()/jit(main)/div\x00callee\x00jit()/jit(main)/eigh[lower=True sort_eigenvalues=True]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jax.result_info\x00tril\x00[0]\x00[1]\x00main\x00public\x00private\x00\x00lapack_zheevd\x00", - xla_call_module_version=4, - ), # End paste -) - data_2024_08_19 = {} - # Pasted from the test output (see export_back_compat_test_util.py module docstring) data_2024_08_19["c128"] = dict( testdata_version=1, diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_hessenberg_lapack_gehrd.py b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_hessenberg_lapack_gehrd.py index 204af8f55396..8d87c2524e64 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_hessenberg_lapack_gehrd.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_hessenberg_lapack_gehrd.py @@ -17,275 +17,8 @@ import datetime from numpy import array, float32, complex64 -data_2024_08_30 = {} - - -# Pasted from the test output (see export_back_compat_test_util.py module docstring) -data_2024_08_30["c128"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_zgehrd'], - serialized_date=datetime.date(2024, 8, 30), - inputs=(), - expected_outputs=(array([[[ 0.7137638961069523 +2.4533812415320035e+00j, - -0.3272236912989258 -3.2003874808591863e+00j, - -3.065817294924296 +1.6978219378771007e+00j, - -3.3971558164664 +2.6931967836060400e-01j], - [ 6.346214936866542 +0.0000000000000000e+00j, - 2.083218259144673 -1.2191838498692813e+00j, - 1.9552582313969427 -3.3216313521481879e+00j, - 2.7451664155727293 +2.5460553490974451e+00j], - [-0.16133388943502391 +3.6906265775683444e-01j, - -4.698636849217318 +0.0000000000000000e+00j, - 2.5396292124414077 -3.3038474840573420e+00j, - 2.5410992366186456 +4.1958389320867528e-01j], - [ 0.47396123039280513 +3.9524384493417053e-03j, - 0.058880409351504966-7.8934332132630333e-02j, - 0.9469634796174572 +0.0000000000000000e+00j, - -3.130422531669044 -8.8070401977461810e-01j]], - - [[-6.7065483048969465 -4.1981401054281309e-01j, - -0.21813268822330256 -3.8602920478381799e+00j, - -0.8248337528620167 -2.9073223456990824e+00j, - -3.597231249446879 +2.7626541679004930e+00j], - [-6.812126638479044 +0.0000000000000000e+00j, - -0.20651586628458585 -1.0948249928988512e+00j, - -1.6675586608354327 +4.2553627621795744e+00j, - -2.410110723267707 +3.6065122124698634e-01j], - [ 0.038235817369200516-3.7823713529009173e-01j, - -8.508141062606947 +0.0000000000000000e+00j, - 4.260708077719245 -6.8052584397204630e-02j, - 5.345997177836541 -1.1955161503390279e+00j], - [-0.18541509608158574 -1.2016051097247168e-01j, - -0.02698777746917469 -4.4847463691672246e-01j, - 6.149305574585603 +0.0000000000000000e+00j, - -2.483131585236393 +2.8524912589603817e+00j]]]), array([[1.2286220194325557+0.5121060656500841j , - 1.9529937219183482-0.23299856112387676j, - 1.5940499664125072-0.8044281430962614j ], - [1.6682114302246909-0.11372755955977935j, - 1.4075913155446236-0.6008708461880701j , - 1.5086928152468893-0.8609480935086589j ]])), - mlir_module_text=r""" -module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<2x4x4xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x3xcomplex> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { - %cst = stablehlo.constant dense<[[[(0.71376389610695234,2.4533812415320035), (-1.0686093138739379,-1.885041510645256), (3.2629529488994033,-0.87160041258342402), (2.4332168907311504,3.4960248990882183)], [(-1.450884474619478,-3.249935163088522), (0.53920035905924757,-5.0056840575116066), (0.13157186736298554,2.5015499854549939), (-1.2451270607408882,0.24345856951924827)], [(2.457366083193417,-2.3532935513245605), (-0.37595429769485644,1.5729223427874068), (3.5877693970448052,-0.30904304334212157), (-1.685615117470264,2.6148811836470265)], [(-3.6826776618664727,-1.5711608241015744), (-0.12407609317204518,-4.7137561145212281), (1.3298255603911306,-1.6739172003954141), (-2.6345448161870149,-0.089008252847513236)]], [[(-6.7065483048969465,-0.41981401054281309), (-2.1586544949255457,0.34815132010709054), (-5.1462488701272413,3.440817752555807), (1.0301804086076078,-0.6994760434270566)], [(4.551940883969797,-0.77472653800638502), (4.4485186470774796,-0.0024458890677252756), (0.66610302132250898,2.5976571401862039), (-5.0693248202533674,-5.7405538897950699)], [(0.14148406399087146,-4.3279346473525058), (-2.353557113110897,2.0880432773400326), (-3.2524452107293618,-0.42398740171508631), (3.7200566224095519,-0.56951559566037058)], [(-2.2001612082232613,-1.2218661647417151), (0.72437359623190833,8.6381970213061301), (0.72314820631775734,0.058458198280771749), (0.37498718985014962,2.1160469724471378)]]]> : tensor<2x4x4xcomplex> loc(#loc) - %c = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_0 = stablehlo.constant dense<1> : tensor loc(#loc2) - %c_1 = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_2 = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_3 = stablehlo.constant dense<2> : tensor loc(#loc2) - %c_4 = stablehlo.constant dense<4288> : tensor loc(#loc2) - %0:4 = stablehlo.custom_call @lapack_zgehrd(%c, %c_0, %c_1, %c_2, %c_3, %c_4, %cst) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor<2x4x4xcomplex>) -> (tensor<2x4x4xcomplex>, tensor<2x3xcomplex>, tensor<2xi32>, tensor<4288xcomplex>) loc(#loc2) - %c_5 = stablehlo.constant dense<0> : tensor loc(#loc2) - %1 = stablehlo.broadcast_in_dim %c_5, dims = [] : (tensor) -> tensor<2xi32> loc(#loc2) - %2 = stablehlo.compare EQ, %0#2, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc2) - %3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc2) - %cst_6 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc2) - %4 = stablehlo.broadcast_in_dim %cst_6, dims = [] : (tensor>) -> tensor<2x4x4xcomplex> loc(#loc2) - %5 = stablehlo.broadcast_in_dim %3, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc2) - %6 = stablehlo.select %5, %0#0, %4 : tensor<2x4x4xi1>, tensor<2x4x4xcomplex> loc(#loc2) - %7 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc2) - %cst_7 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc2) - %8 = stablehlo.broadcast_in_dim %cst_7, dims = [] : (tensor>) -> tensor<2x3xcomplex> loc(#loc2) - %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x3xi1> loc(#loc2) - %10 = stablehlo.select %9, %0#1, %8 : tensor<2x3xi1>, tensor<2x3xcomplex> loc(#loc2) - return %6, %10 : tensor<2x4x4xcomplex>, tensor<2x3xcomplex> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc = loc(unknown) -#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":697:13) -#loc2 = loc("jit(func)/jit(main)/hessenberg"(#loc1)) -""", - mlir_module_serialized=b'ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1f\x05\x01\x03\x01\x03\x05\x03\x0f\x07\t\x0b\r\x0f\x11\x13\x03\xef\xa19\x01W\x0f\x0b\x07\x0b\x13\x13\x0f\x0b\x13\x13+\x0b\x0f\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x17\x0b\x13\x13\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x13\x03K\x0f\x0b\x0b\x0b\x0bo/\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b&\x10\x1f\x1f\x1f\x1f\x0b\x0b\x0b\x0b\'\x0f\x17\x1bO\x1f\x0f\x0b\x0b/OoO\x01\x05\x0b\x0f\x035\x0f\x1b\x07\x0b\x17\x07\x07\x0f\x07\x13\x17\x07\x17\x13\x13\x13\x13\x13\x13\x1b\x13\x1b\x13\x17\x17\x13\x02\xce\x0f\x1d-/\x05\x15\x1f\x05\x17\x03\x03\x03w\x03\x03\x07\x93\x11\x03\x05\x05\x19\x03\x03\x07\x99\x03\x03\x03\x9b\x03\t\x17\x19\x1b\r\x1d\r\x0f\x1f\x05\x1b\x11\x01\x00\x05\x1d\x05\x1f\x05!\x03\x0b#Y%e\'g\x0fq)s\x05#\x05%\x05\'\x05)\x03\x03\x03u\x05+\x171\xe6\n\x1b\x05-\x03\x03\x03y\x03\x03\x03{\x03\x03\x03}\x03\x11;\x7f=\x81?\x83AYC\x85E\x87G\x89I\x8d\x05/\x051\x053\x055\x057\x059\x05;\x05=\x03\x03\x03\x91\x03\x05O\x95Q\x97\x05?\x05A\x03\x03\x07\x9d\x03\x03\x07\x9f\x1f\x1f\x01\x03\x01\x1dC\x1dE\x1dG\x1f!1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f%\x11\x00\x00\x00\x00\x00\x00\x00\x00#\x19\x03\x05im\r\x05[k]_\x1dI\r\x05[o]_\x1dK\x1dM\x1dO\x1f\x07\x02\x08p\t\xdba\'\xd7\xe6?\xa8\xff\'X\x86\xa0\x03@\x0c\xa2t\x14\x06\x19\xf1\xbfT.}I!)\xfe\xbf\x0fG_\x13\x87\x1a\n@\xae:g\x8c&\xe4\xeb\xbf\xeb\x1e\xcej:w\x03@N\xaf\xfc\xe6\xdb\xf7\x0b@\x9f<\x8c\xa3\xd26\xf7\xbf^\xaf\xbc\x01\xde\xff\t\xc0b\xd4\x84\x1c!A\xe1?\xd6{\xa4\n\xd2\x05\x14\xc0\xf0\xe6\xb2\xd1X\xd7\xc0?2\xb5\x86\xa3,\x03\x04@\x91\xf2SZ\n\xec\xf3\xbf\x04\x10\x02\x81\xa6)\xcf?8\xec\x8c\x8c\xaf\xa8\x03@\r\x9d\xc6\x91\x8b\xd3\x02\xc0\xb0\xf6X\x9d\xa2\x0f\xd8\xbf\xbd\xb6V\x9e\xb0*\xf9?7-\x0fq\xc0\xb3\x0c@{|\ry\\\xc7\xd3\xbf\x04\xd9\xb2\x8eG\xf8\xfa\xbf\x9b\x84u\xd3F\xeb\x04@\xf4h\xbb\xb4\x1fv\r\xc0\xdc\\D\x88y#\xf9\xbf\x9a\xaecjs\xc3\xbf\xbf<\xc1\x04\xe2\xe2\xda\x12\xc0\x89<\xb4*\xf7F\xf5?\x1b\x90\xfef]\xc8\xfa\xbf\xdc\xf4\x8a;\x8c\x13\x05\xc0\xf8\xdd\r\xaf>\xc9\xb6\xbfvN\x1af\x81\xd3\x1a\xc0Z\xc6k\x95;\xde\xda\xbf\x87\x8c\xd8\xa5\xecD\x01\xc0\xdd\xd3zy\x1cH\xd6?\x04\x18\x89C\xc2\x95\x14\xc0\x8c\xc95u\xcb\x86\x0b@\x881\xbfs\x9e{\xf0?\x92Y[\x95\x1bb\xe6\xbf\x06\xe7\xb7\xfd/5\x12@L\x95\x02O\x8f\xca\xe8\xbf2`\xe3xH\xcb\x11@>\xda\xc6\xb1f\td\xbfZ\x1a\x8bH\xb7P\xe5?\xa8\x90zw\x00\xc8\x04@<(\xef\x15\xfdF\x14\xc0\xb4aF\xc2S\xf6\x16\xc0\xc1{\xdfY&\x1c\xc2?\xcfj\xa6\x19\xceO\x11\xc0\xc4\xa2p\xc0\x15\xd4\x02\xc0\xfcv\xa6\x08P\xb4\x00@^\xea\xa0\xfe\x01\x05\n\xc0^\x11\x12\x0e\x9c"\xdb\xbfR#\xe4\x0b\xad\xc2\r@F\x8b=\xc5x9\xe2\xbfZ\xf9\x99\x1e\xee\x99\x01\xc0My\x1a\x89\xc3\x8c\xf3\xbf\xd1\xdc<\x89\x11.\xe7?2\xd4\x8d\xc2\xc1F!@mw\t\xb5\x07$\xe7?G\x16\x99\xa3;\xee\xad?M\xd24E\xca\xff\xd7?\xa2\xae\xfb\x08\xaa\xed\x00@\x1f\x05\t\x04\x00\x00\x00\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x02\x00\x00\x00\x1f\x05\t\xc0\x10\x00\x00\x0b\x05\x1dQ\x1dS\x05\x01\x03\x0fWWWWWWa\x03\x03\x8b\x15\x03\x01\x19\x01\x03\ta\x8fcc\x1f#!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x05\t\x00\x00\x00\x00\x1f\'\x01\t\x07\x07\x01\x1f-\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x13!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f11\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x1f7!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x01\x15)\x07\t\x11\x11\x0b\x01\x03\x1b)\x05\t\r\x0b\x13\x1d)\x01\x0b\x1b)\x03\t\x15\x11\x01\x05\x07\r\x0b)\x03\x02\x86\x0b)\x03\x01\x0f)\x03\r\x0f)\x03\t\x0f)\x03\x05\x0f)\x03\x01\x11)\x03\t\t)\x07\t\x05\x05\t)\x03\x05\x11)\x07\t\x11\x11\t)\x03\r\x11)\x05\t\x05\t)\x05\t\r\t)\x03\t\x11\x04\xde\x02\x05\x01\x11\x05\x15\x07\x03\x01\x05\t\x11\x05!\x07\x031Y\x03\x03\x05+\x03\x07\x03\x03\x01\t\x03\x05\x03\x03\x013\x03\x05\x03\x03\x01\t\x03\x05\x03\x03\x01\t\x03\x05\x03\x03\x015\x03\x05\x03\x03\x017\x03\x05\x0b\x07\x019\t\x07\r\x17\x1d\x0f\x03\x05\x07\t\x0b\r\x01\x03\x03\x01K\x03\x05\x05\x07\x01\x0b\x03\x17\x03\x17\r\x07\x01M\x03)\x05\x13\x19\x05\x07\x01\x11\x03+\x03\x1b\x03\x03\x01\x13\x03\x13\x05\x07\x01\x0b\x03\x07\x03\x1f\x05\x07\x01S\x03/\x03\x1d\x07\x06\x01\x03\x07\x07#\x0f!\x05\x07\x01\x11\x033\x03\x1b\x03\x03\x01\x13\x03\x13\x05\x07\x01\x0b\x03\r\x03)\x05\x07\x01U\x035\x03\'\x07\x06\x01\x03\r\x07-\x11+\x0f\x04\x05\x05%/\x06\x03\x01\x05\x01\x00\xf2\tU\x1d\x03\x0f\x0b\t\t\x11#!+\x1b\x1f/!!)#\x1f\x19i?\x1f\x15\x1d\x15\x13%)9\x13+\r\x15\x17\x1f\x11\x15)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit(func)/jit(main)/hessenberg\x00third_party/py/jax/tests/export_back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00mhlo.layout_mode\x00default\x00[0]\x00[1]\x00main\x00public\x00\x00lapack_zgehrd\x00', - xla_call_module_version=9, - nr_devices=1, -) # End paste - - -# Pasted from the test output (see export_back_compat_test_util.py module docstring) -data_2024_08_30["c64"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_cgehrd'], - serialized_date=datetime.date(2024, 8, 30), - inputs=(), - expected_outputs=(array([[[ 5.2023945 -0.878671j , -2.8841915 -0.47488597j , - 1.3024182 +0.6651789j , 4.9291854 -1.9147056j ], - [ 6.3457894 +0.j , 1.6869383 -4.6557646j , - 0.88955224-1.7617276j , 2.9149916 +4.342665j ], - [-0.2465725 -0.5776757j , -5.3007755 +0.j , - -0.9786545 -0.0633831j , -1.3690261 -1.5921416j ], - [ 0.35462287+0.35993803j , -0.38403815-0.46558398j , - 2.8020499 +0.j , 0.5636822 -6.218306j ]], - - [[ 1.0687767 -3.88293j , -4.0144 -2.5885587j , - 5.3900986 -0.8850739j , 2.079677 +3.5515747j ], - [ 7.5675693 +0.j , 0.5971966 -3.6699948j , - 2.246994 -1.0858283j , -0.8870981 -0.022960603j], - [-0.2183232 +0.10552277j , 5.860886 +0.j , - -5.091036 +6.2841997j , 5.008773 +1.8765848j ], - [ 0.1378771 +0.427895j , 0.63263524-0.3470098j , - 6.4528017 +0.j , -4.233642 -0.84165764j ]]], - dtype=complex64), array([[1.0933675-0.3605358j , 1.1987956+0.5659744j , - 1.9999101-0.013409062j], - [1.4504763-0.44363326j , 1.3110259-0.07426627j , - 1.227255 +0.97383535j ]], dtype=complex64)), - mlir_module_text=r""" -module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<2x4x4xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x3xcomplex> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { - %cst = stablehlo.constant dense<[[[(5.20239449,-0.87867099), (-0.211780012,-0.923053801), (-5.25181627,1.90887547), (-1.61342144,-1.98000157)], [(-5.924900e-01,2.28788424), (-1.74142945,-3.25563216), (3.08765078,-3.25260139), (-3.35189271,-0.571629047)], [(3.032444,3.44394636), (1.22205484,0.808871626), (2.58686161,-7.47011566), (1.9139297,-2.57945323)], [(-3.28396916,-1.68601465), (2.62759161,-0.953538239), (-2.78763294,-0.0429570749), (0.426534384,-0.211706176)]], [[(1.06877673,-3.882930e+00), (-0.0192247611,5.96663713), (1.15329504,-5.0599103), (-1.76508892,-1.98541296)], [(-3.40901089,3.35722542), (-6.13531398,2.55851483), (-4.8095789,0.164206699), (-0.247624069,-3.13545418)], [(2.04217815,-1.89123917), (-1.18974173,-1.69466627), (-2.28673625,-0.487834573), (3.01541853,-1.85637176)], [(-2.9499588,-4.23393869), (8.44624137,5.57274485), (-1.09048736,2.4864223), (-0.305431545,-0.298133373)]]]> : tensor<2x4x4xcomplex> loc(#loc) - %c = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_0 = stablehlo.constant dense<1> : tensor loc(#loc2) - %c_1 = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_2 = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_3 = stablehlo.constant dense<2> : tensor loc(#loc2) - %c_4 = stablehlo.constant dense<4288> : tensor loc(#loc2) - %0:4 = stablehlo.custom_call @lapack_cgehrd(%c, %c_0, %c_1, %c_2, %c_3, %c_4, %cst) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor<2x4x4xcomplex>) -> (tensor<2x4x4xcomplex>, tensor<2x3xcomplex>, tensor<2xi32>, tensor<4288xcomplex>) loc(#loc2) - %c_5 = stablehlo.constant dense<0> : tensor loc(#loc2) - %1 = stablehlo.broadcast_in_dim %c_5, dims = [] : (tensor) -> tensor<2xi32> loc(#loc2) - %2 = stablehlo.compare EQ, %0#2, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc2) - %3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc2) - %cst_6 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc2) - %4 = stablehlo.broadcast_in_dim %cst_6, dims = [] : (tensor>) -> tensor<2x4x4xcomplex> loc(#loc2) - %5 = stablehlo.broadcast_in_dim %3, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc2) - %6 = stablehlo.select %5, %0#0, %4 : tensor<2x4x4xi1>, tensor<2x4x4xcomplex> loc(#loc2) - %7 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc2) - %cst_7 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc2) - %8 = stablehlo.broadcast_in_dim %cst_7, dims = [] : (tensor>) -> tensor<2x3xcomplex> loc(#loc2) - %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x3xi1> loc(#loc2) - %10 = stablehlo.select %9, %0#1, %8 : tensor<2x3xi1>, tensor<2x3xcomplex> loc(#loc2) - return %6, %10 : tensor<2x4x4xcomplex>, tensor<2x3xcomplex> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc = loc(unknown) -#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":697:13) -#loc2 = loc("jit(func)/jit(main)/hessenberg"(#loc1)) -""", - mlir_module_serialized=b'ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1f\x05\x01\x03\x01\x03\x05\x03\x0f\x07\t\x0b\r\x0f\x11\x13\x03\xef\xa19\x01W\x0f\x0b\x07\x0b\x13\x13\x0f\x0b\x13\x13+\x0b\x0f\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x17\x0b\x13\x13\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x13\x03K\x0f\x0b\x0b\x0b\x0bo/\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b&\x08\x1f\x1f\x1f\x1f\x0b\x0b\x0b\x0b\'\x0f\x17\x1bO\x1f\x0f\x0b\x0b//oO\x01\x05\x0b\x0f\x035\x0f\x1b\x07\x0b\x17\x07\x07\x0f\x07\x13\x17\x07\x17\x13\x13\x13\x13\x13\x13\x1b\x13\x1b\x13\x17\x17\x13\x02\xae\x0b\x1d-/\x05\x15\x1f\x05\x17\x03\x03\x03w\x03\x03\x07\x93\x11\x03\x05\x05\x19\x03\x03\x07\x99\x03\x03\x03\x9b\x03\t\x17\x19\x1b\r\x1d\r\x0f\x1f\x05\x1b\x11\x01\x00\x05\x1d\x05\x1f\x05!\x03\x0b#Y%e\'g\x0fq)s\x05#\x05%\x05\'\x05)\x03\x03\x03u\x05+\x171\xe6\n\x1b\x05-\x03\x03\x03y\x03\x03\x03{\x03\x03\x03}\x03\x11;\x7f=\x81?\x83AYC\x85E\x87G\x89I\x8d\x05/\x051\x053\x055\x057\x059\x05;\x05=\x03\x03\x03\x91\x03\x05O\x95Q\x97\x05?\x05A\x03\x03\x07\x9d\x03\x03\x07\x9f\x1f\x1f\x01\x03\x01\x1dC\x1dE\x1dG\x1f!1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f%\x11\x00\x00\x00\x00\x00\x00\x00\x00#\x19\x03\x05im\r\x05[k]_\x1dI\r\x05[o]_\x1dK\x1dM\x1dO\x1f\x07\x02\x04\x04z\xa6@\x95\xf0`\xbf\xdc\xdcX\xbeAMl\xbf\xe1\x0e\xa8\xc0\x08V\xf4?\x98\x84\xce\xbf\xb1p\xfd\xbfm\xad\x17\xbf\xb2l\x12@)\xe7\xde\xbfG\\P\xc0\x12\x9cE@\x9f*P\xc0i\x85V\xc0HV\x12\xbf\x90\x13B@\x9ei\\@Kl\x9c?6\x12O?$\x8f%@0\x0b\xef\xc0\xa6\xfb\xf4?\xc3\x15%\xc0\x8d,R\xc0T\xcf\xd7\xbfv*(@\x15\x1bt\xbf\x94h2\xc0\xc2\xf3/\xbd\xb7b\xda>\x81\xc9X\xbe\xad\xcd\x88?\xed\x81x\xc0?}\x9d\xbc\xb1\xee\xbe@,\x9f\x93?\xc9\xea\xa1\xc0o\xee\xe1\xbf\x03"\xfe\xbf<-Z\xc0\xc8\xdcV@~T\xc4\xc0\xb5\xbe#@\x12\xe8\x99\xc0\xcd%(>*\x91}\xbeH\xabH\xc0\x0c\xb3\x02@ \x14\xf2\xbfuI\x98\xbf\xd3\xea\xd8\xbf\xe3Y\x12\xc0t\xc5\xf9\xbe\x9e\xfc@@\x97\x9d\xed\xbf \xcc<\xc0m|\x87\xc0\xce#\x07A\xedS\xb2@\x17\x95\x8b\xbf\x8b!\x1f@\x86a\x9c\xbe\xf0\xa4\x98\xbe\x1f\x05\t\x04\x00\x00\x00\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x02\x00\x00\x00\x1f\x05\t\xc0\x10\x00\x00\x0b\x05\x1dQ\x1dS\x05\x01\x03\x0fWWWWWWa\x03\x03\x8b\x15\x03\x01\x19\x01\x03\ta\x8fcc\x1f#!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x05\t\x00\x00\x00\x00\x1f\'\x01\t\x07\x07\x01\x1f-\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x13\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f11\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x1f7!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x01\x15)\x07\t\x11\x11\x0b\x01\x03\x1b)\x05\t\r\x0b\x13\x1d)\x01\x0b\x1b)\x03\t\x15\x11\x01\x05\x07\r\t)\x03\x02\x86\x0b)\x03\x01\x0f)\x03\r\x0f)\x03\t\x0f)\x03\x05\x0f)\x03\x01\x11)\x03\t\t)\x07\t\x05\x05\t)\x03\x05\x11)\x07\t\x11\x11\t)\x03\r\x11)\x05\t\x05\t)\x05\t\r\t)\x03\t\x11\x04\xde\x02\x05\x01\x11\x05\x15\x07\x03\x01\x05\t\x11\x05!\x07\x031Y\x03\x03\x05+\x03\x07\x03\x03\x01\t\x03\x05\x03\x03\x013\x03\x05\x03\x03\x01\t\x03\x05\x03\x03\x01\t\x03\x05\x03\x03\x015\x03\x05\x03\x03\x017\x03\x05\x0b\x07\x019\t\x07\r\x17\x1d\x0f\x03\x05\x07\t\x0b\r\x01\x03\x03\x01K\x03\x05\x05\x07\x01\x0b\x03\x17\x03\x17\r\x07\x01M\x03)\x05\x13\x19\x05\x07\x01\x11\x03+\x03\x1b\x03\x03\x01\x13\x03\x13\x05\x07\x01\x0b\x03\x07\x03\x1f\x05\x07\x01S\x03/\x03\x1d\x07\x06\x01\x03\x07\x07#\x0f!\x05\x07\x01\x11\x033\x03\x1b\x03\x03\x01\x13\x03\x13\x05\x07\x01\x0b\x03\r\x03)\x05\x07\x01U\x035\x03\'\x07\x06\x01\x03\r\x07-\x11+\x0f\x04\x05\x05%/\x06\x03\x01\x05\x01\x00\xf2\tU\x1d\x03\x0f\x0b\t\t\x11#!+\x1b\x1f/!!)#\x1f\x19i?\x1f\x15\x1d\x15\x13%)9\x13+\r\x15\x17\x1f\x11\x15)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit(func)/jit(main)/hessenberg\x00third_party/py/jax/tests/export_back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00mhlo.layout_mode\x00default\x00[0]\x00[1]\x00main\x00public\x00\x00lapack_cgehrd\x00', - xla_call_module_version=9, - nr_devices=1, -) # End paste - - -# Pasted from the test output (see export_back_compat_test_util.py module docstring) -data_2024_08_30["f32"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_sgehrd'], - serialized_date=datetime.date(2024, 8, 30), - inputs=(), - expected_outputs=(array([[[-3.5237675 , -6.1161256 , -0.549011 , -4.7706876 ], - [ 5.8401766 , 3.424213 , 0.3059119 , 2.3492367 ], - [ 0.63135445 , 2.7238827 , -0.106214404, -0.82470125 ], - [-0.27146497 , 0.09917235 , 0.2545611 , -0.5113605 ]], - - [[ 4.297168 , -1.8758869 , 0.33528137 , 5.867136 ], - [-7.129698 , -3.3118155 , -1.3492918 , -2.8959117 ], - [-0.7266852 , -3.506432 , 4.77164 , -4.0780373 ], - [ 0.14084078 , 0.3389384 , 2.3910007 , -0.79807365 ]]], - dtype=float32), array([[1.3584172, 1.9805213, 0. ], - [1.2920669, 1.7939165, 0. ]], dtype=float32)), - mlir_module_text=r""" -module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<2x4x4xf32> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x3xf32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { - %cst = stablehlo.constant dense<[[[-3.52376747, -0.758410036, 4.85795927, -6.0243597], [-2.09321976, -1.27957773, -0.956288218, -1.11928439], [-5.00878525, 0.51314038, 3.53047514, -2.91282868], [2.15363932, 0.635739565, -0.21264787, 0.555740714]], [[4.29716778, -3.86209464, -2.39021468, 4.17441607], [2.08234859, -1.03958249, 4.09025383, 5.22586823], [-6.69425774, 3.43749118, -0.691099107, 1.59547663], [1.29743183, -2.00156212, 3.08750296, 2.39243269]]]> : tensor<2x4x4xf32> loc(#loc) - %c = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_0 = stablehlo.constant dense<1> : tensor loc(#loc2) - %c_1 = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_2 = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_3 = stablehlo.constant dense<2> : tensor loc(#loc2) - %c_4 = stablehlo.constant dense<4288> : tensor loc(#loc2) - %0:4 = stablehlo.custom_call @lapack_sgehrd(%c, %c_0, %c_1, %c_2, %c_3, %c_4, %cst) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor<2x4x4xf32>) -> (tensor<2x4x4xf32>, tensor<2x3xf32>, tensor<2xi32>, tensor<4288xf32>) loc(#loc2) - %c_5 = stablehlo.constant dense<0> : tensor loc(#loc2) - %1 = stablehlo.broadcast_in_dim %c_5, dims = [] : (tensor) -> tensor<2xi32> loc(#loc2) - %2 = stablehlo.compare EQ, %0#2, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc2) - %3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc2) - %cst_6 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc2) - %4 = stablehlo.broadcast_in_dim %cst_6, dims = [] : (tensor) -> tensor<2x4x4xf32> loc(#loc2) - %5 = stablehlo.broadcast_in_dim %3, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc2) - %6 = stablehlo.select %5, %0#0, %4 : tensor<2x4x4xi1>, tensor<2x4x4xf32> loc(#loc2) - %7 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc2) - %cst_7 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc2) - %8 = stablehlo.broadcast_in_dim %cst_7, dims = [] : (tensor) -> tensor<2x3xf32> loc(#loc2) - %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x3xi1> loc(#loc2) - %10 = stablehlo.select %9, %0#1, %8 : tensor<2x3xi1>, tensor<2x3xf32> loc(#loc2) - return %6, %10 : tensor<2x4x4xf32>, tensor<2x3xf32> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc = loc(unknown) -#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":697:13) -#loc2 = loc("jit(func)/jit(main)/hessenberg"(#loc1)) -""", - mlir_module_serialized=b'ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1f\x05\x01\x03\x01\x03\x05\x03\x0f\x07\t\x0b\r\x0f\x11\x13\x03\xed\xa17\x01W\x0f\x0b\x07\x0b\x13\x13\x0f\x0b\x13\x13+\x0b\x0f\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x17\x0b\x13\x13\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x13\x03K\x0f\x0b\x0b\x0b\x0bo/\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b&\x04\x1f\x1f\x1f\x1f\x0b\x0b\x0b\x0b\'\x0f\x17\x1bO\x1f\x0f\x0b\x0b/\x1foO\x01\x05\x0b\x0f\x033\x0f\x1b\x07\x07\x17\x07\x07\x0f\x07\x13\x17\x17\x13\x13\x13\x13\x13\x13\x1b\x13\x1b\x13\x17\x17\x13\x02\x96\t\x1d-/\x05\x15\x1f\x05\x17\x03\x03\x03w\x03\x03\x07\x93\x11\x03\x05\x05\x19\x03\x03\x07\x99\x03\x03\x03\x9b\x03\t\x17\x19\x1b\r\x1d\r\x0f\x1f\x05\x1b\x11\x01\x00\x05\x1d\x05\x1f\x05!\x03\x0b#Y%e\'g\x0fq)s\x05#\x05%\x05\'\x05)\x03\x03\x03u\x05+\x171\xe6\n\x1b\x05-\x03\x03\x03y\x03\x03\x03{\x03\x03\x03}\x03\x11;\x7f=\x81?\x83AYC\x85E\x87G\x89I\x8d\x05/\x051\x053\x055\x057\x059\x05;\x05=\x03\x03\x03\x91\x03\x05O\x95Q\x97\x05?\x05A\x03\x03\x07\x9d\x03\x03\x07\x9f\x1f\x1d\x01\x03\x01\x1dC\x1dE\x1dG\x1f\x1f1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f#\x11\x00\x00\x00\x00\x00\x00\x00\x00#\x19\x03\x05im\r\x05[k]_\x1dI\r\x05[o]_\x1dK\x1dM\x1dO\x1f\x07\x02\x02h\x85a\xc0)\'B\xbfgt\x9b@\x8e\xc7\xc0\xc0P\xf7\x05\xc04\xc9\xa3\xbfN\xcft\xbf\xb6D\x8f\xbf\xf8G\xa0\xc0+]\x03?N\xf3a@\xc9k:\xc0:\xd5\t@\xd4\xbf"?]\xc0Y\xbe\x06E\x0e?f\x82\x89@\x8f,w\xc0G\xf9\x18\xc0\xd1\x94\x85@3E\x05@\n\x11\x85\xbf\\\xe3\x82@P:\xa7@\\7\xd6\xc0\xdb\xff[@\xdf\xeb0\xbf\x948\xcc??\x12\xa6?\x98\x19\x00\xc0\xa6\x99E@\x9e\x1d\x19@\x1f\x05\t\x04\x00\x00\x00\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x02\x00\x00\x00\x1f\x05\t\xc0\x10\x00\x00\x0b\x05\x1dQ\x1dS\x05\x01\x03\x0fWWWWWWa\x03\x03\x8b\x15\x03\x01\x19\x01\x03\ta\x8fcc\x1f!!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x05\t\x00\x00\x00\x00\x1f%\x01\t\x07\x07\x01\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x13\t\x00\x00\xc0\x7f\x1f/1\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x1f5!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x01\x15)\x07\t\x11\x11\x0b\x01\t)\x05\t\r\x0b\x13\x1d)\x01\x0b\x1b)\x03\t\x15\x11\x01\x05\x07\r)\x03\x02\x86\x0b)\x03\x01\x0f)\x03\r\x0f)\x03\t\x0f)\x03\x05\x0f)\x03\x01\x11)\x03\t\t)\x07\t\x05\x05\t)\x03\x05\x11)\x07\t\x11\x11\t)\x03\r\x11)\x05\t\x05\t)\x05\t\r\t)\x03\t\x11\x04\xde\x02\x05\x01\x11\x05\x15\x07\x03\x01\x05\t\x11\x05!\x07\x031Y\x03\x03\x05+\x03\x07\x03\x03\x01\t\x03\x05\x03\x03\x013\x03\x05\x03\x03\x01\t\x03\x05\x03\x03\x01\t\x03\x05\x03\x03\x015\x03\x05\x03\x03\x017\x03\x05\x0b\x07\x019\t\x07\r\x17\x1b\x0f\x03\x05\x07\t\x0b\r\x01\x03\x03\x01K\x03\x05\x05\x07\x01\x0b\x03\x17\x03\x17\r\x07\x01M\x03\'\x05\x13\x19\x05\x07\x01\x11\x03)\x03\x1b\x03\x03\x01\x13\x03\x13\x05\x07\x01\x0b\x03\x07\x03\x1f\x05\x07\x01S\x03-\x03\x1d\x07\x06\x01\x03\x07\x07#\x0f!\x05\x07\x01\x11\x031\x03\x1b\x03\x03\x01\x13\x03\x13\x05\x07\x01\x0b\x03\r\x03)\x05\x07\x01U\x033\x03\'\x07\x06\x01\x03\r\x07-\x11+\x0f\x04\x05\x05%/\x06\x03\x01\x05\x01\x00\xf2\tU\x1d\x03\x0f\x0b\t\t\x11#!+\x1b\x1f/!!)#\x1f\x19i?\x1f\x15\x1d\x15\x13%)9\x13+\r\x15\x17\x1f\x11\x15)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit(func)/jit(main)/hessenberg\x00third_party/py/jax/tests/export_back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00mhlo.layout_mode\x00default\x00[0]\x00[1]\x00main\x00public\x00\x00lapack_sgehrd\x00', - xla_call_module_version=9, - nr_devices=1, -) # End paste - - -# Pasted from the test output (see export_back_compat_test_util.py module docstring) -data_2024_08_30["f64"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_dgehrd'], - serialized_date=datetime.date(2024, 8, 30), - inputs=(), - expected_outputs=(array([[[ 0.9307390587491866 , -0.35692982324474015 , - -0.1271353200176119 , -0.43952156917870067 ], - [ 2.2633695323673964 , 0.9965090965971986 , - -1.3244131008423046 , 1.7324542351344163 ], - [ 0.24558316247256504 , 2.922776762811796 , - 3.630059093036474 , 1.4330664619737252 ], - [-0.2856727718012896 , -0.4601276537179077 , - -2.8602148466873802 , 1.9928744545245372 ]], - - [[-0.5351339571818844 , 5.753313169426148 , - 0.1385440281649789 , 2.8445493054193807 ], - [ 4.676815781213274 , 2.920688567170204 , - -2.610159425457712 , 4.0359806870679655 ], - [-0.16963242599901043 , -2.342935131066633 , - 4.179999589709703 , -0.6810604472011716 ], - [ 0.030645999613174775, -0.2271804227402005 , - -2.2755242550977153 , 0.7136684502626782 ]]]), array([[1.751436143556826 , 1.6505497938190505, 0. ], - [1.9422862513069978, 1.9018440331997255, 0. ]])), - mlir_module_text=r""" -module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<2x4x4xf64> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x3xf64> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { - %cst = stablehlo.constant dense<[[[0.93073905874918661, 0.18483901505653183, -0.11804347408930886, -0.53725392025434981], [-1.700777672846173, 1.3531570270421245, -2.4375034855727518, 2.2945174202226699], [-0.97352780716312858, -0.8319788592736328, 2.4986640885328582, -2.8118637941861766], [1.1324489199416958, -1.9301638714393787, 1.5523821278819048, 2.7676215285832253]], [[-0.53513395718188439, -5.2137633671981938, 2.9644475919777618, 2.2891023676266191], [-4.4068992105328642, 1.2751848926168665, -2.8947257279736456, -2.6817410994805888], [1.5408926111334784, -0.85423691880254915, 6.4217874587762065, -0.43997818045540715], [-0.27837952612324207, 1.1509460853774549, -0.21686805683301608, 0.11738425574951133]]]> : tensor<2x4x4xf64> loc(#loc) - %c = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_0 = stablehlo.constant dense<1> : tensor loc(#loc2) - %c_1 = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_2 = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_3 = stablehlo.constant dense<2> : tensor loc(#loc2) - %c_4 = stablehlo.constant dense<4288> : tensor loc(#loc2) - %0:4 = stablehlo.custom_call @lapack_dgehrd(%c, %c_0, %c_1, %c_2, %c_3, %c_4, %cst) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor<2x4x4xf64>) -> (tensor<2x4x4xf64>, tensor<2x3xf64>, tensor<2xi32>, tensor<4288xf64>) loc(#loc2) - %c_5 = stablehlo.constant dense<0> : tensor loc(#loc2) - %1 = stablehlo.broadcast_in_dim %c_5, dims = [] : (tensor) -> tensor<2xi32> loc(#loc2) - %2 = stablehlo.compare EQ, %0#2, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc2) - %3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc2) - %cst_6 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc2) - %4 = stablehlo.broadcast_in_dim %cst_6, dims = [] : (tensor) -> tensor<2x4x4xf64> loc(#loc2) - %5 = stablehlo.broadcast_in_dim %3, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc2) - %6 = stablehlo.select %5, %0#0, %4 : tensor<2x4x4xi1>, tensor<2x4x4xf64> loc(#loc2) - %7 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc2) - %cst_7 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc2) - %8 = stablehlo.broadcast_in_dim %cst_7, dims = [] : (tensor) -> tensor<2x3xf64> loc(#loc2) - %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x3xi1> loc(#loc2) - %10 = stablehlo.select %9, %0#1, %8 : tensor<2x3xi1>, tensor<2x3xf64> loc(#loc2) - return %6, %10 : tensor<2x4x4xf64>, tensor<2x3xf64> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc = loc(unknown) -#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":697:13) -#loc2 = loc("jit(func)/jit(main)/hessenberg"(#loc1)) -""", - mlir_module_serialized=b'ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1f\x05\x01\x03\x01\x03\x05\x03\x0f\x07\t\x0b\r\x0f\x11\x13\x03\xed\xa17\x01W\x0f\x0b\x07\x0b\x13\x13\x0f\x0b\x13\x13+\x0b\x0f\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x17\x0b\x13\x13\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x13\x03K\x0f\x0b\x0b\x0b\x0bo/\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b&\x08\x1f\x1f\x1f\x1f\x0b\x0b\x0b\x0b\'\x0f\x17\x1bO\x1f\x0f\x0b\x0b//oO\x01\x05\x0b\x0f\x033\x0f\x1b\x07\x07\x17\x07\x07\x0f\x07\x13\x17\x17\x13\x13\x13\x13\x13\x13\x1b\x13\x1b\x13\x17\x17\x13\x02\xa6\x0b\x1d-/\x05\x15\x1f\x05\x17\x03\x03\x03w\x03\x03\x07\x93\x11\x03\x05\x05\x19\x03\x03\x07\x99\x03\x03\x03\x9b\x03\t\x17\x19\x1b\r\x1d\r\x0f\x1f\x05\x1b\x11\x01\x00\x05\x1d\x05\x1f\x05!\x03\x0b#Y%e\'g\x0fq)s\x05#\x05%\x05\'\x05)\x03\x03\x03u\x05+\x171\xe6\n\x1b\x05-\x03\x03\x03y\x03\x03\x03{\x03\x03\x03}\x03\x11;\x7f=\x81?\x83AYC\x85E\x87G\x89I\x8d\x05/\x051\x053\x055\x057\x059\x05;\x05=\x03\x03\x03\x91\x03\x05O\x95Q\x97\x05?\x05A\x03\x03\x07\x9d\x03\x03\x07\x9f\x1f\x1d\x01\x03\x01\x1dC\x1dE\x1dG\x1f\x1f1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f#\x11\x00\x00\x00\x00\x00\x00\x00\x00#\x19\x03\x05im\r\x05[k]_\x1dI\r\x05[o]_\x1dK\x1dM\x1dO\x1f\x07\x02\x04\xa6\x00NG\x9d\xc8\xed?\xf2\xa8X\n\xce\xa8\xc7?#E\xb8\xdc\x188\xbe\xbf\xb8|$"/1\xe1\xbf\xc4B*\xa6b6\xfb\xbf\xe8\xf9\x97\xfb\x87\xa6\xf5?)^\xd3\xd3\x01\x80\x03\xc0T\xab\xff\xf2+[\x02@4d\xb0\xc9#\'\xef\xbf~e\xf1 \x92\x9f\xea\xbf\x96\x81\xff\x98C\xfd\x03@W\xb0\xe6q\xb2~\x06\xc0F\xa48\xc2\x82\x1e\xf2?\xcc\x0b\xfc\x82\xf3\xe1\xfe\xbf\xdc\\b\xa4\x8e\xd6\xf8?\x8c\xc3\x87\xc1\x16$\x06@\x83h\xa2?\xd1\x1f\xe1\xbf\xdc\xcb\xbc\xc8\xe4\xda\x14\xc0\xe6\x00\x92L0\xb7\x07@Q8\xf1\xe6\x14P\x02@\t\x07\xc8/\xaa\xa0\x11\xc0\x8eH"F(g\xf4?\xf5Jd\xf6e(\x07\xc0\x9e\xddt\xad4t\x05\xc0\x1cv\xb7\x02\x7f\xa7\xf8?B^\xa9\xa9\xe8U\xeb\xbf\x1e:5\r\xe9\xaf\x19@\xa2\x9c\x00>\x9a(\xdc\xbf\xc1\xd1$\\\xf8\xd0\xd1\xbf}|BqFj\xf2?6\x8b\xd2\x1dU\xc2\xcb\xbfdk\x82\x03\xe5\x0c\xbe?\x1f\x05\t\x04\x00\x00\x00\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x02\x00\x00\x00\x1f\x05\t\xc0\x10\x00\x00\x0b\x05\x1dQ\x1dS\x05\x01\x03\x0fWWWWWWa\x03\x03\x8b\x15\x03\x01\x19\x01\x03\ta\x8fcc\x1f!!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x05\t\x00\x00\x00\x00\x1f%\x01\t\x07\x07\x01\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x13\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f/1\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x1f5!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x01\x15)\x07\t\x11\x11\x0b\x01\x0b)\x05\t\r\x0b\x13\x1d)\x01\x0b\x1b)\x03\t\x15\x11\x01\x05\x07\r)\x03\x02\x86\x0b)\x03\x01\x0f)\x03\r\x0f)\x03\t\x0f)\x03\x05\x0f)\x03\x01\x11)\x03\t\t)\x07\t\x05\x05\t)\x03\x05\x11)\x07\t\x11\x11\t)\x03\r\x11)\x05\t\x05\t)\x05\t\r\t)\x03\t\x11\x04\xde\x02\x05\x01\x11\x05\x15\x07\x03\x01\x05\t\x11\x05!\x07\x031Y\x03\x03\x05+\x03\x07\x03\x03\x01\t\x03\x05\x03\x03\x013\x03\x05\x03\x03\x01\t\x03\x05\x03\x03\x01\t\x03\x05\x03\x03\x015\x03\x05\x03\x03\x017\x03\x05\x0b\x07\x019\t\x07\r\x17\x1b\x0f\x03\x05\x07\t\x0b\r\x01\x03\x03\x01K\x03\x05\x05\x07\x01\x0b\x03\x17\x03\x17\r\x07\x01M\x03\'\x05\x13\x19\x05\x07\x01\x11\x03)\x03\x1b\x03\x03\x01\x13\x03\x13\x05\x07\x01\x0b\x03\x07\x03\x1f\x05\x07\x01S\x03-\x03\x1d\x07\x06\x01\x03\x07\x07#\x0f!\x05\x07\x01\x11\x031\x03\x1b\x03\x03\x01\x13\x03\x13\x05\x07\x01\x0b\x03\r\x03)\x05\x07\x01U\x033\x03\'\x07\x06\x01\x03\r\x07-\x11+\x0f\x04\x05\x05%/\x06\x03\x01\x05\x01\x00\xf2\tU\x1d\x03\x0f\x0b\t\t\x11#!+\x1b\x1f/!!)#\x1f\x19i?\x1f\x15\x1d\x15\x13%)9\x13+\r\x15\x17\x1f\x11\x15)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit(func)/jit(main)/hessenberg\x00third_party/py/jax/tests/export_back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00mhlo.layout_mode\x00default\x00[0]\x00[1]\x00main\x00public\x00\x00lapack_dgehrd\x00', - xla_call_module_version=9, - nr_devices=1, -) # End paste - data_2024_08_31 = {} - # Pasted from the test output (see export_back_compat_test_util.py module docstring) data_2024_08_31["c128"] = dict( testdata_version=1, diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_lu_lapack_getrf.py b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_lu_lapack_getrf.py index 72d97df53a4f..2290db62e436 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_lu_lapack_getrf.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_lu_lapack_getrf.py @@ -17,527 +17,8 @@ import datetime from numpy import array, int32, float32, complex64 -data_2023_06_14 = {} - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_06_14['f32'] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_sgetrf'], - serialized_date=datetime.date(2023, 6, 14), - inputs=(), - expected_outputs=(array([[6. , 7. , 8. ], - [0. , 1. , 2. ], - [0.5, 0.5, 0. ]], dtype=float32), array([2, 2, 2], dtype=int32), array([2, 0, 1], dtype=int32)), - mlir_module_text=r""" -#loc = loc(unknown) -module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<3x3xf32> {jax.result_info = "[0]"}, tensor<3xi32> {jax.result_info = "[1]"}, tensor<3xi32> {jax.result_info = "[2]"}) { - %0 = stablehlo.iota dim = 0 : tensor<9xf32> loc(#loc3) - %1 = stablehlo.reshape %0 : (tensor<9xf32>) -> tensor<3x3xf32> loc(#loc4) - %2 = stablehlo.constant dense<3> : tensor loc(#loc5) - %3 = stablehlo.constant dense<3> : tensor loc(#loc5) - %4 = stablehlo.convert %2 : (tensor) -> tensor loc(#loc5) - %5 = stablehlo.reshape %4 : (tensor) -> tensor<1xi32> loc(#loc5) - %6 = stablehlo.convert %3 : (tensor) -> tensor loc(#loc5) - %7 = stablehlo.reshape %6 : (tensor) -> tensor<1xi32> loc(#loc5) - %8 = stablehlo.concatenate %5, %7, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> loc(#loc5) - %9 = stablehlo.constant dense<3> : tensor loc(#loc5) - %10 = stablehlo.convert %9 : (tensor) -> tensor loc(#loc5) - %11 = stablehlo.reshape %10 : (tensor) -> tensor<1xi32> loc(#loc5) - %12 = stablehlo.constant dense<> : tensor<0xi32> loc(#loc5) - %13 = stablehlo.constant dense<1> : tensor loc(#loc5) - %14 = stablehlo.constant dense<3> : tensor loc(#loc5) - %15 = stablehlo.constant dense<3> : tensor loc(#loc5) - %16:3 = stablehlo.custom_call @lapack_sgetrf(%13, %14, %15, %1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>]} : (tensor, tensor, tensor, tensor<3x3xf32>) -> (tensor<3x3xf32>, tensor<3xi32>, tensor) loc(#loc5) - %17 = stablehlo.constant dense<1> : tensor loc(#loc5) - %18 = stablehlo.broadcast_in_dim %17, dims = [] : (tensor) -> tensor<3xi32> loc(#loc5) - %19 = stablehlo.subtract %16#1, %18 : tensor<3xi32> loc(#loc5) - %20 = stablehlo.constant dense<0> : tensor loc(#loc5) - %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor loc(#loc5) - %22 = stablehlo.compare GE, %16#2, %21, SIGNED : (tensor, tensor) -> tensor loc(#loc5) - %23 = stablehlo.broadcast_in_dim %22, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) - %24 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc5) - %25 = stablehlo.broadcast_in_dim %24, dims = [] : (tensor) -> tensor<3x3xf32> loc(#loc5) - %26 = stablehlo.broadcast_in_dim %23, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> loc(#loc5) - %27 = stablehlo.select %26, %16#0, %25 : tensor<3x3xi1>, tensor<3x3xf32> loc(#loc5) - %28 = stablehlo.iota dim = 0 : tensor<3xi32> loc(#loc6) - %29 = stablehlo.constant dense<0> : tensor loc(#loc7) - %30 = stablehlo.constant dense<0> : tensor loc(#loc8) - %31:4 = stablehlo.while(%iterArg = %30, %iterArg_0 = %29, %iterArg_1 = %28, %iterArg_2 = %19) : tensor, tensor, tensor<3xi32>, tensor<3xi32> - cond { - %32 = stablehlo.constant dense<3> : tensor loc(#loc9) - %33 = stablehlo.compare LT, %iterArg, %32, SIGNED : (tensor, tensor) -> tensor loc(#loc10) - stablehlo.return %33 : tensor loc(#loc9) - } do { - %32 = stablehlo.constant dense<1> : tensor loc(#loc9) - %33 = stablehlo.add %iterArg_0, %32 : tensor loc(#loc11) - %34 = stablehlo.constant dense<0> : tensor loc(#loc9) - %35 = stablehlo.compare LT, %iterArg_0, %34, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %36 = stablehlo.constant dense<3> : tensor loc(#loc9) - %37 = stablehlo.add %iterArg_0, %36 : tensor loc(#loc11) - %38 = stablehlo.select %35, %37, %iterArg_0 : tensor, tensor loc(#loc13) - %39 = stablehlo.convert %38 : (tensor) -> tensor loc(#loc14) - %40 = stablehlo.broadcast_in_dim %39, dims = [] : (tensor) -> tensor<1xi32> loc(#loc15) - %41 = "stablehlo.gather"(%iterArg_2, %40) {dimension_numbers = #stablehlo.gather, indices_are_sorted = true, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1xi32>) -> tensor loc(#loc16) - %42 = stablehlo.constant dense<0> : tensor loc(#loc9) - %43 = stablehlo.compare LT, %iterArg_0, %42, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %44 = stablehlo.constant dense<3> : tensor loc(#loc9) - %45 = stablehlo.add %iterArg_0, %44 : tensor loc(#loc11) - %46 = stablehlo.select %43, %45, %iterArg_0 : tensor, tensor loc(#loc13) - %47 = stablehlo.convert %46 : (tensor) -> tensor loc(#loc14) - %48 = stablehlo.broadcast_in_dim %47, dims = [] : (tensor) -> tensor<1xi32> loc(#loc15) - %49 = "stablehlo.gather"(%iterArg_1, %48) {dimension_numbers = #stablehlo.gather, indices_are_sorted = true, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1xi32>) -> tensor loc(#loc16) - %50 = stablehlo.constant dense<0> : tensor loc(#loc9) - %51 = stablehlo.compare LT, %41, %50, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %52 = stablehlo.constant dense<3> : tensor loc(#loc9) - %53 = stablehlo.add %41, %52 : tensor loc(#loc11) - %54 = stablehlo.select %51, %53, %41 : tensor, tensor loc(#loc13) - %55 = stablehlo.dynamic_slice %iterArg_1, %54, sizes = [1] : (tensor<3xi32>, tensor) -> tensor<1xi32> loc(#loc17) - %56 = stablehlo.reshape %55 : (tensor<1xi32>) -> tensor loc(#loc18) - %57 = stablehlo.constant dense<0> : tensor loc(#loc9) - %58 = stablehlo.compare LT, %iterArg_0, %57, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %59 = stablehlo.constant dense<3> : tensor loc(#loc9) - %60 = stablehlo.add %iterArg_0, %59 : tensor loc(#loc11) - %61 = stablehlo.select %58, %60, %iterArg_0 : tensor, tensor loc(#loc13) - %62 = stablehlo.convert %61 : (tensor) -> tensor loc(#loc14) - %63 = stablehlo.broadcast_in_dim %62, dims = [] : (tensor) -> tensor<1xi32> loc(#loc15) - %64 = "stablehlo.scatter"(%iterArg_1, %63, %56) ({ - ^bb0(%arg0: tensor loc(unknown), %arg1: tensor loc(unknown)): - stablehlo.return %arg1 : tensor loc(#loc19) - }) {indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true} : (tensor<3xi32>, tensor<1xi32>, tensor) -> tensor<3xi32> loc(#loc19) - %65 = stablehlo.constant dense<0> : tensor loc(#loc9) - %66 = stablehlo.compare LT, %41, %65, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %67 = stablehlo.constant dense<3> : tensor loc(#loc9) - %68 = stablehlo.add %41, %67 : tensor loc(#loc11) - %69 = stablehlo.select %66, %68, %41 : tensor, tensor loc(#loc13) - %70 = stablehlo.broadcast_in_dim %69, dims = [] : (tensor) -> tensor<1xi32> loc(#loc15) - %71 = "stablehlo.scatter"(%64, %70, %49) ({ - ^bb0(%arg0: tensor loc(unknown), %arg1: tensor loc(unknown)): - stablehlo.return %arg1 : tensor loc(#loc19) - }) {indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true} : (tensor<3xi32>, tensor<1xi32>, tensor) -> tensor<3xi32> loc(#loc19) - %72 = stablehlo.constant dense<1> : tensor loc(#loc9) - %73 = stablehlo.add %iterArg, %72 : tensor loc(#loc11) - stablehlo.return %73, %33, %71, %iterArg_2 : tensor, tensor, tensor<3xi32>, tensor<3xi32> loc(#loc9) - } loc(#loc9) - return %27, %19, %31#2 : tensor<3x3xf32>, tensor<3xi32>, tensor<3xi32> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc1 = loc("third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py":550:0) -#loc2 = loc("third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py":551:0) -#loc3 = loc("jit()/jit(main)/iota[dtype=float32 shape=(9,) dimension=0]"(#loc1)) -#loc4 = loc("jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]"(#loc1)) -#loc5 = loc("jit()/jit(main)/lu"(#loc2)) -#loc6 = loc("jit()/jit(main)/iota[dtype=int32 shape=(3,) dimension=0]"(#loc2)) -#loc7 = loc("jit()/jit(main)/lu_pivots_to_permutation[permutation_size=3]"(#loc2)) -#loc8 = loc("jit()/jit(main)/scan[reverse=False length=3 num_consts=0 num_carry=3 linear=(False, False, False) unroll=1]"(#loc2)) -#loc9 = loc("jit()/jit(main)/while[cond_nconsts=0 body_nconsts=0]"(#loc2)) -#loc10 = loc("jit()/jit(main)/while/cond/lt"(#loc2)) -#loc11 = loc("jit()/jit(main)/while/body/add"(#loc2)) -#loc12 = loc("jit()/jit(main)/while/body/lt"(#loc2)) -#loc13 = loc("jit()/jit(main)/while/body/select_n"(#loc2)) -#loc14 = loc("jit()/jit(main)/while/body/convert_element_type[new_dtype=int32 weak_type=False]"(#loc2)) -#loc15 = loc("jit()/jit(main)/while/body/broadcast_in_dim[shape=(1,) broadcast_dimensions=()]"(#loc2)) -#loc16 = loc("jit()/jit(main)/while/body/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)) slice_sizes=(1,) unique_indices=True indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]"(#loc2)) -#loc17 = loc("jit()/jit(main)/while/body/dynamic_slice[slice_sizes=(1,)]"(#loc2)) -#loc18 = loc("jit()/jit(main)/while/body/squeeze[dimensions=(0,)]"(#loc2)) -#loc19 = loc("jit()/jit(main)/while/body/scatter[update_consts=() dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.FILL_OR_DROP]"(#loc2)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x013\x05\x01\x03\x01\x03\x05\x03#\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f!#%'\x03\xa6\x02\x0e\x023\x01\xb1\x0f\x0f\x07\x17\x0b\x13\x13\x0f\x1b\x13\x0f\x0f\x13\x0f\x13\x13\x13\x0f\x0f\x0b\x13\x17\x0b\x0b\x0b\x0b;\x0b\x0b\x0b\x0f;#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0b\x13\x0b\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x13\x13\x0f\x0b\x0f\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x0f\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x03Q\x0f\x0f/\x0b\x0f\x0b\x0bO\x0b/\x0b\x17\x13\x0b\x13\x0b\x13\x0b\x0b\x0b/\x0f/\x1f\x0b\x0b\x0b\x0b\x1b\x0f\x17\x17/\x1f\x1f\x0b\x1fO/\x0b\x01\x07\x0b\x13\x0b\x01\x03\x0f\x031\x0f\x0f\x13\x13\x0f\x17\x07\x07\x07\x07\x07\x13\x0f\x13\x1b\x13\x13\x13\x13\x13\x13\x17\x17\x13\x02J\t\x1d]\x07\x1d\x8b\x07\x1f\x17-\x9e\x08\x01\x05)\x03\x03/\xb9\x03\x03\t\xd9\x1d\x8d\x07\x03\x051\xc13\xff\x03\x03\t\xfd\x1d\x8f\x07\x1d\x91\x07\x03\x03\t\xdf\x1d\x95\x07\x1d\x02\x02\x07\x03\x03\t\xdd\x03\x03\t\xf5\x1d\x93\x07\x11\x01\x05\x05+\x03\x03S\xb1\x17-\x9a\x08\x01\x05-\x05/\x051\x053\x03\r\x97\xb57\xb19\xbb\x99\xb9;\xc3\x9b\xb5\x055\x057\x059\x1d\x9d\x07\x03\r7\xb19\xbb\xa9\xb5\xab\xb5\xad\xbb\xaf\xb9\x03\x07C%E%'G\x05;\x05=\x05?\x03\x0bK\xbdM\xc5O\xc7'\xd5Q\xd7\x05A\x05C\x05E\x05G\x05I\x1dW+\x05K\x1d[+\x05M\x05O\x03\x03a\xb1\x05Q\x03\x03\t\xdb\x03\x11g\xe1i\xe3k\xe5m\xbdo\xe7q\xe9s\xebu\xef\x05S\x05U\x05W\x05Y\x05[\x05]\x05_\x05a\x03\x03\t\xf3\x03\x051\xc13\xf7\x03\x03\t\xf9\x03\x03/\xfb\x1d\x81\x07\x05c\x1d\x85\x07\x05e\x1d\x89\x07\x05g\x05i\x05k\x05m\x05o\x05q\x05s\x05u\x05w\x05y\x05{\x03\x03;\xc3\x1d\xa3\x07\x05}\x1d\xa7\x07\x05\x7f\x05\x81\x05\x83\x05\x85\x05\x87\x13\x11\x01\x1f%\x01\x1f\x1d\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1d\x89\x1f+\x01\x05\x03\x03\x01\x1f'!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\t\x07\x1f\x1d\x11\x01\x00\x00\x00\x00\x00\x00\x00#\x1f\x03\x07\xc9\xcd\xd1\r\x03\xb7\xcb\x1d\x8b\r\x03\xb7\xcf\x1d\x8d\r\x03\xb7\xd3\x1d\x8f\x1d\x91\x1d\x93\x1f\x03\x11\x03\x00\x00\x00\x00\x00\x00\x00\x1f\x19\x01\x1f\x03\x11\x01\x00\x00\x00\x00\x00\x00\x00\x1f\x05\t\x03\x00\x00\x00\x0b\x05\x1d\x95\x1d\x97\x05\x01\x03\t\xb3\xb3\xb3\xbf\x03\x03\xed\x15\x03\x01\r\x01\x03\x07\xbf\xf1\xb3\x1f)\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x00\x00\x00\x00\x07\x05\x1f\x1b\t\x00\x00\xc0\x7f\x1f1!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f\x03\x11\x00\x00\x00\x00\x00\x00\x00\x00\x07\x0b\x05\x99\x1d\n\x02\x07\x05\x9b\x01\x02\x02)\x01\x11)\x01\x0f)\x03\r\x0f)\x03\x05\x0f)\x01\x17)\x05\r\r\x13\x1b\x1d\t\x13\x01)\x03\x01\x0f)\x01\x13)\x03\x05\x11\x11\x01\x07\r\x07\x07)\x03%\x13)\x03\t\x0f)\x03\x01\x15)\x03\t\x15)\x03\x05\x15)\x03\x01\x11)\x05\x05\x05\x17)\x05\r\r\x17)\x03\t\x11\x04f\n\x05\x01\x11\x05A\x07\x03\x01\x05\x19\x11\x05I\x05\x03K\x85\x13\x03U)\x03!\x0f\x06Y\x03\r\x03\x01\x03\x03\x01\r\x03\x03\x03\x03\x01\r\x03\x03\x0b\x06\x01\x03\x05\x03\x05\x0f\x06\x01\x03\t\x03\t\x0b\x06\x01\x03\x05\x03\x07\x0f\x06\x01\x03\t\x03\r\x1b\x07\x01_\x03#\x05\x0b\x0f\x03\x03\x01\r\x03\x03\x0b\x06\x01\x03\x05\x03\x13\x0f\x06\x01\x03\t\x03\x15\x03\x03\x01c\x03\x19\x03\x03\x01\x1f\x03\x03\x03\x03\x01\x19\x03\x05\x03\x03\x01\x19\x03\x05\x1d\x07\x01e\x07\r\x07\x05\t\x1b\x1d\x1f\x03\x03\x03\x01w\x03\x05\x05\x07\x01\x0b\x03\x07\x03'\x1f\x06\x01\x03\x07\x05#)\x03\x03\x01!\x03\x05\x05\x07\x01\x0b\x03\x05\x03-\x07\x07\x01y\x03\x0b\x05%/\x05\x07\x01\x0b\x03-\x031\x03\x03\x01{\x03\x1b\x05\x07\x01\x0b\x03\r\x035\x05\x07\x01}\x03/\x033\r\x06\x01\x03\r\x079!7\x13\x03\x7f)\x03\x07\x03\x03\x83\x13\x03\x03\x03\x03\x87\x13\x03\x03!\x16\x03\t\x03\x03\x07\x07\tA?=+\t\x03\r\x0f\t\x03\x05\x03\x05\x07\x05\x07\x05\x03\x03\x03\r\x03\x03\x07\x07\x06\x02\x11\x03\x0b\x05KS\x11\x04\x03\x03U\x03]\xaf\t\x03\x05\x03\x05\x07\x05\x07\x05\x03\x03\x03\x1f\x03\x03\t\x06\x0f\x03\x03\x05MS\x03\x03\x03\x13\x03\x03\x07\x07\x15\x11\x03\x0b\x05MW\x03\x03\x03\r\x03\x03\t\x06\x0f\x03\x03\x05M[\r\x06\x17\x03\x03\x07Y]M\x0b\x06#\x03\x05\x03_\x05\x07\x1b\x0b\x03\t\x03a\x15\x07=5\x03\x05\x05Qc\x03\x03\x03\x13\x03\x03\x07\x07\x15\x11\x03\x0b\x05Mg\x03\x03\x03\r\x03\x03\t\x06\x0f\x03\x03\x05Mk\r\x06\x17\x03\x03\x07imM\x0b\x06#\x03\x05\x03o\x05\x07\x1b\x0b\x03\t\x03q\x15\x07=5\x03\x05\x05Os\x03\x03\x03!\x03\x05\x07\x07\x15\x11\x03\x0b\x05ew\x03\x03\x03\x19\x03\x05\t\x06\x0f\x03\x05\x05e{\r\x06\x17\x03\x05\x07y}e#\x07\xa1\x9f\x03\t\x05O\x7f\x0f\x06\xa5\x03\x05\x03\x81\x03\x03\x03\x13\x03\x03\x07\x07\x15\x11\x03\x0b\x05M\x85\x03\x03\x03\r\x03\x03\t\x06\x0f\x03\x03\x05M\x89\r\x06\x17\x03\x03\x07\x87\x8bM\x0b\x06#\x03\x05\x03\x8d\x05\x07\x1b\x0b\x03\t\x03\x8f\x17\x17\x1d?\x03\x07\x07O\x91\x83\x05\x03\x05\x07\x05\x05\x05\x05\x05\x11\x04\x1d\x03\xa9\x03\x03\x03!\x03\x05\x07\x07\x15\x11\x03\x0b\x05e\x95\x03\x03\x03\x19\x03\x05\t\x06\x0f\x03\x05\x05e\x99\r\x06\x17\x03\x05\x07\x97\x9be\x05\x07\x1b\x0b\x03\t\x03\x9d\x17\x17\x1d?\x03\x07\x07\x93\x9fu\x05\x03\x05\x07\x05\x05\x05\x05\x05\x11\x04\x1d\x03\xa9\x03\x03\x03\x1f\x03\x03\t\x06\x0f\x03\x03\x05K\xa3\x11\x04\x03\t\xa5U\xa1Q\x11\x04\x05\x07;+G\x06\x03\x01\x05\x01\x00v%\x9dM2\x04\x1d\x03\x0f\x0b\t\t\t!'\x1f;+y\x87.\x04!\x19+\xb1\xb3YMO{\xe9\x8b\x83\x1f/!!)#\x1f\x19\x157\x85\x87\x1f\x1f\x15\x1d\x15\x1b%)\x19'#+\x1b+\x83\x13\r#\x13\x19\x1f\x1f\x11\x17\x15\x11\x15\x17\x15\x17\x0f\x17)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00compare_v1\x00add_v1\x00convert_v1\x00select_v1\x00reshape_v1\x00return_v1\x00iota_v1\x00gather_v1\x00scatter_v1\x00func_v1\x00concatenate_v1\x00custom_call_v1\x00subtract_v1\x00while_v1\x00dynamic_slice_v1\x00value\x00sym_name\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00broadcast_dimensions\x00compare_type\x00comparison_direction\x00index_vector_dim\x00indices_are_sorted\x00slice_sizes\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00jit()/jit(main)/iota[dtype=float32 shape=(9,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]\x00jit()/jit(main)/lu\x00dimension\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jit()/jit(main)/iota[dtype=int32 shape=(3,) dimension=0]\x00jit()/jit(main)/lu_pivots_to_permutation[permutation_size=3]\x00jit()/jit(main)/scan[reverse=False length=3 num_consts=0 num_carry=3 linear=(False, False, False) unroll=1]\x00jit()/jit(main)/while[cond_nconsts=0 body_nconsts=0]\x00jit()/jit(main)/while/body/add\x00jit()/jit(main)/while/body/lt\x00jit()/jit(main)/while/body/select_n\x00jit()/jit(main)/while/body/convert_element_type[new_dtype=int32 weak_type=False]\x00jit()/jit(main)/while/body/broadcast_in_dim[shape=(1,) broadcast_dimensions=()]\x00collapsed_slice_dims\x00offset_dims\x00start_index_map\x00jit()/jit(main)/while/body/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)) slice_sizes=(1,) unique_indices=True indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]\x00jit()/jit(main)/while/body/dynamic_slice[slice_sizes=(1,)]\x00jit()/jit(main)/while/body/squeeze[dimensions=(0,)]\x00inserted_window_dims\x00scatter_dims_to_operand_dims\x00unique_indices\x00update_window_dims\x00jax.result_info\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_sgetrf\x00jit()/jit(main)/while/body/scatter[update_consts=() dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.FILL_OR_DROP]\x00jit()/jit(main)/while/cond/lt\x00", - xla_call_module_version=6, -) # End paste - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_06_14['f64'] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_dgetrf'], - serialized_date=datetime.date(2023, 6, 14), - inputs=(), - expected_outputs=(array([[6. , 7. , 8. ], - [0. , 1. , 2. ], - [0.5, 0.5, 0. ]]), array([2, 2, 2], dtype=int32), array([2, 0, 1], dtype=int32)), - mlir_module_text=r""" -#loc = loc(unknown) -module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<3x3xf64> {jax.result_info = "[0]"}, tensor<3xi32> {jax.result_info = "[1]"}, tensor<3xi32> {jax.result_info = "[2]"}) { - %0 = stablehlo.iota dim = 0 : tensor<9xf64> loc(#loc3) - %1 = stablehlo.reshape %0 : (tensor<9xf64>) -> tensor<3x3xf64> loc(#loc4) - %2 = stablehlo.constant dense<3> : tensor loc(#loc5) - %3 = stablehlo.constant dense<3> : tensor loc(#loc5) - %4 = stablehlo.convert %2 : (tensor) -> tensor loc(#loc5) - %5 = stablehlo.reshape %4 : (tensor) -> tensor<1xi32> loc(#loc5) - %6 = stablehlo.convert %3 : (tensor) -> tensor loc(#loc5) - %7 = stablehlo.reshape %6 : (tensor) -> tensor<1xi32> loc(#loc5) - %8 = stablehlo.concatenate %5, %7, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> loc(#loc5) - %9 = stablehlo.constant dense<3> : tensor loc(#loc5) - %10 = stablehlo.convert %9 : (tensor) -> tensor loc(#loc5) - %11 = stablehlo.reshape %10 : (tensor) -> tensor<1xi32> loc(#loc5) - %12 = stablehlo.constant dense<> : tensor<0xi32> loc(#loc5) - %13 = stablehlo.constant dense<1> : tensor loc(#loc5) - %14 = stablehlo.constant dense<3> : tensor loc(#loc5) - %15 = stablehlo.constant dense<3> : tensor loc(#loc5) - %16:3 = stablehlo.custom_call @lapack_dgetrf(%13, %14, %15, %1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>]} : (tensor, tensor, tensor, tensor<3x3xf64>) -> (tensor<3x3xf64>, tensor<3xi32>, tensor) loc(#loc5) - %17 = stablehlo.constant dense<1> : tensor loc(#loc5) - %18 = stablehlo.broadcast_in_dim %17, dims = [] : (tensor) -> tensor<3xi32> loc(#loc5) - %19 = stablehlo.subtract %16#1, %18 : tensor<3xi32> loc(#loc5) - %20 = stablehlo.constant dense<0> : tensor loc(#loc5) - %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor loc(#loc5) - %22 = stablehlo.compare GE, %16#2, %21, SIGNED : (tensor, tensor) -> tensor loc(#loc5) - %23 = stablehlo.broadcast_in_dim %22, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) - %24 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc5) - %25 = stablehlo.broadcast_in_dim %24, dims = [] : (tensor) -> tensor<3x3xf64> loc(#loc5) - %26 = stablehlo.broadcast_in_dim %23, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> loc(#loc5) - %27 = stablehlo.select %26, %16#0, %25 : tensor<3x3xi1>, tensor<3x3xf64> loc(#loc5) - %28 = stablehlo.iota dim = 0 : tensor<3xi32> loc(#loc6) - %29 = stablehlo.constant dense<0> : tensor loc(#loc7) - %30 = stablehlo.constant dense<0> : tensor loc(#loc8) - %31:4 = stablehlo.while(%iterArg = %30, %iterArg_0 = %29, %iterArg_1 = %28, %iterArg_2 = %19) : tensor, tensor, tensor<3xi32>, tensor<3xi32> - cond { - %32 = stablehlo.constant dense<3> : tensor loc(#loc9) - %33 = stablehlo.compare LT, %iterArg, %32, SIGNED : (tensor, tensor) -> tensor loc(#loc10) - stablehlo.return %33 : tensor loc(#loc9) - } do { - %32 = stablehlo.constant dense<1> : tensor loc(#loc9) - %33 = stablehlo.add %iterArg_0, %32 : tensor loc(#loc11) - %34 = stablehlo.constant dense<0> : tensor loc(#loc9) - %35 = stablehlo.compare LT, %iterArg_0, %34, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %36 = stablehlo.constant dense<3> : tensor loc(#loc9) - %37 = stablehlo.add %iterArg_0, %36 : tensor loc(#loc11) - %38 = stablehlo.select %35, %37, %iterArg_0 : tensor, tensor loc(#loc13) - %39 = stablehlo.convert %38 : (tensor) -> tensor loc(#loc14) - %40 = stablehlo.broadcast_in_dim %39, dims = [] : (tensor) -> tensor<1xi32> loc(#loc15) - %41 = "stablehlo.gather"(%iterArg_2, %40) {dimension_numbers = #stablehlo.gather, indices_are_sorted = true, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1xi32>) -> tensor loc(#loc16) - %42 = stablehlo.constant dense<0> : tensor loc(#loc9) - %43 = stablehlo.compare LT, %iterArg_0, %42, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %44 = stablehlo.constant dense<3> : tensor loc(#loc9) - %45 = stablehlo.add %iterArg_0, %44 : tensor loc(#loc11) - %46 = stablehlo.select %43, %45, %iterArg_0 : tensor, tensor loc(#loc13) - %47 = stablehlo.convert %46 : (tensor) -> tensor loc(#loc14) - %48 = stablehlo.broadcast_in_dim %47, dims = [] : (tensor) -> tensor<1xi32> loc(#loc15) - %49 = "stablehlo.gather"(%iterArg_1, %48) {dimension_numbers = #stablehlo.gather, indices_are_sorted = true, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1xi32>) -> tensor loc(#loc16) - %50 = stablehlo.constant dense<0> : tensor loc(#loc9) - %51 = stablehlo.compare LT, %41, %50, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %52 = stablehlo.constant dense<3> : tensor loc(#loc9) - %53 = stablehlo.add %41, %52 : tensor loc(#loc11) - %54 = stablehlo.select %51, %53, %41 : tensor, tensor loc(#loc13) - %55 = stablehlo.dynamic_slice %iterArg_1, %54, sizes = [1] : (tensor<3xi32>, tensor) -> tensor<1xi32> loc(#loc17) - %56 = stablehlo.reshape %55 : (tensor<1xi32>) -> tensor loc(#loc18) - %57 = stablehlo.constant dense<0> : tensor loc(#loc9) - %58 = stablehlo.compare LT, %iterArg_0, %57, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %59 = stablehlo.constant dense<3> : tensor loc(#loc9) - %60 = stablehlo.add %iterArg_0, %59 : tensor loc(#loc11) - %61 = stablehlo.select %58, %60, %iterArg_0 : tensor, tensor loc(#loc13) - %62 = stablehlo.convert %61 : (tensor) -> tensor loc(#loc14) - %63 = stablehlo.broadcast_in_dim %62, dims = [] : (tensor) -> tensor<1xi32> loc(#loc15) - %64 = "stablehlo.scatter"(%iterArg_1, %63, %56) ({ - ^bb0(%arg0: tensor loc(unknown), %arg1: tensor loc(unknown)): - stablehlo.return %arg1 : tensor loc(#loc19) - }) {indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true} : (tensor<3xi32>, tensor<1xi32>, tensor) -> tensor<3xi32> loc(#loc19) - %65 = stablehlo.constant dense<0> : tensor loc(#loc9) - %66 = stablehlo.compare LT, %41, %65, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %67 = stablehlo.constant dense<3> : tensor loc(#loc9) - %68 = stablehlo.add %41, %67 : tensor loc(#loc11) - %69 = stablehlo.select %66, %68, %41 : tensor, tensor loc(#loc13) - %70 = stablehlo.broadcast_in_dim %69, dims = [] : (tensor) -> tensor<1xi32> loc(#loc15) - %71 = "stablehlo.scatter"(%64, %70, %49) ({ - ^bb0(%arg0: tensor loc(unknown), %arg1: tensor loc(unknown)): - stablehlo.return %arg1 : tensor loc(#loc19) - }) {indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true} : (tensor<3xi32>, tensor<1xi32>, tensor) -> tensor<3xi32> loc(#loc19) - %72 = stablehlo.constant dense<1> : tensor loc(#loc9) - %73 = stablehlo.add %iterArg, %72 : tensor loc(#loc11) - stablehlo.return %73, %33, %71, %iterArg_2 : tensor, tensor, tensor<3xi32>, tensor<3xi32> loc(#loc9) - } loc(#loc9) - return %27, %19, %31#2 : tensor<3x3xf64>, tensor<3xi32>, tensor<3xi32> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc1 = loc("third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py":553:0) -#loc2 = loc("third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py":554:0) -#loc3 = loc("jit()/jit(main)/iota[dtype=float64 shape=(9,) dimension=0]"(#loc1)) -#loc4 = loc("jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]"(#loc1)) -#loc5 = loc("jit()/jit(main)/lu"(#loc2)) -#loc6 = loc("jit()/jit(main)/iota[dtype=int32 shape=(3,) dimension=0]"(#loc2)) -#loc7 = loc("jit()/jit(main)/lu_pivots_to_permutation[permutation_size=3]"(#loc2)) -#loc8 = loc("jit()/jit(main)/scan[reverse=False length=3 num_consts=0 num_carry=3 linear=(False, False, False) unroll=1]"(#loc2)) -#loc9 = loc("jit()/jit(main)/while[cond_nconsts=0 body_nconsts=0]"(#loc2)) -#loc10 = loc("jit()/jit(main)/while/cond/lt"(#loc2)) -#loc11 = loc("jit()/jit(main)/while/body/add"(#loc2)) -#loc12 = loc("jit()/jit(main)/while/body/lt"(#loc2)) -#loc13 = loc("jit()/jit(main)/while/body/select_n"(#loc2)) -#loc14 = loc("jit()/jit(main)/while/body/convert_element_type[new_dtype=int32 weak_type=False]"(#loc2)) -#loc15 = loc("jit()/jit(main)/while/body/broadcast_in_dim[shape=(1,) broadcast_dimensions=()]"(#loc2)) -#loc16 = loc("jit()/jit(main)/while/body/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)) slice_sizes=(1,) unique_indices=True indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]"(#loc2)) -#loc17 = loc("jit()/jit(main)/while/body/dynamic_slice[slice_sizes=(1,)]"(#loc2)) -#loc18 = loc("jit()/jit(main)/while/body/squeeze[dimensions=(0,)]"(#loc2)) -#loc19 = loc("jit()/jit(main)/while/body/scatter[update_consts=() dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.FILL_OR_DROP]"(#loc2)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x013\x05\x01\x03\x01\x03\x05\x03#\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f!#%'\x03\xa6\x02\x0e\x023\x01\xb1\x0f\x0f\x07\x17\x0b\x13\x13\x0f\x1b\x13\x0f\x0f\x13\x0f\x13\x13\x13\x0f\x0f\x0b\x13\x17\x0b\x0b\x0b\x0b;\x0b\x0b\x0b\x0f;#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0b\x13\x0b\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x13\x13\x0f\x0b\x0f\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x0f\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x03Q\x0f\x0f/\x0b\x0f\x0b\x0bO\x0b/\x0b\x17\x13\x0b\x13\x0b\x13\x0b\x0b\x0b/\x0f/\x1f\x0b\x0b\x0b\x0b\x1b\x0f\x17\x17/\x1f\x1f\x0b/O/\x0b\x01\x07\x0b\x13\x0b\x01\x03\x0f\x031\x0f\x0f\x13\x13\x0f\x17\x07\x07\x07\x07\x07\x13\x0f\x13\x1b\x13\x13\x13\x13\x13\x13\x17\x17\x13\x02Z\t\x1d]\x07\x1d\x8b\x07\x1f\x17-\xaa\x08\x01\x05)\x03\x03/\xb9\x03\x03\t\xd9\x1d\x8d\x07\x03\x051\xc13\xff\x03\x03\t\xfd\x1d\x8f\x07\x1d\x91\x07\x03\x03\t\xdf\x1d\x95\x07\x1d\x02\x02\x07\x03\x03\t\xdd\x03\x03\t\xf5\x1d\x93\x07\x11\x01\x05\x05+\x03\x03S\xb1\x17-\xa6\x08\x01\x05-\x05/\x051\x053\x03\r\x97\xb57\xb19\xbb\x99\xb9;\xc3\x9b\xb5\x055\x057\x059\x1d\x9d\x07\x03\r7\xb19\xbb\xa9\xb5\xab\xb5\xad\xbb\xaf\xb9\x03\x07C%E%'G\x05;\x05=\x05?\x03\x0bK\xbdM\xc5O\xc7'\xd5Q\xd7\x05A\x05C\x05E\x05G\x05I\x1dW+\x05K\x1d[+\x05M\x05O\x03\x03a\xb1\x05Q\x03\x03\t\xdb\x03\x11g\xe1i\xe3k\xe5m\xbdo\xe7q\xe9s\xebu\xef\x05S\x05U\x05W\x05Y\x05[\x05]\x05_\x05a\x03\x03\t\xf3\x03\x051\xc13\xf7\x03\x03\t\xf9\x03\x03/\xfb\x1d\x81\x07\x05c\x1d\x85\x07\x05e\x1d\x89\x07\x05g\x05i\x05k\x05m\x05o\x05q\x05s\x05u\x05w\x05y\x05{\x03\x03;\xc3\x1d\xa3\x07\x05}\x1d\xa7\x07\x05\x7f\x05\x81\x05\x83\x05\x85\x05\x87\x13\x11\x01\x1f%\x01\x1f\x1d\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1d\x89\x1f+\x01\x05\x03\x03\x01\x1f'!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\t\x07\x1f\x1d\x11\x01\x00\x00\x00\x00\x00\x00\x00#\x1f\x03\x07\xc9\xcd\xd1\r\x03\xb7\xcb\x1d\x8b\r\x03\xb7\xcf\x1d\x8d\r\x03\xb7\xd3\x1d\x8f\x1d\x91\x1d\x93\x1f\x03\x11\x03\x00\x00\x00\x00\x00\x00\x00\x1f\x19\x01\x1f\x03\x11\x01\x00\x00\x00\x00\x00\x00\x00\x1f\x05\t\x03\x00\x00\x00\x0b\x05\x1d\x95\x1d\x97\x05\x01\x03\t\xb3\xb3\xb3\xbf\x03\x03\xed\x15\x03\x01\r\x01\x03\x07\xbf\xf1\xb3\x1f)\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x00\x00\x00\x00\x07\x05\x1f\x1b\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f1!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f\x03\x11\x00\x00\x00\x00\x00\x00\x00\x00\x07\x0b\x05\x99\x1d\n\x02\x07\x05\x9b\x01\x02\x02)\x01\x11)\x01\x0f)\x03\r\x0f)\x03\x05\x0f)\x01\x17)\x05\r\r\x13\x1b\x1d\x0b\x13\x01)\x03\x01\x0f)\x01\x13)\x03\x05\x11\x11\x01\x07\r\x07\x07)\x03%\x13)\x03\t\x0f)\x03\x01\x15)\x03\t\x15)\x03\x05\x15)\x03\x01\x11)\x05\x05\x05\x17)\x05\r\r\x17)\x03\t\x11\x04f\n\x05\x01\x11\x05A\x07\x03\x01\x05\x19\x11\x05I\x05\x03K\x85\x13\x03U)\x03!\x0f\x06Y\x03\r\x03\x01\x03\x03\x01\r\x03\x03\x03\x03\x01\r\x03\x03\x0b\x06\x01\x03\x05\x03\x05\x0f\x06\x01\x03\t\x03\t\x0b\x06\x01\x03\x05\x03\x07\x0f\x06\x01\x03\t\x03\r\x1b\x07\x01_\x03#\x05\x0b\x0f\x03\x03\x01\r\x03\x03\x0b\x06\x01\x03\x05\x03\x13\x0f\x06\x01\x03\t\x03\x15\x03\x03\x01c\x03\x19\x03\x03\x01\x1f\x03\x03\x03\x03\x01\x19\x03\x05\x03\x03\x01\x19\x03\x05\x1d\x07\x01e\x07\r\x07\x05\t\x1b\x1d\x1f\x03\x03\x03\x01w\x03\x05\x05\x07\x01\x0b\x03\x07\x03'\x1f\x06\x01\x03\x07\x05#)\x03\x03\x01!\x03\x05\x05\x07\x01\x0b\x03\x05\x03-\x07\x07\x01y\x03\x0b\x05%/\x05\x07\x01\x0b\x03-\x031\x03\x03\x01{\x03\x1b\x05\x07\x01\x0b\x03\r\x035\x05\x07\x01}\x03/\x033\r\x06\x01\x03\r\x079!7\x13\x03\x7f)\x03\x07\x03\x03\x83\x13\x03\x03\x03\x03\x87\x13\x03\x03!\x16\x03\t\x03\x03\x07\x07\tA?=+\t\x03\r\x0f\t\x03\x05\x03\x05\x07\x05\x07\x05\x03\x03\x03\r\x03\x03\x07\x07\x06\x02\x11\x03\x0b\x05KS\x11\x04\x03\x03U\x03]\xaf\t\x03\x05\x03\x05\x07\x05\x07\x05\x03\x03\x03\x1f\x03\x03\t\x06\x0f\x03\x03\x05MS\x03\x03\x03\x13\x03\x03\x07\x07\x15\x11\x03\x0b\x05MW\x03\x03\x03\r\x03\x03\t\x06\x0f\x03\x03\x05M[\r\x06\x17\x03\x03\x07Y]M\x0b\x06#\x03\x05\x03_\x05\x07\x1b\x0b\x03\t\x03a\x15\x07=5\x03\x05\x05Qc\x03\x03\x03\x13\x03\x03\x07\x07\x15\x11\x03\x0b\x05Mg\x03\x03\x03\r\x03\x03\t\x06\x0f\x03\x03\x05Mk\r\x06\x17\x03\x03\x07imM\x0b\x06#\x03\x05\x03o\x05\x07\x1b\x0b\x03\t\x03q\x15\x07=5\x03\x05\x05Os\x03\x03\x03!\x03\x05\x07\x07\x15\x11\x03\x0b\x05ew\x03\x03\x03\x19\x03\x05\t\x06\x0f\x03\x05\x05e{\r\x06\x17\x03\x05\x07y}e#\x07\xa1\x9f\x03\t\x05O\x7f\x0f\x06\xa5\x03\x05\x03\x81\x03\x03\x03\x13\x03\x03\x07\x07\x15\x11\x03\x0b\x05M\x85\x03\x03\x03\r\x03\x03\t\x06\x0f\x03\x03\x05M\x89\r\x06\x17\x03\x03\x07\x87\x8bM\x0b\x06#\x03\x05\x03\x8d\x05\x07\x1b\x0b\x03\t\x03\x8f\x17\x17\x1d?\x03\x07\x07O\x91\x83\x05\x03\x05\x07\x05\x05\x05\x05\x05\x11\x04\x1d\x03\xa9\x03\x03\x03!\x03\x05\x07\x07\x15\x11\x03\x0b\x05e\x95\x03\x03\x03\x19\x03\x05\t\x06\x0f\x03\x05\x05e\x99\r\x06\x17\x03\x05\x07\x97\x9be\x05\x07\x1b\x0b\x03\t\x03\x9d\x17\x17\x1d?\x03\x07\x07\x93\x9fu\x05\x03\x05\x07\x05\x05\x05\x05\x05\x11\x04\x1d\x03\xa9\x03\x03\x03\x1f\x03\x03\t\x06\x0f\x03\x03\x05K\xa3\x11\x04\x03\t\xa5U\xa1Q\x11\x04\x05\x07;+G\x06\x03\x01\x05\x01\x00v%\x9dM2\x04\x1d\x03\x0f\x0b\t\t\t!'\x1f;+y\x87.\x04!\x19+\xb1\xb3YMO{\xe9\x8b\x83\x1f/!!)#\x1f\x19\x157\x85\x87\x1f\x1f\x15\x1d\x15\x1b%)\x19'#+\x1b+\x83\x13\r#\x13\x19\x1f\x1f\x11\x17\x15\x11\x15\x17\x15\x17\x0f\x17)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00compare_v1\x00add_v1\x00convert_v1\x00select_v1\x00reshape_v1\x00return_v1\x00iota_v1\x00gather_v1\x00scatter_v1\x00func_v1\x00concatenate_v1\x00custom_call_v1\x00subtract_v1\x00while_v1\x00dynamic_slice_v1\x00value\x00sym_name\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00broadcast_dimensions\x00compare_type\x00comparison_direction\x00index_vector_dim\x00indices_are_sorted\x00slice_sizes\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00jit()/jit(main)/iota[dtype=float64 shape=(9,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]\x00jit()/jit(main)/lu\x00dimension\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jit()/jit(main)/iota[dtype=int32 shape=(3,) dimension=0]\x00jit()/jit(main)/lu_pivots_to_permutation[permutation_size=3]\x00jit()/jit(main)/scan[reverse=False length=3 num_consts=0 num_carry=3 linear=(False, False, False) unroll=1]\x00jit()/jit(main)/while[cond_nconsts=0 body_nconsts=0]\x00jit()/jit(main)/while/body/add\x00jit()/jit(main)/while/body/lt\x00jit()/jit(main)/while/body/select_n\x00jit()/jit(main)/while/body/convert_element_type[new_dtype=int32 weak_type=False]\x00jit()/jit(main)/while/body/broadcast_in_dim[shape=(1,) broadcast_dimensions=()]\x00collapsed_slice_dims\x00offset_dims\x00start_index_map\x00jit()/jit(main)/while/body/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)) slice_sizes=(1,) unique_indices=True indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]\x00jit()/jit(main)/while/body/dynamic_slice[slice_sizes=(1,)]\x00jit()/jit(main)/while/body/squeeze[dimensions=(0,)]\x00inserted_window_dims\x00scatter_dims_to_operand_dims\x00unique_indices\x00update_window_dims\x00jax.result_info\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_dgetrf\x00jit()/jit(main)/while/body/scatter[update_consts=() dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.FILL_OR_DROP]\x00jit()/jit(main)/while/cond/lt\x00", - xla_call_module_version=6, -) # End paste - - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_06_14['c64'] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_cgetrf'], - serialized_date=datetime.date(2023, 6, 14), - inputs=(), - expected_outputs=(array([[6. +0.j, 7. +0.j, 8. +0.j], - [0. +0.j, 1. +0.j, 2. +0.j], - [0.5+0.j, 0.5+0.j, 0. +0.j]], dtype=complex64), array([2, 2, 2], dtype=int32), array([2, 0, 1], dtype=int32)), - mlir_module_text=r""" -#loc = loc(unknown) -module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<3x3xcomplex> {jax.result_info = "[0]"}, tensor<3xi32> {jax.result_info = "[1]"}, tensor<3xi32> {jax.result_info = "[2]"}) { - %0 = stablehlo.iota dim = 0 : tensor<9xcomplex> loc(#loc3) - %1 = stablehlo.reshape %0 : (tensor<9xcomplex>) -> tensor<3x3xcomplex> loc(#loc4) - %2 = stablehlo.constant dense<3> : tensor loc(#loc5) - %3 = stablehlo.constant dense<3> : tensor loc(#loc5) - %4 = stablehlo.convert %2 : (tensor) -> tensor loc(#loc5) - %5 = stablehlo.reshape %4 : (tensor) -> tensor<1xi32> loc(#loc5) - %6 = stablehlo.convert %3 : (tensor) -> tensor loc(#loc5) - %7 = stablehlo.reshape %6 : (tensor) -> tensor<1xi32> loc(#loc5) - %8 = stablehlo.concatenate %5, %7, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> loc(#loc5) - %9 = stablehlo.constant dense<3> : tensor loc(#loc5) - %10 = stablehlo.convert %9 : (tensor) -> tensor loc(#loc5) - %11 = stablehlo.reshape %10 : (tensor) -> tensor<1xi32> loc(#loc5) - %12 = stablehlo.constant dense<> : tensor<0xi32> loc(#loc5) - %13 = stablehlo.constant dense<1> : tensor loc(#loc5) - %14 = stablehlo.constant dense<3> : tensor loc(#loc5) - %15 = stablehlo.constant dense<3> : tensor loc(#loc5) - %16:3 = stablehlo.custom_call @lapack_cgetrf(%13, %14, %15, %1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>]} : (tensor, tensor, tensor, tensor<3x3xcomplex>) -> (tensor<3x3xcomplex>, tensor<3xi32>, tensor) loc(#loc5) - %17 = stablehlo.constant dense<1> : tensor loc(#loc5) - %18 = stablehlo.broadcast_in_dim %17, dims = [] : (tensor) -> tensor<3xi32> loc(#loc5) - %19 = stablehlo.subtract %16#1, %18 : tensor<3xi32> loc(#loc5) - %20 = stablehlo.constant dense<0> : tensor loc(#loc5) - %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor loc(#loc5) - %22 = stablehlo.compare GE, %16#2, %21, SIGNED : (tensor, tensor) -> tensor loc(#loc5) - %23 = stablehlo.broadcast_in_dim %22, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) - %24 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc5) - %25 = stablehlo.broadcast_in_dim %24, dims = [] : (tensor>) -> tensor<3x3xcomplex> loc(#loc5) - %26 = stablehlo.broadcast_in_dim %23, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> loc(#loc5) - %27 = stablehlo.select %26, %16#0, %25 : tensor<3x3xi1>, tensor<3x3xcomplex> loc(#loc5) - %28 = stablehlo.iota dim = 0 : tensor<3xi32> loc(#loc6) - %29 = stablehlo.constant dense<0> : tensor loc(#loc7) - %30 = stablehlo.constant dense<0> : tensor loc(#loc8) - %31:4 = stablehlo.while(%iterArg = %30, %iterArg_0 = %29, %iterArg_1 = %28, %iterArg_2 = %19) : tensor, tensor, tensor<3xi32>, tensor<3xi32> - cond { - %32 = stablehlo.constant dense<3> : tensor loc(#loc9) - %33 = stablehlo.compare LT, %iterArg, %32, SIGNED : (tensor, tensor) -> tensor loc(#loc10) - stablehlo.return %33 : tensor loc(#loc9) - } do { - %32 = stablehlo.constant dense<1> : tensor loc(#loc9) - %33 = stablehlo.add %iterArg_0, %32 : tensor loc(#loc11) - %34 = stablehlo.constant dense<0> : tensor loc(#loc9) - %35 = stablehlo.compare LT, %iterArg_0, %34, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %36 = stablehlo.constant dense<3> : tensor loc(#loc9) - %37 = stablehlo.add %iterArg_0, %36 : tensor loc(#loc11) - %38 = stablehlo.select %35, %37, %iterArg_0 : tensor, tensor loc(#loc13) - %39 = stablehlo.convert %38 : (tensor) -> tensor loc(#loc14) - %40 = stablehlo.broadcast_in_dim %39, dims = [] : (tensor) -> tensor<1xi32> loc(#loc15) - %41 = "stablehlo.gather"(%iterArg_2, %40) {dimension_numbers = #stablehlo.gather, indices_are_sorted = true, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1xi32>) -> tensor loc(#loc16) - %42 = stablehlo.constant dense<0> : tensor loc(#loc9) - %43 = stablehlo.compare LT, %iterArg_0, %42, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %44 = stablehlo.constant dense<3> : tensor loc(#loc9) - %45 = stablehlo.add %iterArg_0, %44 : tensor loc(#loc11) - %46 = stablehlo.select %43, %45, %iterArg_0 : tensor, tensor loc(#loc13) - %47 = stablehlo.convert %46 : (tensor) -> tensor loc(#loc14) - %48 = stablehlo.broadcast_in_dim %47, dims = [] : (tensor) -> tensor<1xi32> loc(#loc15) - %49 = "stablehlo.gather"(%iterArg_1, %48) {dimension_numbers = #stablehlo.gather, indices_are_sorted = true, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1xi32>) -> tensor loc(#loc16) - %50 = stablehlo.constant dense<0> : tensor loc(#loc9) - %51 = stablehlo.compare LT, %41, %50, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %52 = stablehlo.constant dense<3> : tensor loc(#loc9) - %53 = stablehlo.add %41, %52 : tensor loc(#loc11) - %54 = stablehlo.select %51, %53, %41 : tensor, tensor loc(#loc13) - %55 = stablehlo.dynamic_slice %iterArg_1, %54, sizes = [1] : (tensor<3xi32>, tensor) -> tensor<1xi32> loc(#loc17) - %56 = stablehlo.reshape %55 : (tensor<1xi32>) -> tensor loc(#loc18) - %57 = stablehlo.constant dense<0> : tensor loc(#loc9) - %58 = stablehlo.compare LT, %iterArg_0, %57, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %59 = stablehlo.constant dense<3> : tensor loc(#loc9) - %60 = stablehlo.add %iterArg_0, %59 : tensor loc(#loc11) - %61 = stablehlo.select %58, %60, %iterArg_0 : tensor, tensor loc(#loc13) - %62 = stablehlo.convert %61 : (tensor) -> tensor loc(#loc14) - %63 = stablehlo.broadcast_in_dim %62, dims = [] : (tensor) -> tensor<1xi32> loc(#loc15) - %64 = "stablehlo.scatter"(%iterArg_1, %63, %56) ({ - ^bb0(%arg0: tensor loc(unknown), %arg1: tensor loc(unknown)): - stablehlo.return %arg1 : tensor loc(#loc19) - }) {indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true} : (tensor<3xi32>, tensor<1xi32>, tensor) -> tensor<3xi32> loc(#loc19) - %65 = stablehlo.constant dense<0> : tensor loc(#loc9) - %66 = stablehlo.compare LT, %41, %65, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %67 = stablehlo.constant dense<3> : tensor loc(#loc9) - %68 = stablehlo.add %41, %67 : tensor loc(#loc11) - %69 = stablehlo.select %66, %68, %41 : tensor, tensor loc(#loc13) - %70 = stablehlo.broadcast_in_dim %69, dims = [] : (tensor) -> tensor<1xi32> loc(#loc15) - %71 = "stablehlo.scatter"(%64, %70, %49) ({ - ^bb0(%arg0: tensor loc(unknown), %arg1: tensor loc(unknown)): - stablehlo.return %arg1 : tensor loc(#loc19) - }) {indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true} : (tensor<3xi32>, tensor<1xi32>, tensor) -> tensor<3xi32> loc(#loc19) - %72 = stablehlo.constant dense<1> : tensor loc(#loc9) - %73 = stablehlo.add %iterArg, %72 : tensor loc(#loc11) - stablehlo.return %73, %33, %71, %iterArg_2 : tensor, tensor, tensor<3xi32>, tensor<3xi32> loc(#loc9) - } loc(#loc9) - return %27, %19, %31#2 : tensor<3x3xcomplex>, tensor<3xi32>, tensor<3xi32> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc1 = loc("third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py":553:0) -#loc2 = loc("third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py":554:0) -#loc3 = loc("jit()/jit(main)/iota[dtype=complex64 shape=(9,) dimension=0]"(#loc1)) -#loc4 = loc("jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]"(#loc1)) -#loc5 = loc("jit()/jit(main)/lu"(#loc2)) -#loc6 = loc("jit()/jit(main)/iota[dtype=int32 shape=(3,) dimension=0]"(#loc2)) -#loc7 = loc("jit()/jit(main)/lu_pivots_to_permutation[permutation_size=3]"(#loc2)) -#loc8 = loc("jit()/jit(main)/scan[reverse=False length=3 num_consts=0 num_carry=3 linear=(False, False, False) unroll=1]"(#loc2)) -#loc9 = loc("jit()/jit(main)/while[cond_nconsts=0 body_nconsts=0]"(#loc2)) -#loc10 = loc("jit()/jit(main)/while/cond/lt"(#loc2)) -#loc11 = loc("jit()/jit(main)/while/body/add"(#loc2)) -#loc12 = loc("jit()/jit(main)/while/body/lt"(#loc2)) -#loc13 = loc("jit()/jit(main)/while/body/select_n"(#loc2)) -#loc14 = loc("jit()/jit(main)/while/body/convert_element_type[new_dtype=int32 weak_type=False]"(#loc2)) -#loc15 = loc("jit()/jit(main)/while/body/broadcast_in_dim[shape=(1,) broadcast_dimensions=()]"(#loc2)) -#loc16 = loc("jit()/jit(main)/while/body/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)) slice_sizes=(1,) unique_indices=True indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]"(#loc2)) -#loc17 = loc("jit()/jit(main)/while/body/dynamic_slice[slice_sizes=(1,)]"(#loc2)) -#loc18 = loc("jit()/jit(main)/while/body/squeeze[dimensions=(0,)]"(#loc2)) -#loc19 = loc("jit()/jit(main)/while/body/scatter[update_consts=() dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.FILL_OR_DROP]"(#loc2)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x013\x05\x01\x03\x01\x03\x05\x03#\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f!#%'\x03\xaa\x02\x0e\x025\x01\xb1\x0f\x0f\x07\x17\x0b\x13\x13\x0f\x1b\x13\x0f\x0f\x13\x0f\x13\x13\x13\x0f\x0f\x0b\x13\x17\x0b\x0b\x0b\x0b;\x0b\x0b\x0b\x0f;#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0b\x13\x0b\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x13\x13\x0f\x0b\x0f\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x0f\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x03Q\x0f\x0f/\x0b\x0f\x0b\x0bO\x0b/\x0b\x17\x13\x0b\x13\x0b\x13\x0b\x0b\x0b/\x0f/\x1f\x0b\x0b\x0b\x0b\x1b\x0f\x17\x17/\x1f\x1f\x0b/O/\x0b\x01\x07\x0b\x13\x0b\x01\x03\x0f\x033\x0f\x0f\x13\x13\x0f\x17\x07\x07\x0b\x07\x07\x13\x0f\x13\x1b\x07\x13\x13\x13\x13\x13\x13\x17\x17\x13\x02b\t\x1d]\x07\x1d\x8b\x07\x1f\x17-\xaa\x08\x01\x05)\x03\x03/\xb9\x03\x03\t\xd9\x1d\x8d\x07\x03\x051\xc13\xff\x03\x03\t\xfd\x1d\x8f\x07\x1d\x91\x07\x03\x03\t\xdf\x1d\x95\x07\x1d\x02\x02\x07\x03\x03\t\xdd\x03\x03\t\xf5\x1d\x93\x07\x11\x01\x05\x05+\x03\x03S\xb1\x17-\xa6\x08\x01\x05-\x05/\x051\x053\x03\r\x97\xb57\xb19\xbb\x99\xb9;\xc3\x9b\xb5\x055\x057\x059\x1d\x9d\x07\x03\r7\xb19\xbb\xa9\xb5\xab\xb5\xad\xbb\xaf\xb9\x03\x07C%E%'G\x05;\x05=\x05?\x03\x0bK\xbdM\xc5O\xc7'\xd5Q\xd7\x05A\x05C\x05E\x05G\x05I\x1dW+\x05K\x1d[+\x05M\x05O\x03\x03a\xb1\x05Q\x03\x03\t\xdb\x03\x11g\xe1i\xe3k\xe5m\xbdo\xe7q\xe9s\xebu\xef\x05S\x05U\x05W\x05Y\x05[\x05]\x05_\x05a\x03\x03\t\xf3\x03\x051\xc13\xf7\x03\x03\t\xf9\x03\x03/\xfb\x1d\x81\x07\x05c\x1d\x85\x07\x05e\x1d\x89\x07\x05g\x05i\x05k\x05m\x05o\x05q\x05s\x05u\x05w\x05y\x05{\x03\x03;\xc3\x1d\xa3\x07\x05}\x1d\xa7\x07\x05\x7f\x05\x81\x05\x83\x05\x85\x05\x87\x13\x11\x01\x1f'\x01\x1f\x1d\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1d\x89\x1f-\x01\x05\x03\x03\x01\x1f)!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\t\x07\x1f\x1d\x11\x01\x00\x00\x00\x00\x00\x00\x00#\x1f\x03\x07\xc9\xcd\xd1\r\x03\xb7\xcb\x1d\x8b\r\x03\xb7\xcf\x1d\x8d\r\x03\xb7\xd3\x1d\x8f\x1d\x91\x1d\x93\x1f\x03\x11\x03\x00\x00\x00\x00\x00\x00\x00\x1f\x19\x01\x1f\x03\x11\x01\x00\x00\x00\x00\x00\x00\x00\x1f\x05\t\x03\x00\x00\x00\x0b\x05\x1d\x95\x1d\x97\x05\x01\x03\t\xb3\xb3\xb3\xbf\x03\x03\xed\x15\x03\x01\r\x01\x03\x07\xbf\xf1\xb3\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x00\x00\x00\x00\x07\x05\x1f\x1b\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f3!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f\x03\x11\x00\x00\x00\x00\x00\x00\x00\x00\x07\x0b\x05\x99\x1d\n\x02\x07\x05\x9b\x01\x02\x02)\x01\x11)\x01\x0f)\x03\r\x0f)\x03\x05\x0f)\x01\x17)\x05\r\r\x13\x1b\x1d\x03!\x13\x01)\x03\x01\x0f)\x01\x13)\x03\x05\x11\x11\x01\x07\r\x07\x07\t)\x03%\x13)\x03\t\x0f)\x03\x01\x15)\x03\t\x15)\x03\x05\x15)\x03\x01\x11)\x05\x05\x05\x17)\x05\r\r\x17)\x03\t\x11\x04f\n\x05\x01\x11\x05A\x07\x03\x01\x05\x19\x11\x05I\x05\x03K\x85\x13\x03U)\x03#\x0f\x06Y\x03\r\x03\x01\x03\x03\x01\r\x03\x03\x03\x03\x01\r\x03\x03\x0b\x06\x01\x03\x05\x03\x05\x0f\x06\x01\x03\t\x03\t\x0b\x06\x01\x03\x05\x03\x07\x0f\x06\x01\x03\t\x03\r\x1b\x07\x01_\x03%\x05\x0b\x0f\x03\x03\x01\r\x03\x03\x0b\x06\x01\x03\x05\x03\x13\x0f\x06\x01\x03\t\x03\x15\x03\x03\x01c\x03\x19\x03\x03\x01\x1f\x03\x03\x03\x03\x01\x19\x03\x05\x03\x03\x01\x19\x03\x05\x1d\x07\x01e\x07\r\x07\x05\t\x1b\x1d\x1f\x03\x03\x03\x01w\x03\x05\x05\x07\x01\x0b\x03\x07\x03'\x1f\x06\x01\x03\x07\x05#)\x03\x03\x01!\x03\x05\x05\x07\x01\x0b\x03\x05\x03-\x07\x07\x01y\x03\x0b\x05%/\x05\x07\x01\x0b\x03/\x031\x03\x03\x01{\x03\x1b\x05\x07\x01\x0b\x03\r\x035\x05\x07\x01}\x031\x033\r\x06\x01\x03\r\x079!7\x13\x03\x7f)\x03\x07\x03\x03\x83\x13\x03\x03\x03\x03\x87\x13\x03\x03!\x16\x03\t\x03\x03\x07\x07\tA?=+\t\x03\r\x0f\t\x03\x05\x03\x05\x07\x05\x07\x05\x03\x03\x03\r\x03\x03\x07\x07\x06\x02\x11\x03\x0b\x05KS\x11\x04\x03\x03U\x03]\xaf\t\x03\x05\x03\x05\x07\x05\x07\x05\x03\x03\x03\x1f\x03\x03\t\x06\x0f\x03\x03\x05MS\x03\x03\x03\x13\x03\x03\x07\x07\x15\x11\x03\x0b\x05MW\x03\x03\x03\r\x03\x03\t\x06\x0f\x03\x03\x05M[\r\x06\x17\x03\x03\x07Y]M\x0b\x06#\x03\x05\x03_\x05\x07\x1b\x0b\x03\t\x03a\x15\x07=5\x03\x05\x05Qc\x03\x03\x03\x13\x03\x03\x07\x07\x15\x11\x03\x0b\x05Mg\x03\x03\x03\r\x03\x03\t\x06\x0f\x03\x03\x05Mk\r\x06\x17\x03\x03\x07imM\x0b\x06#\x03\x05\x03o\x05\x07\x1b\x0b\x03\t\x03q\x15\x07=5\x03\x05\x05Os\x03\x03\x03!\x03\x05\x07\x07\x15\x11\x03\x0b\x05ew\x03\x03\x03\x19\x03\x05\t\x06\x0f\x03\x05\x05e{\r\x06\x17\x03\x05\x07y}e#\x07\xa1\x9f\x03\t\x05O\x7f\x0f\x06\xa5\x03\x05\x03\x81\x03\x03\x03\x13\x03\x03\x07\x07\x15\x11\x03\x0b\x05M\x85\x03\x03\x03\r\x03\x03\t\x06\x0f\x03\x03\x05M\x89\r\x06\x17\x03\x03\x07\x87\x8bM\x0b\x06#\x03\x05\x03\x8d\x05\x07\x1b\x0b\x03\t\x03\x8f\x17\x17\x1d?\x03\x07\x07O\x91\x83\x05\x03\x05\x07\x05\x05\x05\x05\x05\x11\x04\x1d\x03\xa9\x03\x03\x03!\x03\x05\x07\x07\x15\x11\x03\x0b\x05e\x95\x03\x03\x03\x19\x03\x05\t\x06\x0f\x03\x05\x05e\x99\r\x06\x17\x03\x05\x07\x97\x9be\x05\x07\x1b\x0b\x03\t\x03\x9d\x17\x17\x1d?\x03\x07\x07\x93\x9fu\x05\x03\x05\x07\x05\x05\x05\x05\x05\x11\x04\x1d\x03\xa9\x03\x03\x03\x1f\x03\x03\t\x06\x0f\x03\x03\x05K\xa3\x11\x04\x03\t\xa5U\xa1Q\x11\x04\x05\x07;+G\x06\x03\x01\x05\x01\x00~%\x9dM2\x04\x1d\x03\x0f\x0b\t\t\t!'\x1f;+y\x87.\x04!\x19+\xb1\xb3YMO{\xe9\x8b\x83\x1f/!!)#\x1f\x19\x157\x85\x8b\x1f\x1f\x15\x1d\x15\x1b%)\x19'#+\x1b+\x83\x13\r#\x13\x19\x1f\x1f\x11\x17\x15\x11\x15\x17\x15\x17\x0f\x17)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00compare_v1\x00add_v1\x00convert_v1\x00select_v1\x00reshape_v1\x00return_v1\x00iota_v1\x00gather_v1\x00scatter_v1\x00func_v1\x00concatenate_v1\x00custom_call_v1\x00subtract_v1\x00while_v1\x00dynamic_slice_v1\x00value\x00sym_name\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00broadcast_dimensions\x00compare_type\x00comparison_direction\x00index_vector_dim\x00indices_are_sorted\x00slice_sizes\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00jit()/jit(main)/iota[dtype=complex64 shape=(9,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]\x00jit()/jit(main)/lu\x00dimension\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jit()/jit(main)/iota[dtype=int32 shape=(3,) dimension=0]\x00jit()/jit(main)/lu_pivots_to_permutation[permutation_size=3]\x00jit()/jit(main)/scan[reverse=False length=3 num_consts=0 num_carry=3 linear=(False, False, False) unroll=1]\x00jit()/jit(main)/while[cond_nconsts=0 body_nconsts=0]\x00jit()/jit(main)/while/body/add\x00jit()/jit(main)/while/body/lt\x00jit()/jit(main)/while/body/select_n\x00jit()/jit(main)/while/body/convert_element_type[new_dtype=int32 weak_type=False]\x00jit()/jit(main)/while/body/broadcast_in_dim[shape=(1,) broadcast_dimensions=()]\x00collapsed_slice_dims\x00offset_dims\x00start_index_map\x00jit()/jit(main)/while/body/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)) slice_sizes=(1,) unique_indices=True indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]\x00jit()/jit(main)/while/body/dynamic_slice[slice_sizes=(1,)]\x00jit()/jit(main)/while/body/squeeze[dimensions=(0,)]\x00inserted_window_dims\x00scatter_dims_to_operand_dims\x00unique_indices\x00update_window_dims\x00jax.result_info\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_cgetrf\x00jit()/jit(main)/while/body/scatter[update_consts=() dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.FILL_OR_DROP]\x00jit()/jit(main)/while/cond/lt\x00", - xla_call_module_version=6, -) # End paste - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_06_14['c128'] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_zgetrf'], - serialized_date=datetime.date(2023, 6, 14), - inputs=(), - expected_outputs=(array([[6. +0.j, 7. +0.j, 8. +0.j], - [0. +0.j, 1. +0.j, 2. +0.j], - [0.5+0.j, 0.5+0.j, 0. +0.j]]), array([2, 2, 2], dtype=int32), array([2, 0, 1], dtype=int32)), - mlir_module_text=r""" -#loc = loc(unknown) -module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<3x3xcomplex> {jax.result_info = "[0]"}, tensor<3xi32> {jax.result_info = "[1]"}, tensor<3xi32> {jax.result_info = "[2]"}) { - %0 = stablehlo.iota dim = 0 : tensor<9xcomplex> loc(#loc3) - %1 = stablehlo.reshape %0 : (tensor<9xcomplex>) -> tensor<3x3xcomplex> loc(#loc4) - %2 = stablehlo.constant dense<3> : tensor loc(#loc5) - %3 = stablehlo.constant dense<3> : tensor loc(#loc5) - %4 = stablehlo.convert %2 : (tensor) -> tensor loc(#loc5) - %5 = stablehlo.reshape %4 : (tensor) -> tensor<1xi32> loc(#loc5) - %6 = stablehlo.convert %3 : (tensor) -> tensor loc(#loc5) - %7 = stablehlo.reshape %6 : (tensor) -> tensor<1xi32> loc(#loc5) - %8 = stablehlo.concatenate %5, %7, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> loc(#loc5) - %9 = stablehlo.constant dense<3> : tensor loc(#loc5) - %10 = stablehlo.convert %9 : (tensor) -> tensor loc(#loc5) - %11 = stablehlo.reshape %10 : (tensor) -> tensor<1xi32> loc(#loc5) - %12 = stablehlo.constant dense<> : tensor<0xi32> loc(#loc5) - %13 = stablehlo.constant dense<1> : tensor loc(#loc5) - %14 = stablehlo.constant dense<3> : tensor loc(#loc5) - %15 = stablehlo.constant dense<3> : tensor loc(#loc5) - %16:3 = stablehlo.custom_call @lapack_zgetrf(%13, %14, %15, %1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>]} : (tensor, tensor, tensor, tensor<3x3xcomplex>) -> (tensor<3x3xcomplex>, tensor<3xi32>, tensor) loc(#loc5) - %17 = stablehlo.constant dense<1> : tensor loc(#loc5) - %18 = stablehlo.broadcast_in_dim %17, dims = [] : (tensor) -> tensor<3xi32> loc(#loc5) - %19 = stablehlo.subtract %16#1, %18 : tensor<3xi32> loc(#loc5) - %20 = stablehlo.constant dense<0> : tensor loc(#loc5) - %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor loc(#loc5) - %22 = stablehlo.compare GE, %16#2, %21, SIGNED : (tensor, tensor) -> tensor loc(#loc5) - %23 = stablehlo.broadcast_in_dim %22, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) - %24 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc5) - %25 = stablehlo.broadcast_in_dim %24, dims = [] : (tensor>) -> tensor<3x3xcomplex> loc(#loc5) - %26 = stablehlo.broadcast_in_dim %23, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> loc(#loc5) - %27 = stablehlo.select %26, %16#0, %25 : tensor<3x3xi1>, tensor<3x3xcomplex> loc(#loc5) - %28 = stablehlo.iota dim = 0 : tensor<3xi32> loc(#loc6) - %29 = stablehlo.constant dense<0> : tensor loc(#loc7) - %30 = stablehlo.constant dense<0> : tensor loc(#loc8) - %31:4 = stablehlo.while(%iterArg = %30, %iterArg_0 = %29, %iterArg_1 = %28, %iterArg_2 = %19) : tensor, tensor, tensor<3xi32>, tensor<3xi32> - cond { - %32 = stablehlo.constant dense<3> : tensor loc(#loc9) - %33 = stablehlo.compare LT, %iterArg, %32, SIGNED : (tensor, tensor) -> tensor loc(#loc10) - stablehlo.return %33 : tensor loc(#loc9) - } do { - %32 = stablehlo.constant dense<1> : tensor loc(#loc9) - %33 = stablehlo.add %iterArg_0, %32 : tensor loc(#loc11) - %34 = stablehlo.constant dense<0> : tensor loc(#loc9) - %35 = stablehlo.compare LT, %iterArg_0, %34, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %36 = stablehlo.constant dense<3> : tensor loc(#loc9) - %37 = stablehlo.add %iterArg_0, %36 : tensor loc(#loc11) - %38 = stablehlo.select %35, %37, %iterArg_0 : tensor, tensor loc(#loc13) - %39 = stablehlo.convert %38 : (tensor) -> tensor loc(#loc14) - %40 = stablehlo.broadcast_in_dim %39, dims = [] : (tensor) -> tensor<1xi32> loc(#loc15) - %41 = "stablehlo.gather"(%iterArg_2, %40) {dimension_numbers = #stablehlo.gather, indices_are_sorted = true, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1xi32>) -> tensor loc(#loc16) - %42 = stablehlo.constant dense<0> : tensor loc(#loc9) - %43 = stablehlo.compare LT, %iterArg_0, %42, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %44 = stablehlo.constant dense<3> : tensor loc(#loc9) - %45 = stablehlo.add %iterArg_0, %44 : tensor loc(#loc11) - %46 = stablehlo.select %43, %45, %iterArg_0 : tensor, tensor loc(#loc13) - %47 = stablehlo.convert %46 : (tensor) -> tensor loc(#loc14) - %48 = stablehlo.broadcast_in_dim %47, dims = [] : (tensor) -> tensor<1xi32> loc(#loc15) - %49 = "stablehlo.gather"(%iterArg_1, %48) {dimension_numbers = #stablehlo.gather, indices_are_sorted = true, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1xi32>) -> tensor loc(#loc16) - %50 = stablehlo.constant dense<0> : tensor loc(#loc9) - %51 = stablehlo.compare LT, %41, %50, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %52 = stablehlo.constant dense<3> : tensor loc(#loc9) - %53 = stablehlo.add %41, %52 : tensor loc(#loc11) - %54 = stablehlo.select %51, %53, %41 : tensor, tensor loc(#loc13) - %55 = stablehlo.dynamic_slice %iterArg_1, %54, sizes = [1] : (tensor<3xi32>, tensor) -> tensor<1xi32> loc(#loc17) - %56 = stablehlo.reshape %55 : (tensor<1xi32>) -> tensor loc(#loc18) - %57 = stablehlo.constant dense<0> : tensor loc(#loc9) - %58 = stablehlo.compare LT, %iterArg_0, %57, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %59 = stablehlo.constant dense<3> : tensor loc(#loc9) - %60 = stablehlo.add %iterArg_0, %59 : tensor loc(#loc11) - %61 = stablehlo.select %58, %60, %iterArg_0 : tensor, tensor loc(#loc13) - %62 = stablehlo.convert %61 : (tensor) -> tensor loc(#loc14) - %63 = stablehlo.broadcast_in_dim %62, dims = [] : (tensor) -> tensor<1xi32> loc(#loc15) - %64 = "stablehlo.scatter"(%iterArg_1, %63, %56) ({ - ^bb0(%arg0: tensor loc(unknown), %arg1: tensor loc(unknown)): - stablehlo.return %arg1 : tensor loc(#loc19) - }) {indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true} : (tensor<3xi32>, tensor<1xi32>, tensor) -> tensor<3xi32> loc(#loc19) - %65 = stablehlo.constant dense<0> : tensor loc(#loc9) - %66 = stablehlo.compare LT, %41, %65, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %67 = stablehlo.constant dense<3> : tensor loc(#loc9) - %68 = stablehlo.add %41, %67 : tensor loc(#loc11) - %69 = stablehlo.select %66, %68, %41 : tensor, tensor loc(#loc13) - %70 = stablehlo.broadcast_in_dim %69, dims = [] : (tensor) -> tensor<1xi32> loc(#loc15) - %71 = "stablehlo.scatter"(%64, %70, %49) ({ - ^bb0(%arg0: tensor loc(unknown), %arg1: tensor loc(unknown)): - stablehlo.return %arg1 : tensor loc(#loc19) - }) {indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true} : (tensor<3xi32>, tensor<1xi32>, tensor) -> tensor<3xi32> loc(#loc19) - %72 = stablehlo.constant dense<1> : tensor loc(#loc9) - %73 = stablehlo.add %iterArg, %72 : tensor loc(#loc11) - stablehlo.return %73, %33, %71, %iterArg_2 : tensor, tensor, tensor<3xi32>, tensor<3xi32> loc(#loc9) - } loc(#loc9) - return %27, %19, %31#2 : tensor<3x3xcomplex>, tensor<3xi32>, tensor<3xi32> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc1 = loc("third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py":553:0) -#loc2 = loc("third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py":554:0) -#loc3 = loc("jit()/jit(main)/iota[dtype=complex128 shape=(9,) dimension=0]"(#loc1)) -#loc4 = loc("jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]"(#loc1)) -#loc5 = loc("jit()/jit(main)/lu"(#loc2)) -#loc6 = loc("jit()/jit(main)/iota[dtype=int32 shape=(3,) dimension=0]"(#loc2)) -#loc7 = loc("jit()/jit(main)/lu_pivots_to_permutation[permutation_size=3]"(#loc2)) -#loc8 = loc("jit()/jit(main)/scan[reverse=False length=3 num_consts=0 num_carry=3 linear=(False, False, False) unroll=1]"(#loc2)) -#loc9 = loc("jit()/jit(main)/while[cond_nconsts=0 body_nconsts=0]"(#loc2)) -#loc10 = loc("jit()/jit(main)/while/cond/lt"(#loc2)) -#loc11 = loc("jit()/jit(main)/while/body/add"(#loc2)) -#loc12 = loc("jit()/jit(main)/while/body/lt"(#loc2)) -#loc13 = loc("jit()/jit(main)/while/body/select_n"(#loc2)) -#loc14 = loc("jit()/jit(main)/while/body/convert_element_type[new_dtype=int32 weak_type=False]"(#loc2)) -#loc15 = loc("jit()/jit(main)/while/body/broadcast_in_dim[shape=(1,) broadcast_dimensions=()]"(#loc2)) -#loc16 = loc("jit()/jit(main)/while/body/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)) slice_sizes=(1,) unique_indices=True indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]"(#loc2)) -#loc17 = loc("jit()/jit(main)/while/body/dynamic_slice[slice_sizes=(1,)]"(#loc2)) -#loc18 = loc("jit()/jit(main)/while/body/squeeze[dimensions=(0,)]"(#loc2)) -#loc19 = loc("jit()/jit(main)/while/body/scatter[update_consts=() dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.FILL_OR_DROP]"(#loc2)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x013\x05\x01\x03\x01\x03\x05\x03#\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f!#%'\x03\xaa\x02\x0e\x025\x01\xb1\x0f\x0f\x07\x17\x0b\x13\x13\x0f\x1b\x13\x0f\x0f\x13\x0f\x13\x13\x13\x0f\x0f\x0b\x13\x17\x0b\x0b\x0b\x0b;\x0b\x0b\x0b\x0f;#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0b\x13\x0b\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x13\x13\x0f\x0b\x0f\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x0f\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x03Q\x0f\x0f/\x0b\x0f\x0b\x0bO\x0b/\x0b\x17\x13\x0b\x13\x0b\x13\x0b\x0b\x0b/\x0f/\x1f\x0b\x0b\x0b\x0b\x1b\x0f\x17\x17/\x1f\x1f\x0bOO/\x0b\x01\x07\x0b\x13\x0b\x01\x03\x0f\x033\x0f\x0f\x13\x13\x0f\x17\x07\x07\x0b\x07\x07\x13\x0f\x13\x1b\x07\x13\x13\x13\x13\x13\x13\x17\x17\x13\x02\x82\t\x1d]\x07\x1d\x8b\x07\x1f\x17-\xaa\x08\x01\x05)\x03\x03/\xb9\x03\x03\t\xd9\x1d\x8d\x07\x03\x051\xc13\xff\x03\x03\t\xfd\x1d\x8f\x07\x1d\x91\x07\x03\x03\t\xdf\x1d\x95\x07\x1d\x02\x02\x07\x03\x03\t\xdd\x03\x03\t\xf5\x1d\x93\x07\x11\x01\x05\x05+\x03\x03S\xb1\x17-\xa6\x08\x01\x05-\x05/\x051\x053\x03\r\x97\xb57\xb19\xbb\x99\xb9;\xc3\x9b\xb5\x055\x057\x059\x1d\x9d\x07\x03\r7\xb19\xbb\xa9\xb5\xab\xb5\xad\xbb\xaf\xb9\x03\x07C%E%'G\x05;\x05=\x05?\x03\x0bK\xbdM\xc5O\xc7'\xd5Q\xd7\x05A\x05C\x05E\x05G\x05I\x1dW+\x05K\x1d[+\x05M\x05O\x03\x03a\xb1\x05Q\x03\x03\t\xdb\x03\x11g\xe1i\xe3k\xe5m\xbdo\xe7q\xe9s\xebu\xef\x05S\x05U\x05W\x05Y\x05[\x05]\x05_\x05a\x03\x03\t\xf3\x03\x051\xc13\xf7\x03\x03\t\xf9\x03\x03/\xfb\x1d\x81\x07\x05c\x1d\x85\x07\x05e\x1d\x89\x07\x05g\x05i\x05k\x05m\x05o\x05q\x05s\x05u\x05w\x05y\x05{\x03\x03;\xc3\x1d\xa3\x07\x05}\x1d\xa7\x07\x05\x7f\x05\x81\x05\x83\x05\x85\x05\x87\x13\x11\x01\x1f'\x01\x1f\x1d\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1d\x89\x1f-\x01\x05\x03\x03\x01\x1f)!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\t\x07\x1f\x1d\x11\x01\x00\x00\x00\x00\x00\x00\x00#\x1f\x03\x07\xc9\xcd\xd1\r\x03\xb7\xcb\x1d\x8b\r\x03\xb7\xcf\x1d\x8d\r\x03\xb7\xd3\x1d\x8f\x1d\x91\x1d\x93\x1f\x03\x11\x03\x00\x00\x00\x00\x00\x00\x00\x1f\x19\x01\x1f\x03\x11\x01\x00\x00\x00\x00\x00\x00\x00\x1f\x05\t\x03\x00\x00\x00\x0b\x05\x1d\x95\x1d\x97\x05\x01\x03\t\xb3\xb3\xb3\xbf\x03\x03\xed\x15\x03\x01\r\x01\x03\x07\xbf\xf1\xb3\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x00\x00\x00\x00\x07\x05\x1f\x1b!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f3!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f\x03\x11\x00\x00\x00\x00\x00\x00\x00\x00\x07\x0b\x05\x99\x1d\n\x02\x07\x05\x9b\x01\x02\x02)\x01\x11)\x01\x0f)\x03\r\x0f)\x03\x05\x0f)\x01\x17)\x05\r\r\x13\x1b\x1d\x03!\x13\x01)\x03\x01\x0f)\x01\x13)\x03\x05\x11\x11\x01\x07\r\x07\x07\x0b)\x03%\x13)\x03\t\x0f)\x03\x01\x15)\x03\t\x15)\x03\x05\x15)\x03\x01\x11)\x05\x05\x05\x17)\x05\r\r\x17)\x03\t\x11\x04f\n\x05\x01\x11\x05A\x07\x03\x01\x05\x19\x11\x05I\x05\x03K\x85\x13\x03U)\x03#\x0f\x06Y\x03\r\x03\x01\x03\x03\x01\r\x03\x03\x03\x03\x01\r\x03\x03\x0b\x06\x01\x03\x05\x03\x05\x0f\x06\x01\x03\t\x03\t\x0b\x06\x01\x03\x05\x03\x07\x0f\x06\x01\x03\t\x03\r\x1b\x07\x01_\x03%\x05\x0b\x0f\x03\x03\x01\r\x03\x03\x0b\x06\x01\x03\x05\x03\x13\x0f\x06\x01\x03\t\x03\x15\x03\x03\x01c\x03\x19\x03\x03\x01\x1f\x03\x03\x03\x03\x01\x19\x03\x05\x03\x03\x01\x19\x03\x05\x1d\x07\x01e\x07\r\x07\x05\t\x1b\x1d\x1f\x03\x03\x03\x01w\x03\x05\x05\x07\x01\x0b\x03\x07\x03'\x1f\x06\x01\x03\x07\x05#)\x03\x03\x01!\x03\x05\x05\x07\x01\x0b\x03\x05\x03-\x07\x07\x01y\x03\x0b\x05%/\x05\x07\x01\x0b\x03/\x031\x03\x03\x01{\x03\x1b\x05\x07\x01\x0b\x03\r\x035\x05\x07\x01}\x031\x033\r\x06\x01\x03\r\x079!7\x13\x03\x7f)\x03\x07\x03\x03\x83\x13\x03\x03\x03\x03\x87\x13\x03\x03!\x16\x03\t\x03\x03\x07\x07\tA?=+\t\x03\r\x0f\t\x03\x05\x03\x05\x07\x05\x07\x05\x03\x03\x03\r\x03\x03\x07\x07\x06\x02\x11\x03\x0b\x05KS\x11\x04\x03\x03U\x03]\xaf\t\x03\x05\x03\x05\x07\x05\x07\x05\x03\x03\x03\x1f\x03\x03\t\x06\x0f\x03\x03\x05MS\x03\x03\x03\x13\x03\x03\x07\x07\x15\x11\x03\x0b\x05MW\x03\x03\x03\r\x03\x03\t\x06\x0f\x03\x03\x05M[\r\x06\x17\x03\x03\x07Y]M\x0b\x06#\x03\x05\x03_\x05\x07\x1b\x0b\x03\t\x03a\x15\x07=5\x03\x05\x05Qc\x03\x03\x03\x13\x03\x03\x07\x07\x15\x11\x03\x0b\x05Mg\x03\x03\x03\r\x03\x03\t\x06\x0f\x03\x03\x05Mk\r\x06\x17\x03\x03\x07imM\x0b\x06#\x03\x05\x03o\x05\x07\x1b\x0b\x03\t\x03q\x15\x07=5\x03\x05\x05Os\x03\x03\x03!\x03\x05\x07\x07\x15\x11\x03\x0b\x05ew\x03\x03\x03\x19\x03\x05\t\x06\x0f\x03\x05\x05e{\r\x06\x17\x03\x05\x07y}e#\x07\xa1\x9f\x03\t\x05O\x7f\x0f\x06\xa5\x03\x05\x03\x81\x03\x03\x03\x13\x03\x03\x07\x07\x15\x11\x03\x0b\x05M\x85\x03\x03\x03\r\x03\x03\t\x06\x0f\x03\x03\x05M\x89\r\x06\x17\x03\x03\x07\x87\x8bM\x0b\x06#\x03\x05\x03\x8d\x05\x07\x1b\x0b\x03\t\x03\x8f\x17\x17\x1d?\x03\x07\x07O\x91\x83\x05\x03\x05\x07\x05\x05\x05\x05\x05\x11\x04\x1d\x03\xa9\x03\x03\x03!\x03\x05\x07\x07\x15\x11\x03\x0b\x05e\x95\x03\x03\x03\x19\x03\x05\t\x06\x0f\x03\x05\x05e\x99\r\x06\x17\x03\x05\x07\x97\x9be\x05\x07\x1b\x0b\x03\t\x03\x9d\x17\x17\x1d?\x03\x07\x07\x93\x9fu\x05\x03\x05\x07\x05\x05\x05\x05\x05\x11\x04\x1d\x03\xa9\x03\x03\x03\x1f\x03\x03\t\x06\x0f\x03\x03\x05K\xa3\x11\x04\x03\t\xa5U\xa1Q\x11\x04\x05\x07;+G\x06\x03\x01\x05\x01\x00\x82%\x9dM2\x04\x1d\x03\x0f\x0b\t\t\t!'\x1f;+y\x87.\x04!\x19+\xb1\xb3YMO{\xe9\x8b\x83\x1f/!!)#\x1f\x19\x157\x85\x8d\x1f\x1f\x15\x1d\x15\x1b%)\x19'#+\x1b+\x83\x13\r#\x13\x19\x1f\x1f\x11\x17\x15\x11\x15\x17\x15\x17\x0f\x17)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00compare_v1\x00add_v1\x00convert_v1\x00select_v1\x00reshape_v1\x00return_v1\x00iota_v1\x00gather_v1\x00scatter_v1\x00func_v1\x00concatenate_v1\x00custom_call_v1\x00subtract_v1\x00while_v1\x00dynamic_slice_v1\x00value\x00sym_name\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00broadcast_dimensions\x00compare_type\x00comparison_direction\x00index_vector_dim\x00indices_are_sorted\x00slice_sizes\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00jit()/jit(main)/iota[dtype=complex128 shape=(9,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]\x00jit()/jit(main)/lu\x00dimension\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jit()/jit(main)/iota[dtype=int32 shape=(3,) dimension=0]\x00jit()/jit(main)/lu_pivots_to_permutation[permutation_size=3]\x00jit()/jit(main)/scan[reverse=False length=3 num_consts=0 num_carry=3 linear=(False, False, False) unroll=1]\x00jit()/jit(main)/while[cond_nconsts=0 body_nconsts=0]\x00jit()/jit(main)/while/body/add\x00jit()/jit(main)/while/body/lt\x00jit()/jit(main)/while/body/select_n\x00jit()/jit(main)/while/body/convert_element_type[new_dtype=int32 weak_type=False]\x00jit()/jit(main)/while/body/broadcast_in_dim[shape=(1,) broadcast_dimensions=()]\x00collapsed_slice_dims\x00offset_dims\x00start_index_map\x00jit()/jit(main)/while/body/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)) slice_sizes=(1,) unique_indices=True indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]\x00jit()/jit(main)/while/body/dynamic_slice[slice_sizes=(1,)]\x00jit()/jit(main)/while/body/squeeze[dimensions=(0,)]\x00inserted_window_dims\x00scatter_dims_to_operand_dims\x00unique_indices\x00update_window_dims\x00jax.result_info\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_zgetrf\x00jit()/jit(main)/while/body/scatter[update_consts=() dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.FILL_OR_DROP]\x00jit()/jit(main)/while/cond/lt\x00", - xla_call_module_version=6, -) # End paste - data_2024_05_31 = {} - # Pasted from the test output (see export_back_compat_test_util.py module docstring) data_2024_05_31["c128"] = dict( testdata_version=1, diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_qr_lapack_geqrf.py b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_qr_lapack_geqrf.py index 94314a7ae518..bf41f3c3445c 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_qr_lapack_geqrf.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_qr_lapack_geqrf.py @@ -17,259 +17,13 @@ import datetime from numpy import array, float32, complex64 -data_2023_03_17 = {} +data_2025_04_02 = {} -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_03_17["f32"] = dict( +data_2025_04_02['c128'] = dict( testdata_version=1, platform='cpu', - custom_call_targets=['lapack_sgeqrf', 'lapack_sorgqr'], - serialized_date=datetime.date(2023, 3, 17), - inputs=(), - expected_outputs=(array([[ 0. , 0.91287076, 0.4082487 ], - [-0.44721356, 0.36514866, -0.8164965 ], - [-0.8944271 , -0.18257445, 0.40824816]], dtype=float32), array([[-6.7082043e+00, -8.0498438e+00, -9.3914852e+00], - [ 0.0000000e+00, 1.0954441e+00, 2.1908894e+00], - [ 0.0000000e+00, 0.0000000e+00, 7.1525574e-07]], dtype=float32)), - mlir_module_text=r""" -module @jit__lambda_ { - func.func public @main() -> (tensor<3x3xf32> {jax.result_info = "[0]"}, tensor<3x3xf32> {jax.result_info = "[1]"}) { - %0 = stablehlo.iota dim = 0 : tensor<9xf32> - %1 = stablehlo.reshape %0 : (tensor<9xf32>) -> tensor<3x3xf32> - %2 = stablehlo.constant dense<1> : tensor - %3 = stablehlo.constant dense<3> : tensor - %4 = stablehlo.constant dense<3> : tensor - %5 = stablehlo.constant dense<96> : tensor - %6 = stablehlo.custom_call @lapack_sgeqrf(%2, %3, %4, %5, %1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor<3x3xf32>) -> tuple, tensor<3xf32>, tensor, tensor<96xf32>> - %7 = stablehlo.get_tuple_element %6[0] : (tuple, tensor<3xf32>, tensor, tensor<96xf32>>) -> tensor<3x3xf32> - %8 = stablehlo.get_tuple_element %6[1] : (tuple, tensor<3xf32>, tensor, tensor<96xf32>>) -> tensor<3xf32> - %9 = stablehlo.get_tuple_element %6[2] : (tuple, tensor<3xf32>, tensor, tensor<96xf32>>) -> tensor - %10 = stablehlo.get_tuple_element %6[3] : (tuple, tensor<3xf32>, tensor, tensor<96xf32>>) -> tensor<96xf32> - %11 = stablehlo.constant dense<0> : tensor - %12 = stablehlo.broadcast_in_dim %11, dims = [] : (tensor) -> tensor - %13 = stablehlo.compare EQ, %9, %12, SIGNED : (tensor, tensor) -> tensor - %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1x1xi1> - %15 = stablehlo.constant dense<0x7FC00000> : tensor - %16 = stablehlo.broadcast_in_dim %15, dims = [] : (tensor) -> tensor<3x3xf32> - %17 = stablehlo.broadcast_in_dim %14, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> - %18 = stablehlo.select %17, %7, %16 : tensor<3x3xi1>, tensor<3x3xf32> - %19 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1xi1> - %20 = stablehlo.constant dense<0x7FC00000> : tensor - %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<3xf32> - %22 = stablehlo.broadcast_in_dim %19, dims = [0] : (tensor<1xi1>) -> tensor<3xi1> - %23 = stablehlo.select %22, %8, %21 : tensor<3xi1>, tensor<3xf32> - %24 = stablehlo.constant dense<0.000000e+00> : tensor - %25 = stablehlo.pad %18, %24, low = [0, 0], high = [0, 0], interior = [0, 0] : (tensor<3x3xf32>, tensor) -> tensor<3x3xf32> - %26 = stablehlo.constant dense<1> : tensor - %27 = stablehlo.constant dense<3> : tensor - %28 = stablehlo.constant dense<3> : tensor - %29 = stablehlo.constant dense<3> : tensor - %30 = stablehlo.constant dense<96> : tensor - %31 = stablehlo.custom_call @lapack_sorgqr(%26, %27, %28, %29, %30, %25, %23) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor<3x3xf32>, tensor<3xf32>) -> tuple, tensor, tensor<96xf32>> - %32 = stablehlo.get_tuple_element %31[0] : (tuple, tensor, tensor<96xf32>>) -> tensor<3x3xf32> - %33 = stablehlo.get_tuple_element %31[1] : (tuple, tensor, tensor<96xf32>>) -> tensor - %34 = stablehlo.get_tuple_element %31[2] : (tuple, tensor, tensor<96xf32>>) -> tensor<96xf32> - %35 = stablehlo.constant dense<0> : tensor - %36 = stablehlo.broadcast_in_dim %35, dims = [] : (tensor) -> tensor - %37 = stablehlo.compare EQ, %33, %36, SIGNED : (tensor, tensor) -> tensor - %38 = stablehlo.broadcast_in_dim %37, dims = [] : (tensor) -> tensor<1x1xi1> - %39 = stablehlo.constant dense<0x7FC00000> : tensor - %40 = stablehlo.broadcast_in_dim %39, dims = [] : (tensor) -> tensor<3x3xf32> - %41 = stablehlo.broadcast_in_dim %38, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> - %42 = stablehlo.select %41, %32, %40 : tensor<3x3xi1>, tensor<3x3xf32> - %43 = call @triu(%18) : (tensor<3x3xf32>) -> tensor<3x3xf32> - return %42, %43 : tensor<3x3xf32>, tensor<3x3xf32> - } - func.func private @triu(%arg0: tensor<3x3xf32>) -> tensor<3x3xf32> { - %0 = stablehlo.iota dim = 0 : tensor<3x3xi32> - %1 = stablehlo.constant dense<-1> : tensor - %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<3x3xi32> - %3 = stablehlo.add %0, %2 : tensor<3x3xi32> - %4 = stablehlo.iota dim = 1 : tensor<3x3xi32> - %5 = stablehlo.compare GE, %3, %4, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> - %6 = stablehlo.constant dense<0.000000e+00> : tensor - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<3x3xf32> - %8 = stablehlo.select %5, %7, %arg0 : tensor<3x3xi1>, tensor<3x3xf32> - return %8 : tensor<3x3xf32> - } -} -""", - mlir_module_serialized=b"ML\xefR\x03MLIRxxx-trunk\x00\x01+\x05\x01\x05\x01\x03\x05\x03\x1b\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f\x03\xa2\x02\n\x027\x01\x9b\x0f\x0f\x17\x13\x0b\x0f\x13\x07\x0b\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0b\x13\x17\x13\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x13\x13\x1b\x13\x13\x0b33\x0b\x0f\x0b\x13\x0b\x13\x0f\x0b\x1b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0bK\x13\x13\x0f\x0b#\x0b\x0b\x0b\x0f\x0b\x0bK\x03g\x0fO/\x0b/\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0b\x1f\x1f\x1f\x1f\x0b\x1f\x0f\x17\x1b\x0f\x0f\x0f\x0f\x1f\x0b\x1fO/\x0b'\x0f\x17\x17\x01\x05\x17\x0b\x037\x0f\x17\x0f\x07\x07\x07\x07\x17\x13\x17\x17\x07\x0f\x17\x13\x17\x17\x13\x13\x1b\x13\x13\x13\x13\x13\x13\x17\x02\xae\t\x1d\x7f\x05\x1d\x97\x05\x17!\xee\x05\x01\x03\x03\x15\xcd\x05!\x1dY\x05\x03\x03\t\xd7\x1f\x05#\x05%\x05'\x03\x03\t\xf1\x05)\x05+\x05-\x05/\x051\x03\x03%\xc9\x053\x1da\x05\x055\x057\x03\x03\t\xd3\x17!\xea\x05\x01\x03\x03\t\xd5\x03\x03\t\xd9\x059\x05;\x05=\x05?\x05A\x05C\x05E\x05G\x03\x03\x11\xe5\x03\x03\x11\xe7\x03\x03\x11\xe9\x03\x03\t\xed\x03\x05)\xab+\xef\x03\x03\x15\xf3\x03\x03\x13S\x05I\x03\x0b\x19\xa1\x1b\xb3\x1d\xb5\x13\xbf\x1f\xc1\x03\x0b\x19\xa7\x1b\xc5\x1d\xa7\x13\xa9\x1f\xc7\x05K\x1d]\x05\x05M\x03\x03\t\xcb\x05O\x03\x03%\xcf\x1dg\x05\x05Q\x03\x05)\xab+\xd1\x1dm\x05\x05S\x1dq\x05\x05U\x1du\x05\x05W\x1dy/\x05Y\x1d}/\x05[\x05]\x03\x115\xad7\xaf9\xdb;\xa1=\xb1?\xddA\xdfC\xe3\x03\x03\x11\xeb\x03\x03\x15\xf5\x1d\x89\x05\x05_\x03\x07\x8d\xa3\x8f\xa3\x91\xa3\x05a\x05c\x05e\x1d\x95\x05\x05g\x05i\x03\x115\xad7\xaf9\xf7;\xa1=\xb1?\xf9A\xfbC\xff\x1f)\x01\x1f+!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f-\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1f\x1d\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1dk\x03\x03\xc3\x1dm\t\x07\x0b\x05\x1do\x05\x01#\x1f\x03\x05\xb7\xbb\r\x03\xa5\xb9\x1dq\r\x03\xa5\xbd\x1ds\x1du\x1dw\r\x01#!\x1dy\x13\x0b\x01\x1f\x01\t\xff\xff\xff\xff\x1f#\x01\x13\x0b\x05\x07\x05\x1f\x05\t\x00\x00\x00\x00\x1f\x01\t\x01\x00\x00\x00\x1f\x01\t\x03\x00\x00\x00\x1f\x01\t`\x00\x00\x00\x1d{\x03\x0b\x9b\x9b\x9b\x9b\x9d\x03\x03\xe1\x15\x03\x01\x11\x01\x03\t\x9d\x9f\x9b\x9f\x13\x07\x01\x13\x07\x05\x13\x07\t\x13\x07\r\x1f\x01\t\x00\x00\x00\x00\x07\x01\x1f\x05\t\x00\x00\xc0\x7f\x1f\x1d!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f3\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1d}\x03\x0f\x9b\x9b\x9b\x9b\x9b\x9d\x9f\x03\x03\xfd\x15\x03\x01\x15\x01\x03\x07\x9d\x9b\x9f\x03\x03\x06\x02\xa9\x05\x7f)\x01\x07)\x05\r\r\t)\x01\t\x1b\t\x1d\x01)\x05\r\r\x07)\x03\r\t)\x03\x02\x03\t)\x05\r\r\r\x13)\x01\r)\x05\x05\x05\r)\x03\t\x0b\x11\x01\x05\x03\x03\x11\x03\x03\x03\x03)\x03\x01\x0b)\x03%\t/\t\x03\x11\x01\x13)\x03\x01\x17)\x03\t\x17)\x03\x05\x17)\x03\x05\r)\x03\r\r)\x03\x05\x0b/\x07\x03\x01\x13\x04\xe6\x06\x05\x01\x11\x0fQ\x07\x03\x01\t\x0f\x11\x0fU\x05\x03Y\xb5\x0b\x03w#\x03%\x17\x06{\x03\x03\x03\x01\x03\x03\x011\x03\x01\x03\x03\x01\r\x03\x01\x03\x03\x01\r\x03\x01\x03\x03\x013\x03\x01\x13\x07\x01\x81\x03'\x0b\x05\x07\t\x0b\x03\x07\x07\x01E\x03\x03\x03\r\x07\x07\x01G\x03\x11\x03\r\x07\x07\x01I\x03\x01\x03\r\x07\x07\x01\x83\x03\x13\x03\r\x03\x03\x01K\x03\x01\x05\x07\x01\x07\x03\x01\x03\x17\r\x07\x01M\x03\x19\x05\x13\x19\x05\x07\x01\x07\x03\x1b\x03\x1b\x03\x03\x01\x17\x03\x05\x05\x07\x01\x07\x03\x03\x03\x1f\x05\x07\x01O\x03\x15\x03\x1d\t\x06\x01\x03\x03\x07#\x0f!\x05\x07\x01\x07\x03/\x03\x1b\x03\x03\x01\x17\x03\x05\x05\x07\x01\x07\x03\x11\x03)\x05\x07\x01\x85\x031\x03'\t\x06\x01\x03\x11\x07-\x11+\x03\x03\x87-\x03\x05\x19\x07\x93\x8b\x03\x03\x05%1\x03\x03\x031\x03\x01\x03\x03\x03\r\x03\x01\x03\x03\x03\r\x03\x01\x03\x03\x03\r\x03\x01\x03\x03\x033\x03\x01\x13\x07\x03\x99\x035\x0f579;=3/\x07\x07\x03E\x03\x03\x03?\x07\x07\x03G\x03\x01\x03?\x07\x07\x03I\x03\x13\x03?\x03\x03\x03K\x03\x01\x05\x07\x03\x07\x03\x01\x03G\r\x07\x03M\x03\x19\x05CI\x05\x07\x03\x07\x03\x1b\x03K\x03\x03\x03\x17\x03\x05\x05\x07\x03\x07\x03\x03\x03O\x05\x07\x03O\x03\x15\x03M\t\x06\x03\x03\x03\x07SAQ\x1b\x07\x0b\x02\x02\x03\x03\x03%\x11\x04\x0f\x05UW\x0f\x11\x0bW\x05\x03\x15+\x03\x03\x0f\x0b\x03[#\x03\x0f\x03\x03\x0b_\x03\x01\x05\x07'\x07\x03\x0f\x03\x05\x15\x06'\x03\x0f\x05\x03\x07\x0b\x03ec\x03\x0f\r\x07ki\x03\x15\x05\t\x0b\x03\x03\x0b-\x03\x05\x05\x07o\x07\x03\x03\x03\x0f\t\x06s\x03\x03\x07\r\x11\x01\x11\x04\x0b\x03\x13\x06\x03\x01\x05\x01\x00\xc6\x18\x81\x0f\x1d\x1d\x11\x0f\x0b\t\t\x03\x0b!Y\x87##%_=\x85\x87W\xb3K\x9bM\x9b\xd2\x02\x1b\x1f/!!)#\x1f\x19+\x1b\x1f\x83\x1f\x15\x1d\x15+\x13\r\r\x11\x0f\x17\x0f\x1f\x15\x11\x17\x11\x15+)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00get_tuple_element_v1\x00select_v1\x00iota_v1\x00compare_v1\x00func_v1\x00return_v1\x00custom_call_v1\x00add_v1\x00reshape_v1\x00pad_v1\x00call_v1\x00value\x00index\x00sym_name\x00broadcast_dimensions\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00iota_dimension\x00compare_type\x00comparison_direction\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]\x00jit()/jit(main)/jit(triu)/add\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]\x00jit()/jit(main)/jit(triu)/ge\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3, 3) broadcast_dimensions=()]\x00jit()/jit(main)/jit(triu)/select_n\x00jit()/jit(main)/iota[dtype=float32 shape=(9,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]\x00jit()/jit(main)/geqrf\x00jit()/jit(main)/qr[full_matrices=True]\x00edge_padding_high\x00edge_padding_low\x00interior_padding\x00jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0))]\x00jit()/jit(main)/householder_product\x00jax.result_info\x00triu\x00\x00[0]\x00[1]\x00main\x00public\x00private\x00lapack_sgeqrf\x00lapack_sorgqr\x00callee\x00", - xla_call_module_version=4, -) # End paste - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_03_17["f64"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_dgeqrf', 'lapack_dorgqr'], - serialized_date=datetime.date(2023, 3, 17), - inputs=(), - expected_outputs=(array([[ 0. , 0.9128709291752773 , 0.40824829046386235], - [-0.447213595499958 , 0.3651483716701102 , -0.8164965809277263 ], - [-0.894427190999916 , -0.1825741858350548 , 0.40824829046386324]]), array([[-6.7082039324993694e+00, -8.0498447189992444e+00, - -9.3914855054991175e+00], - [ 0.0000000000000000e+00, 1.0954451150103341e+00, - 2.1908902300206665e+00], - [ 0.0000000000000000e+00, 0.0000000000000000e+00, - -8.8817841970012523e-16]])), - mlir_module_text=r""" -module @jit__lambda_ { - func.func public @main() -> (tensor<3x3xf64> {jax.result_info = "[0]"}, tensor<3x3xf64> {jax.result_info = "[1]"}) { - %0 = stablehlo.iota dim = 0 : tensor<9xf64> - %1 = stablehlo.reshape %0 : (tensor<9xf64>) -> tensor<3x3xf64> - %2 = stablehlo.constant dense<1> : tensor - %3 = stablehlo.constant dense<3> : tensor - %4 = stablehlo.constant dense<3> : tensor - %5 = stablehlo.constant dense<96> : tensor - %6 = stablehlo.custom_call @lapack_dgeqrf(%2, %3, %4, %5, %1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor<3x3xf64>) -> tuple, tensor<3xf64>, tensor, tensor<96xf64>> - %7 = stablehlo.get_tuple_element %6[0] : (tuple, tensor<3xf64>, tensor, tensor<96xf64>>) -> tensor<3x3xf64> - %8 = stablehlo.get_tuple_element %6[1] : (tuple, tensor<3xf64>, tensor, tensor<96xf64>>) -> tensor<3xf64> - %9 = stablehlo.get_tuple_element %6[2] : (tuple, tensor<3xf64>, tensor, tensor<96xf64>>) -> tensor - %10 = stablehlo.get_tuple_element %6[3] : (tuple, tensor<3xf64>, tensor, tensor<96xf64>>) -> tensor<96xf64> - %11 = stablehlo.constant dense<0> : tensor - %12 = stablehlo.broadcast_in_dim %11, dims = [] : (tensor) -> tensor - %13 = stablehlo.compare EQ, %9, %12, SIGNED : (tensor, tensor) -> tensor - %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1x1xi1> - %15 = stablehlo.constant dense<0x7FF8000000000000> : tensor - %16 = stablehlo.broadcast_in_dim %15, dims = [] : (tensor) -> tensor<3x3xf64> - %17 = stablehlo.broadcast_in_dim %14, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> - %18 = stablehlo.select %17, %7, %16 : tensor<3x3xi1>, tensor<3x3xf64> - %19 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1xi1> - %20 = stablehlo.constant dense<0x7FF8000000000000> : tensor - %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<3xf64> - %22 = stablehlo.broadcast_in_dim %19, dims = [0] : (tensor<1xi1>) -> tensor<3xi1> - %23 = stablehlo.select %22, %8, %21 : tensor<3xi1>, tensor<3xf64> - %24 = stablehlo.constant dense<0.000000e+00> : tensor - %25 = stablehlo.pad %18, %24, low = [0, 0], high = [0, 0], interior = [0, 0] : (tensor<3x3xf64>, tensor) -> tensor<3x3xf64> - %26 = stablehlo.constant dense<1> : tensor - %27 = stablehlo.constant dense<3> : tensor - %28 = stablehlo.constant dense<3> : tensor - %29 = stablehlo.constant dense<3> : tensor - %30 = stablehlo.constant dense<96> : tensor - %31 = stablehlo.custom_call @lapack_dorgqr(%26, %27, %28, %29, %30, %25, %23) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor<3x3xf64>, tensor<3xf64>) -> tuple, tensor, tensor<96xf64>> - %32 = stablehlo.get_tuple_element %31[0] : (tuple, tensor, tensor<96xf64>>) -> tensor<3x3xf64> - %33 = stablehlo.get_tuple_element %31[1] : (tuple, tensor, tensor<96xf64>>) -> tensor - %34 = stablehlo.get_tuple_element %31[2] : (tuple, tensor, tensor<96xf64>>) -> tensor<96xf64> - %35 = stablehlo.constant dense<0> : tensor - %36 = stablehlo.broadcast_in_dim %35, dims = [] : (tensor) -> tensor - %37 = stablehlo.compare EQ, %33, %36, SIGNED : (tensor, tensor) -> tensor - %38 = stablehlo.broadcast_in_dim %37, dims = [] : (tensor) -> tensor<1x1xi1> - %39 = stablehlo.constant dense<0x7FF8000000000000> : tensor - %40 = stablehlo.broadcast_in_dim %39, dims = [] : (tensor) -> tensor<3x3xf64> - %41 = stablehlo.broadcast_in_dim %38, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> - %42 = stablehlo.select %41, %32, %40 : tensor<3x3xi1>, tensor<3x3xf64> - %43 = call @triu(%18) : (tensor<3x3xf64>) -> tensor<3x3xf64> - return %42, %43 : tensor<3x3xf64>, tensor<3x3xf64> - } - func.func private @triu(%arg0: tensor<3x3xf64>) -> tensor<3x3xf64> { - %0 = stablehlo.iota dim = 0 : tensor<3x3xi32> - %1 = stablehlo.constant dense<-1> : tensor - %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<3x3xi32> - %3 = stablehlo.add %0, %2 : tensor<3x3xi32> - %4 = stablehlo.iota dim = 1 : tensor<3x3xi32> - %5 = stablehlo.compare GE, %3, %4, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> - %6 = stablehlo.constant dense<0.000000e+00> : tensor - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<3x3xf64> - %8 = stablehlo.select %5, %7, %arg0 : tensor<3x3xi1>, tensor<3x3xf64> - return %8 : tensor<3x3xf64> - } -} -""", - mlir_module_serialized=b"ML\xefR\x03MLIRxxx-trunk\x00\x01+\x05\x01\x05\x01\x03\x05\x03\x1b\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f\x03\xa2\x02\n\x027\x01\x9b\x0f\x0f\x17\x13\x0b\x0f\x13\x07\x0b\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0b\x13\x17\x13\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x13\x13\x1b\x13\x13\x0b33\x0b\x0f\x0b\x13\x0b\x13\x0f\x0b\x1b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0bK\x13\x13\x0f\x0b#\x0b\x0b\x0b\x0f\x0b\x0bK\x03g\x0fO/\x0b/\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0b/\x1f\x1f\x1f\x0b\x1f\x0f\x17\x1b\x0f\x0f\x0f\x0f\x1f\x0b/O/\x0b'\x0f\x17\x17\x01\x05\x17\x0b\x037\x0f\x17\x0f\x07\x07\x07\x07\x17\x13\x17\x17\x07\x0f\x17\x13\x17\x17\x13\x13\x1b\x13\x13\x13\x13\x13\x13\x17\x02\xce\t\x1d\x7f\x05\x1d\x97\x05\x17!\xee\x05\x01\x03\x03\x15\xcd\x05!\x1dY\x05\x03\x03\t\xd7\x1f\x05#\x05%\x05'\x03\x03\t\xf1\x05)\x05+\x05-\x05/\x051\x03\x03%\xc9\x053\x1da\x05\x055\x057\x03\x03\t\xd3\x17!\xea\x05\x01\x03\x03\t\xd5\x03\x03\t\xd9\x059\x05;\x05=\x05?\x05A\x05C\x05E\x05G\x03\x03\x11\xe5\x03\x03\x11\xe7\x03\x03\x11\xe9\x03\x03\t\xed\x03\x05)\xab+\xef\x03\x03\x15\xf3\x03\x03\x13S\x05I\x03\x0b\x19\xa1\x1b\xb3\x1d\xb5\x13\xbf\x1f\xc1\x03\x0b\x19\xa7\x1b\xc5\x1d\xa7\x13\xa9\x1f\xc7\x05K\x1d]\x05\x05M\x03\x03\t\xcb\x05O\x03\x03%\xcf\x1dg\x05\x05Q\x03\x05)\xab+\xd1\x1dm\x05\x05S\x1dq\x05\x05U\x1du\x05\x05W\x1dy/\x05Y\x1d}/\x05[\x05]\x03\x115\xad7\xaf9\xdb;\xa1=\xb1?\xddA\xdfC\xe3\x03\x03\x11\xeb\x03\x03\x15\xf5\x1d\x89\x05\x05_\x03\x07\x8d\xa3\x8f\xa3\x91\xa3\x05a\x05c\x05e\x1d\x95\x05\x05g\x05i\x03\x115\xad7\xaf9\xf7;\xa1=\xb1?\xf9A\xfbC\xff\x1f)\x01\x1f+!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f-\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1f\x1d\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1dk\x03\x03\xc3\x1dm\t\x07\x0b\x05\x1do\x05\x01#\x1f\x03\x05\xb7\xbb\r\x03\xa5\xb9\x1dq\r\x03\xa5\xbd\x1ds\x1du\x1dw\r\x01#!\x1dy\x13\x0b\x01\x1f\x01\t\xff\xff\xff\xff\x1f#\x01\x13\x0b\x05\x07\x05\x1f\x05\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x01\t\x01\x00\x00\x00\x1f\x01\t\x03\x00\x00\x00\x1f\x01\t`\x00\x00\x00\x1d{\x03\x0b\x9b\x9b\x9b\x9b\x9d\x03\x03\xe1\x15\x03\x01\x11\x01\x03\t\x9d\x9f\x9b\x9f\x13\x07\x01\x13\x07\x05\x13\x07\t\x13\x07\r\x1f\x01\t\x00\x00\x00\x00\x07\x01\x1f\x05\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x1d!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f3\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1d}\x03\x0f\x9b\x9b\x9b\x9b\x9b\x9d\x9f\x03\x03\xfd\x15\x03\x01\x15\x01\x03\x07\x9d\x9b\x9f\x03\x03\x06\x02\xa9\x05\x7f)\x01\x07)\x05\r\r\t)\x01\t\x1b\x0b\x1d\x01)\x05\r\r\x07)\x03\r\t)\x03\x02\x03\t)\x05\r\r\r\x13)\x01\r)\x05\x05\x05\r)\x03\t\x0b\x11\x01\x05\x03\x03\x11\x03\x03\x03\x03)\x03\x01\x0b)\x03%\t/\t\x03\x11\x01\x13)\x03\x01\x17)\x03\t\x17)\x03\x05\x17)\x03\x05\r)\x03\r\r)\x03\x05\x0b/\x07\x03\x01\x13\x04\xe6\x06\x05\x01\x11\x0fQ\x07\x03\x01\t\x0f\x11\x0fU\x05\x03Y\xb5\x0b\x03w#\x03%\x17\x06{\x03\x03\x03\x01\x03\x03\x011\x03\x01\x03\x03\x01\r\x03\x01\x03\x03\x01\r\x03\x01\x03\x03\x013\x03\x01\x13\x07\x01\x81\x03'\x0b\x05\x07\t\x0b\x03\x07\x07\x01E\x03\x03\x03\r\x07\x07\x01G\x03\x11\x03\r\x07\x07\x01I\x03\x01\x03\r\x07\x07\x01\x83\x03\x13\x03\r\x03\x03\x01K\x03\x01\x05\x07\x01\x07\x03\x01\x03\x17\r\x07\x01M\x03\x19\x05\x13\x19\x05\x07\x01\x07\x03\x1b\x03\x1b\x03\x03\x01\x17\x03\x05\x05\x07\x01\x07\x03\x03\x03\x1f\x05\x07\x01O\x03\x15\x03\x1d\t\x06\x01\x03\x03\x07#\x0f!\x05\x07\x01\x07\x03/\x03\x1b\x03\x03\x01\x17\x03\x05\x05\x07\x01\x07\x03\x11\x03)\x05\x07\x01\x85\x031\x03'\t\x06\x01\x03\x11\x07-\x11+\x03\x03\x87-\x03\x05\x19\x07\x93\x8b\x03\x03\x05%1\x03\x03\x031\x03\x01\x03\x03\x03\r\x03\x01\x03\x03\x03\r\x03\x01\x03\x03\x03\r\x03\x01\x03\x03\x033\x03\x01\x13\x07\x03\x99\x035\x0f579;=3/\x07\x07\x03E\x03\x03\x03?\x07\x07\x03G\x03\x01\x03?\x07\x07\x03I\x03\x13\x03?\x03\x03\x03K\x03\x01\x05\x07\x03\x07\x03\x01\x03G\r\x07\x03M\x03\x19\x05CI\x05\x07\x03\x07\x03\x1b\x03K\x03\x03\x03\x17\x03\x05\x05\x07\x03\x07\x03\x03\x03O\x05\x07\x03O\x03\x15\x03M\t\x06\x03\x03\x03\x07SAQ\x1b\x07\x0b\x02\x02\x03\x03\x03%\x11\x04\x0f\x05UW\x0f\x11\x0bW\x05\x03\x15+\x03\x03\x0f\x0b\x03[#\x03\x0f\x03\x03\x0b_\x03\x01\x05\x07'\x07\x03\x0f\x03\x05\x15\x06'\x03\x0f\x05\x03\x07\x0b\x03ec\x03\x0f\r\x07ki\x03\x15\x05\t\x0b\x03\x03\x0b-\x03\x05\x05\x07o\x07\x03\x03\x03\x0f\t\x06s\x03\x03\x07\r\x11\x01\x11\x04\x0b\x03\x13\x06\x03\x01\x05\x01\x00\xc6\x18\x81\x0f\x1d\x1d\x11\x0f\x0b\t\t\x03\x0b!Y\x87##%_=\x85\x87W\xb3K\x9bM\x9b\xd2\x02\x1b\x1f/!!)#\x1f\x19+\x1b\x1f\x83\x1f\x15\x1d\x15+\x13\r\r\x11\x0f\x17\x0f\x1f\x15\x11\x17\x11\x15+)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00get_tuple_element_v1\x00select_v1\x00iota_v1\x00compare_v1\x00func_v1\x00return_v1\x00custom_call_v1\x00add_v1\x00reshape_v1\x00pad_v1\x00call_v1\x00value\x00index\x00sym_name\x00broadcast_dimensions\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00iota_dimension\x00compare_type\x00comparison_direction\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]\x00jit()/jit(main)/jit(triu)/add\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]\x00jit()/jit(main)/jit(triu)/ge\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3, 3) broadcast_dimensions=()]\x00jit()/jit(main)/jit(triu)/select_n\x00jit()/jit(main)/iota[dtype=float64 shape=(9,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]\x00jit()/jit(main)/geqrf\x00jit()/jit(main)/qr[full_matrices=True]\x00edge_padding_high\x00edge_padding_low\x00interior_padding\x00jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0))]\x00jit()/jit(main)/householder_product\x00jax.result_info\x00triu\x00\x00[0]\x00[1]\x00main\x00public\x00private\x00lapack_dgeqrf\x00lapack_dorgqr\x00callee\x00", - xla_call_module_version=4, -) # End paste - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_03_17["c64"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_cgeqrf', 'lapack_cungqr'], - serialized_date=datetime.date(2023, 3, 17), - inputs=(), - expected_outputs=(array([[ 0. +0.j, 0.91287076+0.j, 0.4082487 +0.j], - [-0.44721356-0.j, 0.36514866+0.j, -0.8164965 +0.j], - [-0.8944271 -0.j, -0.18257445+0.j, 0.40824816+0.j]], - dtype=complex64), array([[-6.7082043e+00+0.j, -8.0498438e+00+0.j, -9.3914852e+00+0.j], - [ 0.0000000e+00+0.j, 1.0954441e+00+0.j, 2.1908894e+00+0.j], - [ 0.0000000e+00+0.j, 0.0000000e+00+0.j, 7.1525574e-07+0.j]], - dtype=complex64)), - mlir_module_text=r""" -module @jit__lambda_ { - func.func public @main() -> (tensor<3x3xcomplex> {jax.result_info = "[0]"}, tensor<3x3xcomplex> {jax.result_info = "[1]"}) { - %0 = stablehlo.iota dim = 0 : tensor<9xcomplex> - %1 = stablehlo.reshape %0 : (tensor<9xcomplex>) -> tensor<3x3xcomplex> - %2 = stablehlo.constant dense<1> : tensor - %3 = stablehlo.constant dense<3> : tensor - %4 = stablehlo.constant dense<3> : tensor - %5 = stablehlo.constant dense<96> : tensor - %6 = stablehlo.custom_call @lapack_cgeqrf(%2, %3, %4, %5, %1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor<3x3xcomplex>) -> tuple>, tensor<3xcomplex>, tensor, tensor<96xcomplex>> - %7 = stablehlo.get_tuple_element %6[0] : (tuple>, tensor<3xcomplex>, tensor, tensor<96xcomplex>>) -> tensor<3x3xcomplex> - %8 = stablehlo.get_tuple_element %6[1] : (tuple>, tensor<3xcomplex>, tensor, tensor<96xcomplex>>) -> tensor<3xcomplex> - %9 = stablehlo.get_tuple_element %6[2] : (tuple>, tensor<3xcomplex>, tensor, tensor<96xcomplex>>) -> tensor - %10 = stablehlo.get_tuple_element %6[3] : (tuple>, tensor<3xcomplex>, tensor, tensor<96xcomplex>>) -> tensor<96xcomplex> - %11 = stablehlo.constant dense<0> : tensor - %12 = stablehlo.broadcast_in_dim %11, dims = [] : (tensor) -> tensor - %13 = stablehlo.compare EQ, %9, %12, SIGNED : (tensor, tensor) -> tensor - %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1x1xi1> - %15 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> - %16 = stablehlo.broadcast_in_dim %15, dims = [] : (tensor>) -> tensor<3x3xcomplex> - %17 = stablehlo.broadcast_in_dim %14, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> - %18 = stablehlo.select %17, %7, %16 : tensor<3x3xi1>, tensor<3x3xcomplex> - %19 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1xi1> - %20 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> - %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor>) -> tensor<3xcomplex> - %22 = stablehlo.broadcast_in_dim %19, dims = [0] : (tensor<1xi1>) -> tensor<3xi1> - %23 = stablehlo.select %22, %8, %21 : tensor<3xi1>, tensor<3xcomplex> - %24 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> - %25 = stablehlo.pad %18, %24, low = [0, 0], high = [0, 0], interior = [0, 0] : (tensor<3x3xcomplex>, tensor>) -> tensor<3x3xcomplex> - %26 = stablehlo.constant dense<1> : tensor - %27 = stablehlo.constant dense<3> : tensor - %28 = stablehlo.constant dense<3> : tensor - %29 = stablehlo.constant dense<3> : tensor - %30 = stablehlo.constant dense<96> : tensor - %31 = stablehlo.custom_call @lapack_cungqr(%26, %27, %28, %29, %30, %25, %23) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor<3x3xcomplex>, tensor<3xcomplex>) -> tuple>, tensor, tensor<96xcomplex>> - %32 = stablehlo.get_tuple_element %31[0] : (tuple>, tensor, tensor<96xcomplex>>) -> tensor<3x3xcomplex> - %33 = stablehlo.get_tuple_element %31[1] : (tuple>, tensor, tensor<96xcomplex>>) -> tensor - %34 = stablehlo.get_tuple_element %31[2] : (tuple>, tensor, tensor<96xcomplex>>) -> tensor<96xcomplex> - %35 = stablehlo.constant dense<0> : tensor - %36 = stablehlo.broadcast_in_dim %35, dims = [] : (tensor) -> tensor - %37 = stablehlo.compare EQ, %33, %36, SIGNED : (tensor, tensor) -> tensor - %38 = stablehlo.broadcast_in_dim %37, dims = [] : (tensor) -> tensor<1x1xi1> - %39 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> - %40 = stablehlo.broadcast_in_dim %39, dims = [] : (tensor>) -> tensor<3x3xcomplex> - %41 = stablehlo.broadcast_in_dim %38, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> - %42 = stablehlo.select %41, %32, %40 : tensor<3x3xi1>, tensor<3x3xcomplex> - %43 = call @triu(%18) : (tensor<3x3xcomplex>) -> tensor<3x3xcomplex> - return %42, %43 : tensor<3x3xcomplex>, tensor<3x3xcomplex> - } - func.func private @triu(%arg0: tensor<3x3xcomplex>) -> tensor<3x3xcomplex> { - %0 = stablehlo.iota dim = 0 : tensor<3x3xi32> - %1 = stablehlo.constant dense<-1> : tensor - %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<3x3xi32> - %3 = stablehlo.add %0, %2 : tensor<3x3xi32> - %4 = stablehlo.iota dim = 1 : tensor<3x3xi32> - %5 = stablehlo.compare GE, %3, %4, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> - %6 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor>) -> tensor<3x3xcomplex> - %8 = stablehlo.select %5, %7, %arg0 : tensor<3x3xi1>, tensor<3x3xcomplex> - return %8 : tensor<3x3xcomplex> - } -} -""", - mlir_module_serialized=b"ML\xefR\x03MLIRxxx-trunk\x00\x01+\x05\x01\x05\x01\x03\x05\x03\x1b\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f\x03\xa6\x02\n\x029\x01\x9b\x0f\x0f\x17\x13\x0b\x0f\x13\x07\x0b\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0b\x13\x17\x13\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x13\x13\x1b\x13\x13\x0b33\x0b\x0f\x0b\x13\x0b\x13\x0f\x0b\x1b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0bK\x13\x13\x0f\x0b#\x0b\x0b\x0b\x0f\x0b\x0bK\x03g\x0fO/\x0b/\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0b/\x1f\x1f\x1f\x0b\x1f\x0f\x17\x1b\x0f\x0f\x0f\x0f\x1f\x0b/O/\x0b'\x0f\x17\x17\x01\x05\x17\x0b\x039\x0f\x17\x0f\x07\x0b\x07\x07\x17\x13\x17\x17\x07\x0f\x17\x13\x17\x07\x17\x13\x13\x1b\x13\x13\x13\x13\x13\x13\x17\x02\xd6\t\x1d\x7f\x05\x1d\x97\x05\x17!\xee\x05\x01\x03\x03\x15\xcd\x05!\x1dY\x05\x03\x03\t\xd7\x1f\x05#\x05%\x05'\x03\x03\t\xf1\x05)\x05+\x05-\x05/\x051\x03\x03%\xc9\x053\x1da\x05\x055\x057\x03\x03\t\xd3\x17!\xea\x05\x01\x03\x03\t\xd5\x03\x03\t\xd9\x059\x05;\x05=\x05?\x05A\x05C\x05E\x05G\x03\x03\x11\xe5\x03\x03\x11\xe7\x03\x03\x11\xe9\x03\x03\t\xed\x03\x05)\xab+\xef\x03\x03\x15\xf3\x03\x03\x13S\x05I\x03\x0b\x19\xa1\x1b\xb3\x1d\xb5\x13\xbf\x1f\xc1\x03\x0b\x19\xa7\x1b\xc5\x1d\xa7\x13\xa9\x1f\xc7\x05K\x1d]\x05\x05M\x03\x03\t\xcb\x05O\x03\x03%\xcf\x1dg\x05\x05Q\x03\x05)\xab+\xd1\x1dm\x05\x05S\x1dq\x05\x05U\x1du\x05\x05W\x1dy/\x05Y\x1d}/\x05[\x05]\x03\x115\xad7\xaf9\xdb;\xa1=\xb1?\xddA\xdfC\xe3\x03\x03\x11\xeb\x03\x03\x15\xf5\x1d\x89\x05\x05_\x03\x07\x8d\xa3\x8f\xa3\x91\xa3\x05a\x05c\x05e\x1d\x95\x05\x05g\x05i\x03\x115\xad7\xaf9\xf7;\xa1=\xb1?\xf9A\xfbC\xff\x1f+\x01\x1f-!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f/\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1f\x1d\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1dk\x03\x03\xc3\x1dm\t\x07\x0b\x05\x1do\x05\x01#\x1f\x03\x05\xb7\xbb\r\x03\xa5\xb9\x1dq\r\x03\xa5\xbd\x1ds\x1du\x1dw\r\x01##\x1dy\x13\x0b\x01\x1f\x01\t\xff\xff\xff\xff\x1f%\x01\x13\x0b\x05\x07\x05\x1f\x05\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x01\t\x01\x00\x00\x00\x1f\x01\t\x03\x00\x00\x00\x1f\x01\t`\x00\x00\x00\x1d{\x03\x0b\x9b\x9b\x9b\x9b\x9d\x03\x03\xe1\x15\x03\x01\x11\x01\x03\t\x9d\x9f\x9b\x9f\x13\x07\x01\x13\x07\x05\x13\x07\t\x13\x07\r\x1f\x01\t\x00\x00\x00\x00\x07\x01\x1f\x05\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f\x1d!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f5\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1d}\x03\x0f\x9b\x9b\x9b\x9b\x9b\x9d\x9f\x03\x03\xfd\x15\x03\x01\x15\x01\x03\x07\x9d\x9b\x9f\x03\x03\x06\x02\xa9\x05\x7f)\x01\x07)\x05\r\r\t)\x01\t\x1b\x03!\x1d\x01)\x05\r\r\x07)\x03\r\t)\x03\x02\x03\t)\x05\r\r\r\x13)\x01\r)\x05\x05\x05\r)\x03\t\x0b\x11\x01\x05\x03\x03\t\x11\x03\x03\x03\x03)\x03\x01\x0b)\x03%\t/\t\x03\x11\x01\x13)\x03\x01\x17)\x03\t\x17)\x03\x05\x17)\x03\x05\r)\x03\r\r)\x03\x05\x0b/\x07\x03\x01\x13\x04\xe6\x06\x05\x01\x11\x0fQ\x07\x03\x01\t\x0f\x11\x0fU\x05\x03Y\xb5\x0b\x03w#\x03'\x17\x06{\x03\x03\x03\x01\x03\x03\x011\x03\x01\x03\x03\x01\r\x03\x01\x03\x03\x01\r\x03\x01\x03\x03\x013\x03\x01\x13\x07\x01\x81\x03)\x0b\x05\x07\t\x0b\x03\x07\x07\x01E\x03\x03\x03\r\x07\x07\x01G\x03\x11\x03\r\x07\x07\x01I\x03\x01\x03\r\x07\x07\x01\x83\x03\x13\x03\r\x03\x03\x01K\x03\x01\x05\x07\x01\x07\x03\x01\x03\x17\r\x07\x01M\x03\x19\x05\x13\x19\x05\x07\x01\x07\x03\x1b\x03\x1b\x03\x03\x01\x17\x03\x05\x05\x07\x01\x07\x03\x03\x03\x1f\x05\x07\x01O\x03\x15\x03\x1d\t\x06\x01\x03\x03\x07#\x0f!\x05\x07\x01\x07\x031\x03\x1b\x03\x03\x01\x17\x03\x05\x05\x07\x01\x07\x03\x11\x03)\x05\x07\x01\x85\x033\x03'\t\x06\x01\x03\x11\x07-\x11+\x03\x03\x87-\x03\x05\x19\x07\x93\x8b\x03\x03\x05%1\x03\x03\x031\x03\x01\x03\x03\x03\r\x03\x01\x03\x03\x03\r\x03\x01\x03\x03\x03\r\x03\x01\x03\x03\x033\x03\x01\x13\x07\x03\x99\x037\x0f579;=3/\x07\x07\x03E\x03\x03\x03?\x07\x07\x03G\x03\x01\x03?\x07\x07\x03I\x03\x13\x03?\x03\x03\x03K\x03\x01\x05\x07\x03\x07\x03\x01\x03G\r\x07\x03M\x03\x19\x05CI\x05\x07\x03\x07\x03\x1b\x03K\x03\x03\x03\x17\x03\x05\x05\x07\x03\x07\x03\x03\x03O\x05\x07\x03O\x03\x15\x03M\t\x06\x03\x03\x03\x07SAQ\x1b\x07\x0b\x02\x02\x03\x03\x03%\x11\x04\x0f\x05UW\x0f\x11\x0bW\x05\x03\x15+\x03\x03\x0f\x0b\x03[#\x03\x0f\x03\x03\x0b_\x03\x01\x05\x07'\x07\x03\x0f\x03\x05\x15\x06'\x03\x0f\x05\x03\x07\x0b\x03ec\x03\x0f\r\x07ki\x03\x15\x05\t\x0b\x03\x03\x0b-\x03\x05\x05\x07o\x07\x03\x03\x03\x0f\t\x06s\x03\x03\x07\r\x11\x01\x11\x04\x0b\x03\x13\x06\x03\x01\x05\x01\x00\xce\x18\x81\x0f\x1d\x1d\x11\x0f\x0b\t\t\x03\x0b!Y\x87##%_=\x85\x8bW\xb3K\x9bM\x9b\xd2\x02\x1b\x1f/!!)#\x1f\x19+\x1b\x1f\x83\x1f\x15\x1d\x15+\x13\r\r\x11\x0f\x17\x0f\x1f\x15\x11\x17\x11\x15+)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00get_tuple_element_v1\x00select_v1\x00iota_v1\x00compare_v1\x00func_v1\x00return_v1\x00custom_call_v1\x00add_v1\x00reshape_v1\x00pad_v1\x00call_v1\x00value\x00index\x00sym_name\x00broadcast_dimensions\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00iota_dimension\x00compare_type\x00comparison_direction\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]\x00jit()/jit(main)/jit(triu)/add\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]\x00jit()/jit(main)/jit(triu)/ge\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3, 3) broadcast_dimensions=()]\x00jit()/jit(main)/jit(triu)/select_n\x00jit()/jit(main)/iota[dtype=complex64 shape=(9,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]\x00jit()/jit(main)/geqrf\x00jit()/jit(main)/qr[full_matrices=True]\x00edge_padding_high\x00edge_padding_low\x00interior_padding\x00jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0))]\x00jit()/jit(main)/householder_product\x00jax.result_info\x00triu\x00\x00[0]\x00[1]\x00main\x00public\x00private\x00lapack_cgeqrf\x00lapack_cungqr\x00callee\x00", - xla_call_module_version=4, -) # End paste - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_03_17["c128"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_zgeqrf', 'lapack_zungqr'], - serialized_date=datetime.date(2023, 3, 17), + custom_call_targets=['lapack_zgeqrf_ffi', 'lapack_zungqr_ffi'], + serialized_date=datetime.date(2025, 4, 2), inputs=(), expected_outputs=(array([[ 0. +0.j, 0.9128709291752773 +0.j, 0.40824829046386235+0.j], @@ -283,531 +37,199 @@ [ 0.0000000000000000e+00+0.j, 0.0000000000000000e+00+0.j, -8.8817841970012523e-16+0.j]])), mlir_module_text=r""" -module @jit__lambda_ { - func.func public @main() -> (tensor<3x3xcomplex> {jax.result_info = "[0]"}, tensor<3x3xcomplex> {jax.result_info = "[1]"}) { - %0 = stablehlo.iota dim = 0 : tensor<9xcomplex> - %1 = stablehlo.reshape %0 : (tensor<9xcomplex>) -> tensor<3x3xcomplex> - %2 = stablehlo.constant dense<1> : tensor - %3 = stablehlo.constant dense<3> : tensor - %4 = stablehlo.constant dense<3> : tensor - %5 = stablehlo.constant dense<96> : tensor - %6 = stablehlo.custom_call @lapack_zgeqrf(%2, %3, %4, %5, %1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor<3x3xcomplex>) -> tuple>, tensor<3xcomplex>, tensor, tensor<96xcomplex>> - %7 = stablehlo.get_tuple_element %6[0] : (tuple>, tensor<3xcomplex>, tensor, tensor<96xcomplex>>) -> tensor<3x3xcomplex> - %8 = stablehlo.get_tuple_element %6[1] : (tuple>, tensor<3xcomplex>, tensor, tensor<96xcomplex>>) -> tensor<3xcomplex> - %9 = stablehlo.get_tuple_element %6[2] : (tuple>, tensor<3xcomplex>, tensor, tensor<96xcomplex>>) -> tensor - %10 = stablehlo.get_tuple_element %6[3] : (tuple>, tensor<3xcomplex>, tensor, tensor<96xcomplex>>) -> tensor<96xcomplex> - %11 = stablehlo.constant dense<0> : tensor - %12 = stablehlo.broadcast_in_dim %11, dims = [] : (tensor) -> tensor - %13 = stablehlo.compare EQ, %9, %12, SIGNED : (tensor, tensor) -> tensor - %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1x1xi1> - %15 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> - %16 = stablehlo.broadcast_in_dim %15, dims = [] : (tensor>) -> tensor<3x3xcomplex> - %17 = stablehlo.broadcast_in_dim %14, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> - %18 = stablehlo.select %17, %7, %16 : tensor<3x3xi1>, tensor<3x3xcomplex> - %19 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1xi1> - %20 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> - %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor>) -> tensor<3xcomplex> - %22 = stablehlo.broadcast_in_dim %19, dims = [0] : (tensor<1xi1>) -> tensor<3xi1> - %23 = stablehlo.select %22, %8, %21 : tensor<3xi1>, tensor<3xcomplex> - %24 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> - %25 = stablehlo.pad %18, %24, low = [0, 0], high = [0, 0], interior = [0, 0] : (tensor<3x3xcomplex>, tensor>) -> tensor<3x3xcomplex> - %26 = stablehlo.constant dense<1> : tensor - %27 = stablehlo.constant dense<3> : tensor - %28 = stablehlo.constant dense<3> : tensor - %29 = stablehlo.constant dense<3> : tensor - %30 = stablehlo.constant dense<96> : tensor - %31 = stablehlo.custom_call @lapack_zungqr(%26, %27, %28, %29, %30, %25, %23) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor<3x3xcomplex>, tensor<3xcomplex>) -> tuple>, tensor, tensor<96xcomplex>> - %32 = stablehlo.get_tuple_element %31[0] : (tuple>, tensor, tensor<96xcomplex>>) -> tensor<3x3xcomplex> - %33 = stablehlo.get_tuple_element %31[1] : (tuple>, tensor, tensor<96xcomplex>>) -> tensor - %34 = stablehlo.get_tuple_element %31[2] : (tuple>, tensor, tensor<96xcomplex>>) -> tensor<96xcomplex> - %35 = stablehlo.constant dense<0> : tensor - %36 = stablehlo.broadcast_in_dim %35, dims = [] : (tensor) -> tensor - %37 = stablehlo.compare EQ, %33, %36, SIGNED : (tensor, tensor) -> tensor - %38 = stablehlo.broadcast_in_dim %37, dims = [] : (tensor) -> tensor<1x1xi1> - %39 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> - %40 = stablehlo.broadcast_in_dim %39, dims = [] : (tensor>) -> tensor<3x3xcomplex> - %41 = stablehlo.broadcast_in_dim %38, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> - %42 = stablehlo.select %41, %32, %40 : tensor<3x3xi1>, tensor<3x3xcomplex> - %43 = call @triu(%18) : (tensor<3x3xcomplex>) -> tensor<3x3xcomplex> - return %42, %43 : tensor<3x3xcomplex>, tensor<3x3xcomplex> - } - func.func private @triu(%arg0: tensor<3x3xcomplex>) -> tensor<3x3xcomplex> { - %0 = stablehlo.iota dim = 0 : tensor<3x3xi32> - %1 = stablehlo.constant dense<-1> : tensor - %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<3x3xi32> - %3 = stablehlo.add %0, %2 : tensor<3x3xi32> - %4 = stablehlo.iota dim = 1 : tensor<3x3xi32> - %5 = stablehlo.compare GE, %3, %4, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> - %6 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor>) -> tensor<3x3xcomplex> - %8 = stablehlo.select %5, %7, %arg0 : tensor<3x3xi1>, tensor<3x3xcomplex> - return %8 : tensor<3x3xcomplex> - } -} -""", - mlir_module_serialized=b"ML\xefR\x03MLIRxxx-trunk\x00\x01+\x05\x01\x05\x01\x03\x05\x03\x1b\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f\x03\xa6\x02\n\x029\x01\x9b\x0f\x0f\x17\x13\x0b\x0f\x13\x07\x0b\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0b\x13\x17\x13\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x13\x13\x1b\x13\x13\x0b33\x0b\x0f\x0b\x13\x0b\x13\x0f\x0b\x1b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0bK\x13\x13\x0f\x0b#\x0b\x0b\x0b\x0f\x0b\x0bK\x03g\x0fO/\x0b/\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0bO\x1f\x1f\x1f\x0b\x1f\x0f\x17\x1b\x0f\x0f\x0f\x0f\x1f\x0bOO/\x0b'\x0f\x17\x17\x01\x05\x17\x0b\x039\x0f\x17\x0f\x07\x0b\x07\x07\x17\x13\x17\x17\x07\x0f\x17\x13\x17\x07\x17\x13\x13\x1b\x13\x13\x13\x13\x13\x13\x17\x02\x16\n\x1d\x7f\x05\x1d\x97\x05\x17!\xee\x05\x01\x03\x03\x15\xcd\x05!\x1dY\x05\x03\x03\t\xd7\x1f\x05#\x05%\x05'\x03\x03\t\xf1\x05)\x05+\x05-\x05/\x051\x03\x03%\xc9\x053\x1da\x05\x055\x057\x03\x03\t\xd3\x17!\xea\x05\x01\x03\x03\t\xd5\x03\x03\t\xd9\x059\x05;\x05=\x05?\x05A\x05C\x05E\x05G\x03\x03\x11\xe5\x03\x03\x11\xe7\x03\x03\x11\xe9\x03\x03\t\xed\x03\x05)\xab+\xef\x03\x03\x15\xf3\x03\x03\x13S\x05I\x03\x0b\x19\xa1\x1b\xb3\x1d\xb5\x13\xbf\x1f\xc1\x03\x0b\x19\xa7\x1b\xc5\x1d\xa7\x13\xa9\x1f\xc7\x05K\x1d]\x05\x05M\x03\x03\t\xcb\x05O\x03\x03%\xcf\x1dg\x05\x05Q\x03\x05)\xab+\xd1\x1dm\x05\x05S\x1dq\x05\x05U\x1du\x05\x05W\x1dy/\x05Y\x1d}/\x05[\x05]\x03\x115\xad7\xaf9\xdb;\xa1=\xb1?\xddA\xdfC\xe3\x03\x03\x11\xeb\x03\x03\x15\xf5\x1d\x89\x05\x05_\x03\x07\x8d\xa3\x8f\xa3\x91\xa3\x05a\x05c\x05e\x1d\x95\x05\x05g\x05i\x03\x115\xad7\xaf9\xf7;\xa1=\xb1?\xf9A\xfbC\xff\x1f+\x01\x1f-!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f/\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1f\x1d\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1dk\x03\x03\xc3\x1dm\t\x07\x0b\x05\x1do\x05\x01#\x1f\x03\x05\xb7\xbb\r\x03\xa5\xb9\x1dq\r\x03\xa5\xbd\x1ds\x1du\x1dw\r\x01##\x1dy\x13\x0b\x01\x1f\x01\t\xff\xff\xff\xff\x1f%\x01\x13\x0b\x05\x07\x05\x1f\x05!\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x01\t\x01\x00\x00\x00\x1f\x01\t\x03\x00\x00\x00\x1f\x01\t`\x00\x00\x00\x1d{\x03\x0b\x9b\x9b\x9b\x9b\x9d\x03\x03\xe1\x15\x03\x01\x11\x01\x03\t\x9d\x9f\x9b\x9f\x13\x07\x01\x13\x07\x05\x13\x07\t\x13\x07\r\x1f\x01\t\x00\x00\x00\x00\x07\x01\x1f\x05!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x1d!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f5\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1d}\x03\x0f\x9b\x9b\x9b\x9b\x9b\x9d\x9f\x03\x03\xfd\x15\x03\x01\x15\x01\x03\x07\x9d\x9b\x9f\x03\x03\x06\x02\xa9\x05\x7f)\x01\x07)\x05\r\r\t)\x01\t\x1b\x03!\x1d\x01)\x05\r\r\x07)\x03\r\t)\x03\x02\x03\t)\x05\r\r\r\x13)\x01\r)\x05\x05\x05\r)\x03\t\x0b\x11\x01\x05\x03\x03\x0b\x11\x03\x03\x03\x03)\x03\x01\x0b)\x03%\t/\t\x03\x11\x01\x13)\x03\x01\x17)\x03\t\x17)\x03\x05\x17)\x03\x05\r)\x03\r\r)\x03\x05\x0b/\x07\x03\x01\x13\x04\xe6\x06\x05\x01\x11\x0fQ\x07\x03\x01\t\x0f\x11\x0fU\x05\x03Y\xb5\x0b\x03w#\x03'\x17\x06{\x03\x03\x03\x01\x03\x03\x011\x03\x01\x03\x03\x01\r\x03\x01\x03\x03\x01\r\x03\x01\x03\x03\x013\x03\x01\x13\x07\x01\x81\x03)\x0b\x05\x07\t\x0b\x03\x07\x07\x01E\x03\x03\x03\r\x07\x07\x01G\x03\x11\x03\r\x07\x07\x01I\x03\x01\x03\r\x07\x07\x01\x83\x03\x13\x03\r\x03\x03\x01K\x03\x01\x05\x07\x01\x07\x03\x01\x03\x17\r\x07\x01M\x03\x19\x05\x13\x19\x05\x07\x01\x07\x03\x1b\x03\x1b\x03\x03\x01\x17\x03\x05\x05\x07\x01\x07\x03\x03\x03\x1f\x05\x07\x01O\x03\x15\x03\x1d\t\x06\x01\x03\x03\x07#\x0f!\x05\x07\x01\x07\x031\x03\x1b\x03\x03\x01\x17\x03\x05\x05\x07\x01\x07\x03\x11\x03)\x05\x07\x01\x85\x033\x03'\t\x06\x01\x03\x11\x07-\x11+\x03\x03\x87-\x03\x05\x19\x07\x93\x8b\x03\x03\x05%1\x03\x03\x031\x03\x01\x03\x03\x03\r\x03\x01\x03\x03\x03\r\x03\x01\x03\x03\x03\r\x03\x01\x03\x03\x033\x03\x01\x13\x07\x03\x99\x037\x0f579;=3/\x07\x07\x03E\x03\x03\x03?\x07\x07\x03G\x03\x01\x03?\x07\x07\x03I\x03\x13\x03?\x03\x03\x03K\x03\x01\x05\x07\x03\x07\x03\x01\x03G\r\x07\x03M\x03\x19\x05CI\x05\x07\x03\x07\x03\x1b\x03K\x03\x03\x03\x17\x03\x05\x05\x07\x03\x07\x03\x03\x03O\x05\x07\x03O\x03\x15\x03M\t\x06\x03\x03\x03\x07SAQ\x1b\x07\x0b\x02\x02\x03\x03\x03%\x11\x04\x0f\x05UW\x0f\x11\x0bW\x05\x03\x15+\x03\x03\x0f\x0b\x03[#\x03\x0f\x03\x03\x0b_\x03\x01\x05\x07'\x07\x03\x0f\x03\x05\x15\x06'\x03\x0f\x05\x03\x07\x0b\x03ec\x03\x0f\r\x07ki\x03\x15\x05\t\x0b\x03\x03\x0b-\x03\x05\x05\x07o\x07\x03\x03\x03\x0f\t\x06s\x03\x03\x07\r\x11\x01\x11\x04\x0b\x03\x13\x06\x03\x01\x05\x01\x00\xd2\x18\x81\x0f\x1d\x1d\x11\x0f\x0b\t\t\x03\x0b!Y\x87##%_=\x85\x8dW\xb3K\x9bM\x9b\xd2\x02\x1b\x1f/!!)#\x1f\x19+\x1b\x1f\x83\x1f\x15\x1d\x15+\x13\r\r\x11\x0f\x17\x0f\x1f\x15\x11\x17\x11\x15+)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00get_tuple_element_v1\x00select_v1\x00iota_v1\x00compare_v1\x00func_v1\x00return_v1\x00custom_call_v1\x00add_v1\x00reshape_v1\x00pad_v1\x00call_v1\x00value\x00index\x00sym_name\x00broadcast_dimensions\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00iota_dimension\x00compare_type\x00comparison_direction\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]\x00jit()/jit(main)/jit(triu)/add\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]\x00jit()/jit(main)/jit(triu)/ge\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3, 3) broadcast_dimensions=()]\x00jit()/jit(main)/jit(triu)/select_n\x00jit()/jit(main)/iota[dtype=complex128 shape=(9,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]\x00jit()/jit(main)/geqrf\x00jit()/jit(main)/qr[full_matrices=True]\x00edge_padding_high\x00edge_padding_low\x00interior_padding\x00jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0))]\x00jit()/jit(main)/householder_product\x00jax.result_info\x00triu\x00\x00[0]\x00[1]\x00main\x00public\x00private\x00lapack_zgeqrf\x00lapack_zungqr\x00callee\x00", - xla_call_module_version=4, -) # End paste - - -data_2024_08_22 = {} - - -# Pasted from the test output (see export_back_compat_test_util.py module docstring) -data_2024_08_22['c128'] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_zgeqrf_ffi', 'lapack_zungqr_ffi'], - serialized_date=datetime.date(2024, 8, 22), - inputs=(), - expected_outputs=( - array([ - [0.0 + 0.0j, 0.9128709291752773 + 0.0j, 0.40824829046386235 + 0.0j], - [ - -0.447213595499958 - 0.0j, - 0.3651483716701102 + 0.0j, - -0.8164965809277263 + 0.0j, - ], - [ - -0.894427190999916 - 0.0j, - -0.1825741858350548 + 0.0j, - 0.40824829046386324 + 0.0j, - ], - ]), - array([ - [ - -6.7082039324993694e00 + 0.0j, - -8.0498447189992444e00 + 0.0j, - -9.3914855054991175e00 + 0.0j, - ], - [ - 0.0000000000000000e00 + 0.0j, - 1.0954451150103341e00 + 0.0j, - 2.1908902300206665e00 + 0.0j, - ], - [ - 0.0000000000000000e00 + 0.0j, - 0.0000000000000000e00 + 0.0j, - -8.8817841970012523e-16 + 0.0j, - ], - ]), - ), - mlir_module_text=r""" -#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":364:11) -#loc10 = loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]"(#loc3)) module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<3x3xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<3x3xcomplex> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { + func.func public @main() -> (tensor<3x3xcomplex> {jax.result_info = "result[0]"}, tensor<3x3xcomplex> {jax.result_info = "result[1]"}) { + %c = stablehlo.constant dense<-1> : tensor loc(#loc) + %cst = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> loc(#loc) %0 = stablehlo.iota dim = 0 : tensor<9xcomplex> loc(#loc4) %1 = stablehlo.reshape %0 : (tensor<9xcomplex>) -> tensor<3x3xcomplex> loc(#loc5) - %c = stablehlo.constant dense<3> : tensor loc(#loc6) - %c_0 = stablehlo.constant dense<3> : tensor loc(#loc6) - %2:2 = stablehlo.custom_call @lapack_zgeqrf_ffi(%1) {mhlo.backend_config = {}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>]} : (tensor<3x3xcomplex>) -> (tensor<3x3xcomplex>, tensor<3xcomplex>) loc(#loc6) - %cst = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> loc(#loc7) - %3 = stablehlo.pad %2#0, %cst, low = [0, 0], high = [0, 0], interior = [0, 0] : (tensor<3x3xcomplex>, tensor>) -> tensor<3x3xcomplex> loc(#loc8) - %c_1 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_2 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_3 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_4 = stablehlo.constant dense<1> : tensor loc(#loc9) - %c_5 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_6 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_7 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_8 = stablehlo.constant dense<96> : tensor loc(#loc9) - %4:3 = stablehlo.custom_call @lapack_zungqr(%c_4, %c_5, %c_6, %c_7, %c_8, %3, %2#1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor<3x3xcomplex>, tensor<3xcomplex>) -> (tensor<3x3xcomplex>, tensor, tensor<96xcomplex>) loc(#loc9) - %c_9 = stablehlo.constant dense<0> : tensor loc(#loc9) - %5 = stablehlo.broadcast_in_dim %c_9, dims = [] : (tensor) -> tensor loc(#loc9) - %6 = stablehlo.compare EQ, %4#1, %5, SIGNED : (tensor, tensor) -> tensor loc(#loc9) - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc9) - %cst_10 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc9) - %8 = stablehlo.broadcast_in_dim %cst_10, dims = [] : (tensor>) -> tensor<3x3xcomplex> loc(#loc9) - %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> loc(#loc9) - %10 = stablehlo.select %9, %4#0, %8 : tensor<3x3xi1>, tensor<3x3xcomplex> loc(#loc9) - %11 = call @triu(%2#0) : (tensor<3x3xcomplex>) -> tensor<3x3xcomplex> loc(#loc10) - return %10, %11 : tensor<3x3xcomplex>, tensor<3x3xcomplex> loc(#loc) + %2:2 = stablehlo.custom_call @lapack_zgeqrf_ffi(%1) {mhlo.backend_config = {}, mhlo.frontend_attributes = {num_batch_dims = "0"}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>]} : (tensor<3x3xcomplex>) -> (tensor<3x3xcomplex>, tensor<3xcomplex>) loc(#loc6) + %3 = stablehlo.pad %2#0, %cst, low = [0, 0], high = [0, 0], interior = [0, 0] : (tensor<3x3xcomplex>, tensor>) -> tensor<3x3xcomplex> loc(#loc7) + %4 = stablehlo.custom_call @lapack_zungqr_ffi(%3, %2#1) {mhlo.backend_config = {}, mhlo.frontend_attributes = {num_batch_dims = "0"}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>]} : (tensor<3x3xcomplex>, tensor<3xcomplex>) -> tensor<3x3xcomplex> loc(#loc8) + %5 = stablehlo.iota dim = 0 : tensor<3x3xi32> loc(#loc9) + %6 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<3x3xi32> loc(#loc10) + %7 = stablehlo.add %5, %6 : tensor<3x3xi32> loc(#loc10) + %8 = stablehlo.iota dim = 1 : tensor<3x3xi32> loc(#loc9) + %9 = stablehlo.compare GE, %7, %8, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> loc(#loc11) + %10 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor>) -> tensor<3x3xcomplex> loc(#loc12) + %11 = stablehlo.select %9, %10, %2#0 : tensor<3x3xi1>, tensor<3x3xcomplex> loc(#loc13) + return %4, %11 : tensor<3x3xcomplex>, tensor<3x3xcomplex> loc(#loc) } loc(#loc) - func.func private @triu(%arg0: tensor<3x3xcomplex> {mhlo.layout_mode = "default"} loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]"(#loc3))) -> (tensor<3x3xcomplex> {mhlo.layout_mode = "default"}) { - %0 = stablehlo.iota dim = 0 : tensor<3x3xi32> loc(#loc11) - %c = stablehlo.constant dense<-1> : tensor loc(#loc10) - %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<3x3xi32> loc(#loc12) - %2 = stablehlo.add %0, %1 : tensor<3x3xi32> loc(#loc12) - %3 = stablehlo.iota dim = 1 : tensor<3x3xi32> loc(#loc13) - %4 = stablehlo.compare GE, %2, %3, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> loc(#loc14) - %cst = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> loc(#loc10) - %5 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor>) -> tensor<3x3xcomplex> loc(#loc15) - %6 = stablehlo.select %4, %5, %arg0 : tensor<3x3xi1>, tensor<3x3xcomplex> loc(#loc16) - return %6 : tensor<3x3xcomplex> loc(#loc10) - } loc(#loc10) } loc(#loc) #loc = loc(unknown) -#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":363:26) -#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":363:14) -#loc4 = loc("jit()/jit(main)/iota[dtype=complex128 shape=(9,) dimension=0]"(#loc1)) -#loc5 = loc("jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]"(#loc2)) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":411:26) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":411:14) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":412:11) +#loc4 = loc("jit()/jit(main)/iota"(#loc1)) +#loc5 = loc("jit()/jit(main)/reshape"(#loc2)) #loc6 = loc("jit()/jit(main)/geqrf"(#loc3)) -#loc7 = loc("jit()/jit(main)/qr[full_matrices=True]"(#loc3)) -#loc8 = loc("jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0))]"(#loc3)) -#loc9 = loc("jit()/jit(main)/householder_product"(#loc3)) -#loc11 = loc("jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]"(#loc3)) -#loc12 = loc("jit()/jit(main)/jit(triu)/add"(#loc3)) -#loc13 = loc("jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]"(#loc3)) -#loc14 = loc("jit()/jit(main)/jit(triu)/ge"(#loc3)) -#loc15 = loc("jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3, 3) broadcast_dimensions=()]"(#loc3)) -#loc16 = loc("jit()/jit(main)/jit(triu)/select_n"(#loc3)) +#loc7 = loc("jit()/jit(main)/pad"(#loc3)) +#loc8 = loc("jit()/jit(main)/householder_product"(#loc3)) +#loc9 = loc("jit()/jit(main)/iota"(#loc3)) +#loc10 = loc("jit()/jit(main)/add"(#loc3)) +#loc11 = loc("jit()/jit(main)/ge"(#loc3)) +#loc12 = loc("jit()/jit(main)/broadcast_in_dim"(#loc3)) +#loc13 = loc("jit()/jit(main)/select_n"(#loc3)) """, - mlir_module_serialized=( - b"ML\xefR\x01StableHLO_v0.9.0\x00\x01)\x05\x01\x03\x01\x03\x05\x03\x19\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x03\xae\x02\x12\x025\x01\x9d\x0f\x17\x0b\x0f\x13\x13\x0b\x07\x0b\x0f\x13\x0f\x0b\x0b\x0b\x0b\x13\x0b\x0b\x0f\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b+\x0b\x0f\x0b\x0b\x0b33\x0b\x0f\x0b\x13\x0b\x13\x0f\x0b\x1b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x17\x0f\x0b\x17\x0bS\x0b\x0f\x0b#\x0b\x0b\x0b\x0f\x0b\x0b\x13\x13K\x13\x1b\x13\x03g\x0fO\x0b\x0b\x0b//\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b\x13\x0b\x0b\x0f\x1f\x0f\x0f\x0bO/\x0b\x0b\x0b\x0f\x0f\x17\x13\x1f\x1f\x1f\x0b\x0b'\x0f\x17\x17\x1f\x0bOO\x01\x07\x17\x17\x0b\x01\x05\x0b\x0f\x031\x17\x0f\x0f\x0b\x07\x0f\x17\x07\x07\x07\x17\x13\x17\x07\x17\x13\x13\x13\x13\x13\x17\x13\x0f\x17\x02\xf2\t\x1d\x8f\x03\x17\x11\xb2\x05\x17\x05\x1f\x1dO\x03\x03\x03%\xd1\x03\x03\x05\xd9\x05!\x1f\x05#\x1dy\x03\x03\x03\x05\xeb\x11\x03\x05\x05%\x05'\x05)\x05+\x03\x03#\xcd\x05-\x05/\x1dW\x03\x051\x053\x03\x03\x05\xd7\x055\x057\x059\x05;\x05=\x05?\x05A\x05C\x03\tACE\x17G\x17\rI\x05E\x11\x01\x00\x05G\x05I\x05K\x03\x0b\x19\xa1\x1b\xb7\x1d\xb9\r\xc3\x1f\xc5\x03\x0b\x19\xad\x1b\xc9\x1d\xad\r\xaf\x1f\xcb\x05M\x1dS\x03\x05O\x03\x03\x05\xcf\x05Q\x03\x03#\xd3\x1d]\x03\x05S\x03\x05)\xb1+\xd5\x1dc\x03\x05U\x1dg\x03\x05W\x1dk\x03\x05Y\x1doq\x05[\x17\x11\xae\x055\x1duw\x05]\x17\x11\xae\x05\x1d\x05_\x03\x13/\xdb1\xb33\xdd5\xa17\xb5}\xdf9\xe1;\xe3=\xe7\x05a\x1d\x81\x03\x05c\x03\x07\x85\xa9\x87\xa9\x89\xa9\x05e\x05g\x05i\x1d\x8d\x03\x05k\x05m\x03\x03\x05\xe9\x03\x03\x05\xed\x03\x11/\xef1\xb33\xf15\xa17\xb59\xf3;\xf5=\xf9\x03\x03\x05\xfb\x03\x05)\xb1+\xfd\x03\x03\x05\xff\x1f/\x01\x1f)!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1do\x1dq\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x1b\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1ds\x03\x03\xc7\x1du\t\x07\x1dw\x05\x01#\x1d\x03\x05\xbb\xbf\r\x05\xab\xbd\xa3\xa5\x1dy\r\x05\xab\xc1\xa3\xa5\x1d{\x1d}\x1d\x7f\r\x03\xa3\xa5#!\x1d\x81\x13\r\x01\x1f\x07\t\xff\xff\xff\xff\x1f#\x01\x13\r\x05\x07\x05\x1f\x0f!\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\t\x11\x03\x00\x00\x00\x00\x00\x00\x00\x0b\x03\x1d\x83\r\x01\x03\x03\x9f\x03\x03\xe5\x15\x03\x01\x01\x01\x03\x05\x9f\xa7\x1f\x07\t\x01\x00\x00\x00\x1f\x07\t\x03\x00\x00\x00\x1f\x07\t`\x00\x00\x00\x0b\x05\x1d\x85\x03\x0f\x9d\x9d\x9d\x9d\x9d\x9f\xa7\x03\x03\xf7\x15\x03\x01\x15\x01\x03\x07\x9f\x9d\xa7\x1f\x07\t\x00\x00\x00\x00\x07\x01\x1f\x0f!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x1b!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x03%\x02\x02\x03\x03\x0e\x02\xaf\x05\x87\x01\t\x01\x02\x02)\x05\r\r\x0b)\x01\x17)\x01\r\x03\x1f\x1d)\x01\x0b)\x05\r\r\x17\x01\x13\x1b)\x05\r\r\x13)\x03\t\r\x11\x01\x05\x05\x05\x0b\x11\x03\x05\x03\x05)\x03\x01\r)\x03%\x0b)\x03\r\x0b)\x03\t\x15)\x03\x05\x15)\x03\x02\x03\x0b)\x03\x01\x15)\x01\x13)\x05\x05\x05\x13\x04\x8a\x04\x05\x01\x11\x0f?\x07\x03\x01\t\t\x11\x0fK\x07\x039i\x07\x03m!\x03%\x15\x06s\x03\x05\x03\x01\x03\x03\x13\x0b\x03\t\x03\x03\x13\x0b\x03\t\x11\x07\x13{\x05\x05'\x03\x03\x03\x03\x7f-\x03\x0f\x17\x07\x8b\x83\x03\x05\x05\t\r\x03\x03\x01\x0b\x03\t\x03\x03\x01\x0b\x03\t\x03\x03\x01\x0b\x03\t\x03\x03\x01\x91\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x93\x03\x07\x11\x07\x01\x95\x07\x05\x07-\x0f\x17\x19\x1b\x1d\x1f\x0f\x0b\x03\x03\x01\x97\x03\x07\x05\x07\x01\t\x03\x07\x03'\x0b\x07\x01\x99\x031\x05#)\x05\x07\x01\t\x033\x03+\x03\x03\x01\x9b\x03\x0f\x05\x07\x01\t\x03\x05\x03/\x05\x07\x01\x06\x02\x03\x19\x03-\r\x06\x01\x03\x05\x073!1\x19\x07\x07\n\x02\x03\x05\x03\t\x0f\x04\x0f\x0557\t\x11\x07M\x07\x03\x15+\x03\x05\x07\x07\x03Q!\x03\x11\x03\x03\x07U\x03\x07\x05\x07'\t\x03\x11\x03\x05\x13\x06'\x03\x11\x05\x03\x07\x07\x03[Y\x03\x11\x0b\x07a_\x03\x19\x05\t\x0b\x03\x03\x07-\x03\x0f\x05\x07e\t\x03\x05\x03\x0f\r\x06i\x03\x05\x07\r\x11\x01\x0f\x04\x07\x03\x13\x06\x03\x01\x05\x01\x00\xaa\x1a\x89\x0f\x1d%\x11\x0f\x0b\t\t\x03\x0b!\x11#Y\x87##%_)=\x85\x8dW\xb3K\x9bM\x9bn\x03\x1b%)9\x1f/!!)#\x1f\x19+\x1b+\x1f\x1f\x15\x1d\x15i\x13\r\x11\x0f\x17\x0f\x1f\x15\x15\x17\x11\x11)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00iota_v1\x00func_v1\x00compare_v1\x00select_v1\x00return_v1\x00custom_call_v1\x00add_v1\x00reshape_v1\x00pad_v1\x00call_v1\x00value\x00sym_name\x00third_party/py/jax/tests/export_back_compat_test.py\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00broadcast_dimensions\x00compare_type\x00comparison_direction\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,)" - b' out_shardings=(UnspecifiedValue,) in_layouts=(None,)' - b' out_layouts=(None,) resource_env=None donated_invars=(False,)' - b' name=triu keep_unused=False' - b' inline=False]\x00jit()/jit(main)/jit(triu)/iota[dtype=int32' - b' shape=(3, 3)' - b' dimension=0]\x00jit()/jit(main)/jit(triu)/add\x00jit()/jit(main)/jit(triu)/iota[dtype=int32' - b' shape=(3, 3)' - b' dimension=1]\x00jit()/jit(main)/jit(triu)/ge\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3,' - b' 3) broadcast_dimensions=()]\x00jit()/jit(main)/jit(triu)/select_n\x00jit()/jit(main)/iota[dtype=complex128' - b' shape=(9,)' - b' dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 3)' - b' dimensions=None]\x00jit()/jit(main)/geqrf\x00mhlo.backend_config\x00jit()/jit(main)/qr[full_matrices=True]\x00edge_padding_high\x00edge_padding_low\x00interior_padding\x00jit()/jit(main)/pad[padding_config=((0,' - b' 0, 0), (0, 0,' - b' 0))]\x00jit()/jit(main)/householder_product\x00mhlo.layout_mode\x00default\x00jax.result_info\x00triu\x00\x00[0]\x00[1]\x00main\x00public\x00private\x00lapack_zgeqrf_ffi\x00lapack_zungqr\x00callee\x00' - ), + mlir_module_serialized=b'ML\xefR\rStableHLO_v1.9.3\x00\x01)\x05\x01\x05\x19\x01\x03\x0b\x03\x17\x0f\x13\x17\x1b\x1f#\'+/37\x03\xc7\x8b)\x01E\x17\x07\x0b\x0f\x0b\x1b\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x17\x0f\x0b\x17\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x03G\x0b/\x0b\x0f\x0b\x0b\x0b\x0fO\x13\x0f\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x1fO\x0b\x13\x0b\x0b\x0b\x0f\x17/\x0b\x0f\x13\x0f\x0b\x0b\x01\x05\x0b\x0f\x03%\x17\x0b\x07\x17\x0f\x07\x0f\x07\x17\x07\x13\x13\x13\x13\x13\x13\x17\x07\x02\xd2\x04\x17\x05r\x06\x17\x1f\x05\x1d\x11\x03\x05\x05\x1f\x03\x05\'o)q\x1d\t\x01\x1d7\x01\x03\x07\x13\x15\x17\x07\x19\x07\x05!\x11\x01\x00\x05#\x05%\x05\'\x1d\t\x1f\x17\x05n\x065\x1d#%\x05)\x17\x05n\x06\x1d\x05+\x05-\x1d-\x01\x05/\x1d1\x01\x051\x1d5\x01\x053\x055\x1d;\x01\x057\x1d?\x01\x059\x1dC\x01\x05;\x03\x01\x1f!\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1d=\x13\t\x01\x0b\x03\x1d?\x05\x01\x03\x03U\x1f\x1d!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x05U}\x1f#\x01#\x15\x03\x05_c\r\x03Ia\x1dA\r\x03Ie\x1dC\x1dE\x1dG\x1f\r\t\xff\xff\xff\xff\x1f\x11!\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\r\x01\r\x03su\x1dI\x1dK\x1dM\x03\x03{\x15\x03\x01\x01\x01\x1f\x1f\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1dO\x03\x03\x83\x15\x01\x01\x01\x13\t\x05\t\x07\x07\x05\x01\t\x01\x02\x02)\x05\r\r\x07\x03\x17\x1d)\x05\r\r\x0f)\x01\x0f\x1b)\x01\x07\x13\x11\x01\x05\x05\x05\x0b)\x03%\x07)\x03\r\x07)\x03\t\x13)\x03\x05\x13)\x03\t\t)\x03\x01\t)\x05\r\r\'\x01\x04"\x02\x05\x01Q\x03\x11\x01\x07\x04\xff\x03\x01\x05\x0bP\x03\x03\x07\x04\xeb\x03\x1f=\x05B\x03\x05\x03\r\x05B\x03\x07\x03\x11\x03B\x1d\t\x03\x19\r\x06!\x03\x05\x03\x05\x07G+\x0b\x0b\x05\x05\x1b\x03\x07\x0fF/\r\x03\x05\x05\t\x03\x07G3\x0b\x0f\x03\x05\x05\r\x0b\x03B\r\t\x03\x0b\tF\x0f\x11\x03\x0b\x03\x01\x11\x06\x0f\x03\x0b\x05\x11\x13\x03B\r\x13\x03\x0b\x13F9\x15\x03%\x05\x15\x17\tF=\x11\x03\x05\x03\x03\x15\x06A\x03\x05\x07\x19\x1b\t\x17\x04\x03\x05\x0f\x1d\x06\x03\x01\x05\x01\x00\xba\x0bQ%%\x05\x1f\x0f\x0b\x15\x15\x03!CS79Y9=3)A\x1b%)9;i\x15\x15\x17\x0f\x0f\x17\x11)\x1f\x19\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00iota_v1\x00constant_v1\x00custom_call_v1\x00broadcast_in_dim_v1\x00func_v1\x00reshape_v1\x00pad_v1\x00add_v1\x00compare_v1\x00select_v1\x00return_v1\x00third_party/py/jax/tests/export_back_compat_test.py\x00jit()/jit(main)/iota\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/reshape\x00mhlo.backend_config\x00mhlo.frontend_attributes\x00jit()/jit(main)/geqrf\x00jit()/jit(main)/pad\x00jit()/jit(main)/householder_product\x00jit()/jit(main)/add\x00jit()/jit(main)/ge\x00jit()/jit(main)/broadcast_in_dim\x00jit()/jit(main)/select_n\x00jax.result_info\x00\x00result[0]\x00result[1]\x00main\x00public\x00num_batch_dims\x000\x00lapack_zgeqrf_ffi\x00lapack_zungqr_ffi\x00\x08[\x17\x057\x01\x0bE[]gi\x03k\x03m\x03K\x11MOwEQSyW\x07GGG\x11MO\x7fEQW\x81S\x03Y\x03\x85\x05\x87\x89', xla_call_module_version=9, nr_devices=1, ) # End paste - -# Pasted from the test output (see export_back_compat_test_util.py module docstring) -data_2024_08_22['c64'] = dict( +data_2025_04_02['c64'] = dict( testdata_version=1, platform='cpu', custom_call_targets=['lapack_cgeqrf_ffi', 'lapack_cungqr_ffi'], - serialized_date=datetime.date(2024, 8, 22), + serialized_date=datetime.date(2025, 4, 2), inputs=(), - expected_outputs=( - array( - [ - [0.0 + 0.0j, 0.91287076 + 0.0j, 0.4082487 + 0.0j], - [-0.44721356 - 0.0j, 0.36514866 + 0.0j, -0.8164965 + 0.0j], - [-0.8944271 - 0.0j, -0.18257445 + 0.0j, 0.40824816 + 0.0j], - ], - dtype=complex64, - ), - array( - [ - [ - -6.7082043e00 + 0.0j, - -8.0498438e00 + 0.0j, - -9.3914852e00 + 0.0j, - ], - [0.0000000e00 + 0.0j, 1.0954441e00 + 0.0j, 2.1908894e00 + 0.0j], - [ - 0.0000000e00 + 0.0j, - 0.0000000e00 + 0.0j, - 7.1525574e-07 + 0.0j, - ], - ], - dtype=complex64, - ), - ), + expected_outputs=(array([[ 0. +0.j, 0.91287076+0.j, 0.4082487 +0.j], + [-0.44721356-0.j, 0.36514866+0.j, -0.8164965 +0.j], + [-0.8944271 -0.j, -0.18257445+0.j, 0.40824816+0.j]], + dtype=complex64), array([[-6.7082043e+00+0.j, -8.0498438e+00+0.j, -9.3914852e+00+0.j], + [ 0.0000000e+00+0.j, 1.0954441e+00+0.j, 2.1908894e+00+0.j], + [ 0.0000000e+00+0.j, 0.0000000e+00+0.j, 7.1525574e-07+0.j]], + dtype=complex64)), mlir_module_text=r""" -#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":364:11) -#loc10 = loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]"(#loc3)) module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<3x3xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<3x3xcomplex> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { + func.func public @main() -> (tensor<3x3xcomplex> {jax.result_info = "result[0]"}, tensor<3x3xcomplex> {jax.result_info = "result[1]"}) { + %c = stablehlo.constant dense<-1> : tensor loc(#loc) + %cst = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> loc(#loc) %0 = stablehlo.iota dim = 0 : tensor<9xcomplex> loc(#loc4) %1 = stablehlo.reshape %0 : (tensor<9xcomplex>) -> tensor<3x3xcomplex> loc(#loc5) - %c = stablehlo.constant dense<3> : tensor loc(#loc6) - %c_0 = stablehlo.constant dense<3> : tensor loc(#loc6) - %2:2 = stablehlo.custom_call @lapack_cgeqrf_ffi(%1) {mhlo.backend_config = {}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>]} : (tensor<3x3xcomplex>) -> (tensor<3x3xcomplex>, tensor<3xcomplex>) loc(#loc6) - %cst = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> loc(#loc7) - %3 = stablehlo.pad %2#0, %cst, low = [0, 0], high = [0, 0], interior = [0, 0] : (tensor<3x3xcomplex>, tensor>) -> tensor<3x3xcomplex> loc(#loc8) - %c_1 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_2 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_3 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_4 = stablehlo.constant dense<1> : tensor loc(#loc9) - %c_5 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_6 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_7 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_8 = stablehlo.constant dense<96> : tensor loc(#loc9) - %4:3 = stablehlo.custom_call @lapack_cungqr(%c_4, %c_5, %c_6, %c_7, %c_8, %3, %2#1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor<3x3xcomplex>, tensor<3xcomplex>) -> (tensor<3x3xcomplex>, tensor, tensor<96xcomplex>) loc(#loc9) - %c_9 = stablehlo.constant dense<0> : tensor loc(#loc9) - %5 = stablehlo.broadcast_in_dim %c_9, dims = [] : (tensor) -> tensor loc(#loc9) - %6 = stablehlo.compare EQ, %4#1, %5, SIGNED : (tensor, tensor) -> tensor loc(#loc9) - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc9) - %cst_10 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc9) - %8 = stablehlo.broadcast_in_dim %cst_10, dims = [] : (tensor>) -> tensor<3x3xcomplex> loc(#loc9) - %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> loc(#loc9) - %10 = stablehlo.select %9, %4#0, %8 : tensor<3x3xi1>, tensor<3x3xcomplex> loc(#loc9) - %11 = call @triu(%2#0) : (tensor<3x3xcomplex>) -> tensor<3x3xcomplex> loc(#loc10) - return %10, %11 : tensor<3x3xcomplex>, tensor<3x3xcomplex> loc(#loc) + %2:2 = stablehlo.custom_call @lapack_cgeqrf_ffi(%1) {mhlo.backend_config = {}, mhlo.frontend_attributes = {num_batch_dims = "0"}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>]} : (tensor<3x3xcomplex>) -> (tensor<3x3xcomplex>, tensor<3xcomplex>) loc(#loc6) + %3 = stablehlo.pad %2#0, %cst, low = [0, 0], high = [0, 0], interior = [0, 0] : (tensor<3x3xcomplex>, tensor>) -> tensor<3x3xcomplex> loc(#loc7) + %4 = stablehlo.custom_call @lapack_cungqr_ffi(%3, %2#1) {mhlo.backend_config = {}, mhlo.frontend_attributes = {num_batch_dims = "0"}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>]} : (tensor<3x3xcomplex>, tensor<3xcomplex>) -> tensor<3x3xcomplex> loc(#loc8) + %5 = stablehlo.iota dim = 0 : tensor<3x3xi32> loc(#loc9) + %6 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<3x3xi32> loc(#loc10) + %7 = stablehlo.add %5, %6 : tensor<3x3xi32> loc(#loc10) + %8 = stablehlo.iota dim = 1 : tensor<3x3xi32> loc(#loc9) + %9 = stablehlo.compare GE, %7, %8, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> loc(#loc11) + %10 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor>) -> tensor<3x3xcomplex> loc(#loc12) + %11 = stablehlo.select %9, %10, %2#0 : tensor<3x3xi1>, tensor<3x3xcomplex> loc(#loc13) + return %4, %11 : tensor<3x3xcomplex>, tensor<3x3xcomplex> loc(#loc) } loc(#loc) - func.func private @triu(%arg0: tensor<3x3xcomplex> {mhlo.layout_mode = "default"} loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]"(#loc3))) -> (tensor<3x3xcomplex> {mhlo.layout_mode = "default"}) { - %0 = stablehlo.iota dim = 0 : tensor<3x3xi32> loc(#loc11) - %c = stablehlo.constant dense<-1> : tensor loc(#loc10) - %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<3x3xi32> loc(#loc12) - %2 = stablehlo.add %0, %1 : tensor<3x3xi32> loc(#loc12) - %3 = stablehlo.iota dim = 1 : tensor<3x3xi32> loc(#loc13) - %4 = stablehlo.compare GE, %2, %3, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> loc(#loc14) - %cst = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> loc(#loc10) - %5 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor>) -> tensor<3x3xcomplex> loc(#loc15) - %6 = stablehlo.select %4, %5, %arg0 : tensor<3x3xi1>, tensor<3x3xcomplex> loc(#loc16) - return %6 : tensor<3x3xcomplex> loc(#loc10) - } loc(#loc10) } loc(#loc) #loc = loc(unknown) -#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":363:26) -#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":363:14) -#loc4 = loc("jit()/jit(main)/iota[dtype=complex64 shape=(9,) dimension=0]"(#loc1)) -#loc5 = loc("jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]"(#loc2)) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":411:26) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":411:14) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":412:11) +#loc4 = loc("jit()/jit(main)/iota"(#loc1)) +#loc5 = loc("jit()/jit(main)/reshape"(#loc2)) #loc6 = loc("jit()/jit(main)/geqrf"(#loc3)) -#loc7 = loc("jit()/jit(main)/qr[full_matrices=True]"(#loc3)) -#loc8 = loc("jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0))]"(#loc3)) -#loc9 = loc("jit()/jit(main)/householder_product"(#loc3)) -#loc11 = loc("jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]"(#loc3)) -#loc12 = loc("jit()/jit(main)/jit(triu)/add"(#loc3)) -#loc13 = loc("jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]"(#loc3)) -#loc14 = loc("jit()/jit(main)/jit(triu)/ge"(#loc3)) -#loc15 = loc("jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3, 3) broadcast_dimensions=()]"(#loc3)) -#loc16 = loc("jit()/jit(main)/jit(triu)/select_n"(#loc3)) +#loc7 = loc("jit()/jit(main)/pad"(#loc3)) +#loc8 = loc("jit()/jit(main)/householder_product"(#loc3)) +#loc9 = loc("jit()/jit(main)/iota"(#loc3)) +#loc10 = loc("jit()/jit(main)/add"(#loc3)) +#loc11 = loc("jit()/jit(main)/ge"(#loc3)) +#loc12 = loc("jit()/jit(main)/broadcast_in_dim"(#loc3)) +#loc13 = loc("jit()/jit(main)/select_n"(#loc3)) """, - mlir_module_serialized=( - b"ML\xefR\x01StableHLO_v0.9.0\x00\x01)\x05\x01\x03\x01\x03\x05\x03\x19\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x03\xae\x02\x12\x025\x01\x9d\x0f\x17\x0b\x0f\x13\x13\x0b\x07\x0b\x0f\x13\x0f\x0b\x0b\x0b\x0b\x13\x0b\x0b\x0f\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b+\x0b\x0f\x0b\x0b\x0b33\x0b\x0f\x0b\x13\x0b\x13\x0f\x0b\x1b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x17\x0f\x0b\x17\x0bS\x0b\x0f\x0b#\x0b\x0b\x0b\x0f\x0b\x0b\x13\x13K\x13\x1b\x13\x03g\x0fO\x0b\x0b\x0b//\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b\x13\x0b\x0b\x0f\x1f\x0f\x0f\x0b//\x0b\x0b\x0b\x0f\x0f\x17\x13\x1f\x1f\x1f\x0b\x0b'\x0f\x17\x17\x1f\x0b/O\x01\x07\x17\x17\x0b\x01\x05\x0b\x0f\x031\x17\x0f\x0f\x0b\x07\x0f\x17\x07\x07\x07\x17\x13\x17\x07\x17\x13\x13\x13\x13\x13\x17\x13\x0f\x17\x02\xb2\t\x1d\x8f\x03\x17\x11\xb2\x05\x17\x05\x1f\x1dO\x03\x03\x03%\xd1\x03\x03\x05\xd9\x05!\x1f\x05#\x1dy\x03\x03\x03\x05\xeb\x11\x03\x05\x05%\x05'\x05)\x05+\x03\x03#\xcd\x05-\x05/\x1dW\x03\x051\x053\x03\x03\x05\xd7\x055\x057\x059\x05;\x05=\x05?\x05A\x05C\x03\tACE\x17G\x17\rI\x05E\x11\x01\x00\x05G\x05I\x05K\x03\x0b\x19\xa1\x1b\xb7\x1d\xb9\r\xc3\x1f\xc5\x03\x0b\x19\xad\x1b\xc9\x1d\xad\r\xaf\x1f\xcb\x05M\x1dS\x03\x05O\x03\x03\x05\xcf\x05Q\x03\x03#\xd3\x1d]\x03\x05S\x03\x05)\xb1+\xd5\x1dc\x03\x05U\x1dg\x03\x05W\x1dk\x03\x05Y\x1doq\x05[\x17\x11\xae\x055\x1duw\x05]\x17\x11\xae\x05\x1d\x05_\x03\x13/\xdb1\xb33\xdd5\xa17\xb5}\xdf9\xe1;\xe3=\xe7\x05a\x1d\x81\x03\x05c\x03\x07\x85\xa9\x87\xa9\x89\xa9\x05e\x05g\x05i\x1d\x8d\x03\x05k\x05m\x03\x03\x05\xe9\x03\x03\x05\xed\x03\x11/\xef1\xb33\xf15\xa17\xb59\xf3;\xf5=\xf9\x03\x03\x05\xfb\x03\x05)\xb1+\xfd\x03\x03\x05\xff\x1f/\x01\x1f)!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1do\x1dq\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x1b\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1ds\x03\x03\xc7\x1du\t\x07\x1dw\x05\x01#\x1d\x03\x05\xbb\xbf\r\x05\xab\xbd\xa3\xa5\x1dy\r\x05\xab\xc1\xa3\xa5\x1d{\x1d}\x1d\x7f\r\x03\xa3\xa5#!\x1d\x81\x13\r\x01\x1f\x07\t\xff\xff\xff\xff\x1f#\x01\x13\r\x05\x07\x05\x1f\x0f\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\t\x11\x03\x00\x00\x00\x00\x00\x00\x00\x0b\x03\x1d\x83\r\x01\x03\x03\x9f\x03\x03\xe5\x15\x03\x01\x01\x01\x03\x05\x9f\xa7\x1f\x07\t\x01\x00\x00\x00\x1f\x07\t\x03\x00\x00\x00\x1f\x07\t`\x00\x00\x00\x0b\x05\x1d\x85\x03\x0f\x9d\x9d\x9d\x9d\x9d\x9f\xa7\x03\x03\xf7\x15\x03\x01\x15\x01\x03\x07\x9f\x9d\xa7\x1f\x07\t\x00\x00\x00\x00\x07\x01\x1f\x0f\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f\x1b!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x03%\x02\x02\x03\x03\x0e\x02\xaf\x05\x87\x01\t\x01\x02\x02)\x05\r\r\x0b)\x01\x17)\x01\r\x03\x1f\x1d)\x01\x0b)\x05\r\r\x17\x01\x13\x1b)\x05\r\r\x13)\x03\t\r\x11\x01\x05\x05\x05\t\x11\x03\x05\x03\x05)\x03\x01\r)\x03%\x0b)\x03\r\x0b)\x03\t\x15)\x03\x05\x15)\x03\x02\x03\x0b)\x03\x01\x15)\x01\x13)\x05\x05\x05\x13\x04\x8a\x04\x05\x01\x11\x0f?\x07\x03\x01\t\t\x11\x0fK\x07\x039i\x07\x03m!\x03%\x15\x06s\x03\x05\x03\x01\x03\x03\x13\x0b\x03\t\x03\x03\x13\x0b\x03\t\x11\x07\x13{\x05\x05'\x03\x03\x03\x03\x7f-\x03\x0f\x17\x07\x8b\x83\x03\x05\x05\t\r\x03\x03\x01\x0b\x03\t\x03\x03\x01\x0b\x03\t\x03\x03\x01\x0b\x03\t\x03\x03\x01\x91\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x93\x03\x07\x11\x07\x01\x95\x07\x05\x07-\x0f\x17\x19\x1b\x1d\x1f\x0f\x0b\x03\x03\x01\x97\x03\x07\x05\x07\x01\t\x03\x07\x03'\x0b\x07\x01\x99\x031\x05#)\x05\x07\x01\t\x033\x03+\x03\x03\x01\x9b\x03\x0f\x05\x07\x01\t\x03\x05\x03/\x05\x07\x01\x06\x02\x03\x19\x03-\r\x06\x01\x03\x05\x073!1\x19\x07\x07\n\x02\x03\x05\x03\t\x0f\x04\x0f\x0557\t\x11\x07M\x07\x03\x15+\x03\x05\x07\x07\x03Q!\x03\x11\x03\x03\x07U\x03\x07\x05\x07'\t\x03\x11\x03\x05\x13\x06'\x03\x11\x05\x03\x07\x07\x03[Y\x03\x11\x0b\x07a_\x03\x19\x05\t\x0b\x03\x03\x07-\x03\x0f\x05\x07e\t\x03\x05\x03\x0f\r\x06i\x03\x05\x07\r\x11\x01\x0f\x04\x07\x03\x13\x06\x03\x01\x05\x01\x00\xa6\x1a\x89\x0f\x1d%\x11\x0f\x0b\t\t\x03\x0b!\x11#Y\x87##%_)=\x85\x8bW\xb3K\x9bM\x9bn\x03\x1b%)9\x1f/!!)#\x1f\x19+\x1b+\x1f\x1f\x15\x1d\x15i\x13\r\x11\x0f\x17\x0f\x1f\x15\x15\x17\x11\x11)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00iota_v1\x00func_v1\x00compare_v1\x00select_v1\x00return_v1\x00custom_call_v1\x00add_v1\x00reshape_v1\x00pad_v1\x00call_v1\x00value\x00sym_name\x00third_party/py/jax/tests/export_back_compat_test.py\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00broadcast_dimensions\x00compare_type\x00comparison_direction\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,)" - b' out_shardings=(UnspecifiedValue,) in_layouts=(None,)' - b' out_layouts=(None,) resource_env=None donated_invars=(False,)' - b' name=triu keep_unused=False' - b' inline=False]\x00jit()/jit(main)/jit(triu)/iota[dtype=int32' - b' shape=(3, 3)' - b' dimension=0]\x00jit()/jit(main)/jit(triu)/add\x00jit()/jit(main)/jit(triu)/iota[dtype=int32' - b' shape=(3, 3)' - b' dimension=1]\x00jit()/jit(main)/jit(triu)/ge\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3,' - b' 3) broadcast_dimensions=()]\x00jit()/jit(main)/jit(triu)/select_n\x00jit()/jit(main)/iota[dtype=complex64' - b' shape=(9,)' - b' dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 3)' - b' dimensions=None]\x00jit()/jit(main)/geqrf\x00mhlo.backend_config\x00jit()/jit(main)/qr[full_matrices=True]\x00edge_padding_high\x00edge_padding_low\x00interior_padding\x00jit()/jit(main)/pad[padding_config=((0,' - b' 0, 0), (0, 0,' - b' 0))]\x00jit()/jit(main)/householder_product\x00mhlo.layout_mode\x00default\x00jax.result_info\x00triu\x00\x00[0]\x00[1]\x00main\x00public\x00private\x00lapack_cgeqrf_ffi\x00lapack_cungqr\x00callee\x00' - ), + mlir_module_serialized=b'ML\xefR\rStableHLO_v1.9.3\x00\x01)\x05\x01\x05\x19\x01\x03\x0b\x03\x17\x0f\x13\x17\x1b\x1f#\'+/37\x03\xc7\x8b)\x01E\x17\x07\x0b\x0f\x0b\x1b\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x17\x0f\x0b\x17\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x03G\x0b/\x0b\x0f\x0b\x0b\x0b\x0fO\x13\x0f\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x1f/\x0b\x13\x0b\x0b\x0b\x0f\x17/\x0b\x0f\x13\x0f\x0b\x0b\x01\x05\x0b\x0f\x03%\x17\x0b\x07\x17\x0f\x07\x0f\x07\x17\x07\x13\x13\x13\x13\x13\x13\x17\x07\x02\xb2\x04\x17\x05r\x06\x17\x1f\x05\x1d\x11\x03\x05\x05\x1f\x03\x05\'o)q\x1d\t\x01\x1d7\x01\x03\x07\x13\x15\x17\x07\x19\x07\x05!\x11\x01\x00\x05#\x05%\x05\'\x1d\t\x1f\x17\x05n\x065\x1d#%\x05)\x17\x05n\x06\x1d\x05+\x05-\x1d-\x01\x05/\x1d1\x01\x051\x1d5\x01\x053\x055\x1d;\x01\x057\x1d?\x01\x059\x1dC\x01\x05;\x03\x01\x1f!\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1d=\x13\t\x01\x0b\x03\x1d?\x05\x01\x03\x03U\x1f\x1d!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x05U}\x1f#\x01#\x15\x03\x05_c\r\x03Ia\x1dA\r\x03Ie\x1dC\x1dE\x1dG\x1f\r\t\xff\xff\xff\xff\x1f\x11\x11\x00\x00\x00\x00\x00\x00\x00\x00\r\x01\r\x03su\x1dI\x1dK\x1dM\x03\x03{\x15\x03\x01\x01\x01\x1f\x1f\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1dO\x03\x03\x83\x15\x01\x01\x01\x13\t\x05\t\x07\x07\x05\x01\t\x01\x02\x02)\x05\r\r\x07\x03\x17\x1d)\x05\r\r\x0f)\x01\x0f\x1b)\x01\x07\x13\x11\x01\x05\x05\x05\t)\x03%\x07)\x03\r\x07)\x03\t\x13)\x03\x05\x13)\x03\t\t)\x03\x01\t)\x05\r\r\'\x01\x04"\x02\x05\x01Q\x03\x11\x01\x07\x04\xff\x03\x01\x05\x0bP\x03\x03\x07\x04\xeb\x03\x1f=\x05B\x03\x05\x03\r\x05B\x03\x07\x03\x11\x03B\x1d\t\x03\x19\r\x06!\x03\x05\x03\x05\x07G+\x0b\x0b\x05\x05\x1b\x03\x07\x0fF/\r\x03\x05\x05\t\x03\x07G3\x0b\x0f\x03\x05\x05\r\x0b\x03B\r\t\x03\x0b\tF\x0f\x11\x03\x0b\x03\x01\x11\x06\x0f\x03\x0b\x05\x11\x13\x03B\r\x13\x03\x0b\x13F9\x15\x03%\x05\x15\x17\tF=\x11\x03\x05\x03\x03\x15\x06A\x03\x05\x07\x19\x1b\t\x17\x04\x03\x05\x0f\x1d\x06\x03\x01\x05\x01\x00\xba\x0bQ%%\x05\x1f\x0f\x0b\x15\x15\x03!CS79Y9=3)A\x1b%)9;i\x15\x15\x17\x0f\x0f\x17\x11)\x1f\x19\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00iota_v1\x00constant_v1\x00custom_call_v1\x00broadcast_in_dim_v1\x00func_v1\x00reshape_v1\x00pad_v1\x00add_v1\x00compare_v1\x00select_v1\x00return_v1\x00third_party/py/jax/tests/export_back_compat_test.py\x00jit()/jit(main)/iota\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/reshape\x00mhlo.backend_config\x00mhlo.frontend_attributes\x00jit()/jit(main)/geqrf\x00jit()/jit(main)/pad\x00jit()/jit(main)/householder_product\x00jit()/jit(main)/add\x00jit()/jit(main)/ge\x00jit()/jit(main)/broadcast_in_dim\x00jit()/jit(main)/select_n\x00jax.result_info\x00\x00result[0]\x00result[1]\x00main\x00public\x00num_batch_dims\x000\x00lapack_cgeqrf_ffi\x00lapack_cungqr_ffi\x00\x08[\x17\x057\x01\x0bE[]gi\x03k\x03m\x03K\x11MOwEQSyW\x07GGG\x11MO\x7fEQW\x81S\x03Y\x03\x85\x05\x87\x89', xla_call_module_version=9, nr_devices=1, ) # End paste - -# Pasted from the test output (see export_back_compat_test_util.py module docstring) -data_2024_08_22['f32'] = dict( +data_2025_04_02['f32'] = dict( testdata_version=1, platform='cpu', custom_call_targets=['lapack_sgeqrf_ffi', 'lapack_sorgqr_ffi'], - serialized_date=datetime.date(2024, 8, 22), + serialized_date=datetime.date(2025, 4, 2), inputs=(), - expected_outputs=( - array( - [ - [0.0, 0.91287076, 0.4082487], - [-0.44721356, 0.36514866, -0.8164965], - [-0.8944271, -0.18257445, 0.40824816], - ], - dtype=float32, - ), - array( - [ - [-6.7082043e00, -8.0498438e00, -9.3914852e00], - [0.0000000e00, 1.0954441e00, 2.1908894e00], - [0.0000000e00, 0.0000000e00, 7.1525574e-07], - ], - dtype=float32, - ), - ), + expected_outputs=(array([[ 0. , 0.91287076, 0.4082487 ], + [-0.44721356, 0.36514866, -0.8164965 ], + [-0.8944271 , -0.18257445, 0.40824816]], dtype=float32), array([[-6.7082043e+00, -8.0498438e+00, -9.3914852e+00], + [ 0.0000000e+00, 1.0954441e+00, 2.1908894e+00], + [ 0.0000000e+00, 0.0000000e+00, 7.1525574e-07]], dtype=float32)), mlir_module_text=r""" -#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":364:11) -#loc10 = loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]"(#loc3)) module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<3x3xf32> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<3x3xf32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { + func.func public @main() -> (tensor<3x3xf32> {jax.result_info = "result[0]"}, tensor<3x3xf32> {jax.result_info = "result[1]"}) { + %c = stablehlo.constant dense<-1> : tensor loc(#loc) + %cst = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc) %0 = stablehlo.iota dim = 0 : tensor<9xf32> loc(#loc4) %1 = stablehlo.reshape %0 : (tensor<9xf32>) -> tensor<3x3xf32> loc(#loc5) - %c = stablehlo.constant dense<3> : tensor loc(#loc6) - %c_0 = stablehlo.constant dense<3> : tensor loc(#loc6) - %2:2 = stablehlo.custom_call @lapack_sgeqrf_ffi(%1) {mhlo.backend_config = {}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>]} : (tensor<3x3xf32>) -> (tensor<3x3xf32>, tensor<3xf32>) loc(#loc6) - %cst = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc7) - %3 = stablehlo.pad %2#0, %cst, low = [0, 0], high = [0, 0], interior = [0, 0] : (tensor<3x3xf32>, tensor) -> tensor<3x3xf32> loc(#loc8) - %c_1 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_2 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_3 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_4 = stablehlo.constant dense<1> : tensor loc(#loc9) - %c_5 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_6 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_7 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_8 = stablehlo.constant dense<96> : tensor loc(#loc9) - %4:3 = stablehlo.custom_call @lapack_sorgqr(%c_4, %c_5, %c_6, %c_7, %c_8, %3, %2#1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor<3x3xf32>, tensor<3xf32>) -> (tensor<3x3xf32>, tensor, tensor<96xf32>) loc(#loc9) - %c_9 = stablehlo.constant dense<0> : tensor loc(#loc9) - %5 = stablehlo.broadcast_in_dim %c_9, dims = [] : (tensor) -> tensor loc(#loc9) - %6 = stablehlo.compare EQ, %4#1, %5, SIGNED : (tensor, tensor) -> tensor loc(#loc9) - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc9) - %cst_10 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc9) - %8 = stablehlo.broadcast_in_dim %cst_10, dims = [] : (tensor) -> tensor<3x3xf32> loc(#loc9) - %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> loc(#loc9) - %10 = stablehlo.select %9, %4#0, %8 : tensor<3x3xi1>, tensor<3x3xf32> loc(#loc9) - %11 = call @triu(%2#0) : (tensor<3x3xf32>) -> tensor<3x3xf32> loc(#loc10) - return %10, %11 : tensor<3x3xf32>, tensor<3x3xf32> loc(#loc) + %2:2 = stablehlo.custom_call @lapack_sgeqrf_ffi(%1) {mhlo.backend_config = {}, mhlo.frontend_attributes = {num_batch_dims = "0"}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>]} : (tensor<3x3xf32>) -> (tensor<3x3xf32>, tensor<3xf32>) loc(#loc6) + %3 = stablehlo.pad %2#0, %cst, low = [0, 0], high = [0, 0], interior = [0, 0] : (tensor<3x3xf32>, tensor) -> tensor<3x3xf32> loc(#loc7) + %4 = stablehlo.custom_call @lapack_sorgqr_ffi(%3, %2#1) {mhlo.backend_config = {}, mhlo.frontend_attributes = {num_batch_dims = "0"}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>]} : (tensor<3x3xf32>, tensor<3xf32>) -> tensor<3x3xf32> loc(#loc8) + %5 = stablehlo.iota dim = 0 : tensor<3x3xi32> loc(#loc9) + %6 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<3x3xi32> loc(#loc10) + %7 = stablehlo.add %5, %6 : tensor<3x3xi32> loc(#loc10) + %8 = stablehlo.iota dim = 1 : tensor<3x3xi32> loc(#loc9) + %9 = stablehlo.compare GE, %7, %8, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> loc(#loc11) + %10 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<3x3xf32> loc(#loc12) + %11 = stablehlo.select %9, %10, %2#0 : tensor<3x3xi1>, tensor<3x3xf32> loc(#loc13) + return %4, %11 : tensor<3x3xf32>, tensor<3x3xf32> loc(#loc) } loc(#loc) - func.func private @triu(%arg0: tensor<3x3xf32> {mhlo.layout_mode = "default"} loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]"(#loc3))) -> (tensor<3x3xf32> {mhlo.layout_mode = "default"}) { - %0 = stablehlo.iota dim = 0 : tensor<3x3xi32> loc(#loc11) - %c = stablehlo.constant dense<-1> : tensor loc(#loc10) - %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<3x3xi32> loc(#loc12) - %2 = stablehlo.add %0, %1 : tensor<3x3xi32> loc(#loc12) - %3 = stablehlo.iota dim = 1 : tensor<3x3xi32> loc(#loc13) - %4 = stablehlo.compare GE, %2, %3, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> loc(#loc14) - %cst = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc10) - %5 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<3x3xf32> loc(#loc15) - %6 = stablehlo.select %4, %5, %arg0 : tensor<3x3xi1>, tensor<3x3xf32> loc(#loc16) - return %6 : tensor<3x3xf32> loc(#loc10) - } loc(#loc10) } loc(#loc) #loc = loc(unknown) -#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":363:26) -#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":363:14) -#loc4 = loc("jit()/jit(main)/iota[dtype=float32 shape=(9,) dimension=0]"(#loc1)) -#loc5 = loc("jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]"(#loc2)) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":411:26) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":411:14) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":412:11) +#loc4 = loc("jit()/jit(main)/iota"(#loc1)) +#loc5 = loc("jit()/jit(main)/reshape"(#loc2)) #loc6 = loc("jit()/jit(main)/geqrf"(#loc3)) -#loc7 = loc("jit()/jit(main)/qr[full_matrices=True]"(#loc3)) -#loc8 = loc("jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0))]"(#loc3)) -#loc9 = loc("jit()/jit(main)/householder_product"(#loc3)) -#loc11 = loc("jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]"(#loc3)) -#loc12 = loc("jit()/jit(main)/jit(triu)/add"(#loc3)) -#loc13 = loc("jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]"(#loc3)) -#loc14 = loc("jit()/jit(main)/jit(triu)/ge"(#loc3)) -#loc15 = loc("jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3, 3) broadcast_dimensions=()]"(#loc3)) -#loc16 = loc("jit()/jit(main)/jit(triu)/select_n"(#loc3)) +#loc7 = loc("jit()/jit(main)/pad"(#loc3)) +#loc8 = loc("jit()/jit(main)/householder_product"(#loc3)) +#loc9 = loc("jit()/jit(main)/iota"(#loc3)) +#loc10 = loc("jit()/jit(main)/add"(#loc3)) +#loc11 = loc("jit()/jit(main)/ge"(#loc3)) +#loc12 = loc("jit()/jit(main)/broadcast_in_dim"(#loc3)) +#loc13 = loc("jit()/jit(main)/select_n"(#loc3)) """, - mlir_module_serialized=( - b"ML\xefR\x01StableHLO_v0.9.0\x00\x01)\x05\x01\x03\x01\x03\x05\x03\x19\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x03\xaa\x02\x12\x023\x01\x9d\x0f\x17\x0b\x0f\x13\x13\x0b\x07\x0b\x0f\x13\x0f\x0b\x0b\x0b\x0b\x13\x0b\x0b\x0f\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b+\x0b\x0f\x0b\x0b\x0b33\x0b\x0f\x0b\x13\x0b\x13\x0f\x0b\x1b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x17\x0f\x0b\x17\x0bS\x0b\x0f\x0b#\x0b\x0b\x0b\x0f\x0b\x0b\x13\x13K\x13\x1b\x13\x03g\x0fO\x0b\x0b\x0b//\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b\x13\x0b\x0b\x0f\x1f\x0f\x0f\x0b\x1f/\x0b\x0b\x0b\x0f\x0f\x17\x13\x1f\x1f\x1f\x0b\x0b'\x0f\x17\x17\x1f\x0b\x1fO\x01\x07\x17\x17\x0b\x01\x05\x0b\x0f\x03/\x17\x0f\x0f\x07\x07\x0f\x17\x07\x07\x07\x17\x13\x17\x17\x13\x13\x13\x13\x13\x17\x13\x0f\x17\x02\x8a\t\x1d\x8f\x03\x17\x11\xb2\x05\x17\x05\x1f\x1dO\x03\x03\x03%\xd1\x03\x03\x05\xd9\x05!\x1f\x05#\x1dy\x03\x03\x03\x05\xeb\x11\x03\x05\x05%\x05'\x05)\x05+\x03\x03#\xcd\x05-\x05/\x1dW\x03\x051\x053\x03\x03\x05\xd7\x055\x057\x059\x05;\x05=\x05?\x05A\x05C\x03\tACE\x17G\x17\rI\x05E\x11\x01\x00\x05G\x05I\x05K\x03\x0b\x19\xa1\x1b\xb7\x1d\xb9\r\xc3\x1f\xc5\x03\x0b\x19\xad\x1b\xc9\x1d\xad\r\xaf\x1f\xcb\x05M\x1dS\x03\x05O\x03\x03\x05\xcf\x05Q\x03\x03#\xd3\x1d]\x03\x05S\x03\x05)\xb1+\xd5\x1dc\x03\x05U\x1dg\x03\x05W\x1dk\x03\x05Y\x1doq\x05[\x17\x11\xae\x055\x1duw\x05]\x17\x11\xae\x05\x1d\x05_\x03\x13/\xdb1\xb33\xdd5\xa17\xb5}\xdf9\xe1;\xe3=\xe7\x05a\x1d\x81\x03\x05c\x03\x07\x85\xa9\x87\xa9\x89\xa9\x05e\x05g\x05i\x1d\x8d\x03\x05k\x05m\x03\x03\x05\xe9\x03\x03\x05\xed\x03\x11/\xef1\xb33\xf15\xa17\xb59\xf3;\xf5=\xf9\x03\x03\x05\xfb\x03\x05)\xb1+\xfd\x03\x03\x05\xff\x1f-\x01\x1f'!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1do\x1dq\x1f)\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x1b\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1ds\x03\x03\xc7\x1du\t\x07\x1dw\x05\x01#\x1d\x03\x05\xbb\xbf\r\x05\xab\xbd\xa3\xa5\x1dy\r\x05\xab\xc1\xa3\xa5\x1d{\x1d}\x1d\x7f\r\x03\xa3\xa5#\x1f\x1d\x81\x13\r\x01\x1f\x07\t\xff\xff\xff\xff\x1f!\x01\x13\r\x05\x07\x05\x1f\x0f\t\x00\x00\x00\x00\x1f\t\x11\x03\x00\x00\x00\x00\x00\x00\x00\x0b\x03\x1d\x83\r\x01\x03\x03\x9f\x03\x03\xe5\x15\x03\x01\x01\x01\x03\x05\x9f\xa7\x1f\x07\t\x01\x00\x00\x00\x1f\x07\t\x03\x00\x00\x00\x1f\x07\t`\x00\x00\x00\x0b\x05\x1d\x85\x03\x0f\x9d\x9d\x9d\x9d\x9d\x9f\xa7\x03\x03\xf7\x15\x03\x01\x15\x01\x03\x07\x9f\x9d\xa7\x1f\x07\t\x00\x00\x00\x00\x07\x01\x1f\x0f\t\x00\x00\xc0\x7f\x1f\x1b!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x03%\x02\x02\x03\x03\x0e\x02\xaf\x05\x87\x01\t\x01\x02\x02)\x05\r\r\x0b)\x01\x17)\x01\r\t\x1d)\x01\x0b)\x05\r\r\x17\x01\x13\x1b)\x05\r\r\x13)\x03\t\r\x11\x01\x05\x05\x05\x11\x03\x05\x03\x05)\x03\x01\r)\x03%\x0b)\x03\r\x0b)\x03\t\x15)\x03\x05\x15)\x03\x02\x03\x0b)\x03\x01\x15)\x01\x13)\x05\x05\x05\x13\x04\x8a\x04\x05\x01\x11\x0f?\x07\x03\x01\t\t\x11\x0fK\x07\x039i\x07\x03m!\x03#\x15\x06s\x03\x05\x03\x01\x03\x03\x13\x0b\x03\t\x03\x03\x13\x0b\x03\t\x11\x07\x13{\x05\x05%\x03\x03\x03\x03\x7f-\x03\x0f\x17\x07\x8b\x83\x03\x05\x05\t\r\x03\x03\x01\x0b\x03\t\x03\x03\x01\x0b\x03\t\x03\x03\x01\x0b\x03\t\x03\x03\x01\x91\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x93\x03\x07\x11\x07\x01\x95\x07\x05\x07+\x0f\x17\x19\x1b\x1d\x1f\x0f\x0b\x03\x03\x01\x97\x03\x07\x05\x07\x01\t\x03\x07\x03'\x0b\x07\x01\x99\x03/\x05#)\x05\x07\x01\t\x031\x03+\x03\x03\x01\x9b\x03\x0f\x05\x07\x01\t\x03\x05\x03/\x05\x07\x01\x06\x02\x03\x19\x03-\r\x06\x01\x03\x05\x073!1\x19\x07\x07\n\x02\x03\x05\x03\t\x0f\x04\x0f\x0557\t\x11\x07M\x07\x03\x15+\x03\x05\x07\x07\x03Q!\x03\x11\x03\x03\x07U\x03\x07\x05\x07'\t\x03\x11\x03\x05\x13\x06'\x03\x11\x05\x03\x07\x07\x03[Y\x03\x11\x0b\x07a_\x03\x19\x05\t\x0b\x03\x03\x07-\x03\x0f\x05\x07e\t\x03\x05\x03\x0f\r\x06i\x03\x05\x07\r\x11\x01\x0f\x04\x07\x03\x13\x06\x03\x01\x05\x01\x00\x9e\x1a\x89\x0f\x1d%\x11\x0f\x0b\t\t\x03\x0b!\x11#Y\x87##%_)=\x85\x87W\xb3K\x9bM\x9bn\x03\x1b%)9\x1f/!!)#\x1f\x19+\x1b+\x1f\x1f\x15\x1d\x15i\x13\r\x11\x0f\x17\x0f\x1f\x15\x15\x17\x11\x11)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00iota_v1\x00func_v1\x00compare_v1\x00select_v1\x00return_v1\x00custom_call_v1\x00add_v1\x00reshape_v1\x00pad_v1\x00call_v1\x00value\x00sym_name\x00third_party/py/jax/tests/export_back_compat_test.py\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00broadcast_dimensions\x00compare_type\x00comparison_direction\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,)" - b' out_shardings=(UnspecifiedValue,) in_layouts=(None,)' - b' out_layouts=(None,) resource_env=None donated_invars=(False,)' - b' name=triu keep_unused=False' - b' inline=False]\x00jit()/jit(main)/jit(triu)/iota[dtype=int32' - b' shape=(3, 3)' - b' dimension=0]\x00jit()/jit(main)/jit(triu)/add\x00jit()/jit(main)/jit(triu)/iota[dtype=int32' - b' shape=(3, 3)' - b' dimension=1]\x00jit()/jit(main)/jit(triu)/ge\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3,' - b' 3) broadcast_dimensions=()]\x00jit()/jit(main)/jit(triu)/select_n\x00jit()/jit(main)/iota[dtype=float32' - b' shape=(9,)' - b' dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 3)' - b' dimensions=None]\x00jit()/jit(main)/geqrf\x00mhlo.backend_config\x00jit()/jit(main)/qr[full_matrices=True]\x00edge_padding_high\x00edge_padding_low\x00interior_padding\x00jit()/jit(main)/pad[padding_config=((0,' - b' 0, 0), (0, 0,' - b' 0))]\x00jit()/jit(main)/householder_product\x00mhlo.layout_mode\x00default\x00jax.result_info\x00triu\x00\x00[0]\x00[1]\x00main\x00public\x00private\x00lapack_sgeqrf_ffi\x00lapack_sorgqr\x00callee\x00' - ), + mlir_module_serialized=b'ML\xefR\rStableHLO_v1.9.3\x00\x01)\x05\x01\x05\x19\x01\x03\x0b\x03\x17\x0f\x13\x17\x1b\x1f#\'+/37\x03\xc5\x8b\'\x01E\x17\x07\x0b\x0f\x0b\x1b\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x17\x0f\x0b\x17\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x03G\x0b/\x0b\x0f\x0b\x0b\x0b\x0fO\x13\x0f\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x1f\x1f\x0b\x13\x0b\x0b\x0b\x0f\x17/\x0b\x0f\x13\x0f\x0b\x0b\x01\x05\x0b\x0f\x03#\x17\x07\x07\x17\x0f\x07\x0f\x07\x17\x13\x13\x13\x13\x13\x13\x17\x07\x02\x9a\x04\x17\x05r\x06\x17\x1f\x05\x1d\x11\x03\x05\x05\x1f\x03\x05\'o)q\x1d\t\x01\x1d7\x01\x03\x07\x13\x15\x17\x07\x19\x07\x05!\x11\x01\x00\x05#\x05%\x05\'\x1d\t\x1f\x17\x05n\x065\x1d#%\x05)\x17\x05n\x06\x1d\x05+\x05-\x1d-\x01\x05/\x1d1\x01\x051\x1d5\x01\x053\x055\x1d;\x01\x057\x1d?\x01\x059\x1dC\x01\x05;\x03\x01\x1f\x1f\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1d=\x13\t\x01\x0b\x03\x1d?\x05\x01\x03\x03U\x1f\x1b!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x05U}\x1f!\x01#\x15\x03\x05_c\r\x03Ia\x1dA\r\x03Ie\x1dC\x1dE\x1dG\x1f\r\t\xff\xff\xff\xff\x1f\x11\t\x00\x00\x00\x00\r\x01\r\x03su\x1dI\x1dK\x1dM\x03\x03{\x15\x03\x01\x01\x01\x1f\x1d\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1dO\x03\x03\x83\x15\x01\x01\x01\x13\t\x05\t\x07\x07\x05\x01\t\x01\x02\x02)\x05\r\r\x07\t\x1d)\x05\r\r\x0f)\x01\x0f\x1b)\x01\x07\x13\x11\x01\x05\x05\x05)\x03%\x07)\x03\r\x07)\x03\t\x13)\x03\x05\x13)\x03\t\t)\x03\x01\t)\x05\r\r%\x01\x04"\x02\x05\x01Q\x03\x11\x01\x07\x04\xff\x03\x01\x05\x0bP\x03\x03\x07\x04\xeb\x03\x1f=\x05B\x03\x05\x03\r\x05B\x03\x07\x03\x11\x03B\x1d\t\x03\x17\r\x06!\x03\x05\x03\x05\x07G+\x0b\x0b\x05\x05\x19\x03\x07\x0fF/\r\x03\x05\x05\t\x03\x07G3\x0b\x0f\x03\x05\x05\r\x0b\x03B\r\t\x03\x0b\tF\x0f\x11\x03\x0b\x03\x01\x11\x06\x0f\x03\x0b\x05\x11\x13\x03B\r\x13\x03\x0b\x13F9\x15\x03#\x05\x15\x17\tF=\x11\x03\x05\x03\x03\x15\x06A\x03\x05\x07\x19\x1b\t\x17\x04\x03\x05\x0f\x1d\x06\x03\x01\x05\x01\x00\xba\x0bQ%%\x05\x1f\x0f\x0b\x15\x15\x03!CS79Y9=3)A\x1b%)9;i\x15\x15\x17\x0f\x0f\x17\x11)\x1f\x19\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00iota_v1\x00constant_v1\x00custom_call_v1\x00broadcast_in_dim_v1\x00func_v1\x00reshape_v1\x00pad_v1\x00add_v1\x00compare_v1\x00select_v1\x00return_v1\x00third_party/py/jax/tests/export_back_compat_test.py\x00jit()/jit(main)/iota\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/reshape\x00mhlo.backend_config\x00mhlo.frontend_attributes\x00jit()/jit(main)/geqrf\x00jit()/jit(main)/pad\x00jit()/jit(main)/householder_product\x00jit()/jit(main)/add\x00jit()/jit(main)/ge\x00jit()/jit(main)/broadcast_in_dim\x00jit()/jit(main)/select_n\x00jax.result_info\x00\x00result[0]\x00result[1]\x00main\x00public\x00num_batch_dims\x000\x00lapack_sgeqrf_ffi\x00lapack_sorgqr_ffi\x00\x08[\x17\x057\x01\x0bE[]gi\x03k\x03m\x03K\x11MOwEQSyW\x07GGG\x11MO\x7fEQW\x81S\x03Y\x03\x85\x05\x87\x89', xla_call_module_version=9, nr_devices=1, ) # End paste - -# Pasted from the test output (see export_back_compat_test_util.py module docstring) -data_2024_08_22['f64'] = dict( +data_2025_04_02['f64'] = dict( testdata_version=1, platform='cpu', custom_call_targets=['lapack_dgeqrf_ffi', 'lapack_dorgqr_ffi'], - serialized_date=datetime.date(2024, 8, 22), + serialized_date=datetime.date(2025, 4, 2), inputs=(), - expected_outputs=( - array([ - [0.0, 0.9128709291752773, 0.40824829046386235], - [-0.447213595499958, 0.3651483716701102, -0.8164965809277263], - [-0.894427190999916, -0.1825741858350548, 0.40824829046386324], - ]), - array([ - [ - -6.7082039324993694e00, - -8.0498447189992444e00, - -9.3914855054991175e00, - ], - [ - 0.0000000000000000e00, - 1.0954451150103341e00, - 2.1908902300206665e00, - ], - [ - 0.0000000000000000e00, - 0.0000000000000000e00, - -8.8817841970012523e-16, - ], - ]), - ), + expected_outputs=(array([[ 0. , 0.9128709291752773 , 0.40824829046386235], + [-0.447213595499958 , 0.3651483716701102 , -0.8164965809277263 ], + [-0.894427190999916 , -0.1825741858350548 , 0.40824829046386324]]), array([[-6.7082039324993694e+00, -8.0498447189992444e+00, + -9.3914855054991175e+00], + [ 0.0000000000000000e+00, 1.0954451150103341e+00, + 2.1908902300206665e+00], + [ 0.0000000000000000e+00, 0.0000000000000000e+00, + -8.8817841970012523e-16]])), mlir_module_text=r""" -#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":364:11) -#loc10 = loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]"(#loc3)) module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<3x3xf64> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<3x3xf64> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { + func.func public @main() -> (tensor<3x3xf64> {jax.result_info = "result[0]"}, tensor<3x3xf64> {jax.result_info = "result[1]"}) { + %c = stablehlo.constant dense<-1> : tensor loc(#loc) + %cst = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc) %0 = stablehlo.iota dim = 0 : tensor<9xf64> loc(#loc4) %1 = stablehlo.reshape %0 : (tensor<9xf64>) -> tensor<3x3xf64> loc(#loc5) - %c = stablehlo.constant dense<3> : tensor loc(#loc6) - %c_0 = stablehlo.constant dense<3> : tensor loc(#loc6) - %2:2 = stablehlo.custom_call @lapack_dgeqrf_ffi(%1) {mhlo.backend_config = {}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>]} : (tensor<3x3xf64>) -> (tensor<3x3xf64>, tensor<3xf64>) loc(#loc6) - %cst = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc7) - %3 = stablehlo.pad %2#0, %cst, low = [0, 0], high = [0, 0], interior = [0, 0] : (tensor<3x3xf64>, tensor) -> tensor<3x3xf64> loc(#loc8) - %c_1 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_2 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_3 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_4 = stablehlo.constant dense<1> : tensor loc(#loc9) - %c_5 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_6 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_7 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_8 = stablehlo.constant dense<96> : tensor loc(#loc9) - %4:3 = stablehlo.custom_call @lapack_dorgqr(%c_4, %c_5, %c_6, %c_7, %c_8, %3, %2#1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor<3x3xf64>, tensor<3xf64>) -> (tensor<3x3xf64>, tensor, tensor<96xf64>) loc(#loc9) - %c_9 = stablehlo.constant dense<0> : tensor loc(#loc9) - %5 = stablehlo.broadcast_in_dim %c_9, dims = [] : (tensor) -> tensor loc(#loc9) - %6 = stablehlo.compare EQ, %4#1, %5, SIGNED : (tensor, tensor) -> tensor loc(#loc9) - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc9) - %cst_10 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc9) - %8 = stablehlo.broadcast_in_dim %cst_10, dims = [] : (tensor) -> tensor<3x3xf64> loc(#loc9) - %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> loc(#loc9) - %10 = stablehlo.select %9, %4#0, %8 : tensor<3x3xi1>, tensor<3x3xf64> loc(#loc9) - %11 = call @triu(%2#0) : (tensor<3x3xf64>) -> tensor<3x3xf64> loc(#loc10) - return %10, %11 : tensor<3x3xf64>, tensor<3x3xf64> loc(#loc) + %2:2 = stablehlo.custom_call @lapack_dgeqrf_ffi(%1) {mhlo.backend_config = {}, mhlo.frontend_attributes = {num_batch_dims = "0"}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>]} : (tensor<3x3xf64>) -> (tensor<3x3xf64>, tensor<3xf64>) loc(#loc6) + %3 = stablehlo.pad %2#0, %cst, low = [0, 0], high = [0, 0], interior = [0, 0] : (tensor<3x3xf64>, tensor) -> tensor<3x3xf64> loc(#loc7) + %4 = stablehlo.custom_call @lapack_dorgqr_ffi(%3, %2#1) {mhlo.backend_config = {}, mhlo.frontend_attributes = {num_batch_dims = "0"}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>]} : (tensor<3x3xf64>, tensor<3xf64>) -> tensor<3x3xf64> loc(#loc8) + %5 = stablehlo.iota dim = 0 : tensor<3x3xi32> loc(#loc9) + %6 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<3x3xi32> loc(#loc10) + %7 = stablehlo.add %5, %6 : tensor<3x3xi32> loc(#loc10) + %8 = stablehlo.iota dim = 1 : tensor<3x3xi32> loc(#loc9) + %9 = stablehlo.compare GE, %7, %8, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> loc(#loc11) + %10 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<3x3xf64> loc(#loc12) + %11 = stablehlo.select %9, %10, %2#0 : tensor<3x3xi1>, tensor<3x3xf64> loc(#loc13) + return %4, %11 : tensor<3x3xf64>, tensor<3x3xf64> loc(#loc) } loc(#loc) - func.func private @triu(%arg0: tensor<3x3xf64> {mhlo.layout_mode = "default"} loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]"(#loc3))) -> (tensor<3x3xf64> {mhlo.layout_mode = "default"}) { - %0 = stablehlo.iota dim = 0 : tensor<3x3xi32> loc(#loc11) - %c = stablehlo.constant dense<-1> : tensor loc(#loc10) - %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<3x3xi32> loc(#loc12) - %2 = stablehlo.add %0, %1 : tensor<3x3xi32> loc(#loc12) - %3 = stablehlo.iota dim = 1 : tensor<3x3xi32> loc(#loc13) - %4 = stablehlo.compare GE, %2, %3, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> loc(#loc14) - %cst = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc10) - %5 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<3x3xf64> loc(#loc15) - %6 = stablehlo.select %4, %5, %arg0 : tensor<3x3xi1>, tensor<3x3xf64> loc(#loc16) - return %6 : tensor<3x3xf64> loc(#loc10) - } loc(#loc10) } loc(#loc) #loc = loc(unknown) -#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":363:26) -#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":363:14) -#loc4 = loc("jit()/jit(main)/iota[dtype=float64 shape=(9,) dimension=0]"(#loc1)) -#loc5 = loc("jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]"(#loc2)) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":411:26) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":411:14) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":412:11) +#loc4 = loc("jit()/jit(main)/iota"(#loc1)) +#loc5 = loc("jit()/jit(main)/reshape"(#loc2)) #loc6 = loc("jit()/jit(main)/geqrf"(#loc3)) -#loc7 = loc("jit()/jit(main)/qr[full_matrices=True]"(#loc3)) -#loc8 = loc("jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0))]"(#loc3)) -#loc9 = loc("jit()/jit(main)/householder_product"(#loc3)) -#loc11 = loc("jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]"(#loc3)) -#loc12 = loc("jit()/jit(main)/jit(triu)/add"(#loc3)) -#loc13 = loc("jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]"(#loc3)) -#loc14 = loc("jit()/jit(main)/jit(triu)/ge"(#loc3)) -#loc15 = loc("jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3, 3) broadcast_dimensions=()]"(#loc3)) -#loc16 = loc("jit()/jit(main)/jit(triu)/select_n"(#loc3)) +#loc7 = loc("jit()/jit(main)/pad"(#loc3)) +#loc8 = loc("jit()/jit(main)/householder_product"(#loc3)) +#loc9 = loc("jit()/jit(main)/iota"(#loc3)) +#loc10 = loc("jit()/jit(main)/add"(#loc3)) +#loc11 = loc("jit()/jit(main)/ge"(#loc3)) +#loc12 = loc("jit()/jit(main)/broadcast_in_dim"(#loc3)) +#loc13 = loc("jit()/jit(main)/select_n"(#loc3)) """, - mlir_module_serialized=( - b"ML\xefR\x01StableHLO_v0.9.0\x00\x01)\x05\x01\x03\x01\x03\x05\x03\x19\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x03\xaa\x02\x12\x023\x01\x9d\x0f\x17\x0b\x0f\x13\x13\x0b\x07\x0b\x0f\x13\x0f\x0b\x0b\x0b\x0b\x13\x0b\x0b\x0f\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b+\x0b\x0f\x0b\x0b\x0b33\x0b\x0f\x0b\x13\x0b\x13\x0f\x0b\x1b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x17\x0f\x0b\x17\x0bS\x0b\x0f\x0b#\x0b\x0b\x0b\x0f\x0b\x0b\x13\x13K\x13\x1b\x13\x03g\x0fO\x0b\x0b\x0b//\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b\x13\x0b\x0b\x0f\x1f\x0f\x0f\x0b//\x0b\x0b\x0b\x0f\x0f\x17\x13\x1f\x1f\x1f\x0b\x0b'\x0f\x17\x17\x1f\x0b/O\x01\x07\x17\x17\x0b\x01\x05\x0b\x0f\x03/\x17\x0f\x0f\x07\x07\x0f\x17\x07\x07\x07\x17\x13\x17\x17\x13\x13\x13\x13\x13\x17\x13\x0f\x17\x02\xaa\t\x1d\x8f\x03\x17\x11\xb2\x05\x17\x05\x1f\x1dO\x03\x03\x03%\xd1\x03\x03\x05\xd9\x05!\x1f\x05#\x1dy\x03\x03\x03\x05\xeb\x11\x03\x05\x05%\x05'\x05)\x05+\x03\x03#\xcd\x05-\x05/\x1dW\x03\x051\x053\x03\x03\x05\xd7\x055\x057\x059\x05;\x05=\x05?\x05A\x05C\x03\tACE\x17G\x17\rI\x05E\x11\x01\x00\x05G\x05I\x05K\x03\x0b\x19\xa1\x1b\xb7\x1d\xb9\r\xc3\x1f\xc5\x03\x0b\x19\xad\x1b\xc9\x1d\xad\r\xaf\x1f\xcb\x05M\x1dS\x03\x05O\x03\x03\x05\xcf\x05Q\x03\x03#\xd3\x1d]\x03\x05S\x03\x05)\xb1+\xd5\x1dc\x03\x05U\x1dg\x03\x05W\x1dk\x03\x05Y\x1doq\x05[\x17\x11\xae\x055\x1duw\x05]\x17\x11\xae\x05\x1d\x05_\x03\x13/\xdb1\xb33\xdd5\xa17\xb5}\xdf9\xe1;\xe3=\xe7\x05a\x1d\x81\x03\x05c\x03\x07\x85\xa9\x87\xa9\x89\xa9\x05e\x05g\x05i\x1d\x8d\x03\x05k\x05m\x03\x03\x05\xe9\x03\x03\x05\xed\x03\x11/\xef1\xb33\xf15\xa17\xb59\xf3;\xf5=\xf9\x03\x03\x05\xfb\x03\x05)\xb1+\xfd\x03\x03\x05\xff\x1f-\x01\x1f'!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1do\x1dq\x1f)\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x1b\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1ds\x03\x03\xc7\x1du\t\x07\x1dw\x05\x01#\x1d\x03\x05\xbb\xbf\r\x05\xab\xbd\xa3\xa5\x1dy\r\x05\xab\xc1\xa3\xa5\x1d{\x1d}\x1d\x7f\r\x03\xa3\xa5#\x1f\x1d\x81\x13\r\x01\x1f\x07\t\xff\xff\xff\xff\x1f!\x01\x13\r\x05\x07\x05\x1f\x0f\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\t\x11\x03\x00\x00\x00\x00\x00\x00\x00\x0b\x03\x1d\x83\r\x01\x03\x03\x9f\x03\x03\xe5\x15\x03\x01\x01\x01\x03\x05\x9f\xa7\x1f\x07\t\x01\x00\x00\x00\x1f\x07\t\x03\x00\x00\x00\x1f\x07\t`\x00\x00\x00\x0b\x05\x1d\x85\x03\x0f\x9d\x9d\x9d\x9d\x9d\x9f\xa7\x03\x03\xf7\x15\x03\x01\x15\x01\x03\x07\x9f\x9d\xa7\x1f\x07\t\x00\x00\x00\x00\x07\x01\x1f\x0f\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x1b!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x03%\x02\x02\x03\x03\x0e\x02\xaf\x05\x87\x01\t\x01\x02\x02)\x05\r\r\x0b)\x01\x17)\x01\r\x0b\x1d)\x01\x0b)\x05\r\r\x17\x01\x13\x1b)\x05\r\r\x13)\x03\t\r\x11\x01\x05\x05\x05\x11\x03\x05\x03\x05)\x03\x01\r)\x03%\x0b)\x03\r\x0b)\x03\t\x15)\x03\x05\x15)\x03\x02\x03\x0b)\x03\x01\x15)\x01\x13)\x05\x05\x05\x13\x04\x8a\x04\x05\x01\x11\x0f?\x07\x03\x01\t\t\x11\x0fK\x07\x039i\x07\x03m!\x03#\x15\x06s\x03\x05\x03\x01\x03\x03\x13\x0b\x03\t\x03\x03\x13\x0b\x03\t\x11\x07\x13{\x05\x05%\x03\x03\x03\x03\x7f-\x03\x0f\x17\x07\x8b\x83\x03\x05\x05\t\r\x03\x03\x01\x0b\x03\t\x03\x03\x01\x0b\x03\t\x03\x03\x01\x0b\x03\t\x03\x03\x01\x91\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x93\x03\x07\x11\x07\x01\x95\x07\x05\x07+\x0f\x17\x19\x1b\x1d\x1f\x0f\x0b\x03\x03\x01\x97\x03\x07\x05\x07\x01\t\x03\x07\x03'\x0b\x07\x01\x99\x03/\x05#)\x05\x07\x01\t\x031\x03+\x03\x03\x01\x9b\x03\x0f\x05\x07\x01\t\x03\x05\x03/\x05\x07\x01\x06\x02\x03\x19\x03-\r\x06\x01\x03\x05\x073!1\x19\x07\x07\n\x02\x03\x05\x03\t\x0f\x04\x0f\x0557\t\x11\x07M\x07\x03\x15+\x03\x05\x07\x07\x03Q!\x03\x11\x03\x03\x07U\x03\x07\x05\x07'\t\x03\x11\x03\x05\x13\x06'\x03\x11\x05\x03\x07\x07\x03[Y\x03\x11\x0b\x07a_\x03\x19\x05\t\x0b\x03\x03\x07-\x03\x0f\x05\x07e\t\x03\x05\x03\x0f\r\x06i\x03\x05\x07\r\x11\x01\x0f\x04\x07\x03\x13\x06\x03\x01\x05\x01\x00\x9e\x1a\x89\x0f\x1d%\x11\x0f\x0b\t\t\x03\x0b!\x11#Y\x87##%_)=\x85\x87W\xb3K\x9bM\x9bn\x03\x1b%)9\x1f/!!)#\x1f\x19+\x1b+\x1f\x1f\x15\x1d\x15i\x13\r\x11\x0f\x17\x0f\x1f\x15\x15\x17\x11\x11)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00iota_v1\x00func_v1\x00compare_v1\x00select_v1\x00return_v1\x00custom_call_v1\x00add_v1\x00reshape_v1\x00pad_v1\x00call_v1\x00value\x00sym_name\x00third_party/py/jax/tests/export_back_compat_test.py\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00broadcast_dimensions\x00compare_type\x00comparison_direction\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,)" - b' out_shardings=(UnspecifiedValue,) in_layouts=(None,)' - b' out_layouts=(None,) resource_env=None donated_invars=(False,)' - b' name=triu keep_unused=False' - b' inline=False]\x00jit()/jit(main)/jit(triu)/iota[dtype=int32' - b' shape=(3, 3)' - b' dimension=0]\x00jit()/jit(main)/jit(triu)/add\x00jit()/jit(main)/jit(triu)/iota[dtype=int32' - b' shape=(3, 3)' - b' dimension=1]\x00jit()/jit(main)/jit(triu)/ge\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3,' - b' 3) broadcast_dimensions=()]\x00jit()/jit(main)/jit(triu)/select_n\x00jit()/jit(main)/iota[dtype=float64' - b' shape=(9,)' - b' dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 3)' - b' dimensions=None]\x00jit()/jit(main)/geqrf\x00mhlo.backend_config\x00jit()/jit(main)/qr[full_matrices=True]\x00edge_padding_high\x00edge_padding_low\x00interior_padding\x00jit()/jit(main)/pad[padding_config=((0,' - b' 0, 0), (0, 0,' - b' 0))]\x00jit()/jit(main)/householder_product\x00mhlo.layout_mode\x00default\x00jax.result_info\x00triu\x00\x00[0]\x00[1]\x00main\x00public\x00private\x00lapack_dgeqrf_ffi\x00lapack_dorgqr\x00callee\x00' - ), + mlir_module_serialized=b'ML\xefR\rStableHLO_v1.9.3\x00\x01)\x05\x01\x05\x19\x01\x03\x0b\x03\x17\x0f\x13\x17\x1b\x1f#\'+/37\x03\xc5\x8b\'\x01E\x17\x07\x0b\x0f\x0b\x1b\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x17\x0f\x0b\x17\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x03G\x0b/\x0b\x0f\x0b\x0b\x0b\x0fO\x13\x0f\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x1f/\x0b\x13\x0b\x0b\x0b\x0f\x17/\x0b\x0f\x13\x0f\x0b\x0b\x01\x05\x0b\x0f\x03#\x17\x07\x07\x17\x0f\x07\x0f\x07\x17\x13\x13\x13\x13\x13\x13\x17\x07\x02\xaa\x04\x17\x05r\x06\x17\x1f\x05\x1d\x11\x03\x05\x05\x1f\x03\x05\'o)q\x1d\t\x01\x1d7\x01\x03\x07\x13\x15\x17\x07\x19\x07\x05!\x11\x01\x00\x05#\x05%\x05\'\x1d\t\x1f\x17\x05n\x065\x1d#%\x05)\x17\x05n\x06\x1d\x05+\x05-\x1d-\x01\x05/\x1d1\x01\x051\x1d5\x01\x053\x055\x1d;\x01\x057\x1d?\x01\x059\x1dC\x01\x05;\x03\x01\x1f\x1f\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1d=\x13\t\x01\x0b\x03\x1d?\x05\x01\x03\x03U\x1f\x1b!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x05U}\x1f!\x01#\x15\x03\x05_c\r\x03Ia\x1dA\r\x03Ie\x1dC\x1dE\x1dG\x1f\r\t\xff\xff\xff\xff\x1f\x11\x11\x00\x00\x00\x00\x00\x00\x00\x00\r\x01\r\x03su\x1dI\x1dK\x1dM\x03\x03{\x15\x03\x01\x01\x01\x1f\x1d\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1dO\x03\x03\x83\x15\x01\x01\x01\x13\t\x05\t\x07\x07\x05\x01\t\x01\x02\x02)\x05\r\r\x07\x0b\x1d)\x05\r\r\x0f)\x01\x0f\x1b)\x01\x07\x13\x11\x01\x05\x05\x05)\x03%\x07)\x03\r\x07)\x03\t\x13)\x03\x05\x13)\x03\t\t)\x03\x01\t)\x05\r\r%\x01\x04"\x02\x05\x01Q\x03\x11\x01\x07\x04\xff\x03\x01\x05\x0bP\x03\x03\x07\x04\xeb\x03\x1f=\x05B\x03\x05\x03\r\x05B\x03\x07\x03\x11\x03B\x1d\t\x03\x17\r\x06!\x03\x05\x03\x05\x07G+\x0b\x0b\x05\x05\x19\x03\x07\x0fF/\r\x03\x05\x05\t\x03\x07G3\x0b\x0f\x03\x05\x05\r\x0b\x03B\r\t\x03\x0b\tF\x0f\x11\x03\x0b\x03\x01\x11\x06\x0f\x03\x0b\x05\x11\x13\x03B\r\x13\x03\x0b\x13F9\x15\x03#\x05\x15\x17\tF=\x11\x03\x05\x03\x03\x15\x06A\x03\x05\x07\x19\x1b\t\x17\x04\x03\x05\x0f\x1d\x06\x03\x01\x05\x01\x00\xba\x0bQ%%\x05\x1f\x0f\x0b\x15\x15\x03!CS79Y9=3)A\x1b%)9;i\x15\x15\x17\x0f\x0f\x17\x11)\x1f\x19\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00iota_v1\x00constant_v1\x00custom_call_v1\x00broadcast_in_dim_v1\x00func_v1\x00reshape_v1\x00pad_v1\x00add_v1\x00compare_v1\x00select_v1\x00return_v1\x00third_party/py/jax/tests/export_back_compat_test.py\x00jit()/jit(main)/iota\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/reshape\x00mhlo.backend_config\x00mhlo.frontend_attributes\x00jit()/jit(main)/geqrf\x00jit()/jit(main)/pad\x00jit()/jit(main)/householder_product\x00jit()/jit(main)/add\x00jit()/jit(main)/ge\x00jit()/jit(main)/broadcast_in_dim\x00jit()/jit(main)/select_n\x00jax.result_info\x00\x00result[0]\x00result[1]\x00main\x00public\x00num_batch_dims\x000\x00lapack_dgeqrf_ffi\x00lapack_dorgqr_ffi\x00\x08[\x17\x057\x01\x0bE[]gi\x03k\x03m\x03K\x11MOwEQSyW\x07GGG\x11MO\x7fEQW\x81S\x03Y\x03\x85\x05\x87\x89', xla_call_module_version=9, nr_devices=1, ) # End paste diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_svd_lapack_gesdd.py b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_svd_lapack_gesdd.py index 2d71308caeda..995847a03a60 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_svd_lapack_gesdd.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_svd_lapack_gesdd.py @@ -17,435 +17,8 @@ import datetime from numpy import array, float32, complex64 -data_2023_06_19 = {} - - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_06_19["f32"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_sgesdd'], - serialized_date=datetime.date(2023, 6, 22), - inputs=(array([[[ 1.5410905 , -2.775912 , -2.374003 , 4.028736 ], - [-0.56933475, 1.6115232 , 0.9041465 , -0.8321383 ], - [-5.382895 , 4.734856 , 2.1972926 , 1.5553856 ], - [ 0.5109847 , -1.1969309 , 3.3766198 , -1.3678027 ]], - - [[ 2.2637439 , 3.406768 , 4.809871 , 2.8010902 ], - [-1.9981416 , -0.6599986 , 0.5138156 , 4.5982494 ], - [-2.335944 , -9.151717 , -1.0481138 , 2.272443 ], - [-8.257684 , 1.8223318 , 0.38403794, 5.0769973 ]]], - dtype=float32),), - expected_outputs=(array([[[-0.48540133 , 0.6682397 , -0.48819906 , -0.28196266 ], - [ 0.2180054 , -0.13631375 , 0.14819765 , -0.95495003 ], - [ 0.8457052 , 0.44643915 , -0.27943406 , 0.08597418 ], - [ 0.040523227, -0.57928085 , -0.8133977 , -0.03429017 ]], - - [[-0.21146733 , 0.46376425 , 0.786309 , 0.34917438 ], - [ 0.3461469 , 0.21883713 , 0.3399653 , -0.84659094 ], - [ 0.6526192 , -0.5834038 , 0.3972404 , 0.2755518 ], - [ 0.6399631 , 0.6298203 , -0.32915345 , 0.2922879 ]]], - dtype=float32), array([[ 8.551608 , 5.3574076, 2.8073738, 0.5226082], - [11.457576 , 10.041606 , 5.6716514, 1.4754109]], dtype=float32), array([[[-0.6319046 , 0.6612254 , 0.39110154 , -0.102553196], - [-0.2971051 , 0.13673358 , -0.50112 , 0.80119365 ], - [ 0.08969147 , 0.4433047 , -0.73647296 , -0.5030348 ], - [-0.7101976 , -0.5895471 , -0.23135659 , -0.30745354 ]], - - [[-0.6964344 , -0.5023085 , -0.11150039 , 0.50023323 ], - [-0.32121164 , 0.7889568 , 0.3183193 , 0.41598475 ], - [ 0.5096958 , -0.31399378 , 0.60193455 , 0.5284816 ], - [-0.3898877 , -0.16322286 , 0.7238198 , -0.5453721 ]]], - dtype=float32)), - mlir_module_text=r""" -#loc = loc(unknown) -module @jit_func attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<2x4x4xf32> {jax.arg_info = "input", mhlo.sharding = "{replicated}"} loc(unknown)) -> (tensor<2x4x4xf32> {jax.result_info = "[0]"}, tensor<2x4xf32> {jax.result_info = "[1]"}, tensor<2x4x4xf32> {jax.result_info = "[2]"}) { - %0 = stablehlo.constant dense<1> : tensor loc(#loc2) - %1 = stablehlo.constant dense<1> : tensor loc(#loc2) - %2 = stablehlo.constant dense<2> : tensor loc(#loc2) - %3 = stablehlo.constant dense<4> : tensor loc(#loc2) - %4 = stablehlo.constant dense<4> : tensor loc(#loc2) - %5 = stablehlo.constant dense<268> : tensor loc(#loc2) - %6:7 = stablehlo.custom_call @lapack_sgesdd(%0, %1, %2, %3, %4, %5, %arg0) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor<2x4x4xf32>) -> (tensor<2x4x4xf32>, tensor<2x4xf32>, tensor<2x4x4xf32>, tensor<2x4x4xf32>, tensor<2xi32>, tensor<32xi32>, tensor<268xf32>) loc(#loc2) - %7 = stablehlo.constant dense<0> : tensor loc(#loc2) - %8 = stablehlo.broadcast_in_dim %7, dims = [] : (tensor) -> tensor<2xi32> loc(#loc2) - %9 = stablehlo.compare EQ, %6#4, %8, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc2) - %10 = stablehlo.broadcast_in_dim %9, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc2) - %11 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc2) - %12 = stablehlo.broadcast_in_dim %11, dims = [] : (tensor) -> tensor<2x4xf32> loc(#loc2) - %13 = stablehlo.broadcast_in_dim %10, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x4xi1> loc(#loc2) - %14 = stablehlo.select %13, %6#1, %12 : tensor<2x4xi1>, tensor<2x4xf32> loc(#loc2) - %15 = stablehlo.broadcast_in_dim %9, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc2) - %16 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc2) - %17 = stablehlo.broadcast_in_dim %16, dims = [] : (tensor) -> tensor<2x4x4xf32> loc(#loc2) - %18 = stablehlo.broadcast_in_dim %15, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc2) - %19 = stablehlo.select %18, %6#2, %17 : tensor<2x4x4xi1>, tensor<2x4x4xf32> loc(#loc2) - %20 = stablehlo.broadcast_in_dim %9, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc2) - %21 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc2) - %22 = stablehlo.broadcast_in_dim %21, dims = [] : (tensor) -> tensor<2x4x4xf32> loc(#loc2) - %23 = stablehlo.broadcast_in_dim %20, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc2) - %24 = stablehlo.select %23, %6#3, %22 : tensor<2x4x4xi1>, tensor<2x4x4xf32> loc(#loc2) - return %19, %14, %24 : tensor<2x4x4xf32>, tensor<2x4xf32>, tensor<2x4x4xf32> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":355:0) -#loc2 = loc("jit(func)/jit(main)/svd[full_matrices=True compute_uv=True]"(#loc1)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1f\x05\x01\x03\x01\x03\x05\x03\x0f\x07\t\x0b\r\x0f\x11\x13\x03\xef\xa57\x01Q\x0f\x0b\x07\x13\x0b\x13\x13\x0f\x0b\x13\x13\x13#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x0b\x17\x0b\x13\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x03U\x0fo\x0b/\x0f\x1b\x0b\x0b\x0b\x0b\x0b\x17\x13\x0b\x13\x0b\x13\x0b\x0b\x0b\x1f\x1f\x1f\x1f\x0b\x0b\x0b\x0b\x0b'\x0f\x17'O\x1f\x0f\x0b\x0b/\x1fOo\x01\x03\x0f\x035\x0f\x1b\x07\x07\x17\x07\x07\x0f\x07\x13\x1b\x1b\x1f\x13\x17\x13\x13\x13\x13\x13\x13\x17\x13\x17\x13\x13\x02\xb6\x07\x1d+-\x05\x15\x1f\x03\x03\t\x97\x05\x17\x03\x03\t\x9d\x03\x03\x03\x9f\x11\x01\x05\x05\x19\x03\x03\x03y\x03\x03\x03}\x03\x03\t\xa3\x03\x07\x1b\x0f\x1d\x0f\x11\x1f\x05\x1b\x05\x1d\x05\x1f\x03\x0b#Y%e'g\x11u)w\x05!\x05#\x05%\x05'\x05)\x17/\x8e\x05\x01\x05+\x03\x03\x03{\x03\x03\x03\x7f\x03\x117\x819\x83;\x85=\x87?\x89A\x8bC\x8dE\x91\x05-\x05/\x051\x053\x055\x057\x059\x05;\x03\x03\x03\x95\x03\x05K\x99M\x9b\x05=\x05?\x03\x03\t\xa1\x1f!\x01\x1f#1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1dA\x1f'\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x03[\r\x05]_ac\x1dC\x1dE\x1dG\x1dI#\x1b\x03\x07imq\r\x03Uk\x1dK\r\x03Uo\x1dM\r\x03Us\x1dO\x1dQ\x1dS\x1f\x03\t\x01\x00\x00\x00\x1f\x03\t\x02\x00\x00\x00\x1f\x03\t\x04\x00\x00\x00\x1f\x03\t\x0c\x01\x00\x00\x0b\x05\x1dU\x1dW\x03\x01\x05\x01\x03\x0fQQQQQQS\x03\x03\x8f\x15\x03\x01\x19\x01\x03\x0fS\x93SSWWW\x1f%!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x03\t\x00\x00\x00\x00\x1f)\x01\t\x07\x07\x01\x1f/\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x11\t\x00\x00\xc0\x7f\x1f3!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f51\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x01\x02\x02)\x01\x13)\x07\t\x11\x11\t\x01\t)\x05\t\x11\t\x13\x1d)\x01\t\x1b)\x03\t\x13)\x07\t\x05\x05\x07)\x07\t\x11\x11\x07\x11\x03\x05\x07\x05\x0b\x05)\x03\x81\x13)\x03b\x08\t)\x03\x01\r)\x03\r\r)\x03\t\r)\x03\x05\r)\x03\x01\x0f)\x03\t\x07)\x05\t\x05\x07)\x03\x05\x0f)\x05\t\x11\x07)\x03\t\x0f)\x03\r\x0f\x04~\x03\x05\x01\x11\x05\x19\x07\x03\x01\x05\t\x11\x05!\x05\x03Ak\x03\x05\x05\x03\x03\x01\x13\x03\x03\x03\x03\x01\x13\x03\x03\x03\x03\x011\x03\x03\x03\x03\x01\x15\x03\x03\x03\x03\x01\x15\x03\x03\x03\x03\x013\x03\x03\x0b\x07\x015\x0f\x05\x0b\x05\x05\x15\x1d\x1f\x0f\x03\x05\x07\t\x0b\r\x01\x03\x03\x01G\x03\x03\x05\x07\x01\x07\x03\x15\x03\x1d\r\x07\x01I\x03+\x05\x17\x1f\x05\x07\x01\x0b\x03-\x03!\x03\x03\x01\r\x03\x11\x05\x07\x01\x07\x03\x0b\x03%\x05\x07\x01O\x031\x03#\x07\x06\x01\x03\x0b\x07)\x11'\x05\x07\x01\x0b\x03\x17\x03!\x03\x03\x01\r\x03\x11\x05\x07\x01\x07\x03\x05\x03/\x05\x07\x01\x17\x03\x19\x03-\x07\x06\x01\x03\x05\x073\x131\x05\x07\x01\x0b\x03\x17\x03!\x03\x03\x01\r\x03\x11\x05\x07\x01\x07\x03\x05\x039\x05\x07\x01\x17\x03\x19\x037\x07\x06\x01\x03\x05\x07=\x15;\x0f\x04\x05\x075+?\x06\x03\x01\x05\x01\x00\xbe\nY\x1d\x03\x0f\x0b\t\t\t\x1b\x1d\r\x1b!+\x1b\x1f/!!)#\x1f\x19\x97y\x1f\x15\x1d\x15\x13%)\x13+\r\x15\x17\x1f\x11\x15)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit(func)/jit(main)/svd[full_matrices=True compute_uv=True]\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00jax.arg_info\x00input\x00mhlo.sharding\x00{replicated}\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_sgesdd\x00", - xla_call_module_version=6, -) # End paste - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_06_19["f64"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_dgesdd'], - serialized_date=datetime.date(2023, 6, 22), - inputs=(array([[[ 0.3445689867809981 , 3.5114993759427104 , - 4.702602090972179 , -0.2702264758497052 ], - [ 2.209901632583705 , -2.6286702510632773 , - 4.591276599385847 , 3.4465035398844828 ], - [-1.5083742421154478 , 3.3225165204269635 , - 1.2596205557926703 , 3.524804355848018 ], - [ 1.5118969169108838 , 1.838885943509677 , - 2.818520751293422 , 3.06002540493494 ]], - - [[-2.4045510943950843 , -1.5657555633438576 , - -0.6061472334580296 , -0.23926156407779164], - [ 4.087879920053448 , -3.2507640936811715 , - -2.2556577657517476 , 6.090369998330348 ], - [ 1.1165401344486945 , 2.2134726894037247 , - 5.225178515435584 , 1.9794693474107725 ], - [-4.127878192684534 , -0.37313660200336163, - 0.7893465897510026 , -2.0315217791342848 ]]]),), - expected_outputs=(array([[[-0.5109626909166218 , -0.41744996156105785, - -0.731253241567692 , 0.1729779025790829 ], - [-0.5623501368035175 , 0.7608931604238581 , - 0.03470920608540986, 0.32186828528169453], - [-0.39585755254587435, -0.4954770291405409 , - 0.6561880513437818 , 0.4089212062978684 ], - [-0.5157288533916834 , -0.03577207859388855, - 0.18297871183094833, -0.8362194085221047 ]], - - [[-0.12124821978030875, -0.30260506534356213, - -0.5817463045715607 , -0.7451847292758064 ], - [ 0.8877417367326685 , -0.15794001239879188, - -0.3761180739267688 , 0.2133184375808915 ], - [ 0.03055221675864994, 0.9244545314395409 , - -0.3686107533067095 , -0.09260936183071355], - [-0.44303503260363514, -0.16990864078317836, - -0.619864940232637 , 0.624994775612963 ]]]), array([[8.951386926411189 , 5.762891699811626 , 3.839104008889441 , - 1.2696468971033248 ], - [9.21500688857692 , 6.477297670883227 , 3.24626945855818 , - 0.05112101994354587]]), array([[[-0.17890276924244797 , -0.2881812520705063 , - -0.7749616998111006 , -0.5332726590950898 ], - [ 0.38712159387038353 , -0.8985113987184378 , - 0.1397618670046424 , 0.15258033445914954 ], - [-0.23140697924040152 , -0.03708202130554661 , - -0.5045854966104308 , 0.8309447696839614 ], - [-0.8744034999217865 , -0.32901938548360005 , - 0.35396957633060866 , -0.043246992182741084]], - - [[ 0.6276106632546885 , -0.26728735347872895 , - -0.22995258718774078 , 0.6941067163520401 ], - [ 0.2802931697592562 , 0.4781137804659157 , - 0.808362569504731 , 0.19847646746808023 ], - [ 0.6187014005224262 , 0.47714095343944474 , - -0.3740686697560633 , -0.49961757159793246 ], - [-0.3804591585793503 , 0.6872417290515944 , - -0.3921025301835001 , 0.47875384105714014 ]]])), - mlir_module_text=r""" -#loc = loc(unknown) -module @jit_func attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<2x4x4xf64> {jax.arg_info = "input", mhlo.sharding = "{replicated}"} loc(unknown)) -> (tensor<2x4x4xf64> {jax.result_info = "[0]"}, tensor<2x4xf64> {jax.result_info = "[1]"}, tensor<2x4x4xf64> {jax.result_info = "[2]"}) { - %0 = stablehlo.constant dense<1> : tensor loc(#loc2) - %1 = stablehlo.constant dense<1> : tensor loc(#loc2) - %2 = stablehlo.constant dense<2> : tensor loc(#loc2) - %3 = stablehlo.constant dense<4> : tensor loc(#loc2) - %4 = stablehlo.constant dense<4> : tensor loc(#loc2) - %5 = stablehlo.constant dense<268> : tensor loc(#loc2) - %6:7 = stablehlo.custom_call @lapack_dgesdd(%0, %1, %2, %3, %4, %5, %arg0) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor<2x4x4xf64>) -> (tensor<2x4x4xf64>, tensor<2x4xf64>, tensor<2x4x4xf64>, tensor<2x4x4xf64>, tensor<2xi32>, tensor<32xi32>, tensor<268xf64>) loc(#loc2) - %7 = stablehlo.constant dense<0> : tensor loc(#loc2) - %8 = stablehlo.broadcast_in_dim %7, dims = [] : (tensor) -> tensor<2xi32> loc(#loc2) - %9 = stablehlo.compare EQ, %6#4, %8, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc2) - %10 = stablehlo.broadcast_in_dim %9, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc2) - %11 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc2) - %12 = stablehlo.broadcast_in_dim %11, dims = [] : (tensor) -> tensor<2x4xf64> loc(#loc2) - %13 = stablehlo.broadcast_in_dim %10, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x4xi1> loc(#loc2) - %14 = stablehlo.select %13, %6#1, %12 : tensor<2x4xi1>, tensor<2x4xf64> loc(#loc2) - %15 = stablehlo.broadcast_in_dim %9, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc2) - %16 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc2) - %17 = stablehlo.broadcast_in_dim %16, dims = [] : (tensor) -> tensor<2x4x4xf64> loc(#loc2) - %18 = stablehlo.broadcast_in_dim %15, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc2) - %19 = stablehlo.select %18, %6#2, %17 : tensor<2x4x4xi1>, tensor<2x4x4xf64> loc(#loc2) - %20 = stablehlo.broadcast_in_dim %9, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc2) - %21 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc2) - %22 = stablehlo.broadcast_in_dim %21, dims = [] : (tensor) -> tensor<2x4x4xf64> loc(#loc2) - %23 = stablehlo.broadcast_in_dim %20, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc2) - %24 = stablehlo.select %23, %6#3, %22 : tensor<2x4x4xi1>, tensor<2x4x4xf64> loc(#loc2) - return %19, %14, %24 : tensor<2x4x4xf64>, tensor<2x4xf64>, tensor<2x4x4xf64> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":355:0) -#loc2 = loc("jit(func)/jit(main)/svd[full_matrices=True compute_uv=True]"(#loc1)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1f\x05\x01\x03\x01\x03\x05\x03\x0f\x07\t\x0b\r\x0f\x11\x13\x03\xef\xa57\x01Q\x0f\x0b\x07\x13\x0b\x13\x13\x0f\x0b\x13\x13\x13#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x0b\x17\x0b\x13\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x03U\x0fo\x0b/\x0f\x1b\x0b\x0b\x0b\x0b\x0b\x17\x13\x0b\x13\x0b\x13\x0b\x0b\x0b\x1f\x1f\x1f\x1f\x0b\x0b\x0b\x0b\x0b'\x0f\x17'O\x1f\x0f\x0b\x0b//Oo\x01\x03\x0f\x035\x0f\x1b\x07\x07\x17\x07\x07\x0f\x07\x13\x1b\x1b\x1f\x13\x17\x13\x13\x13\x13\x13\x13\x17\x13\x17\x13\x13\x02\xc6\x07\x1d+-\x05\x15\x1f\x03\x03\t\x97\x05\x17\x03\x03\t\x9d\x03\x03\x03\x9f\x11\x01\x05\x05\x19\x03\x03\x03y\x03\x03\x03}\x03\x03\t\xa3\x03\x07\x1b\x0f\x1d\x0f\x11\x1f\x05\x1b\x05\x1d\x05\x1f\x03\x0b#Y%e'g\x11u)w\x05!\x05#\x05%\x05'\x05)\x17/\x8e\x05\x01\x05+\x03\x03\x03{\x03\x03\x03\x7f\x03\x117\x819\x83;\x85=\x87?\x89A\x8bC\x8dE\x91\x05-\x05/\x051\x053\x055\x057\x059\x05;\x03\x03\x03\x95\x03\x05K\x99M\x9b\x05=\x05?\x03\x03\t\xa1\x1f!\x01\x1f#1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1dA\x1f'\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x03[\r\x05]_ac\x1dC\x1dE\x1dG\x1dI#\x1b\x03\x07imq\r\x03Uk\x1dK\r\x03Uo\x1dM\r\x03Us\x1dO\x1dQ\x1dS\x1f\x03\t\x01\x00\x00\x00\x1f\x03\t\x02\x00\x00\x00\x1f\x03\t\x04\x00\x00\x00\x1f\x03\t\x0c\x01\x00\x00\x0b\x05\x1dU\x1dW\x03\x01\x05\x01\x03\x0fQQQQQQS\x03\x03\x8f\x15\x03\x01\x19\x01\x03\x0fS\x93SSWWW\x1f%!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x03\t\x00\x00\x00\x00\x1f)\x01\t\x07\x07\x01\x1f/\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x11\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f3!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f51\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x01\x02\x02)\x01\x13)\x07\t\x11\x11\t\x01\x0b)\x05\t\x11\t\x13\x1d)\x01\t\x1b)\x03\t\x13)\x07\t\x05\x05\x07)\x07\t\x11\x11\x07\x11\x03\x05\x07\x05\x0b\x05)\x03\x81\x13)\x03b\x08\t)\x03\x01\r)\x03\r\r)\x03\t\r)\x03\x05\r)\x03\x01\x0f)\x03\t\x07)\x05\t\x05\x07)\x03\x05\x0f)\x05\t\x11\x07)\x03\t\x0f)\x03\r\x0f\x04~\x03\x05\x01\x11\x05\x19\x07\x03\x01\x05\t\x11\x05!\x05\x03Ak\x03\x05\x05\x03\x03\x01\x13\x03\x03\x03\x03\x01\x13\x03\x03\x03\x03\x011\x03\x03\x03\x03\x01\x15\x03\x03\x03\x03\x01\x15\x03\x03\x03\x03\x013\x03\x03\x0b\x07\x015\x0f\x05\x0b\x05\x05\x15\x1d\x1f\x0f\x03\x05\x07\t\x0b\r\x01\x03\x03\x01G\x03\x03\x05\x07\x01\x07\x03\x15\x03\x1d\r\x07\x01I\x03+\x05\x17\x1f\x05\x07\x01\x0b\x03-\x03!\x03\x03\x01\r\x03\x11\x05\x07\x01\x07\x03\x0b\x03%\x05\x07\x01O\x031\x03#\x07\x06\x01\x03\x0b\x07)\x11'\x05\x07\x01\x0b\x03\x17\x03!\x03\x03\x01\r\x03\x11\x05\x07\x01\x07\x03\x05\x03/\x05\x07\x01\x17\x03\x19\x03-\x07\x06\x01\x03\x05\x073\x131\x05\x07\x01\x0b\x03\x17\x03!\x03\x03\x01\r\x03\x11\x05\x07\x01\x07\x03\x05\x039\x05\x07\x01\x17\x03\x19\x037\x07\x06\x01\x03\x05\x07=\x15;\x0f\x04\x05\x075+?\x06\x03\x01\x05\x01\x00\xbe\nY\x1d\x03\x0f\x0b\t\t\t\x1b\x1d\r\x1b!+\x1b\x1f/!!)#\x1f\x19\x97y\x1f\x15\x1d\x15\x13%)\x13+\r\x15\x17\x1f\x11\x15)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit(func)/jit(main)/svd[full_matrices=True compute_uv=True]\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00jax.arg_info\x00input\x00mhlo.sharding\x00{replicated}\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_dgesdd\x00", - xla_call_module_version=6, -) # End paste - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_06_19["c64"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_cgesdd'], - serialized_date=datetime.date(2023, 6, 22), - inputs=(array([[[ 1.6052934 +0.45878917j, 4.587192 -4.5177283j , - 0.4177733 -1.9419309j , -2.2248359 -4.5042715j ], - [-7.083374 -8.127356j , 2.7596245 -4.991001j , - -0.52622825+5.033981j , -0.35441273-1.8215327j ], - [-0.7996552 -2.4052901j , -0.8506142 -3.164714j , - -0.3090829 +2.2020447j , 1.2367196 +2.8830793j ], - [ 1.4633094 -0.5451007j , -3.7833478 +6.6770763j , - -3.1279542 -2.2322626j , -2.1099617 -2.9661314j ]], - - [[ 1.2560439 -5.4743752j , -2.0085676 +2.0063214j , - -0.8132642 -3.4407883j , -0.17360081+0.6419895j ], - [ 2.3756726 +6.3315964j , -0.31447247-1.9387872j , - 4.6732006 -4.286903j , 1.7702469 -1.4957623j ], - [ 1.6918924 -0.52161306j, 0.49963537+4.7751374j , - -1.9243752 -4.5870543j , 2.8829405 +1.7382988j ], - [ 1.4884951 -0.44194785j, -1.3645276 -2.8733373j , - -0.39430943+2.4366508j , -0.76268387+5.2014065j ]]], - dtype=complex64),), - expected_outputs=(array([[[ 0.016725361+0.19210356j , 0.5452691 +0.5572638j , - 0.41363996 +0.18964858j , -0.26152334 -0.28195143j ], - [ 0.53678626 +0.64057267j , -0.21783225 -0.21288812j , - 0.28426644 +0.30535883j , 0.15201284 +0.10768581j ], - [ 0.21286921 +0.154735j , 0.066471666-0.25652882j , - -0.4074613 -0.10356682j , -0.11794163 -0.81844836j ], - [-0.39079374 -0.20583564j , -0.18335931 -0.4421772j , - 0.63489586 +0.19758748j , 0.038680226-0.36351213j ]], - - [[-0.3178596 +0.39032036j , -0.1273337 -0.30841744j , - 0.26394194 +0.26815224j , -0.21332254 -0.66947937j ], - [-0.39241245 -0.60790956j , -0.14006221 +0.41040683j , - -0.0830612 -0.10184447j , -0.45091942 -0.2603987j ], - [-0.36103728 +0.2876153j , -0.4965461 +0.10084368j , - -0.13752826 -0.6203828j , 0.35439825 -0.028546419j], - [ 0.062335093-0.078214265j, 0.35014474 -0.5668197j , - -0.42214075 -0.5090833j , -0.2889288 -0.15894148j ]]], - dtype=complex64), array([[15.135655 , 9.373035 , 7.444931 , 0.41523397], - [12.316969 , 8.661011 , 5.005059 , 2.115905 ]], - dtype=float32), array([[[-0.6537865 +0.j , -0.20306697 -0.6166746j , - 0.29948467 +0.24257992j , -0.007604365+0.04945353j ], - [ 0.52712685 +0.j , -0.11291563 -0.7116954j , - -0.089219 -0.36348897j , -0.23654723 -0.08269388j ], - [-0.31538543 +0.j , -0.014410622+0.15958191j , - -0.17958623 -0.13690898j , -0.6930434 -0.58613425j ], - [-0.44185135 +0.j , 0.17604677 -0.050492246j, - -0.4213856 -0.69485146j , 0.22373371 +0.2465445j ]], - - [[-0.64551586 +0.j , 0.32932255 -0.11672116j , - -0.093527466+0.6710145j , -0.038554154+0.02716677j ], - [ 0.4241116 +0.j , 0.031135002-0.539813j , - -0.26271763 +0.22760014j , -0.63609654 -0.04817467j ], - [-0.4577485 +0.j , -0.15202768 +0.2734652j , - 0.18931003 -0.3297506j , -0.7331101 -0.10269702j ], - [ 0.44034657 +0.j , 0.29474002 +0.63307834j , - 0.31271848 +0.4216674j , -0.20595454 -0.020532424j]]], - dtype=complex64)), - mlir_module_text=r""" -#loc = loc(unknown) -module @jit_func attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<2x4x4xcomplex> {jax.arg_info = "input", mhlo.sharding = "{replicated}"} loc(unknown)) -> (tensor<2x4x4xcomplex> {jax.result_info = "[0]"}, tensor<2x4xf32> {jax.result_info = "[1]"}, tensor<2x4x4xcomplex> {jax.result_info = "[2]"}) { - %0 = stablehlo.constant dense<1> : tensor loc(#loc2) - %1 = stablehlo.constant dense<1> : tensor loc(#loc2) - %2 = stablehlo.constant dense<2> : tensor loc(#loc2) - %3 = stablehlo.constant dense<4> : tensor loc(#loc2) - %4 = stablehlo.constant dense<4> : tensor loc(#loc2) - %5 = stablehlo.constant dense<264> : tensor loc(#loc2) - %6:8 = stablehlo.custom_call @lapack_cgesdd(%0, %1, %2, %3, %4, %5, %arg0) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor<2x4x4xcomplex>) -> (tensor<2x4x4xcomplex>, tensor<2x4xf32>, tensor<2x4x4xcomplex>, tensor<2x4x4xcomplex>, tensor<2xi32>, tensor<32xi32>, tensor<100xf32>, tensor<264xcomplex>) loc(#loc2) - %7 = stablehlo.constant dense<0> : tensor loc(#loc2) - %8 = stablehlo.broadcast_in_dim %7, dims = [] : (tensor) -> tensor<2xi32> loc(#loc2) - %9 = stablehlo.compare EQ, %6#4, %8, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc2) - %10 = stablehlo.broadcast_in_dim %9, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc2) - %11 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc2) - %12 = stablehlo.broadcast_in_dim %11, dims = [] : (tensor) -> tensor<2x4xf32> loc(#loc2) - %13 = stablehlo.broadcast_in_dim %10, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x4xi1> loc(#loc2) - %14 = stablehlo.select %13, %6#1, %12 : tensor<2x4xi1>, tensor<2x4xf32> loc(#loc2) - %15 = stablehlo.broadcast_in_dim %9, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc2) - %16 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc2) - %17 = stablehlo.broadcast_in_dim %16, dims = [] : (tensor>) -> tensor<2x4x4xcomplex> loc(#loc2) - %18 = stablehlo.broadcast_in_dim %15, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc2) - %19 = stablehlo.select %18, %6#2, %17 : tensor<2x4x4xi1>, tensor<2x4x4xcomplex> loc(#loc2) - %20 = stablehlo.broadcast_in_dim %9, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc2) - %21 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc2) - %22 = stablehlo.broadcast_in_dim %21, dims = [] : (tensor>) -> tensor<2x4x4xcomplex> loc(#loc2) - %23 = stablehlo.broadcast_in_dim %20, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc2) - %24 = stablehlo.select %23, %6#3, %22 : tensor<2x4x4xi1>, tensor<2x4x4xcomplex> loc(#loc2) - return %19, %14, %24 : tensor<2x4x4xcomplex>, tensor<2x4xf32>, tensor<2x4x4xcomplex> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":355:0) -#loc2 = loc("jit(func)/jit(main)/svd[full_matrices=True compute_uv=True]"(#loc1)) -""", - mlir_module_serialized=b'ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1f\x05\x01\x03\x01\x03\x05\x03\x0f\x07\t\x0b\r\x0f\x11\x13\x03\xf9\xa9=\x01S\x0f\x0b\x07\x13\x0b\x13\x0f\x0b\x13\x13\x13\x13#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x0b\x17\x0b\x13\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x13\x03W\x0fo/\x0b\x0f\x1b\x0b\x0b\x0b\x0b\x0b\x17\x13\x0b\x13\x0b\x13\x0b\x0b\x0b\x1f\x1f\x1f\x1f\x0b\x0b\x0b\x0b\x0b\'\x0f\x17+O\x1f\x0f\x0b\x0b/\x1fO/o\x01\x03\x0f\x03;\x0f\x1b\x07\x07\x17\x07\x07\x0b\x07\x0f\x13\x0f\x1b\x1b\x1f\x13\x17\x17\x13\x13\x13\x13\x13\x13\x17\x13\x17\x13\x13\x02\x1e\x08\x1d+-\x05\x15\x1f\x03\x03\t\x99\x05\x17\x03\x03\t\x9f\x11\x01\x05\x05\x19\x03\x03\x03{\x03\x03\x03\x7f\x03\x03\x03\xa5\x03\x03\t\xa7\x03\x07\x1b\r\x1d\r\x0f\x1f\x05\x1b\x05\x1d\x05\x1f\x03\x0b#[%g\'i\x0fw)y\x05!\x05#\x05%\x05\'\x05)\x17/\x8e\x05\x01\x05+\x03\x03\x03}\x03\x03\x03\x81\x03\x117\x839\x85;\x87=\x89?\x8bA\x8dC\x8fE\x93\x05-\x05/\x051\x053\x055\x057\x059\x05;\x03\x03\x03\x97\x03\x05K\x9bM\x9d\x05=\x05?\x03\x03\x03\xa1\x03\x03\t\xa3\x1f\'\x01\x1f)1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f-\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1dA\x03\x03]\r\x05_ace\x1dC\x1dE\x1dG\x1dI#\x1f\x03\x07kos\r\x03Ym\x1dK\r\x03Yq\x1dM\r\x03Yu\x1dO\x1dQ\x1dS\x1f\x03\t\x01\x00\x00\x00\x1f\x03\t\x02\x00\x00\x00\x1f\x03\t\x04\x00\x00\x00\x1f\x03\t\x08\x01\x00\x00\x0b\x05\x1dU\x1dW\x03\x01\x05\x01\x03\x0fSSSSSSU\x03\x03\x91\x15\x03\x01\x19\x01\x03\x11U\x95UUWWWW\x1f+!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x03\t\x00\x00\x00\x00\x1f/\x01\t\x07\x07\x01\x1f5\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x19\t\x00\x00\xc0\x7f\x1f9!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f\x15\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f;1\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x01\x02\x02)\x01\x13)\x07\t\x11\x11\x11\x01\t)\x05\t\x11\t\x13\x1d\x03\t\x1b)\x01\x11)\x03\t\x13)\x01\t)\x07\t\x05\x05\x07)\x07\t\x11\x11\x07\x11\x03\x05\x07\x05\x0b\x05)\x03\x81\x13)\x03"\x03\t)\x03B\x08\x11)\x03\x01\r)\x03\r\r)\x03\t\r)\x03\x05\r)\x03\x01\x0f)\x03\t\x07)\x05\t\x05\x07)\x03\x05\x0f)\x05\t\x11\x07)\x03\t\x0f)\x03\r\x0f\x04\x82\x03\x05\x01\x11\x05\x19\x07\x03\x01\x05\t\x11\x05!\x05\x03Ck\x03\x05\x05\x03\x03\x01\x11\x03\x03\x03\x03\x01\x11\x03\x03\x03\x03\x011\x03\x03\x03\x03\x01\x13\x03\x03\x03\x03\x01\x13\x03\x03\x03\x03\x013\x03\x03\x0b\x07\x015\x11\x05\x0b\x05\x05\x17!#%\x0f\x03\x05\x07\t\x0b\r\x01\x03\x03\x01G\x03\x03\x05\x07\x01\x07\x03\x17\x03\x1f\r\x07\x01I\x031\x05\x17!\x05\x07\x01\x0b\x033\x03#\x03\x03\x01O\x03\x19\x05\x07\x01\x07\x03\x0b\x03\'\x05\x07\x01Q\x037\x03%\x07\x06\x01\x03\x0b\x07+\x11)\x05\x07\x01\x0b\x03\x1b\x03#\x03\x03\x01\x15\x03\x15\x05\x07\x01\x07\x03\x05\x031\x05\x07\x01\x17\x03\x1d\x03/\x07\x06\x01\x03\x05\x075\x133\x05\x07\x01\x0b\x03\x1b\x03#\x03\x03\x01\x15\x03\x15\x05\x07\x01\x07\x03\x05\x03;\x05\x07\x01\x17\x03\x1d\x039\x07\x06\x01\x03\x05\x07?\x15=\x0f\x04\x05\x077-A\x06\x03\x01\x05\x01\x00\xbe\nY\x1d\x03\x0f\x0b\t\t\t\x1b\x1d\r\x1b!+\x1b\x1f/!!)#\x1f\x19\x97y\x1f\x15\x1d\x15\x13%)\x13+\r\x15\x17\x1f\x11\x15)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit(func)/jit(main)/svd[full_matrices=True compute_uv=True]\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00jax.arg_info\x00input\x00mhlo.sharding\x00{replicated}\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_cgesdd\x00', - xla_call_module_version=6, -) # End paste - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_06_19["c128"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_zgesdd'], - serialized_date=datetime.date(2023, 6, 22), - inputs=(array([[[-0.9247611722912019-1.3615157109291343j , - -1.0663457975211892+4.73170030936092j , - -1.4918732811689488-2.880861991859318j , - -1.111356346434667 -2.869701609083459j ], - [-4.71291623424314 -1.5444012898828912j , - -5.232967549101415 -0.41287816948482003j, - 0.8905737109262459+9.50245186328329j , - 4.397722119094926 -6.842005210371916j ], - [ 1.9369405063276903+2.3496014107398917j , - -1.5609345742256133+4.2102103739897805j , - 0.6596030248996742+5.195353435247212j , - 0.6315014498240328-1.2778849649354402j ], - [ 5.115159214503849 -0.8856276268773485j , - 1.3719934567460779-2.236070491368575j , - 0.4974504006612811-3.0462081956756637j , - -0.2620346712025989+4.424682727912594j ]], - - [[-1.8242711798401063-0.8543252170262536j , - -2.724527211360488 +2.256038331706666j , - -1.2777487543905157+0.976556823566376j , - 3.7438974536713223-0.4994301527847589j ], - [-0.6359051102028691+2.730662301129662j , - -1.2877728943263032+3.9124921723649053j , - -3.4618573226579894+1.7835551986994034j , - -1.4710491660152465+2.144967500163963j ], - [-3.6013691182532828+2.8182351980619034j , - 2.0045935428878803+1.1146211993017152j , - -2.332213857689336 -0.874915651404938j , - -1.5393862406530452+0.6852883119580928j ], - [-2.674897392856801 +2.0724239502976984j , - -3.349108041292141 -1.0215359152295307j , - 0.2603515088197114-1.9093411474619364j , - 5.41252457188561 +8.634368042893094j ]]]),), - expected_outputs=(array([[[-0.04173678258633362+0.10796693731538423j , - 0.6813428383170976 +0.34327979589293334j , - -0.41770229002865755+0.20028957850808823j , - -0.43443513665085287+0.034743251442636465j], - [-0.8408468609573512 -0.1326064604464803j , - -0.21674151028481228+0.015170556885426551j, - 0.17147327711152338+0.1531041615298256j , - -0.3568765623609291 +0.21904384306708768j ], - [-0.2673618144044136 +0.1379833616281103j , - -0.17534278352558025-0.378992615769627j , - -0.8179957069096054 -0.037506032257391624j, - 0.25392637883428526-0.009771014463849802j], - [ 0.40569239968065934-0.08297706578106905j , - -0.4321527034953765 +0.09791545663574397j , - -0.23439193826962654-0.08427130532228161j , - -0.42348296145608866+0.6251448114949291j ]], - - [[ 0.0272684373986653 +0.36312055550335454j , - 0.270297713559288 +0.1304616587162563j , - 0.04286867013923673-0.4765859417602139j , - 0.7242702256119968 +0.15420620503522459j ], - [-0.08593436615104483+0.1189990183325552j , - 0.37050286109355285-0.6240865462984536j , - 0.46902056878806025-0.34747949920770266j , - -0.31667671459632074-0.10340064369932994j ], - [-0.07914843440873574-0.033487314943774035j, - 0.4110353453489128 -0.455090805566563j , - -0.431131803930273 +0.40910871949632j , - 0.13782730102420274+0.49428280062680086j ], - [-0.7478497242333215 +0.5283836938016964j , - -0.08345894989956631+0.011807690067190268j, - -0.27178304569905287+0.056526279406748176j, - -0.09911954913441999-0.2598859654000683j ]]]), array([[16.80132997488892 , 7.744755614558116 , 5.831221808032041 , - 1.1195288361137765], - [12.39537594694893 , 8.218551160453814 , 4.683634850274079 , - 1.8820915363839188]]), array([[[ 0.35796251040556704 +0.j , - 0.40179383774178046 -0.1269359716702074j , - -0.0751486661300563 -0.6109813931761136j , - -0.23049271148274278 +0.51209309438597j ], - [-0.4682861415308549 +0.j , - -0.013958972669495105+0.4210606476774211j , - -0.6006888466394119 -0.3766516564723718j , - -0.24264518623237025 -0.20408557153193485j ], - [-0.6392945524816095 +0.j , - 0.2432388607602898 -0.6679928485374246j , - 0.18168178910997038 -0.08126854868489754j , - -0.2030612067046724 -0.07124733621915219j ], - [-0.49383540371426055 +0.j , - -0.010402968929686592+0.3734624991410737j , - 0.27994282704104956 +0.01949406216762731j , - 0.32588905219319236 +0.6569569657140543j ]], - - [[ 0.2666920370516844 +0.j , - 0.24929033811571413 +0.27271089049933883j , - -0.012922512768026735+0.16383354123801513j , - 0.07388201893235022 -0.8717175469187741j ], - [-0.6156140469162428 +0.j , - -0.33787077397020143 +0.37797154650923376j , - -0.3916043058726119 -0.2839601305776179j , - -0.2714888604157674 -0.23729034093304682j ], - [ 0.5618758038857617 +0.j , - -0.5788776267734554 -0.13833058883452312j , - -0.48995086206819644 +0.19259594116096765j , - -0.22967101640965012 -0.012926826751577613j], - [-0.48393210641613593 +0.j , - -0.1049229605428438 -0.4911419972025977j , - -0.07782239226461217 +0.6751317817750165j , - 0.11941657609231515 -0.19354808489959852j ]]])), - mlir_module_text=r""" -#loc = loc(unknown) -module @jit_func attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<2x4x4xcomplex> {jax.arg_info = "input", mhlo.sharding = "{replicated}"} loc(unknown)) -> (tensor<2x4x4xcomplex> {jax.result_info = "[0]"}, tensor<2x4xf64> {jax.result_info = "[1]"}, tensor<2x4x4xcomplex> {jax.result_info = "[2]"}) { - %0 = stablehlo.constant dense<1> : tensor loc(#loc2) - %1 = stablehlo.constant dense<1> : tensor loc(#loc2) - %2 = stablehlo.constant dense<2> : tensor loc(#loc2) - %3 = stablehlo.constant dense<4> : tensor loc(#loc2) - %4 = stablehlo.constant dense<4> : tensor loc(#loc2) - %5 = stablehlo.constant dense<264> : tensor loc(#loc2) - %6:8 = stablehlo.custom_call @lapack_zgesdd(%0, %1, %2, %3, %4, %5, %arg0) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor<2x4x4xcomplex>) -> (tensor<2x4x4xcomplex>, tensor<2x4xf64>, tensor<2x4x4xcomplex>, tensor<2x4x4xcomplex>, tensor<2xi32>, tensor<32xi32>, tensor<100xf64>, tensor<264xcomplex>) loc(#loc2) - %7 = stablehlo.constant dense<0> : tensor loc(#loc2) - %8 = stablehlo.broadcast_in_dim %7, dims = [] : (tensor) -> tensor<2xi32> loc(#loc2) - %9 = stablehlo.compare EQ, %6#4, %8, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc2) - %10 = stablehlo.broadcast_in_dim %9, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc2) - %11 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc2) - %12 = stablehlo.broadcast_in_dim %11, dims = [] : (tensor) -> tensor<2x4xf64> loc(#loc2) - %13 = stablehlo.broadcast_in_dim %10, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x4xi1> loc(#loc2) - %14 = stablehlo.select %13, %6#1, %12 : tensor<2x4xi1>, tensor<2x4xf64> loc(#loc2) - %15 = stablehlo.broadcast_in_dim %9, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc2) - %16 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc2) - %17 = stablehlo.broadcast_in_dim %16, dims = [] : (tensor>) -> tensor<2x4x4xcomplex> loc(#loc2) - %18 = stablehlo.broadcast_in_dim %15, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc2) - %19 = stablehlo.select %18, %6#2, %17 : tensor<2x4x4xi1>, tensor<2x4x4xcomplex> loc(#loc2) - %20 = stablehlo.broadcast_in_dim %9, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc2) - %21 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc2) - %22 = stablehlo.broadcast_in_dim %21, dims = [] : (tensor>) -> tensor<2x4x4xcomplex> loc(#loc2) - %23 = stablehlo.broadcast_in_dim %20, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc2) - %24 = stablehlo.select %23, %6#3, %22 : tensor<2x4x4xi1>, tensor<2x4x4xcomplex> loc(#loc2) - return %19, %14, %24 : tensor<2x4x4xcomplex>, tensor<2x4xf64>, tensor<2x4x4xcomplex> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":355:0) -#loc2 = loc("jit(func)/jit(main)/svd[full_matrices=True compute_uv=True]"(#loc1)) -""", - mlir_module_serialized=b'ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1f\x05\x01\x03\x01\x03\x05\x03\x0f\x07\t\x0b\r\x0f\x11\x13\x03\xf9\xa9=\x01S\x0f\x0b\x07\x13\x0b\x13\x0f\x0b\x13\x13\x13\x13#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x0b\x17\x0b\x13\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x13\x03W\x0fo/\x0b\x0f\x1b\x0b\x0b\x0b\x0b\x0b\x17\x13\x0b\x13\x0b\x13\x0b\x0b\x0b\x1f\x1f\x1f\x1f\x0b\x0b\x0b\x0b\x0b\'\x0f\x17+O\x1f\x0f\x0b\x0b//OOo\x01\x03\x0f\x03;\x0f\x1b\x07\x07\x17\x07\x07\x0b\x07\x0f\x13\x0f\x1b\x1b\x1f\x13\x17\x17\x13\x13\x13\x13\x13\x13\x17\x13\x17\x13\x13\x02N\x08\x1d+-\x05\x15\x1f\x03\x03\t\x99\x05\x17\x03\x03\t\x9f\x11\x01\x05\x05\x19\x03\x03\x03{\x03\x03\x03\x7f\x03\x03\x03\xa5\x03\x03\t\xa7\x03\x07\x1b\r\x1d\r\x0f\x1f\x05\x1b\x05\x1d\x05\x1f\x03\x0b#[%g\'i\x0fw)y\x05!\x05#\x05%\x05\'\x05)\x17/\x8e\x05\x01\x05+\x03\x03\x03}\x03\x03\x03\x81\x03\x117\x839\x85;\x87=\x89?\x8bA\x8dC\x8fE\x93\x05-\x05/\x051\x053\x055\x057\x059\x05;\x03\x03\x03\x97\x03\x05K\x9bM\x9d\x05=\x05?\x03\x03\x03\xa1\x03\x03\t\xa3\x1f\'\x01\x1f)1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f-\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1dA\x03\x03]\r\x05_ace\x1dC\x1dE\x1dG\x1dI#\x1f\x03\x07kos\r\x03Ym\x1dK\r\x03Yq\x1dM\r\x03Yu\x1dO\x1dQ\x1dS\x1f\x03\t\x01\x00\x00\x00\x1f\x03\t\x02\x00\x00\x00\x1f\x03\t\x04\x00\x00\x00\x1f\x03\t\x08\x01\x00\x00\x0b\x05\x1dU\x1dW\x03\x01\x05\x01\x03\x0fSSSSSSU\x03\x03\x91\x15\x03\x01\x19\x01\x03\x11U\x95UUWWWW\x1f+!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x03\t\x00\x00\x00\x00\x1f/\x01\t\x07\x07\x01\x1f5\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x19\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f9!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f\x15!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f;1\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x01\x02\x02)\x01\x13)\x07\t\x11\x11\x11\x01\x0b)\x05\t\x11\t\x13\x1d\x03\t\x1b)\x01\x11)\x03\t\x13)\x01\t)\x07\t\x05\x05\x07)\x07\t\x11\x11\x07\x11\x03\x05\x07\x05\x0b\x05)\x03\x81\x13)\x03"\x03\t)\x03B\x08\x11)\x03\x01\r)\x03\r\r)\x03\t\r)\x03\x05\r)\x03\x01\x0f)\x03\t\x07)\x05\t\x05\x07)\x03\x05\x0f)\x05\t\x11\x07)\x03\t\x0f)\x03\r\x0f\x04\x82\x03\x05\x01\x11\x05\x19\x07\x03\x01\x05\t\x11\x05!\x05\x03Ck\x03\x05\x05\x03\x03\x01\x11\x03\x03\x03\x03\x01\x11\x03\x03\x03\x03\x011\x03\x03\x03\x03\x01\x13\x03\x03\x03\x03\x01\x13\x03\x03\x03\x03\x013\x03\x03\x0b\x07\x015\x11\x05\x0b\x05\x05\x17!#%\x0f\x03\x05\x07\t\x0b\r\x01\x03\x03\x01G\x03\x03\x05\x07\x01\x07\x03\x17\x03\x1f\r\x07\x01I\x031\x05\x17!\x05\x07\x01\x0b\x033\x03#\x03\x03\x01O\x03\x19\x05\x07\x01\x07\x03\x0b\x03\'\x05\x07\x01Q\x037\x03%\x07\x06\x01\x03\x0b\x07+\x11)\x05\x07\x01\x0b\x03\x1b\x03#\x03\x03\x01\x15\x03\x15\x05\x07\x01\x07\x03\x05\x031\x05\x07\x01\x17\x03\x1d\x03/\x07\x06\x01\x03\x05\x075\x133\x05\x07\x01\x0b\x03\x1b\x03#\x03\x03\x01\x15\x03\x15\x05\x07\x01\x07\x03\x05\x03;\x05\x07\x01\x17\x03\x1d\x039\x07\x06\x01\x03\x05\x07?\x15=\x0f\x04\x05\x077-A\x06\x03\x01\x05\x01\x00\xbe\nY\x1d\x03\x0f\x0b\t\t\t\x1b\x1d\r\x1b!+\x1b\x1f/!!)#\x1f\x19\x97y\x1f\x15\x1d\x15\x13%)\x13+\r\x15\x17\x1f\x11\x15)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit(func)/jit(main)/svd[full_matrices=True compute_uv=True]\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00jax.arg_info\x00input\x00mhlo.sharding\x00{replicated}\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_zgesdd\x00', - xla_call_module_version=6, -) # End paste - data_2024_08_13 = {} - # Pasted from the test output (see export_back_compat_test_util.py module docstring) data_2024_08_13["c128"] = dict( testdata_version=1, diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cuda_lu_pivots_to_permutation.py b/jax/_src/internal_test_util/export_back_compat_test_data/cuda_lu_pivots_to_permutation.py index 12285a45b77a..8063d9f44722 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cuda_lu_pivots_to_permutation.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cuda_lu_pivots_to_permutation.py @@ -16,11 +16,11 @@ from numpy import array, int32 # Pasted from the test output (see export_back_compat_test_util.py module docstring) -data_2024_08_08 = dict( +data_2025_04_01 = dict( testdata_version=1, platform='cuda', custom_call_targets=['cu_lu_pivots_to_permutation'], - serialized_date=datetime.date(2024, 8, 8), + serialized_date=datetime.date(2025, 4, 1), inputs=(), expected_outputs=(array([[[0, 1, 2, 3, 4, 5, 6, 7], [4, 5, 6, 7, 0, 1, 2, 3], @@ -31,25 +31,22 @@ [0, 1, 2, 3, 4, 5, 6, 7]]], dtype=int32),), mlir_module_text=r""" module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<2x3x8xi32> {jax.result_info = "", mhlo.layout_mode = "default"}) { + func.func public @main() -> (tensor<2x3x8xi32> {jax.result_info = "result"}) { %0 = stablehlo.iota dim = 0 : tensor<24xi32> loc(#loc4) %1 = stablehlo.reshape %0 : (tensor<24xi32>) -> tensor<2x3x4xi32> loc(#loc5) - %c = stablehlo.constant dense<2> : tensor loc(#loc6) - %c_0 = stablehlo.constant dense<3> : tensor loc(#loc6) - %c_1 = stablehlo.constant dense<4> : tensor loc(#loc6) - %2 = stablehlo.custom_call @cu_lu_pivots_to_permutation(%1) {mhlo.backend_config = {permutation_size = 8 : i32}, operand_layouts = [dense<[2, 1, 0]> : tensor<3xindex>], result_layouts = [dense<[2, 1, 0]> : tensor<3xindex>]} : (tensor<2x3x4xi32>) -> tensor<2x3x8xi32> loc(#loc6) + %2 = stablehlo.custom_call @cu_lu_pivots_to_permutation(%1) {mhlo.backend_config = {}, mhlo.frontend_attributes = {num_batch_dims = "2"}, operand_layouts = [dense<[2, 1, 0]> : tensor<3xindex>], result_layouts = [dense<[2, 1, 0]> : tensor<3xindex>]} : (tensor<2x3x4xi32>) -> tensor<2x3x8xi32> loc(#loc6) return %2 : tensor<2x3x8xi32> loc(#loc) } loc(#loc) } loc(#loc) #loc = loc(unknown) -#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":347:26) -#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":347:14) -#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":348:11) -#loc4 = loc("jit()/jit(main)/iota[dtype=int32 shape=(24,) dimension=0]"(#loc1)) -#loc5 = loc("jit()/jit(main)/reshape[new_sizes=(2, 3, 4) dimensions=None]"(#loc2)) -#loc6 = loc("jit()/jit(main)/lu_pivots_to_permutation[permutation_size=8]"(#loc3)) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":408:26) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":408:14) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":409:11) +#loc4 = loc("jit()/jit(main)/iota"(#loc1)) +#loc5 = loc("jit()/jit(main)/reshape"(#loc2)) +#loc6 = loc("jit()/jit(main)/lu_pivots_to_permutation"(#loc3)) """, - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1d\x05\x01\x03\x01\x03\x05\x03\r\x07\t\x0b\r\x0f\x11\x03\xa7}\x17\x01Q\x0f\x07\x0b\x0b\x0f\x0b+\x0b\x0f\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x17\x0f\x0b\x17\x13\x0b\x17\x13\x13S\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x03-\x0b\x0b\x0f\x0b\x0f\x1b\x0b\x0b\x0b\x0b\x0b\x0f///\x0b\x0b\x0b\x13\x0b\x0fo\x01\x05\x0b\x0f\x03\x13\x0f\x07\x1b\x07\x13\x13\x1b\x13\x07\x02Z\x04\x1d57\x1f\x05\x13\x05\x15\x11\x03\x05\x05\x17\x03\t\x0f\x11\x13\t\x15\t\x0b\x17\x05\x19\x11\x01\x00\x05\x1b\x05\x1d\x05\x1f\x03\x0b\x1bQ\x1dW\x1fY\x0bc!e\x05!\x05#\x05%\x05'\x03\x03%g\x05)\x1d)+\x05+\x17\x05n\x055\x1d/1\x05-\x17\x05n\x05\x1d\x03\x03\x07i\x05/\x17\x05r\x05\x17\x03\x03\x07k\x03\x03\x07m\x03\x13?oASCqEQGsIuKUMQOU\x051\x053\x055\x057\x059\x05;\x05=\x05?\x05A\x03\x01\x1dC\x03\x03{#\r\x03\x03[\r\x05]S_a\x1dE\x1dG\x1dI\x1dK\x1dM\x13\x0b\x01\x1f\x05\x11\x02\x00\x00\x00\x00\x00\x00\x00\x1f\x05\x11\x03\x00\x00\x00\x00\x00\x00\x00\x1f\x05\x11\x04\x00\x00\x00\x00\x00\x00\x00\x0b\x03\x1dO\x05\x01\r\x03wy\x1dQ\x13\x07!\x1f\x131\x02\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x01\x0b\x1b)\x07\t\r!\x07\x1d\x11\x01\x03\t)\x03a\x07)\x07\t\r\x11\x07)\x03\r\x15\x13\x04{\x05\x01\x11\x03\r\x07\x03\x01\x05\x05\x11\x03\x19\x07\x03\r\x1d\x07\x03'#\x03\x0f\t\x06-\x03\x11\x03\x01\x03\x03\x013\x03\x05\x03\x03\x019\x03\x05\x03\x03\x01;\x03\x05\x0b\x07\x01=\x03\t\x03\x03\r\x04\x03\x03\x0b\x06\x03\x01\x05\x01\x00f\x0cS#9\x0f\x0b\x11#!\x03\x1f/!)!)#\x1f\x19\x8b\x8b\x85\x1f\x1f\x15\x1d\x15\x1b%)9\x13\ri\x15\x1f\x17\x11\x11\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00func_v1\x00iota_v1\x00reshape_v1\x00custom_call_v1\x00return_v1\x00third_party/py/jax/tests/export_back_compat_test.py\x00value\x00sym_name\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00jit()/jit(main)/iota[dtype=int32 shape=(24,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(2, 3, 4) dimensions=None]\x00jit()/jit(main)/lu_pivots_to_permutation[permutation_size=8]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00mhlo.backend_config\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00\x00jax.result_info\x00mhlo.layout_mode\x00default\x00main\x00public\x00cu_lu_pivots_to_permutation\x00permutation_size\x00", + mlir_module_serialized=b"ML\xefR\rStableHLO_v1.9.3\x00\x01\x1d\x05\x01\x05\r\x01\x03\x0b\x03\x0b\x0f\x13\x17\x1b\x1f\x03yQ\x15\x01+\x07\x0b\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x17\x0f\x0b\x17\x1b\x0b\x0b\x0f\x0b\x17\x03'\x0b\x0f\x0b\x0f\x13\x0b\x0b\x0b\x0b\x0f\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0bo\x01\x05\x0b\x0f\x03\x11\x07\x1b\x13\x13\x07\x1b\x13\x07\x02\x9e\x02\x1f\x05\x11\x11\x03\x05\x03\x07\t\x0b\r\x05\x0f\x05\x05\x13\x11\x01\x00\x05\x15\x05\x17\x05\x19\x1d\x15\x17\x05\x1b\x17\x03b\x065\x1d\x1b\x1d\x05\x1d\x17\x03b\x06\x1d\x03\x05!?#A\x05\x1f\x05!\x1d')\x05#\x17\x03f\x06\x17\x03\x01\x03\x03O#\t\x03\x033\r\x0357\x1d%\x1d'\x1d)\x1d+\x13\r\x01\r\x01\r\x03CE\x1d-\x1d/\x0b\x03\x1d1\x1d3\x05\x01\x1f\x111\x02\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02\x1b)\x07\t\r!\x05\x11\x01\x03\x07)\x03a\x05\x1d)\x07\t\r\x11\x05)\x03\r\x13\x13\x04c\x05\x01Q\x01\x07\x01\x07\x04Q\x03\x01\x05\x03P\x01\x03\x07\x04=\x03\x07\x11\x05B\x13\x05\x03\x0b\x07\x06\x19\x03\x0f\x03\x01\tG%\x1f\x07\x03\x07\x03\x03\x0b\x04\x01\x03\x05\x06\x03\x01\x05\x01\x00J\x0759\x03\x05\x1f\x0f\x0b\x0f!c3)A;\x1b%)9i\x15\x1f\x17\x11\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00func_v1\x00iota_v1\x00reshape_v1\x00custom_call_v1\x00return_v1\x00third_party/py/jax/tests/export_back_compat_test.py\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/iota\x00jit()/jit(main)/reshape\x00mhlo.backend_config\x00mhlo.frontend_attributes\x00jit()/jit(main)/lu_pivots_to_permutation\x00jax.result_info\x00result\x00main\x00public\x00num_batch_dims\x002\x00\x00cu_lu_pivots_to_permutation\x00\x08+\t\x05#\x01\x0b+/19;\x03=\x11GIK+M-+-", xla_call_module_version=9, nr_devices=1, ) # End paste diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cuda_qr_cusolver_geqrf.py b/jax/_src/internal_test_util/export_back_compat_test_data/cuda_qr_cusolver_geqrf.py index be5c6e01f8d8..00ced41a0492 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cuda_qr_cusolver_geqrf.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cuda_qr_cusolver_geqrf.py @@ -15,149 +15,10 @@ # ruff: noqa import datetime -from numpy import array, float32, float64, complex64, complex128 +from numpy import array, float32, complex64 -data_2023_03_18 = {} - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_03_18["unbatched"] = dict( - testdata_version=1, - platform='cuda', - custom_call_targets=['cusolver_geqrf', 'cusolver_orgqr'], - serialized_date=datetime.date(2023, 3, 18), - inputs=(), - expected_outputs=(array([[ 0. , 0.9128705 , 0.40824863], - [-0.44721356, 0.36514878, -0.8164964 ], - [-0.8944271 , -0.18257457, 0.40824813]], dtype=float32), array([[-6.7082043e+00, -8.0498438e+00, -9.3914843e+00], - [ 0.0000000e+00, 1.0954436e+00, 2.1908882e+00], - [ 0.0000000e+00, 0.0000000e+00, 5.6703755e-08]], dtype=float32)), - mlir_module_text=r""" -module @jit__lambda_ { - func.func public @main() -> (tensor<3x3xf32> {jax.result_info = "[0]"}, tensor<3x3xf32> {jax.result_info = "[1]"}) { - %0 = stablehlo.iota dim = 0 : tensor<9xf32> - %1 = stablehlo.reshape %0 : (tensor<9xf32>) -> tensor<3x3xf32> - %2 = stablehlo.custom_call @cusolver_geqrf(%1) {api_version = 2 : i32, backend_config = "\00\00\00\00\01\00\00\00\03\00\00\00\03\00\00\00\00\00\03\00", operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor<3x3xf32>) -> tuple, tensor<3xf32>, tensor, tensor<196608xf32>> - %3 = stablehlo.get_tuple_element %2[0] : (tuple, tensor<3xf32>, tensor, tensor<196608xf32>>) -> tensor<3x3xf32> - %4 = stablehlo.get_tuple_element %2[1] : (tuple, tensor<3xf32>, tensor, tensor<196608xf32>>) -> tensor<3xf32> - %5 = stablehlo.get_tuple_element %2[2] : (tuple, tensor<3xf32>, tensor, tensor<196608xf32>>) -> tensor - %6 = stablehlo.get_tuple_element %2[3] : (tuple, tensor<3xf32>, tensor, tensor<196608xf32>>) -> tensor<196608xf32> - %7 = stablehlo.constant dense<0> : tensor - %8 = stablehlo.broadcast_in_dim %7, dims = [] : (tensor) -> tensor - %9 = stablehlo.compare EQ, %5, %8, SIGNED : (tensor, tensor) -> tensor - %10 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor) -> tensor<1x1xi1> - %11 = stablehlo.constant dense<0x7FC00000> : tensor - %12 = stablehlo.broadcast_in_dim %11, dims = [] : (tensor) -> tensor<3x3xf32> - %13 = stablehlo.broadcast_in_dim %10, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> - %14 = stablehlo.select %13, %3, %12 : tensor<3x3xi1>, tensor<3x3xf32> - %15 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor) -> tensor<1xi1> - %16 = stablehlo.constant dense<0x7FC00000> : tensor - %17 = stablehlo.broadcast_in_dim %16, dims = [] : (tensor) -> tensor<3xf32> - %18 = stablehlo.broadcast_in_dim %15, dims = [0] : (tensor<1xi1>) -> tensor<3xi1> - %19 = stablehlo.select %18, %4, %17 : tensor<3xi1>, tensor<3xf32> - %20 = stablehlo.constant dense<0.000000e+00> : tensor - %21 = stablehlo.pad %14, %20, low = [0, 0], high = [0, 0], interior = [0, 0] : (tensor<3x3xf32>, tensor) -> tensor<3x3xf32> - %22 = stablehlo.custom_call @cusolver_orgqr(%21, %19) {api_version = 2 : i32, backend_config = "\00\00\00\00\01\00\00\00\03\00\00\00\03\00\00\00\03\00\00\00 \81\00\00", operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor<3x3xf32>, tensor<3xf32>) -> tuple, tensor, tensor<33056xf32>> - %23 = stablehlo.get_tuple_element %22[0] : (tuple, tensor, tensor<33056xf32>>) -> tensor<3x3xf32> - %24 = stablehlo.get_tuple_element %22[1] : (tuple, tensor, tensor<33056xf32>>) -> tensor - %25 = stablehlo.get_tuple_element %22[2] : (tuple, tensor, tensor<33056xf32>>) -> tensor<33056xf32> - %26 = stablehlo.constant dense<0> : tensor - %27 = stablehlo.broadcast_in_dim %26, dims = [] : (tensor) -> tensor - %28 = stablehlo.compare EQ, %24, %27, SIGNED : (tensor, tensor) -> tensor - %29 = stablehlo.broadcast_in_dim %28, dims = [] : (tensor) -> tensor<1x1xi1> - %30 = stablehlo.constant dense<0x7FC00000> : tensor - %31 = stablehlo.broadcast_in_dim %30, dims = [] : (tensor) -> tensor<3x3xf32> - %32 = stablehlo.broadcast_in_dim %29, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> - %33 = stablehlo.select %32, %23, %31 : tensor<3x3xi1>, tensor<3x3xf32> - %34 = call @triu(%14) : (tensor<3x3xf32>) -> tensor<3x3xf32> - return %33, %34 : tensor<3x3xf32>, tensor<3x3xf32> - } - func.func private @triu(%arg0: tensor<3x3xf32>) -> tensor<3x3xf32> { - %0 = stablehlo.iota dim = 0 : tensor<3x3xi32> - %1 = stablehlo.constant dense<-1> : tensor - %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<3x3xi32> - %3 = stablehlo.add %0, %2 : tensor<3x3xi32> - %4 = stablehlo.iota dim = 1 : tensor<3x3xi32> - %5 = stablehlo.compare GE, %3, %4, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> - %6 = stablehlo.constant dense<0.000000e+00> : tensor - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<3x3xf32> - %8 = stablehlo.select %5, %7, %arg0 : tensor<3x3xi1>, tensor<3x3xf32> - return %8 : tensor<3x3xf32> - } -} -""", - mlir_module_serialized=b"ML\xefR\x03MLIRxxx-trunk\x00\x01+\x05\x01\x05\x01\x03\x05\x03\x1b\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f\x03~\x02\xf79\x01\x99\x0f\x0f\x17\x13\x0f\x07\x0b\x0b\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0b\x13\x17\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x13\x13\x1b\x13\x13\x0b33\x0b\x0f\x0b\x13\x0b\x13\x0f\x0b\x1b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0bK\x0b\x13\x13\x0f\x0b#\x0b\x0b\x0b\x0f\x0bK\x0b\x13\x0b\x03_O/\x0b/\x0b\x0f\x0b\x0b\x0b\x0b\x0f\x0f\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0b\x1f\x0b\x0b\x0f\x17\x1b\x0f\x0f\x0f\x0f\x1f\x0b\x1fO/\x0b\x0b\x13\x17\x039\x17\x0f\x0f\x07\x07\x07\x07\x17\x13\x17\x07\x1b\x0f\x17\x13\x1b\x17\x17\x13\x13\x1b\x13\x13\x13\x13\x13\x13\x17\x02\x06\t\x1d{\x05\x1d\x93\x05\x17\x1f\n\x06\x01\x03\x03\x13\xcb\x1dS\x05\x1f\x05!\x05#\x05%\x05'\x03\x03\r\xe9\x05)\x05+\x05-\x05/\x051\x03\x03#\xc7\x053\x1d[\x05\x055\x057\x03\x03\r\xd1\x17\x1f\x06\x06\x01\x059\x05;\x05=\x05?\x05A\x05C\x05E\x05G\x03\x03\x0f\xdd\x03\x03\x0f\xdf\x03\x03\x0f\xe1\x03\x03\r\xe5\x03\x05'\xa7)\xe7\x03\x03\x13\xeb\x03\x03\x11M\x05I\x03\x0b\x17\x9d\x19\xb1\x1b\xb3\x11\xbd\x1d\xbf\x03\x0b\x17\xa3\x19\xc3\x1b\xa3\x11\xa5\x1d\xc5\x05K\x1dW\x05\x05M\x03\x03\r\xc9\x05O\x03\x03#\xcd\x1da\x05\x05Q\x03\x05'\xa7)\xcf\x1dg\x05\x05S\x1dk\x05\x05U\x1do\x05\x05W\x1ds-\x05Y\x1dw-\x05[\x03\x11/\xa91\xd33\xd55\x9d7\xab9\xd7;\xad=\xdb\x05]\x03\x03\x0f\xe3\x03\x03\x13\xed\x1d\x83\x05\x05_\x03\x07\x87\x9f\x89\x9f\x8b\x9f\x05a\x05c\x05e\x1d\x8f\x05\x05g\x03\x11/\xa91\xef3\xf15\x9d7\xab9\xf3;\xad=\xf5\x05i\x03\x03\x97\xa5\x05k\x1f+!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f-\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1f\x1d\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1dm\x03\x03\xc1\x1do\t\x07\x0b\x05\x05\x01\x03\x03\xd9\x1f/\x01#!\x03\x05\xb5\xb9\r\x03\xa1\xb7\x1dq\r\x03\xa1\xbb\x1ds\x1du\x1dw\r\x01##\x1dy\x13\x0b\x01\x1f\x03\t\xff\xff\xff\xff\x1f%\x01\x13\x0b\x05\x07\x05\x1f\x05\t\x00\x00\x00\x00\x1d{\x1d}\x03\x03\x99\x15\x03\x01\x01\x01\x03\t\x99\x9b\xaf\x9b\x13\t\x01\x13\t\x05\x13\t\t\x13\t\r\x1f\x03\t\x00\x00\x00\x00\x07\x01\x1f\x05\t\x00\x00\xc0\x7f\x1f\x1d!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f5\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1d\x7f\x1d\x81\x03\x05\x99\x9b\x03\x07\x99\xaf\x9b)\x05\r\r\x07)\x01\t)\x01\x07\t\x1b\x1d\x01)\x05\r\r\t)\x03\r\x07)\x05\r\r\r\x13)\x03\x04\x000\x07)\x01\r)\x05\x05\x05\r)\x03\t\x0b)\x03\x04\x12\x08\x07\x11\x01\x05\x01\x01\x11\x03\x01\x03\x01)\x03\x01\x0b)\x03%\x07/\t\x01\x11\x03\x17)\x03\t\x15)\x03\x05\x15)\x03\x01\x15)\x03\x05\r)\x03\r\r)\x03\x05\x0b/\x07\x01\x03\x1f\x04\xe6\x05\x05\x01\x11\x0bK\x07\x03\x01\t\x0f\x11\x0bO\x05\x03G\x91\x0b\x03q!\x03'\x17\x06u\x03\x01\x03\x01\x13\x07\x01y\x03)\x03\x03\x07\x07\x01?\x03\x01\x03\x05\x07\x07\x01A\x03\x11\x03\x05\x07\x07\x01C\x03\x03\x03\x05\x07\x07\x01}\x03\x17\x03\x05\x05\x03\x01E\x03\x03\x03\x07\x01\x07\x03\x03\x03\x0f\r\x07\x01G\x03\x19\x05\x0b\x11\x03\x07\x01\x07\x03\x1b\x03\x13\x05\x03\x01\x15\x03\x05\x03\x07\x01\x07\x03\x01\x03\x17\x03\x07\x01I\x03\x13\x03\x15\t\x06\x01\x03\x01\x07\x1b\x07\x19\x03\x07\x01\x07\x031\x03\x13\x05\x03\x01\x15\x03\x05\x03\x07\x01\x07\x03\x11\x03!\x03\x07\x01\x7f\x033\x03\x1f\t\x06\x01\x03\x11\x07%\t#\x05\x03\x81+\x03\x05\x19\x07\x8d\x85\x03\x01\x05\x1d)\x13\x07\x03\x91\x037\x05+'\x07\x07\x03?\x03\x01\x03-\x07\x07\x03A\x03\x03\x03-\x07\x07\x03C\x03\x1f\x03-\x05\x03\x03E\x03\x03\x03\x07\x03\x07\x03\x03\x035\r\x07\x03G\x03\x19\x0517\x03\x07\x03\x07\x03\x1b\x039\x05\x03\x03\x15\x03\x05\x03\x07\x03\x07\x03\x01\x03=\x03\x07\x03I\x03\x13\x03;\t\x06\x03\x03\x01\x07A/?\x1b\x07\t\x95\x03\x01\x03\x1d\x11\x04\x0b\x05CE\x0f\x11\tQ\x05\x03\x15+\x03\x01\x0b\x0b\x03U!\x03\x0f\x05\x03\tY\x03\x03\x03\x07%\x07\x03\x0f\x03\x05\x15\x06%\x03\x0f\x05\x03\x07\x0b\x03_]\x03\x0f\r\x07ec\x03\x13\x05\t\x0b\x05\x03\t+\x03\x05\x03\x07i\x07\x03\x01\x03\x0f\t\x06m\x03\x01\x07\r\x11\x01\x11\x04\t\x03\x13\x06\x03\x01\x05\x01\x00\x86\x19\x83\x1f3\x1f+\x11\x0f\x0b\t\t\x0b!\x0fY\x87##%_=\x85\x87W\xb3K\x9bM\x9b\xd2\x02\x1b\x1f/!!)#\x1f\x19+\x1b\x1f\x83\x1f\x15\x1d\x15+\x13\r\r\x11\x0f\x17\x0f\x1f\x15\x11\x17\x11\x15+\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00get_tuple_element_v1\x00select_v1\x00iota_v1\x00compare_v1\x00func_v1\x00return_v1\x00custom_call_v1\x00add_v1\x00reshape_v1\x00pad_v1\x00call_v1\x00value\x00index\x00sym_name\x00broadcast_dimensions\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00iota_dimension\x00compare_type\x00comparison_direction\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]\x00jit()/jit(main)/jit(triu)/add\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]\x00jit()/jit(main)/jit(triu)/ge\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3, 3) broadcast_dimensions=()]\x00jit()/jit(main)/jit(triu)/select_n\x00jit()/jit(main)/iota[dtype=float32 shape=(9,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]\x00jit()/jit(main)/geqrf\x00jit()/jit(main)/qr[full_matrices=True]\x00edge_padding_high\x00edge_padding_low\x00interior_padding\x00jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0))]\x00jit()/jit(main)/householder_product\x00callee\x00jax.result_info\x00triu\x00[0]\x00[1]\x00main\x00public\x00private\x00\x00\x00\x00\x00\x01\x00\x00\x00\x03\x00\x00\x00\x03\x00\x00\x00\x00\x00\x03\x00\x00cusolver_geqrf\x00\x00\x00\x00\x00\x01\x00\x00\x00\x03\x00\x00\x00\x03\x00\x00\x00\x03\x00\x00\x00 \x81\x00\x00\x00cusolver_orgqr\x00", - xla_call_module_version=4, -) # End paste - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_03_18["batched"] = dict( - testdata_version=1, - platform='cuda', - custom_call_targets=['cublas_geqrf_batched', 'cusolver_orgqr'], - serialized_date=datetime.date(2023, 3, 18), - inputs=(), - expected_outputs=(array([[[ 0. , 0.91287094, 0.40824836], - [-0.4472136 , 0.36514843, -0.81649655], - [-0.8944272 , -0.18257417, 0.4082483 ]], - - [[-0.42426407, 0.80828977, 0.40824953], - [-0.5656854 , 0.11547142, -0.8164964 ], - [-0.7071068 , -0.5773508 , 0.4082474 ]]], dtype=float32), array([[[-6.7082038e+00, -8.0498447e+00, -9.3914852e+00], - [ 0.0000000e+00, 1.0954450e+00, 2.1908898e+00], - [ 0.0000000e+00, 0.0000000e+00, 4.8374091e-08]], - - [[-2.1213203e+01, -2.2910259e+01, -2.4607319e+01], - [ 0.0000000e+00, 3.4641042e-01, 6.9282258e-01], - [ 0.0000000e+00, 0.0000000e+00, 1.4548683e-06]]], dtype=float32)), - mlir_module_text=r""" -module @jit__lambda_ { - func.func public @main() -> (tensor<2x3x3xf32> {jax.result_info = "[0]"}, tensor<2x3x3xf32> {jax.result_info = "[1]"}) { - %0 = stablehlo.iota dim = 0 : tensor<18xf32> - %1 = stablehlo.reshape %0 : (tensor<18xf32>) -> tensor<2x3x3xf32> - %2 = stablehlo.custom_call @cublas_geqrf_batched(%1) {api_version = 2 : i32, backend_config = "\00\00\00\00\02\00\00\00\03\00\00\00\03\00\00\00", operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor<2x3x3xf32>) -> tuple, tensor<2x3xf32>, tensor<16xi8>, tensor<16xi8>> - %3 = stablehlo.get_tuple_element %2[0] : (tuple, tensor<2x3xf32>, tensor<16xi8>, tensor<16xi8>>) -> tensor<2x3x3xf32> - %4 = stablehlo.get_tuple_element %2[1] : (tuple, tensor<2x3xf32>, tensor<16xi8>, tensor<16xi8>>) -> tensor<2x3xf32> - %5 = stablehlo.get_tuple_element %2[2] : (tuple, tensor<2x3xf32>, tensor<16xi8>, tensor<16xi8>>) -> tensor<16xi8> - %6 = stablehlo.get_tuple_element %2[3] : (tuple, tensor<2x3xf32>, tensor<16xi8>, tensor<16xi8>>) -> tensor<16xi8> - %7 = stablehlo.constant dense<0.000000e+00> : tensor - %8 = stablehlo.pad %3, %7, low = [0, 0, 0], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<2x3x3xf32>, tensor) -> tensor<2x3x3xf32> - %9 = stablehlo.custom_call @cusolver_orgqr(%8, %4) {api_version = 2 : i32, backend_config = "\00\00\00\00\02\00\00\00\03\00\00\00\03\00\00\00\03\00\00\00 \81\00\00", operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor<2x3x3xf32>, tensor<2x3xf32>) -> tuple, tensor<2xi32>, tensor<33056xf32>> - %10 = stablehlo.get_tuple_element %9[0] : (tuple, tensor<2xi32>, tensor<33056xf32>>) -> tensor<2x3x3xf32> - %11 = stablehlo.get_tuple_element %9[1] : (tuple, tensor<2xi32>, tensor<33056xf32>>) -> tensor<2xi32> - %12 = stablehlo.get_tuple_element %9[2] : (tuple, tensor<2xi32>, tensor<33056xf32>>) -> tensor<33056xf32> - %13 = stablehlo.constant dense<0> : tensor - %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<2xi32> - %15 = stablehlo.compare EQ, %11, %14, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> - %16 = stablehlo.broadcast_in_dim %15, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> - %17 = stablehlo.constant dense<0x7FC00000> : tensor - %18 = stablehlo.broadcast_in_dim %17, dims = [] : (tensor) -> tensor<2x3x3xf32> - %19 = stablehlo.broadcast_in_dim %16, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x3x3xi1> - %20 = stablehlo.select %19, %10, %18 : tensor<2x3x3xi1>, tensor<2x3x3xf32> - %21 = call @triu(%3) : (tensor<2x3x3xf32>) -> tensor<2x3x3xf32> - return %20, %21 : tensor<2x3x3xf32>, tensor<2x3x3xf32> - } - func.func private @triu(%arg0: tensor<2x3x3xf32>) -> tensor<2x3x3xf32> { - %0 = stablehlo.iota dim = 0 : tensor<3x3xi32> - %1 = stablehlo.constant dense<-1> : tensor - %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<3x3xi32> - %3 = stablehlo.add %0, %2 : tensor<3x3xi32> - %4 = stablehlo.iota dim = 1 : tensor<3x3xi32> - %5 = stablehlo.compare GE, %3, %4, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> - %6 = stablehlo.broadcast_in_dim %5, dims = [1, 2] : (tensor<3x3xi1>) -> tensor<2x3x3xi1> - %7 = stablehlo.constant dense<0.000000e+00> : tensor - %8 = stablehlo.broadcast_in_dim %7, dims = [] : (tensor) -> tensor<2x3x3xf32> - %9 = stablehlo.select %6, %8, %arg0 : tensor<2x3x3xi1>, tensor<2x3x3xf32> - return %9 : tensor<2x3x3xf32> - } -} -""", - mlir_module_serialized=b"ML\xefR\x03MLIRxxx-trunk\x00\x01+\x05\x01\x05\x01\x03\x05\x03\x1b\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f\x03\x96\x02\xff=\x01\x9f\x17\x0f\x0f\x0f\x07\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0b\x13\x17\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x13\x13\x0b33\x0b\x0f\x0b\x13\x0b\x13\x0f\x0b\x1b\x0f\x0b\x13\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0bK\x0b\x13\x0f\x0b#\x0b\x0b\x0b\x0f\x0bK\x0b\x13\x1b\x13\x13\x13\x13\x0b\x03ao/\x0b/\x0b\x0f\x0b\x0b\x0b\x0b\x0fO\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0bO\x1f\x0b\x0b\x0f\x17\x1b\x0f\x0f\x0f\x0f\x0b\x0b\x13\x17\x1f\x0b/\x1fo\x03=\x1b\x07\x07\x07\x0f\x17\x0f\x07\x13\x07\x13\x1b\x17\x13\x1b\x17\x17\x13\x17\x13\x13\x1b\x07\x13\x13\x13\x17\x13\x1b\x13\x02\x1a\n\x17\x1d\n\x06\x01\x1d\x8f\x01\x1dK\x01\x1dy\x01\x1f\x05!\x03\x03\x0f\xd1\x05#\x05%\x05'\x05)\x05+\x05-\x05/\x051\x03\x03!\xcd\x053\x1dS\x01\x055\x057\x03\x03\x0b\xd9\x17\x1d\x06\x06\x01\x059\x05;\x05=\x05?\x05A\x05C\x05E\x05G\x03\x03\x11\xe5\x03\x03\x11\xe7\x03\x03\x11\xe9\x03\x03\x13E\x05I\x03\x0b\x15\xa3\x17\xb7\x19\xb9\x13\xc3\x1b\xc5\x03\x0b\x15\xa9\x17\xc9\x19\xa9\x13\xab\x1b\xcb\x05K\x1dO\x01\x05M\x03\x03\x0b\xcf\x05O\x03\x03!\xd3\x1dY\x01\x05Q\x03\x05%\xad'\xd5\x1d_\x01\x05S\x03\x03\x0f\xd7\x1de\x01\x05U\x1di\x01\x05W\x1dm\x01\x05Y\x1dq+\x05[\x1du+\x05]\x03\x11-\xaf/\xdb1\xdd3\xa35\xb17\xdf9\xb3;\xe3\x05_\x03\x03\x11\xeb\x1d\x7f\x01\x05a\x03\x07\x83\xa5\x85\xa5\x87\xa5\x05c\x05e\x05g\x1d\x8b\x01\x05i\x03\x11-\xaf/\xed1\xef3\xa35\xb17\xf19\xb3;\xf3\x05k\x03\x03\x0b\xf5\x03\x05%\xad'\xf7\x03\x03\x0f\xf9\x03\x03\x0b\xfb\x03\x03\x0f\xfd\x03\x03\x9d\xab\x05m\x1f/1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f3\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1f\x1b\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1do\x03\x03\xc7\x1dq\t\x07\x0b\x05\x05\x01\x03\x03\xe1\x1f1!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00#\x1f\x03\x05\xbb\xbf\r\x03\xa7\xbd\x1ds\r\x03\xa7\xc1\x1du\x1dw\x1dy\r\x01#!\x1d{\x13\x05\x01\x1f\r\t\xff\xff\xff\xff\x1f#\x01\x13\x05\x05\x07\x05\x1f'!\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x1f\t\t\x00\x00\x00\x00\x1d}\x1d\x7f\x03\x03\x9f\x15\x03\x01\x01\x01\x03\t\x9f\xb5\xa1\xa1\x13\x03\x01\x13\x03\x05\x13\x03\t\x13\x03\r\x1d\x81\x1d\x83\x03\x05\x9f\xb5\x03\x07\x9f\xa1\xa1\x1f\r\t\x00\x00\x00\x00\x07\x01\x1f;\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\t\t\x00\x00\xc0\x7f\x1f\x1b1\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00)\x07\t\r\r\x07\x1b\x1d\t)\x01\x07)\x05\r\r\x03)\x01\x03\x01)\x03A-\x13)\x03\t\x03)\x07\t\r\r\x0f)\x05\t\r\x07)\x03\r\x05)\x03\x04\x12\x08\x07\x11\x01\x05\x01\x01\x11\x03\x01\x03\x01)\x03\x01\x05)\x05\r\r\x0f)\x03\t\x05)\x03I\x07/\t\x01\x19\x11\x11\x17)\x03\r\x13)\x03\t\x13)\x03\x05\x13/\x07\x01\x15\x1d)\x03\t\x0f)\x07\t\x05\x05\x0f)\x03\x05\x05\x04r\x04\x05\x01\x11\tC\x07\x03\x01\t\x0b\x11\tG\x05\x03-]\t\x03o\x1f\x03)\x17\x06s\x03\x01\x03\x01\x13\x07\x07w\x03+\x03\x03\x05\x07\x07=\x03\x01\x03\x05\x05\x07\x07?\x03\x19\x03\x05\x05\x07\x07A\x03\x11\x03\x05\x05\x07\x07{\x03\x11\x03\x05\x07\x03})\x03\t\x19\x07\x89\x81\x03\x01\x05\x07\x0f\x13\x07\x03\x8d\x035\x05\x11\t\x05\x07\x03=\x03\x01\x03\x13\x05\x07\x03?\x03\x15\x03\x13\x05\x07\x03A\x03\x1d\x03\x13\x07\x03\x03\x91\x03\r\x03\x07\x03\r\x03\x15\x03\x1b\r\x07\x03\x93\x037\x05\x17\x1d\x03\x07\x03\x95\x039\x03\x1f\x07\x03\x03\x97\x03\t\x03\x07\x03\r\x03\x01\x03#\x03\x07\x03\x99\x03\x17\x03!\x0f\x06\x03\x03\x01\x07'\x15%\x1b\x07\x05\x9b\x03\x01\x03\x07\x11\x04\t\x05)+\x0b\x11\x05I\x05\x03\x17/\x03\x01\t\t\x03M\x1f\x03\x0b\x07\x03\x05Q\x03\r\x03\x07#\r\x03\x0b\x03\x05\x15\x06#\x03\x0b\x05\x03\x07\t\x03WU\x03\x0b\r\x07][\x03%\x05\t\x0b\x03\x07ca\x03\x17\x03\r\x07\x03\x05)\x03\t\x03\x07g\r\x03\x01\x03\x11\x0f\x06k\x03\x01\x07\x0f\x13\x01\x11\x04\x05\x03\x15\x06\x03\x01\x05\x01\x00Z\x1b\x85\x1f3+#\x11\x0f\x0b\t\t\x0b!\x0fY\x9d##%_=\x8b\x89W\xb9\xc1K\x9bM\x9b\xd2\x02\x1b\x1f/!!)#\x1f\x19+\x1b\x1f\x83\x1f\x15\x1d\x15\x13\r+\r\x11\x0f\x17\x0f\x1f\x15\x15\x17\x11\x11\x19+)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00get_tuple_element_v1\x00constant_v1\x00iota_v1\x00func_v1\x00compare_v1\x00select_v1\x00return_v1\x00custom_call_v1\x00add_v1\x00reshape_v1\x00pad_v1\x00call_v1\x00value\x00broadcast_dimensions\x00index\x00sym_name\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00iota_dimension\x00compare_type\x00comparison_direction\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]\x00jit()/jit(main)/jit(triu)/add\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]\x00jit()/jit(main)/jit(triu)/ge\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(2, 3, 3) broadcast_dimensions=(1, 2)]\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(2, 3, 3) broadcast_dimensions=()]\x00jit()/jit(main)/jit(triu)/select_n\x00jit()/jit(main)/iota[dtype=float32 shape=(18,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(2, 3, 3) dimensions=None]\x00jit()/jit(main)/geqrf\x00jit()/jit(main)/qr[full_matrices=True]\x00edge_padding_high\x00edge_padding_low\x00interior_padding\x00jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0), (0, 0, 0))]\x00jit()/jit(main)/householder_product\x00callee\x00jax.result_info\x00triu\x00[0]\x00[1]\x00main\x00public\x00private\x00\x00\x00\x00\x00\x02\x00\x00\x00\x03\x00\x00\x00\x03\x00\x00\x00\x00cublas_geqrf_batched\x00\x00\x00\x00\x00\x02\x00\x00\x00\x03\x00\x00\x00\x03\x00\x00\x00\x03\x00\x00\x00 \x81\x00\x00\x00cusolver_orgqr\x00", - xla_call_module_version=4, -) # End paste data_2024_09_26 = {} - data_2024_09_26["f32"] = dict( testdata_version=1, platform='cuda', diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/rocm_qr_hipsolver_geqrf.py b/jax/_src/internal_test_util/export_back_compat_test_data/rocm_qr_hipsolver_geqrf.py deleted file mode 100644 index bd5fa628741e..000000000000 --- a/jax/_src/internal_test_util/export_back_compat_test_data/rocm_qr_hipsolver_geqrf.py +++ /dev/null @@ -1,176 +0,0 @@ -# Copyright 2023 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import datetime -from numpy import array, float32 - -data_2024_08_05 = {} - -# Pasted from the test output (see export_back_compat_test_util.py module docstring) -data_2024_08_05["unbatched"] = dict( - testdata_version=1, - platform='rocm', - custom_call_targets=['hipsolver_geqrf', 'hipsolver_orgqr'], - serialized_date=datetime.date(2024, 8, 5), - inputs=(), - expected_outputs=(array([[ 0. , 0.9128709 , 0.40824834], - [-0.4472136 , 0.3651484 , -0.81649655], - [-0.8944272 , -0.18257423, 0.40824828]], dtype=float32), array([[-6.7082038e+00, -8.0498447e+00, -9.3914852e+00], - [ 0.0000000e+00, 1.0954450e+00, 2.1908898e+00], - [ 0.0000000e+00, 0.0000000e+00, 1.6371473e-09]], dtype=float32)), - mlir_module_text=r""" -#loc2 = loc("/release/jax/tests/export_back_compat_test.py":346:0) -#loc9 = loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]"(#loc2)) -module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<3x3xf32> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<3x3xf32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { - %0 = stablehlo.iota dim = 0 : tensor<9xf32> loc(#loc3) - %1 = stablehlo.reshape %0 : (tensor<9xf32>) -> tensor<3x3xf32> loc(#loc4) - %2:4 = stablehlo.custom_call @hipsolver_geqrf(%1) {api_version = 2 : i32, backend_config = "\00\00\00\00\01\00\00\00\03\00\00\00\03\00\00\00\00\01\00\00", operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor<3x3xf32>) -> (tensor<3x3xf32>, tensor<3xf32>, tensor, tensor<256xf32>) loc(#loc5) - %c = stablehlo.constant dense<0> : tensor loc(#loc5) - %3 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor loc(#loc5) - %4 = stablehlo.compare EQ, %2#2, %3, SIGNED : (tensor, tensor) -> tensor loc(#loc5) - %5 = stablehlo.broadcast_in_dim %4, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) - %cst = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc5) - %6 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<3x3xf32> loc(#loc5) - %7 = stablehlo.broadcast_in_dim %5, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> loc(#loc5) - %8 = stablehlo.select %7, %2#0, %6 : tensor<3x3xi1>, tensor<3x3xf32> loc(#loc5) - %9 = stablehlo.broadcast_in_dim %4, dims = [] : (tensor) -> tensor<1xi1> loc(#loc5) - %cst_0 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc5) - %10 = stablehlo.broadcast_in_dim %cst_0, dims = [] : (tensor) -> tensor<3xf32> loc(#loc5) - %11 = stablehlo.broadcast_in_dim %9, dims = [0] : (tensor<1xi1>) -> tensor<3xi1> loc(#loc5) - %12 = stablehlo.select %11, %2#1, %10 : tensor<3xi1>, tensor<3xf32> loc(#loc5) - %cst_1 = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc6) - %13 = stablehlo.pad %8, %cst_1, low = [0, 0], high = [0, 0], interior = [0, 0] : (tensor<3x3xf32>, tensor) -> tensor<3x3xf32> loc(#loc7) - %14:3 = stablehlo.custom_call @hipsolver_orgqr(%13, %12) {api_version = 2 : i32, backend_config = "\00\00\00\00\01\00\00\00\03\00\00\00\03\00\00\00\03\00\00\00\80\00\00\00", operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor<3x3xf32>, tensor<3xf32>) -> (tensor<3x3xf32>, tensor, tensor<128xf32>) loc(#loc8) - %c_2 = stablehlo.constant dense<0> : tensor loc(#loc8) - %15 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor) -> tensor loc(#loc8) - %16 = stablehlo.compare EQ, %14#1, %15, SIGNED : (tensor, tensor) -> tensor loc(#loc8) - %17 = stablehlo.broadcast_in_dim %16, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc8) - %cst_3 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc8) - %18 = stablehlo.broadcast_in_dim %cst_3, dims = [] : (tensor) -> tensor<3x3xf32> loc(#loc8) - %19 = stablehlo.broadcast_in_dim %17, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> loc(#loc8) - %20 = stablehlo.select %19, %14#0, %18 : tensor<3x3xi1>, tensor<3x3xf32> loc(#loc8) - %21 = call @triu(%8) : (tensor<3x3xf32>) -> tensor<3x3xf32> loc(#loc9) - return %20, %21 : tensor<3x3xf32>, tensor<3x3xf32> loc(#loc) - } loc(#loc) - func.func private @triu(%arg0: tensor<3x3xf32> {mhlo.layout_mode = "default"} loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]"(#loc2))) -> (tensor<3x3xf32> {mhlo.layout_mode = "default"}) { - %0 = stablehlo.iota dim = 0 : tensor<3x3xi32> loc(#loc10) - %c = stablehlo.constant dense<-1> : tensor loc(#loc9) - %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<3x3xi32> loc(#loc11) - %2 = stablehlo.add %0, %1 : tensor<3x3xi32> loc(#loc11) - %3 = stablehlo.iota dim = 1 : tensor<3x3xi32> loc(#loc12) - %4 = stablehlo.compare GE, %2, %3, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> loc(#loc13) - %cst = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc9) - %5 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<3x3xf32> loc(#loc14) - %6 = stablehlo.select %4, %5, %arg0 : tensor<3x3xi1>, tensor<3x3xf32> loc(#loc15) - return %6 : tensor<3x3xf32> loc(#loc9) - } loc(#loc9) -} loc(#loc) -#loc = loc(unknown) -#loc1 = loc("/release/jax/tests/export_back_compat_test.py":345:0) -#loc3 = loc("jit()/jit(main)/iota[dtype=float32 shape=(9,) dimension=0]"(#loc1)) -#loc4 = loc("jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]"(#loc1)) -#loc5 = loc("jit()/jit(main)/geqrf"(#loc2)) -#loc6 = loc("jit()/jit(main)/qr[full_matrices=True]"(#loc2)) -#loc7 = loc("jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0))]"(#loc2)) -#loc8 = loc("jit()/jit(main)/householder_product"(#loc2)) -#loc10 = loc("jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]"(#loc2)) -#loc11 = loc("jit()/jit(main)/jit(triu)/add"(#loc2)) -#loc12 = loc("jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]"(#loc2)) -#loc13 = loc("jit()/jit(main)/jit(triu)/ge"(#loc2)) -#loc14 = loc("jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3, 3) broadcast_dimensions=()]"(#loc2)) -#loc15 = loc("jit()/jit(main)/jit(triu)/select_n"(#loc2)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01)\x05\x01\x03\x01\x03\x05\x03\x19\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x03~\x02\xf39\x01\x99\x0f\x17\x13\x0f\x0f\x0b\x0b\x07\x0b\x13\x0f\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0b\x13\x17\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x13+\x0b\x0f\x0b\x0b\x0b33\x0b\x0f\x0b\x13\x0b\x13\x0f\x0b\x1b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0bK\x0b\x13\x0f\x0b#\x0b\x0b\x0b\x0f\x0bK\x0b\x13\x0b\x03[O/\x0b\x0b\x0b/\x0b\x0f\x0b\x0b\x0b\x0b\x0f\x0f\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b\x13\x0b\x0b\x0f\x1f\x0f\x0f\x0b\x1f\x0b\x0b\x0f\x17\x1b\x1f\x0b\x1fO/\x0b\x0b\x13\x17\x01\x05\x0b\x0f\x035\x17\x0f\x0f\x07\x07\x07\x17\x17\x13\x07\x07\x0f\x17\x13\x17\x17\x13\x13\x17\x13\x13\x13\x13\x13\x13\x17\x02\xde\x08\x1d}\x03\x17\x1fj\x05\x01\x03\x03\x11\xcf\x1d\x93\x03\x1dU\x03\x05\x1f\x05!\x1f\x05#\x03\x03\x0b\xe5\x11\x03\x05\x05%\x05'\x05)\x05+\x05-\x03\x03#\xcb\x05/\x1d]\x03\x051\x053\x03\x03\x0b\xd5\x17\x1ff\x05\x01\x055\x057\x059\x05;\x05=\x05?\x05A\x05C\x03\x03\x0b\xe1\x03\x05'\xab)\xe3\x03\x03\x11\xe7\x03\tGIK\x15M\x15\rO\x05E\x11\x01\x00\x05G\x05I\x05K\x03\x0b\x17\x9d\x19\xb5\x1b\xb7\r\xc1\x1d\xc3\x03\x0b\x17\xa7\x19\xc7\x1b\xa7\r\xa9\x1d\xc9\x05M\x1dY\x03\x05O\x03\x03\x0b\xcd\x05Q\x03\x03#\xd1\x1dc\x03\x05S\x03\x05'\xab)\xd3\x1di\x03\x05U\x1dm\x03\x05W\x1dq\x03\x05Y\x1du-\x05[\x1dy-\x05]\x03\x11/\xad1\xd73\xd95\x9d7\xaf9\xdb;\xb1=\xdf\x05_\x03\x03\x11\xe9\x1d\x83\x03\x05a\x03\x07\x87\xa3\x89\xa3\x8b\xa3\x05c\x05e\x05g\x1d\x8f\x03\x05i\x03\x11/\xad1\xeb3\xed5\x9d7\xaf9\xef;\xb1=\xf1\x05k\x03\x03\x97\xa9\x05m\x1f+!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f-\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1do\x1dq\x1f\x1f\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1ds\x03\x03\xc5\x1du\t\x07\x0b\x05\x05\x01\x03\x03\xdd\x1f/\x01#!\x03\x05\xb9\xbd\r\x05\xa5\xbb\x9f\xa1\x1dw\r\x05\xa5\xbf\x9f\xa1\x1dy\x1d{\x1d}\r\x03\x9f\xa1##\x1d\x7f\x13\r\x01\x1f\x07\t\xff\xff\xff\xff\x1f%\x01\x13\r\x05\x07\x05\x1f\t\t\x00\x00\x00\x00\x1d\x81\x1d\x83\x03\x03\x99\x15\x03\x01\x01\x01\x03\t\x99\x9b\xb3\x9b\x1f\x07\t\x00\x00\x00\x00\x07\x01\x1f\t\t\x00\x00\xc0\x7f\x1f\x1f!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f5\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1d\x85\x1d\x87\x03\x05\x99\x9b\x03\x07\x99\xb3\x9b\x01\t\x01\x02\x02)\x05\r\r\x0b)\x01\x19)\x01\x0b\t\x1d\x01)\x05\r\r\x19)\x05\r\r\x0f)\x03\r\x0b\x13\x1b)\x01\x0f)\x05\x05\x05\x0f)\x03\t\r\x11\x01\x05\x05\x05\x11\x03\x05\x03\x05)\x03\x01\r)\x03%\x0b)\x03\x02\x08\x0b)\x03\t\x17)\x03\x05\x17)\x03\x01\x17)\x03\x05\x0f)\x03\r\x0f)\x03\x05\r)\x03\x02\x04\x0b\x04\x1a\x05\x05\x01\x11\x0fE\x07\x03\x01\t\r\x11\x0fQ\x07\x03Cu\t\x03s!\x03'\x15\x06w\x03\x05\x03\x01\x11\x07\x01{\t\x05\x15\x07)\x03\x03\x05\x03\x01?\x03\x07\x03\x07\x01\x05\x03\x07\x03\r\x0b\x07\x01A\x03\x1b\x05\t\x0f\x03\x07\x01\x05\x03\x1d\x03\x11\x05\x03\x01\x13\x03\t\x03\x07\x01\x05\x03\x05\x03\x15\x03\x07\x01C\x03\x13\x03\x13\x07\x06\x01\x03\x05\x07\x19\x05\x17\x03\x07\x01\x05\x031\x03\x11\x05\x03\x01\x13\x03\t\x03\x07\x01\x05\x03\x15\x03\x1f\x03\x07\x01\x7f\x033\x03\x1d\x07\x06\x01\x03\x15\x07#\x07!\x05\x03\x81+\x03\t\x17\x07\x8d\x85\x03\x05\x05\x1b'\x11\x07\x07\x91\x07\x05\x077\x05)%\x05\x03\x07?\x03\x07\x03\x07\x07\x05\x03\x07\x031\x0b\x07\x07A\x03\x1b\x05-3\x03\x07\x07\x05\x03\x1d\x035\x05\x03\x07\x13\x03\t\x03\x07\x07\x05\x03\x05\x039\x03\x07\x07C\x03\x13\x037\x07\x06\x07\x03\x05\x07=+;\x19\x07\t\x95\x03\x05\x03\x1b\x0f\x04\x0f\x05?A\r\x11\tS\x07\x03\x15+\x03\x05\t\t\x03W!\x03\x11\x05\x03\t[\x03\x07\x03\x07%\x05\x03\x11\x03\x05\x13\x06%\x03\x11\x05\x03\x07\t\x03a_\x03\x11\x0b\x07ge\x03\x13\x05\t\x0b\x05\x03\t+\x03\t\x03\x07k\x05\x03\x05\x03\x0f\x07\x06o\x03\x05\x07\r\x11\x01\x0f\x04\t\x03\x13\x06\x03\x01\x05\x01\x00\xea\x1a\x89!3!+\x11\x0f\x0b\t\t\x0b!\x11#\x0fY\x87##%_=\x85\x87W\xb3K\x9bM\x9bn\x03\x1b%)9\x1f/!!)#\x1f\x19+\x1b\x1f]\x1f\x15\x1d\x15+\x13\r\x11\x0f\x17\x0f\x1f\x15\x11\x17\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00iota_v1\x00compare_v1\x00func_v1\x00return_v1\x00custom_call_v1\x00add_v1\x00reshape_v1\x00pad_v1\x00call_v1\x00value\x00sym_name\x00broadcast_dimensions\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00/release/jax/tests/export_back_compat_test.py\x00iota_dimension\x00compare_type\x00comparison_direction\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]\x00jit()/jit(main)/jit(triu)/add\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]\x00jit()/jit(main)/jit(triu)/ge\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3, 3) broadcast_dimensions=()]\x00jit()/jit(main)/jit(triu)/select_n\x00jit()/jit(main)/iota[dtype=float32 shape=(9,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]\x00jit()/jit(main)/geqrf\x00jit()/jit(main)/qr[full_matrices=True]\x00edge_padding_high\x00edge_padding_low\x00interior_padding\x00jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0))]\x00jit()/jit(main)/householder_product\x00callee\x00mhlo.layout_mode\x00default\x00jax.result_info\x00triu\x00[0]\x00[1]\x00main\x00public\x00private\x00\x00\x00\x00\x00\x01\x00\x00\x00\x03\x00\x00\x00\x03\x00\x00\x00\x00\x01\x00\x00\x00hipsolver_geqrf\x00\x00\x00\x00\x00\x01\x00\x00\x00\x03\x00\x00\x00\x03\x00\x00\x00\x03\x00\x00\x00\x80\x00\x00\x00\x00hipsolver_orgqr\x00", - xla_call_module_version=9, - nr_devices=1, -) # End paste - - -# Pasted from the test output (see export_back_compat_test_util.py module docstring) -data_2024_08_05["batched"] = dict( - testdata_version=1, - platform='rocm', - custom_call_targets=['hipblas_geqrf_batched', 'hipsolver_orgqr'], - serialized_date=datetime.date(2024, 8, 5), - inputs=(), - expected_outputs=(array([[[ 0. , 0.9128709 , 0.40824834], - [-0.4472136 , 0.3651484 , -0.81649655], - [-0.8944272 , -0.18257423, 0.40824828]], - - [[-0.42426407, 0.8082888 , 0.4082513 ], - [-0.5656854 , 0.11547317, -0.81649613], - [-0.7071068 , -0.5773518 , 0.40824607]]], dtype=float32), array([[[-6.7082038e+00, -8.0498447e+00, -9.3914852e+00], - [ 0.0000000e+00, 1.0954450e+00, 2.1908898e+00], - [ 0.0000000e+00, 0.0000000e+00, 1.6371473e-09]], - - [[-2.1213203e+01, -2.2910259e+01, -2.4607313e+01], - [ 0.0000000e+00, 3.4641036e-01, 6.9281983e-01], - [ 0.0000000e+00, 0.0000000e+00, 8.3555670e-07]]], dtype=float32)), - mlir_module_text=r""" -#loc2 = loc("/release/jax/tests/export_back_compat_test.py":346:0) -#loc9 = loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]"(#loc2)) -module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<2x3x3xf32> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x3x3xf32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { - %0 = stablehlo.iota dim = 0 : tensor<18xf32> loc(#loc3) - %1 = stablehlo.reshape %0 : (tensor<18xf32>) -> tensor<2x3x3xf32> loc(#loc4) - %2:4 = stablehlo.custom_call @hipblas_geqrf_batched(%1) {api_version = 2 : i32, backend_config = "\00\00\00\00\02\00\00\00\03\00\00\00\03\00\00\00", operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor<2x3x3xf32>) -> (tensor<2x3x3xf32>, tensor<2x3xf32>, tensor<16xi8>, tensor<16xi8>) loc(#loc5) - %cst = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc6) - %3 = stablehlo.pad %2#0, %cst, low = [0, 0, 0], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<2x3x3xf32>, tensor) -> tensor<2x3x3xf32> loc(#loc7) - %4:3 = stablehlo.custom_call @hipsolver_orgqr(%3, %2#1) {api_version = 2 : i32, backend_config = "\00\00\00\00\02\00\00\00\03\00\00\00\03\00\00\00\03\00\00\00\80\00\00\00", operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor<2x3x3xf32>, tensor<2x3xf32>) -> (tensor<2x3x3xf32>, tensor<2xi32>, tensor<128xf32>) loc(#loc8) - %c = stablehlo.constant dense<0> : tensor loc(#loc8) - %5 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<2xi32> loc(#loc8) - %6 = stablehlo.compare EQ, %4#1, %5, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc8) - %7 = stablehlo.broadcast_in_dim %6, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc8) - %cst_0 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc8) - %8 = stablehlo.broadcast_in_dim %cst_0, dims = [] : (tensor) -> tensor<2x3x3xf32> loc(#loc8) - %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x3x3xi1> loc(#loc8) - %10 = stablehlo.select %9, %4#0, %8 : tensor<2x3x3xi1>, tensor<2x3x3xf32> loc(#loc8) - %11 = call @triu(%2#0) : (tensor<2x3x3xf32>) -> tensor<2x3x3xf32> loc(#loc9) - return %10, %11 : tensor<2x3x3xf32>, tensor<2x3x3xf32> loc(#loc) - } loc(#loc) - func.func private @triu(%arg0: tensor<2x3x3xf32> {mhlo.layout_mode = "default"} loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]"(#loc2))) -> (tensor<2x3x3xf32> {mhlo.layout_mode = "default"}) { - %0 = stablehlo.iota dim = 0 : tensor<3x3xi32> loc(#loc10) - %c = stablehlo.constant dense<-1> : tensor loc(#loc9) - %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<3x3xi32> loc(#loc11) - %2 = stablehlo.add %0, %1 : tensor<3x3xi32> loc(#loc11) - %3 = stablehlo.iota dim = 1 : tensor<3x3xi32> loc(#loc12) - %4 = stablehlo.compare GE, %2, %3, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> loc(#loc13) - %5 = stablehlo.broadcast_in_dim %4, dims = [1, 2] : (tensor<3x3xi1>) -> tensor<2x3x3xi1> loc(#loc14) - %cst = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc9) - %6 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<2x3x3xf32> loc(#loc15) - %7 = stablehlo.select %5, %6, %arg0 : tensor<2x3x3xi1>, tensor<2x3x3xf32> loc(#loc16) - return %7 : tensor<2x3x3xf32> loc(#loc9) - } loc(#loc9) -} loc(#loc) -#loc = loc(unknown) -#loc1 = loc("/release/jax/tests/export_back_compat_test.py":345:0) -#loc3 = loc("jit()/jit(main)/iota[dtype=float32 shape=(18,) dimension=0]"(#loc1)) -#loc4 = loc("jit()/jit(main)/reshape[new_sizes=(2, 3, 3) dimensions=None]"(#loc1)) -#loc5 = loc("jit()/jit(main)/geqrf"(#loc2)) -#loc6 = loc("jit()/jit(main)/qr[full_matrices=True]"(#loc2)) -#loc7 = loc("jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0), (0, 0, 0))]"(#loc2)) -#loc8 = loc("jit()/jit(main)/householder_product"(#loc2)) -#loc10 = loc("jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]"(#loc2)) -#loc11 = loc("jit()/jit(main)/jit(triu)/add"(#loc2)) -#loc12 = loc("jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]"(#loc2)) -#loc13 = loc("jit()/jit(main)/jit(triu)/ge"(#loc2)) -#loc14 = loc("jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(2, 3, 3) broadcast_dimensions=(1, 2)]"(#loc2)) -#loc15 = loc("jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(2, 3, 3) broadcast_dimensions=()]"(#loc2)) -#loc16 = loc("jit()/jit(main)/jit(triu)/select_n"(#loc2)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01)\x05\x01\x03\x01\x03\x05\x03\x19\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x03\x96\x02\xfb=\x01\x9f\x17\x0f\x0f\x0b\x13\x0b\x0b\x07\x0f\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0b\x13\x17\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b+\x0b\x0f\x0b\x0b\x0b33\x0b\x0f\x0b\x13\x0b\x13\x0f\x0b\x1b\x0f\x0b\x13\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0bK\x0f\x0b\x0f\x0b#\x0b\x0b\x0b\x0f\x0bK\x0b\x13\x1b\x13\x13\x13\x13\x0b\x03]o/\x0b\x0b\x0b/\x0b\x0f\x0b\x0b\x0b\x0b\x0fO\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b\x13\x0b\x0b\x0f\x1f\x0f\x0f\x0bO\x1f\x0b\x0b\x0f\x17\x1b\x0b\x0b\x13\x17\x1f\x0b/\x1fo\x01\x05\x0b\x0f\x039\x1b\x07\x07\x0f\x17\x0f\x07\x07\x07\x1b\x13\x13\x13\x17\x17\x13\x17\x13\x13\x17\x07\x13\x13\x13\x17\x13\x1b\x13\x02\xf6\t\x17\x1bj\x05\x01\x1d\x8f\x01\x1dK\x01\x05\x1f\x03\x03\x0b\xd5\x05!\x05#\x1f\x11\x03\x05\x05%\x05'\x05)\x05+\x05-\x03\x03\x1f\xd1\x05/\x1dS\x01\x051\x053\x03\x03\x07\xdd\x17\x1bf\x05\x01\x055\x057\x059\x05;\x05=\x05?\x05A\x05C\x03\t=?A\x11C\x11\rE\x05E\x11\x01\x00\x05G\x05I\x05K\x03\x0b\x13\xa3\x15\xbb\x17\xbd\r\xc7\x19\xc9\x03\x0b\x13\xad\x15\xcd\x17\xad\r\xaf\x19\xcf\x05M\x1dO\x01\x05O\x03\x03\x07\xd3\x05Q\x03\x03\x1f\xd7\x1dY\x01\x05S\x03\x05#\xb1%\xd9\x1d_\x01\x05U\x03\x03\x0b\xdb\x1de\x01\x05W\x1di\x01\x05Y\x1dm\x01\x05[\x1dq)\x05]\x1du)\x05_\x03\x11+\xb3-\xdf/\xe11\xa33\xb55\xe37\xb79\xe7\x1d{\x01\x05a\x1d\x7f\x01\x05c\x03\x07\x83\xa9\x85\xa9\x87\xa9\x05e\x05g\x05i\x1d\x8b\x01\x05k\x03\x11+\xb3-\xe9/\xeb1\xa33\xb55\xed7\xb79\xef\x05m\x03\x03\x07\xf1\x03\x05#\xb1%\xf3\x03\x03\x0b\xf5\x03\x03\x07\xf7\x03\x03\x0b\xf9\x03\x03\x9d\xaf\x05o\x1f/1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f3\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1dq\x1ds\x1f\x1b\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1du\x03\x03\xcb\x1dw\t\x07\x0b\x05\x05\x01\x03\x03\xe5\x1f1!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00#\x1f\x03\x05\xbf\xc3\r\x05\xab\xc1\xa5\xa7\x1dy\r\x05\xab\xc5\xa5\xa7\x1d{\x1d}\x1d\x7f\r\x03\xa5\xa7#!\x1d\x81\x13\x07\x01\x1f\x0f\t\xff\xff\xff\xff\x1f#\x01\x13\x07\x05\x07\x05\x1f'!\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x1f\x0b\t\x00\x00\x00\x00\x1d\x83\x1d\x85\x03\x03\x9f\x15\x03\x01\x01\x01\x03\t\x9f\xb9\xa1\xa1\x1d\x87\x1d\x89\x03\x05\x9f\xb9\x03\x07\x9f\xa1\xa1\x1f\x0f\t\x00\x00\x00\x00\x07\x01\x1f;\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x0b\t\x00\x00\xc0\x7f\x1f\x1b1\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x07\t\r\r\t\x1d\t)\x01\t)\x05\r\r\x13)\x01\x13\x01\x1b\x13)\x07\t\r\r\x11)\x03A-)\x03\r\x07)\x03\t\x13\x11\x01\x05\x05\x05\x11\x03\x05\x03\x05)\x03\x01\x07)\x05\r\r\x11)\x03\t\x07)\x03I\t)\x05\t\r\t\x17)\x03\r\x15)\x03\t\x15)\x03\x05\x15)\x03\x02\x04\t)\x03\t\x11)\x07\t\x05\x05\x11)\x03\x05\x07\x04\xa6\x03\x05\x01\x11\x0f;\x07\x03\x01\t\t\x11\x0fG\x07\x03)A\x07\x03o\x1d\x03)\x15\x06s\x03\x05\x03\x01\x11\x07yw\t\x05+\x19\x19\x03\x03\x05\x03}'\x03\x0b\x17\x07\x89\x81\x03\x05\x05\x05\r\x11\x07\x03\x8d\x07\x05\x1d5\x05\x0f\x07\x05\x03\x03\x91\x03\x0f\x03\x07\x03\t\x03\x1d\x03\x17\x0b\x07\x03\x93\x037\x05\x13\x19\x03\x07\x03\x95\x039\x03\x1b\x05\x03\x03\x97\x03\x0b\x03\x07\x03\t\x03\x05\x03\x1f\x03\x07\x03\x99\x03\x17\x03\x1d\r\x06\x03\x03\x05\x07#\x11!\x19\x07\x05\x9b\x03\x05\x03\x05\x0f\x04\x0f\x05%'\t\x11\x05I\x07\x03\x17/\x03\x05\x05\x07\x03M\x1d\x03\r\x05\x03\x05Q\x03\x0f\x03\x07!\t\x03\r\x03\x05\x13\x06!\x03\r\x05\x03\x07\x07\x03WU\x03\r\x0b\x07][\x03%\x05\t\x0b\x03\x07ca\x03\x17\x03\r\x05\x03\x05'\x03\x0b\x03\x07g\t\x03\x05\x03\x11\r\x06k\x03\x05\x07\x0f\x13\x01\x0f\x04\x05\x03\x15\x06\x03\x01\x05\x01\x00\xbe\x1c\x8b!3-#\x11\x0f\x0b\t\t\x0b!\x11#\x0fY\x9d##%_=\x8b\x89W\xb9\xc1K\x9bM\x9bn\x03\x1b%)9\x1f/!!)#\x1f\x19+\x1b\x1f]\x1f\x15\x1d\x15\x13+\r\x11\x0f\x17\x0f\x1f\x15\x15\x17\x11\x11\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00iota_v1\x00func_v1\x00compare_v1\x00select_v1\x00return_v1\x00custom_call_v1\x00add_v1\x00reshape_v1\x00pad_v1\x00call_v1\x00value\x00broadcast_dimensions\x00sym_name\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00/release/jax/tests/export_back_compat_test.py\x00iota_dimension\x00compare_type\x00comparison_direction\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]\x00jit()/jit(main)/jit(triu)/add\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]\x00jit()/jit(main)/jit(triu)/ge\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(2, 3, 3) broadcast_dimensions=(1, 2)]\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(2, 3, 3) broadcast_dimensions=()]\x00jit()/jit(main)/jit(triu)/select_n\x00jit()/jit(main)/iota[dtype=float32 shape=(18,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(2, 3, 3) dimensions=None]\x00jit()/jit(main)/geqrf\x00jit()/jit(main)/qr[full_matrices=True]\x00edge_padding_high\x00edge_padding_low\x00interior_padding\x00jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0), (0, 0, 0))]\x00jit()/jit(main)/householder_product\x00callee\x00mhlo.layout_mode\x00default\x00jax.result_info\x00triu\x00[0]\x00[1]\x00main\x00public\x00private\x00\x00\x00\x00\x00\x02\x00\x00\x00\x03\x00\x00\x00\x03\x00\x00\x00\x00hipblas_geqrf_batched\x00\x00\x00\x00\x00\x02\x00\x00\x00\x03\x00\x00\x00\x03\x00\x00\x00\x03\x00\x00\x00\x80\x00\x00\x00\x00hipsolver_orgqr\x00", - xla_call_module_version=9, - nr_devices=1, -) # End paste diff --git a/jax/_src/internal_test_util/export_back_compat_test_util.py b/jax/_src/internal_test_util/export_back_compat_test_util.py index 5d5e95b5cb9a..b86b24e2b4fc 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_util.py +++ b/jax/_src/internal_test_util/export_back_compat_test_util.py @@ -90,6 +90,7 @@ def func(...): ... from jax.experimental import pjit from jax._src import core +from jax._src import stages from jax._src import test_util as jtu from jax._src import xla_bridge as xb @@ -165,7 +166,8 @@ def load_testdata_nested(self, testdata_nest) -> Iterable[CompatTestData]: else: assert False, testdata_nest - def run_one_test(self, func: Callable[..., jax.Array], + def run_one_test(self, + func: Callable[..., jax.Array] | stages.Wrapped, data: CompatTestData, polymorphic_shapes: Sequence[str] | None = None, rtol: float | None = None, @@ -176,7 +178,8 @@ def run_one_test(self, func: Callable[..., jax.Array], """Run one compatibility test. Args: - func: the JAX function to serialize and run + func: the JAX function to serialize and run, either as a Python Callable + or as a `jax.jit(callable)`. data: the test data polymorphic_shapes: when using shape polymorphism, the specification for each argument of `func`. @@ -269,19 +272,22 @@ def run_one_test(self, func: Callable[..., jax.Array], expect_current_custom_calls = data.custom_call_targets self.assertItemsEqual(expect_current_custom_calls, current_custom_call_targets) - def run_current(self, func: Callable, data: CompatTestData): + def run_current(self, + func: Callable | stages.Wrapped, + data: CompatTestData): """Lowers and runs the test function at the current JAX version.""" - return jax.jit(func)(*data.inputs) + jit_func = func if isinstance(func, stages.Wrapped) else jax.jit(func) + return jit_func(*data.inputs) def serialize(self, - func: Callable, data: CompatTestData, *, + func: Callable | stages.Wrapped, data: CompatTestData, *, polymorphic_shapes: Sequence[str] | None = None, allow_unstable_custom_call_targets: Sequence[str] = () ) -> tuple[bytes, str, int, int]: """Serializes the test function. Args: - func: the function to serialize + func: the function to serialize. polymorphic_shapes: the polymorphic_shapes to use for serialization allow_unstable_custom_call_targets: whether to allow additional custom call targets besides those known as stable. @@ -292,8 +298,9 @@ def serialize(self, """ # Use the native exporter, to make sure we get the proper serialization. args_specs = export.symbolic_args_specs(data.inputs, polymorphic_shapes) + jit_func = func if isinstance(func, stages.Wrapped) else jax.jit(func) exported = export.export( - jax.jit(func), + jit_func, platforms=(self.default_jax_backend(),), disabled_checks=tuple( export.DisabledSafetyCheck.custom_call(target) diff --git a/jax/_src/internal_test_util/lax_test_util.py b/jax/_src/internal_test_util/lax_test_util.py index 4e28791e9cee..767b41dc8ba0 100644 --- a/jax/_src/internal_test_util/lax_test_util.py +++ b/jax/_src/internal_test_util/lax_test_util.py @@ -304,7 +304,7 @@ def lax_ops(): float_dtypes, test_util.rand_uniform, { - np.float32: 1e-5, + np.float32: 2e-5, np.float64: 1e-12, }, ), diff --git a/jax/_src/internal_test_util/test_harnesses.py b/jax/_src/internal_test_util/test_harnesses.py index 48c645c4d033..b557434ac7f3 100644 --- a/jax/_src/internal_test_util/test_harnesses.py +++ b/jax/_src/internal_test_util/test_harnesses.py @@ -408,11 +408,11 @@ def parameterized(harnesses: Iterable[Harness], ############################################################################### -def _make_unary_elementwise_harness(*, prim, shape=(20, 20), dtype): +def _make_unary_elementwise_harness(*, prim, shape=(20, 20), dtype, **kwargs): define( str(prim), f"shape={jtu.format_shape_dtype_string(shape, dtype)}", - prim.bind, [RandArg(shape, dtype)], + lambda x: prim.bind(x, **kwargs), [RandArg(shape, dtype)], prim=prim, dtype=dtype, shape=shape) @@ -429,19 +429,19 @@ def _make_unary_elementwise_harness(*, prim, shape=(20, 20), dtype): _make_unary_elementwise_harness(prim=lax.acos_p, dtype=dtype) _make_unary_elementwise_harness(prim=lax.atan_p, dtype=dtype) _make_unary_elementwise_harness(prim=lax.asin_p, dtype=dtype) - _make_unary_elementwise_harness(prim=lax.cos_p, dtype=dtype) + _make_unary_elementwise_harness(prim=lax.cos_p, dtype=dtype, accuracy=None) _make_unary_elementwise_harness(prim=lax.cosh_p, dtype=dtype) - _make_unary_elementwise_harness(prim=lax.exp_p, dtype=dtype) - _make_unary_elementwise_harness(prim=lax.expm1_p, dtype=dtype) - _make_unary_elementwise_harness(prim=lax.log_p, dtype=dtype) - _make_unary_elementwise_harness(prim=lax.log1p_p, dtype=dtype) - _make_unary_elementwise_harness(prim=lax.rsqrt_p, dtype=dtype) - _make_unary_elementwise_harness(prim=lax.sin_p, dtype=dtype) + _make_unary_elementwise_harness(prim=lax.exp_p, dtype=dtype, accuracy=None) + _make_unary_elementwise_harness(prim=lax.expm1_p, dtype=dtype, accuracy=None) + _make_unary_elementwise_harness(prim=lax.log_p, dtype=dtype, accuracy=None) + _make_unary_elementwise_harness(prim=lax.log1p_p, dtype=dtype, accuracy=None) + _make_unary_elementwise_harness(prim=lax.rsqrt_p, dtype=dtype, accuracy=None) + _make_unary_elementwise_harness(prim=lax.sin_p, dtype=dtype, accuracy=None) _make_unary_elementwise_harness(prim=lax.sinh_p, dtype=dtype) - _make_unary_elementwise_harness(prim=lax.sqrt_p, dtype=dtype) - _make_unary_elementwise_harness(prim=lax.tan_p, dtype=dtype) - _make_unary_elementwise_harness(prim=lax.tanh_p, dtype=dtype) - _make_unary_elementwise_harness(prim=lax.logistic_p, dtype=dtype) + _make_unary_elementwise_harness(prim=lax.sqrt_p, dtype=dtype, accuracy=None) + _make_unary_elementwise_harness(prim=lax.tan_p, dtype=dtype, accuracy=None) + _make_unary_elementwise_harness(prim=lax.tanh_p, dtype=dtype, accuracy=None) + _make_unary_elementwise_harness(prim=lax.logistic_p, dtype=dtype, accuracy=None) for dtype in jtu.dtypes.all_floating: _make_unary_elementwise_harness(prim=lax.bessel_i0e_p, dtype=dtype) @@ -3375,8 +3375,9 @@ def _make_conv_harness(name, define( lax.rng_bit_generator_p, f"{key_dtype=}_shape={jtu.format_shape_dtype_string(shape, dtype)}_{algorithm=}", - lambda key, shape, dtype, algorithm: lax.rng_bit_generator(key, shape, dtype=dtype, - algorithm=algorithm), + lambda key, shape, dtype, algorithm, out_sharding=None: lax.rng_bit_generator( + key, shape, dtype=dtype, algorithm=algorithm, + out_sharding=out_sharding), [RandArg(key_shape, key_dtype), StaticArg(shape), StaticArg(dtype), StaticArg(algorithm)], shape=shape, diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index ddf96af6a010..98cda2df4964 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -194,6 +194,11 @@ def new_arg(trace, primal_aval, nz): tangent_trace.invalidate() if attrs_tracked: raise NotImplementedError("TODO: attrs") + tangent_jaxpr, used_consts, _ = pe.dce_jaxpr_consts( + tangent_jaxpr, [True] * len(tangent_jaxpr.outvars), + [False] * len(tangent_jaxpr.constvars) + [True] * len(tangent_jaxpr.invars)) + tangent_consts = [c for c, used in zip(tangent_consts, used_consts) if used] + residuals_and_primals = (*tangent_consts, *out_primals) residuals_and_primals = map(primal_trace.to_jaxpr_tracer, residuals_and_primals) primal_jaxpr, primal_consts, attrs_tracked = primal_trace.to_jaxpr(residuals_and_primals, debug_info) @@ -407,6 +412,12 @@ def write_primal(v, val): try: cts_out = get_primitive_transpose(eqn.primitive)( cts_in, *invals, **eqn.params) + except core.ShardingTypeError as e: + extra_msg = ("This is a potential JAX bug. Please file an issue at" + " https://github.com/jax-ml/jax/issues") + if extra_msg in str(e): + raise + raise core.ShardingTypeError(f"{str(e)}\n{extra_msg}") except (FloatingPointError, ZeroDivisionError) as e: msg = "When differentiating the code at the top of the callstack:" if msg not in e.args[0]: @@ -865,6 +876,10 @@ def make_zero(aval): for (r, nz) in zip(out_tangents, out_nzs) if nz] in_tracers = [t for t, nz in zip(tangent_args, nonzeros) if nz] jaxpr, out_consts, _ = pe.tracers_to_jaxpr(in_tracers, out_nz_tracers, jvp.debug_info) + jaxpr, used_consts, _ = pe.dce_jaxpr_consts( + jaxpr, [True] * len(jaxpr.outvars), + [False] * len(jaxpr.constvars) + [True] * len(jaxpr.invars)) + out_consts = [c for used, c in zip(used_consts, out_consts) if used] def linearized(residuals, *tangents): nz_tangents_in = [t for (t, nz) in zip(tangents, nonzeros) if nz] @@ -1238,3 +1253,11 @@ def __init__(self): # TODO(mattjj): remove this vestigial dict reducing_transposes: dict[core.Primitive, Callable] = {} + +########################### pvary ################################## + +def _pvary_transpose_rule(cts, *_, axes, axis_index_groups): + from jax._src.lax import parallel as lax_parallel + return lax_parallel.psum_invariant_p.bind( + *cts, axes=axes, axis_index_groups=axis_index_groups) +deflinear2(core.pvary_p, _pvary_transpose_rule) diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index 03c9a95105d7..c97c8d558608 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -29,9 +29,7 @@ from jax._src.partition_spec import PartitionSpec as P from jax._src.sharding_impls import NamedSharding from jax._src import mesh as mesh_lib -from jax._src.ad_util import (Zero, instantiate, SymbolicZero, - replace_rule_output_symbolic_zeros, - add_jaxvals, add_jaxvals_p) +from jax._src.ad_util import Zero, SymbolicZero, add_jaxvals, add_jaxvals_p from jax._src.core import Trace, Tracer, TraceTag, AxisName from jax._src.interpreters import partial_eval as pe from jax._src.tree_util import (tree_unflatten, tree_flatten, @@ -408,6 +406,10 @@ def __init__(self, trace, val, batch_dim: NotMapped | int | RaggedAxis, @property def aval(self): aval = core.get_aval(self.val) + if self._trace.axis_data.spmd_name is not None: + if config._check_rep.value: + aval = aval.update( + vma=aval.vma - frozenset(self._trace.axis_data.spmd_name)) if self.batch_dim is not_mapped: return aval elif type(self.batch_dim) is int: @@ -565,12 +567,9 @@ def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): in_vals, in_dims = unzip2(map(self.to_batch_info, tracers)) fun, out_dims1 = batch_subtrace(fun, self.tag, self.axis_data, in_dims) jvp, out_dims2 = batch_custom_jvp_subtrace(jvp, self.tag, self.axis_data, in_dims) - out_vals = prim.bind_with_trace(self.parent_trace, (fun, jvp) + tuple(in_vals), + out_vals = prim.bind_with_trace(self.parent_trace, (fun, jvp, *in_vals), dict(symbolic_zeros=symbolic_zeros)) fst, out_dims = lu.merge_linear_aux(out_dims1, out_dims2) - if not fst: - assert out_dims == out_dims[:len(out_dims) // 2] * 2 - out_dims = out_dims[:len(out_dims) // 2] src = source_info_util.current() return [BatchTracer(self, v, d, src) for v, d in zip(out_vals, out_dims)] @@ -616,17 +615,15 @@ def _batch_inner(f: Callable, axis_data, out_dim_dests, tag, in_dims, *in_vals): trace = BatchTrace(parent_trace, tag, axis_data) idx = memoize(lambda: BatchTracer(trace, make_iota(axis_data.size), 0, source_info_util.current())) - in_tracers = map(partial(to_elt, trace, idx), in_vals, in_dims) + with core.set_current_trace(parent_trace): + in_tracers = map(partial(to_elt, trace, idx), in_vals, in_dims) with (core.set_current_trace(trace), core.extend_axis_env_nd([(axis_data.name, axis_data.size)]), core.add_spmd_axis_names(axis_data.spmd_name)): outs = f(*in_tracers) - - out_dim_dests = out_dim_dests() if callable(out_dim_dests) else out_dim_dests - out_vals = map(partial(from_elt, trace, axis_data.size, - axis_data.explicit_mesh_axis), - range(len(outs)), outs, out_dim_dests) - + out_dim_dests = out_dim_dests() if callable(out_dim_dests) else out_dim_dests + out_vals = map(partial(from_elt, trace, axis_data.size, axis_data.explicit_mesh_axis), + range(len(outs)), outs, out_dim_dests) return out_vals, trace # NOTE: This divides the in_axes by the tile_size and multiplies the out_axes by it. @@ -771,10 +768,17 @@ def _batch_jaxpr2( handle_ragged(closed_jaxpr.in_avals, dim, aval) if isinstance(dim, RaggedAxis) else (dim, aval) for dim, aval in zip(in_axes, closed_jaxpr.in_avals)]) - avals_in2 = [core.unmapped_aval(axis_data.size, b, aval, - axis_data.explicit_mesh_axis) - if b is not not_mapped else aval - for aval, b in unsafe_zip(avals_in, in_axes2)] + avals_in2 = [] + for aval, b in unsafe_zip(avals_in, in_axes2): + if b is not_mapped: + avals_in2.append(aval) + else: + aval = core.unmapped_aval( + axis_data.size, b, aval, axis_data.explicit_mesh_axis) + if axis_data.spmd_name is not None: + if config._check_rep.value: + aval = aval.update(vma=aval.vma | frozenset(axis_data.spmd_name)) # type: ignore + avals_in2.append(aval) jaxpr_out, _, consts, () = pe.trace_to_jaxpr_dynamic(f, avals_in2) return core.ClosedJaxpr(jaxpr_out, consts), out_axes() @@ -888,20 +892,16 @@ def batch_custom_jvp_subtrace(f, store, tag, axis_data, in_dims, *in_vals): if type(val) is SymbolicZero else BatchTracer(trace, val, dim) for val, dim in zip(in_vals, in_dims * 2)] with core.set_current_trace(trace): - outs = f(*in_tracers) - # TODO(mattjj,frostig): instantiating any SymbolicZero output is easy, but can - # be wasteful in the rare case it actually triggers; handle symbolically! - outs = [instantiate(replace_rule_output_symbolic_zeros(x)) for x in outs] - - out_vals, out_dims = unzip2(map(trace.to_batch_info, outs)) + out_tracers: list[BatchTracer | SymbolicZero] = f(*in_tracers) + out_vals, out_dims = unzip2(map(trace.to_batch_info, out_tracers)) out_primals, out_tangents = split_list(out_vals, [len(out_vals) // 2]) out_primal_bds, out_tangent_bds = split_list(out_dims, [len(out_vals) // 2]) out_dims = map(_merge_bdims, out_primal_bds, out_tangent_bds) out_primals = map(partial(matchaxis, trace.axis_data.name, size, mesh_axis), out_primal_bds, out_dims, out_primals) - out_tangents = map(partial(matchaxis, trace.axis_data.name, size, mesh_axis), + out_tangents = map(partial(_matchaxis_symzeros, trace.axis_data.name, size, mesh_axis), out_tangent_bds, out_dims, out_tangents) - store.store(out_dims * 2) + store.store(out_dims) return out_primals + out_tangents def batch_custom_vjp_bwd(bwd: lu.WrappedFun, tag: core.TraceTag, @@ -929,12 +929,11 @@ def _match_axes_and_sum(f, axis_size, axis_name, mesh_axis, out_dims_thunk, out_dim_dests, *in_vals): # this is like _match_axes, but we do reduce-sums as needed out_vals = f(*in_vals) - return map(partial(_matchaxis_symbolic_zeros, axis_name, axis_size, mesh_axis, - axis_name, sum_match=True), + return map(partial(_matchaxis_symzeros, axis_name, axis_size, mesh_axis, + sum_match=True), out_dims_thunk(), out_dim_dests, out_vals) -def _matchaxis_symbolic_zeros(axis_name, sz, mesh_axis, name, src, dst, x, - sum_match=False): +def _matchaxis_symzeros(axis_name, sz, mesh_axis, src, dst, x, sum_match=False): # Just like `matchaxis`, but handles symbolic zeros using ad_util.py # TODO(mattjj): dedup with matchaxis if isinstance(x, (Zero, SymbolicZero)): @@ -1111,8 +1110,15 @@ def broadcast(x, sz, axis, mesh_axis=None): # TODO(dougalm, yashkatariya): Delete this context manager once we figure # out how to ensure jaxpr arguments always have the context mesh. with mesh_lib.use_abstract_mesh(sharding.mesh): - return jax.lax.broadcast_in_dim(x, shape, broadcast_dims, - out_sharding=sharding) + x = jax.lax.broadcast_in_dim(x, shape, broadcast_dims, out_sharding=sharding) + if config._check_rep.value: + # TODO(yashkatariya,parkers): don't do this, fix during fixit week 2026 + spmd_names = core.get_axis_env().spmd_axis_names + if len(spmd_names) > 1: + raise NotImplementedError + if spmd_names: + x = core.pvary(x, tuple(spmd_names)) + return x def matchaxis(axis_name, sz, mesh_axis, src, dst, x, sum_match=False): if dst == jumble_axis: @@ -1169,3 +1175,17 @@ def add_batched(batched_args, batch_dims): x = moveaxis(x, bdx, bdy) return add_jaxvals(x, y), bdy primitive_batchers[add_jaxvals_p] = add_batched + +########################### core. ################################## + +def _pvary_batcher(vals_in, dims_in, *, axes, axis_index_groups): + if any(type(axis) is int for axis in axes): + raise NotImplementedError + vals_out = core.pvary_p.bind(*vals_in, axes=axes, + axis_index_groups=axis_index_groups) + return vals_out, dims_in +primitive_batchers[core.pvary_p] = _pvary_batcher + +### mutable arrays + +defvectorized(core.mutable_array_p) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 1369f72ac74c..9979ea151b76 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -49,13 +49,14 @@ from jax._src.interpreters import xla from jax._src.layout import AutoLayout, DeviceLocalLayout from jax._src.partition_spec import PartitionSpec +from jax._src.mesh import AxisType from jax._src.sharding import Sharding as JSharding from jax._src.sharding_impls import (AUTO, NamedSharding, modify_sdy_sharding_wrt_axis_types, SdyArraySharding, SdyArrayShardingList) from jax._src.util import foreach from jax._src.lib import xla_client as xc -from jax._src.lib import xla_extension, xla_extension_version +from jax._src.lib import xla_extension from jax._src.lib.mlir import dialects, ir, passmanager from jax._src.lib.mlir.dialects import func as func_dialect, hlo from jax._src.lib.mlir import register_jax_dialects @@ -96,7 +97,6 @@ def _is_not_block_argument(x: IrValues) -> bool: - """Returns true if `x` is not a block argument.""" return not isinstance(x, ir.BlockArgument) @@ -185,24 +185,14 @@ def _is_ir_values(x: IrValues) -> bool: np.dtype(np.float64): ir.F64Type.get, np.dtype(np.complex64): lambda: ir.ComplexType.get(ir.F32Type.get()), np.dtype(np.complex128): lambda: ir.ComplexType.get(ir.F64Type.get()), + np.dtype(dtypes.int2): partial(ir.IntegerType.get_signless, 2), + np.dtype(dtypes.uint2): partial(ir.IntegerType.get_unsigned, 2), + np.dtype(dtypes.float8_e3m4): ir.Float8E3M4Type.get, + np.dtype(dtypes.float8_e4m3): ir.Float8E4M3Type.get, + np.dtype(dtypes.float8_e8m0fnu): ir.Float8E8M0FNUType.get, + np.dtype(dtypes.float4_e2m1fn): ir.Float4E2M1FNType.get, } - -if dtypes.int2 is not None: - assert dtypes.uint2 is not None - _dtype_to_ir_type[np.dtype(dtypes.int2)] = partial(ir.IntegerType.get_signless, 2) - _dtype_to_ir_type[np.dtype(dtypes.uint2)] = partial(ir.IntegerType.get_unsigned, 2) - -if dtypes.float8_e3m4 is not None: - _dtype_to_ir_type[np.dtype(dtypes.float8_e3m4)] = ir.Float8E3M4Type.get -if dtypes.float8_e4m3 is not None: - _dtype_to_ir_type[np.dtype(dtypes.float8_e4m3)] = ir.Float8E4M3Type.get -if dtypes.float8_e8m0fnu is not None: - _dtype_to_ir_type[np.dtype(dtypes.float8_e8m0fnu)] = ir.Float8E8M0FNUType.get - -if dtypes.float4_e2m1fn is not None: - _dtype_to_ir_type[np.dtype(dtypes.float4_e2m1fn)] = ir.Float4E2M1FNType.get - def dtype_to_ir_type(dtype: core.bint | np.dtype | np.generic) -> ir.Type: if isinstance(dtype, core.bint): # TODO Support different-size underlying dtypes to take advantage of the @@ -593,6 +583,15 @@ def module_to_bytecode(module: ir.Module) -> bytes: # Translation rules +# Create one global thread pool that can be shared between multiple ir.Contexts +# and enabling multi-threading +# TODO: remove this check after jaxlib 0.5.4 +if hasattr(ir, "ThreadPool"): + global_thread_pool = ir.ThreadPool() +else: + global_thread_pool = None + + class JaxIrContext(ir.Context): def __init__(self, *args, **kwargs): # Note: we're very intentionally *not* calling the __init__() of our @@ -607,15 +606,17 @@ def make_ir_context() -> ir.Context: context.append_dialect_registry(upstream_dialects) context.load_all_available_dialects() - # If threading is enabled, each MLIR context will keep alive a thread pool. - # Since we cache MLIR modules (and hence contexts), this means we might keep - # several threads alive for each cache entry. This is a terrible idea. However - # we don't do any heavy computation on MLIR modules from Python anyway, so we - # just disable threading. - context.enable_multithreading(False) - # TODO(bartchr): Once JAX is released with SDY, remove the if. - if dialects.sdy: - dialects.sdy.register_dialect(context) + # TODO: remove this check after v0.5.4 jaxlib + if global_thread_pool is not None: + context.set_thread_pool(global_thread_pool) + else: + # If threading is enabled, each MLIR context will keep alive a thread pool. + # Since we cache MLIR modules (and hence contexts), this means we might keep + # several threads alive for each cache entry. This is a terrible idea. However + # we don't do any heavy computation on MLIR modules from Python anyway, so we + # just disable threading. + context.enable_multithreading(False) + dialects.sdy.register_dialect(context) dialects.mhlo.register_mhlo_dialect(context) dialects.chlo.register_dialect(context) dialects.hlo.register_dialect(context) @@ -662,7 +663,7 @@ def __init__(self, @dataclasses.dataclass(frozen=True) class LoweringParameters: # A mapping between primitives and user-defined LoweringRules. - # When lowering a primitive, give priorioty to the rule in this map over + # When lowering a primitive, give priority to the rule in this map over # existing Jax rules. override_lowering_rules: tuple[tuple[core.Primitive, LoweringRule]] | None = None @@ -676,7 +677,7 @@ class LoweringParameters: # Signals that we are lowering for exporting. for_export: bool = False - # See usage in https://jax.readthedocs.io/en/latest/export/export.html#ensuring-forward-and-backward-compatibility + # See usage in https://docs.jax.dev/en/latest/export/export.html#ensuring-forward-and-backward-compatibility # We have this here to ensure it is reflected in the cache keys export_ignore_forward_compatibility: bool = False @@ -834,10 +835,17 @@ def is_forward_compat(self) -> bool: """Returns true if the lowering parameters are in forward compatibility mode. """ lowering_parameters = self.module_context.lowering_parameters - return ( - lowering_parameters.for_export - and not lowering_parameters.export_ignore_forward_compatibility + + check_platforms: Sequence[str] = ( + self.platforms or self.module_context.platforms ) + force_forward_compat = any( + p in xb.FORCE_FORWARD_COMPAT_LOWERING_PLATFORMS for p in check_platforms + ) + + return ( + lowering_parameters.for_export or force_forward_compat + ) and not lowering_parameters.export_ignore_forward_compatibility if not MYPY: @@ -1010,18 +1018,29 @@ class LoweringResult(NamedTuple): def add_manual_axes(axis_ctx: sharding_impls.SPMDAxisContext, sharding, ndim): - mesh = axis_ctx.mesh + mesh = axis_ctx.mesh.abstract_mesh + sharding_mesh = sharding.mesh.abstract_mesh if (isinstance(sharding, sharding_impls.NamedSharding) and - sharding.mesh.shape == mesh.shape): - return sharding_impls.NamedSharding( - sharding.mesh, sharding.spec, memory_kind=sharding.memory_kind, - _manual_axes=axis_ctx.manual_axes) + sharding_mesh.shape == mesh.shape): + out_mesh, spec = sharding_mesh, sharding.spec else: - spec = sharding_impls.parse_flatten_op_sharding( + out_mesh, spec = mesh, sharding_impls.parse_flatten_op_sharding( sharding._to_xla_hlo_sharding(ndim), mesh)[0] - return sharding_impls.NamedSharding( - mesh, spec, memory_kind=sharding.memory_kind, - _manual_axes=axis_ctx.manual_axes) + + out_mesh = out_mesh.update_axis_types( + {a: AxisType.Manual for a in axis_ctx.manual_axes}) + out = sharding_impls.NamedSharding(out_mesh, spec, + memory_kind=sharding.memory_kind) + manual_axes = out.mesh.manual_axes + if any(p in manual_axes for s in out.spec + if s is not None and s is not PartitionSpec.UNCONSTRAINED + for p in (s if isinstance(s, tuple) else (s,))): + raise ValueError( + f'pspec {out.spec} contains a manual axes {manual_axes} of mesh' + f' which is not allowed. If you are using a' + ' with_sharding_constraint under a shard_map, only use the' + ' mesh axis in PartitionSpec which are not manual.') + return out def _to_physical_op_sharding( @@ -1171,7 +1190,7 @@ def lower_jaxpr_to_module( donated_args[input_id] = False if any(donated_args): unused_donations = [str(a) for a, d in zip(in_avals, donated_args) if d] - msg = "See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation." + msg = "See an explanation at https://docs.jax.dev/en/latest/faq.html#buffer-donation." if not platforms_with_donation: msg = f"Donation is not implemented for {platforms}.\n{msg}" if unused_donations: @@ -1788,7 +1807,7 @@ def replicate_trailing_dims(ctx, val: ir.Value, aval) -> ir.Value: s = SdyArraySharding( mesh_shape=None, dimension_shardings=[ - sharding_impls.SdyDimSharding(axes=[], is_closed=i >= aval.ndim) + sharding_impls.SdyDimSharding(axes=[], is_open=i < aval.ndim) for i in range(physical_ndim) ]) return wrap_with_sharding_op(ctx, val, aval, s) @@ -2326,21 +2345,24 @@ def core_call_lowering(ctx: LoweringRuleContext, register_lowering(core.closed_call_p, partial(core_call_lowering, name=None)) -def map_compute_type(c_type): - if c_type == 'device_host': - return 'host' - elif c_type == 'device': - return 'dense' - elif c_type == 'tpu_sparsecore': - return 'sparse' - raise ValueError(f'Invalid compute type {c_type}. Current supported values ' - 'are `device_host`, `device` and `tpu_sparsecore') - -def wrap_compute_type_in_place(ctx, op): +def map_compute_type(c_type: str) -> str: + if c_type == "device_host": + return "host" + elif c_type == "device": + return "dense" + elif c_type == "tpu_sparsecore": + return "sparse" + raise ValueError(f"Invalid compute type {c_type}. Current supported values " + "are `device_host`, `device` and `tpu_sparsecore`") + +def wrap_compute_type_in_place(ctx: LoweringRuleContext, op: ir.Operation) -> None: if ctx.jaxpr_eqn_ctx is not None and ctx.jaxpr_eqn_ctx.compute_type is not None: if ctx.jaxpr_eqn_ctx.compute_type.startswith("gpu_stream:"): stream = ctx.jaxpr_eqn_ctx.compute_type.split(":")[1] - dict_attr = {"_xla_stream_annotation": ir.StringAttr.get(stream)} + dict_attr = { + "_xla_stream_annotation": ir.StringAttr.get(stream), + "inlineable": ir.StringAttr.get("false"), + } op.operation.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get(dict_attr) else: dict_attr = {"_xla_compute_type": ir.StringAttr.get( @@ -2348,7 +2370,7 @@ def wrap_compute_type_in_place(ctx, op): op.operation.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get(dict_attr) -def wrap_xla_metadata_in_place(ctx, op): +def wrap_xla_metadata_in_place(ctx: LoweringRuleContext, op: ir.Operation) -> None: ctx_attributes = {} existing_attributes = {} if ctx.jaxpr_eqn_ctx is not None and ctx.jaxpr_eqn_ctx.xla_metadata: @@ -2746,11 +2768,6 @@ def cached_lowering(ctx, *args, **params): return cached_lowering -def xla_computation_to_mlir_module(xla_computation: xc.XlaComputation - ) -> ir.Module: - module_str = xc._xla.mlir.xla_computation_to_mlir_module(xla_computation) - return ir.Module.parse(module_str) - def merge_mlir_modules(dst_module: ir.Module, sym_name: str, src_module: ir.Module, @@ -3031,11 +3048,8 @@ def refine_polymorphic_shapes(module: ir.Module) -> ir.Module: mlir_module=module_to_bytecode(module), enable_shape_assertions=True, validate_static_shapes=True) - if xla_extension_version >= 319: - refined_module_str = refine_polymorphic_shapes( - enable_shardy=config.use_shardy_partitioner.value) - else: - refined_module_str = refine_polymorphic_shapes() + refined_module_str = refine_polymorphic_shapes( + enable_shardy=config.use_shardy_partitioner.value) except Exception as e: raise ValueError( "Error refining shapes. " + @@ -3044,3 +3058,7 @@ def refine_polymorphic_shapes(module: ir.Module) -> ir.Module: context = make_ir_context() with context: return ir.Module.parse(refined_module_str) + +########################### pvary ################################## + +register_lowering(core.pvary_p, lambda ctx, *x, axes, axis_index_groups: x) diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 07c516fd95c7..f8ce92e7f97f 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -15,7 +15,7 @@ from collections import namedtuple from collections.abc import Callable, Sequence, Hashable -from contextlib import contextmanager +import contextlib from functools import partial import itertools as it import operator as op @@ -42,8 +42,8 @@ mapped_aval, unmapped_aval, DBIdx, InDBIdx, OutDBIdx, InputType, OutputType, get_referent, JaxprEqnContext) from jax._src.state.types import AbstractRef, ReadEffect -from jax._src.tree_util import (PyTreeDef, treedef_tuple, - tree_flatten, tree_structure) +from jax._src.tree_util import (PyTreeDef, treedef_tuple, tree_flatten, + tree_structure, register_static) from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list, merge_lists, partition_list, OrderedSet, as_hashable_function, weakref_lru_cache, subs_list, @@ -58,6 +58,13 @@ def identity(x): return x AvalId = int ConstId = int +AttrKind = Any +PyTree = Any + +# Attrs flavors, see jax/experimental/attrs.py +ReadWrite = type('ReadWrite', (), {})() +Append = type('Append', (), {})() + def _update_annotation_known( f: lu.WrappedFun, orig_type: InputType | None, @@ -395,8 +402,7 @@ def process_custom_jvp_call(self, prim, fun, jvp, tracers, symbolic_zeros): vals = [t.pval[1] for t in tracers] return prim.bind(fun, jvp, *vals, symbolic_zeros=symbolic_zeros) # We assume non-trivial partial evaluation is only performed to build linear - # functions, and hence we don't need to keep the custom JVP rule around - # anymore. + # functions, and hence we don't need to keep the custom JVP rule around. del jvp, symbolic_zeros with core.set_current_trace(self): return fun.call_wrapped(*tracers) @@ -494,6 +500,24 @@ def partial_eval_wrapper_nounits( store.store((*maybe_fwds, out_knowns, out_avals, jaxpr, env)) return (*out_consts, *res) +@lu.transformation_with_aux2 +def partial_eval_wrapper_nounits2( + f: Callable, + store: lu.Store, + in_knowns: Sequence[bool], + in_avals: Sequence[AbstractValue], + *in_consts: Any): + in_avals_, in_consts_ = iter(in_avals), iter(in_consts) + in_pvals = [PartialVal.known(next(in_consts_)) if known else + PartialVal.unknown(next(in_avals_)) for known in in_knowns] + sentinel = object() + assert next(in_avals_, sentinel) is next(in_consts_, sentinel) is sentinel + jaxpr, (*maybe_fwds, out_pvals, res, env) = f(in_pvals) + out_knowns, _, out_consts = partition_pvals(out_pvals) + res_avals = [core.typeof(r) for r in res] + store.store((*maybe_fwds, out_knowns, res_avals, jaxpr, env)) + return (*out_consts, *res) + custom_partial_eval_rules: dict[Primitive, Callable] = {} call_partial_eval_rules: dict[Primitive, Callable] = {} call_param_updaters: dict[Primitive, Callable] = {} @@ -1229,14 +1253,12 @@ def _default_res_aval_updater( params: dict[str, Any], aval: AbstractValue) -> AbstractValue: return aval -@contextmanager -def trivial_ctx(_): yield def call_partial_eval_custom_rule( jaxpr_param_name: str, params_updater: ParamsUpdater, saveable: Callable[..., RematCases_], unks_in: list[bool], inst_in: list[bool], eqn: JaxprEqn, *, res_aval: ResAvalUpdater = _default_res_aval_updater, - ctx = trivial_ctx, + ctx = contextlib.nullcontext, ) -> tuple[JaxprEqn, JaxprEqn, Sequence[bool], Sequence[bool], list[Var]]: jaxpr = eqn.params[jaxpr_param_name] with ctx(eqn.params): @@ -1537,10 +1559,14 @@ def _move_binders_to_front(closed_jaxpr: ClosedJaxpr, to_move: tuple[bool, ...] ) -> ClosedJaxpr: assert len(closed_jaxpr.in_avals) == len(to_move) new_invars = _move_to_front(closed_jaxpr.jaxpr.invars, to_move) + id_map = {id(v): i for i, v in enumerate(new_invars)} + idx_map = {i: id_map[id(v)] for i, v in enumerate(closed_jaxpr.jaxpr.invars)} + new_effs = {e.replace(input_index=idx_map[e.input_index]) + if isinstance(e, effects.JaxprInputEffect) else e + for e in closed_jaxpr.jaxpr.effects} new_jaxpr = Jaxpr(closed_jaxpr.jaxpr.constvars, new_invars, closed_jaxpr.jaxpr.outvars, closed_jaxpr.jaxpr.eqns, - closed_jaxpr.jaxpr.effects, - closed_jaxpr.jaxpr.debug_info) + new_effs, closed_jaxpr.jaxpr.debug_info) new_closed_jaxpr = core.ClosedJaxpr(new_jaxpr, closed_jaxpr.consts) return new_closed_jaxpr @@ -1553,6 +1579,17 @@ def move_binders_to_back(closed_jaxpr: ClosedJaxpr, to_move: Sequence[bool] """Reorder `invars` by moving those indicated in `to_move` to the back.""" return move_binders_to_front(closed_jaxpr, map(op.not_, to_move)) +def move_outvars_to_back(jaxpr: ClosedJaxpr, to_move: Sequence[bool]) -> ClosedJaxpr: + return _move_outvars_to_back(jaxpr, tuple(to_move)) + +@weakref_lru_cache +def _move_outvars_to_back(jaxpr, to_move): + new_outvars = ([e for e, m in zip(jaxpr.jaxpr.outvars, to_move) if not m] + + [e for e, m in zip(jaxpr.jaxpr.outvars, to_move) if m]) + return jaxpr.replace(jaxpr=jaxpr.jaxpr.replace(outvars=new_outvars)) + + + class DynamicJaxprTracer(core.Tracer): __slots__ = ['aval', '_debug_info'] @@ -1657,7 +1694,7 @@ class JaxprStackFrame: eqns: list[JaxprEqn] invars: list[Var] effects: core.Effects - attrs_tracked: list[tuple[Any, str]] + attrs_tracked: list[tuple[Any, str, AttrKind]] attrs_inits: list attrs_vars: list[Var] debug_info: core.DebugInfo @@ -1679,10 +1716,14 @@ def __init__(self, debug_info: core.DebugInfo): def add_eqn(self, eqn: core.JaxprEqn): self.eqns.append(eqn) - def to_jaxpr(self, trace: DynamicJaxprTrace, - out_tracers: Sequence[Tracer], - debug_info: core.DebugInfo, - ) -> tuple[Jaxpr, list[Any], list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]: + def reset_states(self): + reset_states(self.attrs_tracked, self.attrs_inits) + + def to_jaxpr( + self, trace: DynamicJaxprTrace, + out_tracers: Sequence[Tracer], + debug_info: core.DebugInfo, + ) -> tuple[Jaxpr, list[Any], list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str, AttrKind]]]]: # It's not necessary, but we keep the tracer-to-var mapping injective: assert len(self.tracer_to_var) == len(set(self.tracer_to_var.values())) invars = self.attrs_vars + self.invars @@ -1699,7 +1740,6 @@ def to_jaxpr(self, trace: DynamicJaxprTrace, jaxpr, constvals = _const_folding_and_forwarding(jaxpr, constvals) jaxpr, constvals = _inline_literals(jaxpr, constvals) init_trees = [tree_structure(init_val) for init_val in self.attrs_inits] - set_states(self.attrs_tracked, self.attrs_inits) return jaxpr, list(constvals), zip(init_trees, end_trees, self.attrs_tracked) def to_jaxpr2(self, out_tracers: Sequence[core.Tracer], @@ -1840,10 +1880,9 @@ def vars_in_shape(aval: AbstractValue) -> Sequence[Var]: outvars = [var(v) if v in used else dropvar(v.aval) for v in eqn.outvars] new_eqns.append(eqn.replace(invars=invars, outvars=outvars)) new_outvars = [lit_or_var(v) for v in jaxpr.outvars] - jaxpr_effects = make_jaxpr_effects(new_constvars, new_invars, new_outvars, - new_eqns) - new_jaxpr = Jaxpr(new_constvars, new_invars, new_outvars, new_eqns, - jaxpr_effects, jaxpr.debug_info) + effs = make_jaxpr_effects(new_constvars, new_invars, new_outvars, new_eqns) + new_jaxpr = Jaxpr(new_constvars, new_invars, new_outvars, new_eqns, effs, + jaxpr.debug_info) return new_jaxpr, new_constvals @@ -1858,6 +1897,7 @@ def invalidate(self): # avoid cyclic refs self.frame.tracers = [] self.frame.constid_to_tracer = {} + self.frame.constvar_to_val = {} def to_jaxpr_tracer(self, x): as_local_var = self.frame.tracer_to_var.get(id(x)) @@ -2172,19 +2212,23 @@ def trace_to_jaxpr_dynamic( *, keep_inputs: list[bool] | None = None, ) -> tuple[Jaxpr, list[AbstractValue], list[Any], - list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]: + list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str, AttrKind]]]]: keep_inputs = [True] * len(in_avals) if keep_inputs is None else keep_inputs trace = DynamicJaxprTrace(fun.debug_info) with core.ensure_no_leaks(trace), source_info_util.reset_name_stack(): in_tracers = _input_type_to_tracers(trace.new_arg, in_avals) in_tracers = [t for t, keep in zip(in_tracers, keep_inputs) if keep] - with core.set_current_trace(trace): - ans = fun.call_wrapped(*in_tracers) + try: + with core.set_current_trace(trace): + ans = fun.call_wrapped(*in_tracers) - out_tracers = map(trace.to_jaxpr_tracer, ans) - _check_no_returned_refs(fun.debug_info, out_tracers) - jaxpr, consts, attrs_tracked = trace.to_jaxpr(out_tracers, fun.debug_info) - del trace, fun, in_tracers, out_tracers, ans + out_tracers = map(trace.to_jaxpr_tracer, ans) + _check_no_returned_refs(fun.debug_info, out_tracers) + jaxpr, consts, attrs_tracked = trace.to_jaxpr(out_tracers, fun.debug_info) + del fun, in_tracers, out_tracers, ans + finally: + trace.frame.reset_states() + del trace config.enable_checks.value and core.check_jaxpr(jaxpr) return jaxpr, [v.aval for v in jaxpr.outvars], consts, attrs_tracked @@ -2242,14 +2286,18 @@ def trace_to_jaxpr_dynamic2( tuple[AbstractedAxisName, ...], ] -AttrsTracked = list[tuple[Any, str]] +AttrsTracked = list[tuple[Any, str, AttrKind]] AttrStates = list -def set_states(attrs_tracked: AttrsTracked, vals: AttrStates): - for ((obj, attr), val) in zip(attrs_tracked, vals): - setattr(obj, attr, val) +def reset_states(attrs_tracked: AttrsTracked, init_vals: AttrStates) -> None: + for ((obj, attr, _), val) in zip(attrs_tracked, init_vals): + setattr(obj, attr, val) if val is not dne_sentinel else delattr(obj, attr) + +def get_states(attrs_tracked: AttrsTracked) -> list[PyTree]: + return [getattr(obj, attr) for (obj, attr, kind) in attrs_tracked] -def get_states(attrs_tracked: AttrsTracked): - return [getattr(obj, attr) for (obj, attr) in attrs_tracked] +@register_static +class DoesNotExist: ... +dne_sentinel = DoesNotExist() def infer_lambda_input_type( diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index c06eda5214ed..d5a18ad2f439 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -57,6 +57,7 @@ from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import mlir from jax._src.interpreters import xla +from jax._src.lib import jaxlib_extension_version from jax._src.layout import DeviceLocalLayout, AutoLayout, Layout from jax._src.lib import xla_client as xc from jax._src.lib.mlir import ir @@ -257,7 +258,7 @@ def _shard_abstract_array(size, axis: int, x): raise ValueError(f"Axis size {size} does not match dimension {axis} of " f"shape {x.shape}") except IndexError: - raise ValueError("Cannot split a {x.dim}D value along axis {axis}") from None + raise ValueError(f"Cannot split a {x.dim}D value along axis {axis}") from None if config.pmap_no_rank_reduction.value: return x.update(shape=tuple_update(x.shape, axis, 1)) else: @@ -338,8 +339,8 @@ def xla_pmap_impl_lazy( donated_invars: Sequence[bool], is_explicit_global_axis_size: bool, ) -> Callable: - if (config.disable_jit.value and config.eager_pmap.value and - not is_explicit_global_axis_size and not any(d for d in donated_invars)): + if (config.disable_jit.value and + not is_explicit_global_axis_size and not any(donated_invars)): def _emap_apply_fn(*args): return _emap_impl(fun, *args, backend=backend, axis_name=axis_name, axis_size=axis_size, global_axis_size=global_axis_size, @@ -921,7 +922,7 @@ def _pmap_unmapped_aval(size: core.AxisSize, axis: int | None, raise TypeError(f"no unmapping handler for {aval} of type {type(aval)}") -class PmapComputation(stages.XlaLowering): +class PmapComputation(stages.Lowering): _hlo: ir.Module _executable: PmapExecutable | None @@ -930,7 +931,7 @@ def __init__(self, hlo: ir.Module, **compile_args): self._hlo = hlo self.compile_args = compile_args - # -- stages.XlaLowering overrides + # -- stages.Lowering overrides def stablehlo(self) -> ir.Module: return self._hlo @@ -1115,7 +1116,7 @@ def from_hlo(hlo: ir.Module, jaxpr_debug_info=jaxpr_debug_info).load() -class PmapExecutable(stages.XlaExecutable): +class PmapExecutable(stages.Executable): __slots__ = ["xla_executable", "_unsafe_call", "build_unsafe_call", "fingerprint", "in_avals", "_unloaded_executable"] @@ -1135,7 +1136,7 @@ def unsafe_call(self) -> Callable[..., Any]: self._unsafe_call = self.build_unsafe_call() return self._unsafe_call # type: ignore - # -- stages.XlaExecutable overrides + # -- stages.Executable overrides def xla_extension_executable(self): return self.xla_executable @@ -1314,7 +1315,10 @@ def __call__(self, *args): out_ = [] for i, o in zip(self.mut.out_mut, out): if i is not None: - args[i]._buf = o + if jaxlib_extension_version < 330: + args[i]._buf = o + else: + args[i]._buf._replace_with(o) else: out_.append(o) return out_ @@ -2281,6 +2285,16 @@ def lower_sharding_computation( devices_from_context) unique_intermediate_shardings = [js for js, _ in unique_intermediate_shardings] + for a in global_out_avals: + if (a is not core.abstract_token and not a.sharding.mesh.empty and + a.sharding.mesh._are_all_axes_explicit and + len(device_assignment) != a.sharding.mesh.size): + raise ValueError( + f"Length of device assignment {len(device_assignment)} is not equal" + f" to the size of the mesh {a.sharding.mesh.size} of aval" + f" {a.str_short(True, True)}. Please enter your `jit` into a mesh" + " context via `jax.sharding.use_mesh`.") + # TODO(parkers): One _raw_platform has been unified with platform, # change this back to just read platform. platforms = lowering_platforms or ( @@ -2419,7 +2433,7 @@ def _to_logical_sharding( raise TypeError(aval) -class MeshComputation(stages.XlaLowering): +class MeshComputation(stages.Lowering): _hlo: ir.Module _executable: MeshExecutable | None @@ -2435,7 +2449,7 @@ def __init__(self, name: str, hlo: ir.Module, self.compile_args = compile_args self._executable = None - # -- stages.XlaLowering overrides + # -- stages.Lowering overrides def stablehlo(self) -> ir.Module: return self._hlo @@ -2463,14 +2477,41 @@ def cost_analysis(self) -> dict[str, float]: return xe.hlo_module_cost_analysis(backend, self.hlo().as_hlo_module()) +def get_op_sharding_from_executable( + executable) -> tuple[Sequence[xc.OpSharding], Sequence[xc.OpSharding]]: + in_op_shardings: list[xc.OpSharding] = [] + parameter_shardings_from_xla = executable.get_parameter_shardings() + if parameter_shardings_from_xla is not None: + in_op_shardings = parameter_shardings_from_xla + + out_op_shardings: list[xc.OpSharding] = [] + output_shardings_from_xla = executable.get_output_shardings() + if output_shardings_from_xla is not None: + out_op_shardings = output_shardings_from_xla + + return in_op_shardings, out_op_shardings + + +def get_pspec_from_executable( + executable, mesh: Mesh +) -> tuple[tuple[PartitionSpec, ...], tuple[PartitionSpec, ...]]: + input_op_s, output_op_s = get_op_sharding_from_executable(executable) + in_pspec: list[PartitionSpec] = [] + for s in input_op_s: + in_pspec.extend(sharding_impls.parse_flatten_op_sharding(s, mesh)) + + out_pspec: list[PartitionSpec] = [] + for s in output_op_s: + out_pspec.extend(sharding_impls.parse_flatten_op_sharding(s, mesh)) + return tuple(in_pspec), tuple(out_pspec) + + def get_out_shardings_from_executable( xla_executable, device_assignment: Sequence[xc.Device], num_out_avals: int, num_ordered_effects: int, ) -> Sequence[sharding_impls.GSPMDSharding] | None: - from jax._src import pjit - try: omk = xla_executable.get_output_memory_kinds()[0] if num_ordered_effects > 0: @@ -2486,7 +2527,7 @@ def get_out_shardings_from_executable( return [sharding_impls.GSPMDSharding.get_replicated(device_assignment, memory_kind=mk) for mk in omk] - _, out_op_shardings = pjit.get_op_sharding_from_executable(xla_executable) + _, out_op_shardings = get_op_sharding_from_executable(xla_executable) if not out_op_shardings: return None @@ -2517,14 +2558,12 @@ def _get_in_shardings_from_xla( num_ordered_effects: int ) -> Sequence[GSPMDSharding] | None: """Returns input shardings from XLA.""" - from jax._src import pjit - # When the device assignment only has 1 device, SPMD partitioner will not run. # Hence the op shardings will not be set on the `hlo_module`. if len(device_assignment) == 1: return [GSPMDSharding.get_replicated(device_assignment)] * num_in_avals - in_op_shardings, _ = pjit.get_op_sharding_from_executable(xla_executable) + in_op_shardings, _ = get_op_sharding_from_executable(xla_executable) if not in_op_shardings: return None @@ -2543,9 +2582,7 @@ def _get_in_shardings_from_xla( def _get_mesh_pspec_shardings_from_executable( xla_executable, mesh: Mesh ) -> tuple[Sequence[NamedSharding], Sequence[NamedSharding]]: - from jax._src import pjit - - in_pspec, out_pspec = pjit.get_pspec_from_executable(xla_executable, mesh) + in_pspec, out_pspec = get_pspec_from_executable(xla_executable, mesh) return ([NamedSharding(mesh, i) for i in in_pspec], [NamedSharding(mesh, o) for o in out_pspec]) @@ -3085,7 +3122,7 @@ def reflatten_outputs_for_dispatch(out_tree, out_flat): return tree_util.dispatch_registry.flatten(out_unflat, None) -class MeshExecutable(stages.XlaExecutable): +class MeshExecutable(stages.Executable): __slots__ = [ "xla_executable", "_unsafe_call", "build_unsafe_call", "in_avals", "out_avals", "_in_shardings", "_out_shardings", "_auto_spmd_lowering", @@ -3121,7 +3158,7 @@ def unsafe_call(self) -> Callable[..., Any]: self._unsafe_call = self.build_unsafe_call() return self._unsafe_call # type: ignore - # -- stages.XlaExecutable overrides + # -- stages.Executable overrides def xla_extension_executable(self): return self.xla_executable @@ -3149,20 +3186,6 @@ def call(self, *args): self._kept_var_idx) return self.unsafe_call(*args) # pylint: disable=not-callable - def input_shardings(self) -> Sequence[JSharding]: - return self._in_shardings - - def output_shardings(self) -> Sequence[JSharding]: - return self._out_shardings - - def input_layouts(self): - return [Layout(l, s) - for l, s in safe_zip(self._xla_in_layouts, self._in_shardings)] - - def output_layouts(self): - return [Layout(l, s) - for l, s in safe_zip(self._xla_out_layouts, self._out_shardings)] - def create_cpp_call(self, no_kwargs, in_tree, out_tree): if not (isinstance(self.unsafe_call, ExecuteReplicated) and not self.unsafe_call.has_unordered_effects and diff --git a/jax/_src/interpreters/xla.py b/jax/_src/interpreters/xla.py index 33a8992a8be4..7fbb22923e0f 100644 --- a/jax/_src/interpreters/xla.py +++ b/jax/_src/interpreters/xla.py @@ -16,7 +16,7 @@ from __future__ import annotations -from collections.abc import Callable, Sequence +from collections.abc import Callable from functools import partial from typing import Any, Union @@ -25,7 +25,6 @@ from jax._src import core from jax._src import dtypes from jax._src.abstract_arrays import numpy_scalar_types -from jax._src.core import ShapedArray from jax._src.util import safe_zip, safe_map from jax._src.typing import Shape @@ -41,11 +40,6 @@ def identity(x): return x _scalar_types = dtypes.python_scalar_dtypes.keys() -def _make_array_shape(aval: ShapedArray) -> Sequence[xc.Shape]: - aval = core.physical_aval(aval) - dtype = np.dtype('bool') if aval.dtype == dtypes.float0 else aval.dtype - return (xc.Shape.array_shape(dtype, aval.shape),) - # Utilities # HLO instructions optionally can be annotated to say how the output should be @@ -90,20 +84,6 @@ def tuple_sharding_proto(elems): ### handlers -# JAX abstract values -> XLA shapes - -def aval_to_xla_shapes(aval: core.AbstractValue) -> Sequence[xc.Shape]: - try: - return _xla_shape_handlers[type(aval)](aval) - except KeyError as err: - raise TypeError(f"No xla_shape_handler for type: {type(aval)}") from err - -_xla_shape_handlers: dict[type[core.AbstractValue], - Callable[[Any], Sequence[xc.Shape]]] = { - ShapedArray: _make_array_shape, -} -_xla_shape_handlers[core.AbstractToken] = lambda _: (xc.Shape.token_shape(),) - # IR constants diff --git a/jax/_src/lax/ann.py b/jax/_src/lax/ann.py index 0e037ec774b5..c9a68d84b024 100644 --- a/jax/_src/lax/ann.py +++ b/jax/_src/lax/ann.py @@ -240,8 +240,8 @@ def _approx_top_k_abstract_eval(operand, *, k, reduction_dimension, f"either the `k` ({k}) or the " f" reduction dimension size ({reduction_input_size}) are symbolic") return (operand.update(shape=dims, dtype=operand.dtype, - weak_type=operand.weak_type), - operand.update(shape=dims, dtype=np.dtype(np.int32))) + weak_type=operand.weak_type, vma=operand.vma), + operand.update(shape=dims, dtype=np.dtype(np.int32), vma=operand.vma)) def _get_init_val_literal(op_type, is_max_k): return np.array(-np.inf if is_max_k else np.inf, dtype=op_type) diff --git a/jax/_src/lax/control_flow/common.py b/jax/_src/lax/control_flow/common.py index b75cbf6ac708..87dbcd8d3f32 100644 --- a/jax/_src/lax/control_flow/common.py +++ b/jax/_src/lax/control_flow/common.py @@ -260,3 +260,23 @@ def _show_diff(array1, array2): def _avals_short(avals): to_str = lambda aval: getattr(aval, 'str_short', partial(str, aval))() return ' '.join(map(to_str, avals)) + +def _aval_mismatch_extra(a1: core.AbstractValue, a2: core.AbstractValue) -> str: + assert not core.typematch(a1, a2) + if isinstance(a1, core.ShapedArray) and isinstance(a2, core.ShapedArray): + mismatches = [] + if a1.dtype != a2.dtype: + mismatches.append('the dtypes do not match') + if a1.shape != a2.shape: + mismatches.append('the shapes do not match') + if a1.vma != a2.vma: + mismatches.append('the varying manual axes do not match') + # TODO(yashkatariya,mattjj): add check for sharding-in-types mismatch + + if len(mismatches) == 0: + return '' + elif len(mismatches) == 1: + return ', so ' + mismatches[0] + else: + return ', so ' + ', '.join(mismatches[:-1]) + ', and ' + mismatches[-1] + return '' diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index 63896cc2a0bf..b6263c427a00 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -23,7 +23,9 @@ import operator from typing import Any, TypeVar -from jax.tree_util import tree_flatten, tree_unflatten +from jax._src.tree_util import ( + tree_flatten, tree_unflatten, tree_flatten_with_path, keystr, + equality_errors_pytreedef) from jax._src import ad_util from jax._src import api_util from jax._src import config @@ -44,19 +46,14 @@ from jax._src.interpreters import xla from jax._src.lax import lax from jax._src.traceback_util import api_boundary -from jax._src.util import (safe_map, split_list, partition_list) +from jax._src.util import safe_map, split_list, partition_list, unzip2 from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo import numpy as np from jax._src.lax.control_flow.common import ( - _avals_short, - _check_tree_and_avals, - _initial_style_jaxprs_with_common_consts, - _make_closed_jaxpr, - _prune_zeros, - _typecheck_param, - ) + _avals_short, _typecheck_param, _aval_mismatch_extra, + _initial_style_jaxprs_with_common_consts, _make_closed_jaxpr, _prune_zeros) map, unsafe_map = safe_map, map @@ -147,16 +144,30 @@ def switch(index, branches, *operands): if config.mutable_array_checks.value: api_util._check_no_aliased_closed_over_refs(dbgs[0], (*jaxprs[0].consts, *consts), ops) for i, (out_tree, jaxpr) in enumerate(zip(out_trees[1:], jaxprs[1:])): - _check_tree_and_avals("branch 0 output", - out_trees[0], jaxprs[0].out_avals, - f"branch {i + 1} output", - out_tree, jaxpr.out_avals) + _check_branch_outputs( + "switch", "branch 0", f"branch{i+1}", branches[0], branches[i+1], + out_trees[0], out_tree, jaxprs[0].out_avals, jaxpr.out_avals) + # prune passthrough outputs + fwds = [pe._jaxpr_forwarding(jaxpr.jaxpr) for jaxpr in jaxprs] + in_fwd = [xs[0] if len(set(xs)) == 1 else None for xs in zip(*fwds)] + keep = [f is None for f in in_fwd] + jaxprs = [pe.prune_closed_jaxpr_outputs(jaxpr, keep) for jaxpr in jaxprs] + joined_effects = core.join_effects(*(jaxpr.effects for jaxpr in jaxprs)) disallowed_effects = effects.control_flow_allowed_effects.filter_not_in(joined_effects) if disallowed_effects: raise NotImplementedError( f'Effects not supported in `switch`: {disallowed_effects}') + jaxprs = [replace_jaxpr_effects(jaxpr, joined_effects) for jaxpr in jaxprs] out = cond_p.bind(index, *consts, *ops, branches=tuple(jaxprs)) + out_ = iter(out) + + all_inputs = [*consts, *ops] + out = [ + next(out_) if fwd is None else lax.asarray(all_inputs[fwd]) + for fwd in in_fwd + ] + assert next(out_, None) is None return tree_unflatten(out_trees[0], out) @@ -255,11 +266,11 @@ def cond(pred, true_fun, false_fun, *operands): true_jaxpr.out_avals + false_jaxpr.out_avals): raise ValueError("Cannot return `Ref`s from `cond`.") - _check_tree_and_avals("true_fun output", - out_tree, true_jaxpr.out_avals, - "false_fun output", - false_out_tree, false_jaxpr.out_avals) - # prune passhtrough outputs + _check_branch_outputs( + 'cond', 'true_fun', 'false_fun', true_fun, false_fun, out_tree, + false_out_tree, true_jaxpr.out_avals, false_jaxpr.out_avals) + + # prune passthrough outputs true_fwds = pe._jaxpr_forwarding(true_jaxpr.jaxpr) false_fwds = pe._jaxpr_forwarding(false_jaxpr.jaxpr) in_fwd = [i if i == j else None for i, j in zip(true_fwds, false_fwds)] @@ -278,7 +289,6 @@ def cond(pred, true_fun, false_fun, *operands): true_jaxpr = replace_jaxpr_effects(true_jaxpr, joined_effects) out = cond_p.bind(index, *consts, *ops, branches=(false_jaxpr, true_jaxpr)) - num_consts = len(consts) out_ = iter(out) all_inputs = [*consts, *ops] @@ -289,6 +299,90 @@ def cond(pred, true_fun, false_fun, *operands): assert next(out_, None) is None return tree_unflatten(out_tree, out) +def _check_branch_outputs( + api_name, name1, name2, f1, f2, out_tree1, out_tree2, out_avals1, + out_avals2) -> None: + info1 = api_util.fun_sourceinfo(f1) + info2 = api_util.fun_sourceinfo(f2) + try: + outs1 = tree_unflatten(out_tree1, out_avals1) + except: + paths = [None] * len(out_avals1) + component = lambda _: '' + else: + leaves_and_paths, _ = tree_flatten_with_path(outs1) + paths, _ = unzip2(leaves_and_paths) # type: ignore + component = lambda p: f' at path {keystr(p)}' if p else '' + + if out_tree1 != out_tree2: + diffs = [f'{name1} output{component(p)} is a {thing1} but ' + f'{name2} output{component(p)} is a {thing2}, so {expl}' + for p, thing1, thing2, expl + in equality_errors_pytreedef(out_tree1, out_tree2)] + + if len(diffs) == 0: + return # the trees may have different aux data, but structures are same + elif len(diffs) == 1: + differences = f'{diffs[0]}.\n' + else: + differences = ('\n'.join(f' * {d};\n' for d in diffs[:-1]) + + f' * {diffs[-1]}.\n') + + raise TypeError( + f'{api_name} branch outputs must have the same pytree structure, but ' + 'they differ:\n\n' + f'{name1} is {info1}\n' + f'{name2} is {info2}\n\n' + f'{differences}\n' + f'Revise {name1} and/or {name2} so that they have the same pytree ' + 'structure.') + + if not all(map(core.typematch, out_avals1, out_avals2)): + diffs = [f'the output of {name1}{component(p)} has type {a1.str_short()}' + f' but the corresponding output of {name2} has type ' + f'{a2.str_short()}{_aval_mismatch_extra(a1, a2)}' + for p, a1, a2 in zip(paths, out_avals1, out_avals2) + if not core.typematch(a1, a2)] + if len(diffs) == 0: + return # seems unreachable but in any case we don't have a good error msg + elif len(diffs) == 1: + differences = f'{_capitalize(diffs[0])}.\n' + else: + differences = ('\n'.join(f' * {d};' for d in diffs[:-1]) + + f'\n * {diffs[-1]}.\n') + + pvary_applications = [ + f'applying `jax.lax.pvary(..., {tuple(a1.vma - a2.vma)})` ' + f'to the output of {n}{component(p)}' + for p, aval1, aval2 in zip(paths, out_avals1, out_avals2) + for n, a1, a2 in [(name1, aval2, aval1), (name2, aval1, aval2)] + if not core.typematch(a1, a2) and + isinstance(a1, core.ShapedArray) and isinstance(a2, core.ShapedArray) + and a1.vma != a2.vma and a2.vma - a1.vma] + + if not pvary_applications: + pvary_msg = '' + elif len(pvary_applications) == 1: + pvary_msg = f'This might be fixed by {pvary_applications[0]}.\n' + else: + pvary_msg = ('This might be fixed by:\n' + + '\n'.join(f' * {d};' for d in pvary_applications[:-1]) + + f'\n * {pvary_applications[-1]}.\n') + if pvary_msg: + pvary_msg += ("See https://docs.jax.dev/en/latest/notebooks/shard_map.html#scan-vma " + "for more information.\n\n") + + raise TypeError( + f'{api_name} branches must have equal output types but they differ.\n\n' + f'{name1} is {info1}\n' + f'{name2} is {info2}\n\n' + f'{differences}\n' + f'{pvary_msg}' + f'Revise {name1} and/or {name2} so that all output types match.') + + +def _capitalize(s): + # s.capitalize() converts s[1:] to lowercase which we don't want. + return s[0].capitalize() + s[1:] + @api_boundary @functools.wraps(_cond) def cond(*args, **kwargs): @@ -347,6 +441,15 @@ def _cond_abstract_eval(*avals: core.AbstractValue, if disallowed_effects: raise NotImplementedError( f'Effects not supported in `cond`: {disallowed_effects}') + b0_vma = [o.vma for o in branches[0].out_avals] + for branch in branches[1:]: + b_vma = [o.vma for o in branch.out_avals] + if b0_vma != b_vma: + raise Exception("The branches of cond produced mismatched varying manual " + f"axes. Got {b0_vma} and {b_vma}. Please open an issue " + "at https://github.com/jax-ml/jax/issues, and as a " + "temporary workaround pass the check_rep=False argument " + "to shard_map") return branches[0].out_avals, joined_effects def _bcast_select(pred, on_true, on_false): diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 3084fa722977..de53ee14ca0d 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -51,7 +51,7 @@ from jax._src.lax.control_flow.common import ( _avals_short, _initial_style_jaxpr, _initial_style_jaxpr_attrs, _make_closed_jaxpr_attrs, _prune_zeros, - _typecheck_param) + _typecheck_param, _aval_mismatch_extra) from jax._src.lax.other import logaddexp from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo @@ -60,24 +60,12 @@ from jax._src.tree_util import equality_errors from jax._src.typing import Array from jax._src.util import ( - merge_lists, - partition_list, - safe_map, - safe_zip, - split_list, - split_list_checked, - unzip2, - weakref_lru_cache, -) + merge_lists, partition_list, safe_map, safe_zip, split_list, + split_list_checked, unzip2, weakref_lru_cache,) from jax._src import xla_bridge as xb from jax.tree_util import ( - keystr, - tree_flatten, - tree_flatten_with_path, - tree_map, - tree_unflatten, - treedef_is_leaf, -) + keystr, tree_flatten, tree_flatten_with_path, tree_map, tree_unflatten, + treedef_is_leaf) import numpy as np _map = safe_map @@ -178,6 +166,11 @@ def scan(f, init, xs, length=None): :py:func:`scan` compiles ``f``, so while it can be combined with :py:func:`jit`, it's usually unnecessary. + .. note:: + :func:`scan` is designed for iterating with a static number of iterations. + For iteration with a dynamic number of iterations, use :func:`fori_loop` + or :func:`while_loop`. + Args: f: a Python function to be scanned of type ``c -> a -> (c, b)``, meaning that ``f`` accepts two arguments where the first is a value of the loop @@ -239,7 +232,9 @@ def scan(f, init, xs, length=None): try: length = int(length) except core.ConcretizationTypeError as err: - msg = 'The `length` argument to `scan` expects a concrete `int` value.' + msg = ('The `length` argument to `scan` expects a concrete `int` value.' + ' For scan-like iteration with a dynamic length, use `while_loop`' + ' or `fori_loop`.') raise core.ConcretizationTypeError(length, msg) from None # type: ignore[arg-type] if not all(length == l for l in lengths): msg = ("scan got `length` argument of {} which disagrees with " @@ -291,8 +286,17 @@ def _create_jaxpr(init): if len(out_tree_children) != 2: msg = "scan body output must be a pair, got {}." raise TypeError(msg.format(tree_unflatten(out_tree, jaxpr.out_avals))) - _, carry_avals_out, _ = split_list( - jaxpr.out_avals, [len(attrs_tracked), out_tree_children[0].num_leaves]) + + if attrs_tracked: + appends_out = [kind is pe.Append for *_, (_, _, kind) in attrs_tracked] + jaxpr = pe.move_outvars_to_back( + jaxpr, appends_out + [False] * (len(jaxpr.out_avals) - len(appends_out))) + num_attr_carry = sum(init_tree.num_leaves for init_tree, _, (_, _, kind) + in attrs_tracked if kind is pe.ReadWrite) + _, carry_avals_out, _ = split_list( + jaxpr.out_avals, [num_attr_carry, out_tree_children[0].num_leaves]) + else: + carry_avals_out, _ = split_list(jaxpr.out_avals, [out_tree_children[0].num_leaves]) return (init_flat, carry_avals, carry_avals_out, init_tree, in_flat, jaxpr, consts, out_tree, out_tree_children, attrs_tracked) @@ -325,9 +329,8 @@ def _create_jaxpr(init): raise ValueError("`unroll` must be a `bool` or a positive `int`.") if attrs_tracked: in_state = _get_states(attrs_tracked) - in_carry, in_ext = split_list(in_flat, [num_carry]) - in_flat = [*in_state, *in_carry, *in_ext] - num_carry += len(attrs_tracked) + in_flat = [*in_state, *in_flat] + num_carry += len(in_state) out = scan_p.bind(*consts, *in_flat, reverse=reverse, length=length, jaxpr=jaxpr, num_consts=len(consts), num_carry=num_carry, @@ -335,27 +338,50 @@ def _create_jaxpr(init): unroll=unroll, _split_transpose=_split_transpose) if attrs_tracked: - out_state, out = split_list(out, [len(attrs_tracked)]) - _set_states(attrs_tracked, out_state) + num_ext = (len(out) - len(in_state) + - sum(k is pe.Append for *_, (_, _, k) in attrs_tracked)) + out_state, out, out_append = split_list(out, [len(in_state), num_ext]) + out_attrs = _merge_attrs_out(attrs_tracked, out_state, out_append) + _set_states(attrs_tracked, out_attrs) return tree_unflatten(out_tree, out) def _set_states(attrs_tracked, vals): - from jax.experimental.attrs import jax_setattr + from jax.experimental.attrs import jax_setattr, jax_extendattr valss = split_list_checked(vals, [td.num_leaves for _, td, _ in attrs_tracked]) - for ((_, treedef, (obj, attr)), leaves) in zip(attrs_tracked, valss): - val = tree_unflatten(treedef, leaves) - jax_setattr(obj, attr, val) + for ((_, treedef, (obj, attr, kind)), leaves) in zip(attrs_tracked, valss): + if kind is pe.ReadWrite: + val = tree_unflatten(treedef, leaves) + jax_setattr(obj, attr, val) + elif kind is pe.Append: + val, = leaves + jax_extendattr(obj, attr, val.reshape(-1, *val.shape[2:])) + else: + assert False def _get_states(attrs_tracked): from jax.experimental.attrs import jax_getattr vals = [] - for treedef, _, (obj, attr) in attrs_tracked: - tree = jax_getattr(obj, attr) - leaves, treedef_ = tree_flatten(tree) - assert treedef == treedef_ - vals.extend(leaves) + for treedef, _, (obj, attr, kind) in attrs_tracked: + if kind is pe.ReadWrite: + tree = jax_getattr(obj, attr) + leaves, treedef_ = tree_flatten(tree) + assert treedef == treedef_ + vals.extend(leaves) + elif kind is pe.Append: + pass + else: + assert False return vals +def _merge_attrs_out(attrs_tracked, out_state, out_append): + out_state_, out_append_ = iter(out_state), iter(out_append) + out_attrs = [item for _, out_tree, (_, _, k) in attrs_tracked for item in + (itertools.islice(out_state_, out_tree.num_leaves) + if k is pe.ReadWrite else [next(out_append_)])] + assert next(out_state_, None) is next(out_append_, None) is None + return out_attrs + + def _capitalize(s): # s.capitalize() converts s[1:] to lowercase which we don't want. return s[0].capitalize() + s[1:] @@ -390,9 +416,8 @@ def _check_carry_type(name, body_fun, in_carry, out_carry_tree, out_avals): for path, thing1, thing2, explanation in equality_errors(in_carry, out_carry)] if len(diffs) == 0: - # The trees may have different aux data but structures are the same. - return - if len(diffs) == 1: + return # the trees may have different aux data, but structures are same + elif len(diffs) == 1: differences = f'{_capitalize(diffs[0])}.\n' else: differences = ('\n'.join(f' * {d};\n' for d in diffs[:-1]) @@ -409,32 +434,42 @@ def _check_carry_type(name, body_fun, in_carry, out_carry_tree, out_avals): f'{out_aval.str_short()}{_aval_mismatch_extra(in_aval, out_aval)}' for path, in_aval, out_aval in zip(paths, in_avals, out_avals) if not core.typematch(in_aval, out_aval)] + if len(diffs) == 0: - # The trees may have different aux data but structures are the same. - return + return # seems unreachable but in any case we don't have a good error msg if len(diffs) == 1: differences = f'{_capitalize(diffs[0])}.\n' else: differences = ('\n'.join(f' * {d};\n' for d in diffs[:-1]) + f' * {diffs[-1]}.\n') + + pvary_applications = [ + f'applying `jax.lax.pvary(..., {tuple(out_aval.vma - in_aval.vma)})` ' + f'to the initial carry value corresponding to {component(path)}' + for path, in_aval, out_aval in zip(paths, in_avals, out_avals) + if not core.typematch(in_aval, out_aval) and + isinstance(in_aval, ShapedArray) and isinstance(out_aval, ShapedArray) + and in_aval.vma != out_aval.vma and out_aval.vma - in_aval.vma] + + if not pvary_applications: + pvary_msg = '' + elif len(pvary_applications) == 1: + pvary_msg = f'This might be fixed by {pvary_applications[0]}.\n' + else: + pvary_msg = ('This might be fixed by:\n' + + '\n'.join(f' * {d};\n' for d in pvary_applications[:-1]) + + f' * {pvary_applications[-1]}.\n') + if pvary_msg: + pvary_msg += ("See https://docs.jax.dev/en/latest/notebooks/shard_map.html#scan-vma " + "for more information.\n\n") + raise TypeError( - f"{name} function carry input and carry output must have equal types " - "(e.g. shapes and dtypes of arrays), " + f"{name} function carry input and carry output must have equal types, " "but they differ:\n\n" f"{differences}\n" - "Revise the function so that all output types (e.g. shapes " - "and dtypes) match the corresponding input types.") - -def _aval_mismatch_extra(a1: core.AbstractValue, a2: core.AbstractValue) -> str: - assert not core.typematch(a1, a2) - if isinstance(a1, core.ShapedArray) and isinstance(a2, core.ShapedArray): - dtype_mismatch = a1.dtype != a2.dtype - shape_mismatch = a1.shape != a2.shape - return (', so ' * (dtype_mismatch or shape_mismatch) + - 'the dtypes do not match' * dtype_mismatch + - ' and also ' * (dtype_mismatch and shape_mismatch) + - 'the shapes do not match' * shape_mismatch) - return '' + f"{pvary_msg}" + "Revise the function so that all output types match the corresponding " + "input types.") # TODO(mattjj): re-land #19819 version? simpler, but caused ~1 perf regression. def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear, @@ -513,8 +548,8 @@ def _concat(a, b): return lax.concatenate([a, b], 0) def _empty_array(prefix, length_spec, aval): sharding = aval.sharding.with_spec((*length_spec, *aval.sharding.spec)) - return lax.broadcast(lax.empty(aval.dtype), (*prefix, *aval.shape), - out_sharding=sharding) + empty = core.pvary(lax.empty(aval.dtype), tuple(aval.vma)) + return lax.broadcast(empty, (*prefix, *aval.shape), out_sharding=sharding) eval_jaxpr_p = core.Primitive('eval_jaxpr') eval_jaxpr_p.multiple_results = True @@ -532,9 +567,17 @@ def _prepend_dim_to_aval(sz, aval): def _scan_abstract_eval(*args, reverse, length, num_consts, num_carry, jaxpr, linear, unroll, _split_transpose): - carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry]) + out_carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry]) + _, in_carry_avals, _ = split_list(args, [num_consts, num_carry]) + if [i.vma for i in in_carry_avals] != [o.vma for o in out_carry_avals]: + raise ValueError( + 'Scan carry input and output got mismatched varying manual axes ' + f'{in_carry_avals} and {out_carry_avals}. Please open an ' + 'issue at https://github.com/jax-ml/jax/issues, and as a ' + 'temporary workaround pass the check_rep=False argument to ' + 'shard_map') ys_avals = _map(partial(_prepend_dim_to_aval, length), y_avals) - return carry_avals + ys_avals, jaxpr.effects + return out_carry_avals + ys_avals, jaxpr.effects def _scan_jvp(primals, tangents, reverse, length, jaxpr, num_consts, num_carry, linear, unroll, _split_transpose): @@ -655,7 +698,7 @@ def _scan_partial_eval(trace, *tracers, reverse: bool, # The above trace_to_jaxpr_nounits call computed loop-invariant residuals # (known values in invar_pvals_out) and also computed loop-invariant values # needed by the new jaxpr_known (in jaxpr_known_consts, which replace the - # previous consts). We need to collect the computed inteisive residuals, and + # previous consts). We need to collect the computed intensive residuals, and # move corresponding intensive residual binders in jaxpr_unknown to the front. res_pvals = invar_pvals_out[len(invar_pvals_out) - num_res:] intensive_res = [pval.get_known() for pval in res_pvals if pval.is_known()] @@ -778,16 +821,21 @@ def _scan_transpose(cts, *args, reverse, length, num_consts, ct_consts = _map(ad_util.zeros_like_aval, jaxpr.in_avals[num_ires:num_consts]) # jaxpr :: [ires, T d] -> [T c] -> [T a, eres] -> ([T c], [T b]) - # jaxpr_trans :: [ires] -> [CT d, CT c] -> [CT b, eres] -> ([CT d, CT c], [CT a]) + # jaxpr_trans :: [ires] -> [CT d, CT c] -> [CT b, eres] -> ([CT d, CT c], [CT a, e]) jaxpr_trans, attrs_tracked = _transpose_scan_jaxpr( jaxpr, num_ires, num_consts - num_ires, num_eres, ct_ys_is_zeros) - linear_trans = ([False] * num_ires + [False] * len(attrs_tracked) + + appends_out = [kind is pe.Append for *_, (_, _, kind) in attrs_tracked] + jaxpr_trans = pe.move_outvars_to_back( + jaxpr_trans, appends_out + [False] * (len(jaxpr_trans.out_avals) - len(appends_out))) + num_attr_carry = sum(init_tree.num_leaves for init_tree, _, (_, _, kind) + in attrs_tracked if kind is pe.ReadWrite) + linear_trans = ([False] * num_ires + [False] * num_attr_carry + [True] * (len(ct_consts) + len(ct_carry) + len(ct_ys)) + [False] * num_eres) in_state = _get_states(attrs_tracked) transpose_inputs = *ires, *in_state, *ct_consts, *ct_carry, *ct_ys, *eres - transpose_num_out_carry = num_consts-num_ires+num_carry+len(attrs_tracked) + transpose_num_out_carry = num_consts-num_ires+num_carry+num_attr_carry if not _split_transpose: outs = scan_p.bind( @@ -882,8 +930,10 @@ def _scan_transpose(cts, *args, reverse, length, num_consts, for mask in outs_mask ] - out_state, outs = split_list(outs, [len(attrs_tracked)]) - _set_states(attrs_tracked, out_state) + num_outs = len(outs) - num_attr_carry - sum(appends_out) + out_state, outs, out_append = split_list(outs, [num_attr_carry, num_outs]) + out_attrs = _merge_attrs_out(attrs_tracked, out_state, out_append) + _set_states(attrs_tracked, out_attrs) ct_consts, ct_init, ct_xs = split_list(outs, [num_consts - num_ires, num_carry]) return [None] * num_ires + ct_consts + ct_init + ct_xs + [None] * num_eres @@ -928,12 +978,10 @@ def transposed(*res1_cbar_bbar_res2): return c_bar + a_bar # TODO(necula): fix arg names and results for transposed - transposed_wrapped = lu.wrap_init(transposed, - debug_info=jaxpr.jaxpr.debug_info) - return _make_closed_jaxpr_attrs( - transposed_wrapped, - tuple(res1_avals + c_avals + b_carry_avals + - b_ys_avals_stripped + res2_avals)) + transposed_wrapped = lu.wrap_init(transposed, debug_info=jaxpr.jaxpr.debug_info) + trans_avals = (*res1_avals, *c_avals, *b_carry_avals, *b_ys_avals_stripped, *res2_avals) + trans_jaxpr, attrs_tracked = _make_closed_jaxpr_attrs(transposed_wrapped, trans_avals) + return trans_jaxpr, attrs_tracked def _scan_batching_rule(axis_data, args, @@ -1413,9 +1461,34 @@ def _create_jaxpr(init_val): if disallowed_effects: raise NotImplementedError( f'Effects not supported in `while`: {disallowed_effects}') + + # If the body forwards an input carry to an output carry, *and* it's not used + # by the cond fun, it can be moved to be a body const. Doing so can lead to + # efficiency wins: if e.g. we vmap the loop with a batched predicate, we batch + # the carry too, but not the body consts. + body_fwd = pe._jaxpr_forwarding(body_jaxpr.jaxpr) + carry_nofwd = [len(body_consts) + i != f for i, f in enumerate(body_fwd)] + cond_jaxpr_, keep_cond = pe.dce_jaxpr( + cond_jaxpr.jaxpr, [True], [True] * len(cond_consts) + carry_nofwd) + _, keep_cond_carry = split_list(keep_cond, [len(cond_consts)]) + move_to_const = _map(operator.not_, keep_cond_carry) + + if any(move_to_const): + cond_jaxpr = pe.close_jaxpr(cond_jaxpr_) + body_jaxpr = pe.prune_closed_jaxpr_outputs( + body_jaxpr, [not m for m in move_to_const]) + body_jaxpr = pe.move_binders_to_front( + body_jaxpr, [False] * len(body_consts) + move_to_const) + init_vals, new_body_consts = partition_list(move_to_const, init_vals) + body_consts = [*new_body_consts, *body_consts] + outs = while_p.bind(*cond_consts, *body_consts, *init_vals, cond_nconsts=len(cond_consts), cond_jaxpr=cond_jaxpr, body_nconsts=len(body_consts), body_jaxpr=body_jaxpr) + + if any(move_to_const): + outs = pe.merge_lists(move_to_const, outs, new_body_consts) + return tree_unflatten(body_tree, outs) @@ -1438,7 +1511,29 @@ def _join_while_effects(body_jaxpr, cond_jaxpr, body_nconsts, cond_nconsts def _while_loop_abstract_eval(*avals, cond_jaxpr, body_jaxpr, body_nconsts, cond_nconsts): - del avals + cond_consts_avals, body_consts_avals, in_avals = \ + util.split_list(avals, [cond_nconsts, body_nconsts]) + + if len(cond_jaxpr.in_avals) != len(cond_consts_avals) + len(in_avals): + raise core.JaxprTypeError( + f"while_loop {len(cond_jaxpr.in_avals)=} but {len(cond_consts_avals) + len(in_avals)=}") + if len(body_jaxpr.in_avals) != len(body_consts_avals) + len(in_avals): + raise core.JaxprTypeError( + f"while_loop {len(body_jaxpr.in_avals)=} but {len(body_consts_avals) + len(in_avals)=}") + # TODO(mattjj): check body carry type + # TODO(mattjj): make these typecompat checks work with bints + # if not all(_map(core.typecompat, [*cond_consts_avals, *in_avals], cond_jaxpr.in_avals)): # type: ignore + # cond_avals = [*cond_consts_avals, *in_avals] + # a1, a2 = next((a1, a2) for a1, a2 in zip(cond_avals, cond_jaxpr.in_avals) + # if not core.typecompat(a1, a2)) + # raise core.JaxprTypeError(f"while_loop cond function input type error: {a1} != {a2}") + # if not all(_map(core.typecompat, [*body_consts_avals, *in_avals], body_jaxpr.in_avals)): # type: ignore + # body_avals = [*body_consts_avals, *in_avals] + # a1, a2 = next((a1, a2) for a1, a2 in zip(body_avals, body_jaxpr.in_avals) + # if not core.typecompat(a1, a2)) + # raise core.JaxprTypeError(f"while_loop body function input type error: {a1} != {a2}") + + joined_effects = _join_while_effects(body_jaxpr, cond_jaxpr, body_nconsts, cond_nconsts) disallowed_effects = effects.control_flow_allowed_effects.filter_not_in(joined_effects) @@ -1679,7 +1774,7 @@ def _while_partial_eval_custom(saveable, unks_in, inst_in, eqn): assert False, "Fixpoint not reached" assert not num_res body_jaxpr_known = core.ClosedJaxpr(jaxpr_known_, body_jaxpr.consts) - del jaxpr_known_, carry_uk_out, num_res + del jaxpr_known_, carry_uk_out, num_res, unks_in # Instantiate all inputs (b/c jaxpr_staged will take all inputs). new_inst = [x for x, inst in zip(eqn.invars, inst_in) @@ -1701,6 +1796,7 @@ def _while_partial_eval_custom(saveable, unks_in, inst_in, eqn): del cond_uk # Build the known eqn. + unks_in = [*cond_consts_uk, *body_consts_uk, *carry_uk] # fixpoint carry_uk ins_known, _ = partition_list(unks_in, eqn.invars) out_binders_known, _ = partition_list(carry_uk, eqn.outvars) params_known = dict(cond_jaxpr=cond_jaxpr_known, body_jaxpr=body_jaxpr_known, @@ -1711,6 +1807,11 @@ def _while_partial_eval_custom(saveable, unks_in, inst_in, eqn): eqn_known = pe.new_jaxpr_eqn(ins_known, out_binders_known, while_p, params_known, effects_known, eqn.source_info, eqn.ctx) + # Typecheck known eqn. + _while_loop_abstract_eval( + *[v.aval for v in eqn_known.invars], cond_jaxpr=cond_jaxpr_known, + body_jaxpr=body_jaxpr_known, body_nconsts=params_known['body_nconsts'], + cond_nconsts=params_known['cond_nconsts']) # Staged eqn is same as input eqn. eqn_staged = eqn @@ -1763,18 +1864,19 @@ def cond(args): pred = lax.reduce_or(pred, tuple(range(len(pred_aval.shape)))) return pred def body(args): - return tuple(core.eval_jaxpr(body_jaxpr.jaxpr, body_jaxpr.consts, *args)) + return core.eval_jaxpr(body_jaxpr.jaxpr, body_jaxpr.consts, *args) def new_cond(pred_args): - pred, _ = pred_args + pred, *_ = pred_args return pred def new_body(pred_args): - _, args = pred_args - args = body(args) - pred = cond(args) - return pred, args + _, cond_consts, body_consts, carry = pred_args + carry = body((*body_consts, *carry)) + pred = cond((*cond_consts, *carry)) + return pred, cond_consts, body_consts, carry def fun(*args): - pred = cond(args) - _, out = while_loop(new_cond, new_body, (pred, args)) + cond_consts, body_consts, carry = split_list(args, [cond_nconsts, body_nconsts]) + pred = cond((*cond_consts, *carry)) + *_, out = while_loop(new_cond, new_body, (pred, cond_consts, body_consts, carry)) return out return mlir.lower_fun(fun)(ctx, *args) @@ -1798,8 +1900,7 @@ def fun(*args): cond_block.arguments[i] for i in range(len(flat_loop_carry_types)) ] cond_args = mlir.unflatten_ir_values_like_types(flat_cond_args, loop_carry_types) - # Remove tokens from cond args - cond_args = cond_args[num_tokens:] + cond_args = cond_args[num_tokens:] # Remove tokens from cond args x, _, z = util.split_list(cond_args, [cond_nconsts, body_nconsts]) cond_consts = [ mlir.ir_constant(xla.canonicalize_dtype(x)) for x in cond_jaxpr.consts @@ -1861,8 +1962,9 @@ def fun(*args): partial(_pred_bcast_select_hlo, ctx, pred_aval, body_pred), new_z, z, body_jaxpr.out_avals) - hlo.return_([*mlir.flatten_ir_values(out_tokens), *mlir.flatten_ir_values(x), *mlir.flatten_ir_values(y), - *mlir.flatten_ir_values(new_z)]) + hlo.return_([*mlir.flatten_ir_values(out_tokens), + *mlir.flatten_ir_values(x), *mlir.flatten_ir_values(y), + *mlir.flatten_ir_values(new_z)]) outputs = mlir.unflatten_ir_values_like_types(while_op.results, loop_carry_types) tokens, _, _, z = util.split_list(outputs, [num_tokens, cond_nconsts, body_nconsts]) @@ -1976,7 +2078,6 @@ def new_cond(*consts_refs_carry): batching.fancy_primitive_batchers[while_p] = _while_loop_batching_rule pe.partial_eval_jaxpr_custom_rules[while_p] = _while_partial_eval_custom mlir.register_lowering(while_p, _while_lowering) -core.custom_typechecks[while_p] = _while_typecheck state_discharge.register_partial_discharge_rule(while_p)(_while_partial_discharge_rule) @@ -2137,7 +2238,7 @@ def fori_loop(lower, upper, body_fun, init_val): unroll=unroll, ) return result - if unroll is not None: + if unroll is not None and unroll is not False and unroll != 1: raise ValueError("Can only use `unroll` in `fori_loop` if the loop bounds " "are statically known.") @@ -2170,12 +2271,7 @@ def _batch_and_remainder(x, batch_size: int): return scan_tree, remainder_tree @api_boundary -def map( - f, - xs, - *, - batch_size: int | None = None, -): +def map(f, xs, *, batch_size: int | None = None): """Map a function over leading array axes. Like Python's builtin map, except inputs and outputs are in the form of @@ -2237,17 +2333,22 @@ def map(f, xs): _, ys = scan(g, (), xs) return ys -def _rng_bit_generator_batching_rule(batched_args, batch_dims, *, shape, dtype, algorithm): +def _rng_bit_generator_batching_rule(batched_args, batch_dims, *, shape, dtype, + algorithm, out_sharding): keys, = batched_args bd, = batch_dims if bd is batching.not_mapped: - return lax.rng_bit_generator_p.bind(keys, shape=shape, dtype=dtype, - algorithm=algorithm), (None, None) + return lax.rng_bit_generator_p.bind( + keys, shape=shape, dtype=dtype, algorithm=algorithm, + out_sharding=out_sharding), (None, None) keys = batching.moveaxis(keys, bd, 0) batch_size = keys.shape[0] + out_s = (out_sharding.with_spec((keys.aval.sharding.spec[0], *out_sharding.spec)) + if out_sharding is not None else None) key = keys[0] - new_key, bits = lax.rng_bit_generator_p.bind(key, shape=(batch_size, *shape), - dtype=dtype, algorithm=algorithm) + new_key, bits = lax.rng_bit_generator_p.bind( + key, shape=(batch_size, *shape), dtype=dtype, algorithm=algorithm, + out_sharding=out_s) new_keys = slicing.dynamic_update_index_in_dim(keys, new_key, 0, axis=0) return (new_keys, bits), (0, 0) @@ -2488,7 +2589,8 @@ def _cumred_dtype_rule(name, operand, *args, **kw): def _cumulative_reduction_primitive(name, reduce_fn, reduce_window_fn): reducer_p = lax.standard_primitive( _cumred_shape_rule, partial(_cumred_dtype_rule, name), - name, sharding_rule=_cumred_sharding_rule) + name, sharding_rule=_cumred_sharding_rule, + vma_rule=partial(core.standard_vma_rule, name)) batching.primitive_batchers[reducer_p] = partial(_cumred_batch_rule, reducer_p) diff --git a/jax/_src/lax/control_flow/solves.py b/jax/_src/lax/control_flow/solves.py index acfcfd7ff3d3..f34c98c6aaae 100644 --- a/jax/_src/lax/control_flow/solves.py +++ b/jax/_src/lax/control_flow/solves.py @@ -309,24 +309,24 @@ def f_aux(x): jaxprs = _LinearSolveTuple( matvec_jaxpr, vecmat_jaxpr, solve_jaxpr, tr_solve_jaxpr) - out_flat = linear_solve_p.bind( - *(_flatten(all_consts) + b_flat), - const_lengths=const_lengths, jaxprs=jaxprs) + args = _flatten(all_consts) + b_flat + args = core.standard_insert_pvary(*args) + out_flat = linear_solve_p.bind(*args, const_lengths=const_lengths, jaxprs=jaxprs) return tree_unflatten(out_tree, out_flat) def _linear_solve_abstract_eval(*args, const_lengths, jaxprs): args_to_raise = args[sum(const_lengths):] - # raise aux_args to shaped arrays as well if present # number of aux args is the difference in out_avals # of solve and matvec (since they map to the same vector space) - num_aux = len(jaxprs.solve.out_avals) - len(jaxprs.matvec.out_avals) if num_aux > 0: args_to_raise += tuple(jaxprs.solve.out_avals[-num_aux:]) - return args_to_raise, jaxprs.solve.effects + out_vma = core.standard_vma_rule('linear_solve', *args_to_raise) + return (tuple(a.update(vma=out_vma) for a in args_to_raise), + jaxprs.solve.effects) def _custom_linear_solve_impl(*args, const_lengths, jaxprs): diff --git a/jax/_src/lax/convolution.py b/jax/_src/lax/convolution.py index 290d027cc6bc..28d67adb6413 100644 --- a/jax/_src/lax/convolution.py +++ b/jax/_src/lax/convolution.py @@ -158,6 +158,7 @@ def conv_general_dilated( preferred_element_type = ( None if preferred_element_type is None else dtypes.canonicalize_dtype(np.dtype(preferred_element_type))) + lhs, rhs = core.standard_insert_pvary(lhs, rhs) return conv_general_dilated_p.bind( lhs, rhs, window_strides=tuple(window_strides), padding=tuple(padding), lhs_dilation=tuple(lhs_dilation), rhs_dilation=tuple(rhs_dilation), @@ -633,7 +634,8 @@ def _conv_general_dilated_batch_rule( conv_general_dilated_p = lax.standard_primitive( _conv_general_dilated_shape_rule, _conv_general_dilated_dtype_rule, - 'conv_general_dilated') + 'conv_general_dilated', + vma_rule=partial(core.standard_vma_rule, 'conv_general_dilated')) ad.defbilinear(conv_general_dilated_p, _conv_general_dilated_transpose_lhs, diff --git a/jax/_src/lax/fft.py b/jax/_src/lax/fft.py index 6ca1a4abd193..2eebe6d91f22 100644 --- a/jax/_src/lax/fft.py +++ b/jax/_src/lax/fft.py @@ -124,7 +124,7 @@ def fft_abstract_eval(x, fft_type, fft_lengths): f"be equal to fft_lengths {fft_lengths}") shape = x.shape dtype = x.dtype - return x.update(shape=shape, dtype=dtype) + return x.update(shape=shape, dtype=dtype, vma=x.vma) def _fft_lowering(ctx, x, *, fft_type, fft_lengths): if not is_constant_shape(fft_lengths): diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 86a75ada63ad..79bd42607290 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -66,8 +66,8 @@ from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import chlo from jax._src.lib.mlir.dialects import hlo -from jax._src.lib import xla_extension_version from jax._src.sharding_impls import (PmapSharding, NamedSharding, + ShardingContext, SPMDAxisContext, PartitionSpec as P, canonicalize_sharding) from jax._src.typing import Array, ArrayLike, DimSize, DuckTypedArray, DTypeLike, Shape from jax._src.util import (NumpyComplexWarning, cache, canonicalize_axis, @@ -369,6 +369,7 @@ def nextafter(x1: ArrayLike, x2: ArrayLike) -> Array: For the smallest usable (i.e. normal) float, use ``tiny`` of ``jnp.finfo``. """ + x1, x2 = core.standard_insert_pvary(x1, x2) return nextafter_p.bind(x1, x2) @export @@ -483,14 +484,41 @@ def is_finite(x: ArrayLike) -> Array: """ return is_finite_p.bind(x) +class Tolerance: + """Specify the tolerances used for computing unary functions. + + Maximum two tolerances can be specified: (atol and rtol) or (atol and ulps). + """ + + def __init__(self, atol: float = 0.0, rtol: float = 0.0, ulps: int = 0): + if atol < 0.0 or rtol < 0.0 or ulps < 0.0: + raise ValueError('Tolerances must be non-negative.') + if atol == 0.0 and rtol == 0.0 and ulps == 0: + raise ValueError('At least one of atol, rtol, or ulps must be set.') + + self.atol = atol + self.rtol = rtol + self.ulps = ulps + + +class AccuracyMode(enum.Enum): + HIGHEST = 1 + DEFAULT = 2 + @export -def exp(x: ArrayLike) -> Array: +def exp(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise exponential: :math:`e^x`. This function lowers directly to the `stablehlo.exponential`_ operation. Args: x: input array. Must have floating-point or complex type. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: Array of the same shape and dtype as ``x`` containing the element-wise @@ -502,10 +530,10 @@ def exp(x: ArrayLike) -> Array: .. _stablehlo.exponential: https://openxla.org/stablehlo/spec#exponential """ - return exp_p.bind(x) + return exp_p.bind(x, accuracy=accuracy) -@export -def exp2(x: ArrayLike) -> Array: + +def exp2(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise base-2 exponential: :math:`2^x`. This function is implemented in terms of the `stablehlo.exponential`_ @@ -513,6 +541,12 @@ def exp2(x: ArrayLike) -> Array: Args: x: input array. Must have floating-point or complex type. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: Array of the same shape and dtype as ``x`` containing the element-wise @@ -525,10 +559,10 @@ def exp2(x: ArrayLike) -> Array: .. _stablehlo.exponential: https://openxla.org/stablehlo/spec#exponential .. _stablehlo.multiply: https://openxla.org/stablehlo/spec#multiply """ - return exp2_p.bind(x) + return exp2_p.bind(x, accuracy=accuracy) @export -def expm1(x: ArrayLike) -> Array: +def expm1(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise :math:`e^{x} - 1`. This function lowers directly to the `stablehlo.exponential_minus_one`_ @@ -537,6 +571,12 @@ def expm1(x: ArrayLike) -> Array: Args: x: input array. Must have floating-point or complex type. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: Array of the same shape and dtype as ``x`` containing the element-wise @@ -548,16 +588,22 @@ def expm1(x: ArrayLike) -> Array: .. _stablehlo.exponential_minus_one: https://openxla.org/stablehlo/spec#exponential_minus_one """ - return expm1_p.bind(x) + return expm1_p.bind(x, accuracy=accuracy) @export -def log(x: ArrayLike) -> Array: +def log(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise natural logarithm: :math:`\mathrm{log}(x)`. This function lowers directly to the `stablehlo.log`_ operation. Args: x: input array. Must have floating-point or complex type. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: Array of the same shape and dtype as ``x`` containing the element-wise @@ -568,10 +614,10 @@ def log(x: ArrayLike) -> Array: .. _stablehlo.log: https://openxla.org/stablehlo/spec#log """ - return log_p.bind(x) + return log_p.bind(x, accuracy=accuracy) @export -def log1p(x: ArrayLike) -> Array: +def log1p(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise :math:`\mathrm{log}(1 + x)`. This function lowers directly to the `stablehlo.log_plus_one`_ operation. @@ -580,6 +626,12 @@ def log1p(x: ArrayLike) -> Array: Args: x: input array. Must have floating-point or complex type. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: Array of the same shape and dtype as ``x`` containing the element-wise @@ -591,16 +643,22 @@ def log1p(x: ArrayLike) -> Array: .. _stablehlo.log_plus_one: https://openxla.org/stablehlo/spec#log_plus_one """ - return log1p_p.bind(x) + return log1p_p.bind(x, accuracy=accuracy) @export -def tanh(x: ArrayLike) -> Array: +def tanh(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise hyperbolic tangent: :math:`\mathrm{tanh}(x)`. This function lowers directly to the `stablehlo.tanh`_ operation. Args: x: input array. Must have floating-point or complex type. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: Array of the same shape and dtype as ``x`` containing the element-wise @@ -613,10 +671,11 @@ def tanh(x: ArrayLike) -> Array: .. _stablehlo.tanh: https://openxla.org/stablehlo/spec#tanh """ - return tanh_p.bind(x) + return tanh_p.bind(x, accuracy=accuracy) @export -def logistic(x: ArrayLike) -> Array: + +def logistic(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise logistic (sigmoid) function: :math:`\frac{1}{1 + e^{-x}}`. There is no HLO logistic/sigmoid primitive, so this lowers to a sequence @@ -632,10 +691,10 @@ def logistic(x: ArrayLike) -> Array: See also: - :func:`jax.nn.sigmoid`: an alternative API for this functionality. """ - return logistic_p.bind(x) + return logistic_p.bind(x, accuracy=accuracy) @export -def sin(x: ArrayLike) -> Array: +def sin(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise sine: :math:`\mathrm{sin}(x)`. For floating-point inputs, this function lowers directly to the @@ -644,6 +703,12 @@ def sin(x: ArrayLike) -> Array: Args: x: input array. Must have floating-point or complex type. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: Array of the same shape and dtype as ``x`` containing the element-wise @@ -656,10 +721,10 @@ def sin(x: ArrayLike) -> Array: .. _stablehlo.sine: https://openxla.org/stablehlo/spec#sine """ - return sin_p.bind(x) + return sin_p.bind(x, accuracy=accuracy) @export -def cos(x: ArrayLike) -> Array: +def cos(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise cosine: :math:`\mathrm{cos}(x)`. For floating-point inputs, this function lowers directly to the @@ -668,6 +733,12 @@ def cos(x: ArrayLike) -> Array: Args: x: input array. Must have floating-point or complex type. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: Array of the same shape and dtype as ``x`` containing the element-wise @@ -680,7 +751,7 @@ def cos(x: ArrayLike) -> Array: .. _stablehlo.cosine: https://openxla.org/stablehlo/spec#cosine """ - return cos_p.bind(x) + return cos_p.bind(x, accuracy=accuracy) @export def atan2(x: ArrayLike, y: ArrayLike) -> Array: @@ -704,6 +775,7 @@ def atan2(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.atan2: https://openxla.org/stablehlo/spec#atan2 """ + x, y = core.standard_insert_pvary(x, y) return atan2_p.bind(x, y) @export @@ -773,6 +845,7 @@ def complex(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.complex: https://openxla.org/stablehlo/spec#complex """ + x, y = core.standard_insert_pvary(x, y) return complex_p.bind(x, y) @export @@ -844,6 +917,7 @@ def pow(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.convert: https://openxla.org/stablehlo/spec#convert .. _stablehlo.pow: https://openxla.org/stablehlo/spec#pow """ + x, y = core.standard_insert_pvary(x, y) return pow_p.bind(x, y) @export @@ -867,14 +941,21 @@ def integer_pow(x: ArrayLike, y: int) -> Array: """ return integer_pow_p.bind(x, y=y) + @export -def sqrt(x: ArrayLike) -> Array: +def sqrt(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise square root: :math:`\sqrt{x}`. This function lowers directly to the `stablehlo.sqrt`_ operation. Args: x: Input array. Must have floating or complex dtype. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: An array of the same shape and dtype as ``x`` containing the square root. @@ -886,16 +967,22 @@ def sqrt(x: ArrayLike) -> Array: .. _stablehlo.sqrt: https://openxla.org/stablehlo/spec#sqrt """ - return sqrt_p.bind(x) + return sqrt_p.bind(x, accuracy=accuracy) @export -def rsqrt(x: ArrayLike) -> Array: +def rsqrt(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise reciprocal square root: :math:`1 \over \sqrt{x}`. This function lowers directly to the `stablehlo.rsqrt`_ operation. Args: x: Input array. Must have floating or complex dtype. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: An array of the same shape and dtype as ``x`` containing the @@ -908,16 +995,22 @@ def rsqrt(x: ArrayLike) -> Array: .. _stablehlo.rsqrt: https://openxla.org/stablehlo/spec#rsqrt """ - return rsqrt_p.bind(x) + return rsqrt_p.bind(x, accuracy=accuracy) @export -def cbrt(x: ArrayLike) -> Array: +def cbrt(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise cube root: :math:`\sqrt[3]{x}`. This function lowers directly to the `stablehlo.cbrt`_ operation. Args: x: Input array. Must have floating or complex dtype. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: An array of the same shape and dtype as ``x`` containing the cube root. @@ -929,7 +1022,7 @@ def cbrt(x: ArrayLike) -> Array: .. _stablehlo.cbrt: https://openxla.org/stablehlo/spec#cbrt """ - return cbrt_p.bind(x) + return cbrt_p.bind(x, accuracy=accuracy) @export def bitwise_not(x: ArrayLike) -> Array: @@ -979,6 +1072,7 @@ def bitwise_and(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.and: https://openxla.org/stablehlo/spec#and """ + x, y = core.standard_insert_pvary(x, y) return and_p.bind(x, y) @export @@ -1005,6 +1099,7 @@ def bitwise_or(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.or: https://openxla.org/stablehlo/spec#or """ + x, y = core.standard_insert_pvary(x, y) return or_p.bind(x, y) @export @@ -1031,6 +1126,7 @@ def bitwise_xor(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.xor: https://openxla.org/stablehlo/spec#xor """ + x, y = core.standard_insert_pvary(x, y) return xor_p.bind(x, y) @export @@ -1095,6 +1191,7 @@ def add(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.add: https://openxla.org/stablehlo/spec#add """ + x, y = core.standard_insert_pvary(x, y) return add_p.bind(x, y) @export @@ -1118,6 +1215,7 @@ def sub(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.subtract: https://openxla.org/stablehlo/spec#subtract """ + x, y = core.standard_insert_pvary(x, y) return sub_p.bind(x, y) @export @@ -1141,6 +1239,7 @@ def mul(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.multiply: https://openxla.org/stablehlo/spec#multiply """ + x, y = core.standard_insert_pvary(x, y) return mul_p.bind(x, y) @export @@ -1170,6 +1269,7 @@ def div(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.divide: https://openxla.org/stablehlo/spec#divide """ + x, y = core.standard_insert_pvary(x, y) return div_p.bind(x, y) @export @@ -1197,6 +1297,7 @@ def rem(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.remainder: https://openxla.org/stablehlo/spec#remainder """ + x, y = core.standard_insert_pvary(x, y) return rem_p.bind(x, y) @export @@ -1222,6 +1323,7 @@ def max(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.maximum: https://openxla.org/stablehlo/spec#maximum """ + x, y = core.standard_insert_pvary(x, y) return max_p.bind(x, y) @export @@ -1247,6 +1349,7 @@ def min(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.minimum: https://openxla.org/stablehlo/spec#minimum """ + x, y = core.standard_insert_pvary(x, y) return min_p.bind(x, y) @export @@ -1272,6 +1375,7 @@ def shift_left(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.shift_left: https://openxla.org/stablehlo/spec#shift_left """ + x, y = core.standard_insert_pvary(x, y) return shift_left_p.bind(x, y) @export @@ -1298,6 +1402,7 @@ def shift_right_arithmetic(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.shift_right_arithmetic: https://openxla.org/stablehlo/spec#shift_right_arithmetic """ + x, y = core.standard_insert_pvary(x, y) return shift_right_arithmetic_p.bind(x, y) @export @@ -1324,6 +1429,7 @@ def shift_right_logical(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.shift_right_logical: https://openxla.org/stablehlo/spec#shift_right_logical """ + x, y = core.standard_insert_pvary(x, y) return shift_right_logical_p.bind(x, y) @export @@ -1354,6 +1460,7 @@ def eq(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.compare: https://openxla.org/stablehlo/spec#compare """ + x, y = core.standard_insert_pvary(x, y) return eq_p.bind(x, y) @export @@ -1384,6 +1491,7 @@ def ne(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.compare: https://openxla.org/stablehlo/spec#compare """ + x, y = core.standard_insert_pvary(x, y) return ne_p.bind(x, y) @export @@ -1414,6 +1522,7 @@ def ge(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.compare: https://openxla.org/stablehlo/spec#compare """ + x, y = core.standard_insert_pvary(x, y) return ge_p.bind(x, y) @export @@ -1444,6 +1553,7 @@ def gt(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.compare: https://openxla.org/stablehlo/spec#compare """ + x, y = core.standard_insert_pvary(x, y) return gt_p.bind(x, y) @export @@ -1474,6 +1584,7 @@ def le(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.compare: https://openxla.org/stablehlo/spec#compare """ + x, y = core.standard_insert_pvary(x, y) return le_p.bind(x, y) @export @@ -1504,6 +1615,7 @@ def lt(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.compare: https://openxla.org/stablehlo/spec#compare """ + x, y = core.standard_insert_pvary(x, y) return lt_p.bind(x, y) @export @@ -1573,6 +1685,7 @@ def _convert_element_type( "Instead, convert to and from their representation dtypes, e.g.:\n" f"{dtype_to_string(old_dtype)} -> {dtype_to_string(old_rep_dtype)} " f"-> {dtype_to_string(new_rep_dtype)} -> {dtype_to_string(new_dtype)}") + if isinstance(new_dtype, dtypes.ExtendedDType): return to_edtype_p.bind(operand, edtype=new_dtype) return from_edtype_p.bind(operand, dtype=np.dtype(new_dtype)) @@ -1658,6 +1771,7 @@ def clamp(min: ArrayLike, x: ArrayLike, max: ArrayLike) -> Array: x & \text{otherwise} \end{cases}`. """ + min, x, max = core.standard_insert_pvary(min, x, max) return clamp_p.bind(min, x, max) @@ -1764,6 +1878,7 @@ def _decorator(*args, **kwargs): closed_jaxpr, out_tree = _trace_composite_to_jaxpr( partial(decomposition, **kwargs), in_tree, in_avals, name, debug_info ) + flat_args = core.standard_insert_pvary(*flat_args) out_flat = composite_p.bind( *flat_args, name=name, @@ -1838,7 +1953,7 @@ def composite_jvp(*args, **_): raise ValueError( "JVP rule for composite not implemented. You can use `jax.custom_jvp` to " "add support. See " - "https://jax.readthedocs.io/en/latest/_autosummary/jax.custom_jvp.html" + "https://docs.jax.dev/en/latest/_autosummary/jax.custom_jvp.html" ) @@ -1847,7 +1962,7 @@ def composite_transpose(*args, **_): raise ValueError( "Transpose rule for composite not implemented. You can use" "`jax.custom_jvp` or `jax.custom_vjp` to add support. See " - "https://jax.readthedocs.io/en/latest/_autosummary/jax.custom_jvp.html" + "https://docs.jax.dev/en/latest/_autosummary/jax.custom_jvp.html" ) @@ -1881,6 +1996,7 @@ def concatenate(operands: Array | Sequence[ArrayLike], dimension: int) -> Array: op, = operands if isinstance(op, Array): return op + operands = core.standard_insert_pvary(*operands) return concatenate_p.bind(*operands, dimension=dimension) @@ -2230,13 +2346,10 @@ def _convert_to_hlo_attr(self, lhs_dtype: DTypeLike, np.dtype(dtypes.float8_e4m3fnuz), np.dtype(dtypes.float8_e5m2), np.dtype(dtypes.float8_e5m2fnuz), + np.dtype(dtypes.float8_e3m4), + np.dtype(dtypes.float8_e4m3), + np.dtype(dtypes.float8_e8m0fnu), ] - if dtypes.float8_e3m4 is not None: - fp8_dtypes += [np.dtype(dtypes.float8_e3m4)] - if dtypes.float8_e4m3 is not None: - fp8_dtypes += [np.dtype(dtypes.float8_e4m3)] - if dtypes.float8_e8m0fnu is not None: - fp8_dtypes += [np.dtype(dtypes.float8_e8m0fnu)] if lhs_dtype not in fp8_dtypes or rhs_dtype not in fp8_dtypes: raise ValueError( f"The dot algorithm '{self}' requires both inputs to have float8 " @@ -2267,11 +2380,6 @@ def _convert_to_hlo_attr(self, lhs_dtype: DTypeLike, case DotAlgorithmPreset.BF16_BF16_F32_X6: return hlo.DotAlgorithm.get(bf16, bf16, f32, 1, 1, 6, False) case DotAlgorithmPreset.BF16_BF16_F32_X9: - if xla_extension_version < 320: - raise ValueError( - "The dot algorithm BF16_BF16_F32_X9 requires XLA extension " - "version >= 320." - ) return hlo.DotAlgorithm.get(bf16, bf16, f32, 1, 1, 9, False) case DotAlgorithmPreset.TF32_TF32_F32: return hlo.DotAlgorithm.get(tf32, tf32, f32, 1, 1, 1, False) @@ -2352,6 +2460,7 @@ def dot(lhs: Array, rhs: Array, precision: PrecisionLike = None, def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionNumbers, precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None, + *, out_sharding=None) -> Array: """General dot product/contraction operator. @@ -2411,6 +2520,7 @@ def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionN preferred_element_type = ( None if preferred_element_type is None else dtypes.canonicalize_dtype(np.dtype(preferred_element_type))) + lhs, rhs = core.standard_insert_pvary(lhs, rhs) return dot_general_p.bind(lhs, rhs, dimension_numbers=(cdims, bdims), precision=canonicalize_precision(precision), @@ -2546,6 +2656,7 @@ def ragged_dot_general( extra leading dimension of size `g` in the case where the lhs ragged dimension is a contracting dimension. """ + lhs, rhs, group_sizes = core.standard_insert_pvary(lhs, rhs, group_sizes) return ragged_dot_general_p.bind( lhs, rhs, @@ -2557,7 +2668,7 @@ def ragged_dot_general( ) -def broadcast(operand: ArrayLike, sizes: Sequence[int], out_sharding=None +def broadcast(operand: ArrayLike, sizes: Sequence[int], *, out_sharding=None ) -> Array: """Broadcasts an array, adding new leading dimensions @@ -2579,7 +2690,7 @@ def broadcast(operand: ArrayLike, sizes: Sequence[int], out_sharding=None out_sharding=out_sharding) def broadcast_in_dim(operand: ArrayLike, shape: Shape, - broadcast_dimensions: Sequence[int], out_sharding=None + broadcast_dimensions: Sequence[int], *, out_sharding=None ) -> Array: """Wraps XLA's `BroadcastInDim `_ @@ -2622,7 +2733,7 @@ def broadcast_to_rank(x: ArrayLike, rank: int) -> Array: def reshape(operand: ArrayLike, new_sizes: Shape, dimensions: Sequence[int] | None = None, - out_sharding: NamedSharding | P | None = None) -> Array: + *, out_sharding: NamedSharding | P | None = None) -> Array: """Wraps XLA's `Reshape `_ operator. @@ -2729,6 +2840,7 @@ def pad(operand: ArrayLike, padding_value: ArrayLike, [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], dtype=int32) """ + operand, padding_value = core.standard_insert_pvary(operand, padding_value) return pad_p.bind(operand, padding_value, padding_config=tuple(padding_config)) def rev(operand: ArrayLike, dimensions: Sequence[int]) -> Array: @@ -2761,6 +2873,8 @@ def select(pred: ArrayLike, on_true: ArrayLike, on_false: ArrayLike) -> Array: """ # Caution! The select_n_p primitive has the *opposite* order of arguments to # select(). This is because it implements `select_n`. + pred, on_false, on_true = core.standard_insert_pvary( + pred, on_false, on_true) return select_n_p.bind(pred, on_false, on_true) def select_n(which: ArrayLike, *cases: ArrayLike) -> Array: @@ -2786,6 +2900,7 @@ def select_n(which: ArrayLike, *cases: ArrayLike) -> Array: """ if len(cases) == 0: raise ValueError("select_n() must have at least one case") + which, *cases = core.standard_insert_pvary(which, *cases) return select_n_p.bind(which, *cases) @@ -2799,6 +2914,7 @@ def transpose(operand: ArrayLike, if permutation == tuple(range(np.ndim(operand))) and isinstance(operand, Array): return operand else: + return transpose_p.bind(operand, permutation=permutation) def argmin(operand: ArrayLike, axis: int, @@ -3146,6 +3262,7 @@ def sort(operand: Array | Sequence[Array], dimension: int = -1, if not (1 <= num_keys <= len(operand)): raise ValueError(f"{num_keys=} must be between 1 and {len(operand)=}") dimension = canonicalize_axis(dimension, len(operand[0].shape)) + operand = core.standard_insert_pvary(*operand) return tuple(sort_p.bind(*operand, dimension=dimension, is_stable=is_stable, num_keys=num_keys)) @@ -3241,7 +3358,9 @@ def zeros_like_shaped_array(aval: ShapedArray) -> Array: scalar_zero = np.zeros((), dtype=aval.dtype) else: scalar_zero = _convert_element_type(0, aval.dtype, aval.weak_type) - return broadcast(scalar_zero, aval.shape, out_sharding=aval.sharding) + out = broadcast(scalar_zero, aval.shape, out_sharding=aval.sharding) + out = core.pvary(out, tuple(aval.vma)) + return out ad_util.aval_zeros_likers[ShapedArray] = zeros_like_shaped_array @@ -3262,7 +3381,7 @@ def iota(dtype: DTypeLike, size: int) -> Array: return broadcasted_iota(dtype, (size,), 0) def broadcasted_iota(dtype: DTypeLike, shape: Shape, dimension: int, - out_sharding=None) -> Array: + *, out_sharding=None) -> Array: """Convenience wrapper around ``iota``.""" dtype = dtypes.canonicalize_dtype(dtype) shape = canonicalize_shape(shape) @@ -3378,7 +3497,8 @@ def reduce_precision(operand: float | ArrayLike, operator.index, exponent_bits, "exponent_bits argument of lax.reduce_precision") mantissa_bits = core.concrete_or_error( operator.index, mantissa_bits, "mantissa_bits argument of lax.reduce_precision") - return reduce_precision_p.bind(operand, exponent_bits=exponent_bits, mantissa_bits=mantissa_bits) + return reduce_precision_p.bind(operand, exponent_bits=exponent_bits, + mantissa_bits=mantissa_bits) def squeeze(array: ArrayLike, dimensions: Sequence[int]) -> Array: """Squeeze any number of size 1 dimensions from an array.""" @@ -3513,13 +3633,19 @@ def reciprocal(x: ArrayLike) -> Array: return integer_pow(x, -1) @export -def tan(x: ArrayLike) -> Array: +def tan(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise tangent: :math:`\mathrm{tan}(x)`. This function lowers directly to the `stablehlo.tangent`_ operation. Args: x: input array. Must have floating-point or complex type. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: Array of the same shape and dtype as ``x`` containing the element-wise @@ -3533,7 +3659,7 @@ def tan(x: ArrayLike) -> Array: .. _stablehlo.tangent: https://openxla.org/stablehlo/spec#tangent """ - return tan_p.bind(x) + return tan_p.bind(x, accuracy=accuracy) @export def asin(x: ArrayLike) -> Array: @@ -3762,7 +3888,8 @@ def unop_dtype_rule(result_dtype, accepted_dtypes, name, aval, **kwargs): def unop(result_dtype, accepted_dtypes, name): dtype_rule = partial(unop_dtype_rule, result_dtype, accepted_dtypes, name) prim = standard_primitive(_attrgetter('shape'), dtype_rule, name, - sharding_rule=_attrgetter('sharding')) + sharding_rule=_attrgetter('sharding'), + vma_rule=_attrgetter('vma')) batching.defvectorized(prim) pe.def_trivial_padding(prim) return prim @@ -3811,7 +3938,7 @@ def broadcasting_sharding_rule(name, *avals): for a in avals: if a.sharding is not None and not a.sharding.mesh.empty: if mesh is not None and mesh != a.sharding.mesh: - raise ValueError( + raise core.ShardingTypeError( f'Mesh for all inputs should be equal. Got one mesh: {mesh} and' f' another mesh: {a.sharding.mesh}') mesh = a.sharding.mesh @@ -3845,12 +3972,11 @@ def broadcasting_sharding_rule(name, *avals): result_specs[i] = s elif (result_specs[i] is not None and s is not None and result_specs[i] != s): - raise TypeError( + raise core.ShardingTypeError( f'{name} got incompatible shardings for broadcasting: ' f'{", ".join(map(str, map(tuple, specs)))}.') return NamedSharding(mesh, P(*result_specs)) - def naryop(result_dtype, accepted_dtypes, name, allow_extended_dtype=False, require_same_dtypes=True): dtype_rule = partial(naryop_dtype_rule, result_dtype, accepted_dtypes, name, @@ -3858,8 +3984,9 @@ def naryop(result_dtype, accepted_dtypes, name, allow_extended_dtype=False, require_same=require_same_dtypes) shape_rule = partial(broadcasting_shape_rule, name) sharding_rule = partial(broadcasting_sharding_rule, name) - prim = standard_primitive(shape_rule, dtype_rule, name, - sharding_rule=sharding_rule) + prim = standard_primitive( + shape_rule, dtype_rule, name, sharding_rule=sharding_rule, + vma_rule=partial(core.standard_vma_rule, name)) batching.defbroadcasting(prim) pe.def_trivial_padding(prim) return prim @@ -3926,8 +4053,9 @@ def multi_sharding_in_dim(ctx, ops, in_avals, out_aval): return out -def _nary_lower_hlo(op: Callable, ctx, - *args: ir.Value, **params) -> Sequence[ir.Value]: +def _nary_lower_hlo( + op: Callable, ctx, *args: ir.Value, accuracy=None, **params +) -> Sequence[ir.Value]: """Lowers an elementwise operator to its MLIR equivalent. """ del params @@ -3936,8 +4064,15 @@ def _nary_lower_hlo(op: Callable, ctx, args = multi_sharding_in_dim(ctx, args, avals_in, aval_out) out = op(*args) + if accuracy: + out = op(*args, result_accuracy=accuracy_attr(accuracy)) return [mlir.lower_with_sharding_in_types(ctx, out, aval_out)] +def _unary_with_accuracy_pp_rule(eqn, context, settings): + params = dict(eqn.params) + if 'accuracy' in params and params['accuracy'] is None: + del params['accuracy'] + return core._pp_eqn(eqn.replace(params=params), context, settings) _float = {np.floating} _complex = {np.complexfloating} @@ -3997,48 +4132,68 @@ def _round_lower(ctx, x, *, rounding_method): mlir.register_lowering(is_finite_p, partial(_nary_lower_hlo, hlo.is_finite)) exp_p = standard_unop(_float | _complex, 'exp') -ad.defjvp2(exp_p, lambda g, ans, x: mul(g, ans)) +ad.defjvp2(exp_p, lambda g, ans, x, **kwargs: mul(g, ans)) mlir.register_lowering(exp_p, partial(_nary_lower_hlo, hlo.exponential)) batching.ragged_prop_rules[exp_p] = batching.ragged_mask_elementwise_rule +core.pp_eqn_rules[exp_p] = _unary_with_accuracy_pp_rule exp2_p = standard_unop(_float | _complex, 'exp2') -ad.defjvp2(exp2_p, lambda g, ans, x: mul(log(_const(x, 2)), mul(g, ans))) -def _exp2_lower(ctx, x): +ad.defjvp2( + exp2_p, lambda g, ans, x, **kwargs: mul(log(_const(x, 2)), mul(g, ans)) +) + +def _exp2_lower(ctx, x, accuracy): x_aval, = ctx.avals_in log2 = mlir.ir_constant(np.array(np.log(2), x_aval.dtype)) log2 = mlir.broadcast_in_dim(ctx, log2, x_aval, broadcast_dimensions=()) - return [hlo.exponential(hlo.multiply(log2, x))] + return [ + hlo.exponential( + hlo.multiply(log2, x), result_accuracy=accuracy_attr(accuracy) + ) + ] + mlir.register_lowering(exp2_p, _exp2_lower) +core.pp_eqn_rules[exp2_p] = _unary_with_accuracy_pp_rule log_p = standard_unop(_float | _complex, 'log') -ad.defjvp(log_p, lambda g, x: div(g, x)) +ad.defjvp(log_p, lambda g, x, **kwargs: div(g, x)) mlir.register_lowering(log_p, partial(_nary_lower_hlo, hlo.log)) +core.pp_eqn_rules[log_p] = _unary_with_accuracy_pp_rule expm1_p = standard_unop(_float | _complex, 'expm1') -ad.defjvp2(expm1_p, lambda g, ans, x: mul(g, add(ans, _one(ans)))) +ad.defjvp2(expm1_p, lambda g, ans, x, **kwargs: mul(g, add(ans, _one(ans)))) mlir.register_lowering(expm1_p, partial(_nary_lower_hlo, hlo.exponential_minus_one)) +core.pp_eqn_rules[expm1_p] = _unary_with_accuracy_pp_rule log1p_p = standard_unop(_float | _complex, 'log1p') -ad.defjvp(log1p_p, lambda g, x: div(g, add(x, _one(x)))) +ad.defjvp(log1p_p, lambda g, x, **kwargs: div(g, add(x, _one(x)))) mlir.register_lowering(log1p_p, partial(_nary_lower_hlo, hlo.log_plus_one)) +core.pp_eqn_rules[log1p_p] = _unary_with_accuracy_pp_rule tanh_p = standard_unop(_float | _complex, 'tanh') -ad.defjvp2(tanh_p, lambda g, ans, x: mul(add(g, mul(g, ans)), - sub(_one(x), ans))) +ad.defjvp2( + tanh_p, + lambda g, ans, x, **kwargs: mul(add(g, mul(g, ans)), sub(_one(x), ans)), +) mlir.register_lowering(tanh_p, partial(_nary_lower_hlo, hlo.tanh)) +core.pp_eqn_rules[tanh_p] = _unary_with_accuracy_pp_rule logistic_p = standard_unop(_float | _complex, 'logistic') -ad.defjvp2(logistic_p, lambda g, ans, x: mul(g, mul(ans, sub(_one(ans), ans)))) +ad.defjvp2( + logistic_p, + lambda g, ans, x, **kwargs: mul(g, mul(ans, sub(_one(ans), ans))), +) # TODO(phawkins): switch to LogisticOp lowering; debug numerical problems. # mlir.register_lowering(logistic_p, partial(_nary_lower_hlo, hlo.logistic)) -def logistic_impl(x): +def logistic_impl(x, accuracy): one = _const(x, 1) return div(one, add(one, exp(neg(x)))) mlir.register_lowering(logistic_p, mlir.lower_fun(logistic_impl, multiple_results=False)) +core.pp_eqn_rules[logistic_p] = _unary_with_accuracy_pp_rule def _sin_complex(x): # use expm1 instead of exp to avoid cancellation when abs(x) is small @@ -4056,21 +4211,28 @@ def _sin_complex(x): # avoid nan value when real(x) is zero and abs(x) is so large that abs(expm1(x)) is inf return select(a_is_zero, complex(_const(a, 0), im), complex(re, im)) -def _sin_lowering(ctx, x): +def _sin_lowering(ctx, x, accuracy): if dtypes.issubdtype(ctx.avals_in[0].dtype, np.complexfloating): sine = mlir.lower_fun(_sin_complex, multiple_results=False) return sine(ctx, x) - return _nary_lower_hlo(hlo.sine, ctx, x) + return _nary_lower_hlo(hlo.sine, ctx, x, accuracy=accuracy) -def _sin_lin(nzs, x): + +def _sin_p_lin(nzs, x, accuracy): nz, = nzs cos_x = cos(x) # TODO: allow this to happen in the linearized computation (need to fix backward_pass) - return (sin_p.bind(x), nz, cos_x, lambda cos_x_, t: mul(t, cos_x_)) + return ( + sin_p.bind(x, accuracy=accuracy), + nz, + cos_x, + lambda cos_x_, t: mul(t, cos_x_), + ) sin_p = standard_unop(_float | _complex, 'sin') -ad.defjvp(sin_p, lambda g, x: mul(g, cos(x))) -ad.primitive_linearizations[sin_p] = _sin_lin +ad.defjvp(sin_p, lambda g, x, accuracy: mul(g, cos(x, accuracy=accuracy))) +ad.primitive_linearizations[sin_p] = _sin_p_lin mlir.register_lowering(sin_p, _sin_lowering) +core.pp_eqn_rules[sin_p] = _unary_with_accuracy_pp_rule batching.ragged_prop_rules[sin_p] = batching.ragged_mask_elementwise_rule def _cos_complex(x): @@ -4085,19 +4247,23 @@ def _cos_complex(x): re, im = mul(cs, csh), mul(neg(sn), snh) return select(a_is_zero, complex(re, _const(a, 0)), complex(re, im)) -def _cos_lowering(ctx, x): +def _cos_lowering(ctx, x, accuracy): if dtypes.issubdtype(ctx.avals_in[0].dtype, np.complexfloating): cosine = mlir.lower_fun(_cos_complex, multiple_results=False) return cosine(ctx, x) - return _nary_lower_hlo(hlo.cosine, ctx, x) + return _nary_lower_hlo(hlo.cosine, ctx, x, accuracy=accuracy) cos_p = standard_unop(_float | _complex, 'cos') -ad.defjvp(cos_p, lambda g, x: neg(mul(g, sin(x)))) +ad.defjvp( + cos_p, lambda g, x, accuracy: neg(mul(g, sin(x, accuracy=accuracy))) +) mlir.register_lowering(cos_p, _cos_lowering) +core.pp_eqn_rules[cos_p] = _unary_with_accuracy_pp_rule tan_p = standard_unop(_float | _complex, 'tan') -ad.defjvp2(tan_p, lambda g, ans, x: mul(g, add(_const(x, 1), square(ans)))) +ad.defjvp2(tan_p, lambda g, ans, x, **kwargs: mul(g, add(_const(x, 1), square(ans)))) mlir.register_lowering(tan_p, partial(_nary_lower_hlo, hlo.tan)) +core.pp_eqn_rules[tan_p] = _unary_with_accuracy_pp_rule asin_p = standard_unop(_float | _complex, 'asin') ad.defjvp(asin_p, lambda g, x: mul(g, rsqrt(sub(_const(x, 1), square(x))))) @@ -4213,19 +4379,27 @@ def _abs_jvp_rule(g, ans, x): _maybe_real = lambda x: real(x) if _iscomplex(x) else x sqrt_p = standard_unop(_float | _complex, 'sqrt') -ad.defjvp2(sqrt_p, lambda g, ans, x: mul(g, div(_const(x, 0.5), ans))) +ad.defjvp2(sqrt_p, lambda g, ans, x, **kwargs: mul(g, div(_const(x, 0.5), ans))) mlir.register_lowering(sqrt_p, partial(_nary_lower_hlo, hlo.sqrt)) +core.pp_eqn_rules[sqrt_p] = _unary_with_accuracy_pp_rule rsqrt_p = standard_unop(_float | _complex, 'rsqrt') -ad.defjvp2(rsqrt_p, - lambda g, ans, x: - mul(g, mul(_const(x, -0.5), div(ans, x)))) +ad.defjvp2( + rsqrt_p, + lambda g, ans, x, **kwargs: mul(g, mul(_const(x, -0.5), div(ans, x))), +) mlir.register_lowering(rsqrt_p, partial(_nary_lower_hlo, hlo.rsqrt)) +core.pp_eqn_rules[rsqrt_p] = _unary_with_accuracy_pp_rule cbrt_p = standard_unop(_float, 'cbrt') -ad.defjvp2(cbrt_p, - lambda g, ans, x: mul(g, mul(_const(x, 1/3), integer_pow(ans, -2)))) +ad.defjvp2( + cbrt_p, + lambda g, ans, x, **kwargs: mul( + g, mul(_const(x, 1 / 3), integer_pow(ans, -2)) + ), +) mlir.register_lowering(cbrt_p, partial(_nary_lower_hlo, hlo.cbrt)) +core.pp_eqn_rules[cbrt_p] = _unary_with_accuracy_pp_rule square_p = standard_unop(_int | _float | _complex, 'square') @@ -4307,7 +4481,7 @@ def _integer_pow_jvp(g, x, *, y): integer_pow_p = standard_primitive( _attrgetter('shape'), _integer_pow_dtype_rule, 'integer_pow', - sharding_rule=_attrgetter('sharding')) + sharding_rule=_attrgetter('sharding'), vma_rule=_attrgetter('vma')) batching.defvectorized(integer_pow_p) ad.defjvp(integer_pow_p, _integer_pow_jvp) pe.def_trivial_padding(integer_pow_p) @@ -4710,7 +4884,8 @@ def _convert_element_type_bind_with_trace(trace, args, params): partial(standard_abstract_eval, convert_element_type_p, _convert_element_type_shape_rule, _convert_element_type_dtype_rule, _convert_element_type_weak_type_rule, - _convert_element_type_sharding_rule)) + _convert_element_type_sharding_rule, + partial(core.standard_vma_rule, convert_element_type_p.name))) ad.defjvp2(convert_element_type_p, _convert_element_type_jvp_rule) ad.primitive_transposes[convert_element_type_p] = _convert_element_type_transpose_rule batching.defvectorized(convert_element_type_p) @@ -4875,7 +5050,8 @@ def _bitcast_convert_type_dtype_rule(operand, *, new_dtype): bitcast_convert_type_p = standard_primitive( _bitcast_convert_type_shape_rule, _bitcast_convert_type_dtype_rule, 'bitcast_convert_type', weak_type_rule=_strip_weak_type, - sharding_rule=_bitcast_convert_type_sharding_rule) + sharding_rule=_bitcast_convert_type_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'bitcast_convert_type')) ad.defjvp_zero(bitcast_convert_type_p) batching.defvectorized(bitcast_convert_type_p) @@ -4996,13 +5172,13 @@ def _dot_general_shape_computation(lhs_shape, rhs_shape, dimension_numbers): def _check_specs_match(lhs_spec, rhs_spec, msg): for l, r in zip(lhs_spec, rhs_spec): if l is not None and r is not None and l != r: - raise TypeError(msg) + raise core.ShardingTypeError(msg) def _dot_general_sharding_rule(lhs, rhs, *, dimension_numbers, precision, preferred_element_type: DTypeLike | None, out_sharding): if lhs.sharding.mesh != rhs.sharding.mesh: - raise ValueError( + raise core.ShardingTypeError( 'Mesh of both lhs and rhs should match. Got lhs:' f' {lhs.sharding.mesh} and rhs: {rhs.sharding.mesh}') @@ -5026,7 +5202,7 @@ def _dot_general_sharding_rule(lhs, rhs, *, dimension_numbers, precision, for l, r in zip(lhs_contracting_spec, rhs_contracting_spec): if l is not None and r is not None: - raise ValueError( + raise core.ShardingTypeError( 'Contracting dimensions are sharded and it is ambiguous how the' ' output should be sharded. Please specify the output sharding via' ' the `out_sharding` parameter of einsum. Or reshard your input via' @@ -5344,6 +5520,7 @@ def _dot_general_ragged_prop_rule(eqn_params, invar_raggedness, outvars): _dot_general_dtype_rule, 'dot_general', sharding_rule=_dot_general_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'dot_general') ) @@ -5371,15 +5548,26 @@ def _dot_general_batch_unpack_dims(batch_dims): core.pp_eqn_rules[dot_general_p] = _dot_general_pp_rule batching.ragged_prop_rules[dot_general_p] = _dot_general_ragged_prop_rule -def precision_attr(precision: Precision) -> ir.ArrayAttr: + +def _full_precision(precision: Precision) -> tuple[Precision, Precision]: if precision is None or isinstance(precision, (DotAlgorithm, DotAlgorithmPreset)): - full_precision = (Precision.DEFAULT, Precision.DEFAULT) + return (Precision.DEFAULT, Precision.DEFAULT) elif not isinstance(precision, tuple): - full_precision = (precision, precision) + return (precision, precision) else: - full_precision = precision + return precision + + +def precision_attr(precision: Precision) -> ir.ArrayAttr: + return ir.ArrayAttr.get( + [hlo.PrecisionAttr.get(str(p)) for p in _full_precision(precision)] + ) + + +def chlo_precision_attr(precision: Precision) -> ir.ArrayAttr: return ir.ArrayAttr.get( - [hlo.PrecisionAttr.get(str(p)) for p in full_precision]) + [chlo.PrecisionAttr.get(str(p)) for p in _full_precision(precision)] + ) def dot_algorithm_attr(precision: CanonicalPrecision, lhs_dtype: DTypeLike, @@ -5417,32 +5605,30 @@ def maybe_convert_dtype(input_dtype, target_dtypes): return lhs_dtype, rhs_dtype, out_type -def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers, - precision, preferred_element_type: np.dtype | None, - out_sharding, platform: str = "default"): +def accuracy_attr(accuracy) -> hlo.ResultAccuracyAttr: + if isinstance(accuracy, AccuracyMode): + return hlo.ResultAccuracyAttr.get(0.0, 0.0, int(0), str(accuracy.name)) + elif isinstance(accuracy, Tolerance): + return hlo.ResultAccuracyAttr.get( + atol=accuracy.atol, + rtol=accuracy.rtol, + ulps=accuracy.ulps, + mode='TOLERANCE', + ) + +def _handle_dot_precision(ctx, lhs, rhs, precision, platform): def _is_fp8_mixed_precision_matmul(_lhs_dtypes, _rhs_dtypes): fp8_dtypes = (dtypes.float8_e4m3fn, dtypes.float8_e5m2, - dtypes.float8_e5m2fnuz, dtypes.float8_e4m3fnuz) - if dtypes.float8_e3m4 is not None: - fp8_dtypes += (dtypes.float8_e3m4,) - if dtypes.float8_e4m3 is not None: - fp8_dtypes += (dtypes.float8_e4m3,) - if dtypes.float8_e8m0fnu is not None: - fp8_dtypes += (dtypes.float8_e8m0fnu,) + dtypes.float8_e5m2fnuz, dtypes.float8_e4m3fnuz, + dtypes.float8_e3m4, dtypes.float8_e4m3, + dtypes.float8_e8m0fnu) return _lhs_dtypes in fp8_dtypes and _rhs_dtypes in fp8_dtypes - del preferred_element_type # Implied by the output aval - lhs_aval, rhs_aval = ctx.avals_in + + # The *_ lets us reuse this for ragged_dot_general, which has group_sizes. + lhs_aval, rhs_aval, *_ = ctx.avals_in lhs_dtype, rhs_dtype = lhs_aval.dtype, rhs_aval.dtype aval_out, = ctx.avals_out accumulation_aval = aval_out - (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers - - dot_dnums = hlo.DotDimensionNumbers.get( - lhs_batching_dimensions=list(lhs_batch), - rhs_batching_dimensions=list(rhs_batch), - lhs_contracting_dimensions=list(lhs_contracting), - rhs_contracting_dimensions=list(rhs_contracting)) - algorithm_kwarg = {} if isinstance(precision, (DotAlgorithm, DotAlgorithmPreset)): # The CPU backend silently ignores the algorithm spec, so we check here to @@ -5500,7 +5686,22 @@ def maybe_convert_dtype(operand, operand_aval, target_dtype): core.ShapedArray(lhs_aval.shape, aval_out.dtype)) rhs = mlir.convert_hlo(ctx, rhs, rhs_aval, core.ShapedArray(rhs_aval.shape, aval_out.dtype)) + return lhs, rhs, accumulation_aval, algorithm_kwarg + +def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers, + precision, preferred_element_type: np.dtype | None, + out_sharding, platform: str = "default"): + del preferred_element_type # Implied by the output aval + lhs, rhs, accumulation_aval, algorithm_kwarg = _handle_dot_precision( + ctx, lhs, rhs, precision, platform + ) + (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers + dot_dnums = hlo.DotDimensionNumbers.get( + lhs_batching_dimensions=list(lhs_batch), + rhs_batching_dimensions=list(rhs_batch), + lhs_contracting_dimensions=list(lhs_contracting), + rhs_contracting_dimensions=list(rhs_contracting)) result = hlo.dot_general( mlir.aval_to_ir_type(accumulation_aval), lhs, @@ -5509,7 +5710,7 @@ def maybe_convert_dtype(operand, operand_aval, target_dtype): precision_config=precision_attr(precision), **algorithm_kwarg, ) - + aval_out, = ctx.avals_out result = mlir.lower_with_sharding_in_types(ctx, result, aval_out) if accumulation_aval.dtype != aval_out.dtype: result = mlir.convert_hlo(ctx, result, accumulation_aval, aval_out) @@ -5918,6 +6119,7 @@ def _ragged_dot_general_batch_rule( _ragged_dot_general_shape_rule, _ragged_dot_general_dtype_rule, 'ragged_dot_general', + vma_rule=partial(core.standard_vma_rule, 'ragged_dot') ) ad.primitive_jvps[ragged_dot_general_p] = _ragged_dot_general_jvp_rule ad.primitive_transposes[ragged_dot_general_p] = _ragged_dot_general_transpose_rule @@ -6028,10 +6230,85 @@ def expand(x, dim, gs, *axes): ) +def _ragged_dot_general_lower( + ctx, + lhs, + rhs, + group_sizes, + *, + ragged_dot_dimension_numbers, + precision, + preferred_element_type: np.dtype | None, + group_offset: Array | None = None, + platform: str = 'default', +): + if group_offset is not None: + raise NotImplementedError('Unimplemented group_offset support.') + + # TODO(pravnar): Remove this once we have sharding support. + def use_default_lowering(): + axis_context = ctx.module_context.axis_context + return ( + isinstance(axis_context, SPMDAxisContext) + or isinstance(axis_context, ShardingContext) + and axis_context.num_devices > 1 + ) + if use_default_lowering(): + result = mlir.lower_fun(_ragged_dot_general_impl, multiple_results=False)( + ctx, lhs, rhs, group_sizes, + ragged_dot_dimension_numbers=ragged_dot_dimension_numbers, + precision=precision, + preferred_element_type=preferred_element_type, + group_offset=group_offset + ) + (aval_out,) = ctx.avals_out + return mlir.lower_with_sharding_in_types(ctx, result, aval_out) + + del preferred_element_type # Implied by the output aval + lhs, rhs, accumulation_aval, _ = _handle_dot_precision( + ctx, lhs, rhs, precision, platform + ) + (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = ( + ragged_dot_dimension_numbers.dot_dimension_numbers + ) + ragged_dot_dnums = chlo.RaggedDotDimensionNumbers.get( + lhs_batching_dimensions=list(lhs_batch), + rhs_batching_dimensions=list(rhs_batch), + lhs_contracting_dimensions=list(lhs_contracting), + rhs_contracting_dimensions=list(rhs_contracting), + lhs_ragged_dimensions=list( + ragged_dot_dimension_numbers.lhs_ragged_dimensions + ), + rhs_group_dimensions=list( + ragged_dot_dimension_numbers.rhs_group_dimensions + ), + ) + result = chlo.ragged_dot( + mlir.aval_to_ir_type(accumulation_aval), + lhs, + rhs, + group_sizes, + ragged_dot_dnums, + precision_config=chlo_precision_attr(precision), + ) + (aval_out,) = ctx.avals_out + result = mlir.lower_with_sharding_in_types(ctx, result, aval_out) + if accumulation_aval.dtype != aval_out.dtype: + result = mlir.convert_hlo(ctx, result, accumulation_aval, aval_out) + return [result] + + mlir.register_lowering(ragged_dot_general_p, mlir.lower_fun(_ragged_dot_general_impl, multiple_results=False)) +for platform in ['tpu']: + mlir.register_lowering( + ragged_dot_general_p, + partial(_ragged_dot_general_lower, platform=platform), + platform=platform, + ) + def _broadcast_in_dim_shape_rule(operand, *, shape, broadcast_dimensions, sharding): @@ -6252,7 +6529,9 @@ def _broadcast_in_dim_abstract_eval(x, *dyn_shape, shape, broadcast_dimensions, new_sharding = _broadcast_in_dim_sharding_rule( x, shape=shape, broadcast_dimensions=broadcast_dimensions, sharding=sharding) - return core.ShapedArray(shape, x.dtype, x.weak_type, sharding=new_sharding) + new_vma = core.standard_vma_rule('broadcast_in_dim', x) + return core.ShapedArray(shape, x.dtype, x.weak_type, sharding=new_sharding, + vma=new_vma) # If any BInts in shape, or Tracers in dyn_shape, produce a DShapedArray # (even if x is a ShapedArray) # TODO(mattjj): unify DShapedArray with ShapedArray, and remove this code @@ -6336,7 +6615,8 @@ def _clamp_batch_rule(batched_args, batch_dims, **params): return clamp_p.bind(min, x, max), 0 clamp_p = standard_primitive(_clamp_shape_rule, _clamp_dtype_rule, 'clamp', - sharding_rule=_clamp_sharding_rule) + sharding_rule=_clamp_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'clamp')) ad.defjvp(clamp_p, lambda g, min, operand, max: select(bitwise_and(gt(min, operand), lt(min, max)), @@ -6384,7 +6664,7 @@ def _concatenate_sharding_rule(*operands, **kwargs): return core.get_cur_mesh_sharding() if not all(s == non_empty_s[0] for s in non_empty_s): ss = ", ".join(str(o.sharding) for o in operands) - raise TypeError( + raise core.ShardingTypeError( f"All operands should have the same sharding. Got shardings {ss}") return non_empty_s[0] @@ -6423,7 +6703,8 @@ def _concatenate_pad_rule(in_avals, out_avals, *operands, dimension): concatenate_p = standard_primitive( _concatenate_shape_rule, _concatenate_dtype_rule, 'concatenate', - sharding_rule=_concatenate_sharding_rule) + sharding_rule=_concatenate_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'concatenate')) ad.deflinear2(concatenate_p, _concatenate_transpose_rule) ad.primitive_transposes[concatenate_p] = _concatenate_transpose_rule batching.primitive_batchers[concatenate_p] = _concatenate_batch_rule @@ -6495,11 +6776,17 @@ def _split_sharding_rule(operand, *, sizes, axis): return [slicing._get_sharding_for_varying_out_shape(out_sh, operand, 'split') for out_sh in out_shapes] +def _split_vma_rule(operand, *, sizes, axis): + out_vma = core.standard_vma_rule('split', operand) + out_shapes = _split_shape_rule(operand, sizes=sizes, axis=axis) + return [out_vma] * len(out_shapes) + split_p = core.Primitive('split') split_p.multiple_results = True split_p.def_abstract_eval( partial(standard_multi_result_abstract_eval, split_p, _split_shape_rule, - _split_dtype_rule, _split_weak_type_rule, _split_sharding_rule)) + _split_dtype_rule, _split_weak_type_rule, _split_sharding_rule, + _split_vma_rule)) split_p.def_impl(partial(dispatch.apply_primitive, split_p)) ad.deflinear2(split_p, _split_transpose_rule) batching.primitive_batchers[split_p] = _split_batch_rule @@ -6581,7 +6868,8 @@ def _pad_batch_rule(batched_args, batch_dims, *, padding_config): return select(mask, x, broadcasted_padding), operand_bdim pad_p = standard_primitive(_pad_shape_rule, _pad_dtype_rule, 'pad', - sharding_rule=_pad_sharding_rule) + sharding_rule=_pad_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'pad')) ad.deflinear2(pad_p, _pad_transpose) batching.primitive_batchers[pad_p] = _pad_batch_rule @@ -6645,7 +6933,8 @@ def _squeeze_batch_rule(batched_args, batch_dims, *, dimensions): return squeeze(operand, dimensions=dimensions), bdim_out squeeze_p = standard_primitive(_squeeze_shape_rule, _squeeze_dtype_rule, - 'squeeze', sharding_rule=_squeeze_sharding_rule) + 'squeeze', sharding_rule=_squeeze_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'squeeze')) ad.deflinear2(squeeze_p, _squeeze_transpose_rule) batching.primitive_batchers[squeeze_p] = _squeeze_batch_rule pe.def_trivial_padding(squeeze_p) @@ -6703,16 +6992,21 @@ def _split_on_one_axis(op_shape, new_sizes, name): else: count += 1 if count > 1: - raise ValueError( + raise core.ShardingTypeError( f'{name} on more than 1 axis is not supported. Please specify' ' the sharding of the output via the `sharding` argument of' f' jax.lax.reshape. Got operand.shape={op_shape} and {new_sizes=}') temp = [new_sizes[j]] - while math.prod(temp) != op_shape[i]: + next_j = j + 1 + while (math.prod(temp) != op_shape[i] or + (next_j < len(new_sizes) and new_sizes[next_j] == 1)): if math.prod(temp) > op_shape[i]: return False, [] j += 1 + if j >= len(new_sizes): + return False, [] temp.append(new_sizes[j]) + next_j += 1 out.append(temp) i += 1 j += 1 @@ -6744,7 +7038,7 @@ def _reshape_sharding_rule(operand, *, new_sizes, dimensions, sharding): return _merge_an_axis_sharding_rule(operand, operand_merge, new_sizes, dimensions) - raise ValueError( + raise core.ShardingTypeError( 'This reshape is not supported. Please specify the sharding of' ' the output via the `out_sharding` argument of jax.lax.reshape. Got' f' operand shape: {operand.shape}, new sizes: {new_sizes} and' @@ -6777,7 +7071,7 @@ def _split_an_axis_sharding_rule(operand, out_split, new_sizes, dimensions): elif dimensions is None and out[0] % _get_spec_size(sp, mesh) == 0: new_spec.extend([sp] + [None] * (len(out) - 1)) else: - raise ValueError( + raise core.ShardingTypeError( 'This reshape is not supported. Please specify the sharding of the' ' output via the `sharding` argument of jax.lax.reshape. Got' f' operand shape: {operand.shape}, new sizes: {new_sizes} and' @@ -6802,7 +7096,7 @@ def _merge_an_axis_sharding_rule(operand, operand_merge, new_sizes, dimensions): assert new_size % _get_spec_size(sp[0], mesh) == 0 new_spec.append(sp[0]) else: - raise ValueError( + raise core.ShardingTypeError( 'This reshape is not supported. Please specify the sharding of the' ' output via the `sharding` argument of jax.lax.reshape. Got' f' operand shape: {operand.shape}, new sizes: {new_sizes} and' @@ -6879,7 +7173,8 @@ def _reshape_staging_rule( return _dyn_shape_staging_rule(trace, reshape_p, av, x, *dyn, **params) reshape_p = standard_primitive(_reshape_shape_rule, _reshape_dtype_rule, - 'reshape', sharding_rule=_reshape_sharding_rule) + 'reshape', sharding_rule=_reshape_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'reshape')) ad.deflinear2(reshape_p, _reshape_transpose_rule) batching.fancy_primitive_batchers[reshape_p] = _reshape_batch_rule batching.skippable_batchers[reshape_p] = lambda _: () @@ -6911,7 +7206,8 @@ def _rev_batch_rule(batched_args, batch_dims, *, dimensions): return rev(operand, new_dimensions), bdim rev_p = standard_primitive(_rev_shape_rule, _input_dtype, 'rev', - sharding_rule=_rev_sharding_rule) + sharding_rule=_rev_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'rev')) ad.deflinear2(rev_p, lambda t, _, dimensions: [rev(t, dimensions)]) batching.primitive_batchers[rev_p] = _rev_batch_rule @@ -6959,7 +7255,8 @@ def _transpose_lower(ctx, x, *, permutation): transpose_p = standard_primitive( _transpose_shape_rule, _input_dtype, 'transpose', - sharding_rule=_transpose_sharding_rule) + sharding_rule=_transpose_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'transpose')) ad.deflinear2(transpose_p, lambda t, _, permutation: [transpose(t, np.argsort(permutation))]) batching.primitive_batchers[transpose_p] = _transpose_batch_rule @@ -6985,10 +7282,11 @@ def _select_sharding_rule(which, *cases): return core.get_cur_mesh_sharding() if any(s != non_empty_s[0] for s in non_empty_s[1:]): msg = "select cases must have the same shardings, got [{}]." - raise TypeError(msg.format(", ".join([str(c.sharding) for c in cases]))) + raise core.ShardingTypeError( + msg.format(", ".join([str(c.sharding) for c in cases]))) if (which.shape and not which.sharding.mesh.empty and which.sharding != non_empty_s[0]): - raise TypeError( + raise core.ShardingTypeError( 'select `which` must be scalar or have the same sharding as cases, got' f' `which` sharding {which.sharding} but case sharding' f' {cases[0].sharding}.') @@ -7079,7 +7377,11 @@ def _select_jvp(primals, tangents): def _select_hlo_lowering_opaque(ctx, which, *cases): avals_in = ctx.avals_in aval_out, = ctx.avals_out - assert all(aval_case == aval_out for aval_case in avals_in[1:]) + assert all((aval_case.shape, aval_case.dtype) == (aval_out.shape, aval_out.dtype) + for aval_case in avals_in[1:]) + assert all( + aval_case == aval_out for aval_case in avals_in[1:] + if not aval_case.sharding.mesh.empty and not aval_out.sharding.mesh.empty) select_lower = _select_hlo_lowering physical_aval_out = core.physical_aval(aval_out) @@ -7134,7 +7436,8 @@ def _select(offset, cases): select_n_p = standard_primitive( _select_shape_rule, _select_dtype_rule, 'select_n', - weak_type_rule=_select_weak_type_rule, sharding_rule=_select_sharding_rule) + weak_type_rule=_select_weak_type_rule, sharding_rule=_select_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'select_n')) ad.primitive_jvps[select_n_p] = _select_jvp ad.primitive_transposes[select_n_p] = _select_transpose_rule batching.primitive_batchers[select_n_p] = _select_batch_rule @@ -7154,6 +7457,11 @@ def _reduce_sharding_rule(*avals, computation, jaxpr, dimensions): return [op.sharding.with_spec(tuple_delete(op.sharding.spec, dimensions)) for op in operand_avals] +def _reduce_vma_rule(*avals, computation, jaxpr, dimensions): + operand_avals, _ = split_list(avals, [len(avals) // 2]) + out_vma = core.standard_vma_rule('reduce', *operand_avals) + return [out_vma] * len(operand_avals) + def _reduce_dtype_rule(*avals, computation, jaxpr, dimensions): operand_avals, init_val_avals = split_list(avals, [len(avals) // 2]) operand_dtypes = [dtypes.canonicalize_dtype(op.dtype) for op in operand_avals] @@ -7240,7 +7548,8 @@ def _reduce_jvp_rule(primals, tangents, *, computation, jaxpr, dimensions): reduce_p.def_impl(partial(dispatch.apply_primitive, reduce_p)) reduce_p.def_abstract_eval( partial(standard_multi_result_abstract_eval, reduce_p, _reduce_shape_rule, - _reduce_dtype_rule, _reduce_weak_type_rule, _reduce_sharding_rule)) + _reduce_dtype_rule, _reduce_weak_type_rule, _reduce_sharding_rule, + _reduce_vma_rule)) batching.primitive_batchers[reduce_p] = _reduce_batch_rule ad.primitive_jvps[reduce_p] = _reduce_jvp_rule @@ -7314,7 +7623,8 @@ def _reduce_op_sharding_rule(operand, *, axes): reduce_sum_p = standard_primitive( _reduce_op_shape_rule, partial(_reduce_number_dtype_rule, 'reduce_sum'), - 'reduce_sum', sharding_rule=_reduce_op_sharding_rule) + 'reduce_sum', sharding_rule=_reduce_op_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'reduce_sum')) ad.deflinear2(reduce_sum_p, _reduce_sum_transpose_rule) batching.defreducer(reduce_sum_p, _get_sum_identity) pe.padding_rules[reduce_sum_p] = partial(_reducer_padding, reduce_sum, @@ -7329,7 +7639,8 @@ def _reduce_prod_jvp_rule(primals, tangents, *, axes): reduce_prod_p = standard_primitive( _reduce_op_shape_rule, partial(_reduce_number_dtype_rule, 'reduce_prod'), - 'reduce_prod', sharding_rule=_reduce_op_sharding_rule) + 'reduce_prod', sharding_rule=_reduce_op_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'reduce_prod')) ad.primitive_jvps[reduce_prod_p] = _reduce_prod_jvp_rule batching.defreducer(reduce_prod_p, _get_prod_identity) pe.padding_rules[reduce_prod_p] = partial(_reducer_padding, reduce_prod, @@ -7349,7 +7660,8 @@ def _reduce_chooser_jvp_rule(g, ans, operand, *, axes): reduce_max_p = standard_primitive( _reduce_op_shape_rule, _input_dtype, 'reduce_max', - sharding_rule=_reduce_op_sharding_rule) + sharding_rule=_reduce_op_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'reduce_max')) ad.defjvp2(reduce_max_p, _reduce_chooser_jvp_rule) batching.defreducer(reduce_max_p, _get_max_identity) pe.padding_rules[reduce_max_p] = partial(_reducer_padding, reduce_max, @@ -7359,7 +7671,8 @@ def _reduce_chooser_jvp_rule(g, ans, operand, *, axes): reduce_min_p = standard_primitive( _reduce_op_shape_rule, _input_dtype, 'reduce_min', - sharding_rule=_reduce_op_sharding_rule) + sharding_rule=_reduce_op_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'reduce_min')) ad.defjvp2(reduce_min_p, _reduce_chooser_jvp_rule) batching.defreducer(reduce_min_p, _get_min_identity) pe.padding_rules[reduce_min_p] = partial(_reducer_padding, reduce_min, @@ -7426,13 +7739,15 @@ def _compute_argminmax(value_comparator, get_identity, argmin_p = standard_primitive(_argminmax_shape_rule, _argminmax_dtype_rule, 'argmin', weak_type_rule=_strip_weak_type, - sharding_rule=_argminmax_sharding_rule) + sharding_rule=_argminmax_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'argmin')) batching.defreducer(argmin_p, _get_min_identity) ad.defjvp_zero(argmin_p) argmax_p = standard_primitive(_argminmax_shape_rule, _argminmax_dtype_rule, 'argmax', weak_type_rule=_strip_weak_type, - sharding_rule=_argminmax_sharding_rule) + sharding_rule=_argminmax_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'argmax')) batching.defreducer(argmax_p, _get_max_identity) ad.defjvp_zero(argmax_p) @@ -7455,20 +7770,23 @@ def _reduce_logical_sharding_rule(operand, *, axes): reduce_or_p = standard_primitive( _reduce_logical_shape_rule, _input_dtype, 'reduce_or', - weak_type_rule=_strip_weak_type, sharding_rule=_reduce_logical_sharding_rule) + weak_type_rule=_strip_weak_type, sharding_rule=_reduce_logical_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'reduce_or')) batching.defreducer(reduce_or_p, _get_bitwise_or_identity) reduce_and_p = standard_primitive( _reduce_logical_shape_rule, _input_dtype, 'reduce_and', - weak_type_rule=_strip_weak_type, sharding_rule=_reduce_logical_sharding_rule) + weak_type_rule=_strip_weak_type, sharding_rule=_reduce_logical_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'reduce_and')) batching.defreducer(reduce_and_p, _get_bitwise_and_identity) batching.ragged_prop_rules[reduce_and_p] = batching.ragged_mask_elementwise_rule reduce_xor_p = standard_primitive( _reduce_logical_shape_rule, _input_dtype, 'reduce_xor', - weak_type_rule=_strip_weak_type, sharding_rule=_reduce_logical_sharding_rule) + weak_type_rule=_strip_weak_type, sharding_rule=_reduce_logical_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'reduce_xor')) batching.defreducer(reduce_xor_p, _get_bitwise_or_identity) @@ -7515,7 +7833,8 @@ def _reduce_precision_sharding_rule(operand, *, exponent_bits, mantissa_bits): reduce_precision_p = standard_primitive( _reduce_precision_shape_rule, partial(unop_dtype_rule, _identity, _float, 'reduce_precision'), - name='reduce_precision', sharding_rule=_reduce_precision_sharding_rule) + name='reduce_precision', sharding_rule=_reduce_precision_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'reduce_precision')) ad.deflinear(reduce_precision_p, lambda t, **kwargs: [reduce_precision_p.bind(t, **kwargs)]) batching.defvectorized(reduce_precision_p) @@ -7792,6 +8111,7 @@ def after_all(*operands): """Merges one or more XLA token values. Experimental. Wraps the XLA AfterAll operator.""" + operands = core.standard_insert_pvary(*operands) return after_all_p.bind(*operands) def _after_all_abstract_eval(*operands): @@ -7926,6 +8246,7 @@ def rng_uniform(a, b, shape): This API may be removed at any time. """ + a, b = core.standard_insert_pvary(a, b) return rng_uniform_p.bind(a, b, shape=tuple(shape)) def _rng_uniform_abstract_eval(a, b, *, shape): @@ -7952,15 +8273,24 @@ def _rng_uniform_lowering(ctx, a, b, *, shape): mlir.register_lowering(rng_uniform_p, _rng_uniform_lowering) -def _rng_bit_generator_shape_rule(key, *, shape, dtype, algorithm): +def _rng_bit_generator_shape_rule(key, *, shape, dtype, algorithm, out_sharding): del dtype, algorithm return (key.shape, tuple(shape)) -def _rng_bit_generator_dtype_rule(key, *, shape, dtype, algorithm): +def _rng_bit_generator_sharding_rule(key, *, shape, dtype, algorithm, + out_sharding): + return (key.sharding, out_sharding) + +def _rng_bit_generator_vma_rule(key, *, shape, dtype, algorithm, out_sharding): + assert key.vma == frozenset() + return (key.vma, frozenset()) + +def _rng_bit_generator_dtype_rule(key, *, shape, dtype, algorithm, out_sharding): del shape, algorithm return (key.dtype, dtype) -def _rng_bit_generator_weak_type_rule(key, *, shape, dtype, algorithm): +def _rng_bit_generator_weak_type_rule(key, *, shape, dtype, algorithm, + out_sharding): del shape, dtype, algorithm return (key.weak_type, False) @@ -7991,7 +8321,7 @@ def _rng_algorithm(algorithm: RandomAlgorithm): assert False def _rng_bit_generator_lowering( - ctx, key, *, shape, dtype, algorithm): + ctx, key, *, shape, dtype, algorithm, out_sharding): key_type = ir.RankedTensorType(key.type) key_shape, key_etype = key_type.shape, key_type.element_type # While the RngBitGenerator HLO accepts a u64[2] key on all backends, we @@ -8020,7 +8350,7 @@ def _rng_bit_generator_lowering( ir.RankedTensorType.get([2], u64_type), hlo.reshape(ir.RankedTensorType.get([2, 2], u32_type), key)) algorithm_attr = _rng_algorithm(algorithm) - _, out_vals_aval = ctx.avals_out + out_key_aval, out_vals_aval = ctx.avals_out if any(not core.is_constant_shape(a.shape) for a in ctx.avals_out): output_shape = mlir.shape_tensor( mlir.eval_dynamic_shape(ctx, out_vals_aval.shape)) @@ -8044,7 +8374,8 @@ def _rng_bit_generator_lowering( out_vals = hlo.convert( ir.RankedTensorType.get(ir.RankedTensorType(out_vals.type).shape, etype), out_vals) - return [out_key, out_vals] + return [mlir.lower_with_sharding_in_types(ctx, out_key, out_key_aval), + mlir.lower_with_sharding_in_types(ctx, out_vals, out_vals_aval)] rng_bit_generator_p = Primitive("rng_bit_generator") @@ -8054,7 +8385,8 @@ def _rng_bit_generator_lowering( rng_bit_generator_p.def_abstract_eval( partial(standard_multi_result_abstract_eval, rng_bit_generator_p, _rng_bit_generator_shape_rule, _rng_bit_generator_dtype_rule, - _rng_bit_generator_weak_type_rule, None)) + _rng_bit_generator_weak_type_rule, _rng_bit_generator_sharding_rule, + _rng_bit_generator_vma_rule)) mlir.register_lowering(rng_bit_generator_p, _rng_bit_generator_lowering) @@ -8118,7 +8450,8 @@ def _propagate_mem_kind_copy(in_mem_kind): pxla.memory_kind_propagate_rule[copy_p] = _propagate_mem_kind_copy def rng_bit_generator(key, shape, dtype=np.uint32, - algorithm=RandomAlgorithm.RNG_DEFAULT): + algorithm=RandomAlgorithm.RNG_DEFAULT, + *, out_sharding=None): """Stateless PRNG bit generator. Experimental and its use is discouraged. Returns uniformly distributed random bits with the specified shape and dtype @@ -8134,12 +8467,14 @@ def rng_bit_generator(key, shape, dtype=np.uint32, """ shape = core.canonicalize_shape(shape) dtype = dtypes.canonicalize_dtype(dtype) + out_sharding = canonicalize_sharding(out_sharding, 'rng_bit_generator') if np.dtype(dtype) not in {np.dtype('uint8'), np.dtype('uint16'), np.dtype('uint32'), np.dtype('uint64')}: raise TypeError(f'rng_bit_generator: unsupported dtype {dtype}') return tuple( rng_bit_generator_p.bind( - key, shape=shape, dtype=dtype, algorithm=algorithm)) + key, shape=shape, dtype=dtype, algorithm=algorithm, + out_sharding=out_sharding)) def _iota_abstract_eval(*dyn_shape, dtype, shape, dimension, sharding): @@ -8594,11 +8929,13 @@ def optimization_barrier(operand, /): Array(0., dtype=float32, weak_type=True) """ flat_args, treedef = tree_util.tree_flatten(operand) - return tree_util.tree_unflatten( - treedef, optimization_barrier_p.bind(*flat_args)) + flat_args = core.standard_insert_pvary(*flat_args) + out = optimization_barrier_p.bind(*flat_args) + return tree_util.tree_unflatten(treedef, out) def _optimization_barrier_abstract_eval(*args): + core.standard_vma_rule('optimization_barrier', *args) return args def _optimization_barrier_lowering_rule(ctx, *args): diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index c674401fb80d..a49936373835 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -121,6 +121,7 @@ def cholesky_update(r_matrix: ArrayLike, w_vector: ArrayLike) -> Array: A new upper-triangular matrix :math:`R` defining the Cholesky decomposition of :math:`A + w \, w^T`. """ + r_matrix, w_vector = core.standard_insert_pvary(r_matrix, w_vector) return cholesky_update_p.bind(r_matrix, w_vector) @@ -268,6 +269,7 @@ def householder_product(a: ArrayLike, taus: ArrayLike) -> Array: A batch of orthogonal (unitary) matrices with the same shape as ``a``, containing the products of the elementary Householder reflectors. """ + a, taus = core.standard_insert_pvary(a, taus) return householder_product_p.bind(a, taus) @@ -526,7 +528,7 @@ def symmetric_product( Computes the symmetric product - ..math:: + .. math:: \alpha \, A \, A^T + \beta \, C where :math:`A` is a rectangular matrix and :math:`C` is a symmetric matrix. @@ -545,6 +547,7 @@ def symmetric_product( ``symmetrize_output`` is ``True``, the upper triangle is filled with the transpose of the lower triangle, and the whole matrix is valid. """ + a_matrix, c_matrix = core.standard_insert_pvary(a_matrix, c_matrix) result = symmetric_product_p.bind(a_matrix, c_matrix, alpha=alpha, beta=beta) if symmetrize_output: upper_half = lax.transpose( @@ -602,6 +605,7 @@ def triangular_solve( singleton = np.ndim(b) == np.ndim(a) - 1 if singleton: b = lax.expand_dims(b, (-1 if left_side else -2,)) + a, b = core.standard_insert_pvary(a, b) out = triangular_solve_p.bind( a, b, left_side=left_side, lower=lower, transpose_a=transpose_a, conjugate_a=conjugate_a, unit_diagonal=unit_diagonal) @@ -661,6 +665,7 @@ def tridiagonal_solve(dl: Array, d: Array, du: Array, b: Array) -> Array: Returns: Solution ``X`` of tridiagonal system. """ + dl, d, du, b = core.standard_insert_pvary(dl, d, du, b) return tridiagonal_solve_p.bind(dl, d, du, b) @@ -717,14 +722,14 @@ def linalg_sharding_rule( spec = aval.sharding.spec batch_spec, rest_spec = spec[:len(spec) - rank], spec[len(spec) - rank:] if not all(s is None for s in rest_spec): - raise ValueError( + raise core.ShardingTypeError( f"Input {i} to {name} must be unsharded on non-batch dimensions, " f"but got {spec}." ) batch_specs.append(batch_spec) batch_spec = batch_specs[0] if any(b != batch_spec for b in batch_specs[1:]): - raise ValueError( + raise core.ShardingTypeError( f"All inputs to {name} must have the same batch sharding, but got " f"{batch_specs}." ) @@ -740,6 +745,14 @@ def linalg_sharding_rule( ndim = len(output_shapes) - len(batch_spec) return sharding.with_spec(P(*(tuple(batch_spec) + (None,) * ndim))) +def linalg_vma_rule(multiple_results, shape_rule, name, *avals, **kwargs): + output_shapes = shape_rule(*avals, **kwargs) + out_vma = core.standard_vma_rule(name, *avals) + if multiple_results: + return [out_vma] * len(output_shapes) + else: + return out_vma + def linalg_primitive(result_dtype, accepted_dtypes, ranks, result_shape, name, multiple_results=False, supports_batching=True, require_same=True): @@ -754,6 +767,7 @@ def linalg_primitive(result_dtype, accepted_dtypes, ranks, result_shape, name, linalg_sharding_rule, multiple_results, shape_rule, ranks, name) else: sharding_rule = None + vma_rule = partial(linalg_vma_rule, multiple_results, shape_rule, name) prim = core.Primitive(name) prim.multiple_results = multiple_results prim.def_impl(partial(dispatch.apply_primitive, prim)) @@ -761,11 +775,12 @@ def linalg_primitive(result_dtype, accepted_dtypes, ranks, result_shape, name, prim.def_abstract_eval( partial(lax_utils.standard_multi_result_abstract_eval, prim, shape_rule, dtype_rule, lax_utils._standard_weak_type_rule, - sharding_rule)) + sharding_rule, vma_rule)) else: prim.def_abstract_eval( partial(lax_utils.standard_abstract_eval, prim, shape_rule, dtype_rule, - lax_utils._standard_weak_type_rule, sharding_rule)) + lax_utils._standard_weak_type_rule, sharding_rule, + partial(core.standard_vma_rule, name))) if supports_batching: batching.primitive_batchers[prim] = partial( batching.expand_dims_batcher, prim) @@ -1643,6 +1658,7 @@ def _generic_lu_pivots_to_permutation(swaps, permutation_size): if m == 0 or k == 0: return permutation upper = np.array(k, np.int32) if is_constant_dim(k) else k + permutation, swaps = core.standard_insert_pvary(permutation, swaps) result, _ = lax.fori_loop(np.array(0, np.int32), upper, _lu_pivots_body_fn, (permutation, swaps)) return result @@ -1758,6 +1774,7 @@ def geqp3(a: ArrayLike, jpvt: ArrayLike, *, elementary Householder reflectors, and ``jpvt`` is the column-pivot indices such that ``a[:, jpvt] = q @ r``. """ + a, jpvt = core.standard_insert_pvary(a, jpvt) a_out, jpvt_out, taus = geqp3_p.bind(a, jpvt, use_magma=use_magma) return a_out, jpvt_out, taus @@ -2511,16 +2528,30 @@ def _tridiagonal_solve_shape_rule(dl_shape, d_shape, du_shape, b_shape, **_): "equal the dimensions of the diagonal arguments.") return b_shape -def _tridiagonal_solve_gpu_lowering(lowering, ctx, dl, d, du, b): +def _tridiagonal_solve_gpu_lowering(ctx, dl, d, du, b, *, target_name_prefix): _, _, _, b_aval = ctx.avals_in - if b_aval.dtype != np.float32 and b_aval.dtype != np.float64: + *batch_dims, m, n = b_aval.shape + batch_size = math.prod(batch_dims) + + mod = gpu_sparse._cusparse if target_name_prefix == "cu" else gpu_sparse._hipsparse + assert mod is not None + opaque = mod.build_gtsv2_descriptor(batch_size, m, n, m) + if b_aval.dtype == np.float32: + buffer_size = mod.gtsv2_f32_buffer_size(m, n, m) + target_name = "sparse_gtsv2_f32_ffi" + elif b_aval.dtype == np.float64: + buffer_size = mod.gtsv2_f64_buffer_size(m, n, m) + target_name = "sparse_gtsv2_f64_ffi" + else: raise NotImplementedError( "tridiagonal_solve is only implemented for float32 and float64 on GPU.") - m, n = b_aval.shape[-2:] - b_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, b_aval.shape) - return [lowering( - dl, d, du, b, m=m, n=n, ldb=m, t=b_aval.dtype, - b_shape_vals=b_shape_vals)] + + buffer_aval = core.ShapedArray(shape=(buffer_size,), dtype=np.int8) + sub_ctx = ctx.replace(avals_out=[*ctx.avals_out, buffer_aval]) + rule = _linalg_ffi_lowering( + f"{target_name_prefix}{target_name}", operand_output_aliases={3: 0}, + batch_partitionable=False) + return rule(sub_ctx, dl, d, du, b, opaque=opaque)[:1] def _tridiagonal_solve_cpu_lowering(ctx, dl, d, du, b, **kwargs): del kwargs # unused @@ -2628,11 +2659,11 @@ def _tridiagonal_solve_jax(dl, d, du, b, **_): platform='cpu') mlir.register_lowering( tridiagonal_solve_p, - partial(_tridiagonal_solve_gpu_lowering, gpu_sparse.cuda_gtsv2), + partial(_tridiagonal_solve_gpu_lowering, target_name_prefix='cu'), platform='cuda') mlir.register_lowering( tridiagonal_solve_p, - partial(_tridiagonal_solve_gpu_lowering, gpu_sparse.rocm_gtsv2), + partial(_tridiagonal_solve_gpu_lowering, target_name_prefix='hip'), platform='rocm') mlir.register_lowering(tridiagonal_solve_p, mlir.lower_fun( _tridiagonal_solve_jax, multiple_results=False)) diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index 221fe2a9e87a..1df3de6f1bee 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -25,6 +25,7 @@ import jax from jax import tree_util from jax._src import core +from jax._src import config from jax._src import dispatch from jax._src import dtypes from jax._src.sharding_impls import (SPMDAxisContext, ShardingContext, @@ -34,6 +35,8 @@ from jax._src.interpreters import batching from jax._src.interpreters import mlir from jax._src.interpreters import pxla +from jax._src.mesh import get_abstract_mesh +from jax._src.core import pvary from jax._src.lax import lax from jax._src.lax import slicing from jax._src.lib.mlir import ir @@ -115,6 +118,8 @@ def psum(x, axis_name, *, axis_index_groups=None): """ if not isinstance(axis_name, (tuple, list)): axis_name = (axis_name,) + if not axis_name: + return x if any(isinstance(axis, int) for axis in axis_name) and axis_index_groups is not None: raise ValueError("axis_index_groups only supported for sums over just named axes") _validate_reduce_axis_index_groups(axis_index_groups) @@ -139,10 +144,27 @@ def pos_reduce(x): size = math.prod([core.get_axis_env().axis_size(name) for name in named_axes]) out_flat = tuple(lax._const(leaf, size) * pos_reduce(leaf) for leaf in leaves) else: - out_flat = psum_p.bind( - *leaves, axes=tuple(axis_name), axis_index_groups=axis_index_groups) + if config._check_rep.value: + out_flat = bind_psum_invariant( + leaves, axes=tuple(axis_name), axis_index_groups=axis_index_groups) + else: + out_flat = psum_p.bind( + *leaves, axes=tuple(axis_name), axis_index_groups=axis_index_groups) return tree_util.tree_unflatten(treedef, out_flat) +def bind_psum_invariant(leaves, *, axes, axis_index_groups): + if axis_index_groups is not None: + raise NotImplementedError + axes_ = frozenset(axes) + args_ = [] + for x in leaves: + in_vma = core.get_aval(x).vma + args_.append(pvary(x, tuple(pbroadcast_names)) + if (pbroadcast_names := axes_ - in_vma) else x) + return psum_invariant_p.bind(*args_, axes=axes, + axis_index_groups=axis_index_groups) + + def pmean(x, axis_name, *, axis_index_groups=None): """Compute an all-reduce mean on ``x`` over the pmapped axis ``axis_name``. @@ -202,6 +224,7 @@ def pmax(x, axis_name, *, axis_index_groups=None): _validate_reduce_axis_index_groups(axis_index_groups) leaves, treedef = tree_util.tree_flatten(x) axis_index_groups = _canonicalize_axis_index_groups(axis_index_groups) + leaves = map(partial(insert_collective_pvary, axis_name), leaves) out_flat = pmax_p.bind(*leaves, axes=axis_name, axis_index_groups=axis_index_groups) return tree_util.tree_unflatten(treedef, out_flat) @@ -232,6 +255,7 @@ def pmin(x, axis_name, *, axis_index_groups=None): _validate_reduce_axis_index_groups(axis_index_groups) leaves, treedef = tree_util.tree_flatten(x) axis_index_groups = _canonicalize_axis_index_groups(axis_index_groups) + leaves = map(partial(insert_collective_pvary, axis_name), leaves) out_flat = pmin_p.bind(*leaves, axes=axis_name, axis_index_groups=axis_index_groups) return tree_util.tree_unflatten(treedef, out_flat) @@ -325,9 +349,10 @@ def ppermute(x, axis_name, perm): """ if not isinstance(axis_name, (list, tuple)): axis_name = (axis_name,) - return tree_util.tree_map( - partial(ppermute_p.bind, axis_name=axis_name, - perm=tuple(map(tuple, perm))), x) + def bind(leaf): + leaf = insert_collective_pvary(axis_name, leaf) + return ppermute_p.bind(leaf, axis_name=axis_name, perm=tuple(map(tuple, perm))) + return tree_util.tree_map(bind, x) def pshuffle(x, axis_name, perm): """Convenience wrapper of jax.lax.ppermute with alternate permutation encoding @@ -447,6 +472,7 @@ def bind(x, split_axis=split_axis, concat_axis=concat_axis): else: # concat_axis < split_axis x = lax.expand_dims(x, (concat_axis,)) # insert the new axis split_axis += 1 # we have a new axis before split_axis now + x = insert_collective_pvary(axis_name, x) result = all_to_all_p.bind(x, split_axis=split_axis, concat_axis=concat_axis, axis_name=axis_name, axis_index_groups=axis_index_groups, @@ -800,6 +826,48 @@ def _allreduce_effectful_abstract_eval(*args, axes, axis_index_groups): ] return out_avals, {core.NamedAxisEffect(axis) for axis in named_axes} +def _psum_invariant_abstract_eval(name, *args, axes, axis_index_groups): + if not config._check_rep.value: + return psum_p.abstract_eval( + *args, axes=axes, axis_index_groups=axis_index_groups) + + assert isinstance(axes, tuple) + _check_axis_names(axes) + arg_vma = [a.vma for a in args] + # If intersection between arg_vma and axes is empty, error + if any(not set(axes) & a for a in arg_vma): + raise ValueError( + f"Collective {name} must be applied to a device-varying " + f"type, but got {arg_vma} for collective acting " + f"over axis name {axes}. Please open an issue at " + "https://github.com/jax-ml/jax/issues, and as a temporary " + "workaround pass the check_rep=False argument to shard_map") + + named_axes = tuple(axis for axis in axes if not isinstance(axis, int)) + pos_axes = tuple(axis for axis in axes if isinstance(axis, int)) + if axis_index_groups is not None: + if len(pos_axes) != 0: + raise ValueError( + "axis_index_groups can only be used with reductions over " + f"named axes, but got: {axes}") + core.check_avals_context_mesh(args, 'all_reduce') + out_avals = [ + core.ShapedArray( + lax._reduce_op_shape_rule(arg, axes=pos_axes), arg.dtype, + sharding=lax._reduce_op_sharding_rule(arg, axes=pos_axes), + vma=frozenset(a for a in arg.vma if a not in named_axes)) + for arg in args + ] + return out_avals, {core.NamedAxisEffect(axis) for axis in named_axes} + +# TODO(yashkatariya): Replace this with _psum_invariant_abstract_eval +def _pmin_pmax_abstract_eval(name, *args, axes, axis_index_groups): + if not config._check_rep.value: + return _allreduce_effectful_abstract_eval( + *args, axes=axes, axis_index_groups=axis_index_groups) + return _psum_invariant_abstract_eval( + name, *args, axes=axes, axis_index_groups=axis_index_groups) + def _check_axis_names(axes): named_axes = tuple(axis for axis in axes if not isinstance(axis, int)) axis_env = core.get_axis_env() @@ -899,7 +967,7 @@ def broadcast_positional(ct, arg): pmax_p = core.Primitive('pmax') pmax_p.multiple_results = True pmax_p.def_impl(partial(_allreduce_impl, pmax_p, lax.reduce_max)) -pmax_p.def_effectful_abstract_eval(_allreduce_effectful_abstract_eval) +pmax_p.def_effectful_abstract_eval(partial(_pmin_pmax_abstract_eval, 'pmax')) mlir.register_lowering( pmax_p, partial(_allreduce_lowering, lax.max_p, lax.reduce_max)) batching.fancy_primitive_batchers[pmax_p] = \ @@ -910,7 +978,7 @@ def broadcast_positional(ct, arg): pmin_p = core.Primitive('pmin') pmin_p.multiple_results = True pmin_p.def_impl(partial(_allreduce_impl, pmin_p, lax.reduce_min)) -pmin_p.def_effectful_abstract_eval(_allreduce_effectful_abstract_eval) +pmin_p.def_effectful_abstract_eval(partial(_pmin_pmax_abstract_eval, 'pmin')) mlir.register_lowering( pmin_p, partial(_allreduce_lowering, lax.min_p, lax.reduce_min)) batching.fancy_primitive_batchers[pmin_p] = \ @@ -975,6 +1043,7 @@ def _ppermute_batcher(axis_data, vals_in, dims_in, axis_name, perm): def _raise_to_shaped_abstract_eval(x, *, axis_name, **params): _check_axis_names(axis_name) + collective_vma_rule('ppermute', axis_name, x) return x ppermute_p = core.Primitive('ppermute') @@ -1109,15 +1178,15 @@ def _all_to_all_batcher(vals_in, dims_in, *, axis_name, split_axis, concat_axis, def _all_to_all_batched_collective(axis_data, vals_in, dims_in, axis_name, split_axis, concat_axis, axis_index_groups, tiled): - axis_size, frame_name = axis_data.size, axis_data.name if axis_index_groups is not None: raise NotImplementedError("Please open a feature request!") + axis_size, frame_name = axis_data.size, axis_data.name if isinstance(axis_name, (list, tuple)): axes_names = axis_name else: axes_names = [axis_name] - if axis_data.name not in axes_names: + if frame_name not in axes_names: return _all_to_all_batcher( vals_in, dims_in, axis_name=axis_name, split_axis=split_axis, concat_axis=concat_axis, axis_index_groups=axis_index_groups, tiled=tiled) @@ -1157,6 +1226,7 @@ def _all_to_all_batched_collective(axis_data, vals_in, dims_in, axis_index_groups=axis_index_groups, tiled=tiled) # Split out the local part into axis new_d (NOTE: d is already in axis 1) + assert d == 1 x = _splitaxis(split_axis, axis_size, x) new_d = split_axis concat_axis += (split_axis <= concat_axis) # Offset the existing axes by the new batch axis @@ -1188,7 +1258,8 @@ def _all_to_all_effectful_abstract_eval( assert shape[split_axis] % axis_size == 0, (shape[split_axis], axis_size) shape[split_axis] //= axis_size shape[concat_axis] *= axis_size - out_aval = input_aval.update(shape=tuple(shape), weak_type=False) + vma = collective_vma_rule('all_to_all', axis_name, input_aval) + out_aval = input_aval.update(shape=tuple(shape), weak_type=False, vma=vma) effects = {*map(core.NamedAxisEffect, axis_name)} return out_aval, effects @@ -1305,13 +1376,50 @@ def _ragged_all_to_all_transpose( output_t = jax.numpy.where(mask, 0, t) return [operand_t, output_t] + [None] * 4 +def _ragged_all_to_all_batched_collective(axis_data, vals_in, dims_in, + axis_name, axis_index_groups): + del axis_data + if axis_index_groups: + raise NotImplementedError("Please open a feature request!") + + operand, output, input_offsets, send_sizes, output_offsets, recv_sizes = vals_in + operand_dim, output_dim, input_offsets_dim, send_sizes_dim, output_offsets_dim, recv_sizes_dim = dims_in + if not (operand.shape[operand_dim] == output.shape[output_dim] == input_offsets.shape[input_offsets_dim] == send_sizes.shape[send_sizes_dim] == output_offsets.shape[output_offsets_dim] == recv_sizes.shape[recv_sizes_dim]): + raise ValueError("all operands must have the same batch sizes") + + sliced_results = [] + for i in range(operand.shape[operand_dim]): + sliced_operand = slicing.slice_in_dim(operand, start_index=i, limit_index=i+1, axis=operand_dim).flatten() + sliced_output = slicing.slice_in_dim(output, start_index=i, limit_index=i+1, axis=output_dim).flatten() + sliced_input_offsets = slicing.slice_in_dim(input_offsets, start_index=i, limit_index=i+1, axis=input_offsets_dim).flatten() + sliced_send_sizes = slicing.slice_in_dim(send_sizes, start_index=i, limit_index=i+1, axis=send_sizes_dim).flatten() + sliced_output_offsets = slicing.slice_in_dim(output_offsets, start_index=i, limit_index=i+1, axis=output_offsets_dim).flatten() + sliced_recv_sizes = slicing.slice_in_dim(recv_sizes, start_index=i, limit_index=i+1, axis=recv_sizes_dim).flatten() + sliced_result = ragged_all_to_all(sliced_operand, sliced_output, sliced_input_offsets, sliced_send_sizes, sliced_output_offsets, sliced_recv_sizes, axis_name=axis_name, axis_index_groups=axis_index_groups) + sliced_result = lax.expand_dims(sliced_result, dimensions=(output_dim,)) + sliced_results.append(sliced_result) + + concat_result = lax.concatenate(sliced_results, dimension=output_dim) + return concat_result, operand_dim + ragged_all_to_all_p = core.Primitive('ragged_all_to_all') ragged_all_to_all_p.def_effectful_abstract_eval(_ragged_all_to_all_effectful_abstract_eval) ad.primitive_jvps[ragged_all_to_all_p] = _ragged_all_to_all_jvp ad.primitive_transposes[ragged_all_to_all_p] = _ragged_all_to_all_transpose mlir.register_lowering(ragged_all_to_all_p, _ragged_all_to_all_lowering) +batching.fancy_primitive_batchers[ragged_all_to_all_p] = _ragged_all_to_all_batched_collective batching.skippable_batchers[ragged_all_to_all_p] = partial(_names_in_param, 'axis_name') +def insert_collective_pvary(axis_name, x): + if not config._check_rep.value: + return x + + axis_name = (axis_name,) if not isinstance(axis_name, tuple) else axis_name + aval = core.get_aval(x) + names_union = set(axis_name) | aval.vma + x = pvary(x, tuple(n for n in names_union if n not in aval.vma)) + return x + def all_gather(x, axis_name, *, axis_index_groups=None, axis=0, tiled=False): """Gather values of x across all replicas. @@ -1381,6 +1489,7 @@ def all_gather(x, axis_name, *, axis_index_groups=None, axis=0, tiled=False): axis_index_groups = _canonicalize_axis_index_groups(axis_index_groups) axis_size = psum(1, axis_name, axis_index_groups=axis_index_groups) def bind(leaf): + leaf = insert_collective_pvary(axis_name, leaf) return all_gather_p.bind( leaf, all_gather_dimension=canonicalize_axis( @@ -1433,6 +1542,19 @@ def _all_gather_lowering(ctx, x, *, all_gather_dimension, axis_name, **other_args).results +def collective_vma_rule(prim_name, axis_name, x_aval): + if not config._check_rep.value: + return frozenset() + axis_name = (axis_name,) if not isinstance(axis_name, tuple) else axis_name + if any(a not in x_aval.vma for a in axis_name): + raise ValueError( + f"Collective {prim_name} must be applied to a device-varying " + f" type, but got {x_aval.vma} for collective acting " + f"over axis name {axis_name}. Please open an issue at " + "https://github.com/jax-ml/jax/issues and as a temporary " + "workaround pass the check_rep=False argument to shard_map") + return x_aval.vma + def _all_gather_effectful_abstract_eval( x_aval, *, all_gather_dimension, axis_name, axis_index_groups, axis_size, tiled ): @@ -1444,7 +1566,9 @@ def _all_gather_effectful_abstract_eval( new_shape[all_gather_dimension] *= axis_size else: new_shape.insert(all_gather_dimension, axis_size) - return x_aval.update(shape=new_shape), {*map(core.NamedAxisEffect, axis_name)} + out_vma = collective_vma_rule('all_gather', axis_name, x_aval) + return (x_aval.update(shape=new_shape, vma=out_vma), + {*map(core.NamedAxisEffect, axis_name)}) def _all_gather_transpose_rule(cts, x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size, tiled): return (psum_scatter(cts, axis_name=axis_name, @@ -1581,7 +1705,9 @@ def _reduce_scatter_effectful_abstract_eval( f"{scatter_dim_input_size} must match shard count " f"{axis_size}") del new_shape[scatter_dimension] - return x_aval.update(shape=new_shape), {*map(core.NamedAxisEffect, axis_name)} + vma = collective_vma_rule('reduce_scatter', axis_name, x_aval) + return (x_aval.update(shape=new_shape, vma=vma), + {*map(core.NamedAxisEffect, axis_name)}) def _reduce_scatter_transpose_rule(cts, x, *, axis_name, scatter_dimension, @@ -1725,13 +1851,11 @@ def psum_scatter(x, axis_name, *, scatter_dimension=0, axis_index_groups=None, axis_name = axis_name, axis_size = psum(1, axis_name, axis_index_groups=axis_index_groups) axis_index_groups = _canonicalize_axis_index_groups(axis_index_groups) - bind = partial( - reduce_scatter_p.bind, - axis_name=axis_name, - scatter_dimension=scatter_dimension, - axis_index_groups=axis_index_groups, - axis_size=axis_size, - tiled=tiled) + def bind(leaf): + leaf = insert_collective_pvary(axis_name, leaf) + return reduce_scatter_p.bind( + leaf, axis_name=axis_name, scatter_dimension=scatter_dimension, + axis_index_groups=axis_index_groups, axis_size=axis_size, tiled=tiled) return tree_util.tree_map(bind, x) @@ -1781,8 +1905,14 @@ def _axis_index_lowering(ctx, *, axis_name): ctx.module_context.axis_env)] def _axis_index_effectful_abstract_eval(*, axis_name): - _check_axis_names([axis_name]) - return ShapedArray((), np.int32), {core.NamedAxisEffect(axis_name)} + effect = {core.NamedAxisEffect(axis_name)} + axis_name = (axis_name,) if not isinstance(axis_name, tuple) else axis_name + _check_axis_names(axis_name) + mesh = get_abstract_mesh() + sharding = NamedSharding(mesh, P()) + vma = ((frozenset(axis_name) if mesh._any_axis_manual else frozenset()) + if config._check_rep.value else frozenset()) + return ShapedArray((), np.int32, sharding=sharding, vma=vma), effect def _axis_index_batcher(axis_data, vals_in, dims_in, *, axis_name): return lax.iota(np.int32, axis_data.size), 0 @@ -1856,3 +1986,19 @@ def _pgather_collective_batcher(axis_size, frame_name, _, vals_in, dims_in, *, a # TODO: Transpose? That requires adding pscatter... batching.fancy_primitive_batchers[pgather_p] = _pgather_collective_batcher batching.skippable_batchers[pgather_p] = partial(_names_in_param, 'axes') + +psum_invariant_p = core.Primitive('psum_invariant') +psum_invariant_p.multiple_results = True +psum_invariant_p.def_impl(psum_p.impl) +psum_invariant_p.def_effectful_abstract_eval( + partial(_psum_invariant_abstract_eval, psum_invariant_p.name)) +mlir.register_lowering(psum_invariant_p, mlir._lowerings[psum_p]) +batching.fancy_primitive_batchers[psum_invariant_p] = partial( + _batched_reduction_collective, psum_invariant_p, + lambda v, axis_size: axis_size * v) +batching.skippable_batchers[psum_invariant_p] = partial(_names_in_param, 'axes') + +def _psum_invariant_transpose_rule(cts, *args, axes, axis_index_groups): + del args + return core.pvary_p.bind(*cts, axes=axes, axis_index_groups=axis_index_groups) +ad.deflinear2(psum_invariant_p, _psum_invariant_transpose_rule) diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index c26de99c7374..ed7c2f9f2777 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -173,6 +173,8 @@ def dynamic_slice( else: dynamic_sizes = [] static_sizes = core.canonicalize_shape(slice_sizes) # type: ignore + operand, *start_indices = core.standard_insert_pvary( + operand, *start_indices) return dynamic_slice_p.bind(operand, *start_indices, *dynamic_sizes, slice_sizes=tuple(static_sizes)) @@ -234,6 +236,8 @@ def dynamic_update_slice( """ start_indices = _dynamic_slice_indices( operand, start_indices, allow_negative_indices) + operand, update, *start_indices = core.standard_insert_pvary( + operand, update, *start_indices) return dynamic_update_slice_p.bind(operand, update, *start_indices) @@ -416,6 +420,7 @@ def gather(operand: ArrayLike, start_indices: ArrayLike, raise ValueError(f"Unsupported dtype for gather fill_value {dtype}") else: fill_value = None + operand, start_indices = core.standard_insert_pvary(operand, start_indices) return gather_p.bind( operand, start_indices, dimension_numbers=dimension_numbers, slice_sizes=core.canonicalize_shape(slice_sizes), @@ -505,6 +510,8 @@ def scatter_add( """ jaxpr, consts = lax._reduction_jaxpr(lax.add, core.get_aval(lax._const(operand, 0))) + operand, scatter_indices, updates = core.standard_insert_pvary( + operand, scatter_indices, updates) return scatter_add_p.bind( operand, scatter_indices, updates, update_jaxpr=jaxpr, update_consts=consts, dimension_numbers=dimension_numbers, @@ -559,6 +566,8 @@ def scatter_sub( jaxpr, consts = lax._reduction_jaxpr( lax.sub, core.get_aval(lax._const(operand, 0)) ) + operand, scatter_indices, updates = core.standard_insert_pvary( + operand, scatter_indices, updates) return scatter_sub_p.bind( operand, scatter_indices, @@ -613,6 +622,8 @@ def scatter_mul( """ jaxpr, consts = lax._reduction_jaxpr(lax.mul, core.get_aval(lax._const(operand, 1))) + operand, scatter_indices, updates = core.standard_insert_pvary( + operand, scatter_indices, updates) return scatter_mul_p.bind( operand, scatter_indices, updates, update_jaxpr=jaxpr, update_consts=consts, dimension_numbers=dimension_numbers, @@ -660,6 +671,8 @@ def scatter_min( """ jaxpr, consts = lax._reduction_jaxpr(lax.min, core.get_aval(lax._const(operand, 0))) + operand, scatter_indices, updates = core.standard_insert_pvary( + operand, scatter_indices, updates) return scatter_min_p.bind( operand, scatter_indices, updates, update_jaxpr=jaxpr, update_consts=consts, dimension_numbers=dimension_numbers, @@ -707,6 +720,8 @@ def scatter_max( """ jaxpr, consts = lax._reduction_jaxpr(lax.max, core.get_aval(lax._const(operand, 0))) + operand, scatter_indices, updates = core.standard_insert_pvary( + operand, scatter_indices, updates) return scatter_max_p.bind( operand, scatter_indices, updates, update_jaxpr=jaxpr, update_consts=consts, dimension_numbers=dimension_numbers, @@ -771,6 +786,8 @@ def scatter_apply( pass jaxpr, consts = lax._reduction_jaxpr(_apply, core.get_aval(lax._zero(operand))) # TODO: implement this via its own primitive so we can define appropriate autodiff rules. + operand, scatter_indices, unused = core.standard_insert_pvary( + operand, scatter_indices, unused) return scatter_p.bind( operand, scatter_indices, unused, update_jaxpr=jaxpr, update_consts=consts, dimension_numbers=dimension_numbers, @@ -854,6 +871,8 @@ def scatter( ... mode=lax.GatherScatterMode.PROMISE_IN_BOUNDS) Array([0., 2., 3., 0., 4.], dtype=float32) """ + operand, scatter_indices, updates = core.standard_insert_pvary( + operand, scatter_indices, updates) return scatter_p.bind( operand, scatter_indices, updates, update_jaxpr=None, update_consts=(), dimension_numbers=dimension_numbers, @@ -1333,7 +1352,7 @@ def _get_sharding_for_varying_out_shape(out_shape, operand, name): operand.shape, out_shape, operand.sharding.spec): if (op_sh != out_sh and op_spec is not None and out_sh % _get_sub_spec_size(mesh, op_spec) != 0): - raise NotImplementedError( + raise core.ShardingTypeError( f"{name} on sharded dims where out dim ({out_sh}) is not divisble by" f" mesh axes ({_get_sub_spec_size(mesh, op_spec)}) with spec" f" ({op_spec}) is not implemented.") @@ -1393,7 +1412,8 @@ def _slice_batching_rule(batched_args, batch_dims, *, start_indices, return out, bdim slice_p = standard_primitive(_slice_shape_rule, _input_dtype, 'slice', - sharding_rule=_slice_sharding_rule) + sharding_rule=_slice_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'slice')) ad.deflinear2(slice_p, _slice_transpose_rule) batching.primitive_batchers[slice_p] = _slice_batching_rule # TODO(mvoz): A better slice rule for ragged prop, enforcing boundaries @@ -1472,11 +1492,12 @@ def _dynamic_slice_jvp(primals, tangents, *, slice_sizes): def _dynamic_slice_transpose_rule(t, operand, *start_indices, slice_sizes): assert ad.is_undefined_primal(operand) assert all(not ad.is_undefined_primal(s) for s in start_indices) - operand_shape, operand_dtype = operand.aval.shape, operand.aval.dtype if type(t) is ad_util.Zero: return [ad_util.Zero(operand.aval)] + [None] * len(start_indices) else: - zeros = lax.full(operand_shape, 0, operand_dtype) + zeros = lax.full(operand.aval.shape, 0, operand.aval.dtype, + sharding=operand.aval.sharding) + zeros = core.pvary(zeros, tuple(operand.aval.vma)) return ([dynamic_update_slice_p.bind(zeros, t, *start_indices)] + [None] * len(start_indices)) @@ -1558,7 +1579,8 @@ def _dynamic_slice_padding_rule(in_avals, out_avals, x, *starts_and_dyn, dynamic_slice_p = standard_primitive( _dynamic_slice_shape_rule, _dynamic_slice_dtype_rule, 'dynamic_slice', weak_type_rule=_argnum_weak_type(0), - sharding_rule=_dynamic_slice_sharding_rule) + sharding_rule=_dynamic_slice_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'dynamic_slice')) ad.primitive_jvps[dynamic_slice_p] = _dynamic_slice_jvp ad.primitive_transposes[dynamic_slice_p] = _dynamic_slice_transpose_rule batching.primitive_batchers[dynamic_slice_p] = _dynamic_slice_batching_rule @@ -1606,7 +1628,7 @@ def _dynamic_update_slice_shape_rule(operand, update, *start_indices): def _dynamic_update_slice_sharding_rule(operand, update, *start_indices): if operand.sharding != update.sharding: - raise TypeError( + raise core.ShardingTypeError( "dynamic_update_slice update sharding must be equal to operand" " sharding, got update sharding" f" {update.str_short(mesh_axis_types=True)} for operand sharding" @@ -1678,7 +1700,8 @@ def _dynamic_update_slice_batching_rule(batched_args, batch_dims): dynamic_update_slice_p = standard_primitive( _dynamic_update_slice_shape_rule, _dynamic_update_slice_dtype_rule, - 'dynamic_update_slice', sharding_rule=_dynamic_update_slice_sharding_rule) + 'dynamic_update_slice', sharding_rule=_dynamic_update_slice_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'dynamic_update_slice')) ad.primitive_jvps[dynamic_update_slice_p] = _dynamic_update_slice_jvp ad.primitive_transposes[dynamic_update_slice_p] = \ _dynamic_update_slice_transpose_rule @@ -1921,9 +1944,6 @@ def _gather_shape_computation(indices, dimension_numbers, slice_sizes): else next(indices_shape_gen) for i in range(output_shape_rank)) return ans -class GatherShardingError(Exception): - pass - def _gather_sharding_rule(operand, indices, *, dimension_numbers, slice_sizes, unique_indices, indices_are_sorted, mode, fill_value): @@ -1935,7 +1955,7 @@ def _gather_sharding_rule(operand, indices, *, dimension_numbers, all(s is None for s in operand.sharding.spec) and all(s is None for s in indices.sharding.spec)): return core.get_cur_mesh_sharding() - raise GatherShardingError( + raise core.ShardingTypeError( "Use `.at[...].get(out_sharding=)` to provide output PartitionSpec for" " the gather indexing.") @@ -2119,7 +2139,8 @@ def _gather_pad_rule(in_avals, out_avals, operand, indices, *, gather_p = standard_primitive( _gather_shape_rule, _gather_dtype_rule, 'gather', - weak_type_rule=_argnum_weak_type(0), sharding_rule=_gather_sharding_rule) + weak_type_rule=_argnum_weak_type(0), sharding_rule=_gather_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'gather')) ad.defjvp(gather_p, _gather_jvp_rule, None) ad.primitive_transposes[gather_p] = _gather_transpose_rule batching.primitive_batchers[gather_p] = _gather_batching_rule @@ -2601,7 +2622,8 @@ def _scatter_batching_rule(scatter_op, batched_args, batch_dims, *, scatter_add_p = standard_primitive( _scatter_shape_rule, _scatter_dtype_rule, 'scatter-add', - weak_type_rule=_argnum_weak_type(0)) + weak_type_rule=_argnum_weak_type(0), + vma_rule=partial(core.standard_vma_rule, 'scatter_add')) ad.primitive_jvps[scatter_add_p] = partial(_scatter_addsub_jvp, scatter_add_p) ad.primitive_transposes[scatter_add_p] = partial(_scatter_addsub_transpose_rule, scatter_add_p) batching.primitive_batchers[scatter_add_p] = ( @@ -2612,6 +2634,7 @@ def _scatter_batching_rule(scatter_op, batched_args, batch_dims, *, _scatter_dtype_rule, "scatter-sub", weak_type_rule=_argnum_weak_type(0), + vma_rule=partial(core.standard_vma_rule, 'scatter_sub') ) ad.primitive_jvps[scatter_sub_p] = partial(_scatter_addsub_jvp, scatter_sub_p) ad.primitive_transposes[scatter_sub_p] = partial(_scatter_addsub_transpose_rule, scatter_sub_p) @@ -2621,7 +2644,8 @@ def _scatter_batching_rule(scatter_op, batched_args, batch_dims, *, scatter_mul_p = standard_primitive( _scatter_shape_rule, _scatter_dtype_rule, 'scatter-mul', - weak_type_rule=_argnum_weak_type(0)) + weak_type_rule=_argnum_weak_type(0), + vma_rule=partial(core.standard_vma_rule, 'scatter_mul')) def _scatter_mul_jvp_rhs(g, x, i, y, *, dimension_numbers, indices_are_sorted, unique_indices, mode, **kw): @@ -2750,14 +2774,16 @@ def _scatter_extremal_jvp(scatter_op, primals, tangents, update_jaxpr, scatter_min_p = standard_primitive( _scatter_shape_rule, _scatter_dtype_rule, 'scatter-min', - weak_type_rule=_argnum_weak_type(0)) + weak_type_rule=_argnum_weak_type(0), + vma_rule=partial(core.standard_vma_rule, 'scatter_min')) batching.primitive_batchers[scatter_min_p] = ( partial(_scatter_batching_rule, scatter_min_p)) ad.primitive_jvps[scatter_min_p] = partial(_scatter_extremal_jvp, scatter_min_p) scatter_max_p = standard_primitive( _scatter_shape_rule, _scatter_dtype_rule, 'scatter-max', - weak_type_rule=_argnum_weak_type(0)) + weak_type_rule=_argnum_weak_type(0), + vma_rule=partial(core.standard_vma_rule, 'scatter_max')) batching.primitive_batchers[scatter_max_p] = ( partial(_scatter_batching_rule, scatter_max_p)) ad.primitive_jvps[scatter_max_p] = partial(_scatter_extremal_jvp, scatter_max_p) @@ -2915,7 +2941,8 @@ def _scatter_transpose_rule(t, operand, indices, updates, *, scatter_p = standard_primitive( _scatter_shape_rule, _scatter_dtype_rule, 'scatter', - weak_type_rule=_argnum_weak_type(0)) + weak_type_rule=_argnum_weak_type(0), + vma_rule=partial(core.standard_vma_rule, 'scatter')) ad.primitive_jvps[scatter_p] = _scatter_jvp ad.primitive_transposes[scatter_p] = _scatter_transpose_rule batching.primitive_batchers[scatter_p] = ( diff --git a/jax/_src/lax/special.py b/jax/_src/lax/special.py index b70513bc2d20..a486bda28486 100644 --- a/jax/_src/lax/special.py +++ b/jax/_src/lax/special.py @@ -21,6 +21,7 @@ import numpy as np from functools import partial +from jax._src import core from jax._src.lax.lax import (add, bitwise_and, bitwise_not, bitwise_or, broadcast_in_dim, broadcast_shapes, convert_element_type, div, eq, exp, full_like, ge, @@ -37,8 +38,28 @@ from jax._src.lib.mlir.dialects import chlo from jax._src.typing import Array, ArrayLike +# TODO(mattjj): this function sucks, delete it +def _up_and_broadcast(doit): + def up_and_broadcast(*args): + broadcasted_shape = broadcast_shapes(*(a.shape for a in args)) + args = [broadcast_in_dim(a, broadcasted_shape, list(range(a.ndim))) for a in args] + + a_dtype = args[0].dtype + needs_upcast = a_dtype == dtypes.bfloat16 or a_dtype == np.float16 + if needs_upcast: + args = [convert_element_type(a, np.float32) for a in args] + a_x_type = np.float32 + else: + a_x_type = a_dtype + result = doit(*args, dtype=a_x_type) + if needs_upcast: + result = convert_element_type(result, a_dtype) + return result + return up_and_broadcast + def betainc(a: ArrayLike, b: ArrayLike, x: ArrayLike) -> Array: r"""Elementwise regularized incomplete beta integral.""" + a, b, x = core.standard_insert_pvary(a, b, x) return regularized_incomplete_beta_p.bind(a, b, x) def lgamma(x: ArrayLike) -> Array: @@ -51,26 +72,33 @@ def digamma(x: ArrayLike) -> Array: def polygamma(m: ArrayLike, x: ArrayLike) -> Array: r"""Elementwise polygamma: :math:`\psi^{(m)}(x)`.""" + m, x = core.standard_insert_pvary(m, x) return polygamma_p.bind(m, x) def igamma(a: ArrayLike, x: ArrayLike) -> Array: r"""Elementwise regularized incomplete gamma function.""" + a, x = core.standard_insert_pvary(a, x) return igamma_p.bind(a, x) def igammac(a: ArrayLike, x: ArrayLike) -> Array: r"""Elementwise complementary regularized incomplete gamma function.""" + a, x = core.standard_insert_pvary(a, x) return igammac_p.bind(a, x) def igamma_grad_a(a: ArrayLike, x: ArrayLike) -> Array: r"""Elementwise derivative of the regularized incomplete gamma function.""" + a, x = core.standard_insert_pvary(a, x) return igamma_grad_a_p.bind(a, x) -def random_gamma_grad(a: ArrayLike, x: ArrayLike) -> Array: +@_up_and_broadcast +def random_gamma_grad(a: ArrayLike, x: ArrayLike, *, dtype) -> Array: r"""Elementwise derivative of samples from `Gamma(a, 1)`.""" - return random_gamma_grad_p.bind(a, x) + a, x = core.standard_insert_pvary(a, x) + return random_gamma_grad_impl(a, x, dtype=dtype) def zeta(x: ArrayLike, q: ArrayLike) -> Array: r"""Elementwise Hurwitz zeta function: :math:`\zeta(x, q)`""" + x, q = core.standard_insert_pvary(x, q) return zeta_p.bind(x, q) def bessel_i0e(x: ArrayLike) -> Array: @@ -194,12 +222,18 @@ def nth_partial_betainc_numerator(iteration, a, b, x): iteration_is_one = eq(iteration_bcast, full_like(iteration_bcast, 1)) iteration_minus_one = iteration_bcast - full_like(iteration_bcast, 1) m = iteration_minus_one // full_like(iteration_minus_one, 2) + m_is_zero = eq(m, full_like(m, 0)) m = convert_element_type(m, dtype) one = full_like(a, 1) two = full_like(a, 2.0) # Partial numerator terms - even_numerator = -(a + m) * (a + b + m) * x / ( - (a + two * m) * (a + two * m + one)) + + # When a is close to zero and m == 0, using zero_numerator avoids + # inaccuracies when FTZ or DAZ is enabled: + zero_numerator = -(a + b) * x / (a + one) + even_numerator = select(m_is_zero, zero_numerator, + -(a + m) * (a + b + m) * x / ( + (a + two * m) * (a + two * m + one))) odd_numerator = m * (b - m) * x / ((a + two * m - one) * (a + two * m)) one_numerator = full_like(x, 1.0) numerator = select(iteration_is_even, even_numerator, odd_numerator) @@ -210,12 +244,24 @@ def nth_partial_betainc_denominator(iteration, a, b, x): return select(eq(iteration_bcast, full_like(iteration_bcast, 0)), full_like(x, 0), full_like(x, 1)) + a_is_zero = bitwise_or(eq(a, full_like(a, 0)), eq(b, full_like(b, float('inf')))) + b_is_zero = bitwise_or(eq(b, full_like(b, 0)), eq(a, full_like(a, float('inf')))) + x_is_zero = eq(x, full_like(x, 0)) + x_is_one = eq(x, full_like(x, 1)) + x_is_not_zero = bitwise_not(x_is_zero) + x_is_not_one = bitwise_not(x_is_one) + is_nan = bitwise_or(bitwise_or(_isnan(a), _isnan(b)), _isnan(x)) + + result_is_zero = bitwise_or(bitwise_and(b_is_zero, x_is_not_one), bitwise_and(a_is_zero, x_is_zero)) + result_is_one = bitwise_or(bitwise_and(a_is_zero, x_is_not_zero), bitwise_and(b_is_zero, x_is_one)) + result_is_nan = bitwise_or(bitwise_or(bitwise_or( - le(a, full_like(a, 0)), le(b, full_like(b, 0))), + lt(a, full_like(a, 0)), lt(b, full_like(b, 0))), lt(x, full_like(x, 0))), gt(x, full_like(x, 1))) + result_is_nan = bitwise_or(result_is_nan, bitwise_or(bitwise_and(a_is_zero, b_is_zero), is_nan)) - # The continued fraction will converge rapidly when x < (a+1)/(a+b+2) - # as per: http://dlmf.nist.gov/8.17.E23 + # The continued fraction will converge rapidly when x < + # (a+1)/(a+b+2) as per: http://dlmf.nist.gov/8.17.E23. # # Otherwise, we can rewrite using the symmetry relation as per: # http://dlmf.nist.gov/8.17.E4 @@ -234,10 +280,21 @@ def nth_partial_betainc_denominator(iteration, a, b, x): inputs=[a, b, x] ) - lbeta_ab = lgamma(a) + lgamma(b) - lgamma(a + b) - result = continued_fraction * exp(log(x) * a + log1p(-x) * b - lbeta_ab) / a + # For very small a and to avoid division by zero, we'll use + # a * gamma(a) = gamma(a + 1) -> 1 as a -> 0+. + very_small = (dtypes.finfo(dtype).tiny * 2).astype(dtype) + lbeta_ab_small_a = lgamma(b) - lgamma(a + b) + lbeta_ab = lgamma(a) + lbeta_ab_small_a + factor = select(lt(a, full_like(a, very_small)), + exp(log1p(-x) * b - lbeta_ab_small_a), + exp(log(x) * a + log1p(-x) * b - lbeta_ab) / a) + result = continued_fraction * factor + result = select(converges_rapidly, result, sub(full_like(result, 1), result)) + + result = select(result_is_zero, full_like(a, 0), result) + result = select(result_is_one, full_like(a, 1), result) result = select(result_is_nan, full_like(a, float('nan')), result) - return select(converges_rapidly, result, sub(full_like(result, 1), result)) + return result class IgammaMode(Enum): VALUE = 1 @@ -494,24 +551,6 @@ def random_gamma_grad_impl(a, x, *, dtype): full_like(a, float('nan')), output) return output -def _up_and_broadcast(doit): - def up_and_broadcast(*args): - broadcasted_shape = broadcast_shapes(*(a.shape for a in args)) - args = [broadcast_in_dim(a, broadcasted_shape, list(range(a.ndim))) for a in args] - - a_dtype = args[0].dtype - needs_upcast = a_dtype == dtypes.bfloat16 or a_dtype == np.float16 - if needs_upcast: - args = [convert_element_type(a, np.float32) for a in args] - a_x_type = np.float32 - else: - a_x_type = a_dtype - result = doit(*args, dtype=a_x_type) - if needs_upcast: - result = convert_element_type(result, a_dtype) - return result - return up_and_broadcast - def evaluate_chebyshev_polynomial(x, coefficients): b0 = full_like(x,0) @@ -657,11 +696,6 @@ def bessel_i0e_impl(x): ad.defjvp(igammac_p, igammac_grada, igammac_gradx) -random_gamma_grad_p = standard_naryop([_float, _float], 'random_gamma_grad') -mlir.register_lowering(random_gamma_grad_p, - mlir.lower_fun(_up_and_broadcast(random_gamma_grad_impl), - multiple_results=False)) - zeta_p = standard_naryop([_float, _float], 'zeta') mlir.register_lowering(zeta_p, partial(_nary_lower_hlo, chlo.zeta)) diff --git a/jax/_src/lax/utils.py b/jax/_src/lax/utils.py index f39d925ac2ad..8e97621912f1 100644 --- a/jax/_src/lax/utils.py +++ b/jax/_src/lax/utils.py @@ -37,13 +37,13 @@ def _argnum_weak_type(*argnums): return lambda *args, **_: all(args[i].weak_type for i in argnums) def standard_primitive(shape_rule, dtype_rule, name, - weak_type_rule=None, sharding_rule=None): + weak_type_rule=None, sharding_rule=None, vma_rule=None): weak_type_rule = weak_type_rule or _standard_weak_type_rule prim = core.Primitive(name) prim.def_impl(partial(dispatch.apply_primitive, prim)) prim.def_abstract_eval( partial(standard_abstract_eval, prim, shape_rule, dtype_rule, - weak_type_rule, sharding_rule)) + weak_type_rule, sharding_rule, vma_rule)) return prim def _get_array_abstraction_level(a): return a.array_abstraction_level @@ -74,7 +74,7 @@ def call_sharding_rule(prim, rule, num_out, *avals, **kwargs): s = NamedSharding(aval_mesh, P()) return s if num_out is None else [s] * num_out if rule is None: - raise ValueError( + raise core.ShardingTypeError( f'sharding rule for {prim.name} is not implemented. Please file a' ' bug at https://github.com/jax-ml/jax/issues. You can work around' ' this error by dropping that operation into full auto sharding' @@ -95,14 +95,14 @@ def call_shape_dtype_sharding_rule(prim, shape_rule, dtype_rule, sharding_rule, avals_str = ', '.join(i.str_short(short_dtypes=True) for i in avals) mesh = mesh_lib.empty_abstract_mesh if e.mesh is None else e.mesh out_aval_str = core.str_short_aval(out_shapes, out_dtypes, mesh, e.pspec, - short_dtypes=True) - raise TypeError( + frozenset(), short_dtypes=True) + raise core.ShardingTypeError( f'{prim} operation with inputs: {avals_str} produces an illegally' f' sharded result: {out_aval_str}') from e return out_shapes, out_dtypes, out_shardings def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule, - sharding_rule, *avals, **kwargs): + sharding_rule, vma_rule, *avals, **kwargs): assert all(isinstance(aval, core.UnshapedArray) for aval in avals), avals assert not prim.multiple_results weak_type = weak_type_rule(*avals, **kwargs) @@ -112,8 +112,10 @@ def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule, out_shape, out_dtype, out_sharding = call_shape_dtype_sharding_rule( prim, shape_rule, dtype_rule, sharding_rule, False, *avals, **kwargs) + out_vma = vma_rule(*avals, **kwargs) out_aval = core.ShapedArray( - out_shape, out_dtype, weak_type=weak_type, sharding=out_sharding) + out_shape, out_dtype, weak_type=weak_type, sharding=out_sharding, + vma=out_vma) core.check_avals_context_mesh([out_aval], prim.name) return out_aval elif least_specialized is core.DShapedArray: @@ -127,7 +129,7 @@ def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule, raise TypeError(avals, least_specialized) def standard_multi_result_abstract_eval( - prim, shape_rule, dtype_rule, weak_type_rule, sharding_rule, + prim, shape_rule, dtype_rule, weak_type_rule, sharding_rule, vma_rule, *avals, **kwargs): assert prim.multiple_results assert all(isinstance(aval, core.UnshapedArray) for aval in avals), avals @@ -137,11 +139,12 @@ def standard_multi_result_abstract_eval( core.check_avals_context_mesh(avals, prim.name) out_shapes, out_dtypes, out_shardings = call_shape_dtype_sharding_rule( prim, shape_rule, dtype_rule, sharding_rule, True, *avals, **kwargs) + out_vmas = vma_rule(*avals, **kwargs) if isinstance(weak_types, bool): weak_types = (weak_types,) * len(out_shapes) - out_avals = [core.ShapedArray(s, d, weak_type=weak_type, sharding=sh) - for s, d, weak_type, sh in zip(out_shapes, out_dtypes, - weak_types, out_shardings)] + out_avals = [core.ShapedArray(s, d, weak_type=weak_type, sharding=sh, vma=vma) + for s, d, weak_type, sh, vma in zip( + out_shapes, out_dtypes, weak_types, out_shardings, out_vmas)] core.check_avals_context_mesh(out_avals, prim.name) return out_avals elif least_specialized is core.UnshapedArray: diff --git a/jax/_src/lax/windowed_reductions.py b/jax/_src/lax/windowed_reductions.py index 400646f6238f..acbfae37eaf5 100644 --- a/jax/_src/lax/windowed_reductions.py +++ b/jax/_src/lax/windowed_reductions.py @@ -97,6 +97,7 @@ def _reduce_window( raise ValueError( 'reduce_window output must have the same tree structure as the operands' f' {operand_tree} vs. {out_tree}') + flat_operands = core.standard_insert_pvary(*flat_operands) out_flat = reduce_window_p.bind( *flat_operands, *flat_init_values, @@ -250,6 +251,8 @@ def _select_and_scatter(operand: Array, select: Callable, select, core.get_aval(init_value)) scatter_jaxpr, scatter_consts = lax._reduction_jaxpr( scatter, core.get_aval(init_value)) + operand, source, init_value = core.standard_insert_pvary( + operand, source, init_value) return select_and_scatter_p.bind( operand, source, init_value, select_jaxpr=select_jaxpr, select_consts=select_consts, scatter_jaxpr=scatter_jaxpr, @@ -261,6 +264,7 @@ def _select_and_scatter_add(source: Array, operand: Array, window_dimensions: core.Shape, window_strides: Sequence[int], padding: Sequence[tuple[int, int]]) -> Array: + source, operand = core.standard_insert_pvary(source, operand) return select_and_scatter_add_p.bind( source, operand, select_prim=select_prim, window_dimensions=tuple(window_dimensions), @@ -296,6 +300,7 @@ def _select_and_gather_add(tangents: Array, operand: Array, An array containing the elements in `tangents` corresponding to the output of the reduction of `operand` fin each window. """ + tangents, operand = core.standard_insert_pvary(tangents, operand) return select_and_gather_add_p.bind( tangents, operand, select_prim=select_prim, window_dimensions=tuple(window_dimensions), @@ -332,7 +337,8 @@ def _reduce_window_abstract_eval_rule( out_sharding = reduce_window_sharding_rule( operand_avals[0], window_dimensions, window_strides, padding, base_dilation, window_dilation) - return tuple(ShapedArray(out_shape, op.dtype, sharding=out_sharding) + vma = core.standard_vma_rule('reduce_window', *operand_avals) + return tuple(ShapedArray(out_shape, op.dtype, sharding=out_sharding, vma=vma) for op in operand_avals) @@ -524,15 +530,16 @@ def reduce_window_sharding_rule(operand, window_dimensions, window_strides, base_dilation, window_dilation): if spec is None: continue - if not (wdim == 1 and ws == 1 and pd == 1 and bd == 1 and wdil == 1): - raise NotImplementedError( + if not (wdim == 1 and ws == 1 and pd == (0, 0) and bd == 1 and wdil == 1): + raise core.ShardingTypeError( "Only trivial windowing is supported along non-replicated" f" dimensions. Got {operand.sharding.spec=}") return operand.sharding reduce_window_sum_p = lax.standard_primitive( _reduce_window_sum_shape_rule, lax._input_dtype, 'reduce_window_sum', - sharding_rule=reduce_window_sharding_rule) + sharding_rule=reduce_window_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'reduce_window_sum')) ad.deflinear2(reduce_window_sum_p, _reduce_window_sum_transpose_rule) batching.primitive_batchers[reduce_window_sum_p] = partial( _reduce_window_batch_rule, _reduce_window_sum) @@ -598,7 +605,8 @@ def reduce_window_shape_tuple(operand_shape, window_dimensions, window_strides, reduce_window_max_p = lax.standard_primitive( _common_reduce_window_shape_rule, lax._input_dtype, 'reduce_window_max', - sharding_rule=reduce_window_sharding_rule) + sharding_rule=reduce_window_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'reduce_window_max')) ad.defjvp(reduce_window_max_p, partial(_reduce_window_chooser_jvp_rule, lax.max_p)) batching.primitive_batchers[reduce_window_max_p] = partial( @@ -606,7 +614,8 @@ def reduce_window_shape_tuple(operand_shape, window_dimensions, window_strides, reduce_window_min_p = lax.standard_primitive( _common_reduce_window_shape_rule, lax._input_dtype, 'reduce_window_min', - sharding_rule=reduce_window_sharding_rule) + sharding_rule=reduce_window_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'reduce_window_min')) ad.defjvp(reduce_window_min_p, partial(_reduce_window_chooser_jvp_rule, lax.min_p)) @@ -630,7 +639,8 @@ def _reduce_window_lower( ): operand_aval, = ctx.avals_in - scalar_aval = operand_aval.update(shape=()) + scalar_aval = operand_aval.update( + shape=(), sharding=operand_aval.sharding.with_spec(())) return mlir.reduce_window( ctx, @@ -671,7 +681,8 @@ def _select_and_scatter_shape_rule( return operand.shape select_and_scatter_p = lax.standard_primitive( - _select_and_scatter_shape_rule, lax._input_dtype, 'select_and_scatter') + _select_and_scatter_shape_rule, lax._input_dtype, 'select_and_scatter', + vma_rule=partial(core.standard_vma_rule, 'select_and_scatter')) def _select_and_scatter_lower( ctx, operand, source, init_value, *, select_jaxpr, @@ -766,7 +777,8 @@ def _select_and_scatter_add_batch_rule( select_and_scatter_add_p = lax.standard_primitive( _select_and_scatter_add_shape_rule, lax._input_dtype, - 'select_and_scatter_add') + 'select_and_scatter_add', + vma_rule=partial(core.standard_vma_rule, 'select_and_scatter_add')) ad.primitive_transposes[select_and_scatter_add_p] = \ _select_and_scatter_add_transpose @@ -826,7 +838,7 @@ def _select_and_gather_add_sharding_rule( tangents, operand, *, select_prim, window_dimensions, window_strides, padding, base_dilation, window_dilation): if tangents.sharding != operand.sharding: - raise TypeError( + raise core.ShardingTypeError( "select_and_gather_add tangents and operand shardings must match, " f"got {tangents.sharding} and {operand.sharding}.") return reduce_window_sharding_rule( @@ -1039,7 +1051,8 @@ def _select_and_gather_add_batching_rule( select_and_gather_add_p = lax.standard_primitive( _select_and_gather_add_shape_rule, lax._input_dtype, - 'select_and_gather_add', sharding_rule=_select_and_gather_add_sharding_rule) + 'select_and_gather_add', sharding_rule=_select_and_gather_add_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'select_and_gather_add')) ad.primitive_jvps[select_and_gather_add_p] = _select_and_gather_add_jvp ad.primitive_transposes[select_and_gather_add_p] = \ _select_and_gather_add_transpose diff --git a/jax/_src/layout.py b/jax/_src/layout.py index 5309f0b1fd9c..8d4f8acd5327 100644 --- a/jax/_src/layout.py +++ b/jax/_src/layout.py @@ -127,6 +127,10 @@ def __init__(self, device_local_layout: LayoutOptions = None, self.device_local_layout = device_local_layout self.sharding = sharding + @property + def dll(self): + return self.device_local_layout + def __repr__(self): return (f'Layout(device_local_layout={self.device_local_layout},' f' sharding={self.sharding})') diff --git a/jax/_src/lib/BUILD b/jax/_src/lib/BUILD index 1fcbd4b6b7ef..78ddad29306b 100644 --- a/jax/_src/lib/BUILD +++ b/jax/_src/lib/BUILD @@ -44,6 +44,9 @@ py_library_providing_imports_info( "//jaxlib/mosaic/python:tpu_dialect", "//jaxlib:cpu_feature_guard", "//jaxlib:utils", + "//jaxlib:weakref_lru_cache", + "//jaxlib/xla:xla_client", + "//jaxlib/xla:xla_extension", "//jaxlib/triton", "//jaxlib/mlir/_mlir_libs:register_jax_dialects", "//jaxlib/mlir:arithmetic_dialect", @@ -60,6 +63,5 @@ py_library_providing_imports_info( "//jaxlib/mlir:sparse_tensor_dialect", "//jaxlib/mlir:stablehlo_dialect", "//jaxlib/mlir:vector_dialect", - # xla_client ]), ) diff --git a/jax/_src/lib/__init__.py b/jax/_src/lib/__init__.py index 7933bb769733..b011bf0084d4 100644 --- a/jax/_src/lib/__init__.py +++ b/jax/_src/lib/__init__.py @@ -40,7 +40,7 @@ raise ImportError(msg) from err -# Checks the jaxlib version before importing anything else from jaxlib. +# Checks the jaxlib version before importing anything else. # Returns the jaxlib version string. def check_jaxlib_version(jax_version: str, jaxlib_version: str, minimum_jaxlib_version: str) -> tuple[int, ...]: @@ -77,20 +77,36 @@ def _parse_version(v: str) -> tuple[int, ...]: jaxlib_version=jaxlib.version.__version__, minimum_jaxlib_version=jax.version._minimum_jaxlib_version) -# Before importing any C compiled modules from jaxlib, first import the CPU +# Before importing any C compiled modules, first import the CPU # feature guard module to verify that jaxlib was compiled in a way that only # uses instructions that are present on this machine. import jaxlib.cpu_feature_guard as cpu_feature_guard cpu_feature_guard.check_cpu_features() -import jaxlib.utils as utils # noqa: F401 -import jaxlib.xla_client as xla_client import jaxlib.lapack as lapack # noqa: F401 +import jaxlib.utils as utils # noqa: F401 +import jaxlib.xla_extension as xla_extension # noqa: F401 +from jaxlib.xla_extension import guard_lib as guard_lib # noqa: F401 +from jaxlib.xla_extension import jax_jit as jax_jit # noqa: F401 +from jaxlib.xla_extension import pmap_lib as pmap_lib # noqa: F401 +from jaxlib.xla_extension import pytree as pytree # noqa: F401 + +from jaxlib.xla_extension import Device as Device # noqa: F401 + +import jaxlib.xla_client as xla_client # noqa: F401 -xla_extension = xla_client._xla -pytree = xla_client._xla.pytree -jax_jit = xla_client._xla.jax_jit -pmap_lib = xla_client._xla.pmap_lib +# Jaxlib code is split between the Jax and the XLA repositories. +# Only for the internal usage of the JAX developers, we expose a version +# number that can be used to perform changes without breaking the main +# branch on the Jax github. +jaxlib_extension_version: int = getattr(xla_client, '_version', 0) +ifrt_version: int = getattr(xla_client, '_ifrt_version', 0) + +# TODO(phawkins): remove type: ignore once the minimum jaxlib is bumped. +if jaxlib_extension_version >= 328: + import jaxlib.weakref_lru_cache as weakref_lru_cache # type: ignore # noqa: F401 +else: + weakref_lru_cache = xla_extension # type: ignore # noqa: F401 # XLA garbage collection: see https://github.com/jax-ml/jax/issues/14882 def _xla_gc_callback(*args): @@ -109,13 +125,6 @@ def _xla_gc_callback(*args): import jaxlib.gpu_sparse as gpu_sparse # pytype: disable=import-error # noqa: F401 import jaxlib.gpu_prng as gpu_prng # pytype: disable=import-error # noqa: F401 import jaxlib.gpu_linalg as gpu_linalg # pytype: disable=import-error # noqa: F401 -import jaxlib.hlo_helpers as hlo_helpers # pytype: disable=import-error # noqa: F401 - -# Jaxlib code is split between the Jax and the Tensorflow repositories. -# Only for the internal usage of the JAX developers, we expose a version -# number that can be used to perform changes without breaking the main -# branch on the Jax github. -xla_extension_version: int = getattr(xla_client, '_version', 0) import jaxlib.gpu_rnn as gpu_rnn # pytype: disable=import-error # noqa: F401 import jaxlib.gpu_triton as gpu_triton # pytype: disable=import-error # noqa: F401 @@ -168,6 +177,3 @@ def _try_cuda_nvcc_import() -> str | None: return None cuda_path = _cuda_path() - -guard_lib = xla_client._xla.guard_lib -Device = xla_client._xla.Device diff --git a/jax/_src/lib/mlir/dialects/__init__.py b/jax/_src/lib/mlir/dialects/__init__.py index a9bae8821db5..be5317824c36 100644 --- a/jax/_src/lib/mlir/dialects/__init__.py +++ b/jax/_src/lib/mlir/dialects/__init__.py @@ -51,11 +51,7 @@ ]) del _lazy -# TODO(bartchr): Once JAX is released with SDY, remove the try/except. -try: - from jaxlib.mlir.dialects import sdy as sdy -except ImportError: - sdy: Any = None # type: ignore[no-redef] +from jaxlib.mlir.dialects import sdy # Alias that is set up to abstract away the transition from MHLO to StableHLO. from jaxlib.mlir.dialects import stablehlo as hlo diff --git a/jax/_src/linear_util.py b/jax/_src/linear_util.py index 1497597ebd62..bfe87430554e 100644 --- a/jax/_src/linear_util.py +++ b/jax/_src/linear_util.py @@ -67,6 +67,7 @@ def trans1(static_arg, *dynamic_args, **kwargs): from collections.abc import Callable, Sequence from functools import partial import re +import time from typing import Any, Hashable, NamedTuple import warnings import weakref @@ -326,6 +327,18 @@ def replace_func_name(self, name: str) -> DebugInfo: func_src_comps[0] = name return self._replace(func_src_info=" ".join(func_src_comps)) + @property + def func_filename(self) -> str | None: + m = _re_func_src_info.match(self.func_src_info) + if not m: return None + return m.group(3) + + @property + def func_lineno(self) -> int | None: + m = _re_func_src_info.match(self.func_src_info) + if not m or m.group(4) is None: return None + return int(m.group(4)) + def safe_arg_names(self, expected: int) -> tuple[str, ...]: """Get the arg_names with a safety check.""" if len(self.arg_names) == expected: @@ -352,6 +365,7 @@ def filter_result_paths(self, keep: Sequence[bool]) -> tuple[str, ...]: assert self.result_paths is not None and not callable(self.result_paths), self return tuple(v for v, b in zip(self.safe_result_paths(len(keep)), keep) if b) +_re_func_src_info = re.compile(r"([^ ]+)( at (.+):(\d+))?$") def _missing_debug_info(for_what: str) -> DebugInfo: warnings.warn( @@ -433,7 +447,7 @@ def valid_size(d) -> bool: def cache(call: Callable, *, - explain: Callable[[WrappedFun, bool, dict, tuple], None] | None = None): + explain: Callable[[WrappedFun, bool, dict, tuple, float], None] | None = None): """Memoization decorator for functions taking a WrappedFun as first argument. Args: @@ -442,7 +456,8 @@ def cache(call: Callable, *, memoization cache key. explain: a function that is invoked upon cache misses to log an explanation - of the miss. Invoked with `(fun, is_cache_first_use, cache, key)`. + of the miss. + Invoked with `(fun, is_cache_first_use, cache, key, elapsed_sec)`. Returns: A memoized version of ``call``. @@ -457,9 +472,11 @@ def memoized_fun(fun: WrappedFun, *args): ans, stores = result fun.populate_stores(stores) else: + if do_explain := explain and config.explain_cache_misses.value: + start = time.time() ans = call(fun, *args) - if explain and config.explain_cache_misses.value: - explain(fun, cache is new_cache, cache, key) + if do_explain: + explain(fun, cache is new_cache, cache, key, time.time() - start) # type: ignore cache[key] = (ans, fun.stores) return ans diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index b490febf7b0c..aa6c49f0ccdd 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -111,12 +111,16 @@ class AxisType(enum.Enum): def __repr__(self): return self.name -def _normalize_axis_types(axis_names, axis_types): +def _normalize_axis_types(axis_names, axis_types, name): axis_types = ((AxisType.Auto,) * len(axis_names) if axis_types is None else axis_types) if not isinstance(axis_types, tuple): - assert isinstance(axis_types, AxisType), axis_types axis_types = (axis_types,) + + if not all(isinstance(a, AxisType) for a in axis_types): + raise TypeError( + f"axis_types passed to {name} must be of type `jax.sharding.AxisType`." + f" Got {axis_types} of type {tuple(type(a) for a in axis_types)}") if len(axis_names) != len(axis_types): raise ValueError( "Number of axis names should match the number of axis_types. Got" @@ -174,6 +178,21 @@ def _any_axis_auto(self) -> bool: def _any_axis_explicit(self) -> bool: return any_axis_types_match(self._axis_types, AxisType.Explicit) + @functools.cached_property + def auto_axes(self): + return tuple(n for n, t in safe_zip(self.axis_names, self._axis_types) + if t == AxisType.Auto) + + @functools.cached_property + def explicit_axes(self): + return tuple(n for n, t in safe_zip(self.axis_names, self._axis_types) + if t == AxisType.Explicit) + + @functools.cached_property + def manual_axes(self): + return tuple(n for n, t in safe_zip(self.axis_names, self._axis_types) + if t == AxisType.Manual) + @functools.cached_property def _axis_types_dict(self): if not self.axis_names: @@ -194,16 +213,9 @@ def _name_to_type(self): class Mesh(_BaseMesh, contextlib.ContextDecorator): """Declare the hardware resources available in the scope of this manager. - In particular, all ``axis_names`` become valid resource names inside the - managed block and can be used e.g. in the ``in_axis_resources`` argument of - :py:func:`jax.experimental.pjit.pjit`. Also see JAX's multi-process programming - model (https://jax.readthedocs.io/en/latest/multi_process.html) - and the Distributed arrays and automatic parallelization tutorial - (https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) - - If you are compiling in multiple threads, make sure that the - ``with Mesh`` context manager is inside the function that the threads will - execute. + See the Distributed arrays and automatic parallelization tutorial + (https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) + and Explicit sharding tutorial (https://docs.jax.dev/en/latest/notebooks/explicit-sharding.html) Args: devices: A NumPy ndarray object containing JAX device objects (as @@ -214,32 +226,17 @@ class Mesh(_BaseMesh, contextlib.ContextDecorator): Examples: - >>> from jax.experimental.pjit import pjit >>> from jax.sharding import Mesh - >>> from jax.sharding import PartitionSpec as P + >>> from jax.sharding import PartitionSpec as P, NamedSharding >>> import numpy as np ... - >>> inp = np.arange(16).reshape((8, 2)) - >>> devices = np.array(jax.devices()).reshape(4, 2) - ... >>> # Declare a 2D mesh with axes `x` and `y`. - >>> global_mesh = Mesh(devices, ('x', 'y')) - >>> # Use the mesh object directly as a context manager. - >>> with global_mesh: - ... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp) - - >>> # Initialize the Mesh and use the mesh as the context manager. - >>> with Mesh(devices, ('x', 'y')) as global_mesh: - ... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp) - - >>> # Also you can use it as `with ... as ...`. - >>> global_mesh = Mesh(devices, ('x', 'y')) - >>> with global_mesh as m: - ... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp) - - >>> # You can also use it as `with Mesh(...)`. - >>> with Mesh(devices, ('x', 'y')): - ... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp) + >>> devices = np.array(jax.devices()).reshape(4, 2) + >>> mesh = Mesh(devices, ('x', 'y')) + >>> inp = np.arange(16).reshape(8, 2) + >>> arr = jax.device_put(inp, NamedSharding(mesh, P('x', 'y'))) + >>> out = jax.jit(lambda x: x * 2)(arr) + >>> assert out.sharding == NamedSharding(mesh, P('x', 'y')) """ devices: np.ndarray @@ -263,7 +260,7 @@ def __new__(cls, devices: np.ndarray | Sequence[xc.Device], f"devices.ndim == {devices.ndim} and " f"len(axis_names) == {len(axis_names)}.") - axis_types = _normalize_axis_types(axis_names, axis_types) + axis_types = _normalize_axis_types(axis_names, axis_types, 'Mesh') key = (axis_names, devices.shape, tuple(devices.flat), axis_types) val = _mesh_object_dict.get(key, None) @@ -447,7 +444,8 @@ def __init__(self, axis_sizes: tuple[int, ...], axis_names: tuple[str, ...], self.axis_sizes = axis_sizes self.axis_names = axis_names self._size = math.prod(self.axis_sizes) if self.axis_sizes else 0 - self._axis_types = _normalize_axis_types(self.axis_names, axis_types) + self._axis_types = _normalize_axis_types( + self.axis_names, axis_types, 'AbstractMesh') self._hash = hash((self.axis_sizes, self.axis_names, self._axis_types)) def __hash__(self): diff --git a/jax/_src/mesh_utils.py b/jax/_src/mesh_utils.py index ccc75af8c84f..c135919b14c5 100644 --- a/jax/_src/mesh_utils.py +++ b/jax/_src/mesh_utils.py @@ -1,4 +1,4 @@ -# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# Copyright 2021 The JAX Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/jax/_src/monitoring.py b/jax/_src/monitoring.py index 99e957733ba2..de706ccbaef5 100644 --- a/jax/_src/monitoring.py +++ b/jax/_src/monitoring.py @@ -46,10 +46,18 @@ def __call__( ) -> None: ... +class ScalarListenerWithMetadata(Protocol): + + def __call__( + self, event: str, value: float | int, **kwargs: str | int, + ) -> None: + ... + _event_listeners: list[EventListenerWithMetadata] = [] _event_duration_secs_listeners: list[EventDurationListenerWithMetadata] = [] _event_time_span_listeners: list[EventTimeSpanListenerWithMetadata] = [] +_scalar_listeners: list[ScalarListenerWithMetadata] = [] def record_event(event: str, **kwargs: str | int) -> None: @@ -81,6 +89,14 @@ def record_event_time_span( callback(event, start_time, end_time, **kwargs) +def record_scalar( + event: str, value: float | int, **kwargs: str | int +) -> None: + """Record a scalar summary value.""" + for callback in _scalar_listeners: + callback(event, value, **kwargs) + + def register_event_listener( callback: EventListenerWithMetadata, ) -> None: @@ -100,6 +116,14 @@ def register_event_duration_secs_listener( """Register a callback to be invoked during record_event_duration_secs().""" _event_duration_secs_listeners.append(callback) + +def register_scalar_listener( + callback : ScalarListenerWithMetadata, +) -> None: + """Register a callback to be invoked during record_scalar().""" + _scalar_listeners.append(callback) + + def get_event_duration_listeners() -> list[EventDurationListenerWithMetadata]: """Get event duration listeners.""" return list(_event_duration_secs_listeners) @@ -114,12 +138,20 @@ def get_event_listeners() -> list[EventListenerWithMetadata]: """Get event listeners.""" return list(_event_listeners) + +def get_scalar_listeners() -> list[ScalarListenerWithMetadata]: + """Get scalar event listeners.""" + return list(_scalar_listeners) + + def clear_event_listeners(): """Clear event listeners.""" global _event_listeners, _event_duration_secs_listeners, _event_time_span_listeners _event_listeners = [] _event_duration_secs_listeners = [] _event_time_span_listeners = [] + _scalar_listeners = [] + def _unregister_event_duration_listener_by_callback( callback: EventDurationListenerWithMetadata) -> None: @@ -159,3 +191,14 @@ def _unregister_event_listener_by_callback( """ assert callback in _event_listeners _event_listeners.remove(callback) + + +def _unregister_scalar_listener_by_callback( + callback: ScalarListenerWithMetadata, +) -> None: + """Unregister a scalar event listener by callback. + + This function is supposed to be called for testing only. + """ + assert callback in _scalar_listeners + _scalar_listeners.remove(callback) diff --git a/jax/_src/named_sharding.py b/jax/_src/named_sharding.py index 5accdd880a79..2c6741ab4c9a 100644 --- a/jax/_src/named_sharding.py +++ b/jax/_src/named_sharding.py @@ -21,11 +21,11 @@ from typing import Any, Union from jax._src import config -from jax._src.util import use_cpp_class, cache, use_cpp_method, tuple_insert +from jax._src.util import use_cpp_class, cache, use_cpp_method from jax._src.lib import xla_client as xc from jax._src.lib.mlir.dialects import sdy from jax._src import mesh as mesh_lib -from jax._src.partition_spec import PartitionSpec, UnconstrainedSingleton +from jax._src.partition_spec import PartitionSpec from jax._src import sharding as JSharding from jax._src import xla_bridge as xb import numpy as np @@ -42,7 +42,7 @@ def __init__(self, mesh: mesh_lib.Mesh): self.mesh = mesh def _to_sdy_sharding(self, ndim: int) -> SdyArraySharding: - dim_shardings = [SdyDimSharding(axes=[], is_closed=False) + dim_shardings = [SdyDimSharding(axes=[], is_open=True) for _ in range(ndim)] return SdyArraySharding(self.mesh.shape_tuple, dim_shardings) @@ -92,7 +92,7 @@ class NamedSharding(JSharding.Sharding): across ``y`` axis of the mesh. The Distributed arrays and automatic parallelization - (https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#namedsharding-gives-a-way-to-express-shardings-with-names) + (https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#namedsharding-gives-a-way-to-express-shardings-with-names) tutorial has more details and diagrams that explain how :class:`Mesh` and :class:`PartitionSpec` are used. @@ -112,20 +112,17 @@ class NamedSharding(JSharding.Sharding): mesh: mesh_lib.Mesh | mesh_lib.AbstractMesh spec: PartitionSpec _memory_kind: str | None - _manual_axes: frozenset[MeshAxisName] _logical_device_ids: tuple[int, ...] | None @use_cpp_method() def __init__( self, mesh: mesh_lib.Mesh | mesh_lib.AbstractMesh, spec: PartitionSpec, *, - memory_kind: str | None = None, _manual_axes=frozenset(), - _logical_device_ids=None): + memory_kind: str | None = None, _logical_device_ids=None): self.mesh = mesh self.spec = spec self._memory_kind = memory_kind - self._manual_axes = _manual_axes self._logical_device_ids = _logical_device_ids - check_pspec(self.mesh, self.spec, self._manual_axes) + check_pspec(self.mesh, self.spec) def __repr__(self): mem = '' if self.memory_kind is None else f', memory_kind={self.memory_kind}' @@ -137,7 +134,6 @@ def __repr__(self): def __reduce__(self): return (type(self), (self.mesh, self.spec), {'memory_kind': self.memory_kind, - '_manual_axes': self._manual_axes, '_logical_device_ids': self._logical_device_ids}) @property @@ -147,8 +143,7 @@ def memory_kind(self) -> str | None: def __hash__(self): if not hasattr(self, '_hash'): self._hash = hash( - (self.mesh, self.memory_kind, self.spec, self._manual_axes, - self._logical_device_ids)) + (self.mesh, self.memory_kind, self.spec, self._logical_device_ids)) return self._hash def __eq__(self, other): @@ -158,7 +153,6 @@ def __eq__(self, other): return True if (self.spec != other.spec or self.memory_kind != other.memory_kind - or self._manual_axes != other._manual_axes or self._logical_device_ids != other._logical_device_ids): return False return self.mesh is other.mesh or self.mesh == other.mesh @@ -198,7 +192,7 @@ def is_fully_addressable(self) -> bool: # Speed up `is_fully_addressable` since there is a high chance that the # mesh across multiple NamedSharding objects will be the same. if config.enable_empty_arrays.value: - client = self._internal_device_list[0].client + client = self._internal_device_list[0].client # type: ignore return (len(self.mesh._process_indices) == 1 and next(iter(self.mesh._process_indices)) == xb.process_index(client)) @@ -242,11 +236,11 @@ def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding: return named_sharding_to_xla_hlo_sharding(self, num_dimensions) def _to_sdy_sharding(self, num_dimensions: int) -> SdyArraySharding: - dim_shardings = [SdyDimSharding(axes=[], is_closed=True) + dim_shardings = [SdyDimSharding(axes=[], is_open=False) for _ in range(num_dimensions)] for i, dim_spec in enumerate(self.spec): if dim_spec is PartitionSpec.UNCONSTRAINED: - dim_shardings[i].is_closed = False + dim_shardings[i].is_open = True elif dim_spec is None: # Already empty and closed sharding. pass @@ -274,14 +268,13 @@ def get_array_mapping( @dataclasses.dataclass class SdyDimSharding: axes: Sequence[str] - is_closed: bool + is_open: bool priority: int | None = None def build(self) -> sdy.DimensionShardingAttr: return sdy.DimensionShardingAttr.get( [sdy.AxisRefAttr.get(axis) for axis in self.axes], - is_closed=self.is_closed, - priority=self.priority) + is_closed=not self.is_open, priority=self.priority) def __repr__(self): return f'SdyDimSharding({self._custom_repr()})' @@ -289,7 +282,7 @@ def __repr__(self): def _custom_repr(self): axes_repr = ', '.join(f"'{a}'" for a in self.axes) open_repr = '' - if not self.is_closed: + if self.is_open: open_repr = ', ?' if self.axes else '?' priority_repr = '' if self.priority is None else f'p{self.priority}' return f'{{{axes_repr}{open_repr}}}{priority_repr}' @@ -325,80 +318,6 @@ def __repr__(self): if self.replicated_axes else '') return f"SdyArraySharding([{dim_sharding_repr}]{device_id_repr}{rar})" -# TODO(yashkatariya): Remove this after jax 0.5.2 release -class ParsedPartitionSpec: - __slots__ = ('_user_spec', 'partitions') - - _user_spec: PartitionSpec | None - partitions: tuple[tuple[MeshAxisName, ...] | UnconstrainedSingleton, ...] - - def __init__(self, user_spec, partitions): - self._user_spec = user_spec - assert None not in partitions, partitions - self.partitions = tuple(partitions) - - def get_partition_spec(self) -> PartitionSpec: - if isinstance(self._user_spec, PartitionSpec): - return self._user_spec - else: - return get_single_pspec(self) - - def insert_axis_partitions(self, dim, val): - parts = self.partitions - too_short = dim - len(parts) - if too_short > 0: - parts += ((),) * too_short - new_partitions = tuple_insert(parts, dim, val) - return ParsedPartitionSpec(None, new_partitions) - - @classmethod - def from_user_input( - cls, - entry: PartitionSpec | None, - arg_name: str, - allow_unconstrained_dims: bool = False, - ) -> ParsedPartitionSpec: - if entry is None: - return cls(entry, ()) - if not isinstance(entry, PartitionSpec): - raise TypeError(f"{arg_name} are expected to be " - f"PartitionSpec instances or None, but got {entry}") - axis_specs = [] - for axis_spec in entry: - if axis_spec is None: - axis_spec = () - elif isinstance(axis_spec, (list, tuple)): - axis_spec = tuple(axis_spec) - elif axis_spec is PartitionSpec.UNCONSTRAINED: - if not allow_unconstrained_dims: - raise ValueError(f"Unconstrained dims are not allowed: {entry}") - axis_spec = PartitionSpec.UNCONSTRAINED - else: - axis_spec = (axis_spec,) - axis_specs.append(axis_spec) - new_entry = PartitionSpec( - *[tuple(e) if isinstance(e, (list, tuple)) else e for e in entry]) - return cls(new_entry, axis_specs) - - def __hash__(self): - return hash(self.partitions) - - def __eq__(self, other): - if not isinstance(other, ParsedPartitionSpec): - return False - return self.partitions == other.partitions - - def __len__(self): - return len(self.partitions) - - def __getitem__(self, i): - return self.partitions[i] - - def __iter__(self): - return iter(self.partitions) - - def __repr__(self): - return f"ParsedPartitionSpec(partitions={self.partitions})" @cache(max_size=4096, trace_context_in_key=False) def named_sharding_to_xla_hlo_sharding( @@ -408,9 +327,7 @@ def named_sharding_to_xla_hlo_sharding( mesh_axis_pos = {name: i for i, name in enumerate(self.mesh.axis_names)} special_axes = {} - mesh_manual_axes = {n for n, t in self.mesh._name_to_type.items() - if t == mesh_lib.AxisType.Manual} - manual_axes = self._manual_axes.union(mesh_manual_axes) + manual_axes = frozenset(self.mesh.manual_axes) if manual_axes: axis_names = self.mesh.axis_names for manual_axis in manual_axes: @@ -491,21 +408,11 @@ def array_mapping_to_axis_resources(array_mapping: ArrayMapping): partitions.append(None) return PartitionSpec(*partitions) -get_single_pspec = lambda p: array_mapping_to_axis_resources(get_array_mapping(p)) # type: ignore - -# TODO(yashkatariya): Remove this after jax 0.5.2 release -def preprocess(mesh, spec, parsed_pspec, _manual_axes=frozenset()): - if parsed_pspec is None: - spec = PartitionSpec() if spec is None else spec - parsed_pspec = ParsedPartitionSpec.from_user_input( - spec, "NamedSharding spec", allow_unconstrained_dims=True) - _check_unique_resources(parsed_pspec, "NamedSharding spec", mesh) - _check_mesh_resource_axis(mesh, parsed_pspec, _manual_axes) - return parsed_pspec +@cache(max_size=128, trace_context_in_key=False) def check_pspec(mesh, spec, _manual_axes=frozenset()): _check_unique_resources(spec, "NamedSharding spec", mesh) - _check_mesh_resource_axis(mesh, spec, _manual_axes) + _check_mesh_resource_axis(mesh, spec) class DuplicateSpecError(Exception): def __init__(self, message, mesh, pspec): @@ -517,13 +424,10 @@ def __init__(self, message, mesh, pspec): def __str__(self): return f"{self.message}" -def _check_unique_resources( - pspec: ParsedPartitionSpec | PartitionSpec, arg_name: str, mesh=None, -) -> None: +def _check_unique_resources(pspec: PartitionSpec, arg_name: str, mesh=None + ) -> None: resource_counts: dict[MeshAxisName, int] = {} duplicate = False - pspec = (pspec.get_partition_spec() if isinstance(pspec, ParsedPartitionSpec) - else pspec) for d in pspec: if d is PartitionSpec.UNCONSTRAINED or d is None: continue @@ -542,10 +446,8 @@ def _check_unique_resources( f' for {mesh_lib.show_axes(multiple_uses)}'), mesh=mesh, pspec=pspec) -@cache(max_size=128, trace_context_in_key=False) -def _check_mesh_resource_axis(mesh, pspec, _manual_axes): - pspec = (pspec.get_partition_spec() if isinstance(pspec, ParsedPartitionSpec) - else pspec) + +def _check_mesh_resource_axis(mesh, pspec): for p in pspec: if p is PartitionSpec.UNCONSTRAINED or p is None: continue @@ -555,10 +457,6 @@ def _check_mesh_resource_axis(mesh, pspec, _manual_axes): raise ValueError( f"Resource axis: {r} of {pspec} " f"is not found in mesh: {tuple(mesh.shape.keys())}.") - if r in _manual_axes: - raise ValueError( - f"Axis: {r} of {pspec} " - f"is also found in manual_axes: {_manual_axes}.") from None if not all(mesh._name_to_type[p[0]] == mesh._name_to_type[r] for r in p): raise ValueError( 'AxisTypes should be the same in a tuple subset of PartitionSpec:' diff --git a/jax/_src/nn/functions.py b/jax/_src/nn/functions.py index 7df0a638e566..640c8b89d001 100644 --- a/jax/_src/nn/functions.py +++ b/jax/_src/nn/functions.py @@ -22,6 +22,7 @@ import math import numpy as np from typing import Any, List, Literal +import warnings import jax import jax.numpy as jnp @@ -47,13 +48,26 @@ from jax._src.ops.special import logsumexp as _logsumexp -class Unspecified: - def __repr__(self): - return "_UNSPECIFIED" -_UNSPECIFIED = Unspecified() +# activations +@jax.jit +def identity(x: ArrayLike) -> Array: + r"""Identity activation function. + Returns the argument unmodified. -# activations + Args: + x : input array + + Returns: + The argument `x` unmodified. + + Examples: + >>> jax.nn.identity(jax.numpy.array([-2., -1., -0.5, 0, 0.5, 1., 2.])) + Array([-2. , -1. , -0.5, 0. , 0.5, 1. , 2. ], dtype=float32) + + """ + numpy_util.check_arraylike("identity", x) + return jnp.asarray(x) @custom_jvp @jax.jit @@ -505,8 +519,7 @@ def glu(x: ArrayLike, axis: int = -1) -> Array: @partial(jax.jit, static_argnames=("axis",)) def log_softmax(x: ArrayLike, axis: int | tuple[int, ...] | None = -1, - where: ArrayLike | None = None, - initial: Unspecified = _UNSPECIFIED) -> Array: + where: ArrayLike | None = None) -> Array: r"""Log-Softmax function. Computes the logarithm of the :code:`softmax` function, which rescales @@ -532,10 +545,6 @@ def log_softmax(x: ArrayLike, See also: :func:`softmax` """ - # TODO(jakevdp): remove the initial argument after JAX v0.4.40. - if initial is not _UNSPECIFIED: - raise TypeError("The initial argument to jax.nn.log_softmax was removed in JAX v0.4.36.") - del initial numpy_util.check_arraylike("log_softmax", x) x_arr = jnp.asarray(x) x_max = jnp.max(x_arr, axis, where=where, initial=-jnp.inf, keepdims=True) @@ -553,8 +562,7 @@ def log_softmax(x: ArrayLike, # @partial(jax.jit, static_argnames=("axis",)) def softmax(x: ArrayLike, axis: int | tuple[int, ...] | None = -1, - where: ArrayLike | None = None, - initial: Unspecified = _UNSPECIFIED) -> Array: + where: ArrayLike | None = None) -> Array: r"""Softmax function. Computes the function which rescales elements to the range :math:`[0, 1]` @@ -580,10 +588,6 @@ def softmax(x: ArrayLike, See also: :func:`log_softmax` """ - # TODO(jakevdp): remove the initial argument after JAX v0.4.40. - if initial is not _UNSPECIFIED: - raise TypeError("The initial argument to jax.nn.softmax was removed in JAX v0.4.36.") - del initial if config.softmax_custom_jvp.value: # mypy is confused by the `functools.partial` application in the definition # of `_softmax` and incorrectly concludes that `_softmax` returns @@ -1197,47 +1201,139 @@ def scaled_matmul( rhs_scales: Array, preferred_element_type: DTypeLike = jnp.float32, ) -> Array: - r""" - Performs scaled matrix multiplication between two 3D arrays, with scaling - factors applied to the matrices. - .. math:: - \mathrm{ScaledMatmul}(lhs, rhs, lhs_scales, rhs_scales)=lhs_scales \cdot rhs_scales \cdot \mathrm{dot}(lhs, rhs) + r"""Scaled matrix multiplication function. + + Performs block-scaled matmul of `a` and `b` using `a_scales` and `b_scales`. + The last dim is the contracting dim, and block size is inferred. + + Mathematically, this operation is equivalent to:: + + a_block_size = a.shape[-1] // a_scales.shape[-1] + b_block_size = b.shape[-1] // b_scales.shape[-1] + a_scaled = a * jnp.repeat(a_scales, a_block_size, axis=-1) + b_scaled = b * jnp.repeat(b_scales, b_block_size, axis=-1) + jnp.einsum('BMK,BNK->BMN', a_scaled, b_scaled) + Args: - lhs (Array): A 3D array of shape (B, M, K). - rhs (Array): A 3D array of shape (B, N, K). - lhs_scales (Array): A 3D array of shape (B, M, K_block). - rhs_scales (Array): A 3D array of shape (B, N, K_block). - preferred_element_type (DTypeLike, optional): The preferred data type - for the computation. Defaults to `jnp.float32`. + lhs (Array): Operand a, shape (B, M, K). + rhs (Array): Operand b, shape (B, N, K). + lhs_scales (Array): Shape (B, M, K_a), where `K % K_a == 0`. + rhs_scales (Array): Shape (B, N, K_b), where `K % K_b == 0`. + preferred_element_type (DTypeLike, optional): Defaults to `jnp.float32`. + Returns: - Array: A 3D array of shape (B, M, N) representing the scaled matrix - multiplication result. - Raises: - AssertionError: If the number of columns in `lhs` (`lhs_K`) does not - match the number of columns in `rhs` (`rhs_K`). + Array of shape (B, M, N). + Notes: - - The function ensures that the `preferred_element_type` is - danonicalized before passing it to the underlying computation. - - Scaling is applied to the matrices based on the `lhs_scales` and - `rhs_scales` arrays, enabling efficient computations in blocks. + - We currently do not support user-defined `precision` for customizing the + compute data type. It is fixed to `jnp.float32`. + - Block size is inferred as `K // K_a` for `a` and `K // K_b` for `b`. + - To use cuDNN with Nvidia Blackwell GPUs, inputs must match:: + + # mxfp8 + a, b: jnp.float8_e4m3fn | jnp.float8_e5m2 + a_scales, b_scales: jnp.float8_e8m0fnu + block_size: 32 + # nvfp4 + a, b: jnp.float4_e2m1fn + a_scales, b_scales: jnp.float8_e4m3fn + block_size: 16 + + Examples: + + Basic case: + + >>> a = jnp.array([1, 2, 3]).reshape((1, 1, 3)) + >>> b = jnp.array([4, 5, 6]).reshape((1, 1, 3)) + >>> a_scales = jnp.array([0.5]).reshape((1, 1, 1)) + >>> b_scales = jnp.array([0.5]).reshape((1, 1, 1)) + >>> scaled_matmul(a, b, a_scales, b_scales) # doctest: +SKIP + Array([[[8.]]], dtype=float32) + + Using fused cuDNN call on Blackwell GPUs: + + >>> dtype = jnp.float8_e4m3fn + >>> a = jax.random.normal(jax.random.PRNGKey(1), (3, 128, 64), dtype=dtype) + >>> b = jax.random.normal(jax.random.PRNGKey(2), (3, 128, 64), dtype=dtype) + >>> a_scales = jnp.ones((3, 128, 4), dtype=jnp.float8_e8m0fnu) + >>> b_scales = jnp.ones((3, 128, 4), dtype=jnp.float8_e8m0fnu) + >>> scaled_matmul(a, b, a_scales, b_scales) # doctest: +SKIP """ - B, M, lhs_K = lhs.shape - _, N, rhs_K = rhs.shape - assert lhs_K == rhs_K - _, _, K_block = lhs_scales.shape + a, b, a_scales, b_scales = lhs, rhs, lhs_scales, rhs_scales + if not all(x.ndim == 3 for x in (a, b, a_scales, b_scales)): + raise ValueError( + "scaled_matmul requires all inputs to be 3-dimensional arrays" + ) + + B_a, M_a, K_a = a.shape + B_b, N_b, K_b = b.shape + if K_a != K_b or B_a != B_b: + raise ValueError( + "scaled_matmul requires inputs a and b to have matching batch (B) " + f"and contract (K) dimensions, but got shapes {a.shape} and " + f"{b.shape}" + ) + + B_as, M_as, K_as = a_scales.shape + B_bs, N_bs, K_bs = b_scales.shape + if K_as != K_bs or B_as != B_bs: + raise ValueError( + "scaled_matmul requires scales to have matching batch (B) and " + f"contract (K) dimensions, but got shapes {a_scales.shape} and " + f"{b_scales.shape}" + ) + + if M_as != M_a or N_bs != N_b: + raise ValueError( + "scaled_matmul requires scales to match non-contract dimensions of " + f"inputs, but got shapes a: {a.shape}, b: {b.shape}, a_scales: " + f"{a_scales.shape}, b_scales: {b_scales.shape}" + ) preferred_element_type = dtypes.canonicalize_dtype( np.dtype(preferred_element_type) ) out = cudnn_scaled_matmul( - lhs, - rhs, - lhs_scales, - rhs_scales, + a, + b, + a_scales, + b_scales, preferred_element_type=preferred_element_type, ) return out +def get_scaled_dot_general_config(mode: Literal['nvfp4', 'mxfp8'], + global_scale: Array | None = None): + r"""Get quantization configs for scaled_dot_general. + + Create quantization configs for the `jax.nn.scaled_dot_general`. + + See Also: + - :func:`jax.nn.scaled_dot_general`: Scaled dot general function. + """ + + if mode == 'nvfp4': + one = jnp.ones((1,), dtype=jnp.float32) + return BlockScaleConfig( + mode='nvfp4', + block_size=16, + data_type=jnp.float4_e2m1fn, + scale_type=jnp.float8_e4m3fn, + global_scale=one if global_scale is None else global_scale, + infer_only=False + ) + elif mode == 'mxfp8': + return BlockScaleConfig( + mode='mxfp8', + block_size=32, + data_type=jnp.float8_e4m3fn, + scale_type=jnp.float8_e8m0fnu, + global_scale=None, + infer_only=False + ) + else: + raise ValueError(f"Unsupported mode: {mode}") + def scaled_dot_general( lhs, rhs, dimension_numbers, @@ -1246,52 +1342,78 @@ def scaled_dot_general( implementation: Literal['cudnn'] | None = None, ): r"""Scaled dot general operation. - Computes the scaled dot general on lhs, rhs with quanitzation specified by configs: - .. math:: - \widehat{lhs}, s_a=\mathrm{quantize}(lhs) \\ - \widehat{rhs}, s_b=\mathrm{quantize}(rhs) \\ - \mathrm{ScaledDot}(lhs, rhs)=s_a \cdot s_b \cdot \mathrm{dot}(\widehat{lhs}, \widehat{rhs}) + + Performs a generalized dot product with block-scaled quantization on the + lhs and rhs inputs. This operation extends `lax.dot_general` to support + user-defined scaling configurations. + + Essentially, the operation follows:: + + a, a_scales = quantize(lhs, configs[0]) + b, b_scales = quantize(rhs, configs[1]) + c = jax.nn.scaled_matmul(a, b, a_scales, b_scales) + Args: - lhs: Left-hand side input tensor. - rhs: Right-hand side input tensor. - dimension_numbers: A tuple specifying the contraction and batch dimensions - for the dot general operation. Must follow the format: - `((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims))`. - preferred_element_type: The preferred output data type. Supported types are - `jnp.float32`, `jnp.bfloat16`, and `jnp.float16`. Defaults to `jnp.float32`. - configs: A list of `BlockScaleConfig` specifying the scaling - configurations for the operation. Defaults to `mxfp8`. - implementation: A string to control which implementation backend to use. - Supported strings are `cudnn` (cuDNN block scaled dot). It defaults - to `None`, which will automatically select the best available backend. + lhs (ArrayLike): Input array. + rhs (ArrayLike): Input array. + dimension_numbers (DotDimensionNumbers): A tuple of two tuples specifying + the contraction and batch dimensions: + `((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims))`. + preferred_element_type (DTypeLike, optional): Output data type of the dot + product. Defaults to `jnp.float32`. Other valid types include + `jnp.bfloat16` and `jnp.float16`. + configs (list of BlockScaleConfig, optional): Scaling configurations for + lhs, rhs, and gradients. Users can obtain valid configurations via + `jax.nn.get_scaled_dot_general_config`. Currently, `nvfp4` and `mxfp8` + are supported. If `None`, falls back to `lax.dot_general`. + implementation: str + (Deprecated) Backend selector, now ignored. The system chooses the backend + automatically. Scheduled for removal in future releases. + Returns: - The result of the scaled dot general operation. + Array: The resulting tensor, with batch dimensions first, followed by + non-contracting/non-batch dimensions of lhs, and then those of rhs. + + See Also: + - :func:`jax.nn.scaled_matmul`: Scaled matmul function. + - :func:`jax.lax.dot_general`: General dot product operator. + + Notes: + - Unlike `nn.scaled_matmul`, which assumes quantized low-precision + inputs with explicit scaling factors, this operator takes high-precision + inputs, applies quantization internally, and handles the backward pass. + + Examples: + + Creating config for mxfp8: + + >>> configs = [jax.nn.get_scaled_dot_general_config('mxfp8')] * 3 + + Creating config for nvfp4: + + >>> global_scale = jnp.array([0.5], jnp.float32) + >>> configs = [jax.nn.get_scaled_dot_general_config('nvfp4', global_scale)] * 3 + + Using scaled_dot_general with the configs: + + >>> import functools + >>> scaled_dot_general_fn = functools.partial(jax.nn.scaled_dot_general, configs=configs) + >>> lhs = jax.random.normal(jax.random.PRNGKey(1), (3, 128, 64)) + >>> rhs = jax.random.normal(jax.random.PRNGKey(2), (3, 128, 64)) + >>> out = scaled_dot_general_fn(lhs, rhs, (((2,), (2,)), ((0,), (0,)))) # doctest: +SKIP """ - # Create configs if not provided - if configs is None: - if dtypes.float8_e8m0fnu is None: - raise ValueError("Requires >= ml_dtypes 0.5.0 to support float8_e8m0fnu") - mxfp8_config = BlockScaleConfig( - mode='mxfp8', - block_size=32, - data_type=jnp.float8_e4m3fn, - scale_type=jnp.float8_e8m0fnu, - global_scale=None, - infer_only=False - ) - configs = [mxfp8_config for _ in range(3)] + if implementation is not None: + warnings.warn("Backend selector, now ignored. The system chooses the " + "backend automatically.", DeprecationWarning) - if implementation is None: - implementation = 'cudnn' + if configs is None: + return lax.dot_general(lhs, rhs, dimension_numbers, + preferred_element_type=preferred_element_type) - match implementation: - case 'cudnn': - out = cudnn_scaled_dot_general( - lhs, rhs, dimension_numbers, - preferred_element_type=preferred_element_type, - configs=configs - ) - case _: - raise ValueError(f"Unsupported implementation option: {implementation}") + out = cudnn_scaled_dot_general( + lhs, rhs, dimension_numbers, + preferred_element_type=preferred_element_type, + configs=configs + ) return out diff --git a/jax/_src/numpy/array_api_metadata.py b/jax/_src/numpy/array_api_metadata.py index 4a01f579a67e..d634a2856a1b 100644 --- a/jax/_src/numpy/array_api_metadata.py +++ b/jax/_src/numpy/array_api_metadata.py @@ -24,10 +24,12 @@ import jax from jax._src.sharding import Sharding from jax._src.lib import xla_client as xc -from jax._src import dtypes as _dtypes, config +from jax._src import config +from jax._src import dtypes as _dtypes +from jax._src import xla_bridge as xb -__array_api_version__ = '2023.12' +__array_api_version__ = '2024.12' def __array_namespace__(self, *, api_version: None | str = None) -> ModuleType: @@ -51,8 +53,9 @@ class ArrayNamespaceInfo: .. _Python array API: https://data-apis.org/array-api/ """ _capabilities = { - "boolean indexing": True, - "data-dependent shapes": False, + "boolean indexing": False, # within transformations + "data-dependent shapes": False, # within transformations + "max dimensions": 64, # XLA limitation } def _build_dtype_dict(self): @@ -72,7 +75,10 @@ def default_device(self): return None def devices(self): - return jax.devices() + out = [None] # None indicates "uncommitted" + for backend in xb.backends(): + out.extend(jax.devices(backend)) + return out def capabilities(self): return self._capabilities diff --git a/jax/_src/numpy/array_creation.py b/jax/_src/numpy/array_creation.py index 67418e7322c9..4f07f94fe8b4 100644 --- a/jax/_src/numpy/array_creation.py +++ b/jax/_src/numpy/array_creation.py @@ -50,7 +50,8 @@ def zeros(shape: Any, dtype: DTypeLike | None = None, *, Args: shape: int or sequence of ints specifying the shape of the created array. - dtype: optional dtype for the created array; defaults to floating point. + dtype: optional dtype for the created array; defaults to float32 or float64 + depending on the X64 configuration (see :ref:`default-dtypes`). device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding` to which the created array will be committed. @@ -87,7 +88,8 @@ def ones(shape: Any, dtype: DTypeLike | None = None, *, Args: shape: int or sequence of ints specifying the shape of the created array. - dtype: optional dtype for the created array; defaults to floating point. + dtype: optional dtype for the created array; defaults to float32 or float64 + depending on the X64 configuration (see :ref:`default-dtypes`). device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding` to which the created array will be committed. @@ -126,7 +128,8 @@ def empty(shape: Any, dtype: DTypeLike | None = None, *, Args: shape: int or sequence of ints specifying the shape of the created array. - dtype: optional dtype for the created array; defaults to floating point. + dtype: optional dtype for the created array; defaults to float32 or float64 + depending on the X64 configuration (see :ref:`default-dtypes`). device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding` to which the created array will be committed. @@ -244,6 +247,8 @@ def zeros_like(a: ArrayLike | DuckTypedArray, [0, 0, 0]], dtype=int32) """ if not (hasattr(a, 'dtype') and hasattr(a, 'shape')): # support duck typing + if hasattr(a, '__jax_array__'): + a = a.__jax_array__() util.check_arraylike("zeros_like", a) dtypes.check_user_dtype_supported(dtype, "zeros_like") if shape is not None: @@ -287,6 +292,8 @@ def ones_like(a: ArrayLike | DuckTypedArray, [1, 1, 1]], dtype=int32) """ if not (hasattr(a, 'dtype') and hasattr(a, 'shape')): # support duck typing + if hasattr(a, '__jax_array__'): + a = a.__jax_array__() util.check_arraylike("ones_like", a) dtypes.check_user_dtype_supported(dtype, "ones_like") if shape is not None: @@ -332,9 +339,13 @@ def empty_like(prototype: ArrayLike | DuckTypedArray, [0, 0, 0]], dtype=int32) """ if not (hasattr(prototype, 'dtype') and hasattr(prototype, 'shape')): # support duck typing - util.check_arraylike("empty_like", prototype) - dtypes.check_user_dtype_supported(dtype, "empty_like") - return zeros_like(prototype, dtype=dtype, shape=shape, device=device) + if hasattr(prototype, '__jax_array__'): + prototype = prototype.__jax_array__() + util.check_arraylike("ones_like", prototype) + dtypes.check_user_dtype_supported(dtype, "ones_like") + if shape is not None: + shape = canonicalize_shape(shape) + return lax.full_like(prototype, 0, dtype, shape, sharding=util.normalize_device_to_sharding(device)) @export @@ -382,6 +393,8 @@ def full_like(a: ArrayLike | DuckTypedArray, util.check_arraylike("full_like", 0, fill_value) else: util.check_arraylike("full_like", a, fill_value) + if hasattr(a, '__jax_array__'): + a = a.__jax_array__() dtypes.check_user_dtype_supported(dtype, "full_like") if shape is not None: shape = canonicalize_shape(shape) diff --git a/jax/_src/numpy/array_methods.py b/jax/_src/numpy/array_methods.py index e9e097c85aff..0d7c50ee3358 100644 --- a/jax/_src/numpy/array_methods.py +++ b/jax/_src/numpy/array_methods.py @@ -588,7 +588,7 @@ def deferring_binary_op(self, other): def _unimplemented_setitem(self, i, x): msg = ("JAX arrays are immutable and do not support in-place item assignment." " Instead of x[idx] = y, use x = x.at[idx].set(y) or another .at[] method:" - " https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html") + " https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html") raise TypeError(msg.format(type(self))) def _operator_round(number: ArrayLike, ndigits: int | None = None) -> Array: diff --git a/jax/_src/numpy/einsum.py b/jax/_src/numpy/einsum.py index 9d745643b596..21333a9e7a0d 100644 --- a/jax/_src/numpy/einsum.py +++ b/jax/_src/numpy/einsum.py @@ -288,6 +288,10 @@ def einsum( spec = operands[0] if isinstance(operands[0], str) else None path_type = 'optimal' if optimize is True else Unoptimized() if optimize is False else optimize + # Extract __jax_array__ before passing to contract_path() + operands = tuple(op.__jax_array__() if hasattr(op, "__jax_array__") else op + for op in operands) + # Allow handling of shape polymorphism non_constant_dim_types = { type(d) for op in operands if not isinstance(op, str) diff --git a/jax/_src/numpy/error.py b/jax/_src/numpy/error.py new file mode 100644 index 000000000000..e2c23b43bdf8 --- /dev/null +++ b/jax/_src/numpy/error.py @@ -0,0 +1,203 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +from typing import Literal, Sequence + +import jax +from jax._src import config +from jax._src.typing import ArrayLike + +Category = Literal["nan", "divide", "oob"] + + +def _is_category_disabled( + category: Category | None, +) -> bool: + """Check if the error checking behavior for the given category is disabled.""" + if category is None: + return False + if category == "nan": + raise ValueError("nan is deprecated. Use `_set_error_if_nan` instead.") + if category == "divide": + raise ValueError( + "divide is deprecated. Use `_set_error_if_divide_by_zero` instead." + ) + if category == "oob": + return config.error_checking_behavior_oob.value == "ignore" + raise ValueError(f"Invalid category: {category}") + + +def _set_error_if_with_category( + pred: jax.Array, + /, + msg: str, + category: Category | None = None, +) -> None: + """Set the internal error state if any element of `pred` is `True`. + + This function is similar to :func:`set_error_if`, but it also takes a category + argument. The category can be "nan", "divide", or "oob". The error checking + behavior for each category can be configured using + :func:`set_error_checking_behavior`. If not provided, there will be no + category. + + This function is intended for use in JAX internal APIs (e.g., `jax.numpy`) + to perform category-specific runtime checks tied to the operation being + performed. + """ + if _is_category_disabled(category): + return + + # TODO(mattjj): fix the circular import issue. + from jax._src import error_check as error_check_lib + error_check_lib.set_error_if(pred, msg) + + +def _set_error_if_nan(pred: jax.Array, /): + """Set the internal error state if any element of `pred` is `NaN`. + + This function is disabled if the `jax_error_checking_behavior_nan` flag is + set to "ignore". + """ + if config.error_checking_behavior_nan.value == "ignore": + return + + # TODO(mattjj): fix the circular import issue. + import jax.numpy as jnp + if not jnp.issubdtype(pred.dtype, jnp.floating): # only check floats + return + + # TODO(mattjj): fix the circular import issue. + from jax._src import error_check as error_check_lib + error_check_lib.set_error_if(jnp.isnan(pred), "NaN encountered") + + +def _set_error_if_divide_by_zero(pred: jax.Array, /): + """Set the internal error state if any element of `pred` is zero. + + This function is intended for checking if the denominator of a division is + zero. + + This function is disabled if the `jax_error_checking_behavior_divide` flag is + set to "ignore". + """ + if config.error_checking_behavior_divide.value == "ignore": + return + + # TODO(ayx): fix the circular import issue. + from jax._src import error_check as error_check_lib + import jax.numpy as jnp + zero = jnp.zeros_like(pred, shape=()) + error_check_lib.set_error_if(pred == zero, "Division by zero encountered") + + +def _check_precondition_oob_gather( + shape: tuple[int, ...], gather_indices: ArrayLike +) -> None: + """Check for out of bounds errors before calling `lax.gather`.""" + if config.error_checking_behavior_oob.value == "ignore": + return + + # TODO(mattjj): fix the circular import issue. + from jax._src import error_check as error_check_lib + import jax.numpy as jnp + + shape = jnp.array(shape, dtype=jnp.int32) + error_check_lib.set_error_if( + jnp.logical_or( + jnp.min(gather_indices) < -shape, + jnp.max(gather_indices) >= shape, + ), + "Out of bounds encountered before calling `lax.gather`", + ) + + +def _check_precondition_oob_dynamic_slice( + shape: tuple[int, ...], + start_indices: Sequence[ArrayLike], + slice_sizes: list[int], + allow_negative_indices: list[bool], +) -> None: + """Check for out of bounds errors before calling `lax.dynamic_slice`.""" + if config.error_checking_behavior_oob.value == "ignore": + return + + # TODO(mattjj): fix the circular import issue. + from jax._src import error_check as error_check_lib + import jax.numpy as jnp + + shape = jnp.array(shape, dtype=jnp.int32) + start_indices = jnp.array(start_indices, dtype=jnp.int32) + slice_sizes = jnp.array(slice_sizes, dtype=jnp.int32) + allow_negative_indices = jnp.array(allow_negative_indices, dtype=jnp.bool_) + + lower_bound = jnp.where(allow_negative_indices, -shape, 0) + error_check_lib.set_error_if( + jnp.logical_or( + jnp.minimum(start_indices, start_indices + slice_sizes) < lower_bound, + jnp.maximum(start_indices, start_indices + slice_sizes) >= shape, + ), + "Out of bounds encountered before calling `lax.dynamic_slice`", + ) + + +Behavior = Literal["ignore", "raise"] + + +class error_checking_behavior: + """A context manager to set the error checking behavior. + + If both `all` and a category are provided, the category will override the + `all` setting. + + When the error checking behavior is set to "ignore", all errors will be + ignored. When set to "raise", errors will be detected and recorded, but an + exception will not be raised immediately. Users must call + :func:`raise_if_error` to at the end of the computation to raise the + exception. + """ + + def __init__( + self, + *, + all: Behavior | None = None, + nan: Behavior | None = None, + divide: Behavior | None = None, + oob: Behavior | None = None, + ) -> None: + new_settings = {} + if all is not None: + new_settings["nan"] = new_settings["divide"] = new_settings["oob"] = all + if nan is not None: + new_settings["nan"] = nan + if divide is not None: + new_settings["divide"] = divide + if oob is not None: + new_settings["oob"] = oob + self.new_settings = new_settings + self.stack = contextlib.ExitStack() + + def __enter__(self): + config_flags = { + "nan": config.error_checking_behavior_nan, + "divide": config.error_checking_behavior_divide, + "oob": config.error_checking_behavior_oob, + } + for key, value in self.new_settings.items(): + self.stack.enter_context(config_flags[key](value)) + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.stack.close() diff --git a/jax/_src/numpy/indexing.py b/jax/_src/numpy/indexing.py index 5d59bb53b457..6982cc4080e6 100644 --- a/jax/_src/numpy/indexing.py +++ b/jax/_src/numpy/indexing.py @@ -20,8 +20,6 @@ import string from typing import Any, NamedTuple, Sequence -import numpy as np - import jax from jax import lax from jax._src import array @@ -30,17 +28,19 @@ from jax._src import dispatch from jax._src import dtypes from jax._src import errors +from jax._src import mesh as mesh_lib from jax._src.api import jit from jax._src.lax import lax as lax_internal from jax._src.numpy import einsum -from jax._src import mesh as mesh_lib -from jax._src.pjit import auto_axes +from jax._src.numpy import error as jnp_error from jax._src.numpy import lax_numpy from jax._src.numpy import ufuncs from jax._src.numpy import util +from jax._src.pjit import auto_axes from jax._src.tree_util import tree_flatten from jax._src.typing import Array, ArrayLike, StaticScalar -from jax._src.util import canonicalize_axis, set_module, tuple_replace, safe_zip +from jax._src.util import canonicalize_axis, safe_zip, set_module, tuple_update +import numpy as np export = set_module('jax.numpy') @@ -397,7 +397,9 @@ def replace(tup, val): def _make_along_axis_idx(shape, indices, axis): - return tuple_replace(lax_numpy.indices(shape, sparse=True), axis, indices) + if axis < 0: + axis += len(shape) + return tuple_update(lax_numpy.indices(shape, sparse=True), axis, indices) @export @@ -526,8 +528,6 @@ def _attempt_rewriting_take_via_slice(arr: Array, idx: Any, mode: str | None) -> if not all(isinstance(i, int) for i in arr.shape): return None - if len(idx) > arr.ndim: - return None if any(i is None for i in idx): return None # TODO(jakevdp): handle newaxis case # For symbolic dimensions fallback to gather @@ -535,10 +535,13 @@ def _attempt_rewriting_take_via_slice(arr: Array, idx: Any, mode: str | None) -> for i in idx if isinstance(i, slice) for elt in (i.start, i.stop, i.step)): return None - if any(i is Ellipsis for i in idx): - # Remove ellipses and add trailing `slice(None)`. + # Remove ellipses and pad with trailing `slice(None)` if necessary. + # Do this before checking against rank of `arr` so that `...` can + # count as no dimensions at all (e.g. `my_1d_array[:, ...]` succeeds) idx = _canonicalize_tuple_index(arr.ndim, idx=idx) + if len(idx) > arr.ndim: + return None simple_revs = {i for i, ind in enumerate(idx) if _is_simple_reverse_slice(ind)} int_indices = {i for i, (ind, size) in enumerate(zip(idx, arr.shape)) @@ -570,7 +573,7 @@ def _attempt_rewriting_take_via_slice(arr: Array, idx: Any, mode: str | None) -> idx += (arr.ndim - len(idx)) * (slice(None),) start_indices: Sequence[ArrayLike] = [] - slice_sizes: Sequence[int] = [] + slice_sizes: list[int] = [] allow_negative_indices: list[bool] = [] for ind, size in safe_zip(idx, arr.shape): @@ -587,6 +590,7 @@ def _attempt_rewriting_take_via_slice(arr: Array, idx: Any, mode: str | None) -> slice_sizes.append(1) allow_negative_indices.append( not isinstance(ind, (int, np.integer)) or bool(ind < 0)) + # Try to use static slicing when possible. if all(isinstance(i, (int, np.integer)) and i >= 0 for i in start_indices): int_start_indices = [int(i) for i in start_indices] # type: ignore @@ -598,6 +602,9 @@ def _attempt_rewriting_take_via_slice(arr: Array, idx: Any, mode: str | None) -> # start indices to have matching types. if len(start_indices) > 1: start_indices = util.promote_dtypes(*start_indices) + jnp_error._check_precondition_oob_dynamic_slice( + arr.shape, start_indices, slice_sizes, allow_negative_indices + ) arr = lax.dynamic_slice( arr, start_indices=start_indices, slice_sizes=slice_sizes, allow_negative_indices=allow_negative_indices) @@ -640,6 +647,7 @@ def _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted, unique_indices, mode, fill_value, out_sharding): idx = merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx) indexer = index_to_gather(np.shape(arr), idx) # shared with _scatter_update + jnp_error._check_precondition_oob_gather(arr.shape, indexer.gather_indices) y = arr if fill_value is not None: diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 96efc48062e1..641422fceef3 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -49,24 +49,24 @@ from jax._src.lax.lax import (PrecisionLike,_array_copy, _sort_le_comparator, _sort_lt_comparator) from jax._src.lib import xla_client as xc -from jax._src.numpy.array_creation import (empty, empty_like, full, - ones, ones_like, zeros, zeros_like) from jax._src.numpy import indexing from jax._src.numpy import reductions from jax._src.numpy import tensor_contractions from jax._src.numpy import ufuncs from jax._src.numpy import util +from jax._src.numpy.array_creation import (empty, empty_like, full, + ones, ones_like, zeros, zeros_like) from jax._src.numpy.sorting import argsort, sort from jax._src.numpy.vectorize import vectorize +from jax._src.sharding_impls import SingleDeviceSharding from jax._src.typing import ( - Array, ArrayLike, DType, DTypeLike, DeprecatedArg, DimSize, Shape + Array, ArrayLike, DType, DTypeLike, DeprecatedArg, DimSize, Shape, SupportsShape ) from jax._src.util import ( NumpyComplexWarning, canonicalize_axis as _canonicalize_axis, ceil_of_ratio, safe_zip, set_module, unzip2) from jax.sharding import Sharding -from jax._src.sharding_impls import SingleDeviceSharding -from jax.tree_util import tree_leaves, tree_map +from jax.tree_util import tree_flatten, tree_map import numpy as np export = set_module('jax.numpy') @@ -552,7 +552,7 @@ def result_type(*args: Any) -> DType: For details on 64-bit values, refer to `Sharp bits - double precision`_: - .. _Sharp bits - double precision: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision + .. _Sharp bits - double precision: https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision """ return dtypes.result_type(*args) @@ -911,11 +911,11 @@ def histogram(a: ArrayLike, bins: ArrayLike = 10, Array(True, dtype=bool) """ if weights is None: - util.check_arraylike("histogram", a, bins) + a, _ = util.ensure_arraylike("histogram", a, bins) a, = util.promote_dtypes_inexact(a) weights = ones_like(a) else: - util.check_arraylike("histogram", a, bins, weights) + a, _, weights = util.ensure_arraylike("histogram", a, bins, weights) if np.shape(a) != np.shape(weights): raise ValueError("weights should have the same shape as a.") a, weights = util.promote_dtypes_inexact(a, weights) @@ -1005,7 +1005,7 @@ def histogram2d(x: ArrayLike, y: ArrayLike, bins: ArrayLike | list[ArrayLike] = >>> jnp.allclose(normed_sum, 1.0) Array(True, dtype=bool) """ - util.check_arraylike("histogram2d", x, y) + x, y = util.ensure_arraylike("histogram2d", x, y) try: N = len(bins) # type: ignore[arg-type] except TypeError: @@ -1077,10 +1077,10 @@ def histogramdd(sample: ArrayLike, bins: ArrayLike | list[ArrayLike] = 10, Array(True, dtype=bool) """ if weights is None: - util.check_arraylike("histogramdd", sample) + sample = util.ensure_arraylike("histogramdd", sample) sample, = util.promote_dtypes_inexact(sample) else: - util.check_arraylike("histogramdd", sample, weights) + sample, weights = util.ensure_arraylike("histogramdd", sample, weights) if np.shape(weights) != np.shape(sample)[:1]: raise ValueError("should have one weight for each sample.") sample, weights = util.promote_dtypes_inexact(sample, weights) @@ -1203,8 +1203,8 @@ def transpose(a: ArrayLike, axes: Sequence[int] | None = None) -> Array: Array([[1, 3], [2, 4]], dtype=int32) """ - util.check_arraylike("transpose", a) - axes_ = list(range(np.ndim(a))[::-1]) if axes is None else axes + a = util.ensure_arraylike("transpose", a) + axes_ = list(range(a.ndim)[::-1]) if axes is None else axes axes_ = [_canonicalize_axis(i, np.ndim(a)) for i in axes_] return lax.transpose(a, axes_) @@ -1285,8 +1285,8 @@ def matrix_transpose(x: ArrayLike, /) -> Array: [[5, 7], [6, 8]]], dtype=int32) """ - util.check_arraylike("matrix_transpose", x) - ndim = np.ndim(x) + x = util.ensure_arraylike("matrix_transpose", x) + ndim = x.ndim if ndim < 2: raise ValueError(f"x must be at least two-dimensional for matrix_transpose; got {ndim=}") axes = (*range(ndim - 2), ndim - 1, ndim - 2) @@ -1944,8 +1944,7 @@ def isrealobj(x: Any) -> bool: @export def reshape( - a: ArrayLike, shape: DimSize | Shape | None = None, order: str = "C", *, - newshape: DimSize | Shape | DeprecatedArg = DeprecatedArg(), + a: ArrayLike, shape: DimSize | Shape, order: str = "C", *, copy: bool | None = None) -> Array: """Return a reshaped copy of an array. @@ -1962,8 +1961,6 @@ def reshape( JAX does not support ``order="A"``. copy: unused by JAX; JAX always returns a copy, though under JIT the compiler may optimize such copies away. - newshape: deprecated alias of the ``shape`` argument. Will result in a - :class:`DeprecationWarning` if used. Returns: reshaped copy of input array with the specified shape. @@ -2021,14 +2018,6 @@ def reshape( __tracebackhide__ = True util.check_arraylike("reshape", a) - # TODO(jakevdp): finalized 2024-12-2; remove argument after JAX v0.4.40. - if not isinstance(newshape, DeprecatedArg): - raise TypeError("The newshape argument to jnp.reshape was removed in JAX v0.4.36." - " Use shape instead.") - if shape is None: - raise TypeError( - "jnp.shape requires passing a `shape` argument, but none was given." - ) try: # forward to method for ndarrays return a.reshape(shape, order=order) # type: ignore[call-overload,union-attr] @@ -2435,7 +2424,7 @@ def expand_dims(a: ArrayLike, axis: int | Sequence[int]) -> Array: [2], [3]]]], dtype=int32) """ - util.check_arraylike("expand_dims", a) + a = util.ensure_arraylike("expand_dims", a) axis = _ensure_index_tuple(axis) return lax.expand_dims(a, axis) @@ -2825,7 +2814,7 @@ def where(condition, x=None, y=None, /, *, size=None, fill_value=None): (reverse-mode differentiation), a NaN in either ``x`` or ``y`` will propagate into the gradient, regardless of the value of ``condition``. More information on this behavior and workarounds is available in the `JAX FAQ - `_. + `_. Examples: When ``x`` and ``y`` are not provided, ``where`` behaves equivalently to @@ -2918,6 +2907,12 @@ def select( raise ValueError(msg.format(len(condlist), len(choicelist))) if len(condlist) == 0: raise ValueError("condlist must be non-empty") + + util.check_arraylike("select", *condlist, *choicelist, default) + condlist = [asarray(cond) for cond in condlist] + choicelist = [asarray(choice) for choice in choicelist] + default = asarray(default) + # Put the default at front with condition False because # argmax returns zero for an array of False values. choicelist = util.promote_dtypes(default, *choicelist) @@ -2989,7 +2984,7 @@ def bincount(x: ArrayLike, weights: ArrayLike | None = None, >>> jnp.bincount(x, length=5) Array([2, 1, 0, 1, 0], dtype=int32) """ - util.check_arraylike("bincount", x) + x = util.ensure_arraylike("bincount", x) if _dtype(x) == bool: x = lax.convert_element_type(x, 'int32') if not issubdtype(_dtype(x), np.integer): @@ -4376,7 +4371,7 @@ def pad_func(row: Array, pad_width: tuple[int, int], Array([-10, -10, 2, 3, 4, 10, 10], dtype=int32) """ - util.check_arraylike("pad", array) + array = util.ensure_arraylike("pad", array) pad_width = _broadcast_to_pairs(pad_width, np.ndim(array), "pad_width") if pad_width and not all(core.is_dim(p[0]) and core.is_dim(p[1]) for p in pad_width): @@ -4471,7 +4466,7 @@ def stack(arrays: np.ndarray | Array | Sequence[ArrayLike], axis = _canonicalize_axis(axis, arrays.ndim) return concatenate(expand_dims(arrays, axis + 1), axis=axis, dtype=dtype) else: - util.check_arraylike("stack", *arrays) + arrays = util.ensure_arraylike_tuple("stack", arrays) shape0 = np.shape(arrays[0]) axis = _canonicalize_axis(axis, len(shape0) + 1) new_arrays = [] @@ -4560,7 +4555,7 @@ def tile(A: ArrayLike, reps: DimSize | Sequence[DimSize]) -> Array: [1, 2], [3, 4]], dtype=int32) """ - util.check_arraylike("tile", A) + A = util.ensure_arraylike("tile", A) try: iter(reps) # type: ignore[arg-type] except TypeError: @@ -4633,7 +4628,7 @@ def concatenate(arrays: np.ndarray | Array | Sequence[ArrayLike], """ if isinstance(arrays, (np.ndarray, Array)): return _concatenate_array(arrays, axis, dtype=dtype) - util.check_arraylike("concatenate", *arrays) + arrays = util.ensure_arraylike_tuple("concatenate", arrays) if not len(arrays): raise ValueError("Need at least one array to concatenate.") if axis is None: @@ -4875,6 +4870,7 @@ def dstack(tup: np.ndarray | Array | Sequence[ArrayLike], else: # TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error. util.check_arraylike("dstack", *tup, emit_warning=True) + tup = util.ensure_arraylike_tuple("dstack", tup) arrs = [atleast_3d(m) for m in tup] return concatenate(arrs, axis=2, dtype=dtype) @@ -5022,7 +5018,7 @@ def choose(a, choices): """ if out is not None: raise NotImplementedError("The 'out' argument to jnp.choose is not supported.") - util.check_arraylike('choose', a, *choices) + a, *choices = util.ensure_arraylike_tuple('choose', (a, *choices)) if not issubdtype(_dtype(a), np.integer): raise ValueError("`a` array must be integer typed") N = len(choices) @@ -5504,18 +5500,15 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, object = xc._xla.cuda_array_interface_to_buffer( cai=cai, gpu_backend=backend, device_id=device_id) - object = tree_map(lambda leaf: leaf.__jax_array__() - if hasattr(leaf, "__jax_array__") else leaf, object) - leaves = tree_leaves(object, is_leaf=lambda x: x is None) + leaves, treedef = tree_flatten(object, is_leaf=lambda x: x is None) if any(leaf is None for leaf in leaves): - # Added Nov 16 2023 - if deprecations.is_accelerated("jax-numpy-array-none"): - raise TypeError("None is not a valid value for jnp.array") - warnings.warn( - "None encountered in jnp.array(); this is currently treated as NaN. " - "In the future this will result in an error.", - FutureWarning, stacklevel=2) - leaves = tree_leaves(object) + raise ValueError("None is not a valid value for jnp.array") + leaves = [ + leaf + if (leaf_jax_array := getattr(leaf, "__jax_array__", None)) is None + else leaf_jax_array() + for leaf in leaves + ] if dtype is None: # Use lattice_result_type rather than result_type to avoid canonicalization. # Otherwise, weakly-typed inputs would have their dtypes canonicalized. @@ -5530,8 +5523,8 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, if not weak_type: dtype = dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True) # type: ignore[assignment] + object = treedef.unflatten(leaves) out: ArrayLike - if all(not isinstance(leaf, Array) for leaf in leaves): # TODO(jakevdp): falling back to numpy here fails to overflow for lists # containing large integers; see discussion in @@ -5910,14 +5903,14 @@ def fromfile(*args, **kwargs): ``jnp.asarray(np.fromfile(...))`` instead, although care should be taken if ``np.fromfile`` is used within jax transformations because of its potential side-effect of consuming the file object; for more information see `Common Gotchas: Pure Functions - `_. + `_. """ raise NotImplementedError( "jnp.fromfile() is not implemented because it may be non-pure and thus unsafe for use " "with JIT and other JAX transformations. Consider using jnp.asarray(np.fromfile(...)) " "instead, although care should be taken if np.fromfile is used within a jax transformations " "because of its potential side-effect of consuming the file object; for more information see " - "https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions") + "https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions") @export @@ -5929,14 +5922,14 @@ def fromiter(*args, **kwargs): ``jnp.asarray(np.fromiter(...))`` instead, although care should be taken if ``np.fromiter`` is used within jax transformations because of its potential side-effect of consuming the iterable object; for more information see `Common Gotchas: Pure Functions - `_. + `_. """ raise NotImplementedError( "jnp.fromiter() is not implemented because it may be non-pure and thus unsafe for use " "with JIT and other JAX transformations. Consider using jnp.asarray(np.fromiter(...)) " "instead, although care should be taken if np.fromiter is used within a jax transformations " "because of its potential side-effect of consuming the iterable object; for more information see " - "https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions") + "https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions") @export @@ -6995,8 +6988,11 @@ def repeat(a: ArrayLike, repeats: ArrayLike, axis: int | None = None, *, Array([[1, 1, 2, 2, 2, 2, 2], [3, 3, 4, 4, 4, 4, 4]], dtype=int32) """ - arr = util.ensure_arraylike("repeat", a) - core.is_dim(repeats) or util.check_arraylike("repeat", repeats) + if core.is_dim(repeats): + util.check_arraylike("repeat", a) + else: + util.check_arraylike("repeat", a, repeats) + arr = asarray(a) if axis is None: arr = arr.ravel() @@ -7564,7 +7560,7 @@ def tril_indices(n: int, k: int = 0, m: int | None = None) -> tuple[Array, Array @export -def triu_indices_from(arr: ArrayLike, k: int = 0) -> tuple[Array, Array]: +def triu_indices_from(arr: ArrayLike | SupportsShape, k: int = 0) -> tuple[Array, Array]: """Return the indices of upper triangle of a given array. JAX implementation of :func:`numpy.triu_indices_from`. @@ -7615,14 +7611,18 @@ def triu_indices_from(arr: ArrayLike, k: int = 0) -> tuple[Array, Array]: >>> jnp.triu_indices_from(arr, k=-1) (Array([0, 0, 0, 1, 1, 1, 2, 2], dtype=int32), Array([0, 1, 2, 0, 1, 2, 1, 2], dtype=int32)) """ - arr_shape = np.shape(arr) + if hasattr(arr, "shape"): + arr_shape = arr.shape + else: + arr = util.ensure_arraylike("triu_indices_from", arr) + arr_shape = arr.shape if len(arr_shape) != 2: raise ValueError("Only 2-D inputs are accepted") return triu_indices(arr_shape[0], k=k, m=arr_shape[1]) @export -def tril_indices_from(arr: ArrayLike, k: int = 0) -> tuple[Array, Array]: +def tril_indices_from(arr: ArrayLike | SupportsShape, k: int = 0) -> tuple[Array, Array]: """Return the indices of lower triangle of a given array. JAX implementation of :func:`numpy.tril_indices_from`. @@ -7673,7 +7673,11 @@ def tril_indices_from(arr: ArrayLike, k: int = 0) -> tuple[Array, Array]: >>> jnp.tril_indices_from(arr, k=-1) (Array([1, 2, 2], dtype=int32), Array([0, 0, 1], dtype=int32)) """ - arr_shape = np.shape(arr) + if hasattr(arr, "shape"): + arr_shape = arr.shape + else: + arr = util.ensure_arraylike("tril_indices_from", arr) + arr_shape = arr.shape if len(arr_shape) != 2: raise ValueError("Only 2-D inputs are accepted") return tril_indices(arr_shape[0], k=k, m=arr_shape[1]) @@ -7827,7 +7831,7 @@ def diag_indices_from(arr: ArrayLike) -> tuple[Array, ...]: Array([0, 1], dtype=int32), Array([0, 1], dtype=int32)) """ - util.check_arraylike("diag_indices_from", arr) + arr = util.ensure_arraylike("diag_indices_from", arr) nd = np.ndim(arr) if not np.ndim(arr) >= 2: raise ValueError("input array must be at least 2-d") @@ -8243,6 +8247,9 @@ def delete( # Case 3: obj is an array # NB: pass both arrays to check for appropriate error message. util.check_arraylike("delete", a, obj) + # Can't use ensure_arraylike here because obj may be static. + if hasattr(obj, "__jax_array__"): + obj = obj.__jax_array__() # Case 3a: unique integer indices; delete in a JIT-compatible way if issubdtype(_dtype(obj), np.integer) and assume_unique_indices: @@ -8773,6 +8780,7 @@ def argwhere( >>> jnp.argwhere(0) Array([], shape=(0, 0), dtype=int32) """ + a = util.ensure_arraylike("argwhere", a) result = transpose(vstack(nonzero(atleast_1d(a), size=size, fill_value=fill_value))) if np.ndim(a) == 0: return result[:0].reshape(result.shape[0], 0) diff --git a/jax/_src/numpy/linalg.py b/jax/_src/numpy/linalg.py index 23f2a58b09f6..146bbbda0213 100644 --- a/jax/_src/numpy/linalg.py +++ b/jax/_src/numpy/linalg.py @@ -72,8 +72,8 @@ def _symmetrize(x: Array) -> Array: return (x + _H(x)) / 2 @export -@partial(jit, static_argnames=['upper']) -def cholesky(a: ArrayLike, *, upper: bool = False) -> Array: +@partial(jit, static_argnames=['upper', 'symmetrize_input']) +def cholesky(a: ArrayLike, *, upper: bool = False, symmetrize_input: bool = True) -> Array: """Compute the Cholesky decomposition of a matrix. JAX implementation of :func:`numpy.linalg.cholesky`. @@ -98,6 +98,10 @@ def cholesky(a: ArrayLike, *, upper: bool = False) -> Array: Must have shape ``(..., N, N)``. upper: if True, compute the upper Cholesky decomposition `U`. if False (default), compute the lower Cholesky decomposition `L`. + symmetrize_input: if True (default) then input is symmetrized, which leads + to better behavior under automatic differentiation. Note that when this + is set to True, both the upper and lower triangles of the input will + be used in computing the decomposition. Returns: array of shape ``(..., N, N)`` representing the Cholesky decomposition @@ -135,7 +139,7 @@ def cholesky(a: ArrayLike, *, upper: bool = False) -> Array: """ a = ensure_arraylike("jnp.linalg.cholesky", a) a, = promote_dtypes_inexact(a) - L = lax_linalg.cholesky(a) + L = lax_linalg.cholesky(a, symmetrize_input=symmetrize_input) return L.mT.conj() if upper else L @@ -821,7 +825,9 @@ def eigh(a: ArrayLike, UPLO: str | None = None, UPLO: specifies whether the calculation is done with the lower triangular part of ``a`` (``'L'``, default) or the upper triangular part (``'U'``). symmetrize_input: if True (default) then input is symmetrized, which leads - to better behavior under automatic differentiation. + to better behavior under automatic differentiation. Note that when this + is set to True, both the upper and lower triangles of the input will + be used in computing the decomposition. Returns: A namedtuple ``(eigenvalues, eigenvectors)`` where @@ -863,8 +869,9 @@ def eigh(a: ArrayLike, UPLO: str | None = None, @export -@partial(jit, static_argnames=('UPLO',)) -def eigvalsh(a: ArrayLike, UPLO: str | None = 'L') -> Array: +@partial(jit, static_argnames=('UPLO', 'symmetrize_input')) +def eigvalsh(a: ArrayLike, UPLO: str | None = 'L', *, + symmetrize_input: bool = True) -> Array: """ Compute the eigenvalues of a Hermitian matrix. @@ -875,6 +882,10 @@ def eigvalsh(a: ArrayLike, UPLO: str | None = 'L') -> Array: or symmetric (if real) matrix. UPLO: specifies whether the calculation is done with the lower triangular part of ``a`` (``'L'``, default) or the upper triangular part (``'U'``). + symmetrize_input: if True (default) then input is symmetrized, which leads + to better behavior under automatic differentiation. Note that when this + is set to True, both the upper and lower triangles of the input will + be used in computing the decomposition. Returns: An array of shape ``(..., M)`` containing the eigenvalues, sorted in @@ -894,7 +905,7 @@ def eigvalsh(a: ArrayLike, UPLO: str | None = 'L') -> Array: """ a = ensure_arraylike("jnp.linalg.eigvalsh", a) a, = promote_dtypes_inexact(a) - w, _ = eigh(a, UPLO) + w, _ = eigh(a, UPLO, symmetrize_input=symmetrize_input) return w diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index 985b296bc06f..96b2782edc13 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -30,7 +30,7 @@ from jax._src import deprecations from jax._src import dtypes from jax._src.numpy.util import ( - _broadcast_to, check_arraylike, _complex_elem_type, + _broadcast_to, check_arraylike, _complex_elem_type, ensure_arraylike, promote_dtypes_inexact, promote_dtypes_numeric, _where) from jax._src.lax import lax as lax_internal from jax._src.typing import Array, ArrayLike, DType, DTypeLike, DeprecatedArg @@ -1992,7 +1992,7 @@ def _cumulative_reduction( fill_nan: bool = False, fill_value: ArrayLike = 0, promote_integers: bool = False) -> Array: """Helper function for implementing cumulative reductions.""" - check_arraylike(name, a) + a = ensure_arraylike(name, a) if out is not None: raise NotImplementedError(f"The 'out' argument to jnp.{name} is not supported") dtypes.check_user_dtype_supported(dtype, name) @@ -2242,8 +2242,7 @@ def cumulative_sum( Array([[ 0, 1, 3, 6], [ 0, 4, 9, 15]], dtype=int32) """ - check_arraylike("cumulative_sum", x) - x = lax_internal.asarray(x) + x = ensure_arraylike("cumulative_sum", x) if x.ndim == 0: raise ValueError( "The input must be non-scalar to take a cumulative sum, however a " @@ -2304,8 +2303,7 @@ def cumulative_prod( Array([[ 1, 1, 2, 6], [ 1, 4, 20, 120]], dtype=int32) """ - check_arraylike("cumulative_prod", x) - x = lax_internal.asarray(x) + x = ensure_arraylike("cumulative_prod", x) if x.ndim == 0: raise ValueError( "The input must be non-scalar to take a cumulative product, however a " diff --git a/jax/_src/numpy/scalar_types.py b/jax/_src/numpy/scalar_types.py index 2f9954488b41..2b0e04adc997 100644 --- a/jax/_src/numpy/scalar_types.py +++ b/jax/_src/numpy/scalar_types.py @@ -68,33 +68,27 @@ def _make_scalar_type(np_scalar_type: type) -> _ScalarMeta: return meta bool_ = _make_scalar_type(np.bool_) -if dtypes.uint2 is not None: - uint2 = _make_scalar_type(dtypes.uint2) +uint2 = _make_scalar_type(dtypes.uint2) uint4 = _make_scalar_type(dtypes.uint4) uint8 = _make_scalar_type(np.uint8) uint16 = _make_scalar_type(np.uint16) uint32 = _make_scalar_type(np.uint32) uint64 = _make_scalar_type(np.uint64) -if dtypes.int2 is not None: - int2 = _make_scalar_type(dtypes.int2) +int2 = _make_scalar_type(dtypes.int2) int4 = _make_scalar_type(dtypes.int4) int8 = _make_scalar_type(np.int8) int16 = _make_scalar_type(np.int16) int32 = _make_scalar_type(np.int32) int64 = _make_scalar_type(np.int64) -if dtypes.float8_e3m4 is not None: - float8_e3m4 = _make_scalar_type(dtypes.float8_e3m4) -if dtypes.float8_e4m3 is not None: - float8_e4m3 = _make_scalar_type(dtypes.float8_e4m3) -if dtypes.float8_e8m0fnu is not None: - float8_e8m0fnu = _make_scalar_type(dtypes.float8_e8m0fnu) +float4_e2m1fn = _make_scalar_type(dtypes.float4_e2m1fn) +float8_e3m4 = _make_scalar_type(dtypes.float8_e3m4) +float8_e4m3 = _make_scalar_type(dtypes.float8_e4m3) +float8_e8m0fnu = _make_scalar_type(dtypes.float8_e8m0fnu) float8_e4m3fn = _make_scalar_type(dtypes.float8_e4m3fn) float8_e4m3fnuz = _make_scalar_type(dtypes.float8_e4m3fnuz) float8_e5m2 = _make_scalar_type(dtypes.float8_e5m2) float8_e5m2fnuz = _make_scalar_type(dtypes.float8_e5m2fnuz) float8_e4m3b11fnuz = _make_scalar_type(dtypes.float8_e4m3b11fnuz) -if dtypes.float4_e2m1fn is not None: - float4_e2m1fn = _make_scalar_type(dtypes.float4_e2m1fn) bfloat16 = _make_scalar_type(dtypes.bfloat16) float16 = _make_scalar_type(np.float16) float32 = single = _make_scalar_type(np.float32) diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index 91191d24a12e..509b046554d3 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -32,12 +32,13 @@ from jax._src.lax import lax from jax._src.lax import other as lax_other from jax._src.typing import Array, ArrayLike +from jax._src.numpy import error as jnp_error +from jax._src.numpy import reductions +from jax._src.numpy.ufunc_api import ufunc from jax._src.numpy.util import ( - check_arraylike, promote_args, promote_args_inexact, + check_arraylike, ensure_arraylike, promote_args, promote_args_inexact, promote_args_numeric, promote_dtypes_inexact, promote_dtypes_numeric, promote_shapes, _where, check_no_float0s) -from jax._src.numpy.ufunc_api import ufunc -from jax._src.numpy import reductions from jax._src.util import set_module @@ -486,7 +487,9 @@ def log(x: ArrayLike, /) -> Array: >>> jnp.allclose(jnp.log(x1*x2), jnp.log(x1)+jnp.log(x2)) Array(True, dtype=bool) """ - return lax.log(*promote_args_inexact('log', x)) + out = lax.log(*promote_args_inexact('log', x)) + jnp_error._set_error_if_nan(out) + return out @export @@ -572,7 +575,9 @@ def log1p(x: ArrayLike, /) -> Array: >>> jnp.expm1(jnp.log(x1+1)) # doctest: +SKIP Array([1.000166e-04, 9.536743e-07, 0.000000e+00], dtype=float32) """ - return lax.log1p(*promote_args_inexact('log1p', x)) + out = lax.log1p(*promote_args_inexact('log1p', x)) + jnp_error._set_error_if_nan(out) + return out @export @@ -604,7 +609,9 @@ def sin(x: ArrayLike, /) -> Array: ... print(jnp.sin(x)) [ 0.707 1. 0.707 -0. ] """ - return lax.sin(*promote_args_inexact('sin', x)) + out = lax.sin(*promote_args_inexact('sin', x)) + jnp_error._set_error_if_nan(out) + return out @export @@ -635,7 +642,9 @@ def cos(x: ArrayLike, /) -> Array: ... print(jnp.cos(x)) [ 0.707 -0. -0.707 -0.866] """ - return lax.cos(*promote_args_inexact('cos', x)) + out = lax.cos(*promote_args_inexact('cos', x)) + jnp_error._set_error_if_nan(out) + return out @export @@ -666,7 +675,9 @@ def tan(x: ArrayLike, /) -> Array: ... print(jnp.tan(x)) [ 0. 0.577 1. -1. -0.577] """ - return lax.tan(*promote_args_inexact('tan', x)) + out = lax.tan(*promote_args_inexact('tan', x)) + jnp_error._set_error_if_nan(out) + return out @export @@ -708,7 +719,9 @@ def arcsin(x: ArrayLike, /) -> Array: ... jnp.arcsin(3+4j) Array(0.634+2.306j, dtype=complex64, weak_type=True) """ - return lax.asin(*promote_args_inexact('arcsin', x)) + out = lax.asin(*promote_args_inexact('arcsin', x)) + jnp_error._set_error_if_nan(out) + return out @export @@ -751,7 +764,9 @@ def arccos(x: ArrayLike, /) -> Array: ... jnp.arccos(4-1j) Array(0.252+2.097j, dtype=complex64, weak_type=True) """ - return lax.acos(*promote_args_inexact('arccos', x)) + out = lax.acos(*promote_args_inexact('arccos', x)) + jnp_error._set_error_if_nan(out) + return out @export @@ -1005,6 +1020,7 @@ def arccosh(x: ArrayLike, /) -> Array: # Note: arccosh is multi-valued for complex input, and lax.acosh # uses a different convention than np.arccosh. result = lax.acosh(*promote_args_inexact("arccosh", x)) + jnp_error._set_error_if_nan(result) if dtypes.issubdtype(result.dtype, np.complexfloating): result = _where(real(result) < 0, lax.neg(result), result) return result @@ -1110,7 +1126,9 @@ def arctanh(x: ArrayLike, /) -> Array: ... jnp.arctanh(x1) Array([-0.549+1.571j, 0.347+1.571j, 0.239-1.509j], dtype=complex64) """ - return lax.atanh(*promote_args_inexact('arctanh', x)) + out = lax.atanh(*promote_args_inexact('arctanh', x)) + jnp_error._set_error_if_nan(out) + return out @export @@ -1143,7 +1161,9 @@ def sqrt(x: ArrayLike, /) -> Array: >>> jnp.sqrt(-1) Array(nan, dtype=float32, weak_type=True) """ - return lax.sqrt(*promote_args_inexact('sqrt', x)) + out = lax.sqrt(*promote_args_inexact('sqrt', x)) + jnp_error._set_error_if_nan(out) + return out @export @@ -1212,7 +1232,11 @@ def add(x: ArrayLike, y: ArrayLike, /) -> Array: Array([10, 11, 12, 13], dtype=int32) """ x, y = promote_args("add", x, y) - return lax.add(x, y) if x.dtype != bool else lax.bitwise_or(x, y) + if x.dtype == bool: + return lax.bitwise_or(x, y) + out = lax.add(x, y) + jnp_error._set_error_if_nan(out) + return out def _multiply_at(a: Array, indices: Any, b: ArrayLike) -> Array: @@ -1541,7 +1565,9 @@ def subtract(x: ArrayLike, y: ArrayLike, /) -> Array: >>> x - 10 Array([-10, -9, -8, -7], dtype=int32) """ - return lax.sub(*promote_args("subtract", x, y)) + out = lax.sub(*promote_args("subtract", x, y)) + jnp_error._set_error_if_nan(out) + return out @export @@ -1765,7 +1791,9 @@ def float_power(x: ArrayLike, y: ArrayLike, /) -> Array: >>> jnp.float_power(-3, 1.7) Array(nan, dtype=float32, weak_type=True) """ - return lax.pow(*promote_args_inexact("float_power", x, y)) + out = lax.pow(*promote_args_inexact("float_power", x, y)) + jnp_error._set_error_if_nan(out) + return out @export @@ -2443,7 +2471,10 @@ def true_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array: :func:`jax.numpy.floor_divide` for integer division """ x1, x2 = promote_args_inexact("true_divide", x1, x2) - return lax.div(x1, x2) + jnp_error._set_error_if_divide_by_zero(x2) + out = lax.div(x1, x2) + jnp_error._set_error_if_nan(out) + return out @export @@ -2493,6 +2524,7 @@ def floor_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array: Array([3., 2., 2.], dtype=float32) """ x1, x2 = promote_args_numeric("floor_divide", x1, x2) + jnp_error._set_error_if_divide_by_zero(x2) dtype = dtypes.dtype(x1) if dtypes.issubdtype(dtype, np.unsignedinteger): return lax.div(x1, x2) @@ -2547,6 +2579,7 @@ def divmod(x1: ArrayLike, x2: ArrayLike, /) -> tuple[Array, Array]: if dtypes.issubdtype(dtypes.dtype(x1), np.integer): return floor_divide(x1, x2), remainder(x1, x2) else: + jnp_error._set_error_if_divide_by_zero(x2) return _float_divmod(x1, x2) @@ -2582,8 +2615,9 @@ def power(x1: ArrayLike, x2: ArrayLike, /) -> Array: :func:`jax.lax.integer_pow`. - When ``x2`` is a traced scalar or an array, ``jnp.power`` lowers to :func:`jax.lax.pow`. - - ``jnp.power`` raises a ``TypeError`` for integer type raised to negative - integer power. + - ``jnp.power`` raises a ``TypeError`` for integer type raised to a concrete + negative integer power. For a non-concrete power, the operation is invalid + and the returned value is implementation-defined. - ``jnp.power`` returns ``nan`` for negative value raised to the power of non-integer values. @@ -2619,6 +2653,11 @@ def power(x1: ArrayLike, x2: ArrayLike, /) -> Array: [nan, 27., 1.]], dtype=float32) """ check_arraylike("power", x1, x2) + + # Must do __jax_array__ conversion prior to dtype check. + x1 = x1.__jax_array__() if hasattr(x1, "__jax_array__") else x1 + x2 = x2.__jax_array__() if hasattr(x2, "__jax_array__") else x2 + check_no_float0s("power", x1, x2) # We apply special cases, both for algorithmic and autodiff reasons: @@ -2645,7 +2684,9 @@ def power(x1: ArrayLike, x2: ArrayLike, /) -> Array: return lax.integer_pow(x1, x2) # Handle cases #2 and #3 under a jit: - return _power(x1, x2) + out = _power(x1, x2) + jnp_error._set_error_if_nan(out) + return out @export def pow(x1: ArrayLike, x2: ArrayLike, /) -> Array: @@ -2771,7 +2812,9 @@ def log2(x: ArrayLike, /) -> Array: im = lax.imag(r) ln2 = lax.log(_constant_like(re, 2)) return lax.complex(lax.div(re, ln2), lax.div(im, ln2)) - return lax.div(lax.log(x), lax.log(_constant_like(x, 2))) + out = lax.div(lax.log(x), lax.log(_constant_like(x, 2))) + jnp_error._set_error_if_nan(out) + return out @export @@ -2801,7 +2844,9 @@ def log10(x: ArrayLike, /) -> Array: im = lax.imag(r) ln10 = lax.log(_constant_like(re, 10)) return lax.complex(lax.div(re, ln10), lax.div(im, ln10)) - return lax.div(lax.log(x), lax.log(_constant_like(x, 10))) + out = lax.div(lax.log(x), lax.log(_constant_like(x, 10))) + jnp_error._set_error_if_nan(out) + return out @export @@ -3054,6 +3099,7 @@ def remainder(x1: ArrayLike, x2: ArrayLike, /) -> Array: [ 0., 2., -2.]], dtype=float32) """ x1, x2 = promote_args_numeric("remainder", x1, x2) + jnp_error._set_error_if_divide_by_zero(x2) zero = _constant_like(x1, 0) if dtypes.issubdtype(x2.dtype, np.integer): x2 = _where(x2 == 0, lax._ones(x2), x2) @@ -3061,7 +3107,9 @@ def remainder(x1: ArrayLike, x2: ArrayLike, /) -> Array: trunc_mod_not_zero = lax.ne(trunc_mod, zero) do_plus = lax.bitwise_and( lax.ne(lax.lt(trunc_mod, zero), lax.lt(x2, zero)), trunc_mod_not_zero) - return lax.select(do_plus, lax.add(trunc_mod, x2), trunc_mod) + out = lax.select(do_plus, lax.add(trunc_mod, x2), trunc_mod) + jnp_error._set_error_if_nan(out) + return out @export @@ -3109,7 +3157,9 @@ def fmod(x1: ArrayLike, x2: ArrayLike, /) -> Array: check_arraylike("fmod", x1, x2) if dtypes.issubdtype(dtypes.result_type(x1, x2), np.integer): x2 = _where(x2 == 0, lax._ones(x2), x2) - return lax.rem(*promote_args_numeric("fmod", x1, x2)) + out = lax.rem(*promote_args_numeric("fmod", x1, x2)) + jnp_error._set_error_if_nan(out) + return out @export @@ -3451,7 +3501,7 @@ def isinf(x: ArrayLike, /) -> Array: >>> jnp.isinf(x) Array([False, True, False, True, False], dtype=bool) """ - check_arraylike("isinf", x) + x = ensure_arraylike("isinf", x) dtype = dtypes.dtype(x) if dtypes.issubdtype(dtype, np.floating): return lax.eq(lax.abs(x), _constant_like(x, np.inf)) @@ -3464,7 +3514,7 @@ def isinf(x: ArrayLike, /) -> Array: return lax.full_like(x, False, dtype=np.bool_) -def _isposneginf(infinity: float, x: ArrayLike, out) -> Array: +def _isposneginf(infinity: float, x: Array, out) -> Array: if out is not None: raise NotImplementedError("The 'out' argument to isneginf/isposinf is not supported.") dtype = dtypes.dtype(x) @@ -3507,6 +3557,7 @@ def isposinf(x, /, out=None): >>> jnp.isposinf(x) Array([False, False, True, False, False], dtype=bool) """ + x = ensure_arraylike("isposinf", x) return _isposneginf(np.inf, x, out) @@ -3541,6 +3592,7 @@ def isneginf(x, /, out=None): >>> jnp.isneginf(x) Array([ True, False, False, False, False], dtype=bool) """ + x = ensure_arraylike("isneginf", x) return _isposneginf(-np.inf, x, out) @@ -3575,7 +3627,7 @@ def isnan(x: ArrayLike, /) -> Array: >>> jnp.isnan(x) Array([False, False, False, True], dtype=bool) """ - check_arraylike("isnan", x) + x = ensure_arraylike("isnan", x) return lax.ne(x, x) @@ -3591,9 +3643,9 @@ def heaviside(x1: ArrayLike, x2: ArrayLike, /) -> Array: .. math:: \mathrm{heaviside}(x1, x2) = \begin{cases} - 0., & x < 0\\ - x2, & x = 0\\ - 1., & x > 0. + 0, & x1 < 0\\ + x2, & x1 = 0\\ + 1, & x1 > 0. \end{cases} Args: diff --git a/jax/_src/numpy/util.py b/jax/_src/numpy/util.py index e281c63ae654..bcfb12673806 100644 --- a/jax/_src/numpy/util.py +++ b/jax/_src/numpy/util.py @@ -27,7 +27,8 @@ from jax._src.lib import xla_client as xc from jax._src.sharding_impls import SingleDeviceSharding from jax._src.util import safe_zip, safe_map, set_module -from jax._src.typing import Array, ArrayLike, DimSize, DType, DTypeLike, Shape +from jax._src.typing import (Array, ArrayLike, DimSize, DType, DTypeLike, + Shape, SupportsNdim, SupportsShape, SupportsSize) from jax.sharding import Sharding import numpy as np @@ -69,13 +70,13 @@ def _rank_promotion_warning_or_error(fun_name: str, shapes: Sequence[Shape]): msg = ("Following NumPy automatic rank promotion for {} on shapes {}. " "Set the jax_numpy_rank_promotion config option to 'allow' to " "disable this warning; for more information, see " - "https://jax.readthedocs.io/en/latest/rank_promotion_warning.html.") + "https://docs.jax.dev/en/latest/rank_promotion_warning.html.") warnings.warn(msg.format(fun_name, ' '.join(map(str, shapes)))) elif config.numpy_rank_promotion.value == "raise": msg = ("Operands could not be broadcast together for {} on shapes {} " "and with the config option jax_numpy_rank_promotion='raise'. " "For more information, see " - "https://jax.readthedocs.io/en/latest/rank_promotion_warning.html.") + "https://docs.jax.dev/en/latest/rank_promotion_warning.html.") raise ValueError(msg.format(fun_name, ' '.join(map(str, shapes)))) @@ -158,7 +159,7 @@ def ensure_arraylike(fun_name: str, /, *args: Any) -> Array | tuple[Array, ...]: return tuple(_arraylike_asarray(arg) for arg in args) # pytype: disable=bad-return-type -def ensure_arraylike_tuple(fun_name: str, tup: tuple[Any, ...]) -> tuple[Array, ...]: +def ensure_arraylike_tuple(fun_name: str, tup: Sequence[Any]) -> tuple[Array, ...]: """Check that argument elements are arraylike and convert to a tuple of arrays. This is useful because ensure_arraylike with a single argument returns a single array. @@ -258,7 +259,7 @@ def _broadcast_arrays(*args: ArrayLike) -> list[Array]: def _broadcast_to(arr: ArrayLike, shape: DimSize | Shape, sharding=None ) -> Array: - check_arraylike("broadcast_to", arr) + arr = ensure_arraylike("broadcast_to", arr) arr = arr if isinstance(arr, Array) else lax.asarray(arr) if not isinstance(shape, tuple) and np.ndim(shape) == 0: shape = (shape,) @@ -313,7 +314,7 @@ def normalize_device_to_sharding(device: xc.Device | Sharding | None) -> Shardin @export -def ndim(a: ArrayLike) -> int: +def ndim(a: ArrayLike | SupportsNdim) -> int: """Return the number of dimensions of an array. JAX implementation of :func:`numpy.ndim`. Unlike ``np.ndim``, this function @@ -321,7 +322,7 @@ def ndim(a: ArrayLike) -> int: tuple. Args: - a: array-like object. + a: array-like object, or any object with an ``ndim`` attribute. Returns: An integer specifying the number of dimensions of ``a``. @@ -346,13 +347,18 @@ def ndim(a: ArrayLike) -> int: >>> x.ndim 1 """ + if hasattr(a, "ndim"): + return a.ndim # Deprecation warning added 2025-2-20. check_arraylike("ndim", a, emit_warning=True) - return np.ndim(a) # NumPy dispatches to a.ndim if available. + if hasattr(a, "__jax_array__"): + a = a.__jax_array__() + # NumPy dispatches to a.ndim if available. + return np.ndim(a) # type: ignore[arg-type] @export -def shape(a: ArrayLike) -> tuple[int, ...]: +def shape(a: ArrayLike | SupportsShape) -> tuple[int, ...]: """Return the shape an array. JAX implementation of :func:`numpy.shape`. Unlike ``np.shape``, this function @@ -360,7 +366,7 @@ def shape(a: ArrayLike) -> tuple[int, ...]: tuple. Args: - a: array-like object. + a: array-like object, or any object with a ``shape`` attribute. Returns: An tuple of integers representing the shape of ``a``. @@ -385,13 +391,18 @@ def shape(a: ArrayLike) -> tuple[int, ...]: >>> x.shape (10,) """ + if hasattr(a, "shape"): + return a.shape # Deprecation warning added 2025-2-20. check_arraylike("shape", a, emit_warning=True) - return np.shape(a) # NumPy dispatches to a.shape if available. + if hasattr(a, "__jax_array__"): + a = a.__jax_array__() + # NumPy dispatches to a.shape if available. + return np.shape(a) # type: ignore[arg-type] @export -def size(a: ArrayLike, axis: int | None = None) -> int: +def size(a: ArrayLike | SupportsSize | SupportsShape, axis: int | None = None) -> int: """Return number of elements along a given axis. JAX implementation of :func:`numpy.size`. Unlike ``np.size``, this function @@ -399,7 +410,8 @@ def size(a: ArrayLike, axis: int | None = None) -> int: tuple. Args: - a: array-like object + a: array-like object, or any object with a ``size`` attribute when ``axis`` is not + specified, or with a ``shape`` attribute when ``axis`` is specified. axis: optional integer along which to count elements. By default, return the total number of elements. @@ -428,6 +440,12 @@ def size(a: ArrayLike, axis: int | None = None) -> int: >>> y.size 6 """ + if (axis is None and hasattr(a, "size")) or (axis is not None and hasattr(a, "shape")): + # NumPy dispatches to a.size/a.shape if available. + return np.size(a, axis=axis) # type: ignore[arg-type] # Deprecation warning added 2025-2-20. check_arraylike("size", a, emit_warning=True) - return np.size(a, axis=axis) # NumPy dispatches to a.size if available. + if hasattr(a, "__jax_array__"): + a = a.__jax_array__() + # NumPy dispatches to a.size/a.shape if available. + return np.size(a, axis=axis) # type: ignore[arg-type] diff --git a/jax/_src/numpy/vectorize.py b/jax/_src/numpy/vectorize.py index e6ad1386a52e..a60e681427f5 100644 --- a/jax/_src/numpy/vectorize.py +++ b/jax/_src/numpy/vectorize.py @@ -307,7 +307,7 @@ def wrapped(*args, **kwargs): f" promotion for jnp.vectorize function with signature {signature}." " Set the jax_numpy_rank_promotion config option to 'allow' to" " disable this message; for more information, see" - " https://jax.readthedocs.io/en/latest/rank_promotion_warning.html.") + " https://docs.jax.dev/en/latest/rank_promotion_warning.html.") if config.numpy_rank_promotion.value == "warn": warnings.warn(msg) elif config.numpy_rank_promotion.value == "raise": diff --git a/jax/_src/ops/scatter.py b/jax/_src/ops/scatter.py index e19be6622168..eccbbdde006e 100644 --- a/jax/_src/ops/scatter.py +++ b/jax/_src/ops/scatter.py @@ -272,8 +272,7 @@ def segment_prod(data: ArrayLike, data: an array with the values to be reduced. segment_ids: an array with integer dtype that indicates the segments of `data` (along its leading axis) to be reduced. Values can be repeated and - need not be sorted. Values outside of the range [0, num_segments) are - dropped and do not contribute to the result. + need not be sorted. num_segments: optional, an int with nonnegative value indicating the number of segments. The default is set to be the minimum number of segments that would support all indices in ``segment_ids``, calculated as @@ -283,11 +282,11 @@ def segment_prod(data: ArrayLike, indices_are_sorted: whether ``segment_ids`` is known to be sorted. unique_indices: whether `segment_ids` is known to be free of duplicates. bucket_size: size of bucket to group indices into. ``segment_prod`` is - performed on each bucket separately to improve numerical stability of - addition. Default ``None`` means no bucketing. + performed on each bucket separately to improve numerical stability. + Default ``None`` means no bucketing. mode: a :class:`jax.lax.GatherScatterMode` value describing how out-of-bounds indices should be handled. By default, values outside of the - range [0, num_segments) are dropped and do not contribute to the sum. + range [0, num_segments) are dropped and do not contribute to the result. Returns: An array with shape :code:`(num_segments,) + data.shape[1:]` representing the @@ -328,8 +327,7 @@ def segment_max(data: ArrayLike, data: an array with the values to be reduced. segment_ids: an array with integer dtype that indicates the segments of `data` (along its leading axis) to be reduced. Values can be repeated and - need not be sorted. Values outside of the range [0, num_segments) are - dropped and do not contribute to the result. + need not be sorted. num_segments: optional, an int with nonnegative value indicating the number of segments. The default is set to be the minimum number of segments that would support all indices in ``segment_ids``, calculated as @@ -342,7 +340,7 @@ def segment_max(data: ArrayLike, performed on each bucket separately. Default ``None`` means no bucketing. mode: a :class:`jax.lax.GatherScatterMode` value describing how out-of-bounds indices should be handled. By default, values outside of the - range [0, num_segments) are dropped and do not contribute to the sum. + range [0, num_segments) are dropped and do not contribute to the result. Returns: An array with shape :code:`(num_segments,) + data.shape[1:]` representing the @@ -383,8 +381,7 @@ def segment_min(data: ArrayLike, data: an array with the values to be reduced. segment_ids: an array with integer dtype that indicates the segments of `data` (along its leading axis) to be reduced. Values can be repeated and - need not be sorted. Values outside of the range [0, num_segments) are - dropped and do not contribute to the result. + need not be sorted. num_segments: optional, an int with nonnegative value indicating the number of segments. The default is set to be the minimum number of segments that would support all indices in ``segment_ids``, calculated as @@ -397,7 +394,7 @@ def segment_min(data: ArrayLike, performed on each bucket separately. Default ``None`` means no bucketing. mode: a :class:`jax.lax.GatherScatterMode` value describing how out-of-bounds indices should be handled. By default, values outside of the - range [0, num_segments) are dropped and do not contribute to the sum. + range [0, num_segments) are dropped and do not contribute to the result. Returns: An array with shape :code:`(num_segments,) + data.shape[1:]` representing the diff --git a/jax/_src/ops/special.py b/jax/_src/ops/special.py index fe4c46394832..6ed8804bc8e2 100644 --- a/jax/_src/ops/special.py +++ b/jax/_src/ops/special.py @@ -47,16 +47,15 @@ def logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None, JAX implementation of :func:`scipy.special.logsumexp`. .. math:: - \mathrm{logsumexp}(a) = \mathrm{log} \sum_j b \cdot \mathrm{exp}(a_{ij}) + \operatorname{logsumexp} a = \log \sum_i b_i \exp a_i - where the :math:`j` indices range over one or more dimensions to be reduced. + where the :math:`i` indices range over one or more dimensions to be reduced. Args: a: the input array axis: int or sequence of ints, default=None. Axis along which the sum to be computed. If None, the sum is computed along all the axes. - b: scaling factors for :math:`\mathrm{exp}(a)`. Must be broadcastable to the - shape of `a`. + b: scaling factors for the exponentials. Must be broadcastable to the shape of `a`. keepdims: If ``True``, the axes that are reduced are left in the output as dimensions of size 1. return_sign: If ``True``, the output will be a ``(result, sign)`` pair, diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 206c2a73fbed..a74206c46ce7 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -67,6 +67,55 @@ def __repr__(self): SEMAPHORE_INTERPRET_DTYPE = jnp.int16 SEMAPHORE_MAX_VALUE = jnp.iinfo(SEMAPHORE_INTERPRET_DTYPE).max +class AbstractSemaphoreTyRules: + @staticmethod + def pallas_interpret_element_aval(_) -> jax_core.ShapedArray: + return jax_core.ShapedArray((), SEMAPHORE_INTERPRET_DTYPE) + + @staticmethod + def physical_element_aval(_) -> jax_core.ShapedArray: + return jax_core.ShapedArray((), jnp.int32) + +# TODO(sharadmv): implement dtype rules for AbstractSemaphoreTy +class AbstractSemaphoreTy(dtypes.ExtendedDType): + name: str + _rules = AbstractSemaphoreTyRules + + def __repr__(self) -> str: + return self.name + + def __eq__(self, other): + return self.__class__ == other.__class__ + + def __hash__(self) -> int: + return hash(self.__class__) + +class semaphore_dtype(dtypes.extended): + """Common dtype for all kinds of semaphore dtypes. + + This is an abstract class that should never be instantiated, but rather + exists for the sake of `jnp.issubdtype`. + """ + +class semaphore(semaphore_dtype): + """Regular semaphore dtype. + + Like its superclass, this class should never be instantiated. + """ + +class Semaphore(AbstractSemaphoreTy): + name = "semaphore" + type = semaphore + +class barrier_semaphore(semaphore_dtype): + """Barrier semaphore dtype. + + Like its superclass, this class should never be instantiated. + """ + +class BarrierSemaphore(AbstractSemaphoreTy): + name = "barrier_semaphore" + type = barrier_semaphore @runtime_checkable class CompilerParams(Protocol): @@ -1040,7 +1089,10 @@ def wrapped(f): debug_info=api_util.debug_info("pallas_core_map", f, (), {})), in_tree) - with jax_core.extend_axis_env_nd(mesh.shape.items()): + with ( + tracing_grid_env(tuple(mesh.shape.values()), mapped_dims=()), + jax_core.extend_axis_env_nd(mesh.shape.items()), + ): jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, flat_args) out = core_map_p.bind(*consts, jaxpr=jaxpr, mesh=mesh, compiler_params=compiler_params, @@ -1095,6 +1147,7 @@ def default_mesh_discharge_rule( interpret, cost_estimate, name, + memory_space=MemorySpace.ANY, ): """Discharges a ``core_map`` over a mesh to a ``pallas_call``.""" del out_avals # Unused. @@ -1111,13 +1164,9 @@ def body(*args): for eff in jaxpr.effects if isinstance(eff, state_types.WriteEffect) ) - any_spec = BlockSpec(memory_space=MemorySpace.ANY) - grid_spec = GridSpec( - grid=tuple(mesh.shape.items()), - in_specs=[any_spec] * len(in_avals), - out_specs=[any_spec] * len(modified_idxs), - ) + spec = BlockSpec(memory_space=memory_space) from jax._src.pallas import pallas_call # Avoid circular dependency. + outs = pallas_call._pallas_call( body, name=name, @@ -1125,7 +1174,11 @@ def body(*args): input_output_aliases={ in_idx: out_idx for out_idx, in_idx in enumerate(modified_idxs) }, - grid_spec=grid_spec, + grid_spec=GridSpec( + grid=tuple(mesh.shape.items()), + in_specs=[spec] * len(in_avals), + out_specs=[spec] * len(modified_idxs), + ), mesh=mesh, compiler_params=compiler_params, interpret=interpret, diff --git a/jax/_src/pallas/fuser/BUILD b/jax/_src/pallas/fuser/BUILD index 66bbac33aabb..a62a9937d91d 100644 --- a/jax/_src/pallas/fuser/BUILD +++ b/jax/_src/pallas/fuser/BUILD @@ -33,7 +33,7 @@ pytype_strict_library( deps = [ ":block_spec", ":custom_evaluate", - ":fusable", + ":fusible", ":fusion", ":jaxpr_fusion", ], @@ -58,9 +58,9 @@ pytype_strict_library( ) pytype_strict_library( - name = "fusable", + name = "fusible", srcs = [ - "fusable.py", + "fusible.py", ], deps = [ ":fusion", @@ -91,25 +91,26 @@ pytype_strict_library( "jaxpr_fusion.py", ], deps = [ - ":fusable", - ":fusable_dtype", + ":fusible", + ":fusible_dtype", ":fusion", "//jax", "//jax:api_util", "//jax:core", "//jax:partial_eval", "//jax:tree_util", + "//jax:util", ], ) pytype_strict_library( - name = "fusable_dtype", + name = "fusible_dtype", srcs = [ - "fusable_dtype.py", + "fusible_dtype.py", ], deps = [ ":block_spec", - ":fusable", + ":fusible", "//jax", "//jax:api_util", "//jax:core", diff --git a/jax/_src/pallas/fuser/__init__.py b/jax/_src/pallas/fuser/__init__.py index 3295c8f1061a..39720100eb1d 100644 --- a/jax/_src/pallas/fuser/__init__.py +++ b/jax/_src/pallas/fuser/__init__.py @@ -17,6 +17,6 @@ from jax._src.pallas.fuser.block_spec import pull_block_spec as pull_block_spec from jax._src.pallas.fuser.block_spec import push_block_spec as push_block_spec from jax._src.pallas.fuser.custom_evaluate import evaluate as evaluate -from jax._src.pallas.fuser.fusable import fusable as fusable +from jax._src.pallas.fuser.fusible import fusible as fusible from jax._src.pallas.fuser.fusion import Fusion as Fusion from jax._src.pallas.fuser.jaxpr_fusion import fuse as fuse diff --git a/jax/_src/pallas/fuser/block_spec.py b/jax/_src/pallas/fuser/block_spec.py index de0cdd204f3c..9524ce4ca4d2 100644 --- a/jax/_src/pallas/fuser/block_spec.py +++ b/jax/_src/pallas/fuser/block_spec.py @@ -170,8 +170,11 @@ def get_out_block_indices(self): _illegal = object() -_sp_env = threading.local() -_sp_env.scalar_prefetch = None +class _SpEnv(threading.local): + def __init__(self): + self.scalar_prefetch = None + +_sp_env = _SpEnv() @contextlib.contextmanager @@ -236,9 +239,7 @@ def wrapped(*args, **kwargs): jaxpr, consts, in_tree, out_tree_ = fuser_utils.make_jaxpr( f, *args, **kwargs ) - # TODO(sharadmv): handle these consts better, they should correspond to - # scalar prefetch. - del consts, out_tree_ + del out_tree_ jaxpr_out_usages = [{Usage.REGULAR}] * len(jaxpr.outvars) block_specs_ = jax.tree.map( _unwrap_block_spec_scalar_prefetch, out_block_specs @@ -260,6 +261,7 @@ def wrapped(*args, **kwargs): ) kernel_fn = make_kernel_function( jaxpr, + consts, in_tree, out_tree, read_usage_env, @@ -405,6 +407,7 @@ def _get_in_block_spec(v, usage): def make_kernel_function( jaxpr: core.Jaxpr, + consts, in_tree, out_tree, read_usage_env, @@ -502,6 +505,8 @@ def read_env(atom): def write_env(var, val): env[var] = val + for const, constvar in zip(consts, jaxpr.constvars): + env[constvar] = const for invar, arg, usage in zip(jaxpr.invars, flat_args, invar_usages): if Usage.REGULAR in usage: env[invar] = arg @@ -1229,6 +1234,7 @@ def _jit_eval_rule(ctx: KernelEvalContext, *args, jaxpr, **kwargs): ) kernel_fn = make_kernel_function( jaxpr, + (), in_tree, out_tree, read_usage_env, @@ -1286,6 +1292,7 @@ def _custom_jvp_call_eval_rule( ) kernel_fn = make_kernel_function( jaxpr, + (), in_tree, out_tree, read_usage_env, diff --git a/jax/_src/pallas/fuser/fusable.py b/jax/_src/pallas/fuser/fusible.py similarity index 51% rename from jax/_src/pallas/fuser/fusable.py rename to jax/_src/pallas/fuser/fusible.py index b075c6d136c9..289a9dc268b4 100644 --- a/jax/_src/pallas/fuser/fusable.py +++ b/jax/_src/pallas/fuser/fusible.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Fusable primitive.""" +"""Fusible primitive.""" +from typing import Any import jax from jax._src import api_util @@ -24,12 +25,8 @@ from jax._src.interpreters import partial_eval as pe from jax._src.pallas.fuser import fusion as fusion_lib -fusable_p = jax_core.Primitive('fusable') -fusable_p.multiple_results = True - - -def _get_aval(x): - return jax_core.raise_to_shaped(jax_core.get_aval(x)) +fusible_p = jax_core.Primitive('fusible') +fusible_p.multiple_results = True def _make_trivial_fusion(x: jax.Array) -> fusion_lib.Fusion: @@ -40,44 +37,50 @@ def _make_trivial_fusion(x: jax.Array) -> fusion_lib.Fusion: ) -def fusable(f): - def wrapper(*args): - def wrapped(*args): - in_fusions = tree_util.tree_map(_make_trivial_fusion, args) - return f(*in_fusions, None) - - flat_args, in_tree = tree_util.tree_flatten(args) - debug_info = api_util.debug_info('fusable', wrapped, args, {}) - flat_fun, out_tree_thunk = api_util.flatten_fun_nokwargs( - lu.wrap_init(wrapped, debug_info=debug_info), in_tree - ) - flat_avals = [_get_aval(x) for x in flat_args] - jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals) - out_tree = out_tree_thunk() - out = fusable_p.bind( - *consts, - *flat_args, - jaxpr=jaxpr, - num_consts=len(consts), - in_tree=in_tree, - out_tree=out_tree, - func=f, - ) - return tree_util.tree_unflatten(out_tree, out) - - return wrapper - - -@fusable_p.def_impl +def fusible(f=None, *, output_fusion_prefix: Any = True): + def decorator(f): + def wrapper(*args): + def wrapped(*args): + in_fusions = tree_util.tree_map(_make_trivial_fusion, args) + return f(*in_fusions, None) + + flat_args, in_tree = tree_util.tree_flatten(args) + debug_info = api_util.debug_info('fusible', wrapped, args, {}) + flat_fun, out_tree_thunk = api_util.flatten_fun_nokwargs( + lu.wrap_init(wrapped, debug_info=debug_info), in_tree + ) + flat_avals = [jax_core.get_aval(x) for x in flat_args] + jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals) + out_tree = out_tree_thunk() + out = fusible_p.bind( + *consts, + *flat_args, + jaxpr=jaxpr, + num_consts=len(consts), + in_tree=in_tree, + out_tree=out_tree, + func=f, + output_fusion_prefix=output_fusion_prefix, + ) + return tree_util.tree_unflatten(out_tree, out) + + return wrapper + + if f is not None: + return decorator(f) + return decorator + + +@fusible_p.def_impl def _(*consts_and_args, jaxpr, num_consts, **_): consts, args = util.split_list(consts_and_args, [num_consts]) return jax_core.eval_jaxpr(jaxpr, consts, *args) -mlir.register_lowering(fusable_p, mlir.lower_fun(fusable_p.impl)) +mlir.register_lowering(fusible_p, mlir.lower_fun(fusible_p.impl)) -@fusable_p.def_abstract_eval +@fusible_p.def_abstract_eval def _(*args, jaxpr, **kwargs): del args, kwargs return [v.aval for v in jaxpr.outvars] diff --git a/jax/_src/pallas/fuser/fusable_dtype.py b/jax/_src/pallas/fuser/fusible_dtype.py similarity index 95% rename from jax/_src/pallas/fuser/fusable_dtype.py rename to jax/_src/pallas/fuser/fusible_dtype.py index e5bc9ab683ab..628d253e090a 100644 --- a/jax/_src/pallas/fuser/fusable_dtype.py +++ b/jax/_src/pallas/fuser/fusible_dtype.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Custom fusable dtypes.""" +"""Custom fusible dtypes.""" import abc import dataclasses @@ -34,7 +34,7 @@ from jax._src.pallas import pallas_call from jax._src.pallas import primitives as pallas_primitives from jax._src.pallas.fuser import block_spec -from jax._src.pallas.fuser.fusable import fusable_p +from jax._src.pallas.fuser.fusible import fusible_p from jax._src.state import discharge as state_discharge from jax._src.state import primitives as state_primitives from jax._src.util import foreach @@ -54,7 +54,7 @@ @pack_dtype_p.def_abstract_eval def pack_dtype_abstract_eval(*xs, dtype): - if dtypes.issubdtype(dtype, FusableElementDType): + if dtypes.issubdtype(dtype, FusibleElementDType): return dtype.abstract_pack(*xs) raise ValueError("Attempted to pack non-fusion dtype: {dtype}") @@ -69,7 +69,7 @@ def pack(*xs, dtype): @unpack_dtype_p.def_abstract_eval def unpack_dtype_abstract_eval(x): - if dtypes.issubdtype(x.dtype, FusableElementDType): + if dtypes.issubdtype(x.dtype, FusibleElementDType): return x.dtype.abstract_unpack(x) elif isinstance(x.dtype, pallas_core.AbstractMemoryRef): raise NotImplementedError() @@ -80,22 +80,20 @@ def unpack(x): return unpack_dtype_p.bind(x) -class FusableElementDType(dtypes.extended): - """Scalar dtype for fusable dtypes.""" +class FusibleElementDType(dtypes.extended): + """Scalar dtype for fusible dtypes.""" - pass - -class FusableTyRules: +class FusibleTyRules: allow_conversion: bool = False class FusionDType(dtypes.ExtendedDType, metaclass=abc.ABCMeta): - """Base class for fusable extended dtypes.""" + """Base class for fusible extended dtypes.""" _op_registry = {} - _rules = FusableTyRules - type = FusableElementDType + _rules = FusibleTyRules + type = FusibleElementDType @abc.abstractmethod def abstract_unpack(self, x) -> Sequence[Any]: @@ -126,7 +124,7 @@ def pull_block_spec_one_step(self, *args, **kwargs): def physicalize(f): - """Runs a function that contains fusable extended dtypes.""" + """Runs a function that contains fusible extended dtypes.""" def wrapper(*args, **kwargs): if kwargs: @@ -205,7 +203,7 @@ class Context: def physicalize_interp( jaxpr: core.Jaxpr, consts: Sequence[core.Value], *args: core.Value ): - """Physicalizes a jaxpr by replacing fusable dtypes with physical types.""" + """Physicalizes a jaxpr by replacing fusible dtypes with physical types.""" # TODO: Merge into JAX core. env: dict[core.Var, Any] = {} @@ -448,12 +446,12 @@ def _pack_dtype_pull_rule( return dtype.pull_block_spec_one_step(block_spec) # pytype: disable=attribute-error -def _fusable_physicalize_rule( +def _fusible_physicalize_rule( _, *consts_and_args, jaxpr, num_consts, in_tree, out_tree, func ): consts, _ = util.split_list(consts_and_args, [num_consts]) new_jaxpr = physicalize_closed_jaxpr(core.ClosedJaxpr(jaxpr, consts)) - return fusable_p.bind( + return fusible_p.bind( *consts_and_args, jaxpr=new_jaxpr.jaxpr, num_consts=num_consts, @@ -463,4 +461,4 @@ def _fusable_physicalize_rule( ) -_physicalize_rules[fusable_p] = _fusable_physicalize_rule +_physicalize_rules[fusible_p] = _fusible_physicalize_rule diff --git a/jax/_src/pallas/fuser/jaxpr_fusion.py b/jax/_src/pallas/fuser/jaxpr_fusion.py index 3d36b8f3e2fd..95768d71f792 100644 --- a/jax/_src/pallas/fuser/jaxpr_fusion.py +++ b/jax/_src/pallas/fuser/jaxpr_fusion.py @@ -14,35 +14,31 @@ """Fuses a function.""" +from collections.abc import Sequence +import functools from typing import Any - import jax from jax._src import api_util from jax._src import core as jax_core from jax._src import linear_util as lu from jax._src import tree_util from jax._src.interpreters import partial_eval as pe - -from jax._src.pallas.fuser import fusable_dtype +from jax._src.pallas.fuser import fusible_dtype from jax._src.pallas.fuser import fusion as fusion_lib -from jax._src.pallas.fuser.fusable import fusable_p - - -def _get_aval(x): - return jax_core.raise_to_shaped(jax_core.get_aval(x)) +from jax._src.pallas.fuser.fusible import fusible_p def fuse(f=None, *, physicalize: bool = False, debug: bool = False): - """Fuses a function into a single fusable. + """Fuses a function into a single fusible. Args: f: The function to fuse. physicalize: (experimental) whether to physicalize the function. debug: Whether to print debug information. - There should be a single call to a `fusable` inside the body of `f`. `fuse` + There should be a single call to a `fusible` inside the body of `f`. `fuse` returns a transformed function that will fuse the surrounding computation into - the fusable and invoke it. + the fusible and invoke it. """ def decorator(f): @@ -52,7 +48,7 @@ def wrapper(*args, **kwargs): flat_fun, out_tree_thunk = api_util.flatten_fun( lu.wrap_init(f, debug_info=debug_info), in_tree ) - flat_avals = [_get_aval(x) for x in flat_args] + flat_avals = [jax_core.get_aval(x) for x in flat_args] jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals) if debug: print("Jaxpr before fusion:") @@ -62,7 +58,7 @@ def wrapper(*args, **kwargs): return tree_util.tree_unflatten(out_tree, out_flat) if physicalize: - wrapper = fusable_dtype.physicalize(wrapper) + wrapper = fusible_dtype.physicalize(wrapper) return wrapper if f is not None: @@ -70,12 +66,12 @@ def wrapper(*args, **kwargs): return decorator -_fusable: dict[jax_core.Primitive, Any] = {} +_fusible: dict[jax_core.Primitive, Any] = {} -def construct_fusion( +def _construct_fusion_jaxpr( candidate_values, jaxpr: jax_core.Jaxpr, outvars, *invars, **kwargs -) -> fusion_lib.Fusion: +): flat_outvars, out_tree = tree_util.tree_flatten(outvars) flat_invars, in_tree = tree_util.tree_flatten((invars, kwargs)) new_jaxpr_no_dce = jaxpr.replace( @@ -94,12 +90,6 @@ def construct_fusion( c for used, c in zip(used_consts, candidate_values, strict=True) if used ) kernel_in_tree = tree_util.tree_structure((invars, kwargs)) - - def _fn(*args, **kwargs): - flat_args, _ = tree_util.tree_flatten((args, kwargs)) - out_flat = jax_core.eval_jaxpr(new_jaxpr, new_values, *flat_args) - return tree_util.tree_unflatten(out_tree, out_flat) - flat_in_type = [ jax.ShapeDtypeStruct(x.aval.shape, x.aval.dtype) for x in flat_invars ] @@ -108,9 +98,158 @@ def _fn(*args, **kwargs): out_tree, [jax.ShapeDtypeStruct(x.aval.shape, x.aval.dtype) for x in flat_outvars], ) + return new_jaxpr, new_values, in_type, out_type, out_tree + + +def construct_fusion( + candidate_values, jaxpr: jax_core.Jaxpr, outvars, *invars, **kwargs +) -> fusion_lib.Fusion: + new_jaxpr, new_values, in_type, out_type, out_tree = _construct_fusion_jaxpr( + candidate_values, jaxpr, outvars, *invars, **kwargs + ) + + def _fn(*args, **kwargs): + flat_args, _ = tree_util.tree_flatten((args, kwargs)) + out_flat = jax_core.eval_jaxpr(new_jaxpr, new_values, *flat_args) + return tree_util.tree_unflatten(out_tree, out_flat) + return fusion_lib.Fusion(_fn, in_type, out_type) +def _find_downstream( + jaxpr: jax_core.Jaxpr, in_used: Sequence[bool] +) -> tuple[bool, ...]: + # TODO(sharadmv): We use partial_eval to query downstream dependencies which + # is not an officially sanctioned way to do so, since PE is really used for + # AD. In the future, we should have a special Jaxpr API that queries this. + _, _, out_used, *_ = pe.partial_eval_jaxpr_custom( + jaxpr, + in_unknowns=in_used, + in_inst=in_used, + ensure_out_unknowns=False, + ensure_out_inst=False, + saveable=lambda *_, **__: False, + ) + return tuple(out_used) + + +def _construct_output_permutation( + used: list[tuple[bool, ...]], +) -> list[int]: + order = [] + for u in used: + true_vals = [i for i in range(len(u)) if u[i]] + order.extend(true_vals) + return [order.index(i) for i in range(len(order))] + + +def _construct_output_fusions( + candidate_values, + jaxpr, + out_tree, + fusion_eqn_index, + fusion_eqn_outvars, # Flat list of vars output by the fusible eqn + fusion_eqn_out_tree, # Tree structure of the fusible eqn outputs + output_fusion_prefix, # Pytree defining output groups +): + # 1. Create jaxpr_out: represents computation *after* the fusible + # Inputs: fusion_eqn_outvars + # Outputs: jaxpr.outvars + jaxpr_out, all_values, _, _, _ = _construct_fusion_jaxpr( + candidate_values, + jaxpr.replace( + eqns=jaxpr.eqns[:fusion_eqn_index] + + jaxpr.eqns[fusion_eqn_index + 1 :] + ), + tree_util.tree_unflatten(out_tree, jaxpr.outvars), # Original outputs + tree_util.tree_unflatten( + fusion_eqn_out_tree, fusion_eqn_outvars + ), # Fusible outputs as inputs + ) + + # 2. Group fusible outputs based on the mask + unflat_fusible_outvars = jax.tree.unflatten( + fusion_eqn_out_tree, fusion_eqn_outvars + ) + partial_flat = jax.tree.structure(output_fusion_prefix).flatten_up_to( + unflat_fusible_outvars + ) + + # 3. Calculate dependencies and check disjointness + downstream_outputs_used_masks = [] # List of bool tuples, one per group + already_used_final_outputs = set() # Indices of final outputs already claimed + for outvars_group in partial_flat: + # Identify vars in this group + used_fusible_outvars = set(jax.tree.leaves(outvars_group)) + # Create mask for jaxpr_out inputs corresponding to this group + in_used_mask = [ + True if v in used_fusible_outvars else False for v in jaxpr_out.invars + ] + # Trace dependencies through jaxpr_out to find which final outputs are affected + downstream_used_mask = _find_downstream( + jaxpr_out, in_used_mask + ) # Mask for jaxpr_out.outvars (== jaxpr.outvars) + + # Check for overlap in final output usage across groups + for i, used in enumerate(downstream_used_mask): + if used: + if i in already_used_final_outputs: + raise ValueError( + "Outputs must be disjoint in order to use separate output fusions" + ) + already_used_final_outputs.add(i) + downstream_outputs_used_masks.append(downstream_used_mask) + + # 4. Construct output permutation needed to restore original output order + output_permutation = _construct_output_permutation( + downstream_outputs_used_masks + ) + + # Construct fusions for each group by DCEing the jaxpr_out + output_fusions = [] + for i, outvars_group in enumerate(partial_flat): + flat_group_vars, _ = tree_util.tree_flatten(outvars_group) + downstream_used_mask = downstream_outputs_used_masks[i] + + used_jaxpr_invars = [False] * len(all_values) + [ + v in flat_group_vars for v in jaxpr_out.invars + ] + jaxpr_out_for_group, used_consts, _ = pe.dce_jaxpr_consts( + jaxpr_out, downstream_used_mask, instantiate=used_jaxpr_invars + ) + values_for_jaxpr = tuple( + c for used, c in zip(used_consts, all_values, strict=True) if used + ) + + def _fn(jaxpr, vals, *args, **kwargs): + flat_args, _ = tree_util.tree_flatten((args, kwargs)) + out_flat = jax_core.eval_jaxpr(jaxpr, vals, *flat_args) + return tuple(out_flat) + + fn = functools.partial(_fn, jaxpr_out_for_group, values_for_jaxpr) + in_type = jax.tree.map( + lambda v: jax.ShapeDtypeStruct(v.aval.shape, v.aval.dtype), # pytype: disable=attribute-error + outvars_group, + ) + out_type = tuple( + jax.ShapeDtypeStruct(v.aval.shape, v.aval.dtype) # pytype: disable=attribute-error + for v in jaxpr_out_for_group.outvars + ) + fusion = fusion_lib.Fusion( + fn, + (in_type, {}), + out_type, + ) + output_fusions.append(fusion) + + return ( + tree_util.tree_unflatten( + tree_util.tree_structure(output_fusion_prefix), output_fusions + ), + output_permutation, + ) + + def fuse_jaxpr( jaxpr: jax_core.Jaxpr, out_tree: tree_util.PyTreeDef, consts, *args ): @@ -118,16 +257,25 @@ def fuse_jaxpr( # Collect input fusions for i, eqn in enumerate(jaxpr.eqns): - if eqn.primitive is fusable_p: + if eqn.primitive is fusible_p: fusion_eqn_index = i break if fusion_eqn_index is None: - raise ValueError("No fusable eqn found") + raise ValueError("No fusible eqn found") fusion_eqn = jaxpr.eqns[fusion_eqn_index] + # Now let's check if we need to do any fusion at all, e.g. do the outputs of + # the jaxpr have any dependence on the fusion at all? We can DCE the jaxpr + # with all the inputs and outputs to check if there is a dependence. + dced_jaxpr, _ = pe.dce_jaxpr(jaxpr, [True] * len(jaxpr.outvars), + instantiate=True) + if not any(eqn.primitive is fusible_p for eqn in dced_jaxpr.eqns): + # Short circuit if there is nothing to fuse. + return jax_core.eval_jaxpr(dced_jaxpr, consts, *args) + candidate_values = [*consts, *args] - # Construct fusions for non-constant inputs to the fusable. + # Construct fusions for non-constant inputs to the fusible. in_fusions_flat = [ construct_fusion( candidate_values, @@ -141,21 +289,20 @@ def fuse_jaxpr( in_fusions = tree_util.tree_unflatten( fusion_eqn.params["in_tree"], in_fusions_flat ) - out_fusion = construct_fusion( + output_fusions, output_permutation = _construct_output_fusions( candidate_values, - jaxpr.replace( - eqns=jaxpr.eqns[:fusion_eqn_index] - + jaxpr.eqns[fusion_eqn_index + 1 :] - ), - tree_util.tree_unflatten(out_tree, jaxpr.outvars), - tree_util.tree_unflatten( - fusion_eqn.params["out_tree"], fusion_eqn.outvars - ), + jaxpr, + out_tree, + fusion_eqn_index, + fusion_eqn.outvars, + fusion_eqn.params["out_tree"], + fusion_eqn.params["output_fusion_prefix"], ) - # Run the fusable. - out = fusion_eqn.params["func"](*in_fusions, out_fusion) - - # Now return the flattened output (the fuse_jaxpr caller should unflatten). - out_flat = tree_util.tree_leaves(out) - assert len(out_flat) == len(jaxpr.outvars) - return out_flat + out = fusion_eqn.params["func"](*in_fusions, output_fusions) + flat_out = jax.tree.leaves(out) + permuted_out = [flat_out[i] for i in output_permutation] + assert len(permuted_out) == len(jaxpr.outvars), ( + len(permuted_out), + len(jaxpr.outvars), + ) + return permuted_out diff --git a/jax/_src/pallas/mosaic/BUILD b/jax/_src/pallas/mosaic/BUILD index 24e8341046b0..fdd3a56ac7c8 100644 --- a/jax/_src/pallas/mosaic/BUILD +++ b/jax/_src/pallas/mosaic/BUILD @@ -158,6 +158,7 @@ py_library( deps = [ ":core", ":primitives", + ":verification", "//jax", "//jax:core", "//jax:source_info_util", diff --git a/jax/_src/pallas/mosaic/core.py b/jax/_src/pallas/mosaic/core.py index f582248ee7c3..0fe825d44858 100644 --- a/jax/_src/pallas/mosaic/core.py +++ b/jax/_src/pallas/mosaic/core.py @@ -25,7 +25,6 @@ import jax from jax._src import config from jax._src import core as jax_core -from jax._src import dtypes from jax._src import util from jax._src.pallas import core as pallas_core import jax.numpy as jnp @@ -65,23 +64,23 @@ class TPUCompilerParams(pallas_core.CompilerParams): """Mosaic TPU compiler parameters. Attributes: - dimension_semantics: A list of dimension semantics for each grid - dimension of the kernel. Either "parallel" for dimensions that can - execute in any order, or "arbitrary" for dimensions that must be - executed sequentially. + dimension_semantics: A list of dimension semantics for each grid dimension + of the kernel. Either "parallel" for dimensions that can execute in any + order, or "arbitrary" for dimensions that must be executed sequentially. allow_input_fusion: A list of booleans indicating whether input fusion is allowed for each argument. - vmem_limit_bytes: Overrides the default VMEM limit for a kernel. Note - that this must be used in conjunction with the + vmem_limit_bytes: Overrides the default VMEM limit for a kernel. Note that + this must be used in conjunction with the --xla_tpu_scoped_vmem_limit_kib=N flag with N*1kib > vmem_limit_bytes. - collective_id: Indicates which barrier semaphore to use for the kernel. - Note that using the same collective_id does not guarantee that - the same barrier semaphore will be allocated between kernels. + collective_id: Indicates which barrier semaphore to use for the kernel. Note + that using the same collective_id does not guarantee that the same barrier + semaphore will be allocated between kernels. internal_scratch_in_bytes: The size of the internal scratch space used by Mosaic. flags: A dictionary of command line flags for the kernel. serialization_format: The serialization format for the kernel body. device_type: The device type to compile for. + disable_bounds_checks: Disable bounds checks in the kernel. """ PLATFORM: ClassVar[str] = "mosaic" dimension_semantics: ( @@ -95,7 +94,9 @@ class TPUCompilerParams(pallas_core.CompilerParams): internal_scratch_in_bytes: int | None = None serialization_format: int = 1 device_type: str | None = None + disable_bounds_checks: bool = False + # Replace is a method, not a field. replace = dataclasses.replace class TPUMemorySpace(enum.Enum): @@ -112,47 +113,12 @@ def __call__(self, shape: tuple[int, ...], dtype: jnp.dtype): # A convenience function for constructing MemoryRef types. return pallas_core.MemoryRef(shape, dtype, self) -class semaphore_dtype(dtypes.extended): pass -class semaphore(semaphore_dtype): pass -class dma_semaphore(semaphore_dtype): pass -class barrier_semaphore(semaphore_dtype): pass +class dma_semaphore(pallas_core.semaphore_dtype): pass -class AbstractSemaphoreTyRules: - @staticmethod - def pallas_interpret_element_aval(_) -> jax_core.ShapedArray: - return jax_core.ShapedArray((), pallas_core.SEMAPHORE_INTERPRET_DTYPE) - - @staticmethod - def physical_element_aval(_) -> jax_core.ShapedArray: - return jax_core.ShapedArray((), jnp.int32) - -class AbstractSemaphoreTy(dtypes.ExtendedDType): - name: str - _rules = AbstractSemaphoreTyRules - - def __repr__(self) -> str: - return self.name - - def __eq__(self, other): - return self.__class__ == other.__class__ - - def __hash__(self) -> int: - return hash(self.__class__) - -# TODO(sharadmv): implement dtype rules for AbstractSemaphoreTy - -class SemaphoreTy(AbstractSemaphoreTy): - type = semaphore - name = "sem" - -class DmaSemaphoreTy(AbstractSemaphoreTy): +class DMASemaphore(pallas_core.AbstractSemaphoreTy): type = dma_semaphore name = "dma_sem" -class BarrierSemaphoreTy(AbstractSemaphoreTy): - type = barrier_semaphore - name = "barrier_sem" - class SemaphoreType(enum.Enum): REGULAR = "regular" DMA = "dma" @@ -161,11 +127,11 @@ class SemaphoreType(enum.Enum): def __call__(self, shape: tuple[int, ...]): dtype: Any if self == SemaphoreType.DMA: - dtype = DmaSemaphoreTy() + dtype = DMASemaphore() elif self == SemaphoreType.BARRIER: - dtype = BarrierSemaphoreTy() + dtype = pallas_core.BarrierSemaphore() else: - dtype = SemaphoreTy() + dtype = pallas_core.Semaphore() return pallas_core.MemoryRef(shape, dtype, TPUMemorySpace.SEMAPHORE) def get_array_aval(self) -> pallas_core.ShapedArrayWithMemorySpace: diff --git a/jax/_src/pallas/mosaic/helpers.py b/jax/_src/pallas/mosaic/helpers.py index 76421cec3340..24cd7cad6086 100644 --- a/jax/_src/pallas/mosaic/helpers.py +++ b/jax/_src/pallas/mosaic/helpers.py @@ -88,8 +88,8 @@ def signal_core(i): # Don't signal ourself @pl_helpers.when(core_id != i) def _(): - plm_primitives.semaphore_signal(sem, 1, core_index=i) + pl_primitives.semaphore_signal(sem, 1, core_index=i) for i in range(num_cores): signal_core(i) - plm_primitives.semaphore_wait(sem, num_cores - 1) + pl_primitives.semaphore_wait(sem, num_cores - 1) diff --git a/jax/_src/pallas/mosaic/interpret.py b/jax/_src/pallas/mosaic/interpret.py index 1ad7be8154cd..5ac6bb6564ba 100644 --- a/jax/_src/pallas/mosaic/interpret.py +++ b/jax/_src/pallas/mosaic/interpret.py @@ -13,14 +13,14 @@ # limitations under the License. import collections -from collections.abc import Iterable, Sequence import dataclasses import enum import functools +import gc import itertools import math import threading -from typing import Any, Literal +from typing import Any, Callable,Literal import jax from jax import lax @@ -29,14 +29,16 @@ from jax._src.lax.control_flow import for_loop from jax._src import linear_util as lu from jax._src import source_info_util -from jax._src.pallas.mosaic import primitives as mosaic_primitives from jax._src.pallas.mosaic import core as mosaic_core +from jax._src.pallas.mosaic import primitives as mosaic_primitives +from jax._src.pallas.mosaic import verification from jax._src.pallas import core as pallas_core from jax._src.pallas import primitives from jax._src import pjit from jax._src.state import discharge as state_discharge from jax._src.state import indexing from jax._src.state import primitives as state_primitives +from jax._src.typing import Array from jax._src.util import ( safe_map, safe_zip, @@ -73,10 +75,10 @@ class TPUInterpretParams: is waiting on a DMA semaphore that will be signaled when the read or write is complete. Default: "on_wait". - detect_races: If True, a dynamic, happens-before race detector will be - used to detect data races during kernel interpretation. If any races are - detected, a message will be printed and `races.races_found` will be set - to True. + detect_races: If True, a dynamic, happens-before race detector will be used + to detect data races during kernel interpretation. If any races are + detected, a message will be printed and `races.races_found` will be set to + True. Default: False. skip_floating_point_ops: If True, operations that produce only floating point values will not be interpreted; instead, their results will be @@ -84,14 +86,25 @@ class TPUInterpretParams: operands to any operation will be replaced with (arrays of) `jnp.inf`. Default: False. uninitialized_memory: If "nan", allocated buffers are initialized to - to contain all NaNs (or to their maximum possible value for integers). - If "zero", allocated buffers are initialized to all zeros. + contain all NaNs (or to their maximum possible value for integers). If + "zero", allocated buffers are initialized to all zeros. Default: "nan". + random_seed: Seed for random number generator used during interpretation. + Currently random numbers are used to randomize the grid coordinates along + dimensions with 'parallel' semantics. + Default: None. + grid_point_recorder: Callback that is invoked by the interpreter for each + grid point in the order in which the grid points are traversed. This is + intended for inspecting the randomization of coordinates along grid + dimensions with 'parallel' semantics. + Default: None. """ dma_execution_mode: Literal["eager", "on_wait"] = "on_wait" detect_races: bool = False skip_floating_point_ops: bool = False uninitialized_memory: Literal["nan", "zero"] = "nan" + random_seed: int | None = None + grid_point_recorder: Callable[[tuple[jnp.int32, ...]], None] | None = None VectorClock = np.ndarray @@ -452,6 +465,7 @@ class SharedMemory: num_devices: int clocks: list[VectorClock] barrier: threading.Barrier + clean_up_barrier: threading.Barrier # (memory_space, buffer_id, device_id) -> NumPy array # TODO(jburnim): Handle Megacore. @@ -477,6 +491,8 @@ class SharedMemory: next_dma_id: int = 100 + deallocated_bytes: int = 0 + # TODO(jburnim): Do we want to support multiple instances of SharedMemory? # Maybe for running multiple distinct interpreted computations in parallel? @@ -503,18 +519,35 @@ def _initialize_shared_memory(device_id, num_devices, *, interpret_params): interpret_params=interpret_params, num_devices=num_devices, clocks=[make_vector_clock(num_devices) for _ in range(num_devices)], - barrier=threading.Barrier(num_devices)) + barrier=threading.Barrier( + num_devices, action=_update_clocks_for_global_barrier), + clean_up_barrier=threading.Barrier( + num_devices, action=_clear_shared_memory)) assert _shared_memory.num_devices == num_devices global races races = RaceDetectionState(num_devices=num_devices) +def _update_clocks_for_global_barrier(): + shared_memory = _get_shared_memory() + with shared_memory.lock: + # Set the vector clock for device 0 to the max over all device clocks. + for c in shared_memory.clocks[1:]: + update_vector_clock(shared_memory.clocks[0], c) + # Set all other device vector clocks to the max over all the clocks. + for c in shared_memory.clocks[1:]: + update_vector_clock(c, shared_memory.clocks[0]) + +def _barrier(device_id): + device_id = int(device_id) + shared_memory = _get_shared_memory() + if shared_memory.num_devices > 1: + shared_memory.barrier.wait() + def _clean_up_shared_memory(device_id): device_id = int(device_id) shared_memory = _get_shared_memory() - shared_memory.barrier.wait() - if device_id == 0: - _clear_shared_memory() + shared_memory.clean_up_barrier.wait() def _validate(device_id): device_id = int(device_id) @@ -553,8 +586,18 @@ def _deallocate_buffer(device_id, memory_space, buffer_id): shared_memory = _get_shared_memory() with shared_memory.lock: - # TODO(jburnim): Error if buffer doesn't exist? - shared_memory.mem.pop((memory_space, buffer_id, device_id), None) + buff = shared_memory.mem.pop((memory_space, buffer_id, device_id)) + shared_memory.deallocated_bytes += buff.size * buff.itemsize + del buff + + should_collect = shared_memory.deallocated_bytes > 100_000_000 + if should_collect: + shared_memory.deallocated_bytes = 0 + + if should_collect: + # Periodic garbage collection here prevents OOMs -- although it's not clear + # why arrays are not getting freed without this. + gc.collect() def _allocate_semaphores(device_id, shape): device_id = int(device_id) @@ -948,9 +991,9 @@ def _device_coords_to_logical_id(device_coords, axis_sizes): def _device_id_to_logical(device_id, device_id_type, axis_sizes): if device_id is None: return None - if device_id_type == mosaic_primitives.DeviceIdType.MESH: + if device_id_type == primitives.DeviceIdType.MESH: return _device_coords_to_logical_id(device_id, axis_sizes) - elif device_id_type == mosaic_primitives.DeviceIdType.LOGICAL: + elif device_id_type == primitives.DeviceIdType.LOGICAL: return device_id else: raise ValueError(f'Unsupported device ID type: {device_id_type}') @@ -993,7 +1036,7 @@ def write(var, value): value = Placeholder(value.shape, value.dtype) env[var] = value - jax.util.safe_map(write, jaxpr.constvars + jaxpr.invars, args) + jax._src.util.safe_map(write, jaxpr.constvars + jaxpr.invars, args) # Get the device ID. axis_sizes = jax_core.get_axis_env().axis_sizes @@ -1019,7 +1062,7 @@ def write(var, value): # not need to do any reads if `interpret_params.skip_floating_point_ops` # is True. If this is the case, we want to avoid materializing the read # array into the jaxpr when this function is traced. - deferred_invals = functools.partial(jax.util.safe_map, read, eqn.invars) + deferred_invals = functools.partial(jax._src.util.safe_map, read, eqn.invars) if prim is primitives.load_p: (ref, transforms, mask, _) = jax.tree.unflatten( @@ -1050,6 +1093,21 @@ def write(var, value): ordered=True) elif prim is mosaic_primitives.delay_p: + # TODO(jburnim): Implement this properly? + out = [] + + elif prim is mosaic_primitives.prng_seed_p: + # TODO(jburnim): Implement this properly? + out = [] + + elif prim is mosaic_primitives.prng_random_bits_p: + # TODO(jburnim): Implement this properly? + out = jnp.zeros(eqn.params['shape'], jnp.int32) + + elif prim is verification.assume_p: + out = read(eqn.invars[0]) + + elif prim is verification.pretend_p: out = [] elif prim is lax.cond_p: @@ -1125,16 +1183,8 @@ def f(*args, jaxpr): out = _interpret(eqn.params['jaxpr'], *deferred_invals(), *allocs) - for a in allocs: - if isinstance(a, tuple): - callback.io_callback( - _deallocate_buffer, - None, - device_id, - TPU_MEMORY_SPACE_IDXS[v.aval.memory_space], - a, - ordered=True) - else: + for a, v in zip(allocs, eqn.params['jaxpr'].invars): + if v.aval.memory_space == mosaic_core.TPUMemorySpace.SEMAPHORE: # TODO(jburnim): De-allocate semaphores. # callback.io_callback( # _deallocate_semaphores, @@ -1143,6 +1193,14 @@ def f(*args, jaxpr): # a, # ordered=True) pass + else: + callback.io_callback( + _deallocate_buffer, + None, + device_id, + TPU_MEMORY_SPACE_IDXS[v.aval.memory_space], + a, + ordered=True) elif prim is state_primitives.get_p: invals = deferred_invals() @@ -1229,7 +1287,7 @@ def f(*args, jaxpr): compiler_params['mosaic']['collective_id'], ordered=True) - elif prim is mosaic_primitives.semaphore_signal_p: + elif prim is primitives.semaphore_signal_p: sem, sem_transforms, inc, target_device_id, core_index = ( jax.tree.unflatten(eqn.params['args_tree'], deferred_invals())) target_device_id = _device_id_to_logical( @@ -1245,7 +1303,7 @@ def f(*args, jaxpr): ordered=True) out = [] - elif prim is mosaic_primitives.semaphore_wait_p: + elif prim is primitives.semaphore_wait_p: sem, sem_transforms, value = ( jax.tree.unflatten(eqn.params['args_tree'], deferred_invals())) callback.io_callback( @@ -1279,38 +1337,29 @@ def f(*args, jaxpr): out = prim.bind(*subfuns, *deferred_invals(), **bind_params) out = out if prim.multiple_results else [out] - jax.util.safe_map(write, eqn.outvars, out) - - return jax.util.safe_map(read, jaxpr.outvars) - -def _initialize_output_vals( - block_mappings_output: Iterable[BlockMapping], - input_args, input_output_aliases, - interpret_params: TPUInterpretParams, -) -> Sequence[jax.Array]: - oi_map = {v: k for k, v in input_output_aliases} - output_vals = [] - for i, bm in enumerate(block_mappings_output): - if i in oi_map: - output_vals.append(input_args[oi_map[i]]) - else: - output_vals.append(_uninitialized_value( - bm.array_shape_dtype.shape, - bm.array_shape_dtype.dtype, - interpret_params)) - return output_vals - -def _compute_start_indices(block_mapping, loop_idx, *args): - block_indices = ( - jax_core.jaxpr_as_fun(block_mapping.index_map_jaxpr)(*loop_idx, *args)) - if isinstance(block_mapping.indexing_mode, pallas_core.Blocked): - ret = tuple(i if b is pallas_core.mapped else b * i - for b, i in zip(block_mapping.block_shape, block_indices)) - elif isinstance(block_mapping.indexing_mode, pallas_core.Unblocked): - ret = block_indices - else: - raise RuntimeError(f"Unknown indexing mode: {block_mapping.indexing_mode}") - return ret + jax._src.util.safe_map(write, eqn.outvars, out) + + return jax._src.util.safe_map(read, jaxpr.outvars) + +def _compute_start_indices( + block_mapping, loop_idx, *args, compiler_params, interpret_params): + jaxpr = block_mapping.index_map_jaxpr + block_indices = _interpret_jaxpr( + jaxpr.jaxpr, *jaxpr.consts, *loop_idx, *args, + compiler_params=compiler_params, interpret_params=interpret_params) + if isinstance(block_mapping.indexing_mode, pallas_core.Blocked): + ret = jnp.array( + tuple( + i if b is pallas_core.mapped else b * i + for b, i in zip(block_mapping.block_shape, block_indices) + ), + dtype=jnp.int32, + ) + elif isinstance(block_mapping.indexing_mode, pallas_core.Unblocked): + ret = block_indices + else: + raise RuntimeError(f"Unknown indexing mode: {block_mapping.indexing_mode}") + return ret def _get_next_indices(grid, indices): next_indices = [] @@ -1321,6 +1370,96 @@ def _get_next_indices(grid, indices): next_indices.append(jnp.where(carry, 0, i)) return tuple(reversed(next_indices)) +def _get_parallel_dim_semantics( + compiler_params: dict[str, Any], grid: tuple[int, ...] +) -> tuple[bool, ...]: + """Returns a tuple of booleans indicating whether the corresponding dimension in `grid` is parallel.""" + dimension_semantics = compiler_params.get('mosaic', {}).get( + 'dimension_semantics', None + ) + if dimension_semantics is None: + return (False,) * len(grid) + return tuple(ds == 'parallel' for ds in dimension_semantics) + +_GridPointCoordinatesPerDim = tuple[Array, ...] + +def _get_randomized_grid_coordinates( + grid: tuple[int, ...], + compiler_params: dict[str, Any], + random_seed: int | None, +) -> _GridPointCoordinatesPerDim: + """Returns a tuple of randomized coordinates for each 'parallel' dimension in `grid`. + + For a dimension with 'parallel' semantics at position `d` in the grid, the + returned tuple contains a random permutation of the sequence `[0,..., + grid[d] - 1]` at index `d`. For each dimension with 'arbitrary' semantics, + the resulting tuple contains an empty array. (Inserting an empty arry for an + 'arbitrary' dimension at position `d` in the grid, instead of the sequence + `[0,..., grid[d] - 1]`, allows `grid[d]` to be a dynamic value, i.e. a value + not known at Jax trace time.) + + Args: + grid: Tuple of sizes of the dimensions in the grid. + compiler_params: Representation of a `mosaic_core.TPUCompilerParams` object + as a dictionary. + parallel_semantics_per_dim: A tuple of booleans indicating whether the + corresponding dimension in the grid has parallel semantics. + random_seed: The seed to use for randomizing coordinates in parallel + dimensions. + """ + parallel_semantics_per_dim = _get_parallel_dim_semantics( + compiler_params, grid + ) + + key = jax.random.key(random_seed or 0) + grid_point_coordinates = [] + for dim_size, parallel_dim in zip(grid, parallel_semantics_per_dim): + if parallel_dim: + # The size of a dimension with `parallel` semantics must be known at Jax + # trace time. This ensures that the arguments to `jnp.arange` and + # `jax.random.permutation` below are valid. + dim_size = jax_core.concrete_or_error(None, dim_size) + + coordindates_along_dim = jnp.arange(dim_size, dtype=jnp.int32) + key, subkey = jax.random.split(key) + coordindates_along_dim = jax.random.permutation( + subkey, coordindates_along_dim + ) + grid_point_coordinates.append(coordindates_along_dim) + else: + grid_point_coordinates.append(jnp.array((), dtype=jnp.int32)) + + return tuple(grid_point_coordinates) + + +def _get_grid_point( + loop_indices: tuple[Array, ...], + grid_point_coordinates: _GridPointCoordinatesPerDim, +) -> Array: + """Indexes each entry in `grid_point_coordinates` with the corresponding entry in `loop_indices`. + + If an entry in `grid_point_coordinates` is an empty array, the corresponding + entry in the returned array is the corresponding entry in `loop_indices`. + Otherwise, the returned array contains the entry in `grid_point_coordinates` + indexed with the corresponding entry in `loop_indices`. + + Args: + loop_indices: A tuple of loop indices. + grid_point_coordinates: A tuple of coordinate arrays for each dimension in + the grid. Dimensions with 'arbitrary' semantics are represented by empty + arrays. Dimensions with 'parallel' semantics are represented by arrays of + randomized coordinates. + + Returns: + A 1-dimensional array containing the coordinates for the grid point + corresponding to the specified `loop_indices`. + """ + grid_point = [] + for li, coords in zip(loop_indices, grid_point_coordinates): + grid_point.append(li if jnp.size(coords) == 0 else coords[li]) + return jnp.array(grid_point, dtype=np.int32) + + def _maybe_dynamic_slice(start_idx, block_shape, value, is_indexing): start_idx = tuple(jnp.array(s, dtype=jnp.int32) for s in start_idx) output = lax.dynamic_slice(value, start_idx, slice_sizes=block_shape) @@ -1374,7 +1513,7 @@ def interpret_pallas_call( input_output_aliases: tuple[tuple[int, int], ...], grid_mapping: GridMapping, mesh: pallas_core.Mesh | None, - compiler_params: Any, + compiler_params: dict[str, Any], cost_estimate: CostEstimate, out_avals: tuple[jax_core.AbstractValue, ...], interpret_params: TPUInterpretParams, @@ -1423,39 +1562,65 @@ def interpret_pallas_call( for a, bs in zip(input_args, block_shapes[:num_inputs]) ] - # Allocate buffers in HBM for outputs. - output_buffer_ids = [] - output_buffer_shapes = [] - output_vals = _initialize_output_vals( - grid_mapping.block_mappings_output, - scalars + input_args, - input_output_aliases, - interpret_params) - num_outputs = grid_mapping.num_outputs - output_block_shapes = block_shapes[num_inputs : num_inputs + num_outputs] - for out_val, bs in zip(output_vals, output_block_shapes): - padded_val = _pad_to_block_dimension(out_val, bs, interpret_params) - output_buffer_shapes.append(padded_val.shape) - output_buffer_ids.append(callback.io_callback( + # Allocate HBM buffers for pallas_call inputs. + # + # TODO(jburnim): As an optimization, skip allocating buffers for inputs that + # are neither aliased nor passed to the kernel in HBM? + input_buffer_ids = [] + for i, var in enumerate( + jaxpr.invars[grid_mapping.num_index_operands:][:grid_mapping.num_inputs]): + assert var.aval.dtype == input_args[i].dtype + input_buffer_ids.append(callback.io_callback( _allocate_buffer, jax.ShapeDtypeStruct((), jnp.int16), device_id, TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY], - padded_val, + input_args[i], ordered=True)) - # Allocate buffers for all kernel arguments (e.g., scalars, inputs, - # outputs, scratch). - io_alias_map = dict(input_output_aliases) + + # Allocate buffers in HBM for pallas_call outputs. oi_alias_map = {v: k for k, v in input_output_aliases} - kernel_buffer_ids = [] - for _, val in zip(jaxpr.invars[grid_mapping.slice_index_ops], scalars): - kernel_buffer_ids.append(callback.io_callback( + output_buffer_ids = [] + output_buffer_shapes = [] + output_vals = [] + num_outputs = grid_mapping.num_outputs + output_block_shapes = block_shapes[num_inputs : num_inputs + num_outputs] + for i, bm in enumerate(grid_mapping.block_mappings_output): + if i in oi_alias_map: + # Re-use the HBM buffer for the aliased pallas_call input. + output_buffer_ids.append(input_buffer_ids[oi_alias_map[i]]) + output_buffer_shapes.append(input_args[oi_alias_map[i]].shape) + output_vals.append(input_args[oi_alias_map[i]]) + else: + out_val = _uninitialized_value(bm.array_shape_dtype.shape, + bm.array_shape_dtype.dtype, + interpret_params) + padded_val = _pad_to_block_dimension( + out_val, output_block_shapes[i], interpret_params) + output_buffer_ids.append(callback.io_callback( + _allocate_buffer, + jax.ShapeDtypeStruct((), jnp.int16), + device_id, + TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY], + padded_val, + ordered=True)) + output_buffer_shapes.append(padded_val.shape) + output_vals.append(out_val) + + # Allocate buffers for non-HBM kernel arguments (e.g., scalars, inputs, + # outputs, scratch). + scalar_buffer_ids = [] + for var, val in zip(jaxpr.invars[grid_mapping.slice_index_ops], scalars): + assert var.aval.shape == val.shape + assert var.aval.dtype == val.dtype + scalar_buffer_ids.append(callback.io_callback( _allocate_buffer, jax.ShapeDtypeStruct((), jnp.int16), device_id, TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.SMEM], val, ordered=True)) + kernel_buffer_ids = scalar_buffer_ids.copy() for i, var in enumerate(jaxpr.invars[grid_mapping.num_index_operands:]): output_idx = i - grid_mapping.num_inputs is_input = i < grid_mapping.num_inputs @@ -1467,23 +1632,18 @@ def interpret_pallas_call( device_id, var.aval.shape, ordered=True)) - elif is_output and _is_any(var.aval.memory_space): - # Use the already-allocated HBM output buffer. + elif _is_any(var.aval.memory_space): + # Use the already-allocated HBM input or output buffer. # - # TODO(jburnim): For kernel args in HBM, check that block shape is the - # same as for the corresponding pallas_call input, and that the index_map + # TODO(jburnim): For kernel args in HBM, check that block shape eqals the + # shape of the corresponding pallas_call input, and that the index_map # is trivial. - kernel_buffer_ids.append(output_buffer_ids[output_idx]) - elif is_output and (output_idx in oi_alias_map): - # Use the already-allocated (non-HBM) input buffer. - kernel_buffer_ids.append(kernel_buffer_ids[oi_alias_map[output_idx]]) - elif is_input and (i in io_alias_map) and _is_any(var.aval.memory_space): - # Use the already-allocated HBM output buffer. - kernel_buffer_ids.append(output_buffer_ids[io_alias_map[i]]) + assert is_input ^ is_output + if is_input: + kernel_buffer_ids.append(input_buffer_ids[i]) + if is_output: + kernel_buffer_ids.append(output_buffer_ids[output_idx]) else: - # TODO(jburnim): For kernel args in HBM, check that block shape is the - # same as for the corresponding pallas_call input, and that the index_map - # is trivial. kernel_buffer_ids.append(callback.io_callback( _allocate_buffer, jax.ShapeDtypeStruct((), jnp.int16), @@ -1493,74 +1653,141 @@ def interpret_pallas_call( var.aval.shape, var.aval.dtype, interpret_params), ordered=True)) + if compiler_params.get('mosaic', {}).get('collective_id', None) is None: + # The kernel doesn't specify its own barrier semaphore, so we do a global + # barrier before running the first iteration of the kernel. + callback.io_callback(_barrier, (), device_id, ordered=True) + _, input_ids, kernel_output_ids, _ = split_list( kernel_buffer_ids, [grid_mapping.num_index_operands, num_inputs, grid_mapping.num_outputs]) input_vars, output_vars = split_list( jaxpr.invars[grid_mapping.slice_block_ops], [num_inputs]) - # For kernel inputs that are in HBM, we populate the buffer once before - # any kernel invocations. - for buffer_id, var, val in zip(input_ids, input_vars, input_args): - if not _is_any(var.aval.memory_space): - continue - if (val.shape != var.aval.shape) or (val.dtype != var.aval.dtype): - # TODO(jburnim): Also check that the index_map is trivial. - raise ValueError() - callback.io_callback( - store, - (), - device_id, - TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY], - buffer_id, - (), - val, - ordered=True) - if grid: num_iterations = functools.reduce(jnp.multiply, grid) # type: ignore[arg-type] else: # Base case is always one iteration when grid is () num_iterations = 1 - def body(carry): - # The loop carry: (i, loop_idx) -- - # - i:int32 is the interation index - # - loop_idx: tuple[int32] are the program ids for each grid axis - i, loop_idx = carry + randomized_grid_coordinates = _get_randomized_grid_coordinates( + grid, compiler_params, interpret_params.random_seed # type: ignore[arg-type] + ) + def _get_local_grid_env(loop_idx): if grid_mapping.local_grid_env is not None: - local_grid_env = grid_mapping.local_grid_env(loop_idx, grid) + return grid_mapping.local_grid_env(loop_idx, grid) else: - local_grid_env = tuple( + return tuple( pallas_core.GridAxis(idx, b) for dim, (idx, b) in enumerate(zip(loop_idx, grid)) if dim not in grid_mapping.vmapped_dims ) - with pallas_core.grid_env(local_grid_env): + def body( + carry: tuple[ + jnp.int32, tuple[jnp.int32, ...], list[jnp.ndarray], list[jnp.ndarray] + ], + ): + """Performs a single iteration of `jaxpr` in the device grid. + + Execution of `jaxpr` is preceded by reading kernel input buffers and + followed by writing kernel output buffers. + + Args: + carry: (iteration_idx, loop_idx, prev_start_indices, cur_start_indices). + - iteration_idx is the interation index. + - loop_idx are the program ids for each grid axis. + - prev_start_indices is a rank-1 array that contains the start indices + for the slices of inputs and outputs processed in the previous loop + iteration. + - cur_start_indices is a rank-1 array that contains the start indices + for the slices of inputs and outputs processed in the current loop + iteration. + + Note that by carrying the previous *and* current start indices between + loop iterations, it suffices to compute only one list of start indices, + i.e. `next_start_indices` (see below), per iteration. + + Returns: + The carry for the next iteration. + """ + iteration_idx, loop_idx, prev_start_indices, cur_start_indices = carry + if interpret_params.grid_point_recorder is not None: + grid_point = _get_grid_point(loop_idx, randomized_grid_coordinates) + callback.io_callback(interpret_params.grid_point_recorder, (), grid_point) + + with pallas_core.grid_env(_get_local_grid_env(loop_idx)): + next_loop_idx = _get_next_indices(grid, loop_idx) + next_grid_point = _get_grid_point( + next_loop_idx, randomized_grid_coordinates + ) + next_start_indices = [ + _compute_start_indices( + bm, + next_grid_point, + *scalar_buffer_ids, + compiler_params=compiler_params, + interpret_params=interpret_params, + ) + for bm in grid_mapping.block_mappings + ] + # Copy slices of the input to the kernel buffers. - # - # TODO(jburnim): Only copy slices when the index mapping has changed? - start_indices = [_compute_start_indices(bm, loop_idx, *scalars) - for bm in grid_mapping.block_mappings] - for j, var in enumerate(input_vars): - if _is_any(var.aval.memory_space): - continue - sliced_val = _maybe_dynamic_slice(start_indices[j], block_shapes[j], - input_args[j], is_indexing_dim[j]) - assert(sliced_val.shape == var.aval.shape) + + def _store_slice_to_kernel_input(index, input_var): + # Copy from the HBM buffer for the pallas_call input to the kernel + # input buffer. + # TODO(jburnim): Just use input_args[j] when the input is not aliased? + transform = indexing.NDIndexer( + indices=tuple( + indexing.ds(st, sz) if not iid else st + for st, sz, iid in zip( + cur_start_indices[index], + block_shapes[index], + is_indexing_dim[index], + ) + ), + shape=input_args[index].shape, + int_indexer_shape=(), + ) + sliced_val = callback.io_callback( + # TODO(jburnim): Pass source_info from the pallas_call, in case this + # read is involved in a data race. + get, + jax.ShapeDtypeStruct(input_var.aval.shape, input_var.aval.dtype), + device_id, + TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY], + input_buffer_ids[index], + (transform,), + ordered=True, + ) callback.io_callback( # TODO(jburnim): Pass source_info from the pallas_call, in case this # store is involved in a data race. store, (), device_id, - TPU_MEMORY_SPACE_IDXS[var.aval.memory_space], - input_ids[j], + TPU_MEMORY_SPACE_IDXS[input_var.aval.memory_space], + input_ids[index], (), sliced_val, - ordered=True) + ordered=True, + ) + + for j, var in enumerate(input_vars): + if _is_any(var.aval.memory_space): + continue + assert len(cur_start_indices[j].shape) == 1 + assert len(prev_start_indices[j].shape) == 1 + jax.lax.cond( + (iteration_idx == 0) + | jax.lax.reduce_or( + cur_start_indices[j] != prev_start_indices[j], axes=(0,) + ), + functools.partial(_store_slice_to_kernel_input, j, var), + lambda: None, + ) # Invoke the kernel. _interpret_jaxpr(jaxpr, *kernel_buffer_ids, @@ -1568,29 +1795,30 @@ def body(carry): interpret_params=interpret_params) # Copy from the kernel buffers to slices of the output in HBM. - # - # TODO(jburnim): Only copy if the index mapping will change in the - # next iteration (or if this is the last iteration)? - for j, var in enumerate(output_vars): - if _is_any(var.aval.memory_space): - continue + def _store_to_output_buffer(index, output_var): kernel_output_val = callback.io_callback( # TODO(jburnim): Pass source_info from the pallas_call, in case this # get is involved in a data race. get, - var.aval, + output_var.aval, device_id, - TPU_MEMORY_SPACE_IDXS[var.aval.memory_space], + TPU_MEMORY_SPACE_IDXS[output_var.aval.memory_space], kernel_output_ids[j], (), - ordered=True) + ordered=True, + ) transform = indexing.NDIndexer( - indices=tuple(indexing.ds(st, sz) if not iid else st - for st, sz, iid in zip(start_indices[num_inputs + j], - block_shapes[num_inputs + j], - is_indexing_dim[num_inputs + j])), - shape=output_vals[j].shape, - int_indexer_shape=()) + indices=tuple( + indexing.ds(st, sz) if not iid else st + for st, sz, iid in zip( + cur_start_indices[num_inputs + index], + block_shapes[num_inputs + index], + is_indexing_dim[num_inputs + index], + ) + ), + shape=output_vals[index].shape, + int_indexer_shape=(index), + ) callback.io_callback( # TODO(jburnim): Pass source_info from the pallas_call, in case this # store is involved in a data race. @@ -1598,18 +1826,55 @@ def body(carry): (), device_id, TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY], - output_buffer_ids[j], + output_buffer_ids[index], (transform,), kernel_output_val, - ordered=True) + ordered=True, + ) + + for j, var in enumerate(output_vars): + if _is_any(var.aval.memory_space): + continue + assert len(cur_start_indices[num_inputs + j].shape) == 1 + assert len(next_start_indices[num_inputs + j].shape) == 1 + jax.lax.cond( + (iteration_idx + 1 == num_iterations) + | jax.lax.reduce_or( + cur_start_indices[num_inputs + j] + != next_start_indices[num_inputs + j], + axes=(0,), + ), + functools.partial(_store_to_output_buffer, j, var), + lambda: None, + ) - return i + 1, _get_next_indices(grid, loop_idx) + return iteration_idx + 1, next_loop_idx, cur_start_indices, next_start_indices + initial_loop_idx = (jnp.int32(0),) * len(grid) + initial_grid_point = _get_grid_point( + initial_loop_idx, randomized_grid_coordinates + ) + with pallas_core.grid_env(_get_local_grid_env(initial_loop_idx)): + initial_start_indices = [ + _compute_start_indices( + bm, + initial_grid_point, + *scalar_buffer_ids, + compiler_params=compiler_params, + interpret_params=interpret_params, + ) + for bm in grid_mapping.block_mappings + ] # TODO(jburnim): Handle parallel grid dimensions + megacore. _ = lax.while_loop( lambda carry: carry[0] < num_iterations, body, - (jnp.int32(0), (jnp.int32(0),) * len(grid)) + ( + jnp.int32(0), + initial_loop_idx, + initial_start_indices, # Previous start indices are ignored on the first iteration. + initial_start_indices, + ), ) # Read the output from the allocated output buffers. diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 10b9de7487eb..065d2b4c3b14 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -45,6 +45,7 @@ from jax._src.interpreters import partial_eval as pe from jax._src.lax import lax as lax_internal from jax._src.lax.control_flow import for_loop +from jax._src.lib import version as jaxlib_version from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith from jax._src.lib.mlir.dialects import func @@ -230,12 +231,12 @@ def _memory_space_to_mosaic_attribute(memory_space: MemorySpace | None def _dtype_to_ir_type(dtype: jnp.dtype, is_kernel_boundary: bool = False) -> ir.Type: - if jnp.issubdtype(dtype, tpu_core.semaphore_dtype): + if jnp.issubdtype(dtype, pallas_core.semaphore_dtype): if jnp.issubdtype(dtype, tpu_core.dma_semaphore): return ir.Type.parse("!tpu.dma_semaphore") - elif jnp.issubdtype(dtype, tpu_core.semaphore): + elif jnp.issubdtype(dtype, pallas_core.semaphore): return ir.Type.parse("!tpu.semaphore") - elif jnp.issubdtype(dtype, tpu_core.barrier_semaphore): + elif jnp.issubdtype(dtype, pallas_core.barrier_semaphore): return ir.Type.parse("!tpu.semaphore") else: raise NotImplementedError @@ -575,7 +576,7 @@ def err_details(): # TODO(necula): add index_map source location info f"and index_map {bm.index_map_jaxpr.jaxpr}, in " f"memory space {bm.block_aval.memory_space}." - "\nSee details at https://jax.readthedocs.io/en/latest/pallas/grid_blockspec.html#pallas-blockspec") + "\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec") if rank < 1: raise ValueError( "The Pallas TPU lowering currently supports only blocks of " @@ -737,6 +738,15 @@ def dynamic_shape_replacement_fn( block_shape = [ 1 if b is pallas_core.mapped else b for b in bm.block_shape ] + + # Force single-buffering pipelining for trivial windowing in VMEM. + pipeline_mode = bm.pipeline_mode + if ( + tpu_memory_space == tpu_core.TPUMemorySpace.VMEM + and bm.has_trivial_window() + ): + pipeline_mode = pallas_core.Buffered(1) + # If we have an extended dtype, we need to add the block shape for the # remaining physical dtype. block_shape += list(_get_aval_physical_dtype_shape(bm.block_aval.inner_aval)) @@ -754,20 +764,20 @@ def dynamic_shape_replacement_fn( block_params["window_kind"] = ir.Attribute.parse( f"#tpu.element_window<{pad_low},{pad_high}>" ) - if bm.pipeline_mode is not None: - if not isinstance(bm.pipeline_mode, pallas_core.Buffered): + if pipeline_mode is not None: + if not isinstance(pipeline_mode, pallas_core.Buffered): raise LoweringException( - f"Unsupported pipeline mode: {bm.pipeline_mode}." + f"Unsupported pipeline mode: {pipeline_mode}." ) - buffer_count = bm.pipeline_mode.buffer_count + buffer_count = pipeline_mode.buffer_count if buffer_count < 1 or buffer_count > 2: raise LoweringException( "Only single (1) and double (2) buffering are supported. Got" f" {buffer_count}." ) - pipeline_mode = "synchronous" if buffer_count == 1 else "double_buffered" + pipeline_mode_str = "synchronous" if buffer_count == 1 else "double_buffered" block_params["pipeline_mode"] = ir.Attribute.parse( - f"#tpu.pipeline_mode<{pipeline_mode}>" + f"#tpu.pipeline_mode<{pipeline_mode_str}>" ) window_params.append(ir.DictAttr.get(block_params)) m.body.append(mlir_func) @@ -1494,10 +1504,13 @@ def _load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree, **_): starts, ) if load_aval != aval_out: - vec_type = ir.VectorType.get(aval_out.shape, - _dtype_to_ir_type(aval_out.dtype, - is_kernel_boundary=True)) - load_val = vector.shape_cast(vec_type, load_val) + if aval_out.shape: + vec_type = ir.VectorType.get(aval_out.shape, + _dtype_to_ir_type(aval_out.dtype, + is_kernel_boundary=True)) + load_val = vector.shape_cast(vec_type, load_val) + else: + load_val = vector.extract(load_val, [], [0] * len(load_aval.shape)) return _maybe_cast_load_to_bool(ctx, aval_out, load_val) def _prng_key_load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree) -> KeyScalarBundle: @@ -1682,6 +1695,8 @@ def _masked_swap_lowering_rule( result = vector.load(mem_aval_vec_type, ref, starts) val = _maybe_cast_store_to_memref_type(ctx, val_aval, val) if mem_aval != aval_out: + if not aval_out.shape: + raise ValueError("Cannot swap scalars to VMEM.") # We are slicing a scalar so provided dummy 1 indices result_vec_type = ir.VectorType.get(aval_out.shape, _dtype_to_ir_type(aval_out.dtype, is_kernel_boundary=True)) @@ -2096,16 +2111,6 @@ def _convert_helper(x, *, to_dtype): # unsigned -> float is unsupported. We fall through and raise at the bottom. if not jnp.issubdtype(to_dtype, jnp.floating): return x.astype(to_dtype) - if jnp.issubdtype(from_dtype, jnp.floating) and jnp.issubdtype( - to_dtype, jnp.signedinteger - ): - if from_dtype.itemsize < 4: - x = x.astype(jnp.float32) - if to_dtype.itemsize < 4: - # Need to clip values to match XLA - minval, maxval = jnp.iinfo(to_dtype).min, jnp.iinfo(to_dtype).max - x = jnp.clip(x, minval, maxval) - return x.astype(jnp.int32).astype(to_dtype) raise NotImplementedError(f"Unsupported cast: {from_dtype} -> {to_dtype}") def _convert_element_type_lowering_rule( @@ -2149,10 +2154,7 @@ def _convert_element_type_lowering_rule( return x # TODO(apaszke): Remove both_32bit constraints using the Mosaic canonicalizer. elif _from(floating) and _to(signed): - # TODO(apaszke): Remove once a month has passed, along with the - # _convert_helper float -> signed conversion above. - if not ctx.forward_compatible or both_32bit: - return arith.fptosi(out_type, x) + return arith.fptosi(out_type, x) elif _from(signed) and _to(floating) and both_32bit: return arith.sitofp(out_type, x) elif old_dtype == jnp.bool_ and _to(integer) and new_dtype.itemsize == 4: @@ -2177,6 +2179,8 @@ def _reshape_lowering_rule(ctx: LoweringRuleContext, x, new_sizes, dimensions, ), x, ) + if not ctx.avals_out[0].shape: + return vector.extract(x, [], [0] * len(ctx.avals_in[0].shape)) return vector.shape_cast( aval_to_ir_type( ctx.lowering_context.dynamic_shape_replacement_fn, ctx.avals_out[0] @@ -2250,6 +2254,14 @@ def _split_lowering_rule( def _iota_lowering_rule(ctx: LoweringRuleContext, dtype, shape, dimension, sharding): + if len(shape) == 1: + if dimension != 0: + raise ValueError("Dimension must be 0 for 1D iota.") + def _1d_iota_helper(dtype, shape, dimension, sharding): + iota_2d = lax.iota_p.bind(dtype, (1,) + shape, dimension, sharding) + return iota_2d[0] + return lower_fun(_1d_iota_helper, multiple_results=False)( + ctx, dtype, shape, dimension, sharding) out_type = aval_to_ir_type( ctx.lowering_context.dynamic_shape_replacement_fn, ctx.avals_out[0] ) @@ -2346,13 +2358,13 @@ def _bcast(x, y, x_aval, y_aval, out_aval): y_dtype = x_aval.dtype elif x_aval.weak_type: x_dtype = y_aval.dtype - if isinstance(x, (np.ndarray, np.number, int, float)): + if not isinstance(x, ir.Value): if getattr(y, "type", None) == ir.IndexType.get(): mlir_type = y.type else: mlir_type = _dtype_to_ir_type(x_dtype) x = ir_constant(x, mlir_type) - if isinstance(y, (np.ndarray, np.number, int, float)): + if not isinstance(y, ir.Value): if getattr(x, "type", None) == ir.IndexType.get(): mlir_type = x.type else: @@ -2549,14 +2561,18 @@ def _nextafter_lowering_rule(ctx: LoweringRuleContext, x, y): lowering_rules[lax.nextafter_p] = _nextafter_lowering_rule -def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x): +def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") return math.rsqrt(x) lowering_rules[lax.rsqrt_p] = _rsqrt_lowering_rule -def _sqrt_lowering_rule(ctx: LoweringRuleContext, x): +def _sqrt_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") return math.sqrt(x) @@ -2572,7 +2588,9 @@ def _square_lowering_rule(ctx: LoweringRuleContext, x): lowering_rules[lax.square_p] = _square_lowering_rule -def _exp_lowering_rule(ctx: LoweringRuleContext, x): +def _exp_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") return math.exp(x) @@ -2605,9 +2623,11 @@ def _integer_pow_lowering_rule(ctx: LoweringRuleContext, x, *, y): lowering_rules[lax.integer_pow_p] = _integer_pow_lowering_rule -def _exp2_lowering_rule(ctx: LoweringRuleContext, x): +def _exp2_lowering_rule(ctx: LoweringRuleContext, x, accuracy): # exp2 in JAX lowers to exp(ln2 * x), not to pow2. We match that behavior # here. + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") return lower_fun( lambda x: jnp.exp(jnp.astype(np.log(2), x.dtype) * x), multiple_results=False, @@ -2618,7 +2638,9 @@ def _exp2_lowering_rule(ctx: LoweringRuleContext, x): skip_mlir_conversions.add(lax.exp2_p) -def _logistic_lowering_rule(ctx: LoweringRuleContext, x): +def _logistic_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") neg_x = arith.negf(x) exp_neg_x = math.exp(neg_x) aval_out = ctx.avals_out[0] @@ -2636,42 +2658,54 @@ def _logistic_lowering_rule(ctx: LoweringRuleContext, x): lowering_rules[lax.logistic_p] = _logistic_lowering_rule -def _sin_lowering_rule(ctx: LoweringRuleContext, x): +def _sin_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") return math.sin(x) lowering_rules[lax.sin_p] = _sin_lowering_rule -def _cos_lowering_rule(ctx: LoweringRuleContext, x): +def _cos_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") return math.cos(x) lowering_rules[lax.cos_p] = _cos_lowering_rule -def _tan_lowering_rule(ctx: LoweringRuleContext, x): +def _tan_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") return math.tan(x) lowering_rules[lax.tan_p] = _tan_lowering_rule -def _tanh_lowering_rule(ctx: LoweringRuleContext, x): +def _tanh_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") return math.tanh(x) lowering_rules[lax.tanh_p] = _tanh_lowering_rule -def _log_lowering_rule(ctx: LoweringRuleContext, x): +def _log_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") return math.log(x) lowering_rules[lax.log_p] = _log_lowering_rule -def _log1p_lowering_rule(ctx: LoweringRuleContext, x): +def _log1p_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") return math.log1p(x) @@ -3392,7 +3426,7 @@ def _alloc_value( ) -> ir.Value: if isinstance(aval, pallas_core.AbstractMemoryRef): memspace = _memory_space_to_mosaic_attribute(aval.memory_space) - if jnp.issubdtype(aval.dtype, tpu_core.semaphore_dtype): + if jnp.issubdtype(aval.dtype, pallas_core.semaphore_dtype): assert aval.memory_space == TPUMemorySpace.SEMAPHORE memref_type = aval_to_ir_type( ctx.lowering_context.dynamic_shape_replacement_fn, @@ -3442,8 +3476,8 @@ def _run_scoped_lowering_rule(ctx: LoweringRuleContext, *consts, jaxpr): def _device_id_to_logical( ctx: LoweringRuleContext, device_id, - device_id_type: tpu_primitives.DeviceIdType): - if device_id_type is tpu_primitives.DeviceIdType.MESH: + device_id_type: primitives.DeviceIdType): + if device_id_type is primitives.DeviceIdType.MESH: # Mesh means we are passed the mesh coordinates for the device device_ids = tree_util.tree_leaves(device_id) mesh_strides = ctx.lowering_context.mesh_context.mesh_strides @@ -3458,7 +3492,7 @@ def _device_id_to_logical( for a, b in zip(device_ids, mesh_strides) ), ) - elif device_id_type is tpu_primitives.DeviceIdType.LOGICAL: + elif device_id_type is primitives.DeviceIdType.LOGICAL: return device_id raise NotImplementedError(f"Unsupported device id type: {device_id_type}") @@ -3468,19 +3502,30 @@ def _semaphore_read_lowering_rule( *args, args_tree, ): - sem_aval, _ = tree_util.tree_unflatten(args_tree, ctx.avals_in) + sem_aval, sem_transforms_avals = tree_util.tree_unflatten(args_tree, ctx.avals_in) + primitives.check_sem_avals( + sem_aval, + sem_transforms_avals, + "read", + allowed_semaphore_types={ + tpu_core.dma_semaphore, + pallas_core.semaphore, + pallas_core.barrier_semaphore, + pallas_core.SEMAPHORE_INTERPRET_DTYPE, + }, + ) sem, transforms = tree_util.tree_unflatten(args_tree, args) sem, _ = _transform_ref(sem, sem_aval.dtype, sem_aval.shape, transforms) return tpu.sem_read(sem) -lowering_rules[tpu_primitives.semaphore_read_p] = _semaphore_read_lowering_rule +lowering_rules[primitives.semaphore_read_p] = _semaphore_read_lowering_rule def _semaphore_signal_lowering_rule( ctx: LoweringRuleContext, *args, args_tree, - device_id_type: tpu_primitives.DeviceIdType, + device_id_type: primitives.DeviceIdType, ): sem_aval, _, _, _, _ = tree_util.tree_unflatten(args_tree, ctx.avals_in) sem, transforms, value, device_id, core_index = tree_util.tree_unflatten( @@ -3493,7 +3538,7 @@ def _semaphore_signal_lowering_rule( return [] -lowering_rules[tpu_primitives.semaphore_signal_p] = ( +lowering_rules[primitives.semaphore_signal_p] = ( _semaphore_signal_lowering_rule) @@ -3503,10 +3548,16 @@ def _semaphore_wait_lowering_rule(ctx: LoweringRuleContext, *args, args_tree): sem, _ = _transform_ref(sem, sem_aval.dtype, sem_aval.shape, transforms) tpu.sem_wait(sem, value) return [] -lowering_rules[tpu_primitives.semaphore_wait_p] = _semaphore_wait_lowering_rule +lowering_rules[primitives.semaphore_wait_p] = _semaphore_wait_lowering_rule -def _dma_start_lowering_rule(ctx: LoweringRuleContext, *args, tree, - device_id_type: tpu_primitives.DeviceIdType): + +def _dma_start_lowering_rule( + ctx: LoweringRuleContext, + *args, + tree, + device_id_type: primitives.DeviceIdType, + priority: int, +): ( src_ref, src_transforms, @@ -3538,15 +3589,25 @@ def _dma_start_lowering_rule(ctx: LoweringRuleContext, *args, tree, sem, _ = _transform_ref(sem, sem_aval.dtype, sem_aval.shape, sem_transforms) if device_id is not None: device_id = _device_id_to_logical(ctx, device_id, device_id_type) - tpu.enqueue_dma(src_ref, dst_ref, sem, source_semaphore=src_sem, - device_id=device_id) - + priority_kwarg = {"priority": priority} + if jaxlib_version < (0, 5, 4): + priority_kwarg = {} + tpu.enqueue_dma( + src_ref, + dst_ref, + sem, + source_semaphore=src_sem, + device_id=device_id, + **priority_kwarg, + ) return [] + + lowering_rules[tpu_primitives.dma_start_p] = _dma_start_lowering_rule def _dma_wait_lowering_rule(ctx: LoweringRuleContext, *args, tree, - device_id_type: tpu_primitives.DeviceIdType): + device_id_type: primitives.DeviceIdType): del device_id_type (src, src_transforms, dst, transforms, sem, sem_transforms, _, _, _) = ( tree_util.tree_unflatten(tree, args) @@ -3576,10 +3637,6 @@ def _dma_wait_lowering_rule(ctx: LoweringRuleContext, *args, tree, lowering_rules[tpu_primitives.dma_wait_p] = _dma_wait_lowering_rule -def _device_id_lowering_rule(ctx: LoweringRuleContext): - return tpu.device_id() -lowering_rules[tpu_primitives.device_id_p] = _device_id_lowering_rule - def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: Hashable): grid_names = ctx.lowering_context.grid_names if grid_names and axis_name in grid_names: diff --git a/jax/_src/pallas/mosaic/pallas_call_registration.py b/jax/_src/pallas/mosaic/pallas_call_registration.py index 896af0c464c5..824eb7e89716 100644 --- a/jax/_src/pallas/mosaic/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic/pallas_call_registration.py @@ -243,6 +243,7 @@ def _maybe_cast_inputs(*args): collective_id=mosaic_params.get("collective_id", None), has_side_effects=mosaic_params.get("has_side_effects", False), output_memory_spaces=output_memory_spaces, + disable_bounds_checks=mosaic_params.get("disable_bounds_checks"), ) _maybe_cast_to_bool = lambda x, aval: x.astype( jax.numpy.bool_) if aval.dtype == jax.numpy.bool_ else x diff --git a/jax/_src/pallas/mosaic/pipeline.py b/jax/_src/pallas/mosaic/pipeline.py index 184b1497adf9..9b0a9322c94d 100644 --- a/jax/_src/pallas/mosaic/pipeline.py +++ b/jax/_src/pallas/mosaic/pipeline.py @@ -213,8 +213,8 @@ class BufferedRef: spec: pl.BlockSpec # static metadata dtype: Any # static metadata buffer_type: BufferType # static metadata - window_ref: REF | None - accum_ref: REF | None + window_ref: ArrayRef | None + accum_ref: ArrayRef | None current_slot: ArrayRef | None # TODO(ramiroleal): Unused by class. Remove argument from # BufferedRef instantiations. @@ -337,6 +337,7 @@ def memory_space(self): def current_ref(self): buffer_slice = tuple( 0 if x is None else slice(None) for x in self.block_shape) + assert not (self.window_ref is None or isinstance(self.window_ref, REF)) if self.memory_space == VMEM: return self.window_ref.at[buffer_slice] else: @@ -368,10 +369,12 @@ def is_input_output(self): @property def current_slot_index(self): + """Index in double buffer corresponding to the current slot.""" return self.current_slot[0] @property def next_slot_index(self): + """Index in double buffer corresponding to the next slot.""" return lax.rem(self.current_slot_index + 1, 2) def bind_existing_ref(self, window_ref, indices): @@ -463,6 +466,8 @@ def copy_in(self, src_ref, grid_indices): """Starts copy of HBM dma slice into the current slot.""" assert self.is_input if self.memory_space == VMEM: return + assert not (self.window_ref is None or isinstance(self.window_ref, REF)) + assert self.sem_recvs is not None if self.swap is not None: self.swap[0] = True next_slot = self.next_slot_index @@ -470,7 +475,7 @@ def copy_in(self, src_ref, grid_indices): dst_slice = tuple(pl.ds(0, s.size) for s in src_slice) tpu_primitives.make_async_copy( src_ref.at[src_slice], - self.window_ref.at[next_slot].at[dst_slice], + self.window_ref.at[(next_slot, *dst_slice)], self.sem_recvs.at[next_slot], ).start() @@ -478,13 +483,15 @@ def copy_out(self, dst_ref, grid_indices): """Starts copy of HBM dma slice from the current slot.""" assert self.is_output if self.memory_space == VMEM: return + assert not (self.window_ref is None or isinstance(self.window_ref, REF)) + assert self.sem_sends is not None if self.swap is not None: self.swap[0] = True slot = self.current_slot_index dst_slice = self.get_dma_slice(dst_ref.shape, dst_ref.dtype, grid_indices) src_slice = tuple(pl.ds(0, s.size) for s in dst_slice) tpu_primitives.make_async_copy( - self.window_ref.at[slot].at[src_slice], + self.window_ref.at[(slot, *src_slice)], dst_ref.at[dst_slice], self.sem_sends.at[slot], ).start() @@ -493,13 +500,15 @@ def wait_in(self, src_ref, grid_indices): """Waits for input copy to finish.""" assert self.is_input if self.memory_space == VMEM: return + assert not (self.window_ref is None or isinstance(self.window_ref, REF)) + assert self.sem_recvs is not None src_slice = self.get_dma_slice(src_ref.shape, src_ref.dtype, grid_indices) dst_slice = tuple(pl.ds(0, s.size) for s in src_slice) current_slot = self.current_slot_index tpu_primitives.make_async_copy( src_ref.at[src_slice], # nb: doesn't matter - self.window_ref.at[current_slot].at[ - dst_slice + self.window_ref.at[ + (current_slot, *dst_slice) ], # only dst shape is important self.sem_recvs.at[current_slot], ).wait() @@ -508,12 +517,14 @@ def wait_out(self, dst_ref, grid_indices): """Waits for output copy to finish.""" assert self.is_output if self.memory_space == VMEM: return + assert not (self.window_ref is None or isinstance(self.window_ref, REF)) + assert self.sem_sends is not None # In a double buffer, previous slot is the same as next slot. prev_slot = self.next_slot_index dst_slice = self.get_dma_slice(dst_ref.shape, dst_ref.dtype, grid_indices) src_slice = tuple(pl.ds(0, s.size) for s in dst_slice) tpu_primitives.make_async_copy( - self.window_ref.at[prev_slot].at[src_slice], # nb: doesn't matter + self.window_ref.at[(prev_slot, *src_slice)], # nb: doesn't matter dst_ref.at[dst_slice], # only dst shape is important self.sem_sends.at[prev_slot], ).wait() @@ -533,16 +544,18 @@ def set_accumulator(self, init=False): """Set accumulator or zero it out to initialize.""" assert self.is_accumulator if self.accum_ref is not None: + accum_dtype = self.accum_ref.dtype def _init(): self.accum_ref[...] = jnp.zeros_like(self.accum_ref[...]) def _set(): - self.accum_ref[...] = self.current_ref[...].astype(self.accum_ref.dtype) + self.accum_ref[...] = self.current_ref[...].astype(accum_dtype) lax.cond(init, _init, _set) def accumulate(self): """Add into the current slot.""" assert self.is_accumulator if self.accum_ref is not None: + assert self.window_ref is not None accum_dtype = jnp.float32 if self.window_ref.dtype == jnp.int32: accum_dtype = jnp.int32 diff --git a/jax/_src/pallas/mosaic/primitives.py b/jax/_src/pallas/mosaic/primitives.py index fb0e0c2c55e3..59856c0ca7b2 100644 --- a/jax/_src/pallas/mosaic/primitives.py +++ b/jax/_src/pallas/mosaic/primitives.py @@ -16,7 +16,6 @@ from __future__ import annotations import dataclasses -import enum from typing import Any import jax @@ -28,6 +27,7 @@ from jax._src import util from jax._src.interpreters import mlir from jax._src.pallas import core as pl_core +from jax._src.pallas import primitives from jax._src.pallas import utils as pallas_utils from jax._src.pallas.mosaic import core as tpu_core from jax._src.state import discharge as state_discharge @@ -160,255 +160,6 @@ def _roll(x, shift): mlir.register_lowering(roll_p, _roll_lowering_rule) -class DeviceIdType(enum.Enum): - MESH = "mesh" - LOGICAL = "logical" - - -def check_sem_avals( - sem_aval, sem_transforms_avals, name, allowed_semaphore_types=None -): - if allowed_semaphore_types is None: - allowed_semaphore_types = { - tpu_core.semaphore, - tpu_core.barrier_semaphore, - # For interpret mode. - pl_core.SEMAPHORE_INTERPRET_DTYPE, - } - if not isinstance(sem_aval, state.AbstractRef): - raise ValueError(f"Cannot {name} on a non-semaphore Ref: {sem_aval}") - sem_shape = sem_aval.shape - if sem_transforms_avals: - sem_shape = sem_transforms_avals[-1].get_indexer_shape() - if sem_shape: - raise ValueError(f"Cannot {name} on a non-()-shaped semaphore: {sem_shape}") - sem_dtype = sem_aval.dtype - if not any( - jnp.issubdtype(sem_dtype, sem_type) - for sem_type in allowed_semaphore_types - ): - raise ValueError( - f"Must {name} semaphores of the following types:" - f" {allowed_semaphore_types}. Got {sem_dtype}." - ) - - -def _transform_semaphore(ref_value, transforms, ref_aval): - """Helper function for indexing into a semaphore during state_discharge.""" - if ref_value.shape == ref_aval.shape: - return state_discharge.transform_array(ref_value, transforms) - elif len(ref_value.shape) == 0: - return ref_value - else: - raise ValueError( - f"Semaphore value shape {ref_value.shape} does not match aval shape" - f" {ref_aval.shape}" - ) - - -semaphore_read_p = jax_core.Primitive("semaphore_read") -semaphore_read_p.multiple_results = False - - -def semaphore_read(sem_or_view): - ref, transforms = _get_ref_and_transforms(sem_or_view) - args = [ref, transforms] - flat_args, args_tree = tree_util.tree_flatten(args) - return semaphore_read_p.bind(*flat_args, args_tree=args_tree) - -@semaphore_read_p.def_abstract_eval -def _semaphore_read_abstract_eval( - *avals, - args_tree, -): - sem_aval, sem_transforms_avals = tree_util.tree_unflatten(args_tree, avals) - check_sem_avals( - sem_aval, - sem_transforms_avals, - "read", - allowed_semaphore_types={ - tpu_core.dma_semaphore, - tpu_core.semaphore, - tpu_core.barrier_semaphore, - pl_core.SEMAPHORE_INTERPRET_DTYPE, - }, - ) - return jax_core.ShapedArray((), jnp.dtype("int32")) - -def _semaphore_read_discharge_rule(in_avals, - out_avals, - *flat_args, - args_tree): - del out_avals - [ref, transforms] = args_tree.unflatten(flat_args) - sem_value = _transform_semaphore(ref, transforms, in_avals[0]) - sem_value = sem_value.astype(jnp.int32) - return (None,) * len(in_avals), sem_value -state_discharge.register_discharge_rule(semaphore_read_p)( - _semaphore_read_discharge_rule -) - - -semaphore_signal_p = jax_core.Primitive('semaphore_signal') -semaphore_signal_p.multiple_results = True - - -def semaphore_signal( - sem_or_view, - inc: int | jax.Array = 1, - *, - device_id: int | jax.Array | None | tuple[int | jax.Array, ...] = None, - device_id_type: DeviceIdType = DeviceIdType.MESH, - core_index: int | jax.Array | None = None, -): - ref, transforms = _get_ref_and_transforms(sem_or_view) - inc = jnp.asarray(inc, dtype=jnp.int32) - args = [ref, transforms, inc, device_id, core_index] - flat_args, args_tree = tree_util.tree_flatten(args) - semaphore_signal_p.bind( - *flat_args, - args_tree=args_tree, - device_id_type=device_id_type, - ) - - -@semaphore_signal_p.def_abstract_eval -def _semaphore_signal_abstract_eval( - *avals, - args_tree, - device_id_type: DeviceIdType, -): - del device_id_type - ( - sem_aval, - sem_transforms_avals, - value_aval, - device_id_avals, - core_index_aval, - ) = tree_util.tree_unflatten(args_tree, avals) - check_sem_avals(sem_aval, sem_transforms_avals, "signal") - if value_aval.dtype != jnp.dtype("int32"): - raise ValueError("Must signal an int32 value.") - if device_id_avals is not None: - device_id_flat_avals = tree_util.tree_leaves(device_id_avals) - for aval in device_id_flat_avals: - if aval.dtype != jnp.dtype("int32"): - raise ValueError("`device_id`s must be an int32 value.") - return [] - - -def _semaphore_signal_pp_eqn(eqn: jax_core.JaxprEqn, - context: jax_core.JaxprPpContext, - settings: jax_core.JaxprPpSettings): - del settings - invars = eqn.invars - tree = eqn.params["args_tree"] - ( - sem, - sem_transforms, - value, - device_ids, - _, - ) = tree_util.tree_unflatten(tree, invars) - out = pp.concat([ - pp.text("semaphore_signal"), - pp.text(" "), - sp.pp_ref_transforms(context, sem, sem_transforms), - pp.text(" "), - pp.text(jax_core.pp_var(value, context)), - ]) - if device_ids is not None: - flat_device_ids = tree_util.tree_leaves(device_ids) - if not flat_device_ids: - return out - device_ids_pp = [pp.text(jax_core.pp_var(flat_device_ids[0], context))] - for device_id in flat_device_ids[1:]: - device_ids_pp.append(pp.text(" ")) - device_ids_pp.append(pp.text(jax_core.pp_var(device_id, context))) - out = pp.concat([out, pp.concat(device_ids_pp)]) - return out -jax_core.pp_eqn_rules[semaphore_signal_p] = _semaphore_signal_pp_eqn - - -def _semaphore_signal_discharge_rule(in_avals, - out_avals, - *flat_args, - args_tree, - device_id_type): - del out_avals, device_id_type - [ref, transforms, inc, device_id, core_index] = args_tree.unflatten(flat_args) - if device_id is not None: - raise NotImplementedError("Remote signal not implemented.") - if core_index is not None: - raise NotImplementedError("Multiple core support not implemented.") - sem_value = _transform_semaphore(ref, transforms, in_avals[0]) - inc = inc.astype(pl_core.SEMAPHORE_INTERPRET_DTYPE) - _, new_sem_value = state_discharge.transform_swap_array( - ref, transforms, sem_value + inc - ) - return (new_sem_value,) + (None,) * (len(in_avals) - 1), () -state_discharge.register_discharge_rule(semaphore_signal_p)( - _semaphore_signal_discharge_rule -) - - -semaphore_wait_p = jax_core.Primitive('semaphore_wait') -semaphore_wait_p.multiple_results = True - -def semaphore_wait(sem_or_view, dec: int | jax.Array = 1): - ref, transforms = _get_ref_and_transforms(sem_or_view) - dec = jnp.asarray(dec, dtype=jnp.int32) - args = [ref, transforms, dec] - flat_args, args_tree = tree_util.tree_flatten(args) - semaphore_wait_p.bind(*flat_args, args_tree=args_tree) - -@semaphore_wait_p.def_abstract_eval -def _semaphore_wait_abstract_eval(*avals, args_tree): - sem_aval, sem_transforms_avals, value_aval = tree_util.tree_unflatten( - args_tree, avals - ) - check_sem_avals(sem_aval, sem_transforms_avals, "wait") - if value_aval.dtype != jnp.dtype("int32"): - raise ValueError("Must wait an int32 value.") - return [] - -def _semaphore_wait_pp_eqn(eqn: jax_core.JaxprEqn, - context: jax_core.JaxprPpContext, - settings: jax_core.JaxprPpSettings): - del settings - invars = eqn.invars - tree = eqn.params["args_tree"] - ( - sem, - sem_transforms, - value, - ) = tree_util.tree_unflatten(tree, invars) - return pp.concat([ - pp.text("semaphore_wait"), - pp.text(" "), - sp.pp_ref_transforms(context, sem, sem_transforms), - pp.text(" "), - pp.text(jax_core.pp_var(value, context)), - ]) -jax_core.pp_eqn_rules[semaphore_wait_p] = _semaphore_wait_pp_eqn - -def _semaphore_wait_discharge_rule(in_avals, - out_avals, - *flat_args, - args_tree): - del out_avals - [ref, transforms, dec] = args_tree.unflatten(flat_args) - sem_value = _transform_semaphore(ref, transforms, in_avals[0]) - dec = dec.astype(pl_core.SEMAPHORE_INTERPRET_DTYPE) - _, new_sem_value = state_discharge.transform_swap_array( - ref, transforms, sem_value - dec - ) - return (new_sem_value,) + (None,) * (len(in_avals) - 1), () -state_discharge.register_discharge_rule(semaphore_wait_p)( - _semaphore_wait_discharge_rule -) - - @dataclasses.dataclass class AsyncCopyDescriptor: src_ref: Any @@ -420,7 +171,7 @@ class AsyncCopyDescriptor: src_sem: int | jax.Array | None src_sem_transforms: tuple[Transform, ...] | None device_id: int | jax.Array | None - device_id_type: DeviceIdType = DeviceIdType.MESH + device_id_type: primitives.DeviceIdType = primitives.DeviceIdType.MESH def __post_init__(self): if (self.src_sem is None) ^ (self.device_id is None): @@ -457,9 +208,14 @@ def _get_args_and_tree(self, swap_src_and_dst: bool = False): self.device_id, )) - def start(self): + def start(self, priority: int = 0): flat_args, tree = self._get_args_and_tree() - dma_start_p.bind(*flat_args, tree=tree, device_id_type=self.device_id_type) + dma_start_p.bind( + *flat_args, + tree=tree, + device_id_type=self.device_id_type, + priority=priority, + ) def wait(self): if self.is_remote: @@ -488,7 +244,9 @@ def wait_send(self): dma_start_p.multiple_results = True @dma_start_p.def_effectful_abstract_eval -def _dma_start_abstract_eval(*args, tree, device_id_type): +def _dma_start_abstract_eval(*args, tree, device_id_type, priority): + if priority < 0: + raise ValueError(f"DMA start priority must be non-negative: {priority}") ( src_ref_aval, src_transforms_avals, @@ -523,6 +281,7 @@ def _dma_start_pp_eqn(eqn: jax_core.JaxprEqn, settings: jax_core.JaxprPpSettings): invars = eqn.invars tree = eqn.params["tree"] + priority = eqn.params["priority"] ( src_ref, src_transforms, @@ -539,7 +298,7 @@ def _dma_start_pp_eqn(eqn: jax_core.JaxprEqn, if src_sem or device_id: return jax_core._pp_eqn(eqn, context, settings) return pp.concat([ - pp.text("dma_start"), + pp.text(f"dma_start(p{priority})"), pp.text(" "), sp.pp_ref_transforms(context, src_ref, src_transforms), pp.text(" -> "), @@ -550,8 +309,12 @@ def _dma_start_pp_eqn(eqn: jax_core.JaxprEqn, jax_core.pp_eqn_rules[dma_start_p] = _dma_start_pp_eqn -def dma_start_partial_discharge_rule(should_discharge, in_avals, out_avals, - *args, tree, device_id_type): + +def dma_start_partial_discharge_rule( + should_discharge, in_avals, out_avals, *args, tree, device_id_type, priority +): + # Note: we ignore the DMA priority in discharge rules. + del priority ( src_ref, src_transforms, @@ -610,14 +373,14 @@ def dma_start_partial_discharge_rule(should_discharge, in_avals, out_avals, # TODO(justinfu): Verify that code only works in SPMD mode. axis_env = jax_core.get_axis_env() nonempty_axes = [name for name in axis_env.axis_sizes if name is not None] - if device_id_type == DeviceIdType.LOGICAL: + if device_id_type == primitives.DeviceIdType.LOGICAL: if len(nonempty_axes) > 1: raise NotImplementedError("Sharding with more than one named axis not " "implemented in dma_start_p for LOGICAL " "device_id_type.") shard_axis = nonempty_axes[0] my_axis = jax.lax.axis_index(shard_axis) - elif device_id_type == DeviceIdType.MESH: + elif device_id_type == primitives.DeviceIdType.MESH: device_id_len = 1 if isinstance(device_id, jax.Array): device_id_len = device_id.size @@ -667,7 +430,7 @@ def do_discharge_dst(dst_ref=dst_ref): def do_discharge_dst_sem(dst_sem=dst_sem): recv_size = jnp.minimum(updates.size, pl_core.SEMAPHORE_MAX_VALUE) recv_size = jnp.array(recv_size, dtype=pl_core.SEMAPHORE_INTERPRET_DTYPE) - dst_sem_value = _transform_semaphore( + dst_sem_value = primitives._transform_semaphore( dst_sem, dst_sem_transforms, dst_sem_aval ) _, ret = state_discharge.transform_swap_array( @@ -678,7 +441,7 @@ def do_discharge_dst_sem(dst_sem=dst_sem): def do_discharge_src_sem(src_sem=src_sem): send_size = jnp.minimum(local_src.size, pl_core.SEMAPHORE_MAX_VALUE) send_size = jnp.array(send_size, dtype=pl_core.SEMAPHORE_INTERPRET_DTYPE) - src_sem_value = _transform_semaphore( + src_sem_value = primitives._transform_semaphore( src_sem, src_sem_transforms, src_sem_aval ) _, ret = state_discharge.transform_swap_array( @@ -710,6 +473,7 @@ def do_discharge_src_sem(src_sem=src_sem): return new_vals, [] + state_discharge.register_partial_discharge_rule(dma_start_p)(dma_start_partial_discharge_rule) @@ -778,7 +542,7 @@ def dma_wait_partial_discharge_rule(should_discharge, updates = state_discharge.transform_array(dst_ref, dst_ref_transforms) copy_size = jnp.minimum(updates.size, pl_core.SEMAPHORE_MAX_VALUE) copy_size = jnp.array(copy_size, dtype=pl_core.SEMAPHORE_INTERPRET_DTYPE) - sem_value = _transform_semaphore(dst_sem, dst_sem_transforms, dst_sem_aval) + sem_value = primitives._transform_semaphore(dst_sem, dst_sem_transforms, dst_sem_aval) _, new_sem = state_discharge.transform_swap_array( dst_sem, dst_sem_transforms, sem_value - copy_size ) @@ -799,6 +563,7 @@ def _get_ref_and_transforms(ref): return ref.ref, ref.transforms return ref, () + def make_async_copy(src_ref, dst_ref, sem): """Issues a DMA copying from src_ref to dst_ref.""" src_ref, src_transforms = _get_ref_and_transforms(src_ref) @@ -814,17 +579,19 @@ def make_async_copy(src_ref, dst_ref, sem): None, None, None, - DeviceIdType.MESH, + primitives.DeviceIdType.MESH, ) -def async_copy(src_ref, dst_ref, sem): + +def async_copy(src_ref, dst_ref, sem, *, priority: int = 0): """Issues a DMA copying from src_ref to dst_ref.""" copy_descriptor = make_async_copy(src_ref, dst_ref, sem) - copy_descriptor.start() + copy_descriptor.start(priority=priority) return copy_descriptor + def make_async_remote_copy(src_ref, dst_ref, send_sem, recv_sem, device_id, - device_id_type: DeviceIdType = DeviceIdType.MESH): + device_id_type: primitives.DeviceIdType = primitives.DeviceIdType.MESH): """Creates a description of a remote copy operation. Copies data from src_ref on the current device to dst_ref on the device @@ -861,26 +628,18 @@ def make_async_remote_copy(src_ref, dst_ref, send_sem, recv_sem, device_id, ) def async_remote_copy(src_ref, dst_ref, send_sem, recv_sem, device_id, - device_id_type: DeviceIdType = DeviceIdType.MESH): + device_id_type: primitives.DeviceIdType = primitives.DeviceIdType.MESH): copy_descriptor = make_async_remote_copy(src_ref, dst_ref, send_sem, recv_sem, device_id, device_id_type) copy_descriptor.start() return copy_descriptor -device_id_p = jax_core.Primitive('device_id') - -@device_id_p.def_abstract_eval -def _device_id_abstract_eval(): - return jax_core.ShapedArray((), jnp.dtype("int32")) - -device_id = device_id_p.bind - get_barrier_semaphore_p = jax_core.Primitive('get_barrier_semaphore') @get_barrier_semaphore_p.def_abstract_eval def _get_barrier_semaphore_abstract_eval(): return pl_core.AbstractMemoryRef( - jax_core.ShapedArray((), tpu_core.BarrierSemaphoreTy()), + jax_core.ShapedArray((), pl_core.BarrierSemaphore()), tpu_core.TPUMemorySpace.SEMAPHORE, ) diff --git a/jax/_src/pallas/mosaic_gpu/BUILD b/jax/_src/pallas/mosaic_gpu/BUILD index e5b491aef330..554b9db878f6 100644 --- a/jax/_src/pallas/mosaic_gpu/BUILD +++ b/jax/_src/pallas/mosaic_gpu/BUILD @@ -48,7 +48,7 @@ pytype_strict_library( "//jax:mlir", "//jax:mosaic_gpu", "//jax/_src/pallas", - ], + ] + py_deps("numpy"), ) pytype_strict_library( @@ -78,8 +78,10 @@ pytype_strict_library( "//jax:dtypes", "//jax:effects", "//jax:mosaic_gpu", + "//jax:pretty_printer", "//jax:state_types", "//jax:tree_util", + "//jax/_src/lib", "//jax/_src/pallas", "//jaxlib/mlir:ir", ] + py_deps("numpy"), @@ -93,8 +95,8 @@ pytype_strict_library( ":lowering", "//jax", "//jax:core", - "//jax:mlir", "//jax:mosaic_gpu", + "//jax:pretty_printer", "//jax:tree_util", "//jax:util", "//jax/_src/lib", diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 630c1b8f4bed..bb08d8f090a7 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -18,7 +18,7 @@ import abc import collections -from collections.abc import Iterable, Sequence +from collections.abc import Callable, Iterable, Sequence import dataclasses import enum import itertools as it @@ -29,11 +29,16 @@ from jax._src import dtypes from jax._src import effects from jax._src import tree_util +from jax._src import pretty_printer as pp +from jax._src.lib.mlir.dialects import arith as arith_dialect from jax._src.pallas import core as pallas_core +from jax._src.pallas import primitives as pallas_primitives +import jax._src.pallas.utils as pallas_utils +from jax._src.state import discharge as state_discharge from jax._src.state import indexing from jax._src.state import types as state_types -from jax._src.state import discharge as state_discharge import jax.experimental.mosaic.gpu as mgpu +from jax.experimental.mosaic.gpu import utils as mgpu_utils import jax.numpy as jnp from jaxlib.mlir import ir @@ -86,7 +91,7 @@ class GPUCompilerParams(pallas_core.CompilerParams): delay_release: int = 0 profile_space: int = 0 profile_dir: str = "" - thread_semantics: mgpu.core.ThreadSemantics = mgpu.core.ThreadSemantics.Lane + lowering_semantics: mgpu.core.LoweringSemantics = mgpu.core.LoweringSemantics.Lane def __post_init__(self): if bool(self.profile_space) ^ bool(self.profile_dir): @@ -100,6 +105,8 @@ class GPUMemorySpace(enum.Enum): GMEM = "gmem" #: Shared memory. SMEM = "smem" + #: Tensor memory. + TMEM = "tmem" #: Registers. REGS = "regs" @@ -111,20 +118,64 @@ def __call__( shape: tuple[int, ...], dtype: jnp.dtype, transforms: Sequence[MemoryRefTransform] = (), - ) -> pallas_core.MemoryRef: # A convenience function for constructing MemoryRef types. return GPUMemoryRef(shape, dtype, memory_space=self, transforms=transforms) -def kernel(body, out_shape, compiler_params=None, **mesh_kwargs): +class SemaphoreType(enum.Enum): + REGULAR = "regular" + BARRIER = "barrier" + + def __call__(self, shape: tuple[int, ...]): + dtype: Any + if self == SemaphoreType.BARRIER: + dtype = pallas_core.BarrierSemaphore() + else: + dtype = pallas_core.Semaphore() + return pallas_core.MemoryRef(shape, dtype, GPUMemorySpace.GMEM) + + def get_array_aval(self) -> jax_core.ShapedArray: + return self(()).get_array_aval() + + def get_ref_aval(self) -> pallas_core.TransformedRef | AbstractMemoryRef: + return self(()).get_ref_aval() + + +class PrimitiveSemantics(enum.Enum): + """Thread semantics for a primitives at the Pallas user-level.""" + + Warp = enum.auto() + Warpgroup = enum.auto() + + +# Convenience constants for (lowering, primitive) thread semantics pairs. +LANExWG_SEMANTICS = ( + mgpu.LoweringSemantics.Lane, PrimitiveSemantics.Warpgroup) +LANExWARP_SEMANTICS = ( + mgpu.LoweringSemantics.Lane, PrimitiveSemantics.Warp) +WGxWG_SEMANTICS = ( + mgpu.LoweringSemantics.Warpgroup, PrimitiveSemantics.Warpgroup) + + +def kernel( + body: Callable[..., None], + out_shape: object, + *, + scratch_shapes: Sequence[pallas_core.ScratchShape] = (), + compiler_params: object | None = None, + **mesh_kwargs: object, +): if unwrap_out := not isinstance(out_shape, (tuple, list)): out_shape = (out_shape,) def wrapper(*operands): def stateful(operand_and_out_refs): operand_refs, out_refs = operand_and_out_refs def cmap_body(): - body(*operand_refs, *out_refs) + pallas_primitives.run_scoped( + lambda *scratch_refs: body(*operand_refs, *out_refs, *scratch_refs), + *scratch_shapes, + ) pallas_core.core_map( GPUMesh(**mesh_kwargs), compiler_params=compiler_params )(cmap_body) @@ -135,6 +186,24 @@ def cmap_body(): return wrapper +def _is_known_divisible(value, divisor, fuel=10) -> bool: + """Returns True if the value is statically known to be divisible by the divisor.""" + if fuel < 0: + return False + if not isinstance(value.owner, ir.Operation): + return False + def_op = value.owner.opview + match def_op: + case arith_dialect.IndexCastOp(): + return _is_known_divisible(value.owner.operands[0], divisor, fuel - 1) + case arith_dialect.ConstantOp(): + return ir.IntegerAttr(def_op.value).value % divisor == 0 + case arith_dialect.MulIOp(): + return (_is_known_divisible(value.owner.operands[0], divisor, fuel // 2) or + _is_known_divisible(value.owner.operands[1], divisor, (fuel + 1)// 2)) + return False + + @dataclasses.dataclass(frozen=True) class GPUMemoryRef(pallas_core.MemoryRef): transforms: Sequence[MemoryRefTransform] = () @@ -171,7 +240,7 @@ def __call__(self, aval: jax_core.ShapedArray) -> jax_core.ShapedArray: shape=self.to_gpu_transform().transform_shape(aval.shape) ) -Index = slice | int | ir.Value +Index = mgpu.DynamicSlice | slice | int | ir.Value @dataclasses.dataclass(frozen=True) class TilingTransform(MemoryRefTransform): @@ -213,26 +282,74 @@ def transform_shape(self, shape): def transform_dtype(self, dtype): return dtype + def untransform_transpose( + self, perm: tuple[int, ...] + ) -> tuple[tuple[int, ...], state_types.Transform]: + # The transpose in question is applied to the utiled ref so we + # need to translate it by duplicating and offseting the last part. + off = len(perm) + new_suffix = [i + off for i in perm[-len(self.tiling) :]] + if set(new_suffix) != set(range(off, off + len(self.tiling))): + raise ValueError( + "Transpose cannot be moved before a tiling transform when it changes" + f" the set of tiled dimensions. (permutation: {perm}, tiling:" + f" {self.tiling})" + ) + + new_tiling = tuple(self.tiling[i - off] for i in new_suffix) + return (*perm, *new_suffix), dataclasses.replace(self, tiling=new_tiling) + + def untransform_reshape( + self, dtype: jnp.dtype, shape: tuple[int, ...] + ) -> tuple[tuple[int, ...], state_types.Transform]: + del dtype + raise NotImplementedError("Reshapes don't commute with transposes.") + def untransform_index( - self, idxs: tuple[Index, ...] + self, dtype: jnp.dtype | ir.Type, idxs: tuple[Index, ...] ) -> tuple[tuple[Index, ...], state_types.Transform]: + del dtype untiled_idxs = idxs[: -len(self.tiling)] tiled_idxs = idxs[-len(self.tiling) :] - idxs_after_tiling = [] + idxs_after_tiling: list[Index] = [] for idx, tile in zip(tiled_idxs, self.tiling): - if not isinstance(idx, slice): - raise NotImplementedError("Non-slice indices are not supported") - assert isinstance(idx, slice) - if idx.step is not None and idx.step != 1: - raise NotImplementedError("Strided slices unsupported") - if (idx.start is not None and idx.start % tile) or (idx.stop is not None and idx.stop % tile): - raise ValueError("Non-empty slices must be tile aligned") - idxs_after_tiling.append(slice(idx.start // tile, idx.stop // tile)) + if isinstance(idx, slice): + if idx.step is not None and idx.step != 1: + raise NotImplementedError("Strided slices unsupported") + if (idx.start is not None and idx.start % tile) or (idx.stop is not None and idx.stop % tile): + raise ValueError("Non-empty slices must be tile aligned") + idxs_after_tiling.append(slice(idx.start // tile, idx.stop // tile)) + elif isinstance(idx, mgpu.DynamicSlice): + if idx.length % tile: + raise ValueError( + f"Dynamic slice length ({idx.length}) is not divisible by the" + f" tiling ({tile})" + ) + if isinstance(idx.base, ir.Value): + if not _is_known_divisible(idx.base, tile): + raise ValueError( + "Dynamic slice base index (which is a dynamic value) cannot be" + f" statically proven to be divisible by the tiling ({tile})" + ) + new_base = arith_dialect.divui(idx.base, mgpu.c(tile, idx.base.type)) + else: + if idx.base % tile: + raise ValueError( + f"Dynamic slice base ({idx.base}) is not divisible by the" + f" tiling ({tile})" + ) + new_base = idx.base // tile + idxs_after_tiling.append(mgpu.DynamicSlice(new_base, idx.length // tile)) + else: + raise TypeError(f"Unsupported index type: {type(idx)}") return (*untiled_idxs, *idxs_after_tiling, *(slice(None) for _ in self.tiling)), self def undo_to_gpu_transform(self) -> mgpu.MemRefTransform: return mgpu.TileTransform(self.tiling) + def pretty_print(self, context: jax_core.JaxprPpContext) -> pp.Doc: + return pp.text(f"{{untile({list(self.tiling)})}}") + def _perm_inverse(permutation: tuple[int, ...]) -> tuple[int, ...]: inverse = [-1] * len(permutation) @@ -271,7 +388,7 @@ def to_gpu_transform(self) -> mgpu.MemRefTransform: @tree_util.register_dataclass @dataclasses.dataclass(frozen=True) class TransposeRef(state_types.Transform): - permutation: tuple[int, ...] + permutation: tuple[int, ...] = dataclasses.field(metadata=dict(static=True)) def transform_shape(self, shape): if shape is None: @@ -281,11 +398,25 @@ def transform_shape(self, shape): def transform_dtype(self, dtype): return dtype + def untransform_transpose( + self, perm + ) -> tuple[tuple[int, ...], state_types.Transform]: + raise NotImplementedError( + "Commuting of transpose over transpose is not supported." + ) + + def untransform_reshape( + self, dtype: jnp.dtype | ir.Type, shape: tuple[int, ...] + ) -> tuple[tuple[int, ...], state_types.Transform]: + del shape, dtype + raise NotImplementedError("Can't reshape a transposed memref.") + def untransform_index( - self, idxs: tuple[Index, ...] + self, dtype: jnp.dtype | ir.Type, idxs: tuple[Index, ...] ) -> tuple[tuple[Index, ...], state_types.Transform]: + del dtype removed_dims = [ - i for i, idx in enumerate(idxs) if not isinstance(idx, slice) + i for i, idx in enumerate(idxs) if not isinstance(idx, (slice, mgpu.ds)) ] new_perm = tuple( p - sum(d < p for d in removed_dims) @@ -298,19 +429,34 @@ def untransform_index( def undo_to_gpu_transform(self) -> mgpu.MemRefTransform: return mgpu.TransposeTransform(_perm_inverse(self.permutation)) + def pretty_print(self, context: jax_core.JaxprPpContext) -> pp.Doc: + return pp.text(f"{{transpose({list(self.permutation)})}}") -def transpose_ref( - ref: pallas_core.TransformedRef | Any, - permutation: tuple[int, ...], + +def transform_ref( + ref: pallas_core.TransformedRef, + transform: state_types.Transform ) -> pallas_core.TransformedRef: if not isinstance(ref, pallas_core.TransformedRef): if not isinstance(jax_core.get_aval(ref), pallas_core.AbstractMemoryRef): raise TypeError("ref must be a reference") ref = pallas_core.TransformedRef(ref, transforms=()) return pallas_core.TransformedRef( - ref.ref, (*ref.transforms, TransposeRef(permutation)), + ref.ref, (*ref.transforms, transform), ) +def transpose_ref( + ref: pallas_core.TransformedRef | Any, + permutation: tuple[int, ...], +) -> pallas_core.TransformedRef: + return transform_ref(ref, TransposeRef(permutation)) + +def untile_ref(ref, tiling: tuple[int, ...]) -> pallas_core.TransformedRef: + return transform_ref(ref, UntileRef(tiling)) + +def unswizzle_ref(ref, swizzle: int) -> pallas_core.TransformedRef: + return transform_ref(ref, UnswizzleRef(swizzle)) + @dataclasses.dataclass(frozen=True) class SwizzleTransform(MemoryRefTransform): @@ -339,7 +485,7 @@ def undo_to_gpu_transform(self) -> mgpu.MemRefTransform: raise NotImplementedError def __call__(self, aval: jax_core.ShapedArray) -> jax_core.ShapedArray: - swizzle_elems = self.swizzle // aval.dtype.itemsize + swizzle_elems = (self.swizzle * 8) // pallas_utils.dtype_bitwidth(aval.dtype) if swizzle_elems != aval.shape[-1]: raise ValueError( f"Swizzle {self.swizzle} requires the trailing dimension to be of" @@ -353,25 +499,54 @@ def __call__(self, aval: jax_core.ShapedArray) -> jax_core.ShapedArray: class UnswizzleRef(state_types.Transform): swizzle: int = dataclasses.field(metadata=dict(static=True)) + def swizzle_elems(self, dtype: jnp.dtype | ir.Type) -> int: + if not isinstance(dtype, ir.Type): + dtype = mgpu_utils.dtype_to_ir_type(dtype) + return (self.swizzle * 8) // mgpu.bitwidth(dtype) + + def untransform_transpose(self, perm) -> tuple[tuple[int, ...], state_types.Transform]: + if perm[-1] != len(perm) - 1: + raise ValueError("Can't transpose the swizzled dimension.") + + return perm, self + + def untransform_reshape( + self, dtype: jnp.dtype | ir.Type, shape: tuple[int, ...] + ) -> tuple[tuple[int, ...], state_types.Transform]: + if shape[-1] != self.swizzle_elems(dtype): + raise ValueError( + f"Reshape shape {shape} is not divisible by swizzle elements" + f" {self.swizzle_elems(dtype)}" + ) + return shape, self + def untransform_index( - self, idxs: tuple[Index, ...] + self, dtype: jnp.dtype | ir.Type, idxs: tuple[Index, ...] ) -> tuple[tuple[Index, ...], state_types.Transform]: + swizzle_elems = self.swizzle_elems(dtype) if not idxs: return idxs, self - if not all(isinstance(idx, slice) for idx in idxs[-2:]): + if not all(isinstance(idx, (slice, mgpu.ds)) for idx in idxs[-2:]): raise NotImplementedError( "Non-slice indices are not supported in 2 minormost dims" ) last_idx = idxs[-1] - assert isinstance(last_idx, slice) - if last_idx.step is not None and last_idx.step != 1: - raise NotImplementedError("Swizzled dims cannot be sliced") - if (last_idx.start is not None and last_idx.start != 0) or ( - last_idx.stop is not None and last_idx.stop != self.swizzle - ): - raise ValueError("Swizzled dims cannot be sliced") + if isinstance(last_idx, mgpu.DynamicSlice): + if last_idx.base != 0 or last_idx.length != swizzle_elems: + raise ValueError("Swizzled dims cannot be sliced") + else: + assert isinstance(last_idx, slice) + if ( + (last_idx.step is not None and last_idx.step != 1) + or (last_idx.start is not None and last_idx.start != 0) + or (last_idx.stop is not None and last_idx.stop != swizzle_elems) + ): + raise ValueError("Swizzled dims cannot be sliced") return idxs, self + def pretty_print(self, context: jax_core.JaxprPpContext) -> pp.Doc: + return pp.text(f"{{unswizzle({self.swizzle})}}") + @dataclasses.dataclass class GPUBlockSpec(pallas_core.BlockSpec): @@ -408,6 +583,7 @@ def to_block_mapping( GMEM = GPUMemorySpace.GMEM SMEM = GPUMemorySpace.SMEM +TMEM = GPUMemorySpace.TMEM REGS = GPUMemorySpace.REGS @@ -426,6 +602,17 @@ def __str__(self): return self.name +@dataclasses.dataclass(frozen=True) +class ClusterBarrierType(dtypes.ExtendedDType): + type: ClassVar[Any] = barrier_dtype + name: ClassVar[str] = "cluster_barrier" + + collective_axes: tuple[str | tuple[str, ...], ...] + + def __str__(self): + return self.name + + @dataclasses.dataclass(frozen=True) class Barrier: num_arrivals: int @@ -438,6 +625,18 @@ def get_ref_aval(self) -> AbstractMemoryRef: return AbstractMemoryRef(aval, SMEM) +@dataclasses.dataclass(frozen=True) +class ClusterBarrier: + collective_axes: tuple[str | tuple[str, ...], ...] + num_barriers: int = 1 + + def get_ref_aval(self) -> AbstractMemoryRef: + aval = jax_core.ShapedArray( + [self.num_barriers], ClusterBarrierType(self.collective_axes) + ) + return AbstractMemoryRef(aval, SMEM) + + @dataclasses.dataclass(frozen=True) class WGMMAAccumulatorRef: shape: tuple[int, int] @@ -499,24 +698,35 @@ def _as_accum(ref) -> WGMMAAbstractAccumulatorRef: @dataclasses.dataclass(frozen=True, kw_only=True) class GPUMesh: - grid: tuple[int, ...] = () - cluster: tuple[int, ...] = () + grid: Sequence[int] = () + grid_names: Sequence[str] = () + cluster: Sequence[int] = () + cluster_names: Sequence[str] = () # Those are NOT CUDA threads. On Hopper they correspond to warpgroups. num_threads: int | None = None - axis_names: tuple[str, ...] = () + thread_name: str | None = None def __post_init__(self): - if len(self.axis_names) != len(self.grid) + (self.num_threads is not None): - raise ValueError("Need as many axis names as grid dimensions + warp groups") + if len(self.cluster) > 3: + raise ValueError(f"cluster= must be at most 3D, got {self}.") + if len(self.grid_names) != len(self.grid): + raise ValueError( + f"grid_names must have the same length as grid, got {self}." + ) + if len(self.cluster_names) != len(self.cluster): + raise ValueError( + f"cluster_names must have the same length as cluster, got {self}." + ) + if (self.thread_name is None) != (self.num_threads is None): + raise ValueError( + "num_threads and thread_name must be either both set or both None," + f" got {self}" + ) if self.num_threads is not None and self.num_threads > 2048 // 128: raise ValueError( "Requested too many CUDA threads per block. Each Mosaic thread" " corresponds to 128 CUDA threads." ) - if self.cluster: - raise NotImplementedError( - "Pallas/MosaicGPU does not support clusters yet." - ) @property def backend(self) -> str: @@ -527,20 +737,40 @@ def shape(self) -> collections.OrderedDict[object, int]: pairs: Iterable[tuple[object, int]] if self.num_threads is not None: pairs = zip( - self.axis_names, (*self.grid, *self.cluster, self.num_threads) + (*self.grid_names, *self.cluster_names, self.thread_name), + (*self.grid, *self.cluster, self.num_threads), ) else: - pairs = tuple( - zip( - (*self.axis_names, _WARPGROUP_AXIS_NAME), - (*self.grid, *self.cluster, 1), - ) + pairs = zip( + (*self.grid_names, *self.cluster_names), + (*self.grid, *self.cluster), ) return collections.OrderedDict(pairs) def discharges_effect(self, effect: jax_core.Effect): return effect is _wgmma_pipeline_effect or effect is _memory_effect +@dataclasses.dataclass(frozen=True, kw_only=True) +class WarpMesh: + """Represents a mesh over individual warps within a warpgroup. + + When used in conjunction with `core_map`, the warp ID will be visible + within the body of the wrapped scope by querying `lax.axis_index` with + the specified axis name. + """ + + _NUM_WARPS_PER_WARPGROUP: ClassVar[int] = 4 + axis_name: str + + @property + def shape(self): + return collections.OrderedDict([ + (self.axis_name, self._NUM_WARPS_PER_WARPGROUP), + ]) + + def discharges_effect(self, effect: jax_core.Effect): + del effect + return False def _gpu_mesh_discharge_rule( in_avals, @@ -556,8 +786,6 @@ def _gpu_mesh_discharge_rule( ): if not isinstance(mesh, GPUMesh): raise TypeError(f"Mesh must be a GPUMesh, got {type(mesh)}") - if mesh.cluster: - raise NotImplementedError if compiler_params and not isinstance(compiler_params, GPUCompilerParams): raise TypeError( "Compiler params must be a GPUCompilerParams, got" @@ -576,6 +804,7 @@ def _gpu_mesh_discharge_rule( interpret=interpret, cost_estimate=cost_estimate, name=name, + memory_space=GMEM, ) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 6b06e6b7dfc2..726d89bfffc5 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -17,11 +17,13 @@ from __future__ import annotations import collections -from collections.abc import Callable, Hashable, MutableMapping, MutableSequence, Sequence +from collections.abc import Callable, Hashable, Iterable, MutableMapping, MutableSequence, Sequence import contextlib import dataclasses import functools +import itertools import math +import operator from typing import Any, Protocol, cast import jax @@ -57,6 +59,7 @@ from jax.experimental.mosaic.gpu import core as mgpu_core from jax.experimental.mosaic.gpu import profiler as mgpu_profiler from jax.experimental.mosaic.gpu import utils as mgpu_utils +from jax.experimental.mosaic.gpu import tcgen05 import jax.numpy as jnp import numpy as np @@ -82,23 +85,28 @@ def _align_to(x: int, alignment: int): return x -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, kw_only=True) class ResourceEstimatorContext: - thread_semantics: mgpu.ThreadSemantics + axis_names: _AxisNames + lowering_semantics: mgpu.LoweringSemantics @property def arrival_multiplier(self) -> int: return ( WARPGROUP_SIZE - if self.thread_semantics == mgpu.ThreadSemantics.Lane + if self.lowering_semantics == mgpu.LoweringSemantics.Lane else 1 ) +AnyBarrier = mgpu.Barrier | mgpu.ClusterBarrier + + @dataclasses.dataclass(kw_only=True, frozen=True) class Resources: smem_scratch_bytes: int = 0 - barrier_counts: collections.Counter[mgpu.Barrier] = dataclasses.field( + tmem_scratch_cols: int = 0 + barrier_counts: collections.Counter[AnyBarrier] = dataclasses.field( default_factory=collections.Counter ) @@ -108,9 +116,15 @@ def __post_init__(self): "smem_scratch_bytes", _align_to(self.smem_scratch_bytes, _SMEM_ALIGNMENT), ) + object.__setattr__( + self, + "tmem_scratch_cols", + # TMEM must be allocated in 128x8 chunks. + _align_to(self.tmem_scratch_cols, 8), + ) @property - def barriers(self) -> Sequence[mgpu.Barrier]: + def barriers(self) -> Sequence[AnyBarrier]: return list(self.barrier_counts.elements()) def __add__(self, other: Resources) -> Resources: @@ -120,6 +134,7 @@ def __add__(self, other: Resources) -> Resources: # we will allocate two barriers, even though one would be enough. return Resources( smem_scratch_bytes=self.smem_scratch_bytes + other.smem_scratch_bytes, + tmem_scratch_cols=self.tmem_scratch_cols + other.tmem_scratch_cols, barrier_counts=self.barrier_counts + other.barrier_counts, ) @@ -128,6 +143,9 @@ def __or__(self, other: Resources) -> Resources: smem_scratch_bytes=max( self.smem_scratch_bytes, other.smem_scratch_bytes ), + tmem_scratch_cols=max( + self.tmem_scratch_cols, other.tmem_scratch_cols + ), barrier_counts=self.barrier_counts | other.barrier_counts, ) @@ -216,10 +234,36 @@ def _run_scoped_resource_estimator( ) ]) ) - else: + elif isinstance(aval.dtype, gpu_core.ClusterBarrierType): + collective_dims = jax.tree.map( + lambda axis: _resolve_cluster_axis(ctx.axis_names, axis), + aval.dtype.collective_axes, + ) + rs += Resources( + barrier_counts=collections.Counter( + [mgpu.ClusterBarrier(collective_dims, *aval.shape)] + ) + ) + elif aval.memory_space == gpu_core.TMEM: + if aval.dtype.itemsize != 4: + raise ValueError("TMEM only supports 32-bit types.") + if len(aval.shape) != 2: + raise ValueError("TMEM allocations must be 2D.") + if aval.shape[0] % tcgen05.TMEM_ROWS != 0: + raise ValueError("TMEM shape[0] must be a multiple of 128.") + if aval.shape[1] % 8 != 0: + raise ValueError("TMEM shape[1] must be a multiple of 8.") + rs += Resources(tmem_scratch_cols=aval.shape[1]) + elif aval.memory_space == gpu_core.SMEM: rs += Resources( smem_scratch_bytes=math.prod(aval.shape) * aval.dtype.itemsize ) + elif aval.memory_space == gpu_core.REGS: + # Don't need to allocate anything. + pass + else: + raise NotImplementedError( + f"Unsupported memory space: {aval.memory_space}") return rs + _estimate_resources(ctx, jaxpr) @@ -233,23 +277,57 @@ def _reduce_sum_resource_estimator( return Resources(smem_scratch_bytes=4 * x_aval.dtype.itemsize) +@dataclasses.dataclass(frozen=True) +class _AxisNames: + grid: Sequence[Hashable] + cluster: Sequence[Hashable] = () + wg: Hashable | None = None + + def __iter__(self) -> Iterable[Hashable]: + return itertools.chain( + self.grid, self.cluster, [self.wg] if self.wg is not None else [] + ) + + +AnyBarrierRef = mgpu.BarrierRef | mgpu.CollectiveBarrierRef + + @dataclasses.dataclass class ModuleContext: name: str - grid_names: Sequence[Hashable] | None + axis_names: _AxisNames | None program_ids: Sequence[ir.Value] | None approx_math: bool - single_wg_lane_predicate: ir.Value + single_wg_lane_predicate: ir.Value | None + single_warp_lane_predicate: ir.Value | None smem_requested_bytes: int smem_used_bytes: int - runtime_barriers: MutableMapping[ - mgpu.Barrier, MutableSequence[mgpu.BarrierRef] - ] + tmem_requested_cols: int + tmem_used_cols: int + tmem_base_ptr: ir.Value + runtime_barriers: MutableMapping[AnyBarrier, MutableSequence[AnyBarrierRef]] name_stack: source_info_util.NameStack traceback_caches: mlir.TracebackCaches squashed_dims: tuple[int, ...] - thread_semantics: mgpu.ThreadSemantics + lowering_semantics: mgpu.LoweringSemantics + primitive_semantics: gpu_core.PrimitiveSemantics + warp_axis_name: str | None = None + + @property + def single_lane_predicate(self) -> ir.Value: + """Returns a predicate that is True for a single lane within the current + thread semantics. + """ + assert self.lowering_semantics == mgpu.LoweringSemantics.Lane + match self.primitive_semantics: + case gpu_core.PrimitiveSemantics.Warpgroup: + return self.single_wg_lane_predicate + case gpu_core.PrimitiveSemantics.Warp: + return self.single_warp_lane_predicate + case _: + raise ValueError(f"Unknown semantics: {self.primitive_semantics}") + @contextlib.contextmanager def reserve_barrier(self, barrier: mgpu.Barrier) -> mgpu.BarrierRef: """Reserves a barrier. @@ -259,7 +337,30 @@ def reserve_barrier(self, barrier: mgpu.Barrier) -> mgpu.BarrierRef: available = self.runtime_barriers.get(barrier, []) if not available: raise RuntimeError(f"Barrier {barrier} is already reserved") - return available.pop() + barrier = available.pop() + yield barrier + available.append(barrier) + + @contextlib.contextmanager + def alloc_tmem( + self, + struct: jax.ShapeDtypeStruct, + layout: tcgen05.TMEMLayout | None = None + ) -> ir.Value: + if self.tmem_used_cols > 0: + raise NotImplementedError( + "Multiple TMEM allocations are not implemented.") + if layout is None: + layout = tcgen05._infer_tmem_layout(struct.shape, collective=False) + cols_used = np.prod(struct.shape) // tcgen05.TMEM_ROWS + self.tmem_used_cols += cols_used + off = self.tmem_base_ptr + tmem_ref = tcgen05.TMEMRef(address=off, + shape=struct.shape, + dtype=mgpu_utils.dtype_to_ir_type(struct.dtype), + layout=layout) + yield tmem_ref + self.tmem_used_cols -= cols_used # TODO(cperivol): Only return the shapes and figure out the sizes when freeing. @contextlib.contextmanager @@ -286,7 +387,7 @@ def scratch_view( smem = ir.Attribute.parse("#gpu.address_space") i8 = ir.IntegerType.get_signless(8) i32 = ir.IntegerType.get_signless(32) - if self.thread_semantics == mgpu.ThreadSemantics.Lane: + if self.lowering_semantics == mgpu.LoweringSemantics.Lane: smem_base = gpu_dialect.dynamic_shared_memory( ir.MemRefType.get((mgpu_utils.DYNAMIC,), i8, memory_space=smem) ) @@ -302,7 +403,7 @@ def scratch_view( # The below code emission relies on the assumption that the first scratch # operand provided by Mosaic GPU always begins at the beginning of # dynamic SMEM. Mosaic GPU is expected to uphold that invariant. - if self.thread_semantics == mgpu.ThreadSemantics.Lane: + if self.lowering_semantics == mgpu.LoweringSemantics.Lane: view = memref_dialect.view( scratch_ty, smem_base, _as_index(off), [] ) @@ -333,7 +434,10 @@ class LoweringRuleContext: @property def estimator_ctx(self) -> ResourceEstimatorContext: - return ResourceEstimatorContext(thread_semantics=self.module_ctx.thread_semantics) + return ResourceEstimatorContext( + axis_names=self.module_ctx.axis_names, + lowering_semantics=self.module_ctx.lowering_semantics, + ) @dataclasses.dataclass(frozen=True) @@ -341,8 +445,9 @@ class LoweringResult: module: ir.Module grid: tuple[int, ...] block: tuple[int, ...] - out_structs: tuple[jax.ShapeDtypeStruct, ...] + new_out_shapes: tuple[jax.ShapeDtypeStruct, ...] # Does not include gmem scratch! profiler_context: ProfilerContext | None + gmem_scratch_shapes: tuple[jax.ShapeDtypeStruct, ...] @dataclasses.dataclass(frozen=True) @@ -387,7 +492,7 @@ def err_details(bm: pallas_core.BlockMapping) -> str: f" and index_map {bm.index_map_jaxpr.jaxpr} in" f" memory space {bm.transformed_block_aval.memory_space}." " See details at" - " https://jax.readthedocs.io/en/latest/pallas/grid_blockspec.html#pallas-blockspec." + " https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec." ) for bm in block_mappings: @@ -474,12 +579,10 @@ def lower_pipelined_jaxpr_to_module( block_mappings, [grid_mapping.num_inputs] ) - if mesh is not None: + if mesh: assert isinstance(mesh, gpu_core.GPUMesh) - if mesh and mesh.num_threads is not None: - # Last dim corresponds to the warpgroup count. - block = (128 * grid_mapping.grid[-1], 1, 1) - grid = grid_mapping.grid[:-1] + block = (128 * (mesh.num_threads or 1), 1, 1) + grid = mesh.grid else: block = (128, 1, 1) grid = grid_mapping.grid @@ -511,23 +614,44 @@ def ref_for_aval(aval: jax_core.AbstractValue): else: return gpu_core.SMEM(aval.shape, aval.dtype) + sem_placeholder = None + semaphore_ref_avals = [] + scratch_avals = [] + # Need to unzip semaphores + for v in jaxpr.invars[grid_mapping.slice_scratch_ops]: + aval = v.aval + if (isinstance(aval, pallas_core.AbstractMemoryRef) and + jnp.issubdtype(aval.dtype, pallas_core.semaphore_dtype)): + if aval.memory_space != gpu_core.GPUMemorySpace.GMEM: + raise ValueError( + "Only GMEM memory space is supported for semaphores in Mosaic GPU." + ) + semaphore_ref_avals.append(aval) + scratch_avals.append(sem_placeholder) + else: + scratch_avals.append(aval) + def pipeline_fn(*refs): - return primitives.run_scoped( - functools.partial(scoped_pipeline_fn, *refs), + sem_refs = [] + if semaphore_ref_avals: + refs, sem_refs = util.split_list(refs, [-len(semaphore_ref_avals)]) + primitives.run_scoped( + functools.partial(scoped_pipeline_fn, *refs, sem_refs=sem_refs), scratch_refs=[ - ref_for_aval(v.aval) - for v in jaxpr.invars[grid_mapping.slice_scratch_ops] + ref_for_aval(aval) if aval is not sem_placeholder else aval + for aval in scratch_avals ], ) - - def scoped_pipeline_fn(*refs, scratch_refs): - def body_fn(*refs): - grid_env = pallas_core.current_grid_env() - assert grid_env is not None # Set by ``emit_pipeline``. + return () # ``wrap_init`` does not support functions returning None. + + def scoped_pipeline_fn(*refs, sem_refs, scratch_refs): + sem_refs_it = iter(sem_refs) + scratch_refs = [ + next(sem_refs_it) if r is sem_placeholder else r for r in scratch_refs + ] + def body_fn(indices, *refs): program_ids_template = util.merge_lists( - which_parallel, - [grid_axis.index for grid_axis in grid_env], - [None] * sum(which_parallel), + which_parallel, indices, [None] * sum(which_parallel) ) assert len(refs) + len(scratch_refs) == len(jaxpr.invars) return gpu_primitives.jaxpr_call( @@ -551,28 +675,33 @@ def body_fn(*refs): with grid_mapping.trace_env(): new_jaxpr, _, new_consts, () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init( - # ``wrap_init`` does not support functions returning None. - lambda *args: pipeline_fn(*args) or (), - debug_info=jaxpr.debug_info, - ), + lu.wrap_init(pipeline_fn, debug_info=jaxpr.debug_info), [ gpu_core.GMEM( bm.array_shape_dtype.shape, bm.array_shape_dtype.dtype ).get_ref_aval() for bm in block_mappings - ], + ] + semaphore_ref_avals, ) assert not new_consts + axis_names = ( + _AxisNames(mesh.grid_names, mesh.cluster_names, mesh.thread_name) + if mesh is not None + else _AxisNames(grid_mapping.grid_names or ()) + ) with grid_mapping.trace_env(): return lower_jaxpr_to_module( parallel_grid, - grid_mapping.grid_names, + axis_names, block, mesh.cluster if mesh is not None else (), [bm.array_shape_dtype for bm in in_block_mappings], [bm.array_shape_dtype for bm in out_block_mappings], + [ + jax.ShapeDtypeStruct(r.shape, np.dtype(np.int32)) + for r in semaphore_ref_avals + ], new_jaxpr, compiler_params, new_consts, @@ -581,11 +710,12 @@ def body_fn(*refs): def lower_jaxpr_to_module( grid: Sequence[int], - grid_names: Sequence[str], + axis_names: _AxisNames, block: Sequence[int], cluster: Sequence[int], in_shapes: Sequence[jax.ShapeDtypeStruct], out_shapes: Sequence[jax.ShapeDtypeStruct], + gmem_scratch_shapes: Sequence[jax.ShapeDtypeStruct], jaxpr: jax_core.Jaxpr, compiler_params: dict[str, Any], consts=(), @@ -593,10 +723,15 @@ def lower_jaxpr_to_module( debug_info = jaxpr.debug_info params = compiler_params.get("mosaic_gpu", {}) approx_math = params.get("approx_math", False) - thread_semantics = params.get( - "thread_semantics", mgpu_core.ThreadSemantics.Lane + lowering_semantics = params.get( + "lowering_semantics", mgpu_core.LoweringSemantics.Lane ) + if len(cluster) < 3: + cluster = cluster + (1,) * (3 - len(cluster)) + else: + assert len(cluster) == 3 + if len(grid) <= 3: squashed_dims = () parallel_grid = grid + (1,) * (3 - len(grid)) @@ -607,24 +742,43 @@ def lower_jaxpr_to_module( parallel_grid = (math.prod(grid[:-2]), *grid[-2:]) def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): - *buffers_gmem, (runtime_smem, runtime_barriers) = buffers + *buffers_gmem, (runtime_smem, runtime_barriers, runtime_tmem) = buffers grouped_barriers = collections.defaultdict(list) for barrier, barrier_ref in zip(rs.barriers, runtime_barriers): grouped_barriers[barrier].append(barrier_ref) + if runtime_tmem is not None: + tmem_cols = math.prod(runtime_tmem.shape) // tcgen05.TMEM_ROWS + else: + tmem_cols = 0 + + if lowering_semantics == mgpu.LoweringSemantics.Lane: + single_wg_lane_predicate = mgpu.single_thread_predicate( + scope=mgpu.ThreadSubset.WARPGROUP) + single_warp_lane_predicate = mgpu.single_thread_predicate( + scope=mgpu.ThreadSubset.WARP) + else: # Warpgroup semantics do not have a single lane predicate. + single_wg_lane_predicate = None + single_warp_lane_predicate = None + module_ctx = ModuleContext( mlir.sanitize_name(debug_info.func_name), - grid_names, + axis_names, [_program_id(axis, squashed_dims) for axis in range(len(grid))], approx_math, - mgpu.single_thread_predicate(per_block=False), + single_wg_lane_predicate, + single_warp_lane_predicate, smem_requested_bytes=math.prod(ir.MemRefType(runtime_smem.type).shape), smem_used_bytes=0, + tmem_requested_cols=tmem_cols, + tmem_used_cols=0, + tmem_base_ptr=runtime_tmem.address if runtime_tmem else None, runtime_barriers=grouped_barriers, name_stack=source_info_util.NameStack(), traceback_caches=mlir.TracebackCaches(), squashed_dims=squashed_dims, - thread_semantics=thread_semantics, + lowering_semantics=lowering_semantics, + primitive_semantics=gpu_core.PrimitiveSemantics.Warpgroup, ) del runtime_smem, grouped_barriers, runtime_barriers @@ -632,59 +786,82 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): module_ctx, launch_ctx, jaxpr, buffers_gmem, consts ) - rs = _estimate_resources(ResourceEstimatorContext(thread_semantics), jaxpr) + rs = _estimate_resources( + ResourceEstimatorContext( + axis_names=axis_names, lowering_semantics=lowering_semantics + ), + jaxpr, + ) smem_scratch_bytes = params.get("smem_scratch_bytes") if smem_scratch_bytes is None: smem_scratch_bytes = rs.smem_scratch_bytes + tmem_scratch_cols = rs.tmem_scratch_cols + + scratch_buffers = [ + jax.ShapeDtypeStruct(shape=[smem_scratch_bytes], dtype=np.int8), + rs.barriers, + ] + if tmem_scratch_cols > 0: + scratch_buffers.append( + mgpu.TMEM(shape=[tcgen05.TMEM_ROWS, tmem_scratch_cols], dtype=np.int32), + ) + else: + scratch_buffers.append(None) prof_ctx = prof_spec = None if prof_space := params.get("profile_space", 0): # Each range is 2 events, each event is 4 bytes. prof_spec = mgpu_profiler.ProfilerSpec(prof_space * 2 * 4) prof_ctx = ProfilerContext(params["profile_dir"], prof_spec) - module, out_structs_gmem, _, launch_ctx, scratch_arr = ( + module, new_out_shapes, _, launch_ctx, scratch_arr = ( mgpu_core._lower_as_gpu_kernel( body, - grid=parallel_grid, + grid=tuple(map(operator.mul, parallel_grid, cluster)), cluster=cluster, block=block, in_shapes=in_shapes, - out_shape=out_shapes, - smem_scratch_shape=( - jax.ShapeDtypeStruct(shape=[smem_scratch_bytes], dtype=np.int8), - rs.barriers, - ), + out_shape=(*out_shapes, *gmem_scratch_shapes), + smem_scratch_shape=scratch_buffers, module_name=mlir.sanitize_name(debug_info.func_name), prof_spec=prof_spec, ) ) - if thread_semantics == mgpu.ThreadSemantics.Warpgroup: + if lowering_semantics == mgpu.LoweringSemantics.Warpgroup: # Run Python lowering passes. The remaining passes will be run in C++ in # jax/jaxlib/mosaic/gpu/custom_call.cc mgpu.infer_layout(module) # pytype: disable=attribute-error + mgpu.infer_transforms(module) # pytype: disable=attribute-error mgpu.lower_mgpu_dialect(module, launch_ctx) # pytype: disable=attribute-error mgpu_core._initialize_scratch(launch_ctx, scratch_arr) + if gmem_scratch_shapes: + new_out_shapes = new_out_shapes[:-len(gmem_scratch_shapes)] + return LoweringResult( - module, parallel_grid, block, out_structs_gmem, prof_ctx + module, parallel_grid, block, new_out_shapes, prof_ctx, tuple(gmem_scratch_shapes) ) mosaic_lowering_rules = { # Lowering rules when using Mosaic GPU lane semantics. - mgpu.ThreadSemantics.Lane: {} , + (mgpu.LoweringSemantics.Lane, gpu_core.PrimitiveSemantics.Warpgroup): {} , + gpu_core.LANExWARP_SEMANTICS: {} , # Lowering rules when using Mosaic GPU warpgroup semantics. - mgpu.ThreadSemantics.Warpgroup: {}, + (mgpu.LoweringSemantics.Warpgroup, + gpu_core.PrimitiveSemantics.Warpgroup): {}, } def register_lowering_rule( - primitive: jax_core.Primitive, thread_semantics: mgpu.ThreadSemantics + primitive: jax_core.Primitive, + lowering_semantics: mgpu.LoweringSemantics, + primitive_semantics: gpu_core.PrimitiveSemantics = gpu_core.PrimitiveSemantics.Warpgroup, ): def deco(fn): - mosaic_lowering_rules[thread_semantics][primitive] = fn + mosaic_lowering_rules[ + (lowering_semantics, primitive_semantics)][primitive] = fn return fn return deco @@ -720,7 +897,7 @@ def write_env(var: jax_core.Var, val, require_value: bool = True): # TODO(apaszke): Handle other avals (refs, etc.). if isinstance(aval := var.aval, jax_core.ShapedArray): # TODO(apaszke): Clarify the type invariants for lane semantics? - if module_ctx.thread_semantics == mgpu.ThreadSemantics.Warpgroup: + if module_ctx.lowering_semantics == mgpu.LoweringSemantics.Warpgroup: # Shaped arrays must be vectors if and only if their shape is non-empty. # Those with empty shapes should be represented by their scalar type. mlir_dtype = mgpu_utils.dtype_to_ir_type(aval.dtype) @@ -745,8 +922,11 @@ def write_env(var: jax_core.Var, val, require_value: bool = True): if val.type != mlir_dtype: raise AssertionError(f"Scalar type must match ShapedArray dtype, got: {val.type} != {mlir_dtype}") - foreach(write_env, jaxpr.constvars, consts) - foreach(lambda v, a: write_env(v, a, require_value=False), jaxpr.invars, args) + foreach( + functools.partial(write_env, require_value=False), jaxpr.constvars, consts + ) + foreach(functools.partial(write_env, require_value=False), jaxpr.invars, args) + # TODO(justinfu): Handle transform scopes. last_local_name_stack: list[str] = [] named_regions = [] @@ -757,10 +937,13 @@ def write_env(var: jax_core.Var, val, require_value: bool = True): ) loc = mlir._source_info_to_location(module_ctx, eqn.primitive, source_info) with source_info_util.user_context(eqn.source_info.traceback), loc: - if eqn.primitive not in mosaic_lowering_rules[module_ctx.thread_semantics]: + if eqn.primitive not in mosaic_lowering_rules[ + (module_ctx.lowering_semantics, module_ctx.primitive_semantics)]: raise NotImplementedError( "Unimplemented primitive in Pallas Mosaic GPU lowering: " - f"{eqn.primitive.name}. " + f"{eqn.primitive.name} for lowering semantics " + f"{module_ctx.lowering_semantics} and user thread semantics " + f"{module_ctx.primitive_semantics}. " "Please file an issue on https://github.com/jax-ml/jax/issues." ) new_local_name_stack = [scope.name for scope in eqn.source_info.name_stack.stack] @@ -772,7 +955,9 @@ def write_env(var: jax_core.Var, val, require_value: bool = True): wrapper_stack = contextlib.ExitStack() wrapper_stack.enter_context(launch_ctx.named_region(name)) named_regions.append(wrapper_stack) - rule = mosaic_lowering_rules[module_ctx.thread_semantics][eqn.primitive] + rule = mosaic_lowering_rules[ + (module_ctx.lowering_semantics, module_ctx.primitive_semantics) + ][eqn.primitive] rule_ctx = LoweringRuleContext( module_ctx, launch_ctx, @@ -801,8 +986,9 @@ def write_env(var: jax_core.Var, val, require_value: bool = True): return map(read_env, jaxpr.outvars) -@register_lowering_rule(primitives.program_id_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(primitives.program_id_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(primitives.program_id_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule( + primitives.program_id_p, mgpu.LoweringSemantics.Warpgroup) def _program_id_lowering_rule(ctx: LoweringRuleContext, axis): if ctx.module_ctx.program_ids is None: raise NotImplementedError("pl.program_id() is not supported in this context") @@ -869,8 +1055,9 @@ def lowering_rule(ctx: LoweringRuleContext, *args, **params): return lowering_rule -@register_lowering_rule(primitives.num_programs_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(primitives.num_programs_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(primitives.num_programs_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule( + primitives.num_programs_p, mgpu.LoweringSemantics.Warpgroup) def _num_programs_lowering_rule(ctx: LoweringRuleContext, axis): del ctx # Unused. return arith_dialect.index_cast( @@ -878,61 +1065,50 @@ def _num_programs_lowering_rule(ctx: LoweringRuleContext, axis): gpu_dialect.block_dim(gpu_dialect.Dimension(axis)), ) - -def _handle_reshaping( - ref: ir.Value, transforms: Sequence[gpu_core.Transform] -) -> tuple[ir.Value, Sequence[gpu_core.Transform]]: - is_trivial_indexer = lambda t: isinstance( - t, indexing.NDIndexer - ) and gpu_core.is_trivial_index(t.indices, t.shape) - - last_reshaper_idx = next( - reversed([i for i, t in enumerate(transforms) if isinstance(t, RefReshaper)]), - None, - ) - if last_reshaper_idx is None: - return ref, transforms - # Check that before the reshape are only trivial indexes and or - # other reshapes. - # TODO(cperivol): Reshapes should bubble up rather than being - # expected to effectively be the first ref transform. - if not all(isinstance(t, RefReshaper) or is_trivial_indexer(t) for t in transforms[:last_reshaper_idx]): - raise NotImplementedError( - "Reshapes do not compose with other transforms and indexers must be" - f" trivial (transforms: {transforms})" - ) - reshaper = cast(RefReshaper, transforms[last_reshaper_idx]) - # Skip all the reshapes and trivial indexes. - return mgpu.memref_reshape(ref, reshaper.shape), transforms[last_reshaper_idx + 1:] - - -def _handle_indexing( - ref: ir.Value, transforms: Sequence[gpu_core.Transform] +def _handle_transforms( + ref: ir.Value, + transforms: Sequence[gpu_core.Transform], + *, + handle_transposes=True, + handle_reshapes=True, ) -> tuple[ir.Value, Sequence[gpu_core.Transform]]: - if not transforms: - pass - indexer_idxs = [ - i for i, t in enumerate(transforms) if isinstance(t, indexing.NDIndexer) - ] - if not indexer_idxs: - return ref, transforms - sliced_ref = ref + transformed_ref = ref + mlir_dtype = ir.MemRefType(ref.type).element_type new_transforms = [] - for t in transforms: - if not isinstance(t, indexing.NDIndexer): - new_transforms.append(t) - continue - indexer = cast(indexing.NDIndexer, t) - if indexer.int_indexer_shape: - raise NotImplementedError("int_indexer_shape non-empty") - indices = _ndindexer_indices(indexer) + def _bubble_up(untransform_fn, data): + nonlocal new_transforms new_transforms_rev = [] for t in reversed(new_transforms): - indices, new_t = t.untransform_index(indices) + data, new_t = untransform_fn(t, data) new_transforms_rev.append(new_t) - sliced_ref = mgpu.memref_slice(sliced_ref, indices) + new_transforms = list(reversed(new_transforms_rev)) - return sliced_ref, new_transforms + return data + + for t in transforms: + match t: + case indexing.NDIndexer(): + indexer = cast(indexing.NDIndexer, t) + if indexer.int_indexer_shape: + raise NotImplementedError("int_indexer_shape non-empty") + indices = _ndindexer_indices(indexer) + indices = _bubble_up( + lambda t, idxs: t.untransform_index(mlir_dtype, idxs), indices + ) + transformed_ref = mgpu.memref_slice(transformed_ref, indices) + case gpu_core.TransposeRef(perm) if handle_transposes: + perm = _bubble_up(lambda t, p: t.untransform_transpose(p), + perm) + transformed_ref = mgpu.memref_transpose(transformed_ref, perm) + case RefReshaper(dtype=dtype, shape=shape) if handle_reshapes: + shape = _bubble_up( + lambda t, p: t.untransform_reshape(dtype, p), # pylint: disable=cell-var-from-loop + shape) + transformed_ref = mgpu.memref_reshape(transformed_ref, shape) + case _: + new_transforms.append(t) + + return transformed_ref, new_transforms def _ndindexer_indices(indexer: indexing.NDIndexer) -> tuple[gpu_core.Index, ...]: @@ -954,20 +1130,31 @@ def _ndindexer_indices(indexer: indexing.NDIndexer) -> tuple[gpu_core.Index, ... return tuple(indices) -@register_lowering_rule(sp.get_p, mgpu.ThreadSemantics.Lane) -def _get_lowering_rule(ctx: LoweringRuleContext, x_smem, *leaves, tree): - if not isinstance(x_smem, ir.Value) and ir.MemRefType.isinstance(x_smem): - raise TypeError(f"Can only load from references (got {x_smem}).") +@register_lowering_rule(sp.get_p, mgpu.LoweringSemantics.Lane) +def _get_lowering_rule(ctx: LoweringRuleContext, x_ref, *leaves, tree): + if isinstance(x_ref, tcgen05.TMEMRef): + transforms = jax.tree.unflatten(tree, leaves) + if len(transforms) != 1 or not isinstance( + transforms[0], indexing.NDIndexer): + raise NotImplementedError( + "Only a single indexing transform is supported for TMEM refs.") + indexer = cast(indexing.NDIndexer, transforms[0]) + if not gpu_core.is_trivial_index(indexer.indices, x_ref.shape): + raise NotImplementedError( + "Only trivial indexing is supported for TMEM refs.") + return x_ref[:] + + if not isinstance(x_ref, ir.Value) and ir.MemRefType.isinstance(x_ref): + raise TypeError(f"Can only load from references (got {x_ref}).") x_aval = ctx.avals_in[0] transforms = jax.tree.unflatten(tree, leaves) - x_smem, transforms = _handle_reshaping(x_smem, transforms) - x_smem, transforms = _handle_indexing(x_smem, transforms) + x_smem, transforms = _handle_transforms(x_ref, transforms) match transforms: case (gpu_core.UnswizzleRef(swizzle), gpu_core.UntileRef(tiling)): - if tiling != (64, swizzle // x_aval.dtype.itemsize): + if tiling != (8, (swizzle * 8) // pallas_utils.dtype_bitwidth(x_aval.dtype)): raise NotImplementedError("Tiling does not fit swizzle") return mgpu.FragmentedArray.load_tiled( x_smem, is_signed=mgpu_utils.is_signed(x_aval.dtype), swizzle=swizzle @@ -986,7 +1173,7 @@ def _get_lowering_rule(ctx: LoweringRuleContext, x_smem, *leaves, tree): raise NotImplementedError(f"Unsupported transforms: {transforms}") -@register_lowering_rule(sp.get_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(sp.get_p, mgpu.LoweringSemantics.Warpgroup) def _get_lowering_rule_wg(ctx: LoweringRuleContext, x_smem, *leaves, tree): if not isinstance(x_smem, ir.Value) and ir.MemRefType.isinstance(x_smem): raise TypeError(f"Can only load from references (got {x_smem}).") @@ -994,8 +1181,7 @@ def _get_lowering_rule_wg(ctx: LoweringRuleContext, x_smem, *leaves, tree): x_aval = ctx.avals_in[0] transforms = jax.tree.unflatten(tree, leaves) - x_smem, transforms = _handle_reshaping(x_smem, transforms) - x_smem, transforms = _handle_indexing(x_smem, transforms) + x_smem, transforms = _handle_transforms(x_smem, transforms) if transforms: raise NotImplementedError( @@ -1012,7 +1198,7 @@ def _get_lowering_rule_wg(ctx: LoweringRuleContext, x_smem, *leaves, tree): return memref_dialect.load(x_smem, []) -@register_lowering_rule(sp.swap_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(sp.swap_p, mgpu.LoweringSemantics.Lane) def _swap_lowering_rule( ctx: LoweringRuleContext, x_smem, value, *leaves, tree ): @@ -1022,28 +1208,62 @@ def _swap_lowering_rule( raise TypeError(f"Can only store to references (got {x_smem}).") x_aval = ctx.avals_in[0] transforms = jax.tree.unflatten(tree, leaves) - x_smem, transforms = _handle_reshaping(x_smem, transforms) - x_smem, transforms = _handle_indexing(x_smem, transforms) + transposed_value = value.layout == mgpu.WGMMA_TRANSPOSED_LAYOUT + x_smem, transforms = _handle_transforms( + x_smem, transforms, handle_transposes=not transposed_value + ) match transforms: - case (gpu_core.UnswizzleRef(swizzle), gpu_core.UntileRef(tiling)): - if tiling != (64, swizzle // x_aval.dtype.itemsize): + case ( + gpu_core.UnswizzleRef(swizzle), + gpu_core.UntileRef(tiling), + *maybe_transpose, + ): + if tiling != (8, swizzle // x_aval.dtype.itemsize): raise NotImplementedError("Tiling does not fit swizzle") + + if transposed_value != bool(maybe_transpose): + raise ValueError( + "Either both the ref and the value are transposed or neither is." + ) + + if maybe_transpose: + if maybe_transpose != [gpu_core.TransposeRef((1, 0))]: + raise NotImplementedError( + f"Unsupported transforms: {transforms} ({maybe_transpose})" + ) + + x_smem = mgpu.memref_transpose(x_smem, (1, 0, 3, 2)) + old_value = mgpu.FragmentedArray.load_tiled( - x_smem, is_signed=mgpu_utils.is_signed(x_aval.dtype), swizzle=swizzle + x_smem, + is_signed=mgpu_utils.is_signed(x_aval.dtype), + swizzle=swizzle, + layout=value.layout, ) value.store_tiled(x_smem, swizzle=swizzle) return old_value case (): - old_value = mgpu.FragmentedArray.load_strided( - x_smem, is_signed=mgpu_utils.is_signed(x_aval.dtype) - ) - value.store_untiled(x_smem) - return old_value + match value.layout: + case mgpu.TiledLayout(): + old_value = mgpu.FragmentedArray.load_untiled( + x_smem, + layout=value.layout, + is_signed=mgpu_utils.is_signed(x_aval.dtype), + optimized=False, + ) + value.store_untiled(x_smem, optimized=False) + return old_value + case _: + old_value = mgpu.FragmentedArray.load_strided( + x_smem, is_signed=mgpu_utils.is_signed(x_aval.dtype) + ) + value.store_untiled(x_smem) + return old_value case _: raise NotImplementedError(f"Unsupported transforms: {transforms}") -@register_lowering_rule(sp.swap_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(sp.swap_p, mgpu.LoweringSemantics.Warpgroup) def _swap_lowering_rule_wg( ctx: LoweringRuleContext, x_smem, value, *leaves, tree ): @@ -1055,9 +1275,7 @@ def _swap_lowering_rule_wg( x_aval = ctx.avals_in[0] transforms = jax.tree.unflatten(tree, leaves) - x_smem, transforms = _handle_reshaping(x_smem, transforms) - x_smem, transforms = _handle_indexing(x_smem, transforms) - + x_smem, transforms = _handle_transforms(x_smem, transforms) if transforms: raise NotImplementedError( "Transforms are not yet implemented for warpgroup semantics" @@ -1076,8 +1294,8 @@ def _swap_lowering_rule_wg( return old_value -@register_lowering_rule(pjit.pjit_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(pjit.pjit_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(pjit.pjit_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(pjit.pjit_p, mgpu.LoweringSemantics.Warpgroup) def _pjit_lowering_rule(ctx: LoweringRuleContext, *args, jaxpr, **kwargs): if jaxpr.consts: raise NotImplementedError @@ -1085,11 +1303,8 @@ def _pjit_lowering_rule(ctx: LoweringRuleContext, *args, jaxpr, **kwargs): ctx.module_ctx, ctx.launch_ctx, jaxpr.jaxpr, args, ) -@register_lowering_rule(pjit.mesh_cast_p, mgpu.ThreadSemantics.Lane) -def _mesh_cast_lowering_rule(ctx, x, dst_sharding): - return x -@register_lowering_rule(lax.slice_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(lax.slice_p, mgpu.LoweringSemantics.Lane) def _slice_lowering_rule( ctx: LoweringRuleContext, x, limit_indices, start_indices, strides ): @@ -1099,8 +1314,8 @@ def _slice_lowering_rule( return x[tuple(slice(b, e) for b, e in zip(start_indices, limit_indices))] -@register_lowering_rule(lax.select_n_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(lax.select_n_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(lax.select_n_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.select_n_p, mgpu.LoweringSemantics.Warpgroup) def _select_n_lowering_rule(ctx: LoweringRuleContext, pred, *cases): if len(cases) != 2: raise NotImplementedError( @@ -1109,7 +1324,7 @@ def _select_n_lowering_rule(ctx: LoweringRuleContext, pred, *cases): ) pred_aval, *cases_avals = ctx.avals_in [out_aval] = ctx.avals_out - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: pred = _ensure_fa(pred, pred_aval.dtype) cases = _bcast(*cases, *cases_avals, out_aval) # ``select`` expects the first case to be the true branch, but ``select_n`` @@ -1127,7 +1342,7 @@ def _select_n_lowering_rule(ctx: LoweringRuleContext, pred, *cases): return arith_dialect.select(pred, *reversed(cases)) -@register_lowering_rule(lax.broadcast_in_dim_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(lax.broadcast_in_dim_p, mgpu.LoweringSemantics.Lane) def _broadcast_in_dim_lowering_rule( ctx: LoweringRuleContext, x: mgpu.FragmentedArray, @@ -1146,12 +1361,19 @@ def _broadcast_in_dim_lowering_rule( and x.layout == mgpu.WGMMA_ROW_LAYOUT ): return x.broadcast_minor(y_aval.shape[-1]) + if ( + broadcast_dimensions == (1,) + and y_aval.ndim == x_aval.ndim + 1 + and x.layout == mgpu.WGMMA_COL_LAYOUT + ): + return x.broadcast_major(y_aval.shape[-2]) if broadcast_dimensions: raise NotImplementedError return x.broadcast(shape) -@register_lowering_rule(lax.broadcast_in_dim_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule( + lax.broadcast_in_dim_p, mgpu.LoweringSemantics.Warpgroup) def _broadcast_in_dim_lowering_rule_wg( ctx: LoweringRuleContext, x: ir.Value, @@ -1171,18 +1393,25 @@ def _broadcast_in_dim_lowering_rule_wg( ) -@register_lowering_rule(lax.convert_element_type_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(lax.convert_element_type_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.convert_element_type_p, + mgpu.LoweringSemantics.Lane, gpu_core.PrimitiveSemantics.Warp) def _convert_element_type_lowering_rule( ctx: LoweringRuleContext, x, *, new_dtype, weak_type, sharding ): del weak_type, sharding [x_aval] = ctx.avals_in + if ctx.module_ctx.primitive_semantics == gpu_core.PrimitiveSemantics.Warp: + if x_aval.shape != (): + raise NotImplementedError( + "Non-scalar arithmetic is not supported in warp-level lowering.") return _ensure_fa(x, x_aval.dtype).astype( mgpu_utils.dtype_to_ir_type(new_dtype), is_signed=mgpu_utils.is_signed(new_dtype) ) -@register_lowering_rule(lax.convert_element_type_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule( + lax.convert_element_type_p, mgpu.LoweringSemantics.Warpgroup) def _convert_element_type_lowering_rule_wg( ctx: LoweringRuleContext, x, *, new_dtype, weak_type, sharding ): @@ -1274,25 +1503,29 @@ def convert(ty, x): return convert(ty, x) -mosaic_lowering_rules[mgpu.ThreadSemantics.Lane].update({ +mosaic_lowering_rules[gpu_core.LANExWG_SEMANTICS].update({ lax.neg_p: lambda ctx, x: -x, lax.not_p: lambda ctx, x: ~x, }) -mosaic_lowering_rules[mgpu.ThreadSemantics.Warpgroup].update({ +mosaic_lowering_rules[gpu_core.WGxWG_SEMANTICS].update({ lax.neg_p: _lower_fun(lambda x: jnp.subtract(0, x), multiple_results=False), lax.not_p: _lower_fun( - lambda x: jnp.bitwise_xor(x, -1), multiple_results=False + lambda x: jnp.astype(jnp.bitwise_xor(jnp.astype(x, int), -1), jnp.dtype(x)), multiple_results=False, ), }) def _binary_op_lowering_rule(ctx: LoweringRuleContext, x, y, *, impl): + if ctx.module_ctx.primitive_semantics == gpu_core.PrimitiveSemantics.Warp: + if not all(aval_in.shape == () for aval_in in ctx.avals_in): + raise NotImplementedError( + "Non-scalar arithmetic is not supported in warp-level lowering.") x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out) return impl(x, y) - -mosaic_lowering_rules[mgpu.ThreadSemantics.Lane].update({ +for semantics in [gpu_core.LANExWG_SEMANTICS, gpu_core.LANExWARP_SEMANTICS]: + mosaic_lowering_rules[semantics].update({ lax.add_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x + y), lax.sub_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x - y), lax.mul_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x * y), @@ -1308,8 +1541,7 @@ def _binary_op_lowering_rule(ctx: LoweringRuleContext, x, y, *, impl): lax.ne_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x != y), lax.max_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x.max(y)), lax.min_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x.min(y)), -}) - + }) def _binary_op_lowering_rule_wg( ctx: LoweringRuleContext, x, y, *, ui_impl, si_impl, f_impl=None @@ -1353,7 +1585,7 @@ def _binary_op_lowering_rule_wg( arith_dialect.minimumf, ), ]: - mosaic_lowering_rules[mgpu.ThreadSemantics.Warpgroup][op] = partial( + mosaic_lowering_rules[gpu_core.WGxWG_SEMANTICS][op] = partial( _binary_op_lowering_rule_wg, si_impl=si_impl, ui_impl=ui_impl, @@ -1372,7 +1604,7 @@ def _binary_boolean_op_lowering_rule_wg( (lax.or_p, arith_dialect.ori), (lax.xor_p, arith_dialect.xori), ]: - mosaic_lowering_rules[mgpu.ThreadSemantics.Warpgroup][op] = partial( + mosaic_lowering_rules[gpu_core.WGxWG_SEMANTICS][op] = partial( _binary_boolean_op_lowering_rule_wg, impl=impl, ) @@ -1405,7 +1637,7 @@ def _comparison_lowering_rule_wg( (lax.gt_p, CmpIPred.sgt, CmpIPred.ugt, CmpFPred.OGT), (lax.ge_p, CmpIPred.sge, CmpIPred.uge, CmpFPred.OGE), ]: - mosaic_lowering_rules[mgpu.ThreadSemantics.Warpgroup][op] = partial( + mosaic_lowering_rules[gpu_core.WGxWG_SEMANTICS][op] = partial( _comparison_lowering_rule_wg, si_pred=si_pred, ui_pred=ui_pred, @@ -1413,7 +1645,7 @@ def _comparison_lowering_rule_wg( ) -@register_lowering_rule(lax.div_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(lax.div_p, mgpu.LoweringSemantics.Lane) def _div_lowering_rule(ctx: LoweringRuleContext, x, y): x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out) if ir.FloatType.isinstance(x.mlir_dtype): @@ -1421,19 +1653,19 @@ def _div_lowering_rule(ctx: LoweringRuleContext, x, y): return x // y -@register_lowering_rule(lax.integer_pow_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(lax.integer_pow_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(lax.integer_pow_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.integer_pow_p, mgpu.LoweringSemantics.Warpgroup) def _integer_pow_lowering_rule(ctx: LoweringRuleContext, x, y): if y != 2: raise NotImplementedError return _square_lowering_rule(ctx, x) -@register_lowering_rule(lax.square_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(lax.square_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(lax.square_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.square_p, mgpu.LoweringSemantics.Warpgroup) def _square_lowering_rule(ctx: LoweringRuleContext, x): [x_aval] = ctx.avals_in - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: x = _ensure_fa(x, x_aval.dtype) return x * x if jnp.issubdtype(x_aval.dtype, jnp.integer): @@ -1443,11 +1675,13 @@ def _square_lowering_rule(ctx: LoweringRuleContext, x): raise NotImplementedError(f"Unsupported dtype {x_aval.dtype}") -@register_lowering_rule(lax.rsqrt_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(lax.rsqrt_p, mgpu.ThreadSemantics.Warpgroup) -def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x): +@register_lowering_rule(lax.rsqrt_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.rsqrt_p, mgpu.LoweringSemantics.Warpgroup) +def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") [x_aval] = ctx.avals_in - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: return _ensure_fa(x, x_aval.dtype).rsqrt(approx=ctx.module_ctx.approx_math) fastmath = ( arith_dialect.FastMathFlags.afn if ctx.module_ctx.approx_math else None @@ -1457,11 +1691,13 @@ def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x): ) -@register_lowering_rule(lax.tanh_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(lax.tanh_p, mgpu.ThreadSemantics.Warpgroup) -def _tanh_lowering_rule(ctx: LoweringRuleContext, x): +@register_lowering_rule(lax.tanh_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.tanh_p, mgpu.LoweringSemantics.Warpgroup) +def _tanh_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") [x_aval] = ctx.avals_in - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: return _ensure_fa(x, x_aval.dtype).tanh(approx=ctx.module_ctx.approx_math) fastmath = ( arith_dialect.FastMathFlags.afn if ctx.module_ctx.approx_math else None @@ -1469,23 +1705,27 @@ def _tanh_lowering_rule(ctx: LoweringRuleContext, x): return math_dialect.tanh(_ensure_ir_value(x, x_aval.dtype), fastmath=fastmath) -def _logistic(x): +def _logistic(x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") return 1.0 / (1 + lax.exp(-x)) -mosaic_lowering_rules[mgpu.ThreadSemantics.Lane][lax.logistic_p] = _lower_fun( +mosaic_lowering_rules[gpu_core.LANExWG_SEMANTICS][lax.logistic_p] = _lower_fun( _logistic, multiple_results=False ) -mosaic_lowering_rules[mgpu.ThreadSemantics.Warpgroup][lax.logistic_p] = ( +mosaic_lowering_rules[gpu_core.WGxWG_SEMANTICS][lax.logistic_p] = ( _lower_fun(_logistic, multiple_results=False) ) -@register_lowering_rule(lax.exp_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(lax.exp_p, mgpu.ThreadSemantics.Warpgroup) -def _exp_lowering_rule(ctx: LoweringRuleContext, x): +@register_lowering_rule(lax.exp_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.exp_p, mgpu.LoweringSemantics.Warpgroup) +def _exp_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") [x_aval] = ctx.avals_in - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: return _ensure_fa(x, x_aval.dtype).exp(approx=ctx.module_ctx.approx_math) fastmath = ( arith_dialect.FastMathFlags.afn if ctx.module_ctx.approx_math else None @@ -1493,10 +1733,13 @@ def _exp_lowering_rule(ctx: LoweringRuleContext, x): return math_dialect.exp(_ensure_ir_value(x, x_aval.dtype), fastmath=fastmath) -@register_lowering_rule(lax.exp2_p, mgpu.ThreadSemantics.Lane) -def _exp2_lowering_rule(ctx: LoweringRuleContext, x): +@register_lowering_rule(lax.exp2_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.exp2_p, mgpu.LoweringSemantics.Warpgroup) +def _exp2_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") [x_aval] = ctx.avals_in - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: return _ensure_fa(x, x_aval.dtype).exp2(approx=ctx.module_ctx.approx_math) fastmath = ( arith_dialect.FastMathFlags.afn if ctx.module_ctx.approx_math else None @@ -1504,11 +1747,13 @@ def _exp2_lowering_rule(ctx: LoweringRuleContext, x): return math_dialect.exp2(_ensure_ir_value(x, x_aval.dtype), fastmath=fastmath) -@register_lowering_rule(lax.log_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(lax.log_p, mgpu.ThreadSemantics.Warpgroup) -def _log_lowering_rule(ctx: LoweringRuleContext, x): +@register_lowering_rule(lax.log_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.log_p, mgpu.LoweringSemantics.Warpgroup) +def _log_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") [x_aval] = ctx.avals_in - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: return _ensure_fa(x, x_aval.dtype).log(approx=ctx.module_ctx.approx_math) fastmath = ( arith_dialect.FastMathFlags.afn if ctx.module_ctx.approx_math else None @@ -1516,7 +1761,7 @@ def _log_lowering_rule(ctx: LoweringRuleContext, x): return math_dialect.log(_ensure_ir_value(x, x_aval.dtype), fastmath=fastmath) -@register_lowering_rule(lax.reduce_sum_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(lax.reduce_sum_p, mgpu.LoweringSemantics.Lane) def _reduce_sum_lowering_rule(ctx: LoweringRuleContext, x, *, axes): [x_aval] = ctx.avals_in match x.layout: @@ -1536,7 +1781,7 @@ def _reduce_sum_lowering_rule(ctx: LoweringRuleContext, x, *, axes): raise NotImplementedError(f"Unsupported layout {x.layout}") -@register_lowering_rule(lax.reduce_max_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(lax.reduce_max_p, mgpu.LoweringSemantics.Lane) def _reduce_max_lowering_rule(ctx: LoweringRuleContext, x, *, axes): [x_aval] = ctx.avals_in match x.layout: @@ -1577,7 +1822,7 @@ def _reduce_lowering_rule_wg( return vector_dialect.MultiDimReductionOp(kind, x, acc, axes) -@register_lowering_rule(lax.reduce_sum_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(lax.reduce_sum_p, mgpu.LoweringSemantics.Warpgroup) def _reduce_sum_lowering_rule_wg(ctx: LoweringRuleContext, x, *, axes): op = _reduce_lowering_rule_wg( vector_dialect.CombiningKind.ADD, 0, ctx, x, axes=axes @@ -1588,7 +1833,7 @@ def _reduce_sum_lowering_rule_wg(ctx: LoweringRuleContext, x, *, axes): return op.result -@register_lowering_rule(lax.reduce_max_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(lax.reduce_max_p, mgpu.LoweringSemantics.Warpgroup) def _reduce_max_lowering_rule_wg(ctx: LoweringRuleContext, x, *, axes): [x_aval] = ctx.avals_in if jnp.issubdtype(x_aval.dtype, jnp.floating): @@ -1605,52 +1850,101 @@ def _reduce_max_lowering_rule_wg(ctx: LoweringRuleContext, x, *, axes): return _reduce_lowering_rule_wg(kind, acc, ctx, x, axes=axes).result -@register_lowering_rule(lax.axis_index_p, mgpu.ThreadSemantics.Lane) +def _block_id(ctx: LoweringRuleContext, dim: gpu_dialect.Dimension) -> ir.Value: + result = gpu_dialect.block_id(dim) + cluster_size = ctx.launch_ctx.cluster_size + if math.prod(cluster_size) == 1 or cluster_size[dim.value] == 1: + return result + # We scale the grid in the presence of clusters, so we need to scale the + # block ID back here. + return arith_dialect.divui(result, _as_index(cluster_size[dim.value])) + + +def _resolve_cluster_axis(axis_names: _AxisNames | None, axis_name: str): + if not axis_names: + raise LookupError( + "No axis names are available. Make sure you are using `pl.core_map`" + " with a `plgpu.GPUMesh`." + ) + if not axis_names or axis_name not in axis_names.cluster: + raise LookupError( + f"Unknown cluster axis {axis_name}, available axes:" + f" {[*axis_names.cluster]}" + ) + return gpu_dialect.Dimension(axis_names.cluster.index(axis_name)) + + +@register_lowering_rule(lax.axis_index_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.axis_index_p, mgpu.LoweringSemantics.Warpgroup) def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: Hashable): - i32 = ir.IntegerType.get_signless(32) - grid_names = ctx.module_ctx.grid_names + axis_names = ctx.module_ctx.axis_names + if not axis_names: + raise LookupError( + "No axis names are available. Make sure you are using `pl.core_map`" + " with a `plgpu.GPUMesh`." + ) + if axis_name not in axis_names: + raise LookupError( + f"Unknown axis {axis_name}, available axes: {[*axis_names]}" + ) + + if axis_names.wg is not None and axis_name == axis_names.wg: + return mgpu.warpgroup_idx(sync=True) + + if axis_name in axis_names.cluster: + return arith_dialect.index_cast( + ir.IntegerType.get_signless(32), + gpu_dialect.cluster_block_id( + gpu_dialect.Dimension(axis_names.cluster.index(axis_name)) + ), + ) + squashed_dims = ctx.module_ctx.squashed_dims if squashed_dims: - unsquashed_names = grid_names[-3:] - squashed_names = grid_names[:-3] + unsquashed_names = axis_names.grid[-2:] + squashed_names = axis_names.grid[:-2] else: # These are unused but initialized for type checkers. - unsquashed_names = () - squashed_names = () - if grid_names and axis_name in grid_names: - if axis_name == grid_names[-1]: - return mgpu.warpgroup_idx(sync=True) + unsquashed_names = squashed_names = () + + if squashed_dims: + if axis_name in unsquashed_names: + # We add 1 to the index because the first dimension is the + # squashed dimension. + # e.g. for the grid (a, b, c, d, wg) + # squashed = (a, b) Mapped to Dimension.x (0) + # unsquashed = (c, d) Mapped to Dimension.y (1) and Dimension.z (2) + idx = unsquashed_names.index(axis_name) + 1 + return arith_dialect.index_cast( + ir.IntegerType.get_signless(32), + _block_id(ctx, gpu_dialect.Dimension(idx)), + ) else: - if squashed_dims: - if axis_name in unsquashed_names: - # We add 1 to the index because the first dimension is the - # squashed dimension. - # e.g. for the grid (a, b, c, d, wg) - # squashed = (a, b) Mapped to Dimension.x (0) - # unsquashed = (c, d) Mapped to Dimension.y (1) and Dimension.z (2) - idx = unsquashed_names.index(axis_name) + 1 - return arith_dialect.index_cast( - i32, - gpu_dialect.block_id(gpu_dialect.Dimension(idx)), - ) - elif axis_name in squashed_names: - # All squashed dimensions are mapped to Dimension.x. - block_id = gpu_dialect.block_id(gpu_dialect.Dimension.x) - axis = squashed_names.index(axis_name) - return _unravel_program_id(block_id, axis, squashed_dims) - else: - if axis_name in grid_names: - idx = grid_names.index(axis_name) - return arith_dialect.index_cast( - i32, - gpu_dialect.block_id(gpu_dialect.Dimension(idx)), - ) + assert axis_name in squashed_names + # All squashed dimensions are mapped to Dimension.x. + axis = squashed_names.index(axis_name) + return _unravel_program_id( + _block_id(ctx, gpu_dialect.Dimension.x), axis, squashed_dims + ) + else: + assert axis_name in axis_names.grid + idx = axis_names.grid.index(axis_name) + return arith_dialect.index_cast( + ir.IntegerType.get_signless(32), + _block_id(ctx, gpu_dialect.Dimension(idx)), + ) + +@register_lowering_rule(lax.axis_index_p, + mgpu.LoweringSemantics.Lane, gpu_core.PrimitiveSemantics.Warp) +def _axis_index_warp_rule(ctx: LoweringRuleContext, *, axis_name: Hashable): + if axis_name == ctx.module_ctx.warp_axis_name: + return mgpu.warp_idx(sync=True) raise ValueError( - "Named axes can only refer to GPUMesh axes in Mosaic GPU kernels" + "Named axes can only refer to the warp axis name inside of core_map." ) -@register_lowering_rule(primitives.debug_print_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(primitives.debug_print_p, mgpu.LoweringSemantics.Lane) def _debug_print_lowering_rule( ctx: LoweringRuleContext, *args, @@ -1678,8 +1972,8 @@ def _debug_print_lowering_rule( return () -@register_lowering_rule(primitives.debug_print_p, mgpu.ThreadSemantics.Warpgroup) -def _debug_print_lowering_rule( +@register_lowering_rule(primitives.debug_print_p, mgpu.LoweringSemantics.Warpgroup) +def _debug_print_lowering_rule_wg( ctx: LoweringRuleContext, *args, fmt, @@ -1692,79 +1986,103 @@ def _debug_print_lowering_rule( return () -@register_lowering_rule(primitives.run_scoped_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(primitives.run_scoped_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(primitives.run_scoped_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(primitives.run_scoped_p, mgpu.LoweringSemantics.Warpgroup) def _run_scoped_lowering_rule( ctx: LoweringRuleContext, *consts, jaxpr: jax_core.Jaxpr ): input_refs = [] should_discharge = [] - alloc_stack = contextlib.ExitStack() - for v in jaxpr.invars: - aval = v.aval - if isinstance(aval, gpu_core.WGMMAAbstractAccumulatorRef): - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Warpgroup: - # TODO(bchetioui): Fix this and remove the NotImplementedError. - raise NotImplementedError( - "WGMMA accumulators are not supported with Warpgroup semantics." + with contextlib.ExitStack() as alloc_stack: + for v in jaxpr.invars: + aval = v.aval + if isinstance(aval, gpu_core.WGMMAAbstractAccumulatorRef): + dtype = mlir.dtype_to_ir_type(aval.dtype) + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: + input_refs.append(mgpu.WGMMAAccumulator.zero(*aval.shape, dtype)) + else: + zero = arith_dialect.constant(dtype, ir.FloatAttr.get(dtype, 0.0)) + acc = vector_dialect.splat(ir.VectorType.get(aval.shape, dtype), zero) + acc = mgpu.dialect.optimization_barrier([acc]) + nvvm_dialect.wgmma_fence_aligned() + input_refs.append(acc) + should_discharge.append(True) + elif isinstance(aval.dtype, gpu_core.BarrierType): + barrier_ref = alloc_stack.enter_context( + ctx.module_ctx.reserve_barrier( + mgpu.Barrier( + aval.dtype.num_arrivals + * ctx.estimator_ctx.arrival_multiplier, + *aval.shape, + ) + ) ) - mlir_dtype = mlir.dtype_to_ir_type(aval.dtype) - input_refs.append(mgpu.WGMMAAccumulator.zero(*aval.shape, mlir_dtype)) - should_discharge.append(True) - elif isinstance(aval.dtype, gpu_core.BarrierType): - input_refs.append( - ctx.module_ctx.reserve_barrier( - mgpu.Barrier( - aval.dtype.num_arrivals - * ctx.estimator_ctx.arrival_multiplier, - *aval.shape, - ) - ) - ) - should_discharge.append(False) - elif aval.memory_space == gpu_core.SMEM: - [input_ref] = alloc_stack.enter_context( - ctx.module_ctx.scratch_view( - [jax.ShapeDtypeStruct(shape=aval.shape, dtype=aval.dtype)] - ) + input_refs.append(barrier_ref) + should_discharge.append(False) + elif isinstance(aval.dtype, gpu_core.ClusterBarrierType): + collective_dims = jax.tree.map( + lambda axis: _resolve_cluster_axis(ctx.module_ctx.axis_names, axis), + aval.dtype.collective_axes, + ) + barrier_ref = alloc_stack.enter_context( + ctx.module_ctx.reserve_barrier( + mgpu.ClusterBarrier(collective_dims, *aval.shape) + ) + ) + input_refs.append(barrier_ref) + should_discharge.append(False) + elif aval.memory_space == gpu_core.SMEM: + [input_ref] = alloc_stack.enter_context( + ctx.module_ctx.scratch_view( + [jax.ShapeDtypeStruct(shape=aval.shape, dtype=aval.dtype)] + ) + ) + input_refs.append(input_ref) + should_discharge.append(False) + elif aval.memory_space == gpu_core.TMEM: + input_ref = alloc_stack.enter_context( + ctx.module_ctx.alloc_tmem( + jax.ShapeDtypeStruct(shape=aval.shape, dtype=aval.dtype), + ) + ) + input_refs.append(input_ref) + should_discharge.append(False) + else: + raise ValueError(f"Can't convert to ref: {aval}") + + if any(should_discharge): + # We convert consts to args, because we only have ir.Values and + # not JAX values during lowering. discharge_state() produces JAX + # valiues for the aguments but expects them to be provided for the + # consts. We also don't want to wrap the values in refs. + no_const_jaxpr = pe.convert_constvars_jaxpr(jaxpr) + should_discharge = [False] * len(consts) + should_discharge + discharged_jaxpr, _ = discharge.discharge_state(no_const_jaxpr, (), should_discharge=should_discharge) + new_input_vals = consts + tuple(input_refs) + outs = lower_jaxpr_to_mosaic_gpu( + ctx.module_ctx, + ctx.launch_ctx, + discharged_jaxpr, + new_input_vals, + (), ) - input_refs.append(input_ref) - should_discharge.append(False) + # Discharge appends to the output the refs that got discharged. + outs = outs[:-sum(should_discharge)] else: - raise ValueError(f"Can't convert to ref: {aval}") - - if any(should_discharge): - # We convert consts to args, because we only have ir.Values and - # not JAX values during lowering. discharge_state() produces JAX - # valiues for the aguments but expects them to be provided for the - # consts. We also don't want to wrap the values in refs. - no_const_jaxpr = pe.convert_constvars_jaxpr(jaxpr) - should_discharge = [False] * len(consts) + should_discharge - discharged_jaxpr, _ = discharge.discharge_state(no_const_jaxpr, (), should_discharge=should_discharge) - new_input_vals = consts + tuple(input_refs) - outs = lower_jaxpr_to_mosaic_gpu( - ctx.module_ctx, - ctx.launch_ctx, - discharged_jaxpr, - new_input_vals, - (), - ) - # Discharge appends to the output the refs that got discharged. - outs = outs[:-sum(should_discharge)] - else: - outs = lower_jaxpr_to_mosaic_gpu( - ctx.module_ctx, - ctx.launch_ctx, - jaxpr, - input_refs, - consts, - ) + outs = lower_jaxpr_to_mosaic_gpu( + ctx.module_ctx, + ctx.launch_ctx, + jaxpr, + input_refs, + consts, + ) assert len(outs) == len(jaxpr.outvars), (jaxpr, outs) return outs -@register_lowering_rule(discharge.run_state_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(discharge.run_state_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(discharge.run_state_p, mgpu.LoweringSemantics.Warpgroup) def _run_state_lowering_rule( ctx: LoweringRuleContext, *args, @@ -1782,7 +2100,12 @@ def _run_state_lowering_rule( for arg, v, out_aval in zip(args, jaxpr.invars, ctx.avals_out): aval = v.aval if isinstance(aval, gpu_core.WGMMAAbstractAccumulatorRef): - new_input_vals.append(mgpu.WGMMAAccumulator.from_registers(arg)) + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Warpgroup: + arg = mgpu.dialect.optimization_barrier([arg]) + nvvm_dialect.wgmma_fence_aligned() + new_input_vals.append(arg) + else: + new_input_vals.append(mgpu.WGMMAAccumulator.from_registers(arg)) should_discharge.append(True) assert isinstance(out_aval, jax_core.ShapedArray) else: @@ -1817,12 +2140,12 @@ def _lower_jaxpr_to_for_loop( ctx: LoweringRuleContext, jaxpr: jax_core.Jaxpr, start: ir.Value, - length: ir.Value, + length: ir.Value | int, consts, *args, has_loop_index: bool, + unroll: bool = False, ): - _consts_avals, arg_avals = util.split_list(ctx.avals_in, [len(consts)]) arg_avals = arg_avals[has_loop_index:] out_avals = [] @@ -1836,12 +2159,11 @@ def as_values(vals, avals): _ensure = ( _ensure_fa - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane else _ensure_ir_value ) return [v if a else _ensure(v, av) for a, v, av in zip(is_acc, vals, avals)] - @mgpu.fori(length, as_values(args, arg_avals)) def loop(loop_index, body_args): if has_loop_index: loop_index = arith_dialect.addi(loop_index, start) @@ -1853,11 +2175,20 @@ def loop(loop_index, body_args): ) return as_values(outs, out_avals) - return loop.results + if unroll: + assert isinstance(length, int) + outs = as_values(args, arg_avals) + for i in range(length): + outs = loop(_ir_constant(i, start.type), outs) + return outs + else: + if not isinstance(length, ir.Value): + length = _ir_constant(length, start.type) + return mgpu.fori(length, as_values(args, arg_avals))(loop).results -@register_lowering_rule(lax.scan_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(lax.scan_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(lax.scan_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.scan_p, mgpu.LoweringSemantics.Warpgroup) def _scan_lowering_rule( ctx: LoweringRuleContext, *args, @@ -1874,10 +2205,10 @@ def _scan_lowering_rule( if ( (num_extensive := len(args) - num_consts - num_carry) or reverse - or unroll != 1 + or not (unroll == 1 or unroll == length) ): raise NotImplementedError - del linear, num_extensive, reverse, unroll + del linear, num_extensive, reverse jaxpr, jaxpr_consts = jaxpr.jaxpr, jaxpr.consts if jaxpr_consts: @@ -1893,17 +2224,24 @@ def _scan_lowering_rule( start, *args = args index_aval, *_ = arg_avals start: ir.Value = _ensure_ir_value(start, index_aval.dtype) - length = _ir_constant(length, start.type) else: start = _i32_constant(0) - length = _i32_constant(length) + for_out = _lower_jaxpr_to_for_loop( - ctx, jaxpr, start, length, consts, *args, has_loop_index=has_loop_index + ctx, + jaxpr, + start, + length, + consts, + *args, + has_loop_index=has_loop_index, + unroll=unroll == length, ) if has_loop_index: # Need to return the final loop index value if the outer scan expects # it as an output. - return [length, *for_out] + loop_index = arith_dialect.addi(start, _ir_constant(length, start.type)) + return [loop_index, *for_out] return for_out @@ -1945,8 +2283,8 @@ def _lower_while_via_fori( return ub, ub, *for_out -@register_lowering_rule(lax.while_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(lax.while_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(lax.while_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.while_p, mgpu.LoweringSemantics.Warpgroup) def _while_lowering_rule( ctx: LoweringRuleContext, *args, @@ -1970,7 +2308,7 @@ def _while_lowering_rule( _is_acc = lambda x: isinstance(x, mgpu.WGMMAAccumulator) _ensure = _ensure_ir_value - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: _ensure = lambda v, aval: v if _is_acc(v) else _ensure_fa(v, aval.dtype) # If we fail conversion to fori, fallback to an ordinary while loop. @@ -2004,26 +2342,29 @@ def _while_lowering_rule( ctx.module_ctx, ctx.launch_ctx, body_jaxpr.jaxpr, body_args ) loop_out = [*map(_ensure, loop_out, carry_avals)] - for idx, (carry_fa, out_fa) in enumerate(zip(carry, loop_out)): - if _is_acc(carry_fa) != _is_acc(out_fa): - raise ValueError( - f"The loop body output has unexpected accumulator type: output[{idx}]" - f" is {out_fa}, when it should be {carry_fa}." - ) + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: + for idx, (carry_fa, out_fa) in enumerate(zip(carry, loop_out)): + if _is_acc(carry_fa) != _is_acc(out_fa): + raise ValueError( + f"The loop body output has unexpected accumulator type:" + f" output[{idx}] is {out_fa}, when it should be {carry_fa}." + ) - if not _is_acc(out_fa) and carry_fa.layout != out_fa.layout: - raise ValueError( - f"The loop body output has unexpected layout: output[{idx}] has" - f" layout {out_fa.layout}, when it should be {carry_fa.layout}." - ) + if not _is_acc(out_fa) and carry_fa.layout != out_fa.layout: + raise ValueError( + f"The loop body output has unexpected layout: output[{idx}] has" + f" layout {out_fa.layout}, when it should be {carry_fa.layout}." + ) scf_dialect.yield_( carry_treedef.flatten_up_to(loop_out) if loop_out else [] ) return carry_treedef.unflatten(list(while_op.results)) -@register_lowering_rule(lax.cond_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(lax.cond_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(lax.cond_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.cond_p, + mgpu.LoweringSemantics.Lane, gpu_core.PrimitiveSemantics.Warp) +@register_lowering_rule(lax.cond_p, mgpu.LoweringSemantics.Warpgroup) def _cond_lowering_rule(ctx: LoweringRuleContext, index, *args, branches): index_aval, *_arg_avals = ctx.avals_in @@ -2036,16 +2377,16 @@ def _yielded_values(outs, avals): ret.append(_ensure_ir_value(out, aval.dtype)) return ret - # We need the branch return mlir types in order to construct the - # switch operation. To avoid leaking information about what kind of - # mlir types are internal to FragmentedArrays and other mgpu types, - # we run one of the branches in a dummy module that we throw away to - # extract the return types + # We need to know the result types ahead of time to construct the switch + # operation. Below we lower the first branch in a throw-away module to + # extract them. with ir.InsertionPoint(ir.Module.create().body): outs = lower_jaxpr_to_mosaic_gpu( ctx.module_ctx, ctx.launch_ctx, branches[0].jaxpr, args ) - yielded_types = [v.type for v in jax.tree.leaves(_yielded_values(outs, ctx.avals_out))] + yielded_types = [ + v.type for v in jax.tree.leaves(_yielded_values(outs, ctx.avals_out)) + ] del outs switch_op = scf_dialect.IndexSwitchOp( @@ -2080,9 +2421,9 @@ def _yielded_values(outs, avals): return treedef.unflatten(list(switch_op.results)) -@register_lowering_rule(lax.bitcast_convert_type_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(lax.bitcast_convert_type_p, mgpu.LoweringSemantics.Lane) @register_lowering_rule( - lax.bitcast_convert_type_p, mgpu.ThreadSemantics.Warpgroup + lax.bitcast_convert_type_p, mgpu.LoweringSemantics.Warpgroup ) def _bitcast_convert_type_lowering_rule( ctx: LoweringRuleContext, x, *, new_dtype @@ -2098,7 +2439,7 @@ def _bitcast_convert_type_lowering_rule( " have different widths" ) - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Warpgroup: + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Warpgroup: x = _ensure_ir_value(x, x_aval.dtype) return arith_dialect.bitcast( ir.VectorType.get(x_aval.shape, dst_elem_type), x @@ -2114,10 +2455,61 @@ def _bitcast_convert_type_lowering_rule( ) -@register_lowering_rule(lax.optimization_barrier_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(lax.optimization_barrier_p, mgpu.LoweringSemantics.Lane) def _optimization_barrier_lowering(ctx: LoweringRuleContext, *args): - args = (_ensure_fa(arg, aval.dtype) for arg, aval in zip(args, ctx.avals_in)) - return mgpu.optimization_barrier(*args) + result = mgpu.optimization_barrier( + *(_ensure_fa(arg, aval.dtype) for arg, aval in zip(args, ctx.avals_in)) + ) + return (result,) if len(ctx.avals_in) == 1 else result + + +@register_lowering_rule( + lax.optimization_barrier_p, mgpu.LoweringSemantics.Warpgroup +) +def _optimization_barrier_lowering_wg(ctx: LoweringRuleContext, *args): + result = mgpu.dialect.optimization_barrier([ + _ensure_ir_value(arg, aval.dtype) for arg, aval in zip(args, ctx.avals_in) + ]) + return (result,) if len(ctx.avals_in) == 1 else result + + +@register_lowering_rule(pallas_core.core_map_p, mgpu.LoweringSemantics.Lane) +def _core_map_lowering_rule( + ctx: LoweringRuleContext, + *args, + jaxpr, + mesh, + **_, +): + if isinstance(mesh, gpu_core.WarpMesh): + # A core_map over a WarpMesh represents a fork/join over individual + # warps in a warpgroup. + if (ctx.module_ctx.warp_axis_name or + ctx.module_ctx.primitive_semantics == gpu_core.PrimitiveSemantics.Warp): + raise LoweringError( + "Cannot nest core_maps. Already under core_map with warp_axis_name " + f"{ctx.module_ctx.warp_axis_name}.") + module_ctx = dataclasses.replace( + ctx.module_ctx, + warp_axis_name=mesh.axis_name, + primitive_semantics=gpu_core.PrimitiveSemantics.Warp, + ) + for aval_in in ctx.avals_in: + if isinstance(aval_in, jax_core.ShapedArray) and aval_in.shape: + raise LoweringError( + "Can only close over scalars and Refs when using core_map with " + f"WarpMesh. Found array of shape {aval_in}." + ) + _ = lower_jaxpr_to_mosaic_gpu( + module_ctx, + ctx.launch_ctx, + jaxpr, + args=(), + consts=args, + ) + mgpu.warpgroup_barrier() + return [] + raise ValueError(f"Unsupported mesh: {mesh}") def _bcast( diff --git a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py index d506349fe101..eb15aff21235 100644 --- a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py @@ -23,11 +23,13 @@ import warnings import jax +from jax import lax from jax._src import core as jax_core from jax._src.interpreters import mlir from jax._src.pallas import core as pallas_core from jax._src.pallas.mosaic_gpu import lowering -import jax.experimental.mosaic.gpu.core as mosaic_core +from jax.experimental.mosaic import gpu as mgpu +import numpy as np def pallas_call_lowering( @@ -56,11 +58,10 @@ def pallas_call_lowering( print(f"The grid mapping for pallas_call {debug_info.func_src_info}:") print(grid_mapping) - thread_semantics = compiler_params.get("mosaic_gpu", {}).get( - "thread_semantics", mosaic_core.ThreadSemantics.Lane + lowering_semantics = compiler_params.get("mosaic_gpu", {}).get( + "lowering_semantics", mgpu.LoweringSemantics.Lane ) - if thread_semantics == mosaic_core.ThreadSemantics.Warpgroup: - mosaic_core.dialect.register_dialect(ctx.module_context.context) # pytype: disable=attribute-error + mgpu.dialect.register_dialect(ctx.module_context.context) # pytype: disable=attribute-error lowering_result = lowering.lower_pipelined_jaxpr_to_module( grid_mapping, @@ -74,16 +75,31 @@ def pallas_call_lowering( print(lowering_result.module.operation) module = lowering_result.module - new_avals_out = [ - jax_core.ShapedArray(t.shape, t.dtype) for t in lowering_result.out_structs - ] - outs = mosaic_core._mosaic_gpu_lowering_rule( - ctx.replace(avals_out=new_avals_out), - *args, + new_avals_in = list(ctx.avals_in) + new_avals_out = list(map(_as_shaped_array, lowering_result.new_out_shapes)) + scratch_args = () + if lowering_result.gmem_scratch_shapes: + input_output_aliases += tuple( + (len(new_avals_in) + i, len(new_avals_out) + i) + for i in range(len(lowering_result.gmem_scratch_shapes)) + ) + new_avals_in.extend(map(_as_shaped_array, lowering_result.gmem_scratch_shapes)) + new_avals_out.extend(map(_as_shaped_array, lowering_result.gmem_scratch_shapes)) + def zero_init_gmem_scratch(): + return [lax.zeros_like_array(s) for s in lowering_result.gmem_scratch_shapes] + scratch_args = mlir.lower_fun( + zero_init_gmem_scratch, multiple_results=True + )(ctx.replace(avals_in=())) + outs = mgpu.core._mosaic_gpu_lowering_rule( + ctx.replace(avals_in=new_avals_in, avals_out=new_avals_out), + *args, *scratch_args, module=module, - out_types=lowering_result.out_structs, + out_types=(*lowering_result.new_out_shapes, *lowering_result.gmem_scratch_shapes), input_output_aliases=input_output_aliases, + use_custom_barrier=False, # False until we add get_barrier_semaphore() feature ) + if lowering_result.gmem_scratch_shapes: # Drop the GMEM scratch. + outs = outs[:-len(lowering_result.gmem_scratch_shapes)] if (prof_ctx := lowering_result.profiler_context) is not None: *outs, prof_buffer = outs if (dump_path := prof_ctx.dump_path) == "sponge": @@ -112,3 +128,7 @@ def do_callback(prof_buffer): ctx.replace(avals_in=(new_avals_out[-1],)), prof_buffer ) return outs + + +def _as_shaped_array(t: jax.ShapeDtypeStruct) -> jax_core.ShapedArray: + return jax_core.ShapedArray(t.shape, np.dtype(t.dtype)) diff --git a/jax/_src/pallas/mosaic_gpu/pipeline.py b/jax/_src/pallas/mosaic_gpu/pipeline.py index a48fec61b7af..1e52f8701fcb 100644 --- a/jax/_src/pallas/mosaic_gpu/pipeline.py +++ b/jax/_src/pallas/mosaic_gpu/pipeline.py @@ -33,7 +33,6 @@ from jax._src.pallas import core as pallas_core from jax._src.pallas.mosaic_gpu import core as gpu_core from jax._src.pallas.mosaic_gpu import primitives as gpu_primitives -from jax._src.util import foreach from jax.experimental import pallas as pl import jax.numpy as jnp @@ -115,7 +114,7 @@ def _uses_arguments( def _is_index_invariant( - spec: pallas_core.BlockSpec, grid: pallas_core.StaticGrid + spec: pallas_core.BlockSpec, grid: pallas_core.TupleGrid ) -> bool: if (index_map := spec.index_map) is None: return True @@ -123,7 +122,7 @@ def _is_index_invariant( def _inc_grid_by_1( - indices: tuple[jax.Array, ...], grid: Sequence[int] + indices: tuple[jax.Array, ...], grid: pallas_core.TupleGrid ) -> tuple[jax.Array, ...]: next_indices = [] carry: bool | jax.Array = True @@ -162,7 +161,7 @@ def __eq__(self, other: _Slice) -> jax.Array: # type: ignore def emit_pipeline( body: Callable[..., None], *, - grid: pallas_core.StaticGrid, + grid: pallas_core.TupleGrid, in_specs: Sequence[pallas_core.BlockSpec] = (), out_specs: Sequence[pallas_core.BlockSpec] = (), max_concurrent_steps: int = 1, @@ -171,7 +170,8 @@ def emit_pipeline( """Creates a function to emit a manual pipeline within a Pallas kernel. Args: - body: The pipeline body. + body: The pipeline body, called with the indices for the current step, the + input refs, followed by the output refs. grid: The grid to use for the pipeline. in_specs: The block specs for the inputs. out_specs: The block specs for the outputs. @@ -182,19 +182,19 @@ def emit_pipeline( ``max_concurrent_steps``. Generally, you'll want to set it to 1 if you don't await the WGMMA in the body. """ - num_steps = math.prod(grid) - if max_concurrent_steps <= delay_release: raise ValueError( "max_concurrent_steps must be greater than delay_release, but" f" {max_concurrent_steps=}, {delay_release=}" ) + num_steps = math.prod(grid) + has_dynamic_grid = not isinstance(num_steps, int) + # Shrink ``max_concurrent_steps`` if the total number of steps is lower to # reduce the size of the refs allocated in SMEM. - if max_concurrent_steps > num_steps: + if not has_dynamic_grid and max_concurrent_steps > num_steps: max_concurrent_steps = num_steps - delay_release = 0 # No need to delay anything. def pipeline(*gmem_refs: pallas_core.AbstractMemoryRef): in_gmem_refs, out_gmem_refs = util.split_list(gmem_refs, [len(in_specs)]) @@ -244,11 +244,14 @@ def scoped_pipeline( ) ] - for step, indices in enumerate( - it.islice(it.product(*map(range, grid)), max_concurrent_steps) - ): - indices = tuple(map(lambda i: jnp.asarray(i, dtype=jnp.int32), indices)) - foreach(lambda bref: bref.copy_in(step, indices, barrier_ref), in_brefs) + # Initialize the pipeline. + indices = (jnp.asarray(0, dtype=jnp.int32),) * len(grid) + fetch_indices = indices + for step in range(max_concurrent_steps): + for bref in in_brefs: + bref.copy_in(step, fetch_indices, barrier_ref) + fetch_indices = _inc_grid_by_1(fetch_indices, grid) + del fetch_indices # This is true if any of the outputs need to be transferred inside the loop. copies_out_in_loop = not all(bref.is_index_invariant for bref in out_brefs) @@ -266,11 +269,13 @@ def loop_body(step, carry): max_concurrent_steps - (1 + delay_release), wait_read_only=True ) - with pallas_core.grid_env(map(pallas_core.GridAxis, indices, grid)): - body(*( - bref.get_ref_for_slot(slot) - for bref in it.chain(in_brefs, out_brefs) - )) + body( + indices, + *( + bref.get_ref_for_slot(slot) + for bref in it.chain(in_brefs, out_brefs) + ), + ) if copies_out_in_loop: gpu_primitives.commit_smem() @@ -324,7 +329,6 @@ def do_fetch(): # Invariant: ``indices`` and ``fetch_indices`` are always # ``max_concurrent_steps-delay_release`` apart. - indices = (jnp.asarray(0, dtype=jnp.int32),) * len(grid) fetch_indices = indices for _ in range(max_concurrent_steps-delay_release): fetch_indices = _inc_grid_by_1(fetch_indices, grid) @@ -355,13 +359,14 @@ def do_fetch(): return pipeline + def emit_pipeline_warp_specialized( body: Callable[..., None], *, - grid: pallas_core.StaticGrid, + grid: pallas_core.TupleGrid, memory_registers: int, - in_specs: Sequence[gpu_core.GPUBlockSpec] = (), - out_specs: Sequence[gpu_core.GPUBlockSpec] = (), + in_specs: Sequence[pl.BlockSpec] = (), + out_specs: Sequence[pl.BlockSpec] = (), max_concurrent_steps: int = 2, wg_axis: str, num_compute_wgs: int, @@ -376,14 +381,16 @@ def emit_pipeline_warp_specialized( ``manual_consumed_barriers`` argument is True. ``` - def body(*input_refs, *output_refs, [consumed_barriers]) -> None: + def body(indices, *input_refs, *output_refs, [consumed_barriers]) -> None: ``` or with a carries enabled (enabled via the ``carry_coroutine`` argument), where the body returns the next carry: ``` - def body(*input_refs, *output_refs, [consumed_barriers], carry) -> Carry: + def body( + indices, *input_refs, *output_refs, [consumed_barriers], carry + ) -> Carry: ``` Args: @@ -423,12 +430,13 @@ def body(*input_refs, *output_refs, [consumed_barriers], carry) -> Carry: # Trace the index maps to determine if they depend on the grid. # Grid-independent values will not be multiple-buffered. in_spec_has_seq_axis = [ - ~_is_index_invariant(spec, grid) for spec in in_specs] + not _is_index_invariant(spec, grid) for spec in in_specs] out_spec_has_seq_axis = [ - ~_is_index_invariant(spec, grid) for spec in out_specs] + not _is_index_invariant(spec, grid) for spec in out_specs] spec_has_seq_axis = [*in_spec_has_seq_axis, *out_spec_has_seq_axis] - num_pipeline_steps = math.prod(grid) + num_steps = math.prod(grid) + has_dynamic_grid = not isinstance(num_steps, int) def _get_slot(step, has_seq_dim): """Returns the buffer slot given the pipeline step.""" @@ -439,8 +447,8 @@ def _get_slot(step, has_seq_dim): # Shrink ``max_concurrent_steps`` if the total number of steps is lower to # reduce the size of the refs allocated in SMEM. - if max_concurrent_steps > num_pipeline_steps: - max_concurrent_steps = num_pipeline_steps + if not has_dynamic_grid and max_concurrent_steps > num_steps: + max_concurrent_steps = num_steps def pipeline(*gmem_refs: pallas_core.AbstractMemoryRef): in_gmem_refs, out_gmem_refs = util.split_list(gmem_refs, [len(in_specs)]) @@ -458,7 +466,7 @@ def pipeline(*gmem_refs: pallas_core.AbstractMemoryRef): gpu_core.SMEM( (slots, *spec.block_shape), # type: ignore gmem_ref.dtype, - transforms=spec.transforms, + transforms=getattr(spec, "transforms", ()), ) ) in_smem_refs, out_smem_refs = util.split_list( @@ -510,13 +518,13 @@ def scoped_pipeline( consumed_barrier_refs, ): in_brefs: Sequence[BufferedRef] = [ - BufferedRef(spec, ~has_seq_axis, gmem_ref, smem_ref) + BufferedRef(spec, not has_seq_axis, gmem_ref, smem_ref) for spec, has_seq_axis, gmem_ref, smem_ref in zip( in_specs, in_spec_has_seq_axis, in_gmem_refs, in_smem_refs ) ] out_brefs: Sequence[BufferedRef] = [ - BufferedRef(spec, ~has_seq_axis, gmem_ref, smem_ref) + BufferedRef(spec, not has_seq_axis, gmem_ref, smem_ref) for spec, has_seq_axis, gmem_ref, smem_ref in zip( out_specs, out_spec_has_seq_axis, out_gmem_refs, out_smem_refs ) @@ -545,18 +553,17 @@ def compute_loop_body(step, carry): if copies_out_in_loop: gpu_primitives.wait_smem_to_gmem(max_concurrent_steps - 1) - with pallas_core.grid_env(map(pallas_core.GridAxis, indices, grid)): - body_refs = [] - for bref in it.chain(in_brefs, out_brefs): - buf_slot = _get_slot(slot, ~bref.is_index_invariant) - body_refs.append(bref.get_ref_for_slot(buf_slot)) + body_refs = [] + for bref in it.chain(in_brefs, out_brefs): + buf_slot = _get_slot(slot, not bref.is_index_invariant) + body_refs.append(bref.get_ref_for_slot(buf_slot)) - body_args = body_refs - if manual_consumed_barriers: - body_args += [consumed_barrier_ref.at[slot] for consumed_barrier_ref in consumed_barrier_refs] - if has_carry: - body_args += [prev_body_carry] - next_body_carry = body(*body_args) + body_args = body_refs + if manual_consumed_barriers: + body_args += [consumed_barrier_ref.at[slot] for consumed_barrier_ref in consumed_barrier_refs] + if has_carry: + body_args += [prev_body_carry] + next_body_carry = body(indices, *body_args) if not manual_consumed_barriers: [consumed_barrier_ref] = consumed_barrier_refs @@ -581,7 +588,7 @@ def compute_loop_body(step, carry): new_store_slices[idx], ) slices_changed = ~functools.reduce(lax.bitwise_and, are_same_slices) - bref.copy_out(_get_slot(slot, ~bref.is_index_invariant), + bref.copy_out(_get_slot(slot, not bref.is_index_invariant), indices, predicate=slices_changed) gpu_primitives.commit_smem_to_gmem_group() @@ -607,7 +614,7 @@ def compute_loop_body(step, carry): carry_init = None init_loop_carry = (init_indices, last_store_slices, carry_init) last_indices, _, final_body_carry = lax.fori_loop(0, - num_pipeline_steps, + num_steps, compute_loop_body, init_loop_carry) if has_carry: @@ -621,10 +628,11 @@ def compute_loop_body(step, carry): # written in the main pipeline loop. if not copies_out_in_loop: gpu_primitives.commit_smem() - last_slot = lax.rem(num_pipeline_steps - 1, max_concurrent_steps) + last_slot = lax.rem(num_steps - 1, max_concurrent_steps) for bref in out_brefs: if bref.is_index_invariant: - bref.copy_out(last_slot, last_indices, predicate=None) + bref.copy_out(_get_slot(last_slot, has_seq_dim=False), + last_indices, predicate=None) gpu_primitives.commit_smem_to_gmem_group() @@ -635,13 +643,22 @@ def compute_loop_body(step, carry): def memory_block(): gpu_primitives.set_max_registers(memory_registers, action="decrease") indices = (jnp.asarray(0, dtype=jnp.int32),) * len(grid) + if has_dynamic_grid: + prologue_steps = lax.min(max_concurrent_steps, num_steps) + else: + assert max_concurrent_steps <= num_steps + prologue_steps = max_concurrent_steps # Begin initial copies. - for step in range(max_concurrent_steps): + def _init_step(step, indices): for bref, barrier in zip(in_brefs, in_smem_barrier_refs): - buf_slot = _get_slot(step, ~bref.is_index_invariant) + buf_slot = _get_slot(step, not bref.is_index_invariant) bref.copy_in(buf_slot, indices, barrier) - indices = _inc_grid_by_1(indices, grid) + return _inc_grid_by_1(indices, grid) + + indices = jax.lax.fori_loop( + 0, prologue_steps, _init_step, indices, unroll=not has_dynamic_grid + ) def memory_loop_body(step, carry): indices, = carry @@ -662,11 +679,19 @@ def memory_loop_body(step, carry): if manual_consumed_barriers: gpu_primitives.barrier_wait(consumed_barrier.at[slot]) # pytype: disable=attribute-error bref.copy_in( - _get_slot(fetch_slot, ~bref.is_index_invariant), indices, barrier) + _get_slot(fetch_slot, not bref.is_index_invariant), indices, barrier) next_indices = _inc_grid_by_1(indices, grid) return (next_indices,) - lax.fori_loop(0, num_pipeline_steps - max_concurrent_steps, + lax.fori_loop(0, num_steps - max_concurrent_steps, memory_loop_body, (indices,)) + # Await all the arrivals to not leave barriers in a bad state. + # We only need to account for the prologue steps. + def _epi_step(step, _): + for barrier in consumed_barrier_refs: + gpu_primitives.barrier_wait(barrier.at[step]) + jax.lax.fori_loop( + 0, prologue_steps, _epi_step, None, unroll=not has_dynamic_grid + ) wg_idx = lax.axis_index(wg_axis) lax.cond( @@ -680,8 +705,16 @@ def _compute_registers( memory_registers: int, num_compute_wgs: int, ) -> int: - """Returns the number of registers to use for the compute thread.""" - # TODO(justinfu): Configure this per-platform. - n_registers = (512 - memory_registers) / num_compute_wgs + """Returns the max number of registers to use in compute threads. + + We start with the theoretical max registers per thread if one wargroup + (128 threads) used the entire SM's 64k register file (64k / 128 = 512). + Then reserve `memory_registers` for the producer warpgroup and distribute + the remaining registers evenly among the compute warpgroups. + + Note: The maximum number of registers per thread is 255, so we clamp + the value. + """ + n_registers = min(256, (512 - memory_registers) / num_compute_wgs) # Round down to the nearest multiple of 8. return int((n_registers // 8) * 8) diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 7f26f5d2b6a3..070e37f64e2c 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -16,7 +16,7 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Sequence, Callable import dataclasses import enum import itertools @@ -25,12 +25,14 @@ import jax from jax._src import core as jax_core +from jax._src import pretty_printer as pp from jax._src import state from jax._src import tree_util from jax._src import util from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith as arith_dialect from jax._src.lib.mlir.dialects import llvm as llvm_dialect +from jax._src.lib.mlir.dialects import memref as memref_dialect from jax._src.lib.mlir.dialects import nvvm as nvvm_dialect from jax._src.pallas import core as pallas_core from jax._src.pallas.mosaic_gpu import core as gpu_core @@ -62,6 +64,111 @@ def _check_ref( ) +load_p = jax_core.Primitive("load") + +@load_p.def_effectful_abstract_eval +def _load_abstract_eval(src, *avals_flat, args_tree, layout, optimized): + del layout, optimized # Unused. + transforms = args_tree.unflatten(avals_flat) + return ( + jax_core.ShapedArray(transforms[-1].get_indexer_shape(), src.dtype), + {state.ReadEffect(0)}, + ) + +@lowering.register_lowering_rule(load_p, mgpu.LoweringSemantics.Lane) +def _load_p_lowering_rule( + ctx: lowering.LoweringRuleContext, x_ref, *leaves, args_tree, layout, optimized +): + if not isinstance(x_ref, ir.Value) or not ir.MemRefType.isinstance(x_ref.type): + raise TypeError(f"Can only load from references (got {x_ref}).") + + x_aval = ctx.avals_in[0] + + transforms = jax.tree.unflatten(args_tree, leaves) + x_ref, transforms = lowering._handle_transforms(x_ref, transforms) + + if layout is not None: + layout = layout.to_mgpu() + + is_signed = mgpu_utils.is_signed(x_aval.dtype) + match transforms: + case (gpu_core.UnswizzleRef(swizzle), gpu_core.UntileRef(tiling)): + if tiling != (8, swizzle // x_aval.dtype.itemsize): + raise NotImplementedError("Tiling does not fit swizzle") + return mgpu.FragmentedArray.load_tiled( + x_ref, + is_signed=is_signed, + swizzle=swizzle, + layout=layout, + ) + case (): + # Handle scalar indexing. + if not ctx.avals_out[0].shape: + is_signed = mgpu_utils.is_signed(x_aval.dtype) + val = memref_dialect.load(x_ref, []) + return mgpu.FragmentedArray.splat( + val, shape=(), layout=layout, is_signed=is_signed + ) + match layout: + case mgpu.WGMMA_ROW_LAYOUT | mgpu.WGMMA_COL_LAYOUT: + return mgpu.FragmentedArray.load_untiled( + x_ref, + is_signed=is_signed, + layout=layout, + swizzle=16, + optimized=optimized, + ) + case mgpu.WGStridedFragLayout(shape=shape, vec_size=vec_size): + ref_ty = ir.MemRefType(x_ref.type) + if shape != tuple(ref_ty.shape): + raise ValueError( + f"Unsupported shape {shape}, (expected {tuple(ref_ty.shape)})" + ) + return mgpu.FragmentedArray.load_strided( + x_ref, is_signed=is_signed, vec_size=vec_size, + ) + case None: + return mgpu.FragmentedArray.load_strided(x_ref, is_signed=is_signed) + case _: + raise NotImplementedError(f"Unsupported layout: {layout}") + case _: + raise NotImplementedError(f"Unsupported transforms: {transforms}") + + +def load( + src: _Ref, + idx, + *, + layout: Layout | ParameterizedLayout | None = None, + optimized: bool = True, +) -> jax.Array: + """Loads from a reference into an array with the specified layout. + + Args: + src: The reference to load from. Can be either in SMEM or GMEM. + idx: The index to load from. + layout: The optional layout to use for the resulting array. + optimized: If True, a compilation error will be raised if no optimized + implementation for the load is available. + + Returns: + The loaded array. + """ + src, src_transforms = state_primitives.get_ref_and_transforms( + src, idx, "load", force_trailing_indexer=True, + ) + flat_src_transforms, src_transforms_treedef = tree_util.tree_flatten( + src_transforms + ) + return load_p.bind( + src, + *flat_src_transforms, + args_tree=src_transforms_treedef, + layout=layout, + optimized=optimized, + ) + + copy_smem_to_gmem_p = jax_core.Primitive("copy_smem_to_gmem") copy_smem_to_gmem_p.multiple_results = True @@ -74,9 +181,48 @@ def _copy_smem_to_gmem_abstract_eval(src, dst, *args, **params): return (), {state.ReadEffect(0), state.WriteEffect(1)} -@lowering.register_lowering_rule(copy_smem_to_gmem_p, mgpu.ThreadSemantics.Lane) +def _copy_smem_to_gmem_pp_eqn( + eqn: jax_core.JaxprEqn, + context: jax_core.JaxprPpContext, + settings: jax_core.JaxprPpSettings, +): + src, dst, *flat_args = eqn.invars + src_transforms_treedef = eqn.params["src_transforms_treedef"] + dst_transforms_treedef = eqn.params["dst_transforms_treedef"] + pp_params = {} + if not (commit_group := eqn.params["commit_group"]): + pp_params["commit_group"] = commit_group + if eqn.params["has_user_predicate"]: + flat_args, user_predicate = flat_args[:-1], flat_args[-1] + pp_params["user_predicate"] = jax_core.pp_var(user_predicate, context) + if reduction_op := eqn.params["reduction_op"]: + pp_params["reduction_op"] = reduction_op + flat_src_transforms, flat_dst_transforms = util.split_list( + flat_args, + [src_transforms_treedef.num_leaves], + ) + src_transforms = src_transforms_treedef.unflatten(flat_src_transforms) + dst_transforms = dst_transforms_treedef.unflatten(flat_dst_transforms) + return pp.concat([ + pp.text("copy_smem_to_gmem"), + jax_core.pp_kv_pairs(pp_params.items(), context, settings), + pp.text(" "), + state_primitives.pp_ref_transforms(context, src, src_transforms), + pp.text(" -> "), + state_primitives.pp_ref_transforms(context, dst, dst_transforms), + ]) + + +jax_core.pp_eqn_rules[copy_smem_to_gmem_p] = _copy_smem_to_gmem_pp_eqn + + +@lowering.register_lowering_rule( + copy_smem_to_gmem_p, mgpu.LoweringSemantics.Lane) +@lowering.register_lowering_rule( + copy_smem_to_gmem_p, mgpu.LoweringSemantics.Lane, + primitive_semantics=gpu_core.PrimitiveSemantics.Warp) @lowering.register_lowering_rule( - copy_smem_to_gmem_p, mgpu.ThreadSemantics.Warpgroup + copy_smem_to_gmem_p, mgpu.LoweringSemantics.Warpgroup ) def _copy_smem_to_gmem_lowering( ctx: lowering.LoweringRuleContext, @@ -87,27 +233,38 @@ def _copy_smem_to_gmem_lowering( dst_transforms_treedef, has_user_predicate, commit_group, + reduction_op, ): - predicate = ctx.module_ctx.single_wg_lane_predicate if has_user_predicate: flat_args, user_predicate = flat_args[:-1], flat_args[-1] - predicate = arith_dialect.andi( - predicate, lowering._ensure_ir_value(user_predicate, jnp.bool) - ) + predicate = lowering._ensure_ir_value(user_predicate, jnp.bool) + else: + predicate = None + + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: + if predicate is not None: + assert ctx.module_ctx.single_lane_predicate is not None + predicate = arith_dialect.andi( + predicate, ctx.module_ctx.single_lane_predicate + ) + else: + predicate = ctx.module_ctx.single_lane_predicate + flat_src_transforms, flat_dst_transforms = util.split_list( flat_args, [src_transforms_treedef.num_leaves], ) src_transforms = src_transforms_treedef.unflatten(flat_src_transforms) dst_transforms = dst_transforms_treedef.unflatten(flat_dst_transforms) - src, src_transforms = lowering._handle_indexing(src, src_transforms) + src, src_transforms = lowering._handle_transforms(src, src_transforms, handle_transposes=False) copy_params = _extract_gmem_copy_params(dst_transforms) | _extract_smem_copy_params(src_transforms) - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: ctx.launch_ctx.async_copy( src_ref=src, dst_ref=dst, predicate=predicate, arrive=commit_group, + reduction_op=reduction_op, **copy_params, ) return () @@ -186,6 +343,7 @@ def copy_smem_to_gmem( predicate: jax.Array | None = None, *, commit_group: bool = True, + reduction_op: mgpu.ReductionOp | None = None, ) -> None: """Asynchronously copies a SMEM reference to a GMEM reference. @@ -194,9 +352,12 @@ def copy_smem_to_gmem( dst: The GMEM reference to copy to. predicate: A boolean indicating whether the copy should be performed. If ``None``, the copy is always performed. - commit_group: If ``True``, this and any previously uncommitted copies - are committed to a group and can be awaited jointly via + commit_group: If ``True``, this and any previously uncommitted copies are + committed to a group and can be awaited jointly via :func:`jax.experimental.mosaic.gpu.wait_smem_to_gmem`. + reduction_op: If set, perform the specified reduction operation when storing + to GMEM. For example, using ``"add"`` is conceptually equivalent to + doing ``src += dst``. See also: :func:`jax.experimental.mosaic.gpu.wait_smem_to_gmem` @@ -224,6 +385,7 @@ def copy_smem_to_gmem( dst_transforms_treedef=dst_transforms_treedef, has_user_predicate=predicate is not None, commit_group=commit_group, + reduction_op=reduction_op, ) return None @@ -241,9 +403,54 @@ def _copy_gmem_to_smem_abstract_eval(src, dst, barrier, *args, **params): return (), {state.ReadEffect(0), state.WriteEffect(1)} -@lowering.register_lowering_rule(copy_gmem_to_smem_p, mgpu.ThreadSemantics.Lane) +def _copy_gmem_to_smem_pp_eqn( + eqn: jax_core.JaxprEqn, + context: jax_core.JaxprPpContext, + settings: jax_core.JaxprPpSettings, +): + src, dst, barrier, *flat_args = eqn.invars + src_transforms_treedef = eqn.params["src_transforms_treedef"] + dst_transforms_treedef = eqn.params["dst_transforms_treedef"] + barrier_transforms_treedef = eqn.params["barrier_transforms_treedef"] + pp_params = {} + if collective_axes := eqn.params["collective_axes"]: + pp_params["collective_axes"] = collective_axes + flat_src_transforms, flat_dst_transforms, flat_barrier_transforms = ( + util.split_list( + flat_args, + [ + src_transforms_treedef.num_leaves, + dst_transforms_treedef.num_leaves, + ], + ) + ) + src_transforms = src_transforms_treedef.unflatten(flat_src_transforms) + dst_transforms = dst_transforms_treedef.unflatten(flat_dst_transforms) + barrier_transforms = barrier_transforms_treedef.unflatten( + flat_barrier_transforms + ) + return pp.concat([ + pp.text("copy_gmem_to_smem"), + jax_core.pp_kv_pairs(pp_params.items(), context, settings), + pp.text(" "), + state_primitives.pp_ref_transforms(context, src, src_transforms), + pp.text(" -> "), + state_primitives.pp_ref_transforms(context, dst, dst_transforms), + pp.text(" using "), + state_primitives.pp_ref_transforms(context, barrier, barrier_transforms), + ]) + + +jax_core.pp_eqn_rules[copy_gmem_to_smem_p] = _copy_gmem_to_smem_pp_eqn + + @lowering.register_lowering_rule( - copy_gmem_to_smem_p, mgpu.ThreadSemantics.Warpgroup + copy_gmem_to_smem_p, mgpu.LoweringSemantics.Lane) +@lowering.register_lowering_rule( + copy_gmem_to_smem_p, mgpu.LoweringSemantics.Lane, + primitive_semantics=gpu_core.PrimitiveSemantics.Warp) +@lowering.register_lowering_rule( + copy_gmem_to_smem_p, mgpu.LoweringSemantics.Warpgroup ) def _copy_gmem_to_smem_lowering( ctx: lowering.LoweringRuleContext, @@ -254,6 +461,7 @@ def _copy_gmem_to_smem_lowering( src_transforms_treedef, dst_transforms_treedef, barrier_transforms_treedef, + collective_axes, ): flat_src_transforms, flat_dst_transforms, flat_barrier_transforms = ( util.split_list( @@ -266,7 +474,7 @@ def _copy_gmem_to_smem_lowering( ) src_transforms = src_transforms_treedef.unflatten(flat_src_transforms) dst_transforms = dst_transforms_treedef.unflatten(flat_dst_transforms) - dst, dst_transforms = lowering._handle_indexing(dst, dst_transforms) + dst, dst_transforms = lowering._handle_transforms(dst, dst_transforms, handle_transposes=False) copy_params = _extract_smem_copy_params(dst_transforms) | _extract_gmem_copy_params(src_transforms) barrier_indexer = _extract_barrier_indexer( barrier_transforms_treedef.unflatten(flat_barrier_transforms) @@ -275,9 +483,21 @@ def _copy_gmem_to_smem_lowering( barrier = barrier.__getitem__( *map(lowering._as_index, barrier_indexer.indices) ) + collective = None + if collective_axes is not None: + collective = tuple( + lowering._resolve_cluster_axis(ctx.module_ctx.axis_names, axis) + for axis in collective_axes + ) dst_ty = ir.MemRefType(dst.type) - bytes = math.prod(dst_ty.shape) * mgpu.bytewidth(dst_ty.element_type) - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + bits = math.prod(dst_ty.shape) * mgpu.bitwidth(dst_ty.element_type) + if bits % 8: + raise ValueError( + f"Can only transfer integer bytes (shape={dst_ty.shape}," + f" dtype={dst_ty.element_type})" + ) + bytes = bits // 8 + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: if bytes % WARPGROUP_SIZE: raise NotImplementedError("Only aligned copies are supported") # We arrive uniformly from each thread in the WG, so we need to divide the @@ -292,7 +512,8 @@ def _copy_gmem_to_smem_lowering( dst_ref=dst, barrier=barrier, arrive=False, - predicate=ctx.module_ctx.single_wg_lane_predicate, + predicate=ctx.module_ctx.single_lane_predicate, + collective=collective, **copy_params, ) return () @@ -318,7 +539,13 @@ def _copy_gmem_to_smem_lowering( return () -def copy_gmem_to_smem(src: _Ref, dst: _Ref, barrier: _Ref) -> None: +def copy_gmem_to_smem( + src: _Ref, + dst: _Ref, + barrier: _Ref, + *, + collective_axes: str | tuple[str, ...] | None = None, +) -> None: """Asynchronously copies a GMEM reference to a SMEM reference. See also: @@ -343,6 +570,8 @@ def copy_gmem_to_smem(src: _Ref, dst: _Ref, barrier: _Ref) -> None: flat_barrier_transforms, barrier_transforms_treedef = tree_util.tree_flatten( barrier_transforms ) + if isinstance(collective_axes, str): + collective_axes = (collective_axes,) copy_gmem_to_smem_p.bind( src, dst, @@ -353,6 +582,7 @@ def copy_gmem_to_smem(src: _Ref, dst: _Ref, barrier: _Ref) -> None: src_transforms_treedef=src_transforms_treedef, dst_transforms_treedef=dst_transforms_treedef, barrier_transforms_treedef=barrier_transforms_treedef, + collective_axes=collective_axes, ) return None @@ -390,7 +620,27 @@ def _barrier_arrive_abstract_eval(barrier, *args, **params): return (), {gpu_core._memory_effect} -@lowering.register_lowering_rule(barrier_arrive_p, mgpu.ThreadSemantics.Lane) +def _barrier_arrive_pp_eqn( + eqn: jax_core.JaxprEqn, + context: jax_core.JaxprPpContext, + settings: jax_core.JaxprPpSettings, +): + del settings + barrier, *flat_transforms = eqn.invars + transforms_treedef = eqn.params["transforms_treedef"] + transforms = transforms_treedef.unflatten(flat_transforms) + return pp.concat([ + pp.text("barrier_arrive"), + pp.text(" "), + state_primitives.pp_ref_transforms(context, barrier, transforms), + ]) + + +jax_core.pp_eqn_rules[barrier_arrive_p] = _barrier_arrive_pp_eqn + + +@lowering.register_lowering_rule(barrier_arrive_p, mgpu.LoweringSemantics.Lane) +@lowering.register_lowering_rule(barrier_arrive_p, mgpu.LoweringSemantics.Warpgroup) def _barrier_arrive_lowering( ctx: lowering.LoweringRuleContext, barrier, @@ -428,8 +678,27 @@ def _barrier_wait_abstract_eval(barrier, *args, **params): return (), {gpu_core._memory_effect} -@lowering.register_lowering_rule(barrier_wait_p, mgpu.ThreadSemantics.Lane) -@lowering.register_lowering_rule(barrier_wait_p, mgpu.ThreadSemantics.Warpgroup) +def _barrier_wait_pp_eqn( + eqn: jax_core.JaxprEqn, + context: jax_core.JaxprPpContext, + settings: jax_core.JaxprPpSettings, +): + del settings + barrier, *flat_transforms = eqn.invars + transforms_treedef = eqn.params["transforms_treedef"] + transforms = transforms_treedef.unflatten(flat_transforms) + return pp.concat([ + pp.text("barrier_wait"), + pp.text(" "), + state_primitives.pp_ref_transforms(context, barrier, transforms), + ]) + + +jax_core.pp_eqn_rules[barrier_wait_p] = _barrier_wait_pp_eqn + + +@lowering.register_lowering_rule(barrier_wait_p, mgpu.LoweringSemantics.Lane) +@lowering.register_lowering_rule(barrier_wait_p, mgpu.LoweringSemantics.Warpgroup) def _barrier_wait_lowering( ctx: lowering.LoweringRuleContext, barrier, @@ -466,9 +735,10 @@ def _wait_smem_to_gmem_abstract_eval(n, *, wait_read_only): return (), {gpu_core._memory_effect} -@lowering.register_lowering_rule(wait_smem_to_gmem_p, mgpu.ThreadSemantics.Lane) @lowering.register_lowering_rule( - wait_smem_to_gmem_p, mgpu.ThreadSemantics.Warpgroup + wait_smem_to_gmem_p, mgpu.LoweringSemantics.Lane) +@lowering.register_lowering_rule( + wait_smem_to_gmem_p, mgpu.LoweringSemantics.Warpgroup ) def _wait_smem_to_gmem_lowering( ctx: lowering.LoweringRuleContext, n, *, wait_read_only @@ -499,8 +769,9 @@ def _commit_group_abstract_eval(): return (), {gpu_core._memory_effect} -@lowering.register_lowering_rule(commit_group_p, mgpu.ThreadSemantics.Lane) -@lowering.register_lowering_rule(commit_group_p, mgpu.ThreadSemantics.Warpgroup) +@lowering.register_lowering_rule(commit_group_p, mgpu.LoweringSemantics.Lane) +@lowering.register_lowering_rule( + commit_group_p, mgpu.LoweringSemantics.Warpgroup) def _commit_group_lowering(ctx: lowering.LoweringRuleContext): del ctx # Unused. nvvm_dialect.cp_async_bulk_commit_group() @@ -517,11 +788,7 @@ def commit_smem_to_gmem_group() -> None: wgmma_ref_p.multiple_results = True -def wgmma( - acc: gpu_core.WGMMAAbstractAccumulatorRef, - a, - b: pallas_core.TransformedRef, -) -> None: +def wgmma(acc: gpu_core.WGMMAAbstractAccumulatorRef, a, b) -> None: """Performs an asynchronous warp group matmul-accumulate on the given references. Conceptually, this is equivalent to doing ``acc[...] += a[...] @ b[...]``, @@ -555,12 +822,17 @@ def wgmma( a = a.ref else: a_transforms_leaves, a_transforms_tree = [], None - b_transforms_leaves, b_transforms_tree = jax.tree.flatten(b.transforms) + + if isinstance(b, pallas_core.TransformedRef): + b_transforms_leaves, b_transforms_tree = jax.tree.flatten(b.transforms) + b = b.ref + else: + b_transforms_leaves, b_transforms_tree = [], None wgmma_ref_p.bind( acc, a, - b.ref, + b, *a_transforms_leaves, *b_transforms_leaves, a_transforms_tree=a_transforms_tree, @@ -582,6 +854,40 @@ def _wgmma_ref_effectful_abstract_eval(acc_aval, a_aval, b_aval, *_, **params): } +def _wgmma_ref_pp_eqn( + eqn: jax_core.JaxprEqn, + context: jax_core.JaxprPpContext, + settings: jax_core.JaxprPpSettings, +): + del settings + acc, a, b, *leaves = eqn.invars + a_transforms_treedef = eqn.params["a_transforms_tree"] + b_transforms_treedef = eqn.params["b_transforms_tree"] + split = getattr(a_transforms_treedef, "num_leaves", 0) + a_transforms = ( + a_transforms_treedef.unflatten(leaves[:split]) + if a_transforms_treedef is not None + else [] + ) + b_transforms = ( + b_transforms_treedef.unflatten(leaves[split:]) + if b_transforms_treedef is not None + else [] + ) + return pp.concat([ + pp.text("wgmma_ref"), + pp.text(" "), + pp.text(jax_core.pp_var(acc, context)), + pp.text(" <- "), + state_primitives.pp_ref_transforms(context, a, a_transforms), + pp.text(" @ "), + state_primitives.pp_ref_transforms(context, b, b_transforms), + ]) + + +jax_core.pp_eqn_rules[wgmma_ref_p] = _wgmma_ref_pp_eqn + + @discharge.register_discharge_rule(wgmma_ref_p) def _wgmma_ref_discharge(in_avals, out_avals, *args, **kwargs): del in_avals, out_avals @@ -592,7 +898,7 @@ def _wgmma_ref_discharge(in_avals, out_avals, *args, **kwargs): wgmma_p = jax_core.Primitive("wgmma") -@lowering.register_lowering_rule(wgmma_p, mgpu.ThreadSemantics.Lane) +@lowering.register_lowering_rule(wgmma_p, mgpu.LoweringSemantics.Lane) def _wgmma_lowering( ctx: lowering.LoweringRuleContext, acc, @@ -609,15 +915,25 @@ def _wgmma_lowering( transforms_leaves, [a_transforms_tree.num_leaves] ) a_transforms = a_transforms_tree.unflatten(a_transforms_leaves) - a, a_transforms = lowering._handle_indexing(a, a_transforms) + a, a_transforms = lowering._handle_transforms( + a, a_transforms, handle_transposes=False, handle_reshapes=False + ) match a_transforms: case (gpu_core.UnswizzleRef(lhs_swizzle), gpu_core.UntileRef(tiling)): - swizzle_elems = lhs_swizzle // a_aval.dtype.itemsize - if tiling != (64, swizzle_elems): - raise NotImplementedError("WGMMA lhs tiling does not fit swizzle") + lhs_transpose = False + case ( + gpu_core.UnswizzleRef(lhs_swizzle), + gpu_core.UntileRef(tiling), + gpu_core.TransposeRef((1, 0)), + ): + lhs_transpose = True case _: raise ValueError(f"WGMMA lhs has unsupported transforms: {a_transforms}.") + swizzle_elems = lhs_swizzle // a_aval.dtype.itemsize + if tiling != (8, swizzle_elems): + raise NotImplementedError("WGMMA lhs tiling does not fit swizzle") else: + lhs_transpose = False b_transforms_leaves = transforms_leaves # type: ignore if not isinstance(a, mgpu.FragmentedArray): raise ValueError( @@ -626,16 +942,17 @@ def _wgmma_lowering( ) b_transforms = b_transforms_tree.unflatten(b_transforms_leaves) - b, b_transforms = lowering._handle_indexing(b, b_transforms) + b, b_transforms = lowering._handle_transforms( + b, b_transforms, handle_transposes=False, handle_reshapes=False + ) match b_transforms: case (gpu_core.UnswizzleRef(rhs_swizzle), gpu_core.UntileRef(rhs_tiling)): rhs_transpose = False case ( gpu_core.UnswizzleRef(rhs_swizzle), - gpu_core.TransposeRef((1, 0, 2, 3)), # Only transpose between tiles gpu_core.UntileRef(rhs_tiling), - gpu_core.TransposeRef((1, 0)), # Transpose the two logical dims + gpu_core.TransposeRef((1, 0)), ): rhs_transpose = True case ( @@ -664,16 +981,66 @@ def _wgmma_lowering( swizzle_elems = rhs_swizzle // a_aval.dtype.itemsize if rhs_swizzle != lhs_swizzle: raise NotImplementedError("WGMMA rhs swizzle must match lhs swizzle") - if rhs_tiling != (swizzle_elems, swizzle_elems): + if rhs_tiling != (8, swizzle_elems): raise NotImplementedError("WGMMA rhs tiling does not fit swizzle") + if lhs_transpose: + a = mgpu.memref_transpose(a, (1, 0, 3, 2)) if rhs_transpose: - b = mgpu.memref_transpose(b, (0, 1, 3, 2)) + b = mgpu.memref_transpose(b, (1, 0, 3, 2)) new_acc = mgpu.wgmma(acc, a, b, swizzle=rhs_swizzle) nvvm_dialect.wgmma_commit_group_sync_aligned() return new_acc +@lowering.register_lowering_rule(wgmma_p, mgpu.LoweringSemantics.Warpgroup) +def _wgmma_warpgroup_lowering( + ctx: lowering.LoweringRuleContext, + acc, + a, + b, + *transforms_leaves, + a_transforms_tree, + b_transforms_tree, +): + del ctx # Unused. + + if a_transforms_tree is not None: + a_transforms_leaves, b_transforms_leaves = util.split_list( + transforms_leaves, [a_transforms_tree.num_leaves] + ) + a_transforms = a_transforms_tree.unflatten(a_transforms_leaves) + a, a_transforms = lowering._handle_transforms(a, a_transforms) + match a_transforms: + case (gpu_core.TransposeRef((1, 0)),): + a = mgpu.memref_transpose(a, (1, 0)) + case (): + pass + case _: + raise ValueError( + f"WGMMA lhs has unsupported transforms: {a_transforms}." + ) + else: + b_transforms_leaves = transforms_leaves # type: ignore + + if b_transforms_tree is not None: + b_transforms = b_transforms_tree.unflatten(b_transforms_leaves) + b, b_transforms = lowering._handle_transforms(b, b_transforms) + match b_transforms: + case (gpu_core.TransposeRef((1, 0)),): + b = mgpu.memref_transpose(b, (1, 0)) + case (): + pass + case _: + raise ValueError( + f"WGMMA rhs has unsupported transforms: {b_transforms}." + ) + + new_acc = mgpu.dialect.wgmma(acc, a, b) + nvvm_dialect.wgmma_commit_group_sync_aligned() + return new_acc + + @wgmma_p.def_effectful_abstract_eval def _wgmma_effectful_abstract_eval(acc, lhs_ref, *args, **kwargs): del args, kwargs @@ -697,7 +1064,8 @@ def wgmma_wait_effectful_abstract_eval(_): return [], {gpu_core._wgmma_pipeline_effect} -@lowering.register_lowering_rule(wgmma_wait_p, mgpu.ThreadSemantics.Lane) +@lowering.register_lowering_rule(wgmma_wait_p, mgpu.LoweringSemantics.Lane) +@lowering.register_lowering_rule(wgmma_wait_p, mgpu.LoweringSemantics.Warpgroup) def _wgmma_wait_lowering(ctx: lowering.LoweringRuleContext, allow_groups): del ctx nvvm_dialect.wgmma_wait_group_sync_aligned(allow_groups) @@ -728,11 +1096,19 @@ def _wgmma_accumulator_deref_discharge(in_avals, out_avals, acc): return (None,), wgmma_accumulator_deref_p.bind(acc) -@lowering.register_lowering_rule(wgmma_accumulator_deref_p, mgpu.ThreadSemantics.Lane) +@lowering.register_lowering_rule( + wgmma_accumulator_deref_p, mgpu.LoweringSemantics.Lane +) +@lowering.register_lowering_rule( + wgmma_accumulator_deref_p, mgpu.LoweringSemantics.Warpgroup +) def _wgmma_accumulator_deref_lowering(ctx: lowering.LoweringRuleContext, acc): - del ctx nvvm_dialect.wgmma_wait_group_sync_aligned(0) - return acc.value + return ( + acc.value + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane + else acc + ) class Layout(enum.Enum): @@ -740,6 +1116,9 @@ class Layout(enum.Enum): WGMMA = enum.auto() #: [m] matrix, where m % 64 == 0. WGMMA_ROW = enum.auto() + #: [n] matrix, where n % 8 == 0. + WGMMA_COL = enum.auto() + WGMMA_TRANSPOSED = enum.auto() WG_SPLAT = enum.auto() WG_STRIDED = enum.auto() @@ -753,16 +1132,22 @@ def check_no_args(): raise ValueError(f"Can't instantiate {self} with arguments.") match self: + case Layout.WGMMA_TRANSPOSED: + check_no_args() + return mgpu.WGMMA_TRANSPOSED_LAYOUT case Layout.WGMMA: check_no_args() return mgpu.WGMMA_LAYOUT case Layout.WGMMA_ROW: check_no_args() return mgpu.WGMMA_ROW_LAYOUT + case Layout.WGMMA_COL: + check_no_args() + return mgpu.WGMMA_COL_LAYOUT case Layout.WG_SPLAT: return mgpu.WGSplatFragLayout(*args, **kwargs) # pytype: disable=missing-parameter case Layout.WG_STRIDED: - return mgpu.WGStridedFragLayout(*args, **kwargs) + return mgpu.WGStridedFragLayout(*args, **kwargs) # pytype: disable=missing-parameter @dataclasses.dataclass(frozen=True) class ParameterizedLayout: @@ -783,12 +1168,20 @@ def _layout_cast_abstract_eval(x, new_layout): return x -@lowering.register_lowering_rule(layout_cast_p, mgpu.ThreadSemantics.Lane) +@lowering.register_lowering_rule(layout_cast_p, mgpu.LoweringSemantics.Lane) def _layout_cast_lowering(ctx: lowering.LoweringRuleContext, x, *, new_layout): del ctx # Unused. return x.to_layout(new_layout.to_mgpu()) +@lowering.register_lowering_rule(layout_cast_p, mgpu.LoweringSemantics.Warpgroup) +def _layout_cast_lowering_wg( + ctx: lowering.LoweringRuleContext, x, *, new_layout +): + del ctx # Unused. + return mgpu.dialect.layout_cast(x, mgpu.to_layout_attr(new_layout.to_mgpu())) + + def layout_cast(x: Any, new_layout: Layout | ParameterizedLayout): """Casts the layout of the given array.""" return layout_cast_p.bind(x, new_layout=new_layout) @@ -804,7 +1197,10 @@ def _set_max_registers_abstract_eval(n, *, action): return (), {gpu_core._memory_effect} -@lowering.register_lowering_rule(set_max_registers_p, mgpu.ThreadSemantics.Lane) +@lowering.register_lowering_rule( + set_max_registers_p, mgpu.LoweringSemantics.Lane) +@lowering.register_lowering_rule( + set_max_registers_p, mgpu.LoweringSemantics.Warpgroup) def _set_max_registers_lowering( ctx: lowering.LoweringRuleContext, n, *, action ): @@ -832,9 +1228,11 @@ def _commit_smem_abstract_eval(): return (), {gpu_core._memory_effect} -@lowering.register_lowering_rule(commit_smem_p, mgpu.ThreadSemantics.Lane) -@lowering.register_lowering_rule(commit_smem_p, mgpu.ThreadSemantics.Warpgroup) +@lowering.register_lowering_rule(commit_smem_p, mgpu.LoweringSemantics.Lane) +@lowering.register_lowering_rule( + commit_smem_p, mgpu.LoweringSemantics.Warpgroup) def _commit_smem_lowering(ctx: lowering.LoweringRuleContext): + # TODO(bchetioui): add primitive for commit smem to mosaic_gpu dialect. mgpu.commit_shared() return () @@ -852,7 +1250,8 @@ def _broadcasted_iota_abstract_eval(dtype, shape, dimension, layout): return jax_core.ShapedArray(shape, dtype) -@lowering.register_lowering_rule(broadcasted_iota_p, mgpu.ThreadSemantics.Lane) +@lowering.register_lowering_rule( + broadcasted_iota_p, mgpu.LoweringSemantics.Lane) def _broadcasted_iota_lowering( ctx: lowering.LoweringRuleContext, dtype, shape, dimension, layout ): @@ -900,8 +1299,42 @@ def _jaxpr_call_abstract_eval(*args, jaxpr: jax_core.Jaxpr, **params): return [v.aval for v in jaxpr.outvars] -@lowering.register_lowering_rule(jaxpr_call_p, mgpu.ThreadSemantics.Lane) -@lowering.register_lowering_rule(jaxpr_call_p, mgpu.ThreadSemantics.Warpgroup) +def _jaxpr_call_pp_eqn( + eqn: jax_core.JaxprEqn, + context: jax_core.JaxprPpContext, + settings: jax_core.JaxprPpSettings, +): + flat_args = eqn.invars + ref_treedefs = eqn.params["ref_treedefs"] + flat_refs, _ = util.split_list( + flat_args, [sum(treedef.num_leaves for treedef in ref_treedefs)] + ) + flat_refs = util.split_list( + flat_refs, + [treedef.num_leaves for treedef in ref_treedefs[: len(ref_treedefs) - 1]], + ) + trailer = [] + for treedef, flat_ref in zip(ref_treedefs, flat_refs): + ref = treedef.unflatten(flat_ref) + transforms = [] + if isinstance(ref, tuple): + ref, transforms = ref + trailer.append(pp.text(" ")) + trailer.append(state_primitives.pp_ref_transforms(context, ref, transforms)) + return pp.concat([ + pp.text("jaxpr_call"), + pp.text("["), + jax_core.pp_kv_pair("jaxpr", eqn.params["jaxpr"], context, settings), + pp.text("]"), + pp.concat(trailer), + ]) + + +jax_core.pp_eqn_rules[jaxpr_call_p] = _jaxpr_call_pp_eqn + + +@lowering.register_lowering_rule(jaxpr_call_p, mgpu.LoweringSemantics.Lane) +@lowering.register_lowering_rule(jaxpr_call_p, mgpu.LoweringSemantics.Warpgroup) def _jaxpr_call_lowering_rule( ctx: lowering.LoweringRuleContext, *flat_args, @@ -920,9 +1353,12 @@ def _jaxpr_call_lowering_rule( for treedef, flat_ref in zip(ref_treedefs, flat_refs): ref = treedef.unflatten(flat_ref) if isinstance(ref, tuple): + ref, transforms = ref # We ignore other transforms here, because they are already embedded # in the jaxpr. - ref, _ = lowering._handle_indexing(*ref) + ref, _ = lowering._handle_transforms( + ref, transforms, handle_reshapes=False, handle_transposes=False + ) args.append(ref) program_ids = program_ids_treedef.unflatten(flat_program_ids) for axis, pid in enumerate(program_ids): @@ -1025,3 +1461,189 @@ def jaxpr_call( ref_treedefs=ref_treedefs, program_ids_treedef=program_ids_treedef, ) + + +@dataclasses.dataclass(frozen=True) +class GPUShapeDtypeStruct: + shape: tuple[int, ...] + dtype: jnp.dtype + layout: ParameterizedLayout | Layout + + +inline_mgpu_p = jax_core.Primitive("inline_mgpu_p") +inline_mgpu_p.multiple_results = True + + +@dataclasses.dataclass(frozen=True) +class RefType: + ... + + +def inline_mgpu(arg_types=(), return_type=None): + """Decorate a function that inlines mgpu code. + + Arguments provided to the decorated function may be Pallas + references or array values. The body will accept the corresponding + mgpu values. + + The decorated function may return a tree of `FragmentedArray`s. + + ``` + layout = plgpu.Layout.WG_STRIDED(x_ref.shape, vec_size=4) + @plgpu.inline_mgpu( + arg_types=(plgpu.RefType(),), + return_type=plgpu.GPUShapeDtypeStruct( + (128, 128), dtype, layout=layout + ), + ) + def foo(ctx, smem_ref): + del ctx + x = mgpu.FragmentedArray.load_tiled(smem_ref, ) + y = mgpu.FragmentedArray.splat( + mgpu.c(1, x.mlir_dtype), shape=x.shape, layout=x.layout + ) + return (x + y) + + arr = foo(smem_ref) + ``` + + Args: + + arg_types: a sequence of pytrees where the leaves are `RefType` or + `Layout` for references or arrays respectively as the return + type. + + return_type: A pytree where the leaves are `GPUShapeDtypeStruct` + represeinting the arrays returned by the decorated function. + + Returns: + A decorator that creates a function that inlines mgpu code. + + """ + flat_arg_types, treedef_ty = jax.tree.flatten(tuple(arg_types)) + flat_ret_ty, pytree_ret_ty = jax.tree.flatten(return_type) + if return_type and not all(isinstance(r, GPUShapeDtypeStruct) for r in flat_ret_ty): + raise ValueError( + "inline_mgpu_p only supports GPUShapeDtypeStructx return types." + ) + if not all(isinstance(r, (Layout, ParameterizedLayout, RefType)) for r in flat_arg_types): + raise ValueError( + "inline_mgpu_p only supports only Layout, ParameterizedLayout and" + " RefType arg types." + ) + + def inner(f): + def wrapper(*args): + flat_args, treedef = jax.tree.flatten(tuple(args)) + if treedef != treedef_ty: + raise ValueError(f"Mismatched type shape: {treedef} != {treedef_ty}") + + # Strip the transforms from the refs since they will be recorded in + # the types. + raw_refs_flat_args = [] + for a, t in zip(flat_args, flat_arg_types): + def traced_ty(ty): + return isinstance(a, jax_core.Tracer) and isinstance(a.aval, ty) + + if isinstance(t, ParameterizedLayout) and traced_ty(jax_core.ShapedArray): + raw_refs_flat_args.append(a) + elif isinstance(t, RefType) and traced_ty(_Ref): + ref, transforms = a, () + if isinstance(a, state_types.TransformedRef): + ref, transforms = ref.ref, ref.transforms + + raw_refs_flat_args.append(ref) + if transforms: + raise NotImplementedError("Transformed refs (or types) are not supported.") + else: + raise ValueError(f"Mismatched type: {a, t}") + + flat_ret = inline_mgpu_p.bind( + *flat_args, + args_treedef=treedef, + flat_ret_ty=flat_ret_ty, + pytree_ret_ty=pytree_ret_ty, + flat_arg_types=flat_arg_types, + mgpu_fn=f, + ) + return jax.tree.unflatten(pytree_ret_ty, flat_ret) + return wrapper + + return inner + + +@inline_mgpu_p.def_effectful_abstract_eval +def _inline_mgpu_abstract_eval( + *flat_args, + args_treedef, + flat_arg_types, + flat_ret_ty, + pytree_ret_ty, + mgpu_fn, +): + del args_treedef, flat_arg_types, pytree_ret_ty, mgpu_fn # Unused. + aval_return = tuple( + jax_core.ShapedArray(x.shape, x.dtype) for x in flat_ret_ty + ) + # TODO(cperivol): Let the user set the effects. + return aval_return, { + gpu_core._wgmma_pipeline_effect, + gpu_core._memory_effect, + *itertools.chain.from_iterable( + (state.ReadEffect(i), state.WriteEffect(i)) + for i, r in enumerate(flat_args) + if isinstance(r, pallas_core.AbstractMemoryRef) + ), + } + + +@discharge.register_partial_discharge_rule(inline_mgpu_p) +def _inline_mgpu_discharge(*args, **kwargs): + del args, kwargs + raise NotImplementedError("inline_mgpu_p does not support discharge.") + + +def _type_check_mgpu(v, ty): + match (ty, v): + case (RefType(), ir.Value()) if ir.MemRefType.isinstance(v.type): + pass + case (GPUShapeDtypeStruct(), mgpu.FragmentedArray()): + mlir_dtype = mgpu_utils.dtype_to_ir_type(ty.dtype) + if v.mlir_dtype != mlir_dtype or ty.shape != v.shape or v.layout != ty.layout.to_mgpu(): + raise ValueError(f"Array type mismatch at {v} != {ty}.") + case (Layout() , mgpu.FragmentedArray()) | (ParameterizedLayout(), mgpu.FragmentedArray()): + if ty.to_mgpu() != v.layout: + raise ValueError(f"Unexpected layout for {v} (expected: {ty})") + case _: + raise ValueError(f"Unexpected type {ty} for value {v}") + + +@lowering.register_lowering_rule(inline_mgpu_p, mgpu.LoweringSemantics.Lane) +def _inline_mgpu_lowering_rule( + ctx: lowering.LoweringRuleContext, + *flat_args, + mgpu_fn: Callable[..., Any], + flat_arg_types, + flat_ret_ty, + pytree_ret_ty, + args_treedef, +): + for a, t in zip(flat_args, flat_arg_types): + _type_check_mgpu(a, t) + + args = jax.tree.unflatten(args_treedef, flat_args) + ret = mgpu_fn(ctx.launch_ctx, *args) + ret_leaves, ret_tree = jax.tree.flatten( + ret, is_leaf=lambda x: isinstance(x, mgpu.FragmentedArray) + ) + + if ret_tree != pytree_ret_ty: + return_type = jax.tree.unflatten(pytree_ret_ty, flat_ret_ty) + raise ValueError( + f"inline_mgpu_p return type tree mismatch: {ret} != {return_type}" + ) + + for ty, r in zip(flat_ret_ty, ret_leaves): + _type_check_mgpu(r, ty) + + return ret_leaves diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index d0b74b2e5148..f6a3757ca8d6 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -1206,7 +1206,7 @@ def _trace_kernel_to_jaxpr( return jaxpr, tuple(consts) -_PALLAS_USE_MOSAIC_GPU = config.bool_flag( +_PALLAS_USE_MOSAIC_GPU = config.bool_state( "jax_pallas_use_mosaic_gpu", default=config.bool_env("JAX_PALLAS_USE_MOSAIC_GPU", False), help=( @@ -1234,7 +1234,7 @@ def _unsupported_lowering_error(platform: str) -> Exception: f"Cannot lower pallas_call on platform: {platform}. To use Pallas on GPU," " install jaxlib GPU 0.4.24 or newer. To use Pallas on TPU, install" " jaxlib TPU and libtpu. See" - " https://jax.readthedocs.io/en/latest/installation.html." + " https://docs.jax.dev/en/latest/installation.html." ) _Backend = Literal["mosaic_tpu", "triton", "mosaic_gpu"] @@ -1489,7 +1489,7 @@ def pallas_call( ) -> Callable[..., Any]: """Invokes a Pallas kernel on some inputs. - See `Pallas Quickstart `_. + See `Pallas Quickstart `_. Args: kernel: the kernel function, that receives a Ref for each input and output. diff --git a/jax/_src/pallas/primitives.py b/jax/_src/pallas/primitives.py index 3306649f24f3..986a62571010 100644 --- a/jax/_src/pallas/primitives.py +++ b/jax/_src/pallas/primitives.py @@ -993,3 +993,247 @@ def _lower_fun(*lower_fun_args): return out[:num_return_values] return mlir.lower_fun(_lower_fun, multiple_results=True)(ctx, *args) + + +def _get_ref_and_transforms(ref): + if isinstance(ref, state.TransformedRef): + return ref.ref, ref.transforms + return ref, () + + +class DeviceIdType(enum.Enum): + MESH = "mesh" + LOGICAL = "logical" + + +def check_sem_avals( + sem_aval, sem_transforms_avals, name, allowed_semaphore_types=None +): + if allowed_semaphore_types is None: + allowed_semaphore_types = { + pallas_core.semaphore, + pallas_core.barrier_semaphore, + # For interpret mode. + pallas_core.SEMAPHORE_INTERPRET_DTYPE, + } + if not isinstance(sem_aval, state.AbstractRef): + raise ValueError(f"Cannot {name} on a non-semaphore Ref: {sem_aval}") + sem_shape = sem_aval.shape + if sem_transforms_avals: + sem_shape = sem_transforms_avals[-1].get_indexer_shape() + if sem_shape: + raise ValueError(f"Cannot {name} on a non-()-shaped semaphore: {sem_shape}") + sem_dtype = sem_aval.dtype + if not any( + jnp.issubdtype(sem_dtype, sem_type) + for sem_type in allowed_semaphore_types + ): + raise ValueError( + f"Must {name} semaphores of the following types:" + f" {allowed_semaphore_types}." + ) + + +def _transform_semaphore(ref_value, transforms, ref_aval): + """Helper function for indexing into a semaphore during state_discharge.""" + if ref_value.shape == ref_aval.shape: + return state_discharge.transform_array(ref_value, transforms) + elif len(ref_value.shape) == 0: + return ref_value + else: + raise ValueError( + f"Semaphore value shape {ref_value.shape} does not match aval shape" + f" {ref_aval.shape}" + ) + + +semaphore_read_p = jax_core.Primitive("semaphore_read") +semaphore_read_p.multiple_results = False + + +def semaphore_read(sem_or_view): + ref, transforms = _get_ref_and_transforms(sem_or_view) + args = [ref, transforms] + flat_args, args_tree = tree_util.tree_flatten(args) + return semaphore_read_p.bind(*flat_args, args_tree=args_tree) + +@semaphore_read_p.def_abstract_eval +def _semaphore_read_abstract_eval( + *avals, + args_tree, +): + del avals, args_tree + return jax_core.ShapedArray((), jnp.dtype("int32")) + +def _semaphore_read_discharge_rule(in_avals, + out_avals, + *flat_args, + args_tree): + del out_avals + [ref, transforms] = args_tree.unflatten(flat_args) + sem_value = _transform_semaphore(ref, transforms, in_avals[0]) + sem_value = sem_value.astype(jnp.int32) + return (None,) * len(in_avals), sem_value +state_discharge.register_discharge_rule(semaphore_read_p)( + _semaphore_read_discharge_rule +) + + +semaphore_signal_p = jax_core.Primitive('semaphore_signal') +semaphore_signal_p.multiple_results = True + + +def semaphore_signal( + sem_or_view, + inc: int | jax.Array = 1, + *, + device_id: int | jax.Array | None | tuple[int | jax.Array, ...] = None, + device_id_type: DeviceIdType = DeviceIdType.MESH, + core_index: int | jax.Array | None = None, +): + ref, transforms = _get_ref_and_transforms(sem_or_view) + inc = jnp.asarray(inc, dtype=jnp.int32) + args = [ref, transforms, inc, device_id, core_index] + flat_args, args_tree = tree_util.tree_flatten(args) + semaphore_signal_p.bind( + *flat_args, + args_tree=args_tree, + device_id_type=device_id_type, + ) + + +@semaphore_signal_p.def_abstract_eval +def _semaphore_signal_abstract_eval( + *avals, + args_tree, + device_id_type: DeviceIdType, +): + del device_id_type + ( + sem_aval, + sem_transforms_avals, + value_aval, + device_id_avals, + core_index_aval, + ) = tree_util.tree_unflatten(args_tree, avals) + check_sem_avals(sem_aval, sem_transforms_avals, "signal") + if value_aval.dtype != jnp.dtype("int32"): + raise ValueError("Must signal an int32 value.") + if device_id_avals is not None: + device_id_flat_avals = tree_util.tree_leaves(device_id_avals) + for aval in device_id_flat_avals: + if aval.dtype != jnp.dtype("int32"): + raise ValueError("`device_id`s must be an int32 value.") + return [] + + +def _semaphore_signal_pp_eqn(eqn: jax_core.JaxprEqn, + context: jax_core.JaxprPpContext, + settings: jax_core.JaxprPpSettings): + del settings + invars = eqn.invars + tree = eqn.params["args_tree"] + ( + sem, + sem_transforms, + value, + device_ids, + _, + ) = tree_util.tree_unflatten(tree, invars) + out = pp.concat([ + pp.text("semaphore_signal"), + pp.text(" "), + sp.pp_ref_transforms(context, sem, sem_transforms), + pp.text(" "), + pp.text(jax_core.pp_var(value, context)), + ]) + if device_ids is not None: + flat_device_ids = tree_util.tree_leaves(device_ids) + if not flat_device_ids: + return out + device_ids_pp = [pp.text(jax_core.pp_var(flat_device_ids[0], context))] + for device_id in flat_device_ids[1:]: + device_ids_pp.append(pp.text(" ")) + device_ids_pp.append(pp.text(jax_core.pp_var(device_id, context))) + out = pp.concat([out, pp.concat(device_ids_pp)]) + return out +jax_core.pp_eqn_rules[semaphore_signal_p] = _semaphore_signal_pp_eqn + + +def _semaphore_signal_discharge_rule(in_avals, + out_avals, + *flat_args, + args_tree, + device_id_type): + del out_avals, device_id_type + [ref, transforms, inc, device_id, core_index] = args_tree.unflatten(flat_args) + if device_id is not None: + raise NotImplementedError("Remote signal not implemented.") + if core_index is not None: + raise NotImplementedError("Multiple core support not implemented.") + sem_value = _transform_semaphore(ref, transforms, in_avals[0]) + inc = inc.astype(pallas_core.SEMAPHORE_INTERPRET_DTYPE) + _, new_sem_value = state_discharge.transform_swap_array( + ref, transforms, sem_value + inc + ) + return (new_sem_value,) + (None,) * (len(in_avals) - 1), () +state_discharge.register_discharge_rule(semaphore_signal_p)( + _semaphore_signal_discharge_rule +) + + +semaphore_wait_p = jax_core.Primitive('semaphore_wait') +semaphore_wait_p.multiple_results = True + +def semaphore_wait(sem_or_view, dec: int | jax.Array = 1): + ref, transforms = _get_ref_and_transforms(sem_or_view) + dec = jnp.asarray(dec, dtype=jnp.int32) + args = [ref, transforms, dec] + flat_args, args_tree = tree_util.tree_flatten(args) + semaphore_wait_p.bind(*flat_args, args_tree=args_tree) + +@semaphore_wait_p.def_abstract_eval +def _semaphore_wait_abstract_eval(*avals, args_tree): + sem_aval, sem_transforms_avals, value_aval = tree_util.tree_unflatten( + args_tree, avals + ) + check_sem_avals(sem_aval, sem_transforms_avals, "wait") + if value_aval.dtype != jnp.dtype("int32"): + raise ValueError("Must wait an int32 value.") + return [] + +def _semaphore_wait_pp_eqn(eqn: jax_core.JaxprEqn, + context: jax_core.JaxprPpContext, + settings: jax_core.JaxprPpSettings): + del settings + invars = eqn.invars + tree = eqn.params["args_tree"] + ( + sem, + sem_transforms, + value, + ) = tree_util.tree_unflatten(tree, invars) + return pp.concat([ + pp.text("semaphore_wait"), + pp.text(" "), + sp.pp_ref_transforms(context, sem, sem_transforms), + pp.text(" "), + pp.text(jax_core.pp_var(value, context)), + ]) +jax_core.pp_eqn_rules[semaphore_wait_p] = _semaphore_wait_pp_eqn + +def _semaphore_wait_discharge_rule(in_avals, + out_avals, + *flat_args, + args_tree): + del out_avals + [ref, transforms, dec] = args_tree.unflatten(flat_args) + sem_value = _transform_semaphore(ref, transforms, in_avals[0]) + dec = dec.astype(pallas_core.SEMAPHORE_INTERPRET_DTYPE) + _, new_sem_value = state_discharge.transform_swap_array( + ref, transforms, sem_value - dec + ) + return (new_sem_value,) + (None,) * (len(in_avals) - 1), () +state_discharge.register_discharge_rule(semaphore_wait_p)( + _semaphore_wait_discharge_rule +) diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index f3a8dd175ec1..150ae9b8b2d7 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -654,7 +654,9 @@ def _make_dispatch_table( name: str, **tables: Sequence[_Extern | _Fallback] ) -> Callable[..., ir.Value]: - def inner(ctx: LoweringRuleContext, *args: ir.Value) -> ir.Value: + def inner( + ctx: LoweringRuleContext, *args: ir.Value, **_ + ) -> ir.Value: table = tables[ctx.context.platform] h = next((e for e in table if e.matches(ctx.avals_in)), None) if h is None: @@ -1120,7 +1122,7 @@ def inner(ctx: LoweringRuleContext, *args: ir.Value) -> ir.Value: def _minus(x: ir.Value) -> ir.Value: if tt_dialect.PointerType.isinstance(_element_type(x.type)): raise NotImplementedError(f"unsupported type: {x.type}") - return _sub(_full(x.type, 0), x) + return _sub(_zeros_like(x), x) def _add(x: ir.Value, y: ir.Value): @@ -1260,6 +1262,10 @@ def _cmp( ) +def _is_nan(x: ir.Value) -> ir.Value: + return arith_dialect.cmpf(arith_dialect.CmpFPredicate.UNO, x, x) + + _JAX_TO_TRITON_BINARY = { lax.add_p: _add, lax.sub_p: _sub, @@ -1373,7 +1379,7 @@ def _broadcast_to_rule(ctx: LoweringRuleContext, x, shape: Sequence[int]): @register_lowering(lax.integer_pow_p) def _integer_pow_rule(ctx: LoweringRuleContext, x, *, y: int): if y == 0: - return _full(x.type, 1) + return _ones_like(x) is_reciprocal = y < 0 if is_reciprocal: @@ -1393,14 +1399,14 @@ def _integer_pow_rule(ctx: LoweringRuleContext, x, *, y: int): acc = _cast(acc, x_aval.dtype, out_aval.dtype) if is_reciprocal: signed = jnp.issubdtype(out_aval.dtype, jnp.signedinteger) - return _truediv(_full(acc.type, 1), acc, signed=signed) + return _truediv(_ones_like(acc), acc, signed=signed) else: return acc _JAX_FN_MAPPING = { lax.clamp_p: lambda min, a, max: jnp.minimum(jnp.maximum(min, a), max), - lax.logistic_p: lambda a: 1 / (1 + jnp.exp(-a)), + lax.logistic_p: lambda a, accuracy: 1 / (1 + jnp.exp(-a)), } for prim, fn in _JAX_FN_MAPPING.items(): @@ -1514,6 +1520,22 @@ def _full(t: ir.Type, v: object) -> ir.Type: return result +def _zeros(t: ir.Type) -> ir.Value: + return _full(t, 0) + + +def _zeros_like(x: ir.Value) -> ir.Value: + return _full(x.type, 0) + + +def _ones(t: ir.Type) -> ir.Value: + return _full(t, 1) + + +def _ones_like(x: ir.Value) -> ir.Value: + return _full(x.type, 1) + + def _splat(x: ir.value, shape: Sequence[int]) -> ir.Value: if ir.RankedTensorType.isinstance(x.type): raise TypeError("cannot splat a tensor") @@ -1552,7 +1574,7 @@ def _int_int_cast(src: ir.Value, dst_type: ir.Type, signed: bool) -> ir.Value: dst_element_type = ir.IntegerType(_element_type(dst_type)) assert src_element_type != dst_element_type if dst_element_type.width == 1: - return _not_equal(src, _full(src.type, 0), signed=signed) + return _not_equal(src, _zeros_like(src), signed=signed) if src_element_type.width == dst_element_type.width: return arith_dialect.bitcast(dst_type, src) @@ -1572,7 +1594,7 @@ def _float_int_cast( raise NotImplementedError(f"cannot cast {src} tp {dst_type}") dst_element_type = ir.IntegerType(_element_type(dst_type)) if dst_element_type.width == 1: - return _not_equal(src, _full(src.type, 0), signed=signed) + return _not_equal(src, _zeros_like(src), signed=signed) else: # We clamp the float value to the min/max integer destination value # in order to match JAX/XLA casting behavior. Note that this differs @@ -1675,7 +1697,7 @@ def _ir_cast(src: ir.Value, dst_type: ir.Type, *, return tt_dialect.ptr_to_int(dst_type, src) elif dst_element_type.width == 1: x = _ir_cast(src, ir.IntegerType.get_signless(64), signed=signed) - zero = _full(x.type, 0) + zero = _zeros_like(x) return _ir_cast(_not_equal(x, zero, signed=signed), dst_type, signed=signed) if isinstance( src_element_type, ir.IntegerType @@ -1759,6 +1781,12 @@ def _reshape(a: ir.Value, shape: Sequence[int]) -> ir.Value: ) +def get_join_type(old_type: ir.RankedTensorType): + shape = old_type.shape + shape.append(2) + return ir.RankedTensorType.get(shape, old_type.element_type, old_type.encoding) + + @register_lowering(lax.concatenate_p) def _concatenate_lowering_rule(ctx: LoweringRuleContext, *args, dimension): if len(args) != 2: @@ -1773,9 +1801,32 @@ def _concatenate_lowering_rule(ctx: LoweringRuleContext, *args, dimension): raise NotImplementedError( "Only arguments with shape [..., 1] are supported." ) - return tt_dialect.join( - _reshape(x, x_aval.shape[:-1]), _reshape(y, y_aval.shape[:-1]) - ) + lhs = _reshape(x, x_aval.shape[:-1]) + rhs = _reshape(y, y_aval.shape[:-1]) + ret_type = get_join_type(ir.RankedTensorType(rhs.type)) + return tt_dialect.join(ret_type, lhs, rhs) + + +@register_lowering(lax.split_p) +def _split_lowering_rule(ctx: LoweringRuleContext, x, *, sizes, axis): + pass + # TODO(cjfj): Add support for larger powers of 2. + num_parts = len(sizes) + if num_parts != pallas_utils.next_power_of_2(num_parts): + raise NotImplementedError("Only power-of-2 num parts supported.") + if any(size != sizes[0] for size in sizes): + raise NotImplementedError("Only equal-sized splits are supported.") + + def split_into_2(x): + shape = ir.RankedTensorType(x.type).shape + x = _reshape(x, shape[:axis] + [2, shape[axis] // 2] + shape[axis + 1 :]) + permutation = tuple(d for d in range(len(shape) + 1) if d != axis) + (axis,) + return tuple(tt_dialect.split(tt_dialect.trans(x, permutation))) + + x_parts = (x,) + while len(x_parts) < num_parts: + x_parts = sum(map(split_into_2, x_parts), ()) + return x_parts def _compute_offsets_from_indices( @@ -1798,7 +1849,7 @@ def _compute_offsets_from_indices( # Use 64-bit indexing when offset might be >= 2**32 bytes. offset_eltype = ir.IntegerType.get_signless(64 if full_size > 2**32 else 32) if indexer_shape: - offsets = _full(ir.RankedTensorType.get(indexer_shape, offset_eltype), 0) + offsets = _zeros(ir.RankedTensorType.get(indexer_shape, offset_eltype)) else: offsets = _ir_constant(0, offset_eltype) @@ -2060,17 +2111,18 @@ def _masked_load_lowering_rule( # most significant. Before jaxlib 0.5.2, the order was reversed. if is_contiguous_int4: msb_values = arith_dialect.shrui(values, _full(values.type, 4)) + join_type = get_join_type(ir.RankedTensorType(values.type)) if jaxlib_version < (0, 5, 2): - values = tt_dialect.join(msb_values, values) + values = tt_dialect.join(join_type, msb_values, values) else: - values = tt_dialect.join(values, msb_values) + values = tt_dialect.join(join_type, values, msb_values) shape = ir.RankedTensorType(values.type).shape values = _reshape(values, (*shape[:-2], shape[-2] * shape[-1])) else: offsets = _ir_cast(offsets, ir.IntegerType.get_signless(32), signed=False) in_msb = _mod(offsets, _full(offsets.type, 2), signed=False) if jaxlib_version < (0, 5, 2): - in_msb = arith_dialect.xori(in_msb, _full(in_msb.type, 1)) + in_msb = arith_dialect.xori(in_msb, _ones_like(in_msb)) shift = _mul(in_msb, _full(in_msb.type, 4)) shift = _ir_cast(shift, values.type, signed=False) values = arith_dialect.shrui(values, shift) @@ -2198,6 +2250,14 @@ def _transpose_lowering(ctx: LoweringRuleContext, x, *, permutation): _TF32_PRECISIONS = (lax.Precision.HIGH, lax.Precision.DEFAULT) +def _as_bf16(x): + return _ir_cast(x, _dtype_to_ir_type(jnp.bfloat16), signed=False) + + +def _as_f32(x): + return _ir_cast(x, _dtype_to_ir_type(jnp.float32), signed=False) + + @register_lowering(lax.dot_general_p) def _dot_general_lowering( ctx: LoweringRuleContext, @@ -2237,6 +2297,9 @@ def _dot_general_lowering( | lax.DotAlgorithmPreset.F16_F16_F32 | lax.DotAlgorithmPreset.BF16_BF16_BF16 | lax.DotAlgorithmPreset.BF16_BF16_F32 + | lax.DotAlgorithmPreset.BF16_BF16_F32_X3 + | lax.DotAlgorithmPreset.BF16_BF16_F32_X6 + | lax.DotAlgorithmPreset.BF16_BF16_F32_X9 ): input_precision = None case _: @@ -2275,7 +2338,40 @@ def _dot_general_lowering( m, _ = a_type.shape _, n = b_type.shape - acc = _full(ir.RankedTensorType.get([m, n], _dtype_to_ir_type(acc_dtype)), 0) + acc = _zeros(ir.RankedTensorType.get([m, n], _dtype_to_ir_type(acc_dtype))) + + if precision in ( + lax.DotAlgorithmPreset.BF16_BF16_F32_X3, + lax.DotAlgorithmPreset.BF16_BF16_F32_X6, + lax.DotAlgorithmPreset.BF16_BF16_F32_X9, + ): + a_bf16 = _as_bf16(a) + b_bf16 = _as_bf16(b) + a_err0 = _sub(a, _as_f32(a_bf16)) + b_err0 = _sub(b, _as_f32(b_bf16)) + a_err0_bf16 = _as_bf16(a_err0) + b_err0_bf16 = _as_bf16(b_err0) + a_err1_bf16 = _as_bf16(_sub(a_err0, _as_f32(a_err0_bf16))) + b_err1_bf16 = _as_bf16(_sub(b_err0, _as_f32(b_err0_bf16))) + # Accumulate the smallest values first to reduce the numeric error. + if precision == lax.DotAlgorithmPreset.BF16_BF16_F32_X9: + acc = tt_dialect.dot(a_err1_bf16, b_err0_bf16, acc) + acc = tt_dialect.dot(a_err1_bf16, b_err1_bf16, acc) + acc = tt_dialect.dot(a_err0_bf16, b_err1_bf16, acc) + if precision in ( + lax.DotAlgorithmPreset.BF16_BF16_F32_X6, + lax.DotAlgorithmPreset.BF16_BF16_F32_X9, + ): + acc = tt_dialect.dot(a_err1_bf16, b_bf16, acc) + acc = tt_dialect.dot(a_bf16, b_err1_bf16, acc) + acc = tt_dialect.dot(a_err0_bf16, b_err0_bf16, acc) + acc = tt_dialect.dot(a_err0_bf16, b_bf16, acc) + acc = tt_dialect.dot(a_bf16, b_err0_bf16, acc) + # If `a` rounding error is zero and `b` is `inf` then `acc` may contain + # `NaN`s (as `0 * inf = NaN`), and vice versa. + acc = arith_dialect.select(_is_nan(acc), _zeros_like(acc), acc) + a, b = a_bf16, b_bf16 + acc = tt_dialect.dot(a, b, acc, input_precision=input_precision) return _cast(acc, acc_dtype, out_aval.dtype) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index f7a4361ffee2..a10856cbced3 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -21,7 +21,6 @@ from functools import partial import inspect import logging -import operator as op import weakref from typing import NamedTuple, Any, Union, cast import warnings @@ -68,7 +67,7 @@ NamedSharding, GSPMDSharding, SingleDeviceSharding, PmapSharding, AUTO, UNSPECIFIED, UnspecifiedValue, prepare_axis_resources, parse_flatten_op_sharding, canonicalize_sharding, - flatten_spec) + flatten_spec, _internal_use_concrete_mesh) from jax._src.layout import Layout, DeviceLocalLayout, AutoLayout from jax._src.state import discharge as state_discharge, RefEffect, AbstractRef from jax._src.traceback_util import api_boundary @@ -186,9 +185,7 @@ def _python_pjit_helper(fun: Callable, jit_info: PjitInfo, *args, **kwargs): args_flat = [*init_states, *args_flat] try: - if (core.trace_state_clean() and - not config.debug_key_reuse.value and - not config.data_dependent_tracing_fallback.value): + if core.trace_state_clean() and not config.debug_key_reuse.value: args_flat = map(core.full_lower, args_flat) core.check_eval_args(args_flat) out_flat, compiled, profiler = _pjit_call_impl_python(*args_flat, **p.params) @@ -233,20 +230,30 @@ def _python_pjit_helper(fun: Callable, jit_info: PjitInfo, *args, **kwargs): def _set_states(attrs_tracked, vals): - from jax.experimental.attrs import jax_setattr + from jax.experimental.attrs import jax_setattr, jax_extendattr valss = split_list(vals, [td.num_leaves for _, td, _ in attrs_tracked[:-1]]) - for ((_, treedef, (obj, attr)), leaves) in zip(attrs_tracked, valss): - val = tree_unflatten(treedef, leaves) - jax_setattr(obj, attr, val) + for ((_, treedef, (obj, attr, kind)), leaves) in zip(attrs_tracked, valss): + if kind is pe.ReadWrite: + val = tree_unflatten(treedef, leaves) + jax_setattr(obj, attr, val) + elif kind is pe.Append: + del treedef + val, = leaves + jax_extendattr(obj, attr, val) def _get_states(attrs_tracked): - from jax.experimental.attrs import jax_getattr + from jax.experimental.attrs import jax_getattr, dne_sentinel vals = [] - for treedef, _, (obj, attr) in attrs_tracked: - tree = jax_getattr(obj, attr) - leaves, treedef_ = tree_flatten(tree) - assert treedef == treedef_ - vals.extend(leaves) + for treedef, _, (obj, attr, kind) in attrs_tracked: + if kind is pe.ReadWrite: + tree = jax_getattr(obj, attr) if hasattr(obj, attr) else dne_sentinel + leaves, treedef_ = tree_flatten(tree) + assert treedef == treedef_ + vals.extend(leaves) + elif kind is pe.Append: + pass + else: + assert False return vals def _need_to_rebuild_with_fdo(pgle_profiler): @@ -300,12 +307,6 @@ def _get_fastpath_data( return fastpath_data -def _cpp_pjit_evict_fn(self): - self._clear_cache() - _create_pjit_jaxpr.evict_function(self._fun) # pytype: disable=attribute-error - _infer_params_cached.cache_clear() - - # The entries are doubled here from the default 4096 because _pjit_call_impl # also has a cpp dispatch path and that would double the number of entries in # the global shared cache. @@ -366,9 +367,50 @@ def cache_miss(*args, **kwargs): cpp_pjitted_f = wraps(fun)(cpp_pjit_f) cpp_pjitted_f._fun = fun - type(cpp_pjitted_f).clear_cache = _cpp_pjit_evict_fn + cpp_pjitted_f._jit_info = jit_info + # TODO(necula): move these to top-level; we don't need to do this for + # every jit + cpp_jitted_f_class = type(cpp_pjitted_f) + # TODO(necula): make clear_cache private, no need to have it part of the API + cpp_jitted_f_class.clear_cache = jit_evict_fn + cpp_jitted_f_class.lower = jit_lower + cpp_jitted_f_class.trace = jit_trace + cpp_jitted_f_class.eval_shape = jit_eval_shape + # We return directly the function produced by _xla.pjit, because we do not + # want to have Python in the dispatch path. return cpp_pjitted_f +@api_boundary +def jit_trace(jit_func, *args, **kwargs) -> stages.Traced: + p, args_flat = _infer_params(jit_func._fun, jit_func._jit_info, args, kwargs) + donate_argnums = tuple(i for i, d in enumerate(p.donated_invars) if d) + args_info = stages.make_args_info(p.in_tree, p.in_avals, donate_argnums) + lower_callable = partial(_resolve_and_lower, args_flat, **p.params, + pgle_profiler=None) + return stages.Traced( + p.params['jaxpr'], args_info, p.params["name"], p.out_tree, + lower_callable, args_flat, p.arg_names, p.num_consts) + + +@api_boundary +def jit_lower(jit_func, *args, **kwargs): + return jit_trace(jit_func, *args, **kwargs).lower() + +@api_boundary +def jit_eval_shape(jit_func, *args, **kwargs): + p, _ = _infer_params(jit_func._fun, jit_func._jit_info, args, kwargs) + out_s = [None if isinstance(s, UnspecifiedValue) else s for s in p.params['out_shardings']] + # TODO(yashkatariya): Add `Layout` to SDS. + out = [api.ShapeDtypeStruct(x.shape, x.dtype, sharding=s, + weak_type=x.weak_type) + for x, s in zip(p.params['jaxpr'].out_avals, out_s)] + return tree_unflatten(p.out_tree, out) + +def jit_evict_fn(self): + self._clear_cache() + _create_pjit_jaxpr.evict_function(self._fun) # pytype: disable=attribute-error + _infer_params_cached.cache_clear() + def _split_layout_and_sharding(entries): entries_flat, treedef = tree_flatten(entries, is_leaf=lambda x: x is None) @@ -391,14 +433,16 @@ def _split_layout_and_sharding(entries): return tree_unflatten(treedef, layouts), tree_unflatten(treedef, shardings) -def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any, - donate_argnums: int | Sequence[int] | None, - donate_argnames: str | Iterable[str] | None, +def _parse_jit_arguments(fun: Callable, *, in_shardings: Any, + out_shardings: Any, static_argnums: int | Sequence[int] | None, static_argnames: str | Iterable[str] | None, - device: xc.Device | None, backend: str | None, - abstracted_axes: Any | None, keep_unused: bool, - inline: bool, compiler_options: dict[str, Any] | None, + donate_argnums: int | Sequence[int] | None, + donate_argnames: str | Iterable[str] | None, + keep_unused: bool, device: xc.Device | None, + backend: str | None, inline: bool, + abstracted_axes: Any | None, + compiler_options: dict[str, Any] | None, use_resource_env: bool) -> PjitInfo: """Parses the arguments to jit/pjit. @@ -476,56 +520,30 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any, use_resource_env=use_resource_env, compiler_options_kvs=compiler_options_kvs) - -def _make_jit_wrapper(fun: Callable, jit_info: PjitInfo): - - @api_boundary - def lower(*args, **kwargs): - return trace(*args, **kwargs).lower() - - @api_boundary - def eval_shape(*args, **kwargs): - p, _ = _infer_params(fun, jit_info, args, kwargs) - out_s = [None if isinstance(s, UnspecifiedValue) else s for s in p.params['out_shardings']] - # TODO(yashkatariya): Add `Layout` to SDS. - out = [api.ShapeDtypeStruct(x.shape, x.dtype, sharding=s, - weak_type=x.weak_type) - for x, s in zip(p.params['jaxpr'].out_avals, out_s)] - return tree_unflatten(p.out_tree, out) - - @api_boundary - def trace(*args, **kwargs) -> stages.Traced: - p, args_flat = _infer_params(fun, jit_info, args, kwargs) - donate_argnums = tuple(i for i, d in enumerate(p.donated_invars) if d) - args_info = stages.make_args_info(p.in_tree, p.in_avals, donate_argnums) - lower_callable = partial(_resolve_and_lower, args_flat, **p.params, - pgle_profiler=None) - return stages.Traced( - p.params['jaxpr'], args_info, p.params["name"], p.out_tree, - lower_callable, args_flat, p.arg_names, p.num_consts) - - wrapped = _cpp_pjit(fun, jit_info) - wrapped.lower = lower - wrapped.eval_shape = eval_shape - wrapped.trace = trace - return wrapped - - -def make_jit(fun: Callable, in_shardings: Any, out_shardings: Any, - donate_argnums: int | Sequence[int] | None, - donate_argnames: str | Iterable[str] | None, +def make_jit(fun: Callable, + *, + in_shardings: Any, + out_shardings: Any, static_argnums: int | Sequence[int] | None, static_argnames: str | Iterable[str] | None, - device: xc.Device | None, backend: str | None, - abstracted_axes: Any | None, keep_unused: bool, - inline: bool, compiler_options: dict[str, Any] | None, + donate_argnums: int | Sequence[int] | None, + donate_argnames: str | Iterable[str] | None, + keep_unused: bool, + device: xc.Device | None, + backend: str | None, + inline: bool, + abstracted_axes: Any | None, + compiler_options: dict[str, Any] | None, use_resource_env: bool) -> Any: """jit() and pjit() are thin wrappers around this function.""" jit_info = _parse_jit_arguments( - fun, in_shardings, out_shardings, donate_argnums, donate_argnames, - static_argnums, static_argnames, device, backend, abstracted_axes, - keep_unused, inline, compiler_options, use_resource_env) - return _make_jit_wrapper(fun, jit_info) + fun, in_shardings=in_shardings, out_shardings=out_shardings, + static_argnums=static_argnums, static_argnames=static_argnames, + donate_argnums=donate_argnums, donate_argnames=donate_argnames, + keep_unused=keep_unused, device=device, backend=backend, inline=inline, + abstracted_axes=abstracted_axes, compiler_options=compiler_options, + use_resource_env=use_resource_env) + return _cpp_pjit(fun, jit_info) class PjitParams(NamedTuple): @@ -537,7 +555,7 @@ class PjitParams(NamedTuple): donated_invars: tuple[bool, ...] arg_names: tuple[str, ...] num_consts: int - attrs_tracked: list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]] + attrs_tracked: list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str, Any]]] def _infer_params_impl( @@ -613,14 +631,14 @@ def _infer_params_impl( ji.in_layouts_treedef, ji.in_layouts_leaves, in_avals, in_tree, flat_fun.debug_info, device_or_backend_set, have_kwargs) - attr_token = _attr_token(flat_fun, in_type) + attr_token = _attr_cache_index(flat_fun, in_type) jaxpr, consts, out_avals, attrs_tracked = _create_pjit_jaxpr( flat_fun, in_type, attr_token, IgnoreKey(ji.inline)) if config.mutable_array_checks.value: _check_no_aliased_closed_over_refs(dbg, (*jaxpr.consts, *consts), explicit_args) - _attr_update(flat_fun, in_type, attr_token, attrs_tracked) + _attr_cachedata_update(flat_fun, in_type, attr_token, attrs_tracked) out_shardings_flat, out_layouts_flat = _check_and_canonicalize_out_shardings( out_shardings_treedef, out_shardings_leaves, ji.out_layouts_treedef, @@ -636,13 +654,14 @@ def _infer_params_impl( implicit_args = [] args_flat = [*implicit_args, *explicit_args] - num_states_in = sum(init_tree.num_leaves for init_tree, _, _ in attrs_tracked) - num_extra_args = len(implicit_args) + num_states_in + len(consts) + num_attrs_in = sum(init_tree.num_leaves for init_tree, _, (_, _, kind) + in attrs_tracked if kind is pe.ReadWrite) + num_extra_args = len(implicit_args) + num_attrs_in + len(consts) in_shardings_flat = (UNSPECIFIED,) * num_extra_args + in_shardings_flat in_layouts_flat = (None,) * num_extra_args + in_layouts_flat donated_invars = (False,) * num_extra_args + donated_invars assert (len(in_shardings_flat) == len(in_layouts_flat) == - len(donated_invars) == num_states_in + len(consts) + len(args_flat)) + len(donated_invars) == num_attrs_in + len(consts) + len(args_flat)) params = dict( jaxpr=jaxpr, @@ -672,7 +691,7 @@ def __init__(self): # We use an outer cache that is keyed on the signature of the arguments, but # when populating a cache entry using _infer_params_impl, we need to provide -# actual arguments. In principle we could refactor _infer_params_impl to look +# actual arguments. In principle, we could refactor _infer_params_impl to look # only at an argument signature instead of args/kwargs in those cases that we # cache, but this was a more minimal change. @util.weakref_lru_cache @@ -689,8 +708,10 @@ def _infer_params_cached( def _infer_params( fun: Callable, ji: PjitInfo, args: tuple[Any, ...], kwargs: dict[str, Any] ) -> tuple[PjitParams, list[Any]]: - if ji.use_resource_env: - with sharding_impls.use_mesh(mesh_lib.thread_resources.env.physical_mesh): + if ji.use_resource_env: # pjit + phys_mesh = mesh_lib.thread_resources.env.physical_mesh + with (_internal_use_concrete_mesh(phys_mesh), + mesh_lib.use_abstract_mesh(phys_mesh.abstract_mesh)): return _infer_params_internal(fun, ji, args, kwargs) return _infer_params_internal(fun, ji, args, kwargs) @@ -717,7 +738,7 @@ def _infer_params_internal( if entry.pjit_params is None: p, args_flat = _infer_params_impl( fun, ji, ctx_mesh, dbg, args, kwargs, in_avals=avals) - if p.attrs_tracked: # if attrs, don't popoulate the cache + if p.attrs_tracked: # if attrs, don't populate the cache return p, p.consts + args_flat entry.pjit_params = p return entry.pjit_params, entry.pjit_params.consts + dynargs @@ -942,7 +963,7 @@ def pjit( be donated. For more details on buffer donation see the - `FAQ `_. + `FAQ `_. donate_argnames: An optional string or collection of strings specifying which named arguments are donated to the computation. See the comment on ``donate_argnums`` for details. If not @@ -984,9 +1005,12 @@ def pjit( [ 0.5 2. 4. 6. 8. 10. 12. 10. ] """ return make_jit( - fun, in_shardings, out_shardings, donate_argnums, donate_argnames, - static_argnums, static_argnames, device, backend, abstracted_axes, - keep_unused, inline, compiler_options, use_resource_env=True) + fun, in_shardings=in_shardings, out_shardings=out_shardings, + static_argnums=static_argnums, static_argnames=static_argnames, + donate_argnums=donate_argnums, donate_argnames=donate_argnames, + keep_unused=keep_unused, device=device, backend=backend, inline=inline, + abstracted_axes=abstracted_axes, compiler_options=compiler_options, + use_resource_env=True) def hashable_pytree(pytree): @@ -1131,139 +1155,267 @@ def _process_in_axis_resources(in_shardings_treedef, in_shardings_leaves, debug_info.safe_arg_names(len(in_avals)), "jit arguments") # type: ignore[arg-type] return in_shardings_flat, in_layouts_flat -callsites: set[str] = set() +callsites_with_tracing_cache_miss: set[str] = set() + +def diff_tracing_cache_keys( + k: tuple, oldk: tuple, debug_info: lu.DebugInfo) -> tuple[Sequence[str], int]: + """Explanations of differences between the cache keys, along with diff sizes. + + Result: a pair of a list of explanations for differences, and the total size + of the differences. The sizes are used to pick the old key with the smallest + different size for the explanation that is shown to the user. + """ + (fun_transforms_k, fun_params_k, fun_in_type_k, + (arg_in_type_k, arg_attr_data_k, arg_inline_k), ctx_k) = k + (fun_transforms_ok, fun_params_ok, fun_in_type_ok, + (arg_in_type_ok, arg_attr_data_ok, arg_inline_ok), ctx_ok) = oldk + + diffs: list[tuple[str, int]] = [] # each difference with its size + def unavailable(key_field: str, what_k, what_ok): + diffs.append( + (f"different {key_field}:\n now: {what_k}\n != before: {what_ok}.\n" + "explanation unavailable! " + "please open an issue at https://github.com/jax-ml/jax.", + 10)) + + def list_diff_size(s1: Sequence, s2: Sequence) -> int: + min_len = min(len(s1), len(s2)) + diff_size = max(len(s1), len(s2)) - min_len + diff_size += sum(e1 != e2 for e1, e2 in zip(s1[:min_len], + s2[:min_len])) + return diff_size + + different_leaf_count = False + + def explain_transform_argnums_partial(param_k: tuple, param_ok: tuple): + dyn_argnums_k, static_args_k = param_k + dyn_argnums_ok, static_args_ok = param_ok + if dyn_argnums_k != dyn_argnums_ok: + diffs.append( + ("different static_argnums:\n" + f" dynamic argnums now {dyn_argnums_k} and before {dyn_argnums_ok}", + 1)) + if static_args_k != static_args_ok: + diffs.append( + ("different value of static args:\n" + f" now {', '.join(repr(a.val) for a in static_args_k)}" + f" and before {', '.join(repr(a.val) for a in static_args_ok)}", + list_diff_size(static_args_k, static_args_ok))) + + def explain_transform_argnames_partial(param_k: tuple, param_ok: tuple): + static_kwargs_k, = param_k + static_kwargs_ok, = param_ok + static_kwargs_k = [(k, v.val) for k, v in + sorted(static_kwargs_k.val.items())] + static_kwargs_ok = [(k, v.val) for k, v in + sorted(static_kwargs_ok.val.items())] + if static_kwargs_k != static_kwargs_ok: + diffs.append( + ("different value of static kwargs:\n" + f" now {{{', '.join(f'{k}: {repr(v)}' for k, v in static_kwargs_k)}}}" + f" and before {{{', '.join(f'{k}: {repr(v)}' for k, v in static_kwargs_ok)}}}", + list_diff_size(static_kwargs_k, static_kwargs_ok))) + + def explain_in_tree_diff(in_tree_k: PyTreeDef, in_tree_ok: PyTreeDef): + nonlocal different_leaf_count + different_leaf_count = (in_tree_k.num_leaves != in_tree_ok.num_leaves) + if not different_leaf_count: + # Look for the special case of passing positional args as kwargs or + # vice-versa; the common prefix of positional args match. + args_tree_k, kwargs_tree_k = treedef_children(in_tree_k) + nr_args_k = len(treedef_children(args_tree_k)) + args_tree_ok, kwargs_tree_ok = treedef_children(in_tree_ok) + nr_args_ok = len(treedef_children(args_tree_k)) + if (treedef_children(args_tree_k)[:min(nr_args_k, nr_args_ok)] == + treedef_children(args_tree_ok)[:min(nr_args_k, nr_args_ok)]): + keys_k = kwargs_tree_k.node_data()[1] # type: ignore[index] + keys_ok = kwargs_tree_ok.node_data()[1] # type: ignore[index] + diffs.append( + (("different number of args and kwargs, but same total number.\n" + f" now {nr_args_k} args and kwargs " + f"with keys {keys_k}\n" + f" before {nr_args_ok} args and kwargs " + f"with keys {keys_ok}"), + abs(nr_args_ok - nr_args_k))) + return + + in_tree_k_str = str(in_tree_k) + in_tree_k_str = (in_tree_k_str if len(in_tree_k_str) < 73 + else in_tree_k_str[:73] + "...") + in_tree_ok_str = str(in_tree_ok) + in_tree_ok_str = (in_tree_ok_str if len(in_tree_ok_str) < 73 + else in_tree_ok_str[:73] + "...") + diff = [f"different input pytree:\n now: {in_tree_k_str}\n" + f" before: {in_tree_ok_str}"] + + errs = list(tree_util.equality_errors_pytreedef(in_tree_k, in_tree_ok)) + for path, thing1, thing2, explanation in errs: + fst, *path = path # type: ignore + base = ["args", "kwargs"][fst.idx] + diff.append( + f" * at {base}{keystr(tuple(path))}, now {thing1} and before {thing2}," + f" so {explanation}") + diffs.append(("\n".join(diff), len(errs))) + + def explain_args_type_diff(args_k: tuple[core.AbstractValue], + args_ok: tuple[core.AbstractValue]): + diff_size = 0 + arg_names = debug_info.safe_arg_names(len(args_k)) + def arg_type_to_str(at): + if hasattr(at, "str_short"): + return at.str_short(short_dtypes=True) + else: + return str(at) + args_k_str = ", ".join(f"{an}: {arg_type_to_str(at)}" + for an, at in zip(arg_names, args_k)) + args_k_str = args_k_str if len(args_k_str) < 73 else args_k_str[:73] + "..." + diff = [f"different input types:\n types now: {args_k_str}"] + add_weak_type_hint = False + + for name, arg_t_k, arg_t_ok in zip(arg_names, args_k, args_ok): + if arg_t_k == arg_t_ok: continue + this_arg_diff_size = 0 + if type(arg_t_k) == type(arg_t_ok) == core.ShapedArray: + s1, s2 = arg_type_to_str(arg_t_k), arg_type_to_str(arg_t_ok) + this_arg_diff_size += list_diff_size(arg_t_k.shape, arg_t_ok.shape) # type: ignore + + if arg_t_k.weak_type != arg_t_ok.weak_type: # type: ignore + s1 += f"{{weak_type={arg_t_k.weak_type}}}" # type: ignore + s2 += f"{{weak_type={arg_t_ok.weak_type}}}" # type: ignore + add_weak_type_hint = True + this_arg_diff_size += 1 + elif arg_t_k.sharding != arg_t_ok.sharding: # type: ignore + s1 = arg_t_k.str_short(short_dtypes=True, mesh_axis_types=True) # type: ignore + s2 = arg_t_ok.str_short(short_dtypes=True, mesh_axis_types=True) # type: ignore + this_arg_diff_size += 1 + else: + s1, s2 = str(arg_t_k), str(arg_t_ok) + diff_size += max(1, this_arg_diff_size) + diff.append(f" * at {name}, now {s1} and before {s2}") + + if add_weak_type_hint: + diff.append( + "where weak_type=True often means a Python builtin numeric value, and \n" + "weak_type=False means a jax.Array.\n" + "See https://docs.jax.dev/en/latest/type_promotion.html#weak-types.") + diffs.append(("\n".join(diff), diff_size)) + + if fun_transforms_k != fun_transforms_ok: + if len(fun_transforms_k) != len(fun_transforms_ok): + different_leaf_count = True # Skip other more precise checks + unavailable("fun_transforms length", + fun_transforms_k, fun_transforms_ok) + else: + for i, (t, ot) in enumerate(zip(fun_transforms_k, fun_transforms_ok)): + t_name = t[0].__name__ + if t == ot: continue + if t[0] != ot[0]: + unavailable(f"fun_transforms[{i}] transform", t, ot) + continue + + if t_name == "flatten_fun": + explain_in_tree_diff(t[1][0], ot[1][0]) + continue + if t_name == "_argnums_partial": + explain_transform_argnums_partial(t[1], ot[1]) + continue + if t_name == "_argnames_partial": + explain_transform_argnames_partial(t[1], ot[1]) + continue + unavailable(f"fun_transforms.{t_name} params", t[1:], ot[1:]) + continue + + # If we had different leaf counts, we can discard the _argnums_partial + # difference. That transform sometimes occurs before the flatten_fun + if different_leaf_count: + diffs = [d for d in diffs if "fun_transforms._argnums_partial" not in d[0]] + if fun_params_k != fun_params_ok: + unavailable("fun_params", fun_params_k, fun_params_ok) + if fun_in_type_k != fun_in_type_ok: + unavailable("fun_in_type", fun_params_k, fun_params_ok) + if arg_in_type_k != arg_in_type_ok and not different_leaf_count: + explain_args_type_diff(arg_in_type_k, arg_in_type_ok) + if arg_attr_data_k != arg_attr_data_ok: + unavailable("arg_attr_data", arg_attr_data_k, arg_attr_data_ok) + if arg_inline_k != arg_inline_ok: + unavailable("arg_inline", arg_inline_k, arg_inline_ok) + if ctx_k != ctx_ok: + assert len(ctx_k) == len(ctx_ok) + idxs = [f" [{i}]: now {c_k} and before {c_ok}" + for i, (c_k, c_ok) in enumerate(zip(ctx_k, ctx_ok)) if c_k != c_ok] + diffs.append( + ("different tracing context, e.g. due to config or context manager.\n" + "found differences at positions\n" + + ", and\n".join(idxs) + + "\ncompare to tuple returned by " + "config.trace_context() in jax/_src/config.py.", + len(idxs))) + if not diffs: # Should never happen, but let's not crash + unavailable("something (unexpected empty diffs)", k, oldk) + diffs_and_sizes = util.unzip2(sorted(diffs, key=lambda d: d[1])) + return (diffs_and_sizes[0], sum(diffs_and_sizes[1])) + def explain_tracing_cache_miss( - fun: lu.WrappedFun, unseen_f: bool, cache: dict, key: tuple): + fun: lu.WrappedFun, unseen_f: bool, cache: dict, + key: tuple, elapsed_sec: float): if config.check_tracer_leaks.value: return - - def unpack(key): - transforms, (), _, (in_type, _, inline), *_, ctx = key - # TODO(dougalm,mattjj): enable cache miss explanation with attrs - _, (_, (in_tree,)), *_ = transforms - return in_tree, in_type, inline.val, ctx - in_tree, in_type, inline, ctx = unpack(key) - if inline: return + if key[3][2].val: return # No explanations for "inline" functions debug_info = fun.debug_info + func_filename = debug_info.func_filename + if func_filename and not source_info_util.is_user_filename(func_filename): + return + msg: list[str] = [] p = msg.append - done = lambda: logger.log(logging.WARNING, '\n'.join(msg)) + done = lambda: logger.log(logging.WARNING, "\n".join(msg)) callsite = source_info_util.summarize(source_info_util.current()) - p(f"TRACING CACHE MISS at {callsite} because:") + p(f"TRACING CACHE MISS at {callsite} costing {elapsed_sec * 1e3:.3f} ms because:") # have we seen this function before at all? - fun_name = getattr(fun.f, '__qualname__', fun.f) - if debug_info.func_src_info: - # TODO(necula): clean up the extraction of the source info - _, *rest = debug_info.func_src_info.split(' at ') - src_info = " defined at " + ' '.join(rest) - else: - src_info = '' - if unseen_f: - p(f" never seen function:\n {fun_name} id={id(fun.f)}{src_info}") - if callsite in callsites: + src_info = "" + if func_filename: + src_info += f" defined at {func_filename}" + if func_lineno := debug_info.func_lineno: + src_info += f":{func_lineno}" + func_name = debug_info.func_name + if unseen_f or not cache: + p(f" never seen function:\n {func_name} id={id(fun.f)}{src_info}") + if callsite in callsites_with_tracing_cache_miss: p(" but seen another function defined on the same line; maybe the function is\n" " being re-defined repeatedly, preventing caching?") - callsites.add(callsite) - return done() - else: - p(f" for {fun_name}{src_info}") - - seen_keys = map(unpack, cache.keys()) - - # have we maybe switched some args to be kwargs or visa-versa? - args_tree, kwargs_tree = treedef_children(in_tree) - args_kwargs_trees = [treedef_children(k) for k, *_ in seen_keys] - args_kwargs_match = [t for t in args_kwargs_trees - if t == [args_tree, kwargs_tree]] - if not args_kwargs_match: - num_args = len(treedef_children(args_tree)) - _, kwarg_keys = kwargs_tree.node_data() # type: ignore - p(f" never seen passing {num_args} positional args and {len(kwarg_keys)} " - "keyword args with keys:\n" - f" {', '.join(map(repr, kwarg_keys))}") - dont_match = [set(t[1].node_data()[1]) for t in args_kwargs_trees # type: ignore - if t != [args_tree, kwargs_tree]] - close_kwargs = min( - dont_match, key=set(kwarg_keys).symmetric_difference, default=None - ) - if not close_kwargs: - p(" closest seen is passing no keyword args") else: - p(f" closest seen passes {len(close_kwargs)} keyword args with keys:\n" - f" {', '.join(map(repr, close_kwargs))}") - return done() - - # have we never seen this tracing context before? - ctxs_match = [c for *_, c in seen_keys if c == ctx] - if not ctxs_match: - p(" tracing context doesn't match, e.g. due to config or context manager") - dont_match = [c for *_, c in seen_keys if c != ctx] - closest_ctx = min(dont_match, key=lambda c: sum(map(op.ne, c, ctx))) - idxs = [i for i, (c1, c2) in enumerate(zip(ctx, closest_ctx)) if c1 != c2] - p(" closest seen context tuple differs at positions:\n" - f" {', '.join(map(str, idxs))}\n" - " compare to tuple returned by config._trace_context() in jax/_src/config.py.") + callsites_with_tracing_cache_miss.add(callsite) return done() - # have we never seen this input pytree before? - trees_match = [k for k in seen_keys if k[0] == in_tree] - if not trees_match: - in_tree_str = f':\n {in_tree}' if len(str(in_tree)) < 76 else '' - p(f" never seen input pytree{in_tree_str}") - dont_match = [t for t, *_ in seen_keys if t != in_tree] - closest_tree = min(dont_match, key=lambda t: abs(t.num_leaves - in_tree.num_leaves)) - errs = list(tree_util.equality_errors_pytreedef(in_tree, closest_tree)) # type: ignore[arg-type] - p(f" closest seen input pytree has {len(errs)} mismatches, including:") - for path, thing1, thing2, explanation in errs: - fst, *path = path # type: ignore - base = ['args', 'kwargs'][fst.idx] - p(f" * at {base}{keystr(tuple(path))}, seen {thing2} but now given {thing1}," - f" so {explanation}") - return done() + p(f" for {func_name}{src_info}") + + diffs = [diff_tracing_cache_keys(key, ok, debug_info) + for ok in cache.keys() if key != ok] + assert diffs, "we must find some diffs if key differs from all cache keys" + min_diff = min(diffs, key=lambda v: v[1]) + smallest_diffs: Sequence[Sequence[str]] # the diffs for the closest keys + smallest_diffs = [d[0] for d in diffs if d[1] == min_diff[1]] + def indent_subsequent_lines(indent: int, msg: str) -> str: + return msg.replace("\n", "\n" + " " * indent) + def p_one_diff(diff: Sequence[str]): + for d in diff: + p(" * key with " + indent_subsequent_lines(4, d)) + + if len(smallest_diffs) == 1: + p(" all previously seen cache keys are different. Closest previous key:") + p_one_diff(smallest_diffs[0]) + else: + p(" all previously seen cache keys are different. " + "Several previous keys are closest:") + for d in smallest_diffs: + p_one_diff(d) - # have we never seen these input types (eg shapes, dtypes) before? - types_match = [k for k in trees_match if k[1] == in_type] - if not types_match: - if len(in_type) < 5: - in_type_str = ':\n {}'.format(', '.join( - f'{n}: {ty.str_short(short_dtypes=True)}' - for n, ty in zip(debug_info.arg_names, in_type))) - else: - in_type_str = '' - p(f" never seen input type signature{in_type_str}") - dont_match = [t for _, t, *_ in trees_match if t != in_type] - closest_ty = min(dont_match, key=lambda t: sum(map(op.ne, t, in_type))) - num_mismatch = sum(map(op.ne, closest_ty, in_type)) - p(f" closest seen input type signature has {num_mismatch} mismatches, including:") - add_weak_type_hint = False - arg_names = debug_info.safe_arg_names(len(in_type)) - - for name, ty1, ty2 in zip(arg_names, closest_ty, in_type): - if ty1 != ty2: - if type(ty1) == type(ty2) == core.ShapedArray: - s1, s2 = ty1.str_short(True), ty2.str_short(True) - if ty1.weak_type != ty2.weak_type: - s1 += f'{{weak_type={ty1.weak_type}}}' - s2 += f'{{weak_type={ty2.weak_type}}}' - add_weak_type_hint = True - elif ty1.sharding != ty2.sharding: - s1 = ty1.str_short(short_dtypes=True, mesh_axis_types=True) - s2 = ty2.str_short(short_dtypes=True, mesh_axis_types=True) - else: - s1, s2 = str(ty1), str(ty2) - p(f" * at {name}, seen {s1}, but now given {s2}") - if add_weak_type_hint: - p('where weak_type=True often means a Python builtin numeric value, and ') - p('weak_type=False means a jax.Array.') - p('See https://jax.readthedocs.io/en/latest/type_promotion.html#weak-types') - return done() + done() + return - # we think this is unreachable... - p("explanation unavailable! please open an issue at https://github.com/jax-ml/jax") - return done() @partial(lu.cache, explain=explain_tracing_cache_miss) def _create_pjit_jaxpr( @@ -1272,7 +1424,7 @@ def _create_pjit_jaxpr( attr_data: int, ignored_inline: IgnoreKey ) -> tuple[core.ClosedJaxpr, list[Any], list[core.AbstractValue], - list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]: + list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str, Any]]]]: util.test_event("create_pjit_jaxpr") del ignored_inline # just for explain_cache_miss if config.no_tracing.value: @@ -1348,32 +1500,31 @@ def seen_attrs_get( assert fun.in_type is None or fun.in_type == in_type return cache[(fun.transforms, fun.params, in_type)] -def _attr_token( +def _attr_cache_index( fun: lu.WrappedFun, in_type: core.InputType | tuple[core.AbstractValue, ...] ) -> int: - from jax.experimental.attrs import jax_getattr + from jax.experimental.attrs import dne_sentinel cases = seen_attrs_get(fun, in_type) for i, records in enumerate(cases): - for obj, attr, treedef, avals in records: - val = jax_getattr(obj, attr) - vals, treedef_ = tree_flatten(val) - avals_ = map(core.shaped_abstractify, vals) - if treedef != treedef_ or avals != avals_: break + for obj, attr, kind, treedef, avals in records: + if kind is pe.ReadWrite: + val = getattr(obj, attr, dne_sentinel) + vals, treedef_ = tree_flatten(val) + avals_ = map(core.shaped_abstractify, vals) + if treedef != treedef_ or avals != avals_: break else: return i return len(cases) -def _attr_update(fun, in_type, i, attrs_tracked): - from jax.experimental.attrs import jax_getattr - leaves = lambda obj, attr: tree_leaves(jax_getattr(obj, attr)) - records = [(obj, attr, init_tree, map(core.shaped_abstractify, leaves(obj, attr))) - for init_tree, _, (obj, attr) in attrs_tracked] +def _attr_cachedata_update(fun, in_type, i, attrs_tracked): + from jax.experimental.attrs import dne_sentinel + leaves = lambda obj, attr: tree_leaves(getattr(obj, attr, dne_sentinel)) + records = [(obj, attr, kind, init_tree, map(core.typeof, leaves(obj, attr))) + for init_tree, _, (obj, attr, kind) in attrs_tracked] cases = seen_attrs_get(fun, in_type) if i == len(cases): cases.append(records) - else: - assert i < len(cases) and cases[i] == records @dataclasses.dataclass(frozen=True) @@ -1538,6 +1689,7 @@ def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding] committed_arg_shardings.append((arg_s, pxla.MismatchType.ARG_SHARDING, None)) resolved_in_shardings: list[PjitSharding] = [] + assert len(args) == len(pjit_in_shardings) for arg, pjit_in_s in zip(args, pjit_in_shardings): # arg sharding can be None in case of ShapeDtypeStruct. jax.Array does # not allow None as the sharding. @@ -1571,14 +1723,12 @@ def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding] 'Passing non-trivial shardings for numpy ' 'inputs is not allowed. To fix this error, either specify a ' 'replicated sharding explicitly or use ' - '`jax.experimental.multihost_utils.host_local_array_to_global_array(...)` ' + '`jax.make_array_from_process_local_data(...)` ' 'to convert your host local numpy inputs to a jax.Array which you ' - 'can pass to pjit. ' + 'can pass to jit. ' 'If the numpy input is the same on each process, then you can use ' '`jax.make_array_from_callback(...) to create a `jax.Array` which ' - 'you can pass to pjit. ' - 'Please see the jax.Array migration guide for more information ' - 'https://jax.readthedocs.io/en/latest/jax_array_migration.html#handling-of-host-local-inputs-to-pjit-like-batch-etc. ' + 'you can pass to jit. ' f'Got arg shape: {arg.shape}, arg value: {arg}') if not isinstance(arg_s, UnspecifiedValue) and arg_s._is_concrete: # jax.jit does not allow resharding across different memory kinds even @@ -1777,11 +1927,12 @@ def pjit_staging_rule(trace, *args, **params): return pe.inline_jaxpr_into_trace( trace, jaxpr.jaxpr, jaxpr.consts, *args) - jaxpr, in_fwd, out_shardings, out_layouts = _pjit_forwarding( - params['jaxpr'], params['out_shardings'], params['out_layouts']) - params = dict(params, jaxpr=jaxpr, out_shardings=out_shardings, - out_layouts=out_layouts) + jaxpr = params['jaxpr'] if config.dynamic_shapes.value: + jaxpr, in_fwd, out_shardings, out_layouts = _pjit_forwarding( + jaxpr, params['out_shardings'], params['out_layouts']) + params = dict(params, jaxpr=jaxpr, out_shardings=out_shardings, + out_layouts=out_layouts) source_info = source_info_util.current() out_tracers = [] for aval in _out_type(jaxpr): @@ -1795,6 +1946,10 @@ def pjit_staging_rule(trace, *args, **params): map(trace.getvar, args), map(trace.makevar, out_tracers), pjit_p, params, jaxpr.effects, source_info) trace.frame.add_eqn(eqn) + out_tracers_ = iter(out_tracers) + out_tracers = [args[f] if type(f) is int else next(out_tracers_) + for f in in_fwd] + assert next(out_tracers_, None) is None elif any(isinstance(c, core.MutableArray) for c in jaxpr.consts): jaxpr, consts = pxla._move_mutable_consts(jaxpr) consts = map(trace.new_const, consts) @@ -1807,19 +1962,14 @@ def pjit_staging_rule(trace, *args, **params): pjit_p, (*args, *consts), new_params) else: out_tracers = trace.default_process_primitive(pjit_p, args, params) - - out_tracers_ = iter(out_tracers) - out_tracers = [args[f] if type(f) is int else next(out_tracers_) - for f in in_fwd] - assert next(out_tracers_, None) is None return out_tracers pe.custom_staging_rules[pjit_p] = pjit_staging_rule def _pjit_forwarding(jaxpr, out_shardings, out_layouts): in_fwd: list[int | None] = pe._jaxpr_forwarding(jaxpr.jaxpr) - in_fwd = [fwd if isinstance(os, UnspecifiedValue) and ol is None else None for fwd, os, ol - in zip(in_fwd, out_shardings, out_layouts)] + in_fwd = [fwd if isinstance(os, UnspecifiedValue) and ol is None else None + for fwd, os, ol in zip(in_fwd, out_shardings, out_layouts)] keep = [f is None for f in in_fwd] jaxpr = pe.prune_closed_jaxpr_outputs(jaxpr, keep) out_shardings = [o for o, k in zip(out_shardings, keep) if k] @@ -1827,6 +1977,8 @@ def _pjit_forwarding(jaxpr, out_shardings, out_layouts): return jaxpr, in_fwd, out_shardings, out_layouts def pjit_forwarding_rule(eqn): + if not config.dynamic_shapes.value: + return [None] * len(eqn.outvars), eqn jaxpr, in_fwd, out_shardings, out_layouts = _pjit_forwarding( eqn.params['jaxpr'], eqn.params['out_shardings'], eqn.params['out_layouts']) new_outvars = [v for v, f in zip(eqn.outvars, in_fwd) if f is None] @@ -1835,6 +1987,7 @@ def pjit_forwarding_rule(eqn): new_eqn = eqn.replace(params=new_params, outvars=new_outvars) fwd_vars = [eqn.invars[f] if f is not None else None for f in in_fwd] return fwd_vars, new_eqn +# TODO(mattjj): Remove pjit_forwarding_rule and also in staging rule. pe.forwarding_rules[pjit_p] = pjit_forwarding_rule @@ -2062,14 +2215,47 @@ def _pjit_linearization(nzs, *primals_in, jaxpr, donated_invars, ctx_mesh, name, keep_unused, inline, compiler_options_kvs): primal_jaxpr, num_residuals, nzs_out, tangent_jaxpr = ad.linearize_jaxpr(jaxpr, nzs) - # constvars will become residuals. Move them to the end of the ordinary args. res_shardings = (UNSPECIFIED,) * num_residuals res_layouts = (None,) * num_residuals res_donated = (False,) * num_residuals + primal_out_shardings = res_shardings + tuple(out_shardings) + primal_out_layouts = res_layouts + tuple(out_layouts) + + def keep_where(l, should_keep): + return tuple(x for x, keep in zip(l, should_keep) if keep) + + # Input-to-output forwarding. + in_fwd = pe._jaxpr_forwarding(primal_jaxpr.jaxpr) + in_fwd_res, in_fwd_primal = split_list(in_fwd, [num_residuals]) + in_fwd = in_fwd_res + [ + fwd if isinstance(os, UnspecifiedValue) and ol is None else None + for os, ol, fwd in zip(out_shardings, out_layouts, in_fwd_primal) + ] + del in_fwd_res, in_fwd_primal + keep = [f is None for f in in_fwd] + primal_jaxpr = pe.prune_closed_jaxpr_outputs(primal_jaxpr, keep) + primal_out_shardings = keep_where(primal_out_shardings, keep) + primal_out_layouts = keep_where(primal_out_layouts, keep) + kept_res, _ = split_list(keep, [num_residuals]) + num_kept_residuals = sum(kept_res) + del keep, kept_res + + # Output-to-output forwarding. + num_out_primals = len(primal_jaxpr.jaxpr.outvars) - num_kept_residuals + res_vars, out_vars = split_list(primal_jaxpr.jaxpr.outvars, [num_kept_residuals]) + idx_map = {id(v): i for i, v in enumerate(out_vars)} + offset = sum(id(v) not in idx_map for v in res_vars) + idx_map = {k: v + offset for k, v in idx_map.items()} + out_fwd = [idx_map.get(id(v)) for v in res_vars] + [None] * num_out_primals + keep = [f is None for f in out_fwd] + primal_jaxpr = pe.prune_closed_jaxpr_outputs(primal_jaxpr, keep) + primal_out_shardings = keep_where(primal_out_shardings, keep) + primal_out_layouts = keep_where(primal_out_layouts, keep) + del keep + def tangent_fun(consts_, *tangents): tangents_nz = _filter_zeros(nzs, tangents) - assert len(consts_) == num_residuals - nz_tangents_out = pjit_p.bind(*(*tangents_nz, *consts_), + nz_tangents_out = pjit_p.bind(*tangents_nz, *consts_, jaxpr=tangent_jaxpr, in_shardings=_filter_zeros(nzs, in_shardings) + res_shardings, out_shardings=_filter_zeros(nzs_out, out_shardings), @@ -2092,15 +2278,17 @@ def _filter_zeros(is_nz_l, l): ans = pjit_p.bind(*primals_in, jaxpr=primal_jaxpr, in_shardings=in_shardings, - out_shardings=(*res_shardings, *out_shardings), + out_shardings=primal_out_shardings, in_layouts=in_layouts, - out_layouts=(*res_layouts, *out_layouts), + out_layouts=primal_out_layouts, donated_invars=donated_invars, ctx_mesh=ctx_mesh, name=name, keep_unused=keep_unused, inline=inline, compiler_options_kvs=compiler_options_kvs) + ans = subs_list(out_fwd, ans, ans) + ans = subs_list(in_fwd, primals_in, ans) residuals_ans, primal_ans = split_list(ans, [num_residuals]) return primal_ans, nzs_out, residuals_ans, tangent_fun @@ -2334,11 +2522,12 @@ def prune_type(ty, xs, maybe_zeros): if attrs_tracked: init_states = _get_states(attrs_tracked) + num_attr_outs = sum(final_tree.num_leaves for _, final_tree, _ in attrs_tracked) primals_and_nz_cts_in = [*init_states, *primals_and_nz_cts_in] - transpose_in_shardings = (UNSPECIFIED,) * len(attrs_tracked) + transpose_in_shardings - transpose_out_shardings = (UNSPECIFIED,) * len(attrs_tracked) + transpose_out_shardings - transpose_in_layouts = (None,) * len(attrs_tracked) + transpose_in_layouts - transpose_out_layouts = (None,) * len(attrs_tracked) + transpose_out_layouts + transpose_in_shardings = (UNSPECIFIED,) * len(init_states) + transpose_in_shardings + transpose_out_shardings = (UNSPECIFIED,) * num_attr_outs + transpose_out_shardings + transpose_in_layouts = (None,) * len(init_states) + transpose_in_layouts + transpose_out_layouts = (None,) * num_attr_outs + transpose_out_layouts try: nz_cts_out = pjit_p.bind( @@ -2367,7 +2556,7 @@ def prune_type(ty, xs, maybe_zeros): dispatch._raise_no_nan_in_deoptimized(e) if attrs_tracked: - final_states, nz_cts_out = split_list(nz_cts_out, [len(init_states)]) + final_states, nz_cts_out = split_list(nz_cts_out, [num_attr_outs]) _set_states(attrs_tracked, final_states) return tree_unflatten(cts_out_treedef, nz_cts_out) @@ -2516,7 +2705,7 @@ def with_sharding_constraint(x, shardings): Returns: x_with_shardings: PyTree of jax.Arrays with specified sharding constraints. - .. _Distributed arrays and automatic parallelization: https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html + .. _Distributed arrays and automatic parallelization: https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html """ x_flat, tree = tree_flatten(x) @@ -2821,29 +3010,50 @@ def _reshard_batcher(axis_data, vals_in, dims_in, dst_sharding): # -------------------- auto and user mode ------------------------- def _get_new_mesh(axes: str | tuple[str, ...] | None, - axis_type: mesh_lib.AxisType, name: str, - error_on_manual_to_auto_explict=False): + axis_type: mesh_lib.AxisType, name: str, shardings=None, + error_on_manual_to_auto_explicit=False): cur_mesh = mesh_lib.get_abstract_mesh() - # TODO(yashkatariya): Maybe allow fetching mesh from the args to enable - # computation follows data? - if cur_mesh.empty: + flat_shardings, _ = tree_flatten(shardings) + sharding_mesh = mesh_lib.empty_abstract_mesh + for i in flat_shardings: + if isinstance(i, NamedSharding): + if not sharding_mesh.empty and sharding_mesh != i.mesh: + raise ValueError( + f'Shardings passed to {name} should have the same mesh. Got one' + f' mesh {sharding_mesh} and another {i.mesh}') + sharding_mesh = i.mesh.abstract_mesh + + if sharding_mesh.empty and cur_mesh.empty: raise ValueError( f'Context mesh {cur_mesh} cannot be empty. Please use' ' `jax.sharding.use_mesh` API to enter into a mesh context when using' f' `{name}` API.') + if not sharding_mesh.empty and not cur_mesh.empty: + if sharding_mesh != cur_mesh: + raise ValueError( + f'Context mesh {cur_mesh} must match the mesh passed to shardings' + f' {sharding_mesh}. Recommended approach is to use' + ' `jax.sharding.use_mesh` context manager.') + mesh_to_use = cur_mesh + elif sharding_mesh.empty and not cur_mesh.empty: + mesh_to_use = cur_mesh + else: + assert not sharding_mesh.empty and cur_mesh.empty + mesh_to_use = sharding_mesh + if axes is None: - axes = cur_mesh.axis_names + axes = mesh_to_use.axis_names if not isinstance(axes, tuple): axes = (axes,) for a in axes: - if (error_on_manual_to_auto_explict and - cur_mesh._name_to_type[a] == mesh_lib.AxisType.Manual and + if (error_on_manual_to_auto_explicit and + mesh_to_use._name_to_type[a] == mesh_lib.AxisType.Manual and axis_type in {mesh_lib.AxisType.Auto, mesh_lib.AxisType.Explicit}): raise NotImplementedError( 'Going from `Manual` AxisType to `Auto` or `Explicit` AxisType is not' ' allowed. Please file a bug at https://github.com/jax-ml/jax/issues' ' with your use case') - return cur_mesh.update_axis_types({a: axis_type for a in axes}) + return mesh_to_use.update_axis_types({a: axis_type for a in axes}) def auto_axes(fun, *, axes: str | tuple[str, ...] | None = None, out_shardings=None): @@ -2855,8 +3065,9 @@ def decorator(*args, **kwargs): raise TypeError("Missing required keyword argument: 'out_shardings'") else: _out_shardings = out_shardings - new_mesh = _get_new_mesh(axes, mesh_lib.AxisType.Auto, 'auto_axes', - error_on_manual_to_auto_explict=True) + new_mesh = _get_new_mesh( + axes, mesh_lib.AxisType.Auto, 'auto_axes', shardings=_out_shardings, + error_on_manual_to_auto_explicit=True) with mesh_lib.use_abstract_mesh(new_mesh): in_specs = tree_map(lambda a: core.modify_spec_for_auto_manual( core.get_aval(a).sharding.spec, new_mesh), args) @@ -2883,7 +3094,7 @@ def decorator(*args, **kwargs): else: _in_shardings = in_shardings new_mesh = _get_new_mesh(axes, mesh_lib.AxisType.Explicit, 'explicit_axes', - error_on_manual_to_auto_explict=True) + error_on_manual_to_auto_explicit=True) with mesh_lib.use_abstract_mesh(new_mesh): args = mesh_cast(args, _in_shardings) out = fun(*args, **kwargs) @@ -2899,47 +3110,53 @@ def use_explicit_axes(*axes): with mesh_lib.use_abstract_mesh(new_mesh): yield -# -------------------- helpers -------------------- - -def get_unconstrained_dims(sharding: NamedSharding): - assert sharding.spec is not None - return {i for i, axes in enumerate(sharding.spec) - if axes is PartitionSpec.UNCONSTRAINED} - - -def get_op_sharding_from_executable( - executable) -> tuple[Sequence[xc.OpSharding], Sequence[xc.OpSharding]]: - in_op_shardings: list[xc.OpSharding] = [] - parameter_shardings_from_xla = executable.get_parameter_shardings() - if parameter_shardings_from_xla is not None: - in_op_shardings = parameter_shardings_from_xla +# -------------------- with_dll_constraint -------------------- - out_op_shardings: list[xc.OpSharding] = [] - output_shardings_from_xla = executable.get_output_shardings() - if output_shardings_from_xla is not None: - out_op_shardings = output_shardings_from_xla +def with_dll_constraint(x, layouts): + x_flat, tree = tree_flatten(x) + layouts_flat = tuple(flatten_axes("with_dll_constraint layouts", tree, + layouts)) + if any(not isinstance(l, DeviceLocalLayout) for l in layouts_flat): + raise ValueError( + 'layouts passed to `with_dll_constraint` must be of type' + f' `DeviceLocalLayout`. Got {[type(l) for l in layouts_flat]}') + check_aval_layout_compatibility( + layouts_flat, x_flat, ("",) * len(layouts_flat), + "with_dll_constraint arguments") + outs = [dll_constraint_p.bind(xf, layout=l) + for xf, l in zip(x_flat, layouts_flat)] + return tree_unflatten(tree, outs) - return in_op_shardings, out_op_shardings +dll_constraint_p = core.Primitive('dll_constraint') +dll_constraint_p.def_abstract_eval(lambda x, **_: x) +ad.deflinear2(dll_constraint_p, + lambda ct, _, **params: (dll_constraint_p.bind(ct, **params),)) +def _dll_constraint_impl(x, *, layout): + if not isinstance(x, xc.ArrayImpl): + raise ValueError( + 'with_dll_constraint in eager mode can only be applied to' + f' jax.Arrays. Got {type(x)}') + if x.layout.device_local_layout == layout: # type: ignore + return x + return api.jit(_identity_fn, out_shardings=Layout(layout, x.sharding))(x) +dll_constraint_p.def_impl(_dll_constraint_impl) -def _get_ppspec_from_executable( - executable, mesh - ) -> tuple[Sequence[PartitionSpec], Sequence[PartitionSpec]]: - input_op_shardings, output_op_sharding = get_op_sharding_from_executable( - executable - ) - in_pspec: list[PartitionSpec] = [] - for s in input_op_shardings: - in_pspec.extend(parse_flatten_op_sharding(s, mesh)) +def _dll_constraint_hlo_lowering(ctx, x_node, *, layout): + aval, = ctx.avals_in + out_aval, = ctx.avals_out + return [mlir.wrap_with_layout_op(ctx, x_node, out_aval, layout, aval)] +mlir.register_lowering(dll_constraint_p, + _dll_constraint_hlo_lowering) - out_pspec: list[PartitionSpec] = [] - for s in output_op_sharding: - out_pspec.extend(parse_flatten_op_sharding(s, mesh)) - return in_pspec, out_pspec +def _dll_constraint_batcher(axis_data, vals_in, dims_in, layout): + raise NotImplementedError +batching.fancy_primitive_batchers[dll_constraint_p] = _dll_constraint_batcher +batching.skippable_batchers[dll_constraint_p] = lambda _: () +# -------------------- helpers -------------------- -def get_pspec_from_executable( - executable, mesh: pxla.Mesh -) -> tuple[tuple[PartitionSpec, ...], tuple[PartitionSpec, ...]]: - in_pspec, out_pspec = _get_ppspec_from_executable(executable, mesh) - return tuple(in_pspec), tuple(out_pspec) +def get_unconstrained_dims(sharding: NamedSharding): + assert sharding.spec is not None + return {i for i, axes in enumerate(sharding.spec) + if axes is PartitionSpec.UNCONSTRAINED} diff --git a/jax/_src/pretty_printer.py b/jax/_src/pretty_printer.py index e8fdff497445..d02b6d9962e0 100644 --- a/jax/_src/pretty_printer.py +++ b/jax/_src/pretty_printer.py @@ -201,26 +201,20 @@ def __init__(self, child: Doc, *, foreground: Color | None = None, # non-recursive formulation using an explicit stack, necessary because Python # doesn't have a tail recursion optimization. -def _fits(doc: Doc, width: int, agenda: list[tuple[int, _BreakMode, Doc]] - ) -> bool: +def _fits(doc: Doc, width: int) -> bool: + agenda = [doc] while width >= 0 and len(agenda) > 0: - i, m, doc = agenda.pop() + doc = agenda.pop() if isinstance(doc, _NilDoc): pass elif isinstance(doc, _TextDoc): width -= len(doc.text) elif isinstance(doc, _ConcatDoc): - agenda.extend((i, m, d) for d in reversed(doc.children)) + agenda.extend(reversed(doc.children)) elif isinstance(doc, _BreakDoc): - if m == _BreakMode.BREAK: - return True width -= len(doc.text) - elif isinstance(doc, _NestDoc): - agenda.append((i + doc.n, m, doc.child)) - elif isinstance(doc, _GroupDoc): - agenda.append((i, _BreakMode.FLAT, doc.child)) - elif isinstance(doc, _ColorDoc) or isinstance(doc, _SourceMapDoc): - agenda.append((i, m, doc.child)) + elif isinstance(doc, (_NestDoc, _GroupDoc, _ColorDoc, _SourceMapDoc)): + agenda.append(doc.child) else: raise ValueError("Invalid document ", doc) @@ -372,8 +366,7 @@ def _format( elif isinstance(doc, _GroupDoc): # In Lindig's paper, _fits is passed the remainder of the document. # I'm pretty sure that's a bug and we care only if the current group fits! - if (_sparse(doc) - and _fits(doc, width - k, [(i, _BreakMode.FLAT, doc.child)])): + if (_sparse(doc) and _fits(doc, width - k)): agenda.append(_State(i, _BreakMode.FLAT, doc.child, color, source)) else: agenda.append(_State(i, _BreakMode.BREAK, doc.child, color, source)) diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 2fa9b2b37aa4..dd91097fcf98 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -31,6 +31,7 @@ from jax._src import core from jax._src import dispatch from jax._src import dtypes +from jax._src import ffi from jax._src import pretty_printer as pp from jax._src import source_info_util from jax._src import tree_util as tree_util_internal @@ -50,7 +51,8 @@ from jax._src.numpy.array_methods import ( _array_operators, _set_array_base_attributes, _IndexUpdateHelper) from jax._src.sharding_impls import ( - NamedSharding, PmapSharding, physical_sharding, logical_sharding) + NamedSharding, PmapSharding, SingleDeviceSharding, physical_sharding, + logical_sharding) from jax._src.typing import Array from jax._src.util import safe_map, safe_zip @@ -64,6 +66,13 @@ UINT_DTYPES = { 8: jnp.uint8, 16: jnp.uint16, 32: jnp.uint32, 64: jnp.uint64} +if hasattr(gpu_prng, "registrations"): + for platform, targets in gpu_prng.registrations().items(): + for name, value, api_version in targets: + ffi.register_ffi_target( + name, value, platform=platform, api_version=api_version + ) + # -- PRNG implementation interface class PRNGImpl(NamedTuple): @@ -105,7 +114,7 @@ def pprint(self): ])))) -prngs = {} +prngs: dict[str, PRNGImpl] = {} def register_prng(impl: PRNGImpl): if impl.name in prngs: @@ -148,7 +157,7 @@ class behave like an array whose base elements are keys, hiding the # device_buffer, device_buffers, __cuda_interface__() _impl: PRNGImpl - _base_array: typing.Array + _base_array: jax.Array _consumed: bool | np.ndarray # Used in jax.experimental.key_reuse. _source_info: None | source_info_util.SourceInfo = None @@ -156,8 +165,13 @@ def __init__(self, impl, key_data: Any): assert not isinstance(key_data, core.Tracer) _check_prng_key_data(impl, key_data) self._impl = impl - self._base_array = key_data self._consumed = False # TODO(jakevdp): default to True here? + if isinstance(key_data, np.ndarray): + aval = core.get_aval(key_data) + device = pxla.get_default_device() + key_data = pxla.batched_device_put(aval, SingleDeviceSharding(device), + [key_data], [device], committed=False) + self._base_array = key_data def block_until_ready(self): _ = self._base_array.block_until_ready() @@ -168,9 +182,8 @@ def copy_to_host_async(self): @property def aval(self): - logical_sharding = (self.sharding if hasattr(self._base_array, 'sharding') - else None) - return keys_shaped_array(self._impl, self.shape, logical_sharding) + vma = self._base_array.aval.vma + return keys_shaped_array(self._impl, self.shape, self.sharding, vma) @property def shape(self): @@ -188,6 +201,10 @@ def ndim(self): def dtype(self): return KeyTy(self._impl) + @property + def nbytes(self): + return self.itemsize * self.size + @property def itemsize(self): return self.dtype.itemsize @@ -321,8 +338,8 @@ def seed_with_impl(impl: PRNGImpl, seed: int | typing.ArrayLike) -> PRNGKeyArray return random_seed(seed, impl=impl) -def keys_shaped_array(impl, shape, sharding): - aval = core.ShapedArray(shape, KeyTy(impl)) +def keys_shaped_array(impl, shape, sharding, vma): + aval = core.ShapedArray(shape, KeyTy(impl), vma=vma) return core.update_aval_with_sharding(aval, sharding) def base_arr_shape_to_keys_shape(impl, base_arr_shape): @@ -415,7 +432,6 @@ def device_put_sharded(vals, aval, sharding, devices): @staticmethod def device_put_replicated(val, aval, sharding, devices): physical_aval = core.physical_aval(aval) - assert len(xla.aval_to_xla_shapes(physical_aval)) == 1 physical_buf = random_unwrap(val) phys_sharding = physical_sharding(aval, sharding) physical_result = pxla.batched_device_put( @@ -542,7 +558,8 @@ def random_seed(seeds: int | typing.ArrayLike, impl: PRNGImpl) -> PRNGKeyArray: @random_seed_p.def_abstract_eval def random_seed_abstract_eval(seeds_aval, *, impl): - return keys_shaped_array(impl, seeds_aval.shape, seeds_aval.sharding) + return keys_shaped_array(impl, seeds_aval.shape, seeds_aval.sharding, + seeds_aval.vma) @random_seed_p.def_impl def random_seed_impl(seeds, *, impl): @@ -577,7 +594,7 @@ def random_split_abstract_eval(keys_aval, *, shape): # don't choose None here? new_spec = (*keys_aval.sharding.spec, *[None] * len(shape)) return keys_shaped_array(keys_aval.dtype._impl, (*keys_aval.shape, *shape), - keys_aval.sharding.with_spec(new_spec)) + keys_aval.sharding.with_spec(new_spec), keys_aval.vma) @random_split_p.def_impl def random_split_impl(keys, *, shape): @@ -603,7 +620,9 @@ def random_split_lowering(ctx, keys, *, shape): def random_fold_in(keys, msgs): - return random_fold_in_p.bind(keys, jnp.asarray(msgs)) + msgs = jnp.asarray(msgs) + keys, msgs = core.standard_insert_pvary(keys, msgs) + return random_fold_in_p.bind(keys, msgs) random_fold_in_p = core.Primitive('random_fold_in') ad.defjvp_zero(random_fold_in_p) @@ -613,7 +632,10 @@ def random_fold_in(keys, msgs): def random_fold_in_abstract_eval(keys_aval, msgs_aval): shape = lax_internal.broadcasting_shape_rule( 'random_fold_in', keys_aval, msgs_aval) - return core.ShapedArray(shape, keys_aval.dtype) + sharding = lax_internal.broadcasting_sharding_rule( + 'random_fold_in', keys_aval, msgs_aval) + vma = core.standard_vma_rule('random_fold_in', keys_aval, msgs_aval) + return core.ShapedArray(shape, keys_aval.dtype, sharding=sharding, vma=vma) @random_fold_in_p.def_impl def random_fold_in_impl(keys, msgs): @@ -651,7 +673,7 @@ def random_bits(keys, bit_width, shape): def random_bits_abstract_eval(keys_aval, *, bit_width, shape): out_shape = (*keys_aval.shape, *shape) out_dtype = dtypes.dtype(f'uint{bit_width}') - return core.ShapedArray(out_shape, out_dtype) + return core.ShapedArray(out_shape, out_dtype, vma=keys_aval.vma) @random_bits_p.def_impl def random_bits_impl(keys, *, bit_width, shape): @@ -708,7 +730,7 @@ def random_wrap(base_arr, *, impl): def random_wrap_abstract_eval(base_arr_aval, *, impl): shape = base_arr_shape_to_keys_shape(impl, base_arr_aval.shape) sharding = logical_sharding(shape, KeyTy(impl), base_arr_aval.sharding) - return keys_shaped_array(impl, shape, sharding) + return keys_shaped_array(impl, shape, sharding, base_arr_aval.vma) @random_wrap_p.def_impl def random_wrap_impl(base_arr, *, impl): @@ -902,7 +924,7 @@ def _threefry2x32_lowering(key1, key2, x1, x2, use_rolled_loops=True): multiple_results=True) -def _threefry2x32_gpu_lowering_rule(lowering_func, ctx, k1, k2, x1, x2): +def _threefry2x32_gpu_lowering_rule(ctx, k1, k2, x1, x2, *, target_name_prefix): if not config.threefry_gpu_kernel_lowering.value: # back to default lowering return _threefry2x32_lowering_rule(ctx, k1, k2, x1, x2) @@ -917,23 +939,11 @@ def _broadcast(x, aval): return mlir.broadcast_in_dim(ctx, x, aval_out, broadcast_dimensions=range(rank - len(aval.shape), rank)) - out_len = reduce(op.mul, aval_out.shape, 1) - if not core.is_constant_dim(out_len): - length = mlir.eval_dynamic_shape_as_tensor(ctx, [out_len]) - length = mlir.hlo.convert( - ir.RankedTensorType.get((1,), ir.IntegerType.get_signless(64)), - length) - output_shape = mlir.eval_dynamic_shape_as_tensor(ctx, aval_out.shape) - else: - length = int(out_len) # will be passed statically - output_shape = None - - return lowering_func( - (_broadcast(k1, k1_aval), _broadcast(k2, k2_aval)), - (_broadcast(x1, x1_aval), _broadcast(x2, x2_aval)), length, - output_shape, - False, # forward_compatibility_mode - ) + sub_ctx = ctx.replace(avals_in=(aval_out,) * 4) + rule = ffi.ffi_lowering( + f"{target_name_prefix}_threefry2x32_ffi") + return rule(sub_ctx, _broadcast(k1, k1_aval), _broadcast(k2, k2_aval), + _broadcast(x1, x1_aval), _broadcast(x2, x2_aval)) threefry2x32_p = core.Primitive("threefry2x32") @@ -947,11 +957,11 @@ def _broadcast(x, aval): threefry2x32_p, _threefry2x32_cpu_lowering_rule, platform='cpu') mlir.register_lowering( threefry2x32_p, - partial(_threefry2x32_gpu_lowering_rule, gpu_prng.cuda_threefry2x32), + partial(_threefry2x32_gpu_lowering_rule, target_name_prefix='cu'), platform='cuda') mlir.register_lowering( threefry2x32_p, - partial(_threefry2x32_gpu_lowering_rule, gpu_prng.rocm_threefry2x32), + partial(_threefry2x32_gpu_lowering_rule, target_name_prefix='hip'), platform='rocm') diff --git a/jax/_src/profiler.py b/jax/_src/profiler.py index f06933f57e22..912c90182977 100644 --- a/jax/_src/profiler.py +++ b/jax/_src/profiler.py @@ -33,6 +33,7 @@ from jax._src import xla_bridge from jax._src.lib import xla_client +from jax._src.lib import version as jaxlib_version _profiler_server: xla_client.profiler.ProfilerServer | None = None @@ -271,7 +272,6 @@ class TraceAnnotation(xla_client.profiler.TraceMe): This will cause a "my_label" event to show up on the trace timeline if the event occurs while the process is being traced. """ - pass class StepTraceAnnotation(TraceAnnotation): @@ -332,7 +332,6 @@ def annotate_function(func: Callable, name: str | None = None, def wrapper(*args, **kwargs): with TraceAnnotation(name, **decorator_kwargs): return func(*args, **kwargs) - return wrapper return wrapper @@ -426,6 +425,10 @@ def trace(cls, runner: PGLEProfiler | None): else: options = xla_client.profiler.ProfileOptions() options.enable_hlo_proto = True + + # ToDo(patrios): Remove when jaxlib version is updated to 0.5.4. + if jaxlib_version > (0, 5, 3): + options.raise_error_on_start_failure = True runner.current_session = xla_client.profiler.ProfilerSession(options) try: diff --git a/jax/_src/public_test_util.py b/jax/_src/public_test_util.py index 455a3b98cce2..3b1e24bc9c50 100644 --- a/jax/_src/public_test_util.py +++ b/jax/_src/public_test_util.py @@ -14,6 +14,7 @@ from functools import partial import operator +from typing import Any, TypeAlias from jax._src import api from jax._src import config @@ -32,7 +33,7 @@ EPS = 1e-4 -def _dtype(x): +def _dtype(x: Any) -> np.dtype: if hasattr(x, 'dtype'): return x.dtype elif type(x) in _dtypes.python_scalar_dtypes: @@ -40,20 +41,27 @@ def _dtype(x): else: return np.asarray(x).dtype +ToleranceDict: TypeAlias = dict[np.dtype, int | float] -_default_tolerance = { +_default_tolerance: ToleranceDict = { _dtypes.float0: 0, np.dtype(np.bool_): 0, + np.dtype(_dtypes.int2): 0, np.dtype(_dtypes.int4): 0, np.dtype(np.int8): 0, np.dtype(np.int16): 0, np.dtype(np.int32): 0, np.dtype(np.int64): 0, + np.dtype(_dtypes.uint2): 0, np.dtype(_dtypes.uint4): 0, np.dtype(np.uint8): 0, np.dtype(np.uint16): 0, np.dtype(np.uint32): 0, np.dtype(np.uint64): 0, + np.dtype(_dtypes.float4_e2m1fn): 1e0, + np.dtype(_dtypes.float8_e3m4): 1e-1, + np.dtype(_dtypes.float8_e4m3): 1e-1, + np.dtype(_dtypes.float8_e8m0fnu): 1e0, np.dtype(_dtypes.float8_e4m3b11fnuz): 1e-1, np.dtype(_dtypes.float8_e4m3fn): 1e-1, np.dtype(_dtypes.float8_e4m3fnuz): 1e-1, @@ -67,16 +75,15 @@ def _dtype(x): np.dtype(np.complex128): 1e-15, } -if _dtypes.int2 is not None: - assert _dtypes.uint2 is not None - _default_tolerance[np.dtype(_dtypes.int2)] = 0 - _default_tolerance[np.dtype(_dtypes.uint2)] = 0 - def default_tolerance(): return _default_tolerance -default_gradient_tolerance = { +default_gradient_tolerance: ToleranceDict = { + np.dtype(_dtypes.float4_e2m1fn): 1e0, + np.dtype(_dtypes.float8_e3m4): 1e-1, + np.dtype(_dtypes.float8_e4m3): 1e-1, + np.dtype(_dtypes.float8_e8m0fnu): 1e0, np.dtype(_dtypes.float8_e4m3b11fnuz): 1e-1, np.dtype(_dtypes.float8_e4m3fn): 1e-1, np.dtype(_dtypes.float8_e4m3fnuz): 1e-1, @@ -90,21 +97,8 @@ def default_tolerance(): np.dtype(np.complex128): 1e-5, } -# TODO: make this unconditional when ml_dtypes>=0.5.0 is required -if _dtypes.float8_e3m4 is not None: - _default_tolerance[np.dtype(_dtypes.float8_e3m4)] = 1e-1 - default_gradient_tolerance[np.dtype(_dtypes.float8_e3m4)] = 1e-1 -if _dtypes.float8_e4m3 is not None: - _default_tolerance[np.dtype(_dtypes.float8_e4m3)] = 1e-1 - default_gradient_tolerance[np.dtype(_dtypes.float8_e4m3)] = 1e-1 -if _dtypes.float8_e8m0fnu is not None: - _default_tolerance[np.dtype(_dtypes.float8_e8m0fnu)] = 1e0 - default_gradient_tolerance[np.dtype(_dtypes.float8_e8m0fnu)] = 1e0 -if _dtypes.float4_e2m1fn is not None: - _default_tolerance[np.dtype(_dtypes.float4_e2m1fn)] = 1e0 - default_gradient_tolerance[np.dtype(_dtypes.float4_e2m1fn)] = 1e0 - -def is_python_scalar(val): + +def is_python_scalar(val: Any) -> bool: return not isinstance(val, np.generic) and isinstance(val, (bool, int, float, complex)) def _assert_numpy_allclose(a, b, atol=None, rtol=None, err_msg=''): @@ -113,6 +107,10 @@ def _assert_numpy_allclose(a, b, atol=None, rtol=None, err_msg=''): return custom_float_dtypes = [ + _dtypes.float4_e2m1fn, + _dtypes.float8_e8m0fnu, + _dtypes.float8_e3m4, + _dtypes.float8_e4m3, _dtypes.float8_e4m3b11fnuz, _dtypes.float8_e4m3fn, _dtypes.float8_e4m3fnuz, @@ -121,15 +119,6 @@ def _assert_numpy_allclose(a, b, atol=None, rtol=None, err_msg=''): _dtypes.bfloat16, ] - if _dtypes.float8_e4m3 is not None: - custom_float_dtypes.insert(0, _dtypes.float8_e4m3) - if _dtypes.float8_e3m4 is not None: - custom_float_dtypes.insert(0, _dtypes.float8_e3m4) - if _dtypes.float8_e8m0fnu is not None: - custom_float_dtypes.insert(0, _dtypes.float8_e8m0fnu) - if _dtypes.float4_e2m1fn is not None: - custom_float_dtypes.insert(0, _dtypes.float4_e2m1fn) - def maybe_upcast(x): if x.dtype in custom_float_dtypes: return x.astype(np.float32) @@ -151,7 +140,8 @@ def maybe_upcast(x): # value errors. It should not do that. np.testing.assert_allclose(a, b, **kw, err_msg=err_msg) -def tolerance(dtype, tol=None): + +def tolerance(dtype: np.dtype, tol: int | float | ToleranceDict | None = None) -> int | float: tol = {} if tol is None else tol if not isinstance(tol, dict): return tol diff --git a/jax/_src/random.py b/jax/_src/random.py index 094268c65825..b29c1dca7b08 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -41,6 +41,8 @@ from jax._src.lax import lax as lax_internal from jax._src.numpy.lax_numpy import _convert_and_clip_integer from jax._src.numpy.util import _arraylike, check_arraylike, promote_dtypes_inexact +from jax._src.pjit import auto_axes +from jax._src.sharding_impls import canonicalize_sharding from jax._src.typing import Array, ArrayLike, DTypeLike from jax._src.util import canonicalize_axis @@ -223,7 +225,7 @@ def PRNGKey(seed: int | ArrayLike, *, This function produces old-style legacy PRNG keys, which are arrays of dtype ``uint32``. For more, see the note in the `PRNG keys - `_ + `_ section. When possible, :func:`jax.random.key` is recommended for use instead. @@ -346,9 +348,19 @@ def _check_shape(name: str, shape: Shape, *param_shapes) -> None: raise ValueError(msg.format(name, shape_, shape)) +def maybe_auto_axes(f, out_shardings, **hoist_kwargs): + f_ = partial(f, **hoist_kwargs) + if out_shardings is None: + return f_ + else: + return auto_axes(f_, out_shardings=out_shardings) + + def bits(key: ArrayLike, shape: Shape = (), - dtype: DTypeLikeUInt | None = None) -> Array: + dtype: DTypeLikeUInt | None = None, + *, + out_sharding=None) -> Array: """Sample uniform bits in the form of unsigned integers. Args: @@ -371,15 +383,19 @@ def bits(key: ArrayLike, f"got {dtype}") dtype = dtypes.canonicalize_dtype(dtype) shape = core.canonicalize_shape(shape) + out_sharding = canonicalize_sharding(out_sharding, "bits") bit_width = dtype.itemsize * 8 - return _random_bits(key, bit_width, shape) + return maybe_auto_axes(_random_bits, out_sharding, + bit_width=bit_width, shape=shape)(key) def uniform(key: ArrayLike, shape: Shape = (), dtype: DTypeLikeFloat = float, minval: RealArray = 0., - maxval: RealArray = 1.) -> Array: + maxval: RealArray = 1., + *, + out_sharding=None) -> Array: """Sample uniform random values in [minval, maxval) with given shape/dtype. Args: @@ -397,15 +413,17 @@ def uniform(key: ArrayLike, key, _ = _check_prng_key("uniform", key) dtypes.check_user_dtype_supported(dtype) shape = core.canonicalize_shape(shape) + out_sharding = canonicalize_sharding(out_sharding, "uniform") if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `uniform` must be a float dtype, " f"got {dtype}") dtype = dtypes.canonicalize_dtype(dtype) - return _uniform(key, shape, dtype, minval, maxval) + return maybe_auto_axes(_uniform, out_sharding, + shape=shape,dtype=dtype)(key, minval, maxval) -@partial(jit, static_argnums=(1, 2)) -def _uniform(key, shape, dtype, minval, maxval) -> Array: +@partial(jit, static_argnums=(3, 4)) +def _uniform(key, minval, maxval, shape, dtype) -> Array: _check_shape("uniform", shape) if not jnp.issubdtype(dtype, np.floating): raise TypeError("uniform only accepts floating point dtypes.") @@ -449,7 +467,9 @@ def randint(key: ArrayLike, shape: Shape, minval: IntegerArray, maxval: IntegerArray, - dtype: DTypeLikeInt = int) -> Array: + dtype: DTypeLikeInt = int, + *, + out_sharding=None) -> Array: """Sample uniform random values in [minval, maxval) with given shape/dtype. Args: @@ -469,10 +489,12 @@ def randint(key: ArrayLike, dtypes.check_user_dtype_supported(dtype) dtype = dtypes.canonicalize_dtype(dtype) shape = core.canonicalize_shape(shape) - return _randint(key, shape, minval, maxval, dtype) + out_sharding = canonicalize_sharding(out_sharding, "randint") + return maybe_auto_axes(_randint, out_sharding, shape=shape, dtype=dtype)( + key, minval, maxval) -@partial(jit, static_argnums=(1, 4)) -def _randint(key, shape, minval, maxval, dtype) -> Array: +@partial(jit, static_argnums=(3, 4)) +def _randint(key, minval, maxval, shape, dtype) -> Array: _check_shape("randint", shape, np.shape(minval), np.shape(maxval)) if not jnp.issubdtype(dtype, np.integer): raise TypeError(f"randint only accepts integer dtypes, got {dtype}") @@ -537,7 +559,9 @@ def _randint(key, shape, minval, maxval, dtype) -> Array: def permutation(key: ArrayLike, x: int | ArrayLike, axis: int = 0, - independent: bool = False) -> Array: + independent: bool = False, + *, + out_sharding=None) -> Array: """Returns a randomly permuted array or range. Args: @@ -554,11 +578,17 @@ def permutation(key: ArrayLike, key, _ = _check_prng_key("permutation", key) check_arraylike("permutation", x) axis = canonicalize_axis(axis, np.ndim(x) or 1) + out_sharding = canonicalize_sharding(out_sharding, "permutation") if not np.ndim(x): if not np.issubdtype(lax.dtype(x), np.integer): raise TypeError("x must be an integer or at least 1-dimensional") - r = core.concrete_or_error(int, x, 'argument x of jax.random.permutation()') - return _shuffle(key, jnp.arange(r), axis) + r = core.concrete_or_error(int, x, "argument x of jax.random.permutation()") + return maybe_auto_axes(lambda key: _shuffle(key, jnp.arange(r), axis), + out_sharding)(key) + return maybe_auto_axes( + _permutation, out_sharding, axis=axis, independent=independent)(key, x) + +def _permutation(key, x, axis, independent): if independent or np.ndim(x) == 1: return _shuffle(key, x, axis) ind = _shuffle(key, jnp.arange(x.shape[axis]), 0) # type: ignore[union-attr] @@ -680,7 +710,9 @@ def choice(key: ArrayLike, def normal(key: ArrayLike, shape: Shape = (), - dtype: DTypeLikeFloat = float) -> Array: + dtype: DTypeLikeFloat = float, + *, + out_sharding=None) -> Array: r"""Sample standard normal random values with given shape and float dtype. The values are returned according to the probability density function: @@ -702,12 +734,13 @@ def normal(key: ArrayLike, """ key, _ = _check_prng_key("normal", key) shape = core.canonicalize_shape(shape) + out_sharding = canonicalize_sharding(out_sharding, "normal") dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.inexact): raise ValueError(f"dtype argument to `normal` must be a float or complex dtype, " f"got {dtype}") dtype = dtypes.canonicalize_dtype(dtype) - return _normal(key, shape, dtype) + return maybe_auto_axes(_normal, out_sharding, shape=shape, dtype=dtype)(key) @partial(jit, static_argnums=(1, 2)) def _normal(key, shape, dtype) -> Array: @@ -818,7 +851,8 @@ def truncated_normal(key: ArrayLike, lower: RealArray, upper: RealArray, shape: Shape | None = None, - dtype: DTypeLikeFloat = float) -> Array: + dtype: DTypeLikeFloat = float, + *, out_sharding=None) -> Array: r"""Sample truncated standard normal random values with given shape and dtype. The values are returned according to the probability density function: @@ -849,12 +883,14 @@ def truncated_normal(key: ArrayLike, if shape is not None: shape = core.canonicalize_shape(shape) key, _ = _check_prng_key("truncated_normal", key) + out_sharding = canonicalize_sharding(out_sharding, "truncated_normal") dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `truncated_normal` must be a float " f"dtype, got {dtype}") dtype = dtypes.canonicalize_dtype(dtype) - return _truncated_normal(key, lower, upper, shape, dtype) + return maybe_auto_axes(_truncated_normal, out_sharding, + shape=shape, dtype=dtype)(key, lower, upper) @partial(jit, static_argnums=(3, 4)) def _truncated_normal(key, lower, upper, shape, dtype) -> Array: @@ -1537,7 +1573,8 @@ def gumbel(key: ArrayLike, def _gumbel(key, shape, dtype, mode) -> Array: _check_shape("gumbel", shape) if mode == "high": - high, low = _uniform(key, (2,) + shape, dtype, minval=0., maxval=1.) + high, low = _uniform(key, minval=0., maxval=1., + shape=(2,) + shape, dtype=dtype) # TODO(parkers): The condition is to protect against rounding up but # we should be able to add safely with the right addition operation. x = jnp.where(high >= 0.5, high, @@ -1545,7 +1582,8 @@ def _gumbel(key, shape, dtype, mode) -> Array: return -jnp.log(-jnp.log1p(-x)) else: return -jnp.log(-jnp.log( - _uniform(key, shape, dtype, minval=jnp.finfo(dtype).tiny, maxval=1.))) + _uniform(key, minval=jnp.finfo(dtype).tiny, maxval=1., + shape=shape, dtype=dtype))) def categorical( @@ -1568,8 +1606,8 @@ def categorical( shape: Optional, a tuple of nonnegative integers representing the result shape. Must be broadcast-compatible with ``np.delete(logits.shape, axis)``. The default (None) produces a result shape equal to ``np.delete(logits.shape, axis)``. - replace: If True, perform sampling without replacement. Default (False) is to - perform sampling with replacement. + replace: If True (default), perform sampling with replacement. If False, perform + sampling without replacement. Returns: A random array with int dtype and shape given by ``shape`` if ``shape`` diff --git a/jax/_src/scipy/linalg.py b/jax/_src/scipy/linalg.py index 9917cbaa0b12..55961607b252 100644 --- a/jax/_src/scipy/linalg.py +++ b/jax/_src/scipy/linalg.py @@ -2182,3 +2182,64 @@ def hilbert(n: int) -> Array: """ a = lax.broadcasted_iota(jnp.float64, (n, 1), 0) return 1/(a + a.T + 1) + +@partial(jit, static_argnames=("n", "kind",)) +def pascal(n: int, kind: str | None = None) -> Array: + r"""Create a Pascal matrix approximation of order n. + + JAX implementation of :func:`scipy.linalg.pascal`. + + The elements of the Pascal matrix approximate the binomial coefficents. This + implementation is not exact as JAX does not support exact factorials. + + Args: + n: the size of the matrix to create. + kind: (optional) must be one of ``lower``, ``upper``, or ``symmetric`` (default). + + Returns: + A Pascal matrix of shape ``(n, n)`` + + Examples: + >>> with jnp.printoptions(precision=3): + ... print(jax.scipy.linalg.pascal(3, kind="lower")) + ... print(jax.scipy.linalg.pascal(4, kind="upper")) + ... print(jax.scipy.linalg.pascal(5)) + [[1. 0. 0.] + [1. 1. 0.] + [1. 2. 1.]] + [[1. 1. 1. 1.] + [0. 1. 2. 3.] + [0. 0. 1. 3.] + [0. 0. 0. 1.]] + [[ 1. 1. 1. 1. 1.] + [ 1. 2. 3. 4. 5.] + [ 1. 3. 6. 10. 15.] + [ 1. 4. 10. 20. 35.] + [ 1. 5. 15. 35. 70.]] + """ + if kind is None: + kind = "symmetric" + + valid_kind = ["symmetric", "lower", "upper"] + + if kind not in valid_kind: + raise ValueError(f"Expected kind to be on of: {valid_kind}; got {kind}") + + a = jnp.arange(n, dtype=jnp.float32) + + L_n = _binom(a[:, None], a[None, :]) + + if kind == "lower": + return L_n + + if kind == "upper": + return L_n.T + + return jnp.dot(L_n, L_n.T) + +@jit +def _binom(n, k): + a = lax.lgamma(n + 1.0) + b = lax.lgamma(n - k + 1.0) + c = lax.lgamma(k + 1.0) + return lax.exp(a - b - c) diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index 2bbf913783e3..2b79e09c49e1 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -33,14 +33,12 @@ from jax._src import xla_bridge as xb from jax._src import mesh_utils from jax._src.lib import xla_client as xc -from jax._src.lib import xla_extension_version from jax._src.lib.mlir.dialects import sdy from jax._src.named_sharding import ( # noqa: F401 SdyArraySharding, SdyDimSharding, UnspecifiedValue, AUTO, - ParsedPartitionSpec, _check_unique_resources, NamedSharding, UNSPECIFIED, + _check_unique_resources, NamedSharding, UNSPECIFIED, ArrayMapping, ArrayMappingOrAutoOrUnspecified, get_array_mapping, - array_mapping_to_axis_resources, get_single_pspec, preprocess, - named_sharding_to_xla_hlo_sharding) + array_mapping_to_axis_resources, named_sharding_to_xla_hlo_sharding) from jax._src.op_shardings import ( are_op_shardings_equal, get_num_ways_dim_sharded, is_op_sharding_replicated) from jax._src.partition_spec import PartitionSpec @@ -103,10 +101,13 @@ def modify_sdy_sharding_wrt_axis_types(sdy_sharding: SdyArraySharding, mesh): dim_shardings, used_axes = [], [] # type: ignore for d in sdy_sharding.dimension_shardings: # TODO(yashkatariya): Maybe if any mesh axis is auto, mark all axes as open? - dim_shardings.append(SdyDimSharding(axes=[], is_closed=False) - if not d.axes and d.is_closed else d) + dim_shardings.append(SdyDimSharding(axes=[], is_open=True) + if not d.axes and not d.is_open else d) used_axes.extend(d.axes) remaining_axes = set(mesh.axis_names) - set(used_axes) + # Sort wrt mesh axis names so order is deterministic and doesn't hang in + # McJAX. + remaining_axes = [n for n in mesh.axis_names if n in remaining_axes] replicated_axes = tuple(r for r in remaining_axes if mesh._name_to_type[r] == mesh_lib.AxisType.Explicit) return SdyArraySharding(sdy_sharding.mesh_shape, dim_shardings, @@ -184,7 +185,7 @@ def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding: return replicated_hlo_sharding def _to_sdy_sharding(self, num_dimensions: int) -> SdyArraySharding: - sdy_dim_sharding = [SdyDimSharding(axes=[], is_closed=True) + sdy_dim_sharding = [SdyDimSharding(axes=[], is_open=False) for _ in range(num_dimensions)] return SdyArraySharding(None, sdy_dim_sharding) @@ -882,8 +883,7 @@ def parse_flatten_op_sharding( return out elif hlo_sharding.is_replicated(): return [PartitionSpec()] - elif (xla_extension_version >= 319 and hlo_sharding.is_maximal() - and mesh.size == 1): + elif hlo_sharding.is_maximal() and mesh.size == 1: return [PartitionSpec()] elif hlo_sharding.is_tiled(): mesh_shape = mesh.shape @@ -898,7 +898,11 @@ def parse_flatten_op_sharding( while dim_size > 1: axis = next(mesh_axis) axis_size = mesh_shape[axis] - assert dim_size % axis_size == 0 + if dim_size % axis_size != 0: + raise ValueError( + f'{shape=} is incompatible with {mesh_shape=}: ' + f'{dim_size=} is not divisible by {axis_size=}.' + ) dim_size //= axis_size dim_partitions.append(axis) partitions.append(tuple(dim_partitions)) @@ -1382,10 +1386,8 @@ def use_mesh(mesh: mesh_lib.Mesh): if not isinstance(mesh, mesh_lib.Mesh): raise ValueError( f"Expected mesh of type `jax.sharding.Mesh`. Got {type(mesh)}") - - # TODO(yashkatariya): Enable this. - # if not core.trace_state_clean(): - # raise ValueError('`use_mesh` can only be used outside of `jax.jit`') + if not core.trace_state_clean(): + raise ValueError('`use_mesh` can only be used outside of `jax.jit`') with mesh_lib.use_abstract_mesh(mesh.abstract_mesh), use_concrete_mesh(mesh): yield @@ -1410,13 +1412,16 @@ def set_mesh(mesh: mesh_lib.Mesh | None) -> mesh_lib.Mesh | None: @contextlib.contextmanager def use_concrete_mesh(mesh: mesh_lib.Mesh | None): + if not core.trace_state_clean(): + raise ValueError('`use_concrete_mesh` can only be used outside of `jax.jit`.') + with _internal_use_concrete_mesh(mesh): + yield + +@contextlib.contextmanager +def _internal_use_concrete_mesh(mesh: mesh_lib.Mesh | None): if mesh is not None and not isinstance(mesh, mesh_lib.Mesh): raise ValueError( f"Expected mesh of type `jax.sharding.Mesh`. Got {type(mesh)}") - # TODO(yashkatariya): Enable this. - # if not core.trace_state_clean(): - # raise ValueError('`use_concrete_mesh` can only be used outside of `jax.jit`.') - prev_val = config.device_context.swap_local(mesh) try: yield diff --git a/jax/_src/stages.py b/jax/_src/stages.py index 19cd0822aa58..b813037a3204 100644 --- a/jax/_src/stages.py +++ b/jax/_src/stages.py @@ -61,15 +61,23 @@ CompilerOptions = dict[str, Union[str, bool]] -# -- Internal protocols +# -- Internal types -class Executable(Protocol): - """Protocol for executables, which a user-facing ``Compiled`` encapsulates.""" + +class Executable: + + def xla_extension_executable(self) -> xc.LoadedExecutable: + raise NotImplementedError( + "compiled executable carries no loaded XLA executable. It may be " + f"that {type(self)} defines an incomplete implementation.") def call(self, *args_flat) -> Sequence[Any]: """Execute on the flat list of arguments, returning flat outputs.""" - # TODO(frostig): improve annotation (sequences of arrays/buffers) - raise NotImplementedError + raise NotImplementedError("compiled executable does not support invocation") + + def create_cpp_call(self, no_kwargs, in_tree, out_tree) -> Any: + """Optionally constructs a fast c++ dispatcher.""" + return None def input_shardings(self) -> Sequence[jax.sharding.Sharding]: """Flat sequence of input shardings. @@ -77,7 +85,8 @@ def input_shardings(self) -> Sequence[jax.sharding.Sharding]: May raise ``NotImplementedError`` if unavailable, e.g. based on backend, compiler, or runtime. """ - raise NotImplementedError + raise NotImplementedError( + "compiled executable carries no input sharding information") def output_shardings(self) -> Sequence[jax.sharding.Sharding]: """Flat sequence of output shardings. @@ -85,13 +94,16 @@ def output_shardings(self) -> Sequence[jax.sharding.Sharding]: May raise ``NotImplementedError`` if unavailable, e.g. based on backend, compiler, or runtime. """ - raise NotImplementedError + raise NotImplementedError( + "compiled executable carries no output sharding information") def input_layouts(self): - raise NotImplementedError + raise NotImplementedError( + "compiled executable carries no input layout information") def output_layouts(self): - raise NotImplementedError + raise NotImplementedError( + "compiled executable carries no input layout information") def as_text(self) -> str: """A human-readable text representation of this executable. @@ -102,139 +114,6 @@ def as_text(self) -> str: May raise ``NotImplementedError`` if unavailable, e.g. based on backend, compiler, or runtime. """ - raise NotImplementedError - - def cost_analysis(self) -> Any: - """A summary of execution cost estimates. - - Intended for visualization and debugging purposes. The object output by - this is some simple data structure that can easily be printed or serialized - (e.g. nested dicts, lists, and tuples with numeric leaves). However, its - structure can be arbitrary: it need not be consistent across versions of JAX - and jaxlib, or even across invocations. It is relayed directly to external - callers. - - May raise ``NotImplementedError`` if unavailable, e.g. based on backend, - compiler, or runtime. - """ - # TODO(frostig): improve annotation (arbitrary pytree) - raise NotImplementedError - - def memory_analysis(self) -> Any: - """A summary of estimated memory requirements. - - Intended for visualization and debugging purposes. The object output by - this is some simple data structure that can easily be printed or serialized - (e.g. nested dicts, lists, and tuples with numeric leaves). However, its - structure can be arbitrary: it need not be consistent across versions of JAX - and jaxlib, or even across invocations. It is relayed directly to external - callers. - - May raise ``NotImplementedError`` if unavailable, e.g. based on backend, - compiler, or runtime. - """ - # TODO(frostig): improve annotation (arbitrary pytree) - raise NotImplementedError - - def runtime_executable(self) -> Any: - """An arbitrary object representation of this executable. - - Intended for debugging purposes. This need not be a valid nor reliable - serialization. It is relayed directly to external callers, with no - guarantee on type, structure, or consistency across invocations. - - May raise ``NotImplementedError`` if unavailable, e.g. based on backend or - compiler. - """ - raise NotImplementedError - - def create_cpp_call(self, no_kwargs, in_tree, out_tree) -> Any: - """Optionally constructs a fast c++ dispatcher.""" - return None - - -class Lowering(Protocol): - """Protocol for lowerings, which a user-facing ``Lowered`` encapsulates.""" - - def compile( - self, compiler_options: CompilerOptions | None = None) -> Executable: - """Compile and return a corresponding ``Executable``.""" - raise NotImplementedError - - def as_text(self, dialect: str | None = None, *, - debug_info: bool = False) -> str: - """A human-readable text representation of this lowering. - - Intended for visualization and debugging purposes. This need not be a valid - nor reliable serialization. It is relayed directly to external callers. - """ - raise NotImplementedError - - def compiler_ir(self, dialect: str | None = None) -> Any: - """An arbitrary object representation of this lowering. - - Intended for debugging purposes. This need not be a valid nor reliable - serialization. It is relayed directly to external callers, with no - guarantee on type, structure, or consistency across invocations. - - May raise ``NotImplementedError`` if unavailable, e.g. based on backend or - compiler. - - Args: - dialect: Optional string specifying a representation dialect - (e.g. "stablehlo") - """ - raise NotImplementedError - - def cost_analysis(self) -> Any: - """A summary of execution cost estimates. - - Intended for visualization and debugging purposes. The object output by - this is some simple data structure that can easily be printed or serialized - (e.g. nested dicts, lists, and tuples with numeric leaves). However, its - structure can be arbitrary: it need not be consistent across versions of JAX - and jaxlib, or even across invocations. It is relayed directly to external - callers. - - This function estimates execution cost in the absence of compiler - optimizations, which may drastically affect the cost. For execution cost - estimates after optimizations, compile this lowering and see - ``Compiled.cost_analysis``. - - May raise ``NotImplementedError`` if unavailable, e.g. based on backend, - compiler, or runtime. - """ - # TODO(frostig): improve annotation (arbitrary pytree) - raise NotImplementedError - - -# -- Internal adapters from XLA-related objects to the above protocols - -class XlaExecutable(Executable): - - def xla_extension_executable(self) -> xc.LoadedExecutable: - raise NotImplementedError("must override") - - def call(self, *args_flat) -> Sequence[Any]: - raise NotImplementedError("must override") - - def input_shardings(self) -> Sequence[jax.sharding.Sharding]: - raise NotImplementedError( - "compiled executable carries no input sharding information") - - def output_shardings(self) -> Sequence[jax.sharding.Sharding]: - raise NotImplementedError( - "compiled executable carries no output sharding information") - - def input_layouts(self): - raise NotImplementedError( - "compiled executable carries no input layout information") - - def output_layouts(self): - raise NotImplementedError( - "compiled executable carries no input layout information") - - def as_text(self) -> str: xla_ext_exe = self.xla_extension_executable() err_msg = ("text view unsupported on current XLA backend: " f"{type(xla_ext_exe)}") @@ -249,7 +128,19 @@ def as_text(self) -> str: else: raise - def cost_analysis(self) -> dict[str, float]: + def cost_analysis(self) -> Any: + """A summary of execution cost estimates. + + Intended for visualization and debugging purposes. The object output by + this is some simple data structure that can easily be printed or serialized + (e.g. nested dicts, lists, and tuples with numeric leaves). However, its + structure can be arbitrary: it need not be consistent across versions of JAX + and jaxlib, or even across invocations. It is relayed directly to external + callers. + + May raise ``NotImplementedError`` if unavailable, e.g. based on backend, + compiler, or runtime. + """ xla_ext_exe = self.xla_extension_executable() if hasattr(xla_ext_exe, "cost_analysis"): @@ -273,6 +164,18 @@ def cost_analysis(self) -> dict[str, float]: ) def memory_analysis(self) -> Any: + """A summary of estimated memory requirements. + + Intended for visualization and debugging purposes. The object output by + this is some simple data structure that can easily be printed or serialized + (e.g. nested dicts, lists, and tuples with numeric leaves). However, its + structure can be arbitrary: it need not be consistent across versions of JAX + and jaxlib, or even across invocations. It is relayed directly to external + callers. + + May raise ``NotImplementedError`` if unavailable, e.g. based on backend, + compiler, or runtime. + """ xla_ext_exe = self.xla_extension_executable() err_msg = ("memory analysis unsupported on current XLA backend: " f"{type(xla_ext_exe)}") @@ -288,11 +191,19 @@ def memory_analysis(self) -> Any: raise def runtime_executable(self) -> Any: + """An arbitrary object representation of this executable. + + Intended for debugging purposes. This need not be a valid nor reliable + serialization. It is relayed directly to external callers, with no + guarantee on type, structure, or consistency across invocations. + + May raise ``NotImplementedError`` if unavailable, e.g. based on backend or + compiler. + """ return self.xla_extension_executable() -class XlaLowering(Lowering): - """Adapts our various internal XLA-backed computations into a ``Lowering``.""" +class Lowering: compile_args: dict[str, Any] @@ -306,15 +217,23 @@ def hlo(self) -> xc.XlaComputation: def stablehlo(self) -> ir.Module: """Return a StableHLO representation of this computation.""" - raise NotImplementedError("must override") + raise NotImplementedError( + f"cost analysis unsupported on XLA computation: {type(self)}") def compile( self, compiler_options: CompilerOptions | None = None) -> Executable: - raise NotImplementedError("must override") + """Compile and return a corresponding ``Executable``.""" + raise NotImplementedError( + f"cost analysis unsupported on XLA computation: {type(self)}") def as_text(self, dialect: str | None = None, *, debug_info: bool = False) -> str: + """A human-readable text representation of this lowering. + + Intended for visualization and debugging purposes. This need not be a valid + nor reliable serialization. It is relayed directly to external callers. + """ if dialect is None: dialect = "stablehlo" if dialect == "stablehlo": @@ -328,6 +247,19 @@ def as_text(self, dialect: str | None = None, raise ValueError(f"unknown dialect: {dialect}") def compiler_ir(self, dialect: str | None = None) -> Any: + """An arbitrary object representation of this lowering. + + Intended for debugging purposes. This need not be a valid nor reliable + serialization. It is relayed directly to external callers, with no + guarantee on type, structure, or consistency across invocations. + + May raise ``NotImplementedError`` if unavailable, e.g. based on backend or + compiler. + + Args: + dialect: Optional string specifying a representation dialect + (e.g. "stablehlo") + """ if dialect is None: dialect = "stablehlo" if dialect == "stablehlo": @@ -337,8 +269,26 @@ def compiler_ir(self, dialect: str | None = None) -> Any: else: raise ValueError(f"unknown dialect: {dialect}") - def cost_analysis(self) -> dict[str, float]: - raise NotImplementedError("must override") + def cost_analysis(self) -> Any: + """A summary of execution cost estimates. + + Intended for visualization and debugging purposes. The object output by + this is some simple data structure that can easily be printed or serialized + (e.g. nested dicts, lists, and tuples with numeric leaves). However, its + structure can be arbitrary: it need not be consistent across versions of JAX + and jaxlib, or even across invocations. It is relayed directly to external + callers. + + This function estimates execution cost in the absence of compiler + optimizations, which may drastically affect the cost. For execution cost + estimates after optimizations, compile this lowering and see + ``Compiled.cost_analysis``. + + May raise ``NotImplementedError`` if unavailable, e.g. based on backend, + compiler, or runtime. + """ + raise NotImplementedError( + f"cost analysis unsupported on XLA computation: {type(self)}") # -- Public-facing API, plus helpers @@ -488,7 +438,7 @@ def runtime_executable(self) -> Any | None: @property def input_shardings(self): # PyTree[sharding.Sharding] - shardings_flat = self._executable.input_shardings() + shardings_flat = self._executable._in_shardings # Some input shardings got DCE'd if self.in_tree.num_leaves > len(shardings_flat): iter_shardings_flat = iter(shardings_flat) @@ -498,13 +448,14 @@ def input_shardings(self): # PyTree[sharding.Sharding] @property def output_shardings(self): # PyTree[sharding.Sharding] - shardings_flat = self._executable.output_shardings() + shardings_flat = self._executable._out_shardings return tree_util.tree_unflatten(self.out_tree, shardings_flat) # pytype: disable=attribute-error @property def input_layouts(self): - layouts_flat = self._executable.input_layouts() - assert all(isinstance(l, Layout) for l in layouts_flat) + dll_flat = self._executable._xla_in_layouts + layouts_flat = [Layout(l, s) + for l, s in zip(dll_flat, self._executable._in_shardings)] # Some input layouts got DCE'd if self.in_tree.num_leaves > len(layouts_flat): iter_layouts_flat = iter(layouts_flat) @@ -514,8 +465,9 @@ def input_layouts(self): @property def output_layouts(self): - layouts_flat = self._executable.output_layouts() - assert all(isinstance(l, Layout) for l in layouts_flat) + dll_flat = self._executable._xla_out_layouts + layouts_flat = [Layout(l, s) + for l, s in zip(dll_flat, self._executable._out_shardings)] return tree_util.tree_unflatten(self.out_tree, layouts_flat) # pytype: disable=attribute-error @staticmethod @@ -593,14 +545,14 @@ class Lowered(Stage): lowering paths (:func:`~jax.jit`, :func:`~jax.pmap`, etc.). """ __slots__ = ["_lowering", "args_info", "out_tree", "_no_kwargs"] - _lowering: XlaLowering + _lowering: Lowering args_info: Any # PyTree of ArgInfo out_tree: tree_util.PyTreeDef _no_kwargs: bool def __init__( self, - lowering: XlaLowering, + lowering: Lowering, args_info, # PyTree of ArgInfo out_tree: tree_util.PyTreeDef, no_kwargs: bool = False): @@ -612,7 +564,7 @@ def __init__( @classmethod def from_flat_info(cls, - lowering: XlaLowering, + lowering: Lowering, in_tree: tree_util.PyTreeDef, in_avals, donate_argnums: tuple[int, ...], diff --git a/jax/_src/state/discharge.py b/jax/_src/state/discharge.py index 7ab77d5b1c37..615fa862bf31 100644 --- a/jax/_src/state/discharge.py +++ b/jax/_src/state/discharge.py @@ -275,33 +275,97 @@ def _maybe_convert_to_dynamic_slice( return starts, sizes, squeeze_dims -def _convert_to_array_indexer(indexer: indexing.NDIndexer - ) -> tuple[int | Array, ...]: - # This is the general gather case. We need to create the gather arrays. - is_integer_indexer, _, integer_indexer = ( - indexing.unpack_ndindexer(indexer) +# In this code, indexing is handled in three ways: `slice`, `dynamic_slice`, and +# gather. For the gather case, the goal is to create a gather array, which means +# that we need to convert all other types of indexers into integer array +# indexers. This is done by looping over all indexers and checking if they are +# not integer array indexers, and if not, performing the conversion. However, +# during this process, the indexing semantics may change. Specifically, +# according to the indexing rules of NumPy, when there are integer array +# indexers separated by other indexers, the axes corresponding to the integer +# array indexers need to be moved to the front. After we convert all other +# indexers to integer array indexers, the distinction between integer array +# indexers and other types of indexers is lost. As a result, it becomes +# impossible to determine which axes should be moved to the front. In this case, +# we need to transpose the target array before the gather operation. We also +# need to transpose the target array back after the gather operation, if it is +# used in subsequent computations. +def _maybe_transpose_before_gather( + indexer: indexing.NDIndexer +) -> tuple[int, ...] | None: + is_int_indexing, _, _ = indexing.unpack_ndindexer(indexer) + + int_indexers_contiguous = bool( + np.all(np.diff(np.where(is_int_indexing)[0]) == 1) ) - total_shape = indexer.get_indexer_shape() - int_indexer_shape = indexer.int_indexer_shape - slice_shape = total_shape[len(int_indexer_shape):] - slice_dims = tuple( - i + len(int_indexer_shape) for i in range(len(slice_shape)) + if int_indexers_contiguous: + return None # no transpose needed + + int_indexer_idxs: list[int] = [] + non_int_indexer_idxs: list[int] = [] + for i, is_int_index in enumerate(is_int_indexing): + (int_indexer_idxs if is_int_index else non_int_indexer_idxs).append(i) + transpose_order = (*int_indexer_idxs, *non_int_indexer_idxs) + return transpose_order + + +def _perform_transpose_before_gather( + target_arr: Array, + indexer: indexing.NDIndexer, + transpose_order: tuple[int, ...], +) -> tuple[Array, indexing.NDIndexer]: + new_target_arr = target_arr.transpose(transpose_order) + reordered_indices = tuple(indexer.indices[i] for i in transpose_order) + new_indexer = indexing.NDIndexer( + indices=reordered_indices, + shape=indexer.shape, + int_indexer_shape=indexer.int_indexer_shape, ) - slice_dim_iter = iter(slice_dims) - slice_indexer: list[Array] = [] - for idx, is_int_index in zip(indexer.indices, is_integer_indexer): - if not is_int_index: - assert isinstance(idx, indexing.Slice) - slice_indices = lax.broadcasted_iota( - np.dtype("int32"), total_shape, next(slice_dim_iter) - ) * idx.stride + idx.start - slice_indexer.append(slice_indices) - integer_indexer = tuple( - lax.expand_dims(idx, (-1,)) for idx in integer_indexer + return new_target_arr, new_indexer + + +def _convert_to_gather_arrays(indexer: indexing.NDIndexer) -> tuple[Array, ...]: + # This is the general gather case. We need to create the gather arrays. + total_shape = indexer.get_indexer_shape() + is_int_indexing, _, _ = indexing.unpack_ndindexer(indexer) + + if any(is_int_indexing): + n_idxers = len(indexer.indices) + int_indexer_shape = indexer.int_indexer_shape + n_int_indexers = sum(1 for p in is_int_indexing if p) + last_int_index_idx = n_idxers - 1 - is_int_indexing[::-1].index(True) + n_slice_index_dims_after_int = n_idxers - last_int_index_idx - 1 + + def get_idx_in_shape_after_indexing(i): + if not any(is_int_indexing): + return i + + if i < n_idxers - n_slice_index_dims_after_int - n_int_indexers: + return i + if i < n_idxers - n_slice_index_dims_after_int: + raise ValueError + return i - n_int_indexers + len(int_indexer_shape) + + arrs = [] + for i, idxer in enumerate(indexer.indices): + if isinstance(idxer, indexing.Slice): + idx_in_shape_after_indexing = get_idx_in_shape_after_indexing(i) + arr = ( + lax.iota(np.int32, total_shape[idx_in_shape_after_indexing]) + * idxer.stride + + idxer.start ) - continue - assert next(slice_dim_iter, None) is None - return tuple(merge_lists(is_integer_indexer, slice_indexer, integer_indexer)) + diff = len(total_shape) - idx_in_shape_after_indexing - 1 + arr = arr.reshape(arr.shape + (1,) * diff) + arrs.append(arr) + elif isinstance(idxer, (np.ndarray, Array)): + diff = n_idxers - 1 - last_int_index_idx + arr = idxer.reshape(idxer.shape + (1,) * diff) + arrs.append(arr) + else: + raise ValueError(f"Invalid type of idxer: {type(idxer).__name__}") + + return tuple(arrs) @register_discharge_rule(get_p) @@ -313,20 +377,8 @@ def _get_discharge_rule( y = _get_discharge(x, idx, tree) return (None,) * (len(idx) + 1), y -def _prepend_gather(x, indexer): - # NumPy advanced int indexing won't prepend w/ only one dim, so add dummy. - return x[None][(np.array(0, 'int32'), *indexer)] -def _prepend_scatter(x, indexer, val, *, add=False): - # NumPy advanced int indexing won't prepend w/ only one dim, so add dummy. - # However, since this is scatter, we need to remove the 1-sized dimension - # we added at the front. - if add: - return x[None].at[(0, *indexer)].add(val)[0] - return x[None].at[(0, *indexer)].set(val)[0] - - -def _index_array(x, indexer): +def _index_array(x, indexer: indexing.NDIndexer): if _is_trivial_indexer(indexer): return x # Try the three APIs in the following order: `lax.slice`, @@ -336,13 +388,16 @@ def _index_array(x, indexer): # If everything in the indexer is a slice or ()-shaped, we can also # use `lax.dynamic_slice` with 1-sized slices for ()-shaped indices. # We need to squeeze out the 1-sized slices at the end. - elif maybe_slice := _maybe_convert_to_dynamic_slice(indexer): - starts, sizes, squeeze_dims = maybe_slice + elif maybe_dynamic_slice := _maybe_convert_to_dynamic_slice(indexer): + starts, sizes, squeeze_dims = maybe_dynamic_slice y = lax_slicing.dynamic_slice(x, starts, sizes) x = lax.squeeze(y, squeeze_dims) else: - indexer = _convert_to_array_indexer(indexer) - x = x[None][(np.array(0, "int32"), *indexer)] + transpose_order = _maybe_transpose_before_gather(indexer) + if transpose_order is not None: + x, indexer = _perform_transpose_before_gather(x, indexer, transpose_order) + arrays = _convert_to_gather_arrays(indexer) + x = x[arrays] return x @@ -367,53 +422,79 @@ def transform_array(x, transforms): def transform_swap_array(x, transforms, val): if transforms is None: transforms = [] - result = x - result_val = val - # Compute updated "val" (result). - _results = [x] + + # Will hold the value read from `x` before the swap, and will have the same + # shape as `val`. + new_val = x + # List of intermediate results by transforming `x`. + intermediates = [x] + + # Read phase (forward loop) for transform in transforms: match transform: case indexing.NDIndexer(): indexer = transform if _is_trivial_indexer(indexer): - _results.append(_results[-1]) + intermediates.append(intermediates[-1]) continue # If everything in the indexer is a slice or ()-shaped, we can also # use `lax.dynamic_slice` with 1-sized slices for ()-shaped indices. # We need to squeeze out the 1-sized slices at the end. if maybe_slice := _maybe_convert_to_dynamic_slice(indexer): starts, sizes, squeeze_dims = maybe_slice - result_old = lax_slicing.dynamic_slice(result, starts, sizes) - result = lax.squeeze(result_old, squeeze_dims) + new_val = lax.squeeze( + lax_slicing.dynamic_slice(new_val, starts, sizes), squeeze_dims + ) else: - indexer = _convert_to_array_indexer(indexer) - result = _prepend_gather(result, indexer) - _results.append(result) + transpose_order = _maybe_transpose_before_gather(indexer) + if transpose_order is not None: + new_val, indexer = _perform_transpose_before_gather( + new_val, indexer, transpose_order + ) + arrays = _convert_to_gather_arrays(indexer) + new_val = new_val[arrays] + # Here, we don't need to transpose `new_val` back because it now holds + # the result of the indexing, and is no longer the original array that + # was indexed into. + intermediates.append(new_val) case RefBitcaster(): - _results.append(bitcast(result, transform.dtype)) + intermediates.append(bitcast(new_val, transform.dtype)) case RefReshaper(): - _results.append(result.reshape(transform.shape)) + intermediates.append(new_val.reshape(transform.shape)) case _: raise NotImplementedError(f"Unsupported transform: {transform}") - # Compute updated "x" (result_val) - for i, transform in reversed(list(enumerate(transforms))): + # Will hold the final state of the `x` after `val` has been written to the + # transformed location, and will have the same shape as `x`. + new_x = val + + # Write phase (reversed loop) + for intermediate, transform in reversed(zip(intermediates[:-1], transforms)): if isinstance(transform, indexing.NDIndexer): indexer = transform if _is_trivial_indexer(indexer): continue if maybe_slice := _maybe_convert_to_dynamic_slice(indexer): starts, _, squeeze_dims = maybe_slice - result_val = lax.expand_dims(result_val, squeeze_dims) - result_val = lax_slicing.dynamic_update_slice( - _results[i], result_val, starts + new_x = lax_slicing.dynamic_update_slice( + intermediate, lax.expand_dims(new_x, squeeze_dims), starts ) else: - indexer = _convert_to_array_indexer(indexer) - result_val = _prepend_scatter(_results[i], indexer, result_val) + transpose_order = _maybe_transpose_before_gather(indexer) + if transpose_order is not None: + intermediate, indexer = _perform_transpose_before_gather( + intermediate, indexer, transpose_order + ) + arrays = _convert_to_gather_arrays(indexer) + new_x = intermediate.at[arrays].set(new_x) # pytype: disable=attribute-error + if transpose_order is not None: + transpose_order_inversed = np.argsort(transpose_order) + new_x = new_x.transpose(transpose_order_inversed) else: raise NotImplementedError(f"Unsupported transform: {transform}") - return result, result_val + + return new_val, new_x + def _get_discharge(x, idx, tree): transforms = tree_util.tree_unflatten(tree, idx) @@ -446,8 +527,10 @@ def _addupdate_discharge(x, val, idx, tree): if len(transforms) > 1: raise NotImplementedError("Only single indexer is supported.") indexer = transforms[0] + if _is_trivial_indexer(indexer): return x + val + # If everything in the indexer is a slice or ()-shaped, we can also # use `lax.dynamic_slice` with 1-sized slices for ()-shaped indices. # We need to squeeze out the 1-sized slices at the end. @@ -457,8 +540,17 @@ def _addupdate_discharge(x, val, idx, tree): val = lax.expand_dims(val, squeeze_dims) y = lax_slicing.dynamic_update_slice(x, x_old + val, starts) return y - indexer = _convert_to_array_indexer(indexer) - return _prepend_scatter(x, indexer, val, add=True) + + transpose_order = _maybe_transpose_before_gather(indexer) + if transpose_order is not None: + x, indexer = _perform_transpose_before_gather(x, indexer, transpose_order) + arrays = _convert_to_gather_arrays(indexer) + x = x.at[arrays].add(val) + if transpose_order is not None: + transpose_order_inversed = np.argsort(transpose_order) + x = x.transpose(transpose_order_inversed) + return x + @weakref_lru_cache def _cached_closed_jaxpr_discharge(closed_jaxpr: core.ClosedJaxpr): diff --git a/jax/_src/state/indexing.py b/jax/_src/state/indexing.py index 4b627c1cd581..e6e6b8a5ee25 100644 --- a/jax/_src/state/indexing.py +++ b/jax/_src/state/indexing.py @@ -20,6 +20,7 @@ from typing import Any, Sequence, Union from jax._src import core +from jax._src import pretty_printer as pp from jax._src import tree_util from jax._src.typing import Array from jax._src.util import merge_lists @@ -78,6 +79,30 @@ def from_slice(cls, slc: slice, size: int) -> Slice: return cls(start, size, step) +def _pp_slice(context: core.JaxprPpContext, dim, slc: Slice) -> str: + start, size = slc.start, slc.size + if isinstance(start, core.Var): + start_str = core.pp_var(start, context) + size_str = ( + core.pp_var(size, context) if isinstance(size, core.Var) else str(size) + ) + return f"{start_str}:{start_str}+{size_str}" + else: + start_str = str(start) + if start == 0: + start_str = "" + if isinstance(size, core.Var): + size_str = core.pp_var(size, context) + if start_str: + return f"{start_str}:{start_str}+{size_str}" + else: + return f":{size_str}" + else: + end = start + size + end_str = "" if end == dim else str(end) + return f"{start_str}:{end_str}" + + def dslice( start: int | Array | None, size: int | Array | None = None, @@ -247,11 +272,21 @@ def from_indices_shape(cls, indices, shape) -> NDIndexer: return cls(indices, shape, int_indexer_shape, validate=True) def get_indexer_shape(self) -> tuple[int | Array, ...]: - _, slice_indexers, _ = unpack_ndindexer(self) - slice_shape = [s.size for s in slice_indexers] - # In NDIndexers, the int_indexer_shape is *always* at the front of the - # result. - return (*self.int_indexer_shape, *slice_shape) + is_int_indexing, slice_indexers, _ = unpack_ndindexer(self) + + slice_shape = tuple(s.size for s in slice_indexers) + int_indexers_contiguous = bool( + np.all(np.diff(np.where(is_int_indexing)[0]) == 1) + ) + if not int_indexers_contiguous: + return self.int_indexer_shape + slice_shape + + has_int_indexers = any(is_int_indexing) + if has_int_indexers: + pos = is_int_indexing.index(True) + return slice_shape[:pos] + self.int_indexer_shape + slice_shape[pos:] + + return slice_shape def transform_shape(self, shape: None | tuple[int | Array, ...]) -> None | tuple[int | Array, ...]: del shape # Unused @@ -282,3 +317,12 @@ def transform_sharding(self, sharding): f"along unsharded axes, but ref of shape {self.shape} " f"was sliced on axis {i}, which is sharded like {s}") return sharding + + def pretty_print(self, context: core.JaxprPpContext) -> pp.Doc: + indices = [] + for idx, dim in zip(self.indices, self.shape): + if isinstance(idx, Slice): + indices.append(_pp_slice(context, dim, idx)) + else: + indices.append(core.pp_var(idx, context)) # type: ignore + return pp.concat([pp.text("["), pp.text(",".join(indices)), pp.text("]")]) diff --git a/jax/_src/state/primitives.py b/jax/_src/state/primitives.py index 6f7570a5f3cd..1237da57f217 100644 --- a/jax/_src/state/primitives.py +++ b/jax/_src/state/primitives.py @@ -18,9 +18,12 @@ import types from typing import Any, Union +import numpy as np + from jax._src import ad_util from jax._src import core from jax._src import dispatch +from jax._src import dtypes from jax._src import pretty_printer as pp from jax._src import traceback_util from jax._src import tree_util @@ -34,15 +37,12 @@ AbstractRef, AccumEffect, ReadEffect, - RefBitcaster, - RefReshaper, Transform, TransformedRef, WriteEffect, ) from jax._src.typing import Array from jax._src.util import safe_map, safe_zip -import numpy as np ## General utilities @@ -144,10 +144,25 @@ def ref_swap( _function_name: str = "ref_swap", ) -> Array: """Sets a `Ref`'s value and returns the original value.""" + if hasattr(ref_or_view, 'dtype'): + value = _maybe_implicit_cast(ref_or_view.dtype, value) ref, transforms = get_ref_and_transforms(ref_or_view, idx, _function_name) flat_transforms, tree = tree_util.tree_flatten(transforms) return swap_p.bind(ref, value, *flat_transforms, tree=tree) +# TODO(slebedev,mattjj): replace with special handling of Python numeric types: +# if (isinstance(value, (int, float, complex)) and +# value == np.array(value, dtype).item()): return cast +def _maybe_implicit_cast(dtype, value): + aval = core.typeof(value) + if (aval.weak_type and + (dtypes.issubdtype(dtype, np.floating) and + dtypes.issubdtype(aval.dtype, np.floating)) or + (dtypes.issubdtype(dtype, np.integer) and + dtypes.issubdtype(aval.dtype, np.integer))): + return lax.convert_element_type(value, dtype) + return value + def ref_set( ref_or_view: AbstractRef | TransformedRef, @@ -248,7 +263,7 @@ def _swap_abstract_eval(ref_aval: AbstractRef, f"Expected shape: {expected_out_shape}. " f"Value shape: {val_aval.shape}. " f"Transforms: {transforms}. ") - if expected_out_dtype != val_aval.dtype and not val_aval.weak_type: + if expected_out_dtype != val_aval.dtype: raise ValueError( "Invalid dtype for `swap`. " f"Ref dtype: {expected_out_dtype}. " @@ -297,70 +312,6 @@ def _addupdate_abstract_eval(ref_aval: AbstractRef, pp_ref_var = partial(pp.color, intensity=pp.Intensity.NORMAL, foreground=pp.Color.GREEN) -def _pp_slice(context: core.JaxprPpContext, dim, slc: indexing.Slice - ) -> str: - start, size = slc.start, slc.size - if isinstance(start, core.Var): - start_str = core.pp_var(start, context) - size_str = ( - core.pp_var(size, context) - if isinstance(size, core.Var) - else str(size) - ) - return f'{start_str}:{start_str}+{size_str}' - else: - start_str = str(start) - if start == 0: - start_str = '' - if isinstance(size, core.Var): - size_str = core.pp_var(size, context) - if start_str: - return f'{start_str}:{start_str}+{size_str}' - else: - return f':{size_str}' - else: - end = start + size - end_str = '' if end == dim else str(end) - return f'{start_str}:{end_str}' - -def pp_indexer(context: core.JaxprPpContext,indexer: indexing.NDIndexer - ) -> pp.Doc: - indices = [] - for idx, dim in zip(indexer.indices, indexer.shape): - if isinstance(idx, indexing.Slice): - indices.append(_pp_slice(context, dim, idx)) - else: - indices.append(core.pp_var(idx, context)) # type: ignore - return pp.concat([pp.text("["), pp.text(','.join(indices)), pp.text("]")]) - - -def pp_bitcaster( - context: core.JaxprPpContext, bitcaster: RefBitcaster -) -> pp.Doc: - del context - return pp.text( - f"[bitcast({bitcaster.dtype}[{','.join(str(d) for d in bitcaster.shape)}])]" - ) - - -def pp_reshaper(context: core.JaxprPpContext, reshaper: RefReshaper) -> pp.Doc: - del context - return pp.text( - f"[reshape({reshaper.dtype}[{','.join(str(d) for d in reshaper.shape)}])]" - ) - - -def pp_transform(context: core.JaxprPpContext, transform: Transform) -> pp.Doc: - match transform: - case indexing.NDIndexer(): - return pp_indexer(context, transform) - case RefBitcaster(): - return pp_bitcaster(context, transform) - case RefReshaper(): - return pp_reshaper(context, transform) - case _: - return pp.text(f"[{transform}]") - def _pp_transforms( context: core.JaxprPpContext, @@ -369,7 +320,7 @@ def _pp_transforms( if not transforms: return pp.text("[...]") return pp.concat( - [pp_transform(context, transform) for transform in transforms] + [transform.pretty_print(context) for transform in transforms] ) @@ -503,11 +454,52 @@ def _state_partial_eval_custom(prim, saveable, unks_in, inst_in, eqn): ## get/swap/addupdate batching rules -def _batch_indexer(indexer: indexing.NDIndexer, dims, - axis_size: int, - ref_shape: tuple[int, ...], - ref_dim: int | batching.NotMapped, - idx_is_batched: bool) -> indexing.NDIndexer: +def _batch_indexer( + indexer: indexing.NDIndexer, + dims, + axis_size: int, + ref_shape: tuple[int, ...], + ref_dim: int | batching.NotMapped, + idx_is_batched: bool, +) -> indexing.NDIndexer: + """Converts a batched indexer into an unbatched one. + + This function handles the complexity of `vmap`-style batching where either the + `ref` being indexed, the indexer, or both may have batched dimensions. The + goal is to produce a new indexer that acts as if applied in a batched context, + but without actual batching, enabling downstream code to process it as usual. + + If any index in `indexer` is batched, all array indexers are normalized. If + the array indexer contains a batched dimension, the dimension is moved to the + front (axis 0). If the array indexer not batched, it is broadcasted to include + a batch dimension at the front. This is to guarantee that all array indexers + are still of the same shape. + + Slices are passed through unchanged unless they contain dynamic elements and + are themselves batched, which is currently unsupported. + + If `ref` is batched (`ref_dim` is not `NotMapped`), we simulate per-example + indexing by inserting a new iota array at the position corresponding to + `ref_dim` in the indexer. + + It is worth noting that if the array indexers in the original indexer are + contiguous, but become non-contiguous in the new indexer due to the insertion + of the iota, the dimensions corresponding to the array indexers will be moved + to the front in the indexing result. The batched dimension will be at axis 0, + while the dimensions corresponding to the array indexers in the original + indexer will start from axis 1. This behavior would cause a mismatch between + the original indexer and the new indexer. Callers must take this behavior into + account and properly transpose the arrays involved to avoid this mismatch. + + Args: + indexer: An `NDIndexer` that indexes into `ref`. + dims: A pytree with the same structure as `indexer`, indicating which + dimension (if any) is batched for each array indexer. + axis_size: Size of the batch dimension. + ref_shape: Shape of `ref`. + ref_dim: The dimension of `ref` that is batched (if any). + idx_is_batched: Whether any index in the `indexer` is batched. + """ indices = indexer.indices indices_dims = dims.indices new_indices: list[Array | indexing.Slice | int] = [] @@ -559,9 +551,9 @@ def _batch_indexer(indexer: indexing.NDIndexer, dims, if ref_dim is not batching.not_mapped: iota = lax.broadcasted_iota(np.dtype('int32'), new_integer_indexer_shape, 0) new_indices.insert(ref_dim, iota) - return indexing.NDIndexer(tuple(new_indices), ref_shape, - new_integer_indexer_shape, - validate=True) + return indexing.NDIndexer( + tuple(new_indices), ref_shape, new_integer_indexer_shape, validate=True + ) def _get_vmap(batched_args, batched_dims, *, tree): axis_size, = {x.shape[d] for x, d in zip(batched_args, batched_dims) @@ -576,11 +568,42 @@ def _get_vmap(batched_args, batched_dims, *, tree): if len(indexers) > 1: raise NotImplementedError("Batching with multiple indexers not supported.") # TODO(sharadmv): handle vmap of multiple indexers - indexers = tuple(_batch_indexer(indexer, dims, axis_size, + new_indexers = tuple(_batch_indexer(indexer, dims, axis_size, ref.shape, ref_dim, idx_is_batched) for indexer, dims in zip(indexers, indexers_dims)) - flat_indexers, tree = tree_util.tree_flatten(indexers) - return get_p.bind(ref, *flat_indexers, tree=tree), 0 + flat_indexers, tree = tree_util.tree_flatten(new_indexers) + + is_int_indexing, _, _ = indexing.unpack_ndindexer(indexers[0]) + int_indexers_contiguous = bool( + np.all(np.diff(np.where(is_int_indexing)[0]) == 1) + ) + is_new_int_indexing, _, _ = indexing.unpack_ndindexer(new_indexers[0]) + new_int_indexers_contiguous = bool( + np.all(np.diff(np.where(is_new_int_indexing)[0]) == 1) + ) + + out = get_p.bind(ref, *flat_indexers, tree=tree) + if not int_indexers_contiguous: # will always be moved to the front + out_bdim = 0 + else: # originally not going to be moved to the front + if new_int_indexers_contiguous: # now not going to be moved to the front + out_bdim = is_new_int_indexing.index(True) + else: # now going to be moved to the front + original_pos = is_int_indexing.index(True) + array_indexer_shape = new_indexers[0].int_indexer_shape + array_indexer_len = len(array_indexer_shape) + + transpose_order = list(range(len(out.shape))) + transpose_order = ( + transpose_order[0], + *transpose_order[array_indexer_len:array_indexer_len+original_pos], + *transpose_order[1:array_indexer_len], + *transpose_order[array_indexer_len+original_pos:], + ) + + out = lax.transpose(out, transpose_order) + out_bdim = 0 + return out, out_bdim batching.primitive_batchers[get_p] = _get_vmap def _swap_vmap(batched_args, batched_dims, *, tree): @@ -598,15 +621,59 @@ def _swap_vmap(batched_args, batched_dims, *, tree): if len(indexers) > 1: raise NotImplementedError("Batching with multiple indexers not supported.") # TODO(sharadmv): handle vmap of multiple indexers - indexers = tuple(_batch_indexer(indexer, dims, axis_size, + new_indexers = tuple(_batch_indexer(indexer, dims, axis_size, ref.shape, ref_dim, idx_is_batched) for indexer, dims in zip(indexers, indexers_dims)) - flat_indexers, tree = tree_util.tree_flatten(indexers) - if (ref_is_batched or idx_is_batched) and not val_is_batched: - val = batching.broadcast(val, axis_size, 0) - if val_is_batched: - val = batching.moveaxis(val, val_dim, 0) - return swap_p.bind(ref, val, *flat_indexers, tree=tree), 0 + flat_indexers, tree = tree_util.tree_flatten(new_indexers) + + is_int_indexing, _, _ = indexing.unpack_ndindexer(indexers[0]) + int_indexers_contiguous = bool( + np.all(np.diff(np.where(is_int_indexing)[0]) == 1) + ) + is_new_int_indexing, _, _ = indexing.unpack_ndindexer(new_indexers[0]) + new_int_indexers_contiguous = bool( + np.all(np.diff(np.where(is_new_int_indexing)[0]) == 1) + ) + + if not new_int_indexers_contiguous: # will be moved to the front + batched_dim_in_result = 0 + else: + batched_dim_in_result = is_new_int_indexing.index(True) + 0 + + if not val_is_batched: + if ref_is_batched or idx_is_batched: + val = batching.broadcast(val, axis_size, batched_dim_in_result) + else: + val = batching.moveaxis(val, val_dim, batched_dim_in_result) + + transpose_order_inversed = None + + # Originally not going to be moved to the front, but now going to be moved to + # the front. + if int_indexers_contiguous and not new_int_indexers_contiguous: + original_pos = is_int_indexing.index(True) + array_indexer_shape = new_indexers[0].int_indexer_shape + array_indexer_len = len(array_indexer_shape) + + transpose_order = list(range(len(val.shape))) + transpose_order = ( + transpose_order[0], + *transpose_order[1+original_pos:(1+original_pos)+(array_indexer_len-1)], + *transpose_order[1:1+original_pos], + *transpose_order[(1+original_pos)+(array_indexer_len-1):], + ) + val = val.transpose(transpose_order) + transpose_order_inversed = np.argsort(transpose_order) + + out = swap_p.bind(ref, val, *flat_indexers, tree=tree) + + # `val` should not be transposed, but we needed to transpose it to match + # `swap_p`. As a result, the output of `swap_p` is also transposed. Now we + # need to transpose it back. + if transpose_order_inversed is not None: + out = out.transpose(transpose_order_inversed) + + return out, batched_dim_in_result batching.primitive_batchers[swap_p] = _swap_vmap def _addupdate_vmap(batched_args, batched_dims, *, tree): @@ -624,14 +691,47 @@ def _addupdate_vmap(batched_args, batched_dims, *, tree): if len(indexers) > 1: raise NotImplementedError("Batching with multiple indexers not supported.") # TODO(sharadmv): handle vmap of multiple indexers - indexers = tuple(_batch_indexer(indexer, dims, axis_size, + new_indexers = tuple(_batch_indexer(indexer, dims, axis_size, ref.shape, ref_dim, idx_is_batched) for indexer, dims in zip(indexers, indexers_dims)) - flat_indexers, tree = tree_util.tree_flatten(indexers) - if (ref_is_batched or idx_is_batched) and not val_is_batched: - val = batching.broadcast(val, axis_size, 0) - if val_is_batched: - val = batching.moveaxis(val, val_dim, 0) + flat_indexers, tree = tree_util.tree_flatten(new_indexers) + + is_int_indexing, _, _ = indexing.unpack_ndindexer(indexers[0]) + int_indexers_contiguous = bool( + np.all(np.diff(np.where(is_int_indexing)[0]) == 1) + ) + is_new_int_indexing, _, _ = indexing.unpack_ndindexer(new_indexers[0]) + new_int_indexers_contiguous = bool( + np.all(np.diff(np.where(is_new_int_indexing)[0]) == 1) + ) + + if not new_int_indexers_contiguous: # will be moved to the front + batched_dim_in_result = 0 + else: + batched_dim_in_result = is_new_int_indexing.index(True) + + if not val_is_batched: + if ref_is_batched or idx_is_batched: + val = batching.broadcast(val, axis_size, batched_dim_in_result) + else: + val = batching.moveaxis(val, val_dim, batched_dim_in_result) + + # Originally not going to be moved to the front, but now going to be moved to + # the front. + if int_indexers_contiguous and not new_int_indexers_contiguous: + original_pos = is_int_indexing.index(True) + array_indexer_shape = new_indexers[0].int_indexer_shape + array_indexer_len = len(array_indexer_shape) + + transpose_order = list(range(len(val.shape))) + transpose_order = ( + transpose_order[0], + *transpose_order[1+original_pos:(1+original_pos)+(array_indexer_len-1)], + *transpose_order[1:1+original_pos], + *transpose_order[(1+original_pos)+(array_indexer_len-1):], + ) + val = val.transpose(transpose_order) + return addupdate_p.bind(ref, val, *flat_indexers, tree=tree), [] batching.primitive_batchers[addupdate_p] = _addupdate_vmap diff --git a/jax/_src/state/types.py b/jax/_src/state/types.py index 057242f4c1ac..af5bec49a6d5 100644 --- a/jax/_src/state/types.py +++ b/jax/_src/state/types.py @@ -125,6 +125,10 @@ def transform_sharding(self, sharding): return sharding raise NotImplementedError + def pretty_print(self, context: core.JaxprPpContext) -> pp.Doc: + del context # Unused. + return pp.text(f"{{bitcast({self.dtype}{list(self.shape)}])}}") + @tree_util.register_pytree_node_class @dataclasses.dataclass(frozen=True) @@ -178,6 +182,10 @@ def transform_sharding(self, sharding): return sharding raise NotImplementedError + def pretty_print(self, context: core.JaxprPpContext) -> pp.Doc: + del context # Unused. + return pp.text(f"{{reshape({self.dtype}{list(self.shape)})}}") + class Transform(Protocol): @@ -205,6 +213,9 @@ def transform_sharding(self, sharding): if all(p is None for p in sharding.spec): return sharding # no explicit axes raise NotImplementedError + def pretty_print(self, context: core.JaxprPpContext) -> pp.Doc: + return pp.text(f"{{{self}}}") + @dataclasses.dataclass class RefIndexer: @@ -266,6 +277,9 @@ def dtype(self): assert dtype is not None return dtype + ndim = property(lambda self: len(self.shape)) + size = property(lambda self: math.prod(self.shape)) + @property def at(self) -> RefIndexer: return RefIndexer(self) @@ -330,6 +344,12 @@ def update(self, inner_aval=None): ndim = property(lambda self: len(self.shape)) size = property(lambda self: math.prod(self.shape)) + def _len(self, ignored_tracer) -> int: + try: + return self.shape[0] + except IndexError as err: + raise TypeError("len() of unsized object") from err # same as numpy error + @property def shape(self): try: @@ -357,6 +377,15 @@ def sharding(self): f"`Ref{{{self.inner_aval.str_short()}}} has no `sharding`." ) from None + @property + def vma(self): + try: + return self.inner_aval.vma # pytype: disable=attribute-error + except AttributeError: + raise AttributeError( + f"`Ref{{{self.inner_aval.str_short()}}} has no `vma`." + ) from None + @core.aval_property def at(self): return RefIndexer(self) @@ -427,7 +456,7 @@ def shaped_array_ref( shape: tuple[int, ...], dtype, weak_type: bool = False) -> AbstractRef: return AbstractRef(core.ShapedArray(shape, dtype, weak_type=weak_type)) -def _shard_ref(mesh, auto, names, ref_aval: AbstractRef): +def _shard_ref(mesh, auto, check_rep, names, ref_aval: AbstractRef): del mesh if names: # Can't actually shard a ref, can only close over it. @@ -435,7 +464,7 @@ def _shard_ref(mesh, auto, names, ref_aval: AbstractRef): return ref_aval core.shard_aval_handlers[AbstractRef] = _shard_ref -def _unshard_ref(mesh, names, ref_aval: AbstractRef): +def _unshard_ref(mesh, check_rep, names, ref_aval: AbstractRef): del mesh if names: # Can't actually shard a ref, can only close over it. diff --git a/jax/_src/test_loader.py b/jax/_src/test_loader.py new file mode 100644 index 000000000000..8f97cea1e7bc --- /dev/null +++ b/jax/_src/test_loader.py @@ -0,0 +1,222 @@ +# Copyright 2018 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Contains a custom unittest loader and test suite. + +Implements: +- A test filter based on the JAX_TEST_TARGETS and JAX_EXCLUDE_TEST_TARGETS + environment variables. +- A test suite that runs tests in parallel using threads if JAX_TEST_NUM_THREADS + is >= 1. +- Test decorators that mark a test case or test class as thread-hostile. +""" + +from __future__ import annotations + +from collections.abc import Callable +from concurrent.futures import ThreadPoolExecutor +from contextlib import contextmanager +import logging +import os +import re +import threading +import time +import unittest + +from absl.testing import absltest +from jax._src import config +from jax._src import test_warning_util +from jax._src import util + +logger = logging.getLogger(__name__) + + +_TEST_TARGETS = config.string_flag( + 'test_targets', os.getenv('JAX_TEST_TARGETS', ''), + 'Regular expression specifying which tests to run, called via re.search on ' + 'the test name. If empty or unspecified, run all tests.' +) + +_EXCLUDE_TEST_TARGETS = config.string_flag( + 'exclude_test_targets', os.getenv('JAX_EXCLUDE_TEST_TARGETS', ''), + 'Regular expression specifying which tests NOT to run, called via re.search ' + 'on the test name. If empty or unspecified, run all tests.' +) + +TEST_NUM_THREADS = config.int_flag( + 'jax_test_num_threads', int(os.getenv('JAX_TEST_NUM_THREADS', '0')), + help='Number of threads to use for running tests. 0 means run everything ' + 'in the main thread. Using > 1 thread is experimental.' +) + +# We use a reader-writer lock to protect test execution. Tests that may run in +# parallel acquire a read lock; tests that are not thread-safe acquire a write +# lock. +_test_rwlock = util.Mutex() + +def _run_one_test(test: unittest.TestCase, result: ThreadSafeTestResult): + if getattr(test.__class__, "thread_hostile", False): + _test_rwlock.writer_lock() + try: + test(result) # type: ignore + finally: + _test_rwlock.writer_unlock() + else: + _test_rwlock.reader_lock() + try: + test(result) # type: ignore + finally: + _test_rwlock.reader_unlock() + + +@contextmanager +def thread_unsafe_test(): + """Decorator for tests that are not thread-safe. + + Note: this decorator (naturally) only applies to what it wraps, not to, say, + code in separate setUp() or tearDown() methods. + """ + if TEST_NUM_THREADS.value <= 0: + yield + return + + _test_rwlock.assert_reader_held() + _test_rwlock.reader_unlock() + _test_rwlock.writer_lock() + try: + yield + finally: + _test_rwlock.writer_unlock() + _test_rwlock.reader_lock() + + +def thread_unsafe_test_class(): + """Decorator that marks a TestCase class as thread-hostile.""" + def f(klass): + assert issubclass(klass, unittest.TestCase), type(klass) + klass.thread_hostile = True + return klass + return f + + +class ThreadSafeTestResult: + """ + Wraps a TestResult to make it thread safe. + + We do this by accumulating API calls and applying them in a batch under a + lock at the conclusion of each test case. + + We duck type instead of inheriting from TestResult because we aren't actually + a perfect implementation of TestResult, and would rather get a loud error + for things we haven't implemented. + """ + def __init__(self, lock: threading.Lock, result: unittest.TestResult): + self.lock = lock + self.test_result = result + self.actions: list[Callable[[], None]] = [] + + def startTest(self, test: unittest.TestCase): + logger.info("Test start: %s", test.id()) + self.start_time = time.time() + + def stopTest(self, test: unittest.TestCase): + logger.info("Test stop: %s", test.id()) + stop_time = time.time() + with self.lock: + # If test_result is an ABSL _TextAndXMLTestResult we override how it gets + # the time. This affects the timing that shows up in the XML output + # consumed by CI. + time_getter = getattr(self.test_result, "time_getter", None) + try: + self.test_result.time_getter = lambda: self.start_time + self.test_result.startTest(test) + for callback in self.actions: + callback() + self.test_result.time_getter = lambda: stop_time + self.test_result.stopTest(test) + finally: + if time_getter is not None: + self.test_result.time_getter = time_getter + + def addSuccess(self, test: unittest.TestCase): + self.actions.append(lambda: self.test_result.addSuccess(test)) + + def addSkip(self, test: unittest.TestCase, reason: str): + self.actions.append(lambda: self.test_result.addSkip(test, reason)) + + def addError(self, test: unittest.TestCase, err): + self.actions.append(lambda: self.test_result.addError(test, err)) + + def addFailure(self, test: unittest.TestCase, err): + self.actions.append(lambda: self.test_result.addFailure(test, err)) + + def addExpectedFailure(self, test: unittest.TestCase, err): + self.actions.append(lambda: self.test_result.addExpectedFailure(test, err)) + + def addDuration(self, test: unittest.TestCase, elapsed): + self.actions.append(lambda: self.test_result.addDuration(test, elapsed)) + + +class JaxTestSuite(unittest.TestSuite): + """Runs tests in parallel using threads if TEST_NUM_THREADS is > 1. + + Caution: this test suite does not run setUpClass or setUpModule methods if + thread parallelism is enabled. + """ + + def __init__(self, suite: unittest.TestSuite): + super().__init__(list(suite)) + + def run(self, result: unittest.TestResult, debug: bool = False) -> unittest.TestResult: + if TEST_NUM_THREADS.value <= 0: + return super().run(result) + + test_warning_util.install_threadsafe_warning_handlers() + + executor = ThreadPoolExecutor(TEST_NUM_THREADS.value) + lock = threading.Lock() + futures = [] + + def run_test(test): + """Recursively runs tests in a test suite or test case.""" + if isinstance(test, unittest.TestSuite): + for subtest in test: + run_test(subtest) + else: + test_result = ThreadSafeTestResult(lock, result) + futures.append(executor.submit(_run_one_test, test, test_result)) + + with executor: + run_test(self) + for future in futures: + future.result() + + return result + + +class JaxTestLoader(absltest.TestLoader): + suiteClass = JaxTestSuite + + def getTestCaseNames(self, testCaseClass): + names = super().getTestCaseNames(testCaseClass) + if _TEST_TARGETS.value: + pattern = re.compile(_TEST_TARGETS.value) + names = [name for name in names + if pattern.search(f"{testCaseClass.__name__}.{name}")] + if _EXCLUDE_TEST_TARGETS.value: + pattern = re.compile(_EXCLUDE_TEST_TARGETS.value) + names = [name for name in names + if not pattern.search(f"{testCaseClass.__name__}.{name}")] + return names diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index c55dc2a560e0..caff1c73145b 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -17,7 +17,6 @@ import collections from collections.abc import Callable, Generator, Iterable, Sequence -from concurrent.futures import ThreadPoolExecutor from contextlib import ExitStack, contextmanager import datetime import functools @@ -32,12 +31,10 @@ import tempfile import textwrap import threading -import time from typing import Any, TextIO import unittest import zlib -from absl.testing import absltest from absl.testing import parameterized import jax from jax import lax @@ -51,21 +48,29 @@ from jax._src import lib as _jaxlib from jax._src import monitoring from jax._src import test_warning_util +from jax._src.typing import ArrayLike, DTypeLike from jax._src import xla_bridge from jax._src import util from jax._src import mesh as mesh_lib from jax._src.cloud_tpu_init import running_in_cloud_tpu_vm from jax._src.interpreters import mlir +from jax._src.lib import jaxlib_extension_version +from jax._src.lib.mlir.dialects import hlo from jax._src.numpy.util import promote_dtypes, promote_dtypes_inexact from jax._src.public_test_util import ( # noqa: F401 _assert_numpy_allclose, _check_dtypes_match, _default_tolerance, _dtype, check_close, check_grads, - check_jvp, check_vjp, default_gradient_tolerance, default_tolerance, rand_like, tolerance) + check_jvp, check_vjp, default_gradient_tolerance, default_tolerance, rand_like, tolerance, ToleranceDict) +from jax._src.test_loader import thread_unsafe_test as thread_unsafe_test +from jax._src.test_loader import thread_unsafe_test_class as thread_unsafe_test_class +from jax._src.test_loader import JaxTestLoader as JaxTestLoader +from jax._src.test_loader import TEST_NUM_THREADS as TEST_NUM_THREADS from jax._src.util import unzip2 from jax.tree_util import tree_all, tree_flatten, tree_map, tree_unflatten import numpy as np import numpy.random as npr + # This submodule includes private test utilities that are not exported to # jax.test_util. Functionality appearing here is for internal use only, and # may be changed or removed at any time and without any deprecation cycle. @@ -89,22 +94,12 @@ 'sampling process is terminated.' ) -_SKIP_SLOW_TESTS = config.bool_flag( +SKIP_SLOW_TESTS = config.bool_flag( 'jax_skip_slow_tests', config.bool_env('JAX_SKIP_SLOW_TESTS', False), help='Skip tests marked as slow (> 5 sec).' ) -_TEST_TARGETS = config.string_flag( - 'test_targets', os.getenv('JAX_TEST_TARGETS', ''), - 'Regular expression specifying which tests to run, called via re.search on ' - 'the test name. If empty or unspecified, run all tests.' -) -_EXCLUDE_TEST_TARGETS = config.string_flag( - 'exclude_test_targets', os.getenv('JAX_EXCLUDE_TEST_TARGETS', ''), - 'Regular expression specifying which tests NOT to run, called via re.search ' - 'on the test name. If empty or unspecified, run all tests.' -) TEST_WITH_PERSISTENT_COMPILATION_CACHE = config.bool_flag( 'jax_test_with_persistent_compilation_cache', config.bool_env('JAX_TEST_WITH_PERSISTENT_COMPILATION_CACHE', False), @@ -118,11 +113,6 @@ 'deterministic, interactive'), ) -TEST_NUM_THREADS = config.int_flag( - 'jax_test_num_threads', int(os.getenv('JAX_TEST_NUM_THREADS', '0')), - help='Number of threads to use for running tests. 0 means run everything ' - 'in the main thread. Using > 1 thread is experimental.' -) # We sanitize test names to ensure they work with "unitttest -k" and # "pytest -k" test filtering. pytest accepts '[' and ']' but unittest -k @@ -131,10 +121,10 @@ def sanitize_test_name(s: str) -> str: return kSanitizeNameRE.sub("_", s) -def num_float_bits(dtype): +def num_float_bits(dtype: DTypeLike) -> int: return _dtypes.finfo(_dtypes.canonicalize_dtype(dtype)).bits -def to_default_dtype(arr): +def to_default_dtype(arr: ArrayLike) -> np.ndarray: """Convert a value to an array with JAX's default dtype. This is generally used for type conversions of values returned by numpy functions, @@ -145,7 +135,7 @@ def to_default_dtype(arr): dtype = _dtypes._default_types.get(arr.dtype.kind) return arr.astype(_dtypes.canonicalize_dtype(dtype)) if dtype else arr -def with_jax_dtype_defaults(func, use_defaults=True): +def with_jax_dtype_defaults(func: Callable[..., Any], use_defaults: bool = True): """Return a version of a function with outputs that match JAX's default dtypes. This is generally used to wrap numpy functions within tests, in order to make @@ -168,7 +158,7 @@ def wrapped(*args, **kwargs): return tree_map(f, result, use_defaults) return wrapped -def is_sequence(x): +def is_sequence(x: Any) -> bool: try: iter(x) except TypeError: @@ -176,14 +166,16 @@ def is_sequence(x): else: return True -def _normalize_tolerance(tol): +def _normalize_tolerance(tol: int | float | ToleranceDict | None) -> ToleranceDict: tol = tol or 0 if isinstance(tol, dict): return {np.dtype(k): v for k, v in tol.items()} else: return dict.fromkeys(_default_tolerance, tol) -def join_tolerance(tol1, tol2): +def join_tolerance( + tol1: int | float | ToleranceDict | None, + tol2: int | float | ToleranceDict | None) -> ToleranceDict: tol1 = _normalize_tolerance(tol1) tol2 = _normalize_tolerance(tol2) out = tol1 @@ -192,7 +184,7 @@ def join_tolerance(tol1, tol2): return out -def check_eq(xs, ys, err_msg=''): +def check_eq(xs: Any, ys: Any, err_msg: str = '') -> None: assert_close = partial(_assert_numpy_allclose, err_msg=err_msg) tree_all(tree_map(assert_close, xs, ys)) @@ -373,10 +365,13 @@ def device_under_test(): def supported_dtypes(): if device_under_test() == "tpu": - types = {np.bool_, np.int8, np.int16, np.int32, np.uint8, np.uint16, - np.uint32, _dtypes.bfloat16, np.float16, np.float32, np.complex64, + types = {np.bool_, _dtypes.int4, np.int8, np.int16, np.int32, + _dtypes.uint4, np.uint8, np.uint16, np.uint32, + _dtypes.bfloat16, np.float16, np.float32, np.complex64, _dtypes.float8_e4m3fn, _dtypes.float8_e4m3b11fnuz, _dtypes.float8_e5m2} + if jaxlib_extension_version < 327: + types -= {_dtypes.int4, _dtypes.uint4} elif device_under_test() == "gpu": types = {np.bool_, np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64, @@ -386,10 +381,12 @@ def supported_dtypes(): elif device_under_test() == "METAL": types = {np.int32, np.uint32, np.float32} else: - types = {np.bool_, np.int8, np.int16, np.int32, np.int64, - np.uint8, np.uint16, np.uint32, np.uint64, + types = {np.bool_, _dtypes.int4, np.int8, np.int16, np.int32, np.int64, + _dtypes.uint4, np.uint8, np.uint16, np.uint32, np.uint64, _dtypes.bfloat16, np.float16, np.float32, np.float64, np.complex64, np.complex128} + if jaxlib_extension_version < 327: + types -= {_dtypes.int4, _dtypes.uint4} if not config.enable_x64.value: types -= {np.uint64, np.int64, np.float64, np.complex128} return types @@ -428,14 +425,22 @@ def pjrt_c_api_version_at_least(major_version: int, minor_version: int): return True return pjrt_c_api_versions >= (major_version, minor_version) +def stablehlo_version_at_least(required_version: str): + plugin_version = xla_bridge.backend_stablehlo_version() + if plugin_version is None: + return True + return hlo.get_smaller_version( + ".".join(map(str, plugin_version)), required_version + ) == plugin_version + def get_tpu_version() -> int: if device_under_test() != "tpu": raise ValueError("Device is not TPU") kind = jax.devices()[0].device_kind - if kind.endswith(' lite'): - kind = kind[:-len(' lite')] - assert kind[:-1] == "TPU v", kind - return int(kind[-1]) + match = re.match(r"TPU[^\d]*(\d+)", kind) + if match is None: + raise ValueError(f"Device kind {kind} is not supported") + return int(match.group(1)) def is_device_tpu_at_least(version: int) -> bool: if device_under_test() != "tpu": @@ -1044,165 +1049,6 @@ def sample_product(*args, **kw): """ return parameterized.parameters(*sample_product_testcases(*args, **kw)) -# We use a reader-writer lock to protect test execution. Tests that may run in -# parallel acquire a read lock; tests that are not thread-safe acquire a write -# lock. -_test_rwlock = util.Mutex() - -def _run_one_test(test: unittest.TestCase, result: ThreadSafeTestResult): - if getattr(test.__class__, "thread_hostile", False): - _test_rwlock.writer_lock() - try: - test(result) # type: ignore - finally: - _test_rwlock.writer_unlock() - else: - _test_rwlock.reader_lock() - try: - test(result) # type: ignore - finally: - _test_rwlock.reader_unlock() - - -@contextmanager -def thread_unsafe_test(): - """Decorator for tests that are not thread-safe. - - Note: this decorator (naturally) only applies to what it wraps, not to, say, - code in separate setUp() or tearDown() methods. - """ - if TEST_NUM_THREADS.value <= 0: - yield - return - - _test_rwlock.assert_reader_held() - _test_rwlock.reader_unlock() - _test_rwlock.writer_lock() - try: - yield - finally: - _test_rwlock.writer_unlock() - _test_rwlock.reader_lock() - - -def thread_unsafe_test_class(): - "Decorator that marks a TestCase class as thread-hostile." - def f(klass): - assert issubclass(klass, unittest.TestCase), type(klass) - klass.thread_hostile = True - return klass - return f - - -class ThreadSafeTestResult: - """ - Wraps a TestResult to make it thread safe. - - We do this by accumulating API calls and applying them in a batch under a - lock at the conclusion of each test case. - - We duck type instead of inheriting from TestResult because we aren't actually - a perfect implementation of TestResult, and would rather get a loud error - for things we haven't implemented. - """ - def __init__(self, lock: threading.Lock, result: unittest.TestResult): - self.lock = lock - self.test_result = result - self.actions: list[Callable] = [] - - def startTest(self, test: unittest.TestCase): - del test - self.start_time = time.time() - - def stopTest(self, test: unittest.TestCase): - stop_time = time.time() - with self.lock: - # If test_result is an ABSL _TextAndXMLTestResult we override how it gets - # the time. This affects the timing that shows up in the XML output - # consumed by CI. - time_getter = getattr(self.test_result, "time_getter", None) - try: - self.test_result.time_getter = lambda: self.start_time - self.test_result.startTest(test) - for callback in self.actions: - callback() - self.test_result.time_getter = lambda: stop_time - self.test_result.stopTest(test) - finally: - if time_getter is not None: - self.test_result.time_getter = time_getter - - def addSuccess(self, test: unittest.TestCase): - self.actions.append(lambda: self.test_result.addSuccess(test)) - - def addSkip(self, test: unittest.TestCase, reason: str): - self.actions.append(lambda: self.test_result.addSkip(test, reason)) - - def addError(self, test: unittest.TestCase, err): - self.actions.append(lambda: self.test_result.addError(test, err)) - - def addFailure(self, test: unittest.TestCase, err): - self.actions.append(lambda: self.test_result.addFailure(test, err)) - - def addExpectedFailure(self, test: unittest.TestCase, err): - self.actions.append(lambda: self.test_result.addExpectedFailure(test, err)) - - def addDuration(self, test: unittest.TestCase, elapsed): - self.actions.append(lambda: self.test_result.addDuration(test, elapsed)) - - -class JaxTestSuite(unittest.TestSuite): - """Runs tests in parallel using threads if TEST_NUM_THREADS is > 1. - - Caution: this test suite does not run setUpClass or setUpModule methods if - thread parallelism is enabled. - """ - - def __init__(self, suite: unittest.TestSuite): - super().__init__(list(suite)) - - def run(self, result: unittest.TestResult, debug: bool = False) -> unittest.TestResult: - if TEST_NUM_THREADS.value <= 0: - return super().run(result) - - test_warning_util.install_threadsafe_warning_handlers() - - executor = ThreadPoolExecutor(TEST_NUM_THREADS.value) - lock = threading.Lock() - futures = [] - - def run_test(test): - "Recursively runs tests in a test suite or test case." - if isinstance(test, unittest.TestSuite): - for subtest in test: - run_test(subtest) - else: - test_result = ThreadSafeTestResult(lock, result) - futures.append(executor.submit(_run_one_test, test, test_result)) - - with executor: - run_test(self) - for future in futures: - future.result() - - return result - - -class JaxTestLoader(absltest.TestLoader): - suiteClass = JaxTestSuite - - def getTestCaseNames(self, testCaseClass): - names = super().getTestCaseNames(testCaseClass) - if _TEST_TARGETS.value: - pattern = re.compile(_TEST_TARGETS.value) - names = [name for name in names - if pattern.search(f"{testCaseClass.__name__}.{name}")] - if _EXCLUDE_TEST_TARGETS.value: - pattern = re.compile(_EXCLUDE_TEST_TARGETS.value) - names = [name for name in names - if not pattern.search(f"{testCaseClass.__name__}.{name}")] - return names - def with_config(**kwds): """Test case decorator for subclasses of JaxTestCase""" @@ -1348,15 +1194,15 @@ def assertDeprecationWarnsOrRaises(self, deprecation_id: str, message: str): else: return self.assertWarnsRegex(DeprecationWarning, message) - def assertArraysEqual(self, x, y, *, check_dtypes=True, err_msg='', + def assertArraysEqual(self, actual, desired, *, check_dtypes=True, err_msg='', allow_object_dtype=False, verbose=True): """Assert that x and y arrays are exactly equal.""" if check_dtypes: - self.assertDtypesMatch(x, y) - x = np.asarray(x) - y = np.asarray(y) + self.assertDtypesMatch(actual, desired) + actual = np.asarray(actual) + desired = np.asarray(desired) - if (not allow_object_dtype) and (x.dtype == object or y.dtype == object): + if (not allow_object_dtype) and (actual.dtype == object or desired.dtype == object): # See https://github.com/jax-ml/jax/issues/17867 raise TypeError( "assertArraysEqual may be poorly behaved when np.asarray casts to dtype=object. " @@ -1366,57 +1212,57 @@ def assertArraysEqual(self, x, y, *, check_dtypes=True, err_msg='', # Work around https://github.com/numpy/numpy/issues/18992 with np.errstate(over='ignore'): - np.testing.assert_array_equal(x, y, err_msg=err_msg, + np.testing.assert_array_equal(actual, desired, err_msg=err_msg, verbose=verbose) - def assertArraysAllClose(self, x, y, *, check_dtypes=True, atol=None, + def assertArraysAllClose(self, actual, desired, *, check_dtypes=True, atol=None, rtol=None, err_msg=''): - """Assert that x and y are close (up to numerical tolerances).""" - self.assertEqual(x.shape, y.shape) - atol = max(tolerance(_dtype(x), atol), tolerance(_dtype(y), atol)) - rtol = max(tolerance(_dtype(x), rtol), tolerance(_dtype(y), rtol)) + """Assert that actual and desired are close (up to numerical tolerances).""" + self.assertEqual(actual.shape, desired.shape) + atol = max(tolerance(_dtype(actual), atol), tolerance(_dtype(desired), atol)) + rtol = max(tolerance(_dtype(actual), rtol), tolerance(_dtype(desired), rtol)) - _assert_numpy_allclose(x, y, atol=atol, rtol=rtol, err_msg=err_msg) + _assert_numpy_allclose(actual, desired, atol=atol, rtol=rtol, err_msg=err_msg) if check_dtypes: - self.assertDtypesMatch(x, y) + self.assertDtypesMatch(actual, desired) - def assertDtypesMatch(self, x, y, *, canonicalize_dtypes=True): + def assertDtypesMatch(self, actual, desired, *, canonicalize_dtypes=True): if not config.enable_x64.value and canonicalize_dtypes: - self.assertEqual(_dtypes.canonicalize_dtype(_dtype(x), allow_extended_dtype=True), - _dtypes.canonicalize_dtype(_dtype(y), allow_extended_dtype=True)) + self.assertEqual(_dtypes.canonicalize_dtype(_dtype(actual), allow_extended_dtype=True), + _dtypes.canonicalize_dtype(_dtype(desired), allow_extended_dtype=True)) else: - self.assertEqual(_dtype(x), _dtype(y)) + self.assertEqual(_dtype(actual), _dtype(desired)) - def assertAllClose(self, x, y, *, check_dtypes=True, atol=None, rtol=None, + def assertAllClose(self, actual, desired, *, check_dtypes=True, atol=None, rtol=None, canonicalize_dtypes=True, err_msg=''): - """Assert that x and y, either arrays or nested tuples/lists, are close.""" - if isinstance(x, dict): - self.assertIsInstance(y, dict) - self.assertEqual(set(x.keys()), set(y.keys())) - for k in x.keys(): - self.assertAllClose(x[k], y[k], check_dtypes=check_dtypes, atol=atol, + """Assert that actual and desired, either arrays or nested tuples/lists, are close.""" + if isinstance(actual, dict): + self.assertIsInstance(desired, dict) + self.assertEqual(set(actual.keys()), set(desired.keys())) + for k in actual.keys(): + self.assertAllClose(actual[k], desired[k], check_dtypes=check_dtypes, atol=atol, rtol=rtol, canonicalize_dtypes=canonicalize_dtypes, err_msg=err_msg) - elif is_sequence(x) and not hasattr(x, '__array__'): - self.assertTrue(is_sequence(y) and not hasattr(y, '__array__')) - self.assertEqual(len(x), len(y)) - for x_elt, y_elt in zip(x, y): - self.assertAllClose(x_elt, y_elt, check_dtypes=check_dtypes, atol=atol, + elif is_sequence(actual) and not hasattr(actual, '__array__'): + self.assertTrue(is_sequence(desired) and not hasattr(desired, '__array__')) + self.assertEqual(len(actual), len(desired)) + for actual_elt, desired_elt in zip(actual, desired): + self.assertAllClose(actual_elt, desired_elt, check_dtypes=check_dtypes, atol=atol, rtol=rtol, canonicalize_dtypes=canonicalize_dtypes, err_msg=err_msg) - elif hasattr(x, '__array__') or np.isscalar(x): - self.assertTrue(hasattr(y, '__array__') or np.isscalar(y)) + elif hasattr(actual, '__array__') or np.isscalar(actual): + self.assertTrue(hasattr(desired, '__array__') or np.isscalar(desired)) if check_dtypes: - self.assertDtypesMatch(x, y, canonicalize_dtypes=canonicalize_dtypes) - x = np.asarray(x) - y = np.asarray(y) - self.assertArraysAllClose(x, y, check_dtypes=False, atol=atol, rtol=rtol, + self.assertDtypesMatch(actual, desired, canonicalize_dtypes=canonicalize_dtypes) + actual = np.asarray(actual) + desired = np.asarray(desired) + self.assertArraysAllClose(actual, desired, check_dtypes=False, atol=atol, rtol=rtol, err_msg=err_msg) - elif x == y: + elif actual == desired: return else: - raise TypeError((type(x), type(y))) + raise TypeError((type(actual), type(desired))) def assertMultiLineStrippedEqual(self, expected, what): """Asserts two strings are equal, after dedenting and stripping each line.""" @@ -1431,7 +1277,6 @@ def assertMultiLineStrippedEqual(self, expected, what): self.assertMultiLineEqual(expected_clean, what_clean, msg=f"Found\n{what}\nExpecting\n{expected}") - @contextmanager def assertNoWarnings(self): with test_warning_util.raise_on_warnings(): @@ -1501,9 +1346,9 @@ def wrapped_fun(*args): python_should_be_executing = False compiled_ans = cfun(*args) - self.assertAllClose(python_ans, monitored_ans, check_dtypes=check_dtypes, + self.assertAllClose(monitored_ans, python_ans, check_dtypes=check_dtypes, atol=atol or tol, rtol=rtol or tol) - self.assertAllClose(python_ans, compiled_ans, check_dtypes=check_dtypes, + self.assertAllClose(compiled_ans, python_ans, check_dtypes=check_dtypes, atol=atol or tol, rtol=rtol or tol) args = args_maker() @@ -1514,7 +1359,7 @@ def wrapped_fun(*args): python_should_be_executing = False compiled_ans = cfun(*args) - self.assertAllClose(python_ans, compiled_ans, check_dtypes=check_dtypes, + self.assertAllClose(compiled_ans, python_ans, check_dtypes=check_dtypes, atol=atol or tol, rtol=rtol or tol) def _CheckAgainstNumpy(self, numpy_reference_op, lax_op, args_maker, @@ -1523,7 +1368,7 @@ def _CheckAgainstNumpy(self, numpy_reference_op, lax_op, args_maker, args = args_maker() lax_ans = lax_op(*args) numpy_ans = numpy_reference_op(*args) - self.assertAllClose(numpy_ans, lax_ans, check_dtypes=check_dtypes, + self.assertAllClose(lax_ans, numpy_ans, check_dtypes=check_dtypes, atol=atol or tol, rtol=rtol or tol, canonicalize_dtypes=canonicalize_dtypes) @@ -1630,15 +1475,11 @@ def custom_floats(self): _dtypes.float8_e4m3fnuz, _dtypes.float8_e5m2, _dtypes.float8_e5m2fnuz, + _dtypes.float8_e3m4, + _dtypes.float8_e4m3, + _dtypes.float8_e8m0fnu, + _dtypes.float4_e2m1fn, ] - if _dtypes.float8_e3m4 is not None: - float_dtypes += [_dtypes.float8_e3m4] - if _dtypes.float8_e4m3 is not None: - float_dtypes += [_dtypes.float8_e4m3] - if _dtypes.float8_e8m0fnu is not None: - float_dtypes += [_dtypes.float8_e8m0fnu] - if _dtypes.float4_e2m1fn is not None: - float_dtypes += [_dtypes.float4_e2m1fn] return self.supported(float_dtypes) @_cached_property diff --git a/jax/_src/tpu_custom_call.py b/jax/_src/tpu_custom_call.py index 4089e047f8b0..32236bb6ae90 100644 --- a/jax/_src/tpu_custom_call.py +++ b/jax/_src/tpu_custom_call.py @@ -32,6 +32,7 @@ from jax._src import config from jax._src import core from jax._src import sharding_impls +from jax._src.cloud_tpu_init import is_cloud_tpu_older_than from jax._src.interpreters import mlir from jax._src.lib import tpu from jax._src.lib import xla_client @@ -63,8 +64,22 @@ ) -# This tracks the latest Mosaic IR version with a monthly delay. -FWD_COMPAT_IR_VERSION = 3 +# Controls the IR serialization version. Upon incrementing the +# default version in jaxlib/mosaic/dialect/tpu/transforms/serde.cc we must +# continue to use the old serialization version when in forward compatibility +# mode: for 1 month when exporting, or when using old cloud TPU. +# +# This can be achieved by adding: +# if ctx.is_forward_compat() or is_cloud_tpu_older_than(): +# return +# return None +# +# We should also add a TODO to remove the conditional one month later. +def get_ir_version(ctx: mlir.LoweringRuleContext) -> int | None: + # TODO(jevinjiang): remove the forward compatibility check after 2025-05-05. + if ctx.is_forward_compat() or is_cloud_tpu_older_than(2025, 4, 5): + return 3 + return None tpu_custom_call_p = core.Primitive("tpu_custom_call") @@ -128,6 +143,7 @@ class CustomCallBackendConfig: serialization_format: int | None internal_scratch_in_bytes: int | None output_memory_spaces: tuple[MemorySpace | None, ...] | None + disable_bounds_checks: bool # We omit the body while printing, because primitive params get embedded # in HLO metadata, and the body blows up its size. @@ -178,6 +194,9 @@ def to_json(self) -> bytes: color = memory_space.color if memory_space is not None else -1 config.write(str(color).encode("ascii")) config.write(b"]") + if self.disable_bounds_checks: + config.write(b', "disable_bounds_checks": ') + config.write(str(self.disable_bounds_checks).lower().encode("ascii")) config.write(b"}") # End of custom_call_config. if self.device_type is not None: config.write(b', "device_type": ') @@ -484,13 +503,9 @@ def _lower_mosaic_module_to_asm( module_op = module.operation.clone() prev_allow_unregistered_dialects = ctx.allow_unregistered_dialects ctx.allow_unregistered_dialects = True - # TODO(apaszke): Remove once the minimum jaxlib version is at least 0.4.37. - if jax.version._version_as_tuple(jax.lib.__version__) < (0, 4, 37): - target_version = "" - else: - target_version = ( - f"target-version={ir_version}" if ir_version is not None else "" - ) + target_version = ( + f"target-version={ir_version}" if ir_version is not None else "" + ) try: pipeline = PassManager.parse( "builtin.module(mosaic-serde{serialize=true " + target_version + "})" @@ -562,6 +577,7 @@ def _lower_to_custom_call_config( output_memory_spaces: tuple[MemorySpace | None, ...] | None = None, kernel_name: str | None = None, ir_version: int | None = None, + disable_bounds_checks: bool = False, ) -> CustomCallBackendConfig: lowered_module_asm, ( has_communication, @@ -590,6 +606,7 @@ def _lower_to_custom_call_config( needs_hlo_passes=needs_hlo_passes, needs_layout_passes=needs_layout_passes, output_memory_spaces=output_memory_spaces, + disable_bounds_checks=disable_bounds_checks, ) @@ -609,6 +626,7 @@ def _lowered_to_custom_call_config( needs_layout_passes: bool, device_type: str | None, output_memory_spaces: tuple[MemorySpace | None, ...] | None = None, + disable_bounds_checks: bool = False, ): if has_custom_barrier: if collective_id is None: @@ -639,6 +657,7 @@ def _lowered_to_custom_call_config( serialization_format, internal_scratch_in_bytes, output_memory_spaces, + disable_bounds_checks, ) return config @@ -661,6 +680,7 @@ def lower_module_to_custom_call( serialization_format: int | None, output_memory_spaces: tuple[MemorySpace | None, ...] | None, device_type: str | None, + disable_bounds_checks: bool = False, ) -> Sequence[ir.Value]: config = _lower_to_custom_call_config( module, @@ -675,7 +695,8 @@ def lower_module_to_custom_call( serialization_format=serialization_format, output_memory_spaces=output_memory_spaces, kernel_name=kernel_name, - ir_version=FWD_COMPAT_IR_VERSION if ctx.is_forward_compat() else None, + ir_version=get_ir_version(ctx), + disable_bounds_checks=disable_bounds_checks, ) return _tpu_custom_call_lowering( ctx, @@ -704,6 +725,7 @@ def as_tpu_kernel( has_side_effects: bool = False, serialization_format: int | None = 1, output_memory_spaces: tuple[MemorySpace | None, ...] | None = None, + disable_bounds_checks: bool = False, ) -> Callable[..., Any]: """Turns an MLIR Mosaic kernel into a JAX-compatible function.""" device_type = _get_device_type(module) @@ -720,6 +742,7 @@ def as_tpu_kernel( serialization_format=serialization_format, output_memory_spaces=output_memory_spaces, kernel_name=kernel_name, + disable_bounds_checks=disable_bounds_checks, ) return _as_jax_callable( config, @@ -749,6 +772,7 @@ def lowered_as_tpu_kernel( input_output_aliases: tuple[tuple[int, int], ...] = (), serialization_format: int | None = None, internal_scratch_in_bytes: int | None = None, + disable_bounds_checks: bool = False, ) -> Callable[..., Any]: lowered_module_asm = lowered_module.operation.get_asm( binary=True, enable_debug_info=True @@ -767,6 +791,7 @@ def lowered_as_tpu_kernel( has_communication=has_communication, needs_hlo_passes=needs_hlo_passes, needs_layout_passes=needs_layout_passes, + disable_bounds_checks=disable_bounds_checks, ) return _as_jax_callable( config, diff --git a/jax/_src/traceback_util.py b/jax/_src/traceback_util.py index d66cbb912a99..cde9e4a30f99 100644 --- a/jax/_src/traceback_util.py +++ b/jax/_src/traceback_util.py @@ -56,8 +56,10 @@ def _path_starts_with(path: str, path_prefix: str) -> bool: return False def include_frame(f: types.FrameType) -> bool: - return not any(_path_starts_with(f.f_code.co_filename, path) - for path in _exclude_paths) + return include_filename(f.f_code.co_filename) + +def include_filename(filename: str) -> bool: + return not any(_path_starts_with(filename, path) for path in _exclude_paths) # When scanning stack traces, we might encounter frames from cpython that are # removed from printed stack traces, such as frames from parts of importlib. We diff --git a/jax/_src/tree_util.py b/jax/_src/tree_util.py index 6c7e15a042e5..b73d84b330de 100644 --- a/jax/_src/tree_util.py +++ b/jax/_src/tree_util.py @@ -21,7 +21,7 @@ from functools import partial import operator as op import textwrap -from typing import Any, NamedTuple, TypeVar, overload +from typing import Any, TypeVar, overload from jax._src import traceback_util from jax._src.lib import pytree @@ -362,6 +362,8 @@ def tree_map(f: Callable[..., Any], def build_tree(treedef: PyTreeDef, xs: Any) -> Any: """Build a treedef from a nested iterable structure + DEPRECATED: Use :func:`jax.tree.unflatten` instead. + Args: treedef: the PyTreeDef structure to build. xs: nested iterables matching the arity as the treedef @@ -376,13 +378,6 @@ def build_tree(treedef: PyTreeDef, xs: Any) -> Any: >>> import jax >>> tree = [(1, 2), {'a': 3, 'b': 4}] >>> treedef = jax.tree.structure(tree) - - Both ``build_tree`` and :func:`jax.tree_util.tree_unflatten` can reconstruct - the tree from new values, but ``build_tree`` takes these values in terms of - a nested rather than flat structure: - - >>> jax.tree_util.build_tree(treedef, [[10, 11], [12, 13]]) - [(10, 11), {'a': 12, 'b': 13}] >>> jax.tree_util.tree_unflatten(treedef, [10, 11, 12, 13]) [(10, 11), {'a': 12, 'b': 13}] """ @@ -767,42 +762,6 @@ def _simple_entrystr(key: KeyEntry) -> str: return str(key) -# TODO(ivyzheng): remove this after another jaxlib release. -class _RegistryWithKeypathsEntry(NamedTuple): - flatten_with_keys: Callable[..., Any] - unflatten_func: Callable[..., Any] - - -def _register_keypaths( - ty: type[T], handler: Callable[[T], tuple[KeyEntry, ...]] -) -> None: - def flatten_with_keys(xs): - children, treedef = _registry[ty].to_iter(xs) - return list(zip(handler(xs), children)), treedef - if ty in _registry: - _registry_with_keypaths[ty] = _RegistryWithKeypathsEntry( - flatten_with_keys, _registry[ty].from_iter - ) - -_registry_with_keypaths: dict[type[Any], _RegistryWithKeypathsEntry] = {} - -_register_keypaths( - tuple, lambda xs: tuple(SequenceKey(i) for i in range(len(xs))) -) -_register_keypaths( - list, lambda xs: tuple(SequenceKey(i) for i in range(len(xs))) -) -_register_keypaths(dict, lambda xs: tuple(DictKey(k) for k in sorted(xs))) - -_register_keypaths( - collections.defaultdict, lambda x: tuple(DictKey(k) for k in x.keys()) -) - -_register_keypaths( - collections.OrderedDict, lambda x: tuple(DictKey(k) for k in x.keys()) -) - - @export def register_pytree_with_keys( nodetype: type[T], @@ -872,9 +831,6 @@ def flatten_func_impl(tree): register_pytree_node( nodetype, flatten_func, unflatten_func, flatten_with_keys ) - _registry_with_keypaths[nodetype] = _RegistryWithKeypathsEntry( - flatten_with_keys, unflatten_func - ) @export @@ -1067,11 +1023,6 @@ def register_dataclass( msg += f" Unexpected fields: {unexpected}." raise ValueError(msg) - def flatten_with_keys(x): - meta = tuple(getattr(x, name) for name in meta_fields) - data = tuple((GetAttrKey(name), getattr(x, name)) for name in data_fields) - return data, meta - def unflatten_func(meta, data): meta_args = tuple(zip(meta_fields, meta)) data_args = tuple(zip(data_fields, data)) @@ -1087,9 +1038,6 @@ def flatten_func(x): none_leaf_registry.register_dataclass_node(nodetype, list(data_fields), list(meta_fields)) dispatch_registry.register_dataclass_node(nodetype, list(data_fields), list(meta_fields)) _registry[nodetype] = _RegistryEntry(flatten_func, unflatten_func) - _registry_with_keypaths[nodetype] = _RegistryWithKeypathsEntry( - flatten_with_keys, unflatten_func - ) return nodetype diff --git a/jax/_src/typing.py b/jax/_src/typing.py index 010841b45dd2..ee2422dd2d73 100644 --- a/jax/_src/typing.py +++ b/jax/_src/typing.py @@ -47,7 +47,19 @@ @typing.runtime_checkable class SupportsDType(Protocol): @property - def dtype(self) -> DType: ... + def dtype(self, /) -> DType: ... + +class SupportsShape(Protocol): + @property + def shape(self, /) -> tuple[int, ...]: ... + +class SupportsSize(Protocol): + @property + def size(self, /) -> int: ... + +class SupportsNdim(Protocol): + @property + def ndim(self, /) -> int: ... # DTypeLike is meant to annotate inputs to np.dtype that return # a valid JAX dtype. It's different than numpy.typing.DTypeLike diff --git a/jax/_src/util.py b/jax/_src/util.py index 0e28aea04b5a..e551c654b005 100644 --- a/jax/_src/util.py +++ b/jax/_src/util.py @@ -27,7 +27,7 @@ import numpy as np from jax._src import config -from jax._src.lib import xla_client as xc +from jax._src.lib import weakref_lru_cache as _weakref_lru_cache from jax._src.lib import utils as jaxlib_utils logger = logging.getLogger(__name__) @@ -108,11 +108,7 @@ def foreach(f, *args): return None else: - # TODO(phawkins): remove after jaxlib 0.5.2 is the minimum. - if hasattr(jaxlib_utils, 'foreach'): - foreach = jaxlib_utils.foreach - else: - foreach = safe_map + foreach = jaxlib_utils.foreach def unzip2(xys: Iterable[tuple[T1, T2]] @@ -244,61 +240,8 @@ def curry(f): """ return wraps(f)(partial(partial, f)) -# TODO(phawkins): make this unconditional after jaxlib 0.5.3 is the minimum. toposort: Callable[[Iterable[Any]], list[Any]] -if hasattr(jaxlib_utils, "topological_sort"): - toposort = partial(jaxlib_utils.topological_sort, "parents") -else: - - def toposort(end_nodes): - if not end_nodes: - return [] - end_nodes = _remove_duplicates(end_nodes) - - child_counts = {} - stack = list(end_nodes) - while stack: - node = stack.pop() - if id(node) in child_counts: - child_counts[id(node)] += 1 - else: - child_counts[id(node)] = 1 - stack.extend(node.parents) - for node in end_nodes: - child_counts[id(node)] -= 1 - - sorted_nodes = [] - childless_nodes = [ - node for node in end_nodes if child_counts[id(node)] == 0 - ] - assert childless_nodes - while childless_nodes: - node = childless_nodes.pop() - sorted_nodes.append(node) - for parent in node.parents: - if child_counts[id(parent)] == 1: - childless_nodes.append(parent) - else: - child_counts[id(parent)] -= 1 - sorted_nodes = sorted_nodes[::-1] - - check_toposort(sorted_nodes) - return sorted_nodes - - def check_toposort(nodes): - visited = set() - for node in nodes: - assert all(id(parent) in visited for parent in node.parents) - visited.add(id(node)) - - def _remove_duplicates(node_list): - seen = set() - out = [] - for n in node_list: - if id(n) not in seen: - seen.add(id(n)) - out.append(n) - return out +toposort = partial(jaxlib_utils.topological_sort, "parents") def split_merge(predicate, xs): @@ -320,7 +263,6 @@ def merge(new_lhs, new_rhs): return lhs, rhs, merge - def _ignore(): return None @@ -362,8 +304,9 @@ def weakref_lru_cache(call: Callable, maxsize=2048, behave similar to `functools.lru_cache`. """ global _weakref_lru_caches - cached_call = xc.weakref_lru_cache( - config.trace_context if trace_context_in_key else _ignore, call, maxsize) + cached_call = _weakref_lru_cache.weakref_lru_cache( + config.trace_context if trace_context_in_key else _ignore, call, maxsize + ) _weakref_lru_caches.add(cached_call) return cached_call @@ -475,8 +418,9 @@ def wrapper(fun: T) -> T: else docstr.format(fun=name, doc=doc, **kwargs)) fun.__qualname__ = getattr(wrapped, "__qualname__", fun.__name__) fun.__wrapped__ = wrapped - finally: - return fun + except Exception: + pass + return fun return wrapper @@ -497,10 +441,6 @@ def tuple_update(t, idx, val): assert 0 <= idx < len(t), (idx, len(t)) return t[:idx] + (val,) + t[idx+1:] -def tuple_replace(tupl, index, item): - # unlike tuple_update, works with negative indices as well - return tupl[:index] + (item,) + tupl[index:][1:] - class HashableFunction: """Decouples function equality and hash from its identity. @@ -554,13 +494,8 @@ def __eq__(self, other): self.args == other.args and self.kwargs == other.kwargs) def __hash__(self): - return hash( - ( - self.f.__code__, - self.args, - tuple(sorted(self.kwargs.items(), key=lambda kv: kv[0])), - ), - ) + kwargs = tuple(sorted(self.kwargs.items(), key=lambda kv: kv[0])) + return hash((self.f.__code__, self.args, kwargs)) def __call__(self, *args, **kwargs): return self.f(*self.args, *args, **self.kwargs, **kwargs) diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index be96deab81d8..5fb42c333605 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -32,7 +32,7 @@ import platform as py_platform import threading import traceback -from typing import Any, Union +from typing import Any, Sequence, Union import warnings from jax._src import config @@ -60,6 +60,9 @@ XlaBackend = xla_client.Client +# The platforms in this set will force forward compatibility for lowering. +FORCE_FORWARD_COMPAT_LOWERING_PLATFORMS: set[str] = set() + MIN_COMPUTE_CAPABILITY = 52 _DEFAULT_CPU_COLLECTIVES_IMPL = 'gloo' @@ -86,13 +89,13 @@ 'Restricts the set of ROCM devices that JAX will use. Either "all", or a ' 'comma-separate list of integer device IDs.') -_MOCK_NUM_GPU_PROCESSES = config.int_flag( +MOCK_NUM_GPU_PROCESSES = config.int_flag( name="mock_num_gpu_processes", default=0, help="Mock number of JAX processes in GPU client. Value zero turns " "off mocking.", ) -_MOCK_GPU_TOPOLOGY = config.string_flag( +MOCK_GPU_TOPOLOGY = config.string_flag( name="jax_mock_gpu_topology", default="", help='Mock multi-host GPU topology in GPU client. The value should ' @@ -131,7 +134,7 @@ def _log_warning(): warnings.warn( f'TPU backend initialization is taking more than {timer_secs} seconds. ' 'Did you run your code on all TPU hosts? ' - 'See https://jax.readthedocs.io/en/latest/multi_process.html ' + 'See https://docs.jax.dev/en/latest/multi_process.html ' 'for more information.') # Will log a warning after `timer_secs`. @@ -287,7 +290,7 @@ def _check_cuda_compute_capability(devices_to_check): f"Device {idx} has CUDA compute capability {compute_cap/10} which is " "lower than the minimum supported compute capability " f"{MIN_COMPUTE_CAPABILITY/10}. See " - "https://jax.readthedocs.io/en/latest/installation.html#nvidia-gpu for " + "https://docs.jax.dev/en/latest/installation.html#nvidia-gpu for " "more details", RuntimeWarning ) @@ -429,7 +432,7 @@ def _version_check(name: str, f'following issues with CUDA components:\n' f'{join_str.join(errors)}') -def _get_num_nodes_from_gpu_topology(topology: str) -> int: +def get_num_nodes_from_gpu_topology(topology: str) -> int: try: slices_str, hosts_per_slice_str, _ = topology.split("x", 2) return int(slices_str) * int(hosts_per_slice_str) @@ -438,69 +441,6 @@ def _get_num_nodes_from_gpu_topology(topology: str) -> int: '" x x ' '".') -def make_gpu_client( - *, platform_name: str, visible_devices_flag: config.Flag[str] -) -> xla_client.Client: - visible_devices = visible_devices_flag.value - allowed_devices = None - if visible_devices != "all": - allowed_devices = {int(x) for x in visible_devices.split(",")} - - mock_gpu_topology = _MOCK_GPU_TOPOLOGY.value or None - mock_num_gpu_processes = (_get_num_nodes_from_gpu_topology(mock_gpu_topology) if - mock_gpu_topology else _MOCK_NUM_GPU_PROCESSES.value) - - use_mock_gpu_client = mock_num_gpu_processes > 0 - num_nodes = (mock_num_gpu_processes if use_mock_gpu_client - else distributed.global_state.num_processes) - - if platform_name == "cuda": - if not os.getenv("JAX_SKIP_CUDA_CONSTRAINTS_CHECK"): - _check_cuda_versions() - else: - print('Skipped CUDA versions constraints check due to the ' - 'JAX_SKIP_CUDA_CONSTRAINTS_CHECK env var being set.') - - devices_to_check = ( - allowed_devices - if allowed_devices - else range(cuda_versions.cuda_device_count()) - ) - _check_cuda_compute_capability(devices_to_check) - - return xla_client.make_gpu_client( - distributed_client=distributed.global_state.client, - node_id=distributed.global_state.process_id, - num_nodes=num_nodes, - platform_name=platform_name, - allowed_devices=allowed_devices, - mock=use_mock_gpu_client, - ) - - -if hasattr(xla_client, "make_gpu_client"): - register_backend_factory( - "cuda", - partial( - make_gpu_client, - platform_name="cuda", - visible_devices_flag=CUDA_VISIBLE_DEVICES, - ), - priority=200, - fail_quietly=True, - ) - register_backend_factory( - "rocm", - partial( - make_gpu_client, - platform_name="rocm", - visible_devices_flag=_ROCM_VISIBLE_DEVICES, - ), - priority=200, - fail_quietly=True, - ) - - if hasattr(xla_client, "make_tpu_client"): # TODO(phawkins,skyewm): switch TPU plugin to use the PJRT plugin mechanism, # and then fail loudly on initialization failure. @@ -649,9 +589,9 @@ def _options_from_jax_configs(plugin_name): else _ROCM_VISIBLE_DEVICES.value) if visible_devices != 'all': options['visible_devices'] = [int(x) for x in visible_devices.split(',')] - mock_gpu_topology = _MOCK_GPU_TOPOLOGY.value or None - mock_num_processes = (_get_num_nodes_from_gpu_topology(mock_gpu_topology) if - mock_gpu_topology else _MOCK_NUM_GPU_PROCESSES.value) + mock_gpu_topology = MOCK_GPU_TOPOLOGY.value or None + mock_num_processes = (get_num_nodes_from_gpu_topology(mock_gpu_topology) if + mock_gpu_topology else MOCK_NUM_GPU_PROCESSES.value) options['enable_mock_nccl'] = mock_num_processes > 0 if mock_num_processes > 0: options['num_nodes'] = mock_num_processes @@ -696,6 +636,8 @@ def factory(): 'node_id': distributed.global_state.process_id, 'num_nodes': distributed.global_state.num_processes, } + if (slice_index := distributed.global_state.slice_index) is not None: + distribute_options['slice_index'] = slice_index if options is not None: distribute_options.update(updated_options) return xla_client.make_c_api_client( @@ -957,7 +899,7 @@ def _suggest_missing_backends(): warning_msg += ( "This may be due to JAX pre-allocating too much device " "memory, leaving too little for CUDA library initialization. See " - "https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html " + "https://docs.jax.dev/en/latest/gpu_memory_allocation.html " "for more details and potential workarounds." ) warning_msg += "(Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)" @@ -1144,6 +1086,16 @@ def backend_xla_version(platform=None) -> int | None: backend = get_backend(platform) return getattr(backend, "xla_version", None) +def backend_stablehlo_version(platform=None) -> Sequence[int] | None: + """Returns the StableHLO version of the backend. + + Returns None if the backend does not use PJRT C API or does not have + stablehlo_current_version in the plugin attributes. This methon can be used to + skip features that are not available before certain stablehlo_current_version + if the backend is a plugin and uses stablehlo_current_version. + """ + backend = get_backend(platform) + return getattr(backend, "stablehlo_current_version", None) @lru_cache def local_devices(process_index: int | None = None, diff --git a/jax/_src/xla_metadata.py b/jax/_src/xla_metadata.py index 91895b4e7851..77c0e2ff9910 100644 --- a/jax/_src/xla_metadata.py +++ b/jax/_src/xla_metadata.py @@ -24,6 +24,8 @@ class XlaMetadata: __slots__ = ['val', 'hash'] + val: dict[str, Any] + def __init__(self, val): self.val = val self.hash = hash(tuple(sorted(self.val.items()))) @@ -35,14 +37,19 @@ def __eq__(self, other): return other is not None and self.val == other.val +def filter_nones(d: dict) -> dict: + return {k: v for k, v in d.items() if v is not None} + + def update_metadata(a, b: dict[str, Any]): if not b: return a if a is None or a is config_ext.unset: - return XlaMetadata(b) - val = a.val.copy() + val = {} + else: + val = a.val.copy() val.update(b) - return XlaMetadata(val) + return XlaMetadata(filter_nones(val)) def current_xla_metadata(): diff --git a/jax/collect_profile.py b/jax/collect_profile.py index d1309e0c5bca..b355816772a1 100644 --- a/jax/collect_profile.py +++ b/jax/collect_profile.py @@ -91,7 +91,7 @@ def collect_profile(port: int, duration_in_ms: int, host: str, in root_trace_folder.iterdir()] latest_folder = max(trace_folders, key=os.path.getmtime) xplane = next(latest_folder.glob("*.xplane.pb")) - result, _ = convert.xspace_to_tool_data([xplane], "trace_viewer^", {}) + result, _ = convert.xspace_to_tool_data([xplane], "trace_viewer", {}) with gzip.open(str(latest_folder / "remote.trace.json.gz"), "wb") as fp: fp.write(result.encode("utf-8")) diff --git a/jax/core.py b/jax/core.py index 3fd7af440d4a..50b50d935024 100644 --- a/jax/core.py +++ b/jax/core.py @@ -81,154 +81,79 @@ from jax._src import core as _src_core _deprecations = { - # Added 2024-12-16 - "ClosedJaxpr": ("jax.core.ClosedJaxpr is deprecated. Use jax.extend.core.ClosedJaxpr instead, " - "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", - _src_core.ClosedJaxpr), - "Jaxpr": ("jax.core.Jaxpr is deprecated. Use jax.extend.core.Jaxpr instead, " - "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", - _src_core.Jaxpr), - "JaxprEqn": ("jax.core.JaxprEqn is deprecated. Use jax.extend.core.JaxprEqn instead, " - "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", - _src_core.JaxprEqn), - "Literal": ("jax.core.Literal is deprecated. Use jax.extend.core.Literal instead, " - "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", - _src_core.Literal), - "Primitive": ("jax.core.Primitive is deprecated. Use jax.extend.core.Primitive instead, " - "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", - _src_core.Primitive), - "Token": ("jax.core.Token is deprecated. Use jax.extend.core.Token instead, " - "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", - _src_core.Token), - "Var": ("jax.core.Var is deprecated. Use jax.extend.core.Var instead, " - "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", - _src_core.Var), # Added 2024-12-11 "axis_frame": ("jax.core.axis_frame is deprecated.", _src_core.axis_frame), "AxisName": ("jax.core.AxisName is deprecated.", _src_core.AxisName), - "AxisSize": ("jax.core.AxisSize is deprecated.", _src_core.AxisSize), "ConcretizationTypeError": ("jax.core.ConcretizationTypeError is deprecated; " "use jax.errors.ConcretizationTypeError.", _src_core.ConcretizationTypeError), - "EvalTrace": ("jax.core.EvalTrace is deprecated.", _src_core.EvalTrace), - "InDBIdx": ("jax.core.InDBIdx is deprecated.", _src_core.InDBIdx), - "InputType": ("jax.core.InputType is deprecated.", _src_core.InputType), - "MapPrimitive": ("jax.core.MapPrimitive is deprecated.", _src_core.MapPrimitive), - "OpaqueTraceState": ("jax.core.OpaqueTraceState is deprecated.", _src_core.OpaqueTraceState), - "OutDBIdx": ("jax.core.OutDBIdx is deprecated.", _src_core.OutDBIdx), - "TRACER_LEAK_DEBUGGER_WARNING": ("jax.core.TRACER_LEAK_DEBUGGER_WARNING is deprecated.", - _src_core.TRACER_LEAK_DEBUGGER_WARNING), "call_p": ("jax.core.call_p is deprecated. Use jax.extend.core.primitives.call_p", _src_core.call_p), "closed_call_p": ("jax.core.closed_call_p is deprecated. Use jax.extend.core.primitives.closed_call_p", _src_core.closed_call_p), - "concrete_aval": ("jax.core.concrete_aval is deprecated.", _src_core.abstractify), - "dedup_referents": ("jax.core.dedup_referents is deprecated.", _src_core.dedup_referents), - "escaped_tracer_error": ("jax.core.escaped_tracer_error is deprecated.", - _src_core.escaped_tracer_error), - "extend_axis_env_nd": ("jax.core.extend_axis_env_nd is deprecated.", - _src_core.extend_axis_env_nd), "get_type": ("jax.core.get_type is deprecated.", _src_core.get_aval), - "get_referent": ("jax.core.get_referent is deprecated.", _src_core.get_referent), - "join_effects": ("jax.core.join_effects is deprecated.", _src_core.join_effects), - "leaked_tracer_error": ("jax.core.leaked_tracer_error is deprecated.", - _src_core.leaked_tracer_error), - "maybe_find_leaked_tracers": ("jax.core.maybe_find_leaked_tracers is deprecated.", - _src_core.maybe_find_leaked_tracers), - "raise_to_shaped_mappings": ("jax.core.raise_to_shaped_mappings is deprecated." - " It is unused as of jax v0.4.36.", - _src_core.raise_to_shaped_mappings), - "reset_trace_state": ("jax.core.reset_trace_state is deprecated.", - _src_core.reset_trace_state), - "str_eqn_compact": ("jax.core.str_eqn_compact is deprecated.", _src_core.str_eqn_compact), - "substitute_vars_in_output_ty": ("jax.core.substitute_vars_in_output_ty is deprecated.", - _src_core.substitute_vars_in_output_ty), "trace_state_clean": ("jax.core.trace_state_clean is deprecated.", _src_core.trace_state_clean), "typecheck": ("jax.core.typecheck is deprecated.", _src_core.typecheck), - "typecompat": ("jax.core.typecompat is deprecated.", _src_core.typecompat), "typematch": ("jax.core.typematch is deprecated.", _src_core.typematch), - "used_axis_names_jaxpr": ("jax.core.used_axis_names_jaxpr is deprecated.", - _src_core.used_axis_names_jaxpr), # Added 2024-12-10 - "full_lower": ("jax.core.full_lower is deprecated. It is a no-op as of JAX v0.4.36.", - _src_core.full_lower), - "jaxpr_as_fun": ("jax.core.jaxpr_as_fun is deprecated. Use jax.extend.core.jaxpr_as_fun instead, " - "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", - _src_core.jaxpr_as_fun), - "lattice_join": ("jax.core.lattice_join is deprecated. It is a no-op as of JAX v0.4.36.", - _src_core.lattice_join), - "raise_to_shaped": ("jax.core.raise_to_shaped is deprecated. It is a no-op as of JAX v0.4.36.", - _src_core.raise_to_shaped), - # Finalized 2024-12-11; remove after 2025-3-11 - "check_eqn": ("jax.core.check_eqn was removed in JAX v0.4.38.", None), - "check_type": ("jax.core.check_type was removed in JAX v0.4.38.", None), - "check_valid_jaxtype": ( - ("jax.core.check_valid_jaxtype was removed in JAX v0.4.38. Instead, you can manually" - " raise an error if core.valid_jaxtype() returns False."), - None), - "non_negative_dim": ( - "jax.core.non_negative_dim was removed in JAX v0.4.38. Use max_dim(..., 0).", None, - ), - # Finalized 2024-09-25; remove after 2024-12-25 - "pp_aval": ("jax.core.pp_aval was removed in JAX v0.4.34.", None), - "pp_eqn": ("jax.core.pp_eqn was removed in JAX v0.4.34.", None), - "pp_eqn_rules": ("jax.core.pp_eqn_rules was removed in JAX v0.4.34.", None), - "pp_eqns": ("jax.core.pp_eqns was removed in JAX v0.4.34.", None), - "pp_jaxpr": ("jax.core.pp_jaxpr was removed in JAX v0.4.34.", None), - "pp_jaxpr_eqn_range": ("jax.core.pp_jaxpr_eqn_range was removed in JAX v0.4.34.", None), - "pp_jaxpr_skeleton": ("jax.core.pp_jaxpr_skeleton was removed in JAX v0.4.34.", None), - "pp_jaxprs": ("jax.core.pp_jaxprs was removed in JAX v0.4.34.", None), - "pp_kv_pair": ("jax.core.pp_kv_pair was removed in JAX v0.4.34.", None), - "pp_kv_pairs": ("jax.core.pp_kv_pairs was removed in JAX v0.4.34.", None), - "pp_var": ("jax.core.pp_var was removed in JAX v0.4.34.", None), - "pp_vars": ("jax.core.pp_vars was removed in JAX v0.4.34.", None), + "full_lower": ("jax.core.full_lower is deprecated. It is a no-op as of JAX v0.4.36.", None), + "jaxpr_as_fun": ("jax.core.jaxpr_as_fun was removed in JAX v0.6.0. Use jax.extend.core.jaxpr_as_fun instead, " + "and see https://docs.jax.dev/en/latest/jax.extend.html for details.", + None), + "lattice_join": ("jax.core.lattice_join is deprecated. It is a no-op as of JAX v0.4.36.", None), + # Finalized 2025-03-25 for JAX v0.6.0; remove after 2025-06-25 + "AxisSize": ("jax.core.AxisSize was removed in JAX v0.6.0.", None), + "ClosedJaxpr": ("jax.core.ClosedJaxpr was removed in JAX v0.6.0. Use jax.extend.core.ClosedJaxpr instead, " + "and see https://docs.jax.dev/en/latest/jax.extend.html for details.", None), + "EvalTrace": ("jax.core.EvalTrace was removed in JAX v0.6.0.", None), + "InDBIdx": ("jax.core.InDBIdx was removed in JAX v0.6.0.", None), + "InputType": ("jax.core.InputType was removed in JAX v0.6.0.", None), + "Jaxpr": ("jax.core.Jaxpr was removed in JAX v0.6.0. Use jax.extend.core.Jaxpr instead, " + "and see https://docs.jax.dev/en/latest/jax.extend.html for details.", None), + "JaxprEqn": ("jax.core.JaxprEqn was removed in JAX v0.6.0. Use jax.extend.core.JaxprEqn instead, " + "and see https://docs.jax.dev/en/latest/jax.extend.html for details.", None), + "Literal": ("jax.core.Literal was removed in JAX v0.6.0. Use jax.extend.core.Literal instead, " + "and see https://docs.jax.dev/en/latest/jax.extend.html for details.", None), + "MapPrimitive": ("jax.core.MapPrimitive was removed in JAX v0.6.0.", None), + "OpaqueTraceState": ("jax.core.OpaqueTraceState was removed in JAX v0.6.0.", None), + "OutDBIdx": ("jax.core.OutDBIdx was removed in JAX v0.6.0.", None), + "Primitive": ("jax.core.Primitive was removed in JAX v0.6.0. Use jax.extend.core.Primitive instead, " + "and see https://docs.jax.dev/en/latest/jax.extend.html for details.", None), + "Token": ("jax.core.Token was removed in JAX v0.6.0. Use jax.extend.core.Token instead, " + "and see https://docs.jax.dev/en/latest/jax.extend.html for details.", None), + "TRACER_LEAK_DEBUGGER_WARNING": ("jax.core.TRACER_LEAK_DEBUGGER_WARNING was removed in JAX v0.6.0.", None), + "Var": ("jax.core.Var was removed in JAX v0.6.0. Use jax.extend.core.Var instead, " + "and see https://docs.jax.dev/en/latest/jax.extend.html for details.", None), + "concrete_aval": ("jax.core.concrete_aval was removed in JAX v0.6.0.", None), + "dedup_referents": ("jax.core.dedup_referents was removed in JAX v0.6.0.", None), + "escaped_tracer_error": ("jax.core.escaped_tracer_error was removed in JAX v0.6.0.", None), + "extend_axis_env_nd": ("jax.core.extend_axis_env_nd was removed in JAX v0.6.0.", None), + "get_referent": ("jax.core.get_referent was removed in JAX v0.6.0.", None), + "join_effects": ("jax.core.join_effects was removed in JAX v0.6.0.", None), + "leaked_tracer_error": ("jax.core.leaked_tracer_error was removed in JAX v0.6.0.", None), + "maybe_find_leaked_tracers": ("jax.core.maybe_find_leaked_tracers was removed in JAX v0.6.0.", None), + "raise_to_shaped": ("jax.core.raise_to_shaped was removed in JAX v0.6.0. It is a no-op as of JAX v0.4.36.", None), + "raise_to_shaped_mappings": ("jax.core.raise_to_shaped_mappings was removed in JAX v0.6.0." + " It is unused as of jax v0.4.36.", None), + "reset_trace_state": ("jax.core.reset_trace_state was removed in JAX v0.6.0.", None), + "str_eqn_compact": ("jax.core.str_eqn_compact was removed in JAX v0.6.0.", None), + "substitute_vars_in_output_ty": ("jax.core.substitute_vars_in_output_ty was removed in JAX v0.6.0.", None), + "typecompat": ("jax.core.typecompat was removed in JAX v0.6.0.", None), + "used_axis_names_jaxpr": ("jax.core.used_axis_names_jaxpr was removed in JAX v0.6.0.", None), } import typing if typing.TYPE_CHECKING: AxisName = _src_core.AxisName - AxisSize = _src_core.AxisSize - ClosedJaxpr = _src_core.ClosedJaxpr ConcretizationTypeError = _src_core.ConcretizationTypeError - EvalTrace = _src_core.EvalTrace - InDBIdx = _src_core.InDBIdx - InputType = _src_core.InputType - Jaxpr = _src_core.Jaxpr - JaxprEqn = _src_core.JaxprEqn - Literal = _src_core.Literal - MapPrimitive = _src_core.MapPrimitive - OpaqueTraceState = _src_core.OpaqueTraceState - OutDBIdx = _src_core.OutDBIdx - Primitive = _src_core.Primitive - Token = _src_core.Token - TRACER_LEAK_DEBUGGER_WARNING = _src_core.TRACER_LEAK_DEBUGGER_WARNING - Var = _src_core.Var axis_frame = _src_core.axis_frame call_p = _src_core.call_p closed_call_p = _src_core.closed_call_p - concrete_aval = _src_core.abstractify - dedup_referents = _src_core.dedup_referents - escaped_tracer_error = _src_core.escaped_tracer_error - extend_axis_env_nd = _src_core.extend_axis_env_nd - full_lower = _src_core.full_lower get_type = _src_core.get_aval - get_referent = _src_core.get_referent - jaxpr_as_fun = _src_core.jaxpr_as_fun - join_effects = _src_core.join_effects - lattice_join = _src_core.lattice_join - leaked_tracer_error = _src_core.leaked_tracer_error - maybe_find_leaked_tracers = _src_core.maybe_find_leaked_tracers - raise_to_shaped = _src_core.raise_to_shaped - raise_to_shaped_mappings = _src_core.raise_to_shaped_mappings - reset_trace_state = _src_core.reset_trace_state - str_eqn_compact = _src_core.str_eqn_compact - substitute_vars_in_output_ty = _src_core.substitute_vars_in_output_ty trace_state_clean = _src_core.trace_state_clean typecheck = _src_core.typecheck - typecompat = _src_core.typecompat typematch = _src_core.typematch - used_axis_names_jaxpr = _src_core.used_axis_names_jaxpr else: from jax._src.deprecations import deprecation_getattr as _deprecation_getattr __getattr__ = _deprecation_getattr(__name__, _deprecations) diff --git a/jax/dlpack.py b/jax/dlpack.py index a65496ec0cbf..d008608fc356 100644 --- a/jax/dlpack.py +++ b/jax/dlpack.py @@ -12,8 +12,35 @@ # See the License for the specific language governing permissions and # limitations under the License. + +import jax._src.dlpack +import jax._src.deprecations + from jax._src.dlpack import ( - to_dlpack as to_dlpack, from_dlpack as from_dlpack, SUPPORTED_DTYPES as SUPPORTED_DTYPES, ) + +_deprecations = { + "to_dlpack": ( + ( + "jax.dlpack.to_dlpack was deprecated in JAX v0.6.0 and will be" + " removed in JAX v0.7.0. Please use the newer DLPack API based on" + " __dlpack__ and __dlpack_device__ instead. Typically, you can pass" + " a JAX array directly to the `from_dlpack` function of another" + " framework without using `to_dlpack`." + ), + jax._src.dlpack.to_dlpack, + ), +} + + +import typing as _typing + +if _typing.TYPE_CHECKING: + to_dlpack = jax._src.dlpack.to_dlpack +else: + __getattr__ = jax._src.deprecations.deprecation_getattr( + __name__, _deprecations + ) +del _typing diff --git a/jax/experimental/__init__.py b/jax/experimental/__init__.py index 375d058d0edc..6c37635df1b0 100644 --- a/jax/experimental/__init__.py +++ b/jax/experimental/__init__.py @@ -19,6 +19,10 @@ enable_x64 as enable_x64, disable_x64 as disable_x64, ) +from jax._src.api import ( + saved_input_vjp as saved_input_vjp, + si_vjp as si_vjp +) from jax._src.callback import ( io_callback as io_callback ) diff --git a/jax/experimental/array_serialization/BUILD b/jax/experimental/array_serialization/BUILD index ab1ee3fd393e..84e5b9300912 100644 --- a/jax/experimental/array_serialization/BUILD +++ b/jax/experimental/array_serialization/BUILD @@ -45,7 +45,7 @@ jax_multiplatform_test( name = "serialization_test", srcs = ["serialization_test.py"], enable_configs = [ - "tpu_v3_2x2", + "tpu_v3_x4", ], deps = [ "//jax:experimental", diff --git a/jax/experimental/array_serialization/serialization_test.py b/jax/experimental/array_serialization/serialization_test.py index 9f4539fc63c8..280c2f58b348 100644 --- a/jax/experimental/array_serialization/serialization_test.py +++ b/jax/experimental/array_serialization/serialization_test.py @@ -26,7 +26,7 @@ from jax._src import config from jax._src import test_util as jtu from jax._src import array -from jax.sharding import NamedSharding, GSPMDSharding, SingleDeviceSharding +from jax._src.sharding_impls import NamedSharding, GSPMDSharding, SingleDeviceSharding from jax.sharding import PartitionSpec as P from jax.experimental.array_serialization import serialization from jax.experimental.layout import Layout, DeviceLocalLayout as DLL @@ -620,7 +620,7 @@ def test_deserialization_with_int4(self): ckpt_dir = pathlib.Path(self.create_tempdir('test_ckpt').full_path) # Run serialization. - sharding = jax.sharding.GSPMDSharding.get_replicated(jax.devices()) + sharding = GSPMDSharding.get_replicated(jax.devices()) tspecs = jax.tree_util.tree_map( serialization.get_tensorstore_spec, [ckpt_dir] ) diff --git a/jax/experimental/attrs.py b/jax/experimental/attrs.py index 4e1dc4b8f493..0d40938a85c4 100644 --- a/jax/experimental/attrs.py +++ b/jax/experimental/attrs.py @@ -16,6 +16,7 @@ from typing import Any, Callable +import jax from jax._src import core from jax._src import source_info_util from jax._src import api_util @@ -32,19 +33,31 @@ map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip +Array = Any JaxVal = Any Pytree = Any +ReadWrite = pe.ReadWrite +Append = pe.Append + register = api_util.register_class_with_attrs +dne_sentinel = pe.dne_sentinel -def jax_getattr(obj: Any, attr: str): +def jax_getattr(obj: Any, attr: str) -> Pytree: with core.take_current_trace() as t: return t.process_getattr(obj, attr) -def jax_setattr(obj: Any, attr: str, val: Pytree): +def jax_setattr(obj: Any, attr: str, val: Pytree) -> None: with core.take_current_trace() as t: return t.process_setattr(obj, attr, val) +def jax_appendattr(obj: Any, attr: str, val: Array) -> None: + return jax_extendattr(obj, attr, jax.numpy.expand_dims(val, 0)) + +def jax_extendattr(obj: Any, attr: str, val: Array) -> None: + with core.take_current_trace() as t: + return t.process_extendattr(obj, attr, val) + def _getattr_impl(_, obj, attr): return getattr(obj, attr) core.EvalTrace.process_getattr = _getattr_impl @@ -53,6 +66,25 @@ def _setattr_impl(_, obj, attr, val): setattr(obj, attr, val) core.EvalTrace.process_setattr = _setattr_impl +def _extendattr_impl(_, obj, attr, val): + cur = getattr(obj, attr, dne_sentinel) + if cur is dne_sentinel: + new = val + else: + _check_append_type_agreement(obj, attr, core.typeof(cur), core.typeof(val)) + new = jax.numpy.concatenate([cur, val]) + setattr(obj, attr, new) +core.EvalTrace.process_extendattr = _extendattr_impl + +def _check_append_type_agreement(_, attr, curtype, valtype): + expected = core.mapped_aval(curtype.shape[0], 0, curtype) + got = core.mapped_aval(valtype.shape[0], 0, valtype) + if not core.typematch(expected, got): + raise TypeError( + f"can only append to attr {attr} with values of trailing shape " + f"{expected.str_short()}, but appendattr got value of type " + f"{valtype.str_short()} which has trailing shape {got.str_short()}.") + def _ensure_tracked(trace: pe.DynamicJaxprTrace, obj: Any, attr: str): frame = trace.frame @@ -64,13 +96,16 @@ def new_tracer(x): frame.tracers.append(tracer) return tracer - if (obj, attr) not in frame.attrs_tracked: - init_val = getattr(obj, attr) + if (obj, attr, Append) in frame.attrs_tracked: + raise TypeError(f"can't read/write to append-only attr {attr}") + + if (obj, attr, ReadWrite) not in frame.attrs_tracked: + init_val = getattr(obj, attr, dne_sentinel) frame.attrs_inits.append(init_val) init_vals, init_tree = tree_flatten(init_val) tracers = map(new_tracer, init_vals) setattr(obj, attr, tree_unflatten(init_tree, tracers)) - frame.attrs_tracked.append((obj, attr)) + frame.attrs_tracked.append((obj, attr, ReadWrite)) pe.DynamicJaxprTrace._ensure_tracked = _ensure_tracked def _getattr_staging(trace, obj, attr): @@ -83,6 +118,27 @@ def _setattr_staging(trace, obj, attr, val): setattr(obj, attr, val) pe.DynamicJaxprTrace.process_setattr = _setattr_staging +def _extendattr_staging(trace, obj, attr, val): + frame = trace.frame + + if (obj, attr, ReadWrite) in frame.attrs_tracked: + raise TypeError("can't append to read/write-only attr {attr}") + + first_write = (obj, attr, Append) not in frame.attrs_tracked + init_val = getattr(obj, attr, dne_sentinel) + if init_val is not dne_sentinel: + _check_append_type_agreement(obj, attr, core.typeof(init_val), core.typeof(val)) + if first_write: + frame.attrs_inits.append(init_val) + frame.attrs_tracked.append((obj, attr, Append)) + tracer = val + else: + assert init_val is not dne_sentinel + with core.set_current_trace(trace): + tracer = jax.numpy.concatenate([init_val, val]) + setattr(obj, attr, tracer) +pe.DynamicJaxprTrace.process_extendattr = _extendattr_staging + def jvp(f, primals, tangents, attr_tangents): attrs, attr_tangents = unzip2(((o, a), t) for o, a, t in attr_tangents) diff --git a/jax/experimental/colocated_python/api.py b/jax/experimental/colocated_python/api.py index b855bba48abb..81db9b965e7c 100644 --- a/jax/experimental/colocated_python/api.py +++ b/jax/experimental/colocated_python/api.py @@ -28,6 +28,15 @@ def colocated_cpu_devices( devices: Sequence[jax.Device], ) -> Sequence[jax.Device]: """Finds CPU devices colocated with the given devices.""" + if not isinstance(devices, tuple): + devices = tuple(devices) + return _colocated_cpu_devices_cached(devices) + + +@jax._src.util.cache(max_size=1024, trace_context_in_key=False) +def _colocated_cpu_devices_cached( + devices: tuple[jax.Device, ...], +) -> Sequence[jax.Device]: cpu_devices_by_colocation_id = collections.defaultdict(list) for device in devices[0].client._get_all_devices(): # pylint: disable=protected-access if device.device_kind == "cpu": diff --git a/jax/experimental/colocated_python/func.py b/jax/experimental/colocated_python/func.py index effca1fe77b7..b7188d9da7ad 100644 --- a/jax/experimental/colocated_python/func.py +++ b/jax/experimental/colocated_python/func.py @@ -279,7 +279,7 @@ def _make_async_execution_fun( ) -@jax.util.cache(max_size=None) +@jax._src.util.cache(max_size=None) def _get_specialized_func( info: FunctionInfo, specialization: Specialization ) -> Callable[..., Any]: diff --git a/jax/experimental/colocated_python/obj.py b/jax/experimental/colocated_python/obj.py index b1e7a0b1eade..d7d40e88f925 100644 --- a/jax/experimental/colocated_python/obj.py +++ b/jax/experimental/colocated_python/obj.py @@ -58,7 +58,7 @@ def pop_instance(self, uid: int) -> set[jax.Device]: SINGLETON_INSTANCE_REGISTRY = _InstanceRegistry() -@jax.util.cache(max_size=4096) +@jax._src.util.cache(max_size=4096) def _update_instance_devices( uid: int, shardings: tuple[jax.sharding.Sharding, ...] ) -> None: diff --git a/jax/experimental/colocated_python/serialization.py b/jax/experimental/colocated_python/serialization.py index 1ca29ab12660..a8a62d78359f 100644 --- a/jax/experimental/colocated_python/serialization.py +++ b/jax/experimental/colocated_python/serialization.py @@ -35,7 +35,7 @@ DeviceList = xc.DeviceList -@jax.util.cache(max_size=None) +@jax._src.util.cache(max_size=None) def _get_cpu_device_map() -> dict[int, jax.Device]: """Returns a map from a device id to a matching device.""" cpu_device_map: dict[int, jax.Device] = {} diff --git a/jax/experimental/host_callback.py b/jax/experimental/host_callback.py index 7d60f62e230f..4dcd9f66a961 100644 --- a/jax/experimental/host_callback.py +++ b/jax/experimental/host_callback.py @@ -16,7 +16,7 @@ .. warning:: The host_callback APIs are deprecated as of March 20, 2024. The functionality is subsumed by the - `new JAX external callbacks `_ + `new JAX external callbacks `_ See https://github.com/jax-ml/jax/issues/20385. """ diff --git a/jax/experimental/jax2tf/README.md b/jax/experimental/jax2tf/README.md index 0d827fbcc7a5..cb1c97bc7b7c 100644 --- a/jax/experimental/jax2tf/README.md +++ b/jax/experimental/jax2tf/README.md @@ -138,7 +138,7 @@ f_tf_graph = tf.function(f_tf, autograph=False) ``` Note that when using the default native serialization, the target JAX function -must be jittable (see [JAX - The Sharp Bits](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html)). +must be jittable (see [JAX - The Sharp Bits](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html)). In the native serialization mode, under TensorFlow eager the whole JAX function executes as one op. @@ -461,7 +461,7 @@ presence of shape polymorphism, some dimensions may be dimension variables. The `polymorphic_shapes` parameter must be either `None`, or a pytree of shape specifiers corresponding to the pytree of arguments. (A value `None` for `polymorphic_shapes` is equivalent to a list of `None`. -See [how optional parameters are matched to arguments](https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees).) +See [how optional parameters are matched to arguments](https://docs.jax.dev/en/latest/pytrees.html#applying-optional-parameters-to-pytrees).) A shape specifier is combined with a `TensorSpec` as follows: * A shape specifier of `None` means that the shape is given @@ -568,6 +568,7 @@ because the shape abstraction that JAX tracing uses is given by the actual arguments are more specific and would actually work. Also, + ```python jax2tf.convert(lambda x: jnp.matmul(x, x), polymorphic_shapes=["(v, 4)"])(np.ones((4, 4))) @@ -808,6 +809,7 @@ TypeError: add got incompatible shapes for broadcasting: (a,), (floordiv(b, 2),) ``` You can fix this by adding a constraint: + ```python jax2tf.convert(lambda x, y: x + y[:y.shape[0] // 2], polymorphic_shapes=("a", "b"), @@ -826,12 +828,12 @@ For example, the following code will fail because `a1` and `a2` use different scopes (created by `export.symbolic_shape`): -````python +```python a1, = export.symbolic_shape("a,") a2, = export.symbolic_shape("a,", constraints=("a >= 8",)) a1 + a2 -```` +``` The symbolic expressions that originate from a single call to `export.symbolic_shape` share a scope and @@ -1024,7 +1026,7 @@ always behaves like the JAX function. JAX interprets the type of Python scalars differently based on `JAX_ENABLE_X64` flag. (See -[JAX - The Sharp Bits: Double (64bit) precision](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision).) +[JAX - The Sharp Bits: Double (64bit) precision](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision).) In the default configuration, the flag is unset, and JAX interprets Python constants as 32-bit, e.g., the type of `3.14` is `float32`. This is also what @@ -1086,7 +1088,7 @@ Applies to both native and non-native serialization. `jax2tf` can lower functions with arguments and results that are nested collections (tuples, lists, dictionaries) of numeric values or JAX arrays -([pytrees](https://jax.readthedocs.io/en/latest/pytrees.html)). The +([pytrees](https://docs.jax.dev/en/latest/pytrees.html)). The resulting TensorFlow function will take the same kind of arguments except the leaves can be numeric values or TensorFlow tensors (`tf.Tensor`, `tf.TensorSpec`, `tf.Variable`). @@ -1285,7 +1287,7 @@ per PRNG operation. The "unsafe" part is that it doesn't guarantee determinism across JAX/XLA versions, and the quality of random streams it generates from different keys is less well understood. Nevertheless, this should be fine for most inference/serving cases. -See more details in the [JAX PRNG documentation](https://jax.readthedocs.io/en/latest/jax.random.html?highlight=unsafe_rbg#advanced-rng-configuration). +See more details in the [JAX PRNG documentation](https://docs.jax.dev/en/latest/jax.random.html?highlight=unsafe_rbg#advanced-rng-configuration). ### SavedModel supports only first-order gradients diff --git a/jax/experimental/jax2tf/call_tf.py b/jax/experimental/jax2tf/call_tf.py index 98c1c20cd6e5..bb2af54025bc 100644 --- a/jax/experimental/jax2tf/call_tf.py +++ b/jax/experimental/jax2tf/call_tf.py @@ -41,11 +41,14 @@ from jax._src import effects from jax._src import util from jax._src.lib import xla_client +from jax._src.lib import xla_extension as _xla +from jax._src.lib import jaxlib_extension_version from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import func as func_dialect from jax._src.lib.mlir.dialects import hlo from jax.experimental.jax2tf import jax2tf as jax2tf_internal from jax._src.interpreters import mlir +import ml_dtypes import numpy as np import tensorflow as tf @@ -345,8 +348,7 @@ def _arg_jax_to_tf(arg_jax): if (isinstance(arg_jax, jax.Array) and list(arg_jax.devices())[0].platform in _DLPACK_PLATFORMS and arg_jax.dtype.type in dlpack.SUPPORTED_DTYPES): - arg_dlpack = jax.dlpack.to_dlpack(arg_jax) - return tf.experimental.dlpack.from_dlpack(arg_dlpack) + return tf.experimental.dlpack.from_dlpack(arg_jax.__dlpack__()) # The following avoids copies to the host on CPU, always for Array # and even for ndarray if they are sufficiently aligned. # TODO(necula): on TPU this copies to the host! @@ -468,6 +470,47 @@ def is_fully_known_shape(s): call_tf_p.def_effectful_abstract_eval(_call_tf_abstract_eval) +def _mlir_type_to_numpy_dtype(type: ir.Type) -> np.dtype: + """Converts an MLIR scalar type to a NumPy dtype.""" + + if ir.IntegerType.isinstance(type): + type = ir.IntegerType(type) + width = type.width + if width == 1: + return np.dtype(np.bool_) + elif width == 8: + return np.dtype(np.uint8 if type.is_unsigned else np.int8) + elif width == 16: + return np.dtype(np.uint16 if type.is_unsigned else np.int16) + elif width == 32: + return np.dtype(np.uint32 if type.is_unsigned else np.int32) + elif width == 64: + return np.dtype(np.uint64 if type.is_unsigned else np.int64) + else: + raise ValueError(f"Unsupported integer width: {width}") + + elif ir.F16Type.isinstance(type): + return np.dtype(np.float16) + elif ir.F32Type.isinstance(type): + return np.dtype(np.float32) + elif ir.F64Type.isinstance(type): + return np.dtype(np.float64) + elif ir.BF16Type.isinstance(type): + return np.dtype(ml_dtypes.bfloat16) + + elif ir.ComplexType.isinstance(type): + element_type = ir.ComplexType(type).element_type + if ir.F32Type.isinstance(element_type): + return np.dtype(np.complex64) + elif ir.F64Type.isinstance(element_type): + return np.dtype(np.complex128) + else: + raise ValueError(f"Unsupported complex element type: {element_type}") + + else: + raise TypeError(f"Unsupported MLIR type for NumPy conversion: {type}") + + def _call_tf_lowering( ctx: mlir.LoweringRuleContext, *args_op, @@ -555,33 +598,12 @@ def convert_to_spec(x): "\n\nCaught TensorFlow exception: " + str(e)) raise ValueError(msg) from e - xla_comp = xla_client.XlaComputation(func_tf_hlo) - - # Canonicalize the results; e.g., makes them x32 if JAX is in 32-bit mode - def canonical_res_aval(res_shape: xla_client.Shape) -> core.ShapedArray: - if not res_shape.is_static(): - msg = ("Compiled TensorFlow function has dynamic output shape " + - f"{res_shape}. call_tf can used " + - "in a staged context (under jax.jit, lax.scan, etc.) only with " + - "compilable functions with static output shapes. " + - "See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf for a discussion.") - raise ValueError(msg) - - res_dtype = res_shape.numpy_dtype() - jax_res_dtype = dtypes.canonicalize_dtype(res_dtype) - return core.ShapedArray(res_shape.dimensions(), jax_res_dtype) - - result_shape = xla_comp.program_shape().result_shape() - if not result_shape.is_tuple(): - # TF does not wrap singletons as tuples, but JAX expects tuples because - # call_tf is a multiple_results primitive. - result_shapes = (result_shape,) + if jaxlib_extension_version >= 324: + stablehlo = _xla.mlir.hlo_to_stablehlo(func_tf_hlo) else: - result_shapes = result_shape.tuple_shapes() # type: ignore - - result_avals = tuple(map(canonical_res_aval, result_shapes)) - - submodule = mlir.xla_computation_to_mlir_module(xla_comp) + xla_comp = xla_client.XlaComputation(func_tf_hlo) + stablehlo = _xla.mlir.xla_computation_to_mlir_module(xla_comp) + submodule = ir.Module.parse(stablehlo) symtab = ir.SymbolTable(submodule.operation) callee_result_types = symtab["main"].type.results fn = mlir.merge_mlir_modules(ctx.module_context.module, @@ -600,10 +622,26 @@ def canonical_res_aval(res_shape: xla_client.Shape) -> core.ShapedArray: ) outputs = [] - for op, res_aval, res_shape in zip(flat_results, result_avals, - result_shapes): - if res_aval.dtype != res_shape.numpy_dtype(): - op = hlo.ConvertOp(mlir.aval_to_ir_type(res_aval), op).result + for op, res_type in zip(flat_results, callee_result_types): + if not res_type.has_static_shape: + msg = ( + "Compiled TensorFlow function has dynamic output shape " + + f"{res_type}. call_tf can used in a staged context (under jax.jit," + " lax.scan, etc.) only with compilable functions with static" + " output shapes. See" + " https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf" + " for a discussion." + ) + raise ValueError(msg) + + res_dtype = _mlir_type_to_numpy_dtype(res_type.element_type) + # Canonicalize the results; e.g., makes them x32 if JAX is in 32-bit mode + jax_res_dtype = dtypes.canonicalize_dtype(res_dtype) + if res_dtype != jax_res_dtype: + op = hlo.ConvertOp( + mlir.aval_to_ir_type(core.ShapedArray(res_type.shape, jax_res_dtype)), + op, + ).result outputs.append(op) return outputs diff --git a/jax/experimental/jax2tf/g3doc/no_xla_limitations.md b/jax/experimental/jax2tf/g3doc/no_xla_limitations.md index 24a1d62ee67e..af092a218805 100644 --- a/jax/experimental/jax2tf/g3doc/no_xla_limitations.md +++ b/jax/experimental/jax2tf/g3doc/no_xla_limitations.md @@ -24,7 +24,7 @@ partial support. For a detailed description of these XLA ops, please see the [XLA Operation Semantics documentation](https://www.tensorflow.org/xla/operation_semantics). -| XLA ops ([documentation](https://www.tensorflow.org/xla/operation_semantics)) | JAX primitive(s) ([documentation](https://jax.readthedocs.io/en/latest/jax.lax.html)) | Supported | +| XLA ops ([documentation](https://www.tensorflow.org/xla/operation_semantics)) | JAX primitive(s) ([documentation](https://docs.jax.dev/en/latest/jax.lax.html)) | Supported | | ------- | ---------------- | ------- | | XlaDot | `lax.dot_general` | Full | | XlaDynamicSlice | `lax.dynamic_slice` | Full | @@ -47,7 +47,7 @@ support and which not. ### XlaConv JAX convolutions are done using -[`lax.conv_general_dilated`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv_general_dilated.html). +[`lax.conv_general_dilated`](https://docs.jax.dev/en/latest/_autosummary/jax.lax.conv_general_dilated.html). ``` lax.conv_general_dilated( @@ -88,7 +88,7 @@ instance, parallelization primitives `vmap` and `pmap` use gather to specify a batch dimension, and it is used for slices or multidimensional indexing as well, e.g. `x[0, 1]`, `x[:, :1]`, or `x[[0], [1]]`. -The signature of [`lax.gather`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.gather.html#jax.lax.gather) +The signature of [`lax.gather`](https://docs.jax.dev/en/latest/_autosummary/jax.lax.gather.html#jax.lax.gather) is as follows: ``` @@ -128,7 +128,7 @@ All other cases of `lax.gather` are currently not supported. ### XlaReduceWindow -The signature of [`lax.reduce_window`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.reduce_window.html) +The signature of [`lax.reduce_window`](https://docs.jax.dev/en/latest/_autosummary/jax.lax.reduce_window.html) is as follows: ``` diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 7f98ce433815..4c2f35a95c57 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -272,7 +272,7 @@ def convert(fun_jax: Callable, should be `None` (monomorphic argument), or a Python object with the same pytree structure as the argument. See [how optional parameters are matched to - arguments](https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees). + arguments](https://docs.jax.dev/en/latest/pytrees.html#applying-optional-parameters-to-pytrees). A shape specification for an array argument should be an object `PolyShape(dim0, dim1, ..., dimn)` @@ -1666,17 +1666,18 @@ def _integer_pow(x, *, y: int, _in_avals: Sequence[core.ShapedArray], tf_impl_with_avals[lax.integer_pow_p] = _integer_pow -tf_impl[lax.exp_p] = tf.math.exp -tf_impl[lax_internal.exp2_p] = lambda x: \ - tf.math.exp(tf.math.multiply(tf.math.log(tf.constant(2, x.dtype)), x)) -tf_impl[lax.expm1_p] = tf.math.expm1 -tf_impl[lax.log_p] = tf.math.log -tf_impl[lax.log1p_p] = tf.math.log1p -tf_impl[lax.tan_p] = tf.math.tan -tf_impl[lax.tanh_p] = tf.math.tanh -tf_impl[lax.sin_p] = tf.math.sin +tf_impl[lax.exp_p] = lambda x, accuracy: tf.math.exp(x) +tf_impl[lax_internal.exp2_p] = lambda x, accuracy: tf.math.exp( + tf.math.multiply(tf.math.log(tf.constant(2, x.dtype)), x) +) +tf_impl[lax.expm1_p] = lambda x, accuracy: tf.math.expm1(x) +tf_impl[lax.log_p] = lambda x, accuracy: tf.math.log(x) +tf_impl[lax.log1p_p] = lambda x, accuracy: tf.math.log1p(x) +tf_impl[lax.tan_p] = lambda x, accuracy: tf.math.tan(x) +tf_impl[lax.tanh_p] = lambda x, accuracy: tf.math.tanh(x) +tf_impl[lax.sin_p] = lambda x, accuracy: tf.math.sin(x) tf_impl[lax.sinh_p] = tf.math.sinh -tf_impl[lax.cos_p] = tf.math.cos +tf_impl[lax.cos_p] = lambda x, accuracy: tf.math.cos(x) tf_impl[lax.cosh_p] = tf.math.cosh tf_impl_with_avals[lax.atan_p] = _convert_jax_impl( lax_internal.atan_impl, multiple_results=False) @@ -1706,11 +1707,11 @@ def _atan2(y, x, **kwargs): tf_impl[lax.asin_p] = tf.math.asin tf_impl[lax.acos_p] = tf.math.acos -tf_impl[lax.sqrt_p] = tf.math.sqrt +tf_impl[lax.sqrt_p] = lambda x, accuracy: tf.math.sqrt(x) tf_impl[lax.square_p] = tf.math.square -tf_impl[lax.rsqrt_p] = tf.math.rsqrt +tf_impl[lax.rsqrt_p] = lambda x, accuracy: tf.math.rsqrt(x) -def _cbrt(x): +def _cbrt(x, accuracy): return tf.math.sign(x) * tf.math.pow(tf.math.abs(x), 1/3) tf_impl[lax.cbrt_p] = _cbrt @@ -2822,7 +2823,8 @@ def _threefry2x32_jax_impl(*args: TfVal, _in_avals, _out_aval): multiple_results=False, extra_name_stack="random_gamma") -def _rng_bit_generator(key: TfVal, *, shape, dtype, algorithm) -> Sequence[TfVal]: +def _rng_bit_generator(key: TfVal, *, shape, dtype, algorithm, + out_sharding) -> Sequence[TfVal]: is_uint32_key = key.dtype == _to_tf_dtype(jnp.uint32) if is_uint32_key: key = tf.reshape(key, (2, 2)) @@ -3171,12 +3173,11 @@ def select_one_carry(new_c: TfVal, c: TfVal, c_aval: core.ShapedArray) -> TfVal: lax_control_flow._scan_impl, extra_name_stack="scan") -tf_impl_with_avals[ad_checkpoint.remat_p] = \ - _convert_jax_impl(partial(ad_checkpoint.remat_expansion, - # TODO: jax2tf cannot discriminate by platform - is_gpu_platform=False), - multiple_results=True, - extra_name_stack="checkpoint") +tf_impl_with_avals[ad_checkpoint.remat_p] = _convert_jax_impl( + ad_checkpoint.remat_expansion, + multiple_results=True, + extra_name_stack="checkpoint", +) tf_impl[ad_checkpoint.name_p] = lambda x, *, name: x diff --git a/jax/experimental/jax2tf/tests/jax2tf_test.py b/jax/experimental/jax2tf/tests/jax2tf_test.py index bea2b76cb7cf..b40b1a6d5571 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_test.py +++ b/jax/experimental/jax2tf/tests/jax2tf_test.py @@ -832,11 +832,7 @@ def f(x1): arg = np.array(3.) f_tf = jax2tf.convert(jax.grad(remat_f)) f_tf_hlo = self.TfToHlo(f_tf, arg) - if config.remat_opt_barrier.value: - self.assertRegex(f_tf_hlo, r"opt-barrier") - else: - self.assertRegex(f_tf_hlo, - r'transpose/jax2tf_f_/jvp/checkpoint/cond/branch_1_fun/Sin') + self.assertRegex(f_tf_hlo, r"opt-barrier") def test_remat_free_var(self): def f(x): diff --git a/jax/experimental/jax2tf/tests/primitives_test.py b/jax/experimental/jax2tf/tests/primitives_test.py index 1ccd009f157c..74e4ddc8136d 100644 --- a/jax/experimental/jax2tf/tests/primitives_test.py +++ b/jax/experimental/jax2tf/tests/primitives_test.py @@ -172,8 +172,14 @@ def test_primitive_coverage(self): continue if p.name == "composite": continue + if p.name == "pvary": + continue + if p.name == "psum_invariant": + continue if p.name == "sharding_constraint": continue + if p.name == "dll_constraint": + continue if p.name == "mesh_cast": continue if p.name == "reshard": diff --git a/jax/experimental/jax2tf/tests/shape_poly_test.py b/jax/experimental/jax2tf/tests/shape_poly_test.py index 09da97e8420a..17d03bc8c778 100644 --- a/jax/experimental/jax2tf/tests/shape_poly_test.py +++ b/jax/experimental/jax2tf/tests/shape_poly_test.py @@ -595,7 +595,7 @@ def conv_and_run(*, arg_shape: core.Shape, "Using the following polymorphic shapes specifications: args[0].shape = (2*b + a, a, c + b + a). " "Obtained dimension variables: 'a' = 2 from specification 'a' for dimension args[0].shape[1] (= 2), " "'b' = 0 from specification '2*b + a' for dimension args[0].shape[0] (= 2), . " - "Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details." + "Please see https://docs.jax.dev/en/latest/export/shape_poly.html#shape-assertion-errors for more details." )), dict(shape=(3, 2, 6), # a = 2, b = 0.5, c = 4 - b is not integer poly_spec="(a + 2*b, a, a + b + c)", @@ -604,7 +604,7 @@ def conv_and_run(*, arg_shape: core.Shape, "Division had remainder 1 when computing the value of 'b'. " "Using the following polymorphic shapes specifications: args[0].shape = (2*b + a, a, c + b + a). " "Obtained dimension variables: 'a' = 2 from specification 'a' for dimension args[0].shape[1] (= 2), . " - "Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details." + "Please see https://docs.jax.dev/en/latest/export/shape_poly.html#shape-assertion-errors for more details." )), dict(shape=(8, 2, 6), # a = 2, b = 3 - inconsistency poly_spec="(a + 2*b, a, a + b)", @@ -614,7 +614,7 @@ def conv_and_run(*, arg_shape: core.Shape, "Using the following polymorphic shapes specifications: args[0].shape = (2*b + a, a, b + a). " "Obtained dimension variables: 'a' = 2 from specification 'a' for dimension args[0].shape[1] (= 2), " "'b' = 4 from specification 'b + a' for dimension args[0].shape[2] (= 6), . " - "Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details." + "Please see https://docs.jax.dev/en/latest/export/shape_poly.html#shape-assertion-errors for more details." )), dict(shape=(7, 2, 36), # a = 2, b = 3, c = 6 - cannot solve c poly_spec="(2 * a + b, a, c * c)", @@ -623,7 +623,7 @@ def conv_and_run(*, arg_shape: core.Shape, "We can only solve linear uni-variate constraints. " "Using the following polymorphic shapes specifications: args[0].shape = (b + 2*a, a, c^2). " "Unprocessed specifications: 'c^2' for dimension size args[0].shape[2]. " - "Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#dimension-variables-must-be-solvable-from-the-input-shapes for more details." + "Please see https://docs.jax.dev/en/latest/export/shape_poly.html#dimension-variables-must-be-solvable-from-the-input-shapes for more details." )), ]) def test_shape_constraints_errors(self, *, diff --git a/jax/experimental/jet.py b/jax/experimental/jet.py index 15273f0fd02a..acf8885b0f98 100644 --- a/jax/experimental/jet.py +++ b/jax/experimental/jet.py @@ -76,7 +76,7 @@ from jax._src.util import unzip2, weakref_lru_cache, safe_zip -def jet(fun, primals, series): +def jet(fun, primals, series, **_): r"""Taylor-mode higher-order automatic differentiation. Args: @@ -405,11 +405,11 @@ def deriv_prop(prim, deriv, primals_in, series_in): lax.exp(lax.neg(lax.square(x))))) -def def_comp(prim, comp): +def def_comp(prim, comp, **kwargs): """ Define the jet rule for a primitive in terms of a composition of simpler primitives. """ - jet_rules[prim] = partial(jet, comp) + jet_rules[prim] = partial(jet, comp, **kwargs) def_comp(lax.expm1_p, lambda x: lax.exp(x) - 1) @@ -478,7 +478,7 @@ def _scale(k, j): def _scale2(k, j): return 1. / (fact(k - j) * fact(j)) -def _exp_taylor(primals_in, series_in): +def _exp_taylor(primals_in, series_in, **_): x, = primals_in series, = series_in u = [x] + series @@ -522,7 +522,7 @@ def _integer_pow_taylor(primals_in, series_in, *, y): jet_rules[lax.integer_pow_p] = _integer_pow_taylor -def _logistic_taylor(primals_in, series_in): +def _logistic_taylor(primals_in, series_in, **_): x, = primals_in series, = series_in u = [x] + series @@ -538,7 +538,7 @@ def _logistic_taylor(primals_in, series_in): jet_rules[lax.logistic_p] = _logistic_taylor -def _tanh_taylor(primals_in, series_in): +def _tanh_taylor(primals_in, series_in, **_): x, = primals_in series, = series_in u = [2*x] + [2 * series_ for series_ in series] @@ -548,7 +548,7 @@ def _tanh_taylor(primals_in, series_in): return 2 * primal_out - 1, series_out jet_rules[lax.tanh_p] = _tanh_taylor -def _log_taylor(primals_in, series_in): +def _log_taylor(primals_in, series_in, **_): x, = primals_in series, = series_in u = [x] + series @@ -590,7 +590,7 @@ def scale(k, j): return 1. / (fact(k - j) * fact(j)) return primal_out, series_out jet_rules[lax.div_p] = _div_taylor_rule -def _sinusoidal_rule(sign, prims, primals_in, series_in): +def _sinusoidal_rule(sign, prims, primals_in, series_in, **_): x, = primals_in series, = series_in u = [x] + series @@ -603,7 +603,7 @@ def _sinusoidal_rule(sign, prims, primals_in, series_in): return (s[0], s[1:]), (c[0], c[1:]) def _get_ind(f, ind): - return lambda *args: f(*args)[ind] + return lambda *args, **kwargs: f(*args, **kwargs)[ind] jet_rules[lax.sin_p] = _get_ind(partial(_sinusoidal_rule, -1, (lax.sin, lax.cos)), 0) jet_rules[lax.cos_p] = _get_ind(partial(_sinusoidal_rule, -1, (lax.sin, lax.cos)), 1) diff --git a/jax/experimental/layout.py b/jax/experimental/layout.py index ed9f8931938e..aa114a2803e8 100644 --- a/jax/experimental/layout.py +++ b/jax/experimental/layout.py @@ -14,5 +14,8 @@ from jax._src.layout import ( DeviceLocalLayout as DeviceLocalLayout, - Layout as Layout + Layout as Layout, +) +from jax._src.pjit import ( + with_dll_constraint as with_dll_constraint, ) diff --git a/jax/experimental/mesh_utils.py b/jax/experimental/mesh_utils.py index 075e4e6eed48..58d20c331d5f 100644 --- a/jax/experimental/mesh_utils.py +++ b/jax/experimental/mesh_utils.py @@ -1,4 +1,4 @@ -# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# Copyright 2021 The JAX Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/jax/experimental/mosaic/gpu/__init__.py b/jax/experimental/mosaic/gpu/__init__.py index d004c7deb3df..c1275396036c 100644 --- a/jax/experimental/mosaic/gpu/__init__.py +++ b/jax/experimental/mosaic/gpu/__init__.py @@ -23,7 +23,7 @@ Barrier as Barrier, ClusterBarrier as ClusterBarrier, TMABarrier as TMABarrier, - ThreadSemantics as ThreadSemantics, + LoweringSemantics as LoweringSemantics, TMEM as TMEM, Union as Union, as_gpu_kernel as as_gpu_kernel, @@ -32,6 +32,7 @@ from .launch_context import ( LaunchContext as LaunchContext, MemRefTransform as MemRefTransform, + ReductionOp as ReductionOp, Rounding as Rounding, TileTransform as TileTransform, TransposeTransform as TransposeTransform, @@ -45,6 +46,10 @@ infer_layout as infer_layout, ) +from .layouts import ( + to_layout_attr as to_layout_attr, +) + from .transform_inference import ( infer_transforms as infer_transforms, ) @@ -52,9 +57,11 @@ from .fragmented_array import ( FragmentedArray as FragmentedArray, FragmentedLayout as FragmentedLayout, + TiledLayout as TiledLayout, WGMMA_LAYOUT as WGMMA_LAYOUT, WGMMA_ROW_LAYOUT as WGMMA_ROW_LAYOUT, - WGMMARowFragLayout as WGMMARowFragLayout, + WGMMA_COL_LAYOUT as WGMMA_COL_LAYOUT, + WGMMA_TRANSPOSED_LAYOUT as WGMMA_TRANSPOSED_LAYOUT, WGSplatFragLayout as WGSplatFragLayout, WGStridedFragLayout as WGStridedFragLayout, optimization_barrier as optimization_barrier, @@ -65,6 +72,7 @@ DynamicSlice as DynamicSlice, Partition as Partition, Partition1D as Partition1D, + ThreadSubset as ThreadSubset, bitwidth as bitwidth, bytewidth as bytewidth, c as c, diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py index b255893e2e2e..860b41e7e8e3 100644 --- a/jax/experimental/mosaic/gpu/core.py +++ b/jax/experimental/mosaic/gpu/core.py @@ -83,6 +83,15 @@ # Set this so that the custom call can find it os.environ["MOSAIC_GPU_RUNTIME_LIB_PATH"] = str(RUNTIME_PATH) +if os.environ.get("MOSAIC_GPU_NVSHMEM_LLVM_LIB_PATH") is None: + try: + from nvidia import nvshmem + except ImportError: + pass + else: + os.environ["MOSAIC_GPU_NVSHMEM_LLVM_LIB_PATH"] = ( + os.path.join(nvshmem.__path__[0], 'lib/libnvshmem_device.bc') + ) mosaic_gpu_p = jax._src.core.Primitive("mosaic_gpu_p") mosaic_gpu_p.multiple_results = True @@ -103,7 +112,9 @@ def _mosaic_gpu_lowering_rule( module, out_types, input_output_aliases: tuple[tuple[int, int], ...] = (), + use_custom_barrier: bool = False, ): + assert len(args) == len(ctx.avals_in) assert len(out_types) == len(ctx.avals_out) module = _run_serde_pass( module, @@ -120,15 +131,35 @@ def _mosaic_gpu_lowering_rule( raise RuntimeError("Hash collision!") else: KNOWN_KERNELS[kernel_id] = module_asm - op = mlir.custom_call( - "mosaic_gpu", - result_types=[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out], - operands=args, - operand_layouts=[list(reversed(range(a.ndim))) for a in ctx.avals_in], - result_layouts=[list(reversed(range(a.ndim))) for a in ctx.avals_out], - backend_config=kernel_id + module_asm, - operand_output_aliases=dict(input_output_aliases), - ) + + if ctx.is_forward_compat(): + if use_custom_barrier: + raise ValueError("Barrier semaphore is not supported in forward compatibility mode. " + "Please, use 'export_ignore_forward_compatibility=True'.") + op = mlir.custom_call( + "mosaic_gpu", + result_types=[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out], + operands=args, + operand_layouts=[list(reversed(range(a.ndim))) for a in ctx.avals_in], + result_layouts=[list(reversed(range(a.ndim))) for a in ctx.avals_out], + backend_config=kernel_id + module, + operand_output_aliases=dict(input_output_aliases), + ) + else: + op = mlir.custom_call( + "mosaic_gpu_v2", + result_types=[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out], + operands=args, + operand_layouts=[list(reversed(range(a.ndim))) for a in ctx.avals_in], + result_layouts=[list(reversed(range(a.ndim))) for a in ctx.avals_out], + backend_config=dict( + kernel_hash=ir.StringAttr.get(kernel_id), + module=ir.StringAttr.get(module_asm), + use_custom_barrier=ir.BoolAttr.get(use_custom_barrier), + ), + operand_output_aliases=dict(input_output_aliases), + api_version=4, + ) return op.results @@ -178,7 +209,7 @@ def _count_buffer_bytes(shape_dtype: jax.ShapeDtypeStruct) -> int: return math.prod(shape_dtype.shape) * np.dtype(shape_dtype.dtype).itemsize -class ThreadSemantics(enum.Enum): +class LoweringSemantics(enum.Enum): """Semantics for the kernel's instruction stream.""" Lane = enum.auto() @@ -307,6 +338,8 @@ def _smem_tree_size(smem_buffers: ShapeTree) -> int: raise NotImplementedError("Misaligned barrier allocation") size += num_barriers * utils.MBARRIER_BYTES case TMEM(_): + # TODO(justinfu): This can trigger misaligned barrier allocations + # if TMEM is requested before barriers b/c it's not divisible by 8. size += 4 # i32 takes up 4 bytes case _: size += _count_buffer_bytes(l) @@ -446,6 +479,7 @@ def _shape_to_ref_ty(shape: jax.ShapeDtypeStruct) -> ir.MemRefType: out_ref_tys.append(prof_spec.mlir_buffer_type(grid, block)) module = ir.Module.create() + dialect.register_dialect(module.context) attrs = module.operation.attributes attrs["sym_name"] = ir.StringAttr.get(module_name) if kernel_name is None: @@ -562,7 +596,7 @@ def as_gpu_kernel( module_name: str = "unknown", kernel_name: str | None = None, ir_version: int | None = None, - thread_semantics: ThreadSemantics = ThreadSemantics.Lane, + thread_semantics: LoweringSemantics = LoweringSemantics.Lane, ): if isinstance(in_shape, list): in_shape = tuple(in_shape) @@ -576,7 +610,7 @@ def as_gpu_kernel( ) ) - if thread_semantics == ThreadSemantics.Warpgroup and dialect is not None: + if thread_semantics == LoweringSemantics.Warpgroup and dialect is not None: # Run Python lowering passes. The remaining passes will be run in C++ in # jax/jaxlib/mosaic/gpu/custom_call.cc layout_inference.infer_layout(module) # pytype: disable=attribute-error @@ -636,7 +670,7 @@ def as_torch_gpu_kernel( cluster: tuple[int, int, int] = (1, 1, 1), module_name: str = "unknown", kernel_name: str | None = None, - thread_semantics: ThreadSemantics = ThreadSemantics.Lane, + lowering_semantics: LoweringSemantics = LoweringSemantics.Lane, ): try: import torch @@ -659,7 +693,7 @@ def as_torch_gpu_kernel( ) ) - if thread_semantics == ThreadSemantics.Warpgroup and dialect is not None: + if lowering_semantics == LoweringSemantics.Warpgroup and dialect is not None: # Run Python lowering passes. The remaining passes will be run in C++ in # jax/jaxlib/mosaic/gpu/custom_call.cc layout_inference.infer_layout(module) # pytype: disable=attribute-error diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py index fedde5a00887..1239a20ba865 100644 --- a/jax/experimental/mosaic/gpu/dialect_lowering.py +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -14,13 +14,15 @@ """Lowering rules and pass for the MLIR Mosaic GPU dialect.""" -from collections.abc import Callable +from collections.abc import Callable, Iterable import dataclasses import functools import itertools +import math import operator from typing import Any, Sequence, Type, cast +from jax._src import lib as jaxlib from jax._src.interpreters import mlir as mlir_interpreter from jax._src.lib import mosaic_gpu_dialect as mgpu from jax._src.lib.mlir import ir @@ -34,6 +36,8 @@ from jax._src.lib.mlir.dialects import nvvm from jax._src.lib.mlir.dialects import scf from jax._src.lib.mlir.dialects import vector +from jax._src.util import safe_zip +from jax.experimental.mosaic.gpu import layouts as layouts_lib import numpy as np from . import fragmented_array as fa @@ -201,6 +205,38 @@ def _initialize_barrier_op_lowering_rule( barrier_base_ptr, initialize_barrier_op.barriers_ref.type), +# TODO(bchetioui): remove once minimum jaxlib >= 0.5.3. +OptimizationBarrierOp = getattr(mgpu, "OptimizationBarrierOp", None) + + +@_register_lowering(OptimizationBarrierOp) +def _optimization_barrier_op_lowering_rule( + _: LoweringContext, + op: OptimizationBarrierOp, +) -> Sequence[ir.Value]: + if not all(ir.VectorType.isinstance(operand.type) for operand in op.operands): + raise NotImplementedError( + f"Optimization barrier op {op} has non-vector operands." + ) + + fragmented_arrays = [] + for operand, layout in safe_zip(op.operands, inference_utils.in_layouts(op)): + ty = ir.VectorType(operand.type) + is_signed = False if ir.IntegerType.isinstance(ty.element_type) else None + fragmented_arrays.append( + _fragmented_array_from_ir(operand, layout, is_signed=is_signed) + ) + + lowered_fragmented_arrays = fa.optimization_barrier(*fragmented_arrays) + if isinstance(lowered_fragmented_arrays, fa.FragmentedArray): + lowered_fragmented_arrays = [lowered_fragmented_arrays] + + return [ + _fragmented_array_to_ir(arr, result.type) + for arr, result in safe_zip(lowered_fragmented_arrays, op.results) + ] + + @_register_lowering(arith.ConstantOp) def _arith_constant_op_lowering_rule( _: LoweringContext, op: arith.ConstantOp @@ -320,7 +356,7 @@ def _vector_load_op_lowering_rule( ) ref_ty = ir.MemRefType(vector_load_op.base.type) _check_transforms_and_swizzle_are_supported(ref_ty, transforms, swizzle) - transformed_ref = transform_memref(vector_load_op.base, transforms) + transformed_ref = reinterpret_smem_ref(vector_load_op.base, transforms) fragmented_array = fa.FragmentedArray.load_tiled( transformed_ref, swizzle=swizzle, @@ -362,7 +398,7 @@ def _vector_store_op_lowering_rule( ref_ty = ir.MemRefType(vector_store_op.base.type) _check_transforms_and_swizzle_are_supported(ref_ty, transforms, swizzle) fragmented_array.store_tiled( - transform_memref(vector_store_op.base, transforms), swizzle + reinterpret_smem_ref(vector_store_op.base, transforms), swizzle ) elif (isinstance(fragmented_array.layout, fa.WGStridedFragLayout) or isinstance(fragmented_array.layout, fa.WGSplatFragLayout)): @@ -434,6 +470,15 @@ def _vector_reduction_op_lowering_rule( return [_fragmented_array_to_ir(result, op.result.type)] +# TODO(dasenov): Remove this after the minimal jaxlib version is 0.5.4. +if jaxlib.version >= (0, 5, 4): + @_register_lowering(mgpu.LayoutCastOp) + def _mgpu_layout_cast_op_lowering_rule( + _: LoweringContext, layout_cast_op: mgpu.LayoutCastOp + ) -> Sequence[ir.Value]: + return [layout_cast_op.x] + + def swizzle_and_transforms_from_transforms_attr( transforms: ir.ArrayAttr, ) -> tuple[mgpu.SwizzlingMode, tuple[launch_context.MemRefTransform, ...]]: @@ -475,32 +520,78 @@ def swizzle_and_transforms_from_transforms_attr( return swizzle or mgpu.SwizzlingMode.kNoSwizzle, tuple(gmem_transforms) -def transform_memref( - mem_ref: ir.Value, transforms: tuple[launch_context.MemRefTransform, ...] +def _is_memref_transposed(mem_ref_type: ir.MemRefType) -> bool: + strides, _ = mem_ref_type.get_strides_and_offset() + prev_stride = math.inf + for stride in strides: + if stride > prev_stride: + return True + prev_stride = stride + return False + + +def reinterpret_smem_ref( + ref: ir.Value, + transforms: tuple[launch_context.MemRefTransform, ...], ) -> ir.Value: - """Reinterprets the memref to one where the shape is transformed as given.""" - if not transforms: - return mem_ref + """Applies transforms on the ref, and makes sure that their effect is + propagated appropriately on the strides. - mem_ref_type = ir.MemRefType(mem_ref.type) - if mem_ref_type.memory_space != ir.Attribute.parse( - "#gpu.address_space" - ): - raise ValueError(f"Only workgroup memory is supported but got {mem_ref}.") + This function is used any time we lower from a dialect SMEM ref (2D for wgmma) + with given transforms to a "physical" SMEM ref (4D for wgmma) that is fully + transformed and transposed as needed. + """ + ref_ty = ir.MemRefType(ref.type) + transposed = _is_memref_transposed(ref_ty) + if not transforms and not transposed: + return ref + + if ref_ty.memory_space != ir.Attribute.parse("#gpu.address_space"): + raise ValueError(f"Only workgroup memory is supported but got {ref}.") + + shape = ref_ty.shape + if transposed: + if len(shape) != 2: + raise NotImplementedError( + f"Only 2D shapes can be transposed, but got {shape}" + ) + strides, _ = ref_ty.get_strides_and_offset() + if strides[0] != 1 or strides[1] != shape[0]: + raise NotImplementedError( + f"Only contiguous 2D memrefs can be transposed, but got {ref_ty}" + ) - shape = mem_ref_type.shape for t in transforms: - shape = t.transform_shape(shape) + shape = list(t.transform_shape(shape)) + + if transposed: + # The expected output is a transposed ref and `shape` is already transposed. + # We need to compute the correct strides to match the shape. + if len(shape) == 2: + minor_to_major_stride_order = (1, 0) + elif len(shape) == 4: + minor_to_major_stride_order = (2, 3, 0, 1) + else: + raise NotImplementedError( + f"Expected a 2D or 4D shape after transforms, but got {shape}" + ) + strides = [1]*len(shape) + for i in minor_to_major_stride_order[1:]: + strides[i] = strides[i-1] * shape[i-1] + layout = ir.StridedLayoutAttr.get(0, strides) + else: + layout = None - memref_new_type = ir.MemRefType.get( + new_ref_ty = ir.MemRefType.get( shape, - mem_ref_type.element_type, - memory_space=mem_ref_type.memory_space, + ref_ty.element_type, + memory_space=ref_ty.memory_space, + layout=layout, ) - ms = utils.WORKGROUP_NVPTX_ADDRESS_SPACE - ptr = utils.memref_ptr(mem_ref, memory_space=ms) - return utils.ptr_as_memref(ptr, memref_new_type, ptr_memory_space=ms) + ptr = utils.memref_ptr(ref, memory_space=ms) + ref = utils.ptr_as_memref(ptr, new_ref_ty, ptr_memory_space=ms) + return ref @_register_lowering(mgpu.AsyncLoadOp) @@ -525,10 +616,16 @@ def _mgpu_async_load_op_lowering_rule( v = idx if size < 0 else utils.DynamicSlice(idx, size) gmem_slice.append(v) + # TODO(dasenov): async_copy requires all GMEM strides except the last one + # to be a multiple of 16 bytes. This restriction could be loosned with + # strided layouts when they are contiguous in GMEM. In that case, we could do: + # flatten -> async_copy -> unflatted here, as long as flattened size is a + # multiple of 16. + # TODO(dasenov): Add support for the remaining op properties. ctx.launch_context.async_copy( src_ref=load_op.source, - dst_ref=transform_memref(load_op.destination, transforms), + dst_ref=reinterpret_smem_ref(load_op.destination, transforms), gmem_slice=tuple(gmem_slice), barrier=barrier, arrive=False, @@ -561,9 +658,15 @@ def _mgpu_async_store_op_lowering_rule( v = idx if size < 0 else utils.DynamicSlice(idx, size) gmem_slice.append(v) + # TODO(dasenov): async_copy requires all GMEM strides except the last one + # to be a multiple of 16 bytes. This restriction could be loosned with + # strided layouts when they are contiguous in GMEM. In that case, we could do: + # flatten -> async_copy -> unflatted here, as long as flattened size is a + # multiple of 16. + # TODO(dasenov): Add support for the remaining op properties. ctx.launch_context.async_copy( - src_ref=transform_memref(store_op.source, transforms), + src_ref=reinterpret_smem_ref(store_op.source, transforms), dst_ref=store_op.destination, gmem_slice=tuple(gmem_slice), swizzle=swizzle, @@ -761,9 +864,6 @@ def _bitcast_op_lowering_rule( def _mgpu_wgmma_op_lowering_rule( _: LoweringContext, wgmma_op: mgpu.WGMMAOp ) -> Sequence[ir.Value]: - if wgmma_op.transpose_a or wgmma_op.transpose_b: - raise ValueError("Transpose arguments are to be deleted.") - fa_layouts = ( *inference_utils.in_layouts(wgmma_op), *inference_utils.out_layouts(wgmma_op), @@ -796,7 +896,7 @@ def _mgpu_wgmma_op_lowering_rule( _check_transforms_and_swizzle_are_supported( ref_ty, b_transforms, b_swizzle, minimum_swizzle ) - b_operand = transform_memref(wgmma_op.b, b_transforms) + b_operand = reinterpret_smem_ref(wgmma_op.b, b_transforms) if ir.VectorType.isinstance(wgmma_op.a.type): a_operand = _fragmented_array_from_ir(wgmma_op.a, wgmma_layout) @@ -813,7 +913,7 @@ def _mgpu_wgmma_op_lowering_rule( f"Non-matching swizzles of operands a and b in WGMMA: {a_swizzle} !=" f" {b_swizzle}" ) - a_operand = transform_memref(wgmma_op.a, a_transforms) + a_operand = reinterpret_smem_ref(wgmma_op.a, a_transforms) new_acc = wgmma.wgmma(acc, a_operand, b_operand, swizzle=b_swizzle) @@ -850,13 +950,9 @@ def _mgpu_wait_op_lowering_rule( return [] -# TODO(bchetioui): remove this once jaxlib minimum version >= 0.5.2. -SliceSMEMOp = getattr(mgpu, "SliceSMEMOp", None) - - -@_register_lowering(SliceSMEMOp) +@_register_lowering(mgpu.SliceSMEMOp) def _mgpu_slice_smem_op_lowering_rule( - ctx: LoweringContext, op: SliceSMEMOp + ctx: LoweringContext, op: mgpu.SliceSMEMOp ) -> Sequence[ir.Value]: del ctx return [_slice_smem(op.result.type, op.offset)] @@ -872,6 +968,66 @@ def _slice_smem(result: ir.Type, offset: ir.Value): return memref.view(result, smem_base, offset, []) +# The metadata needed to recostruct a vector from its flattened representation. +_VectorTemplate = tuple[Sequence[int], fa.FragmentedLayout, ir.VectorType] + + +def _flatten_ir_values( + values: Sequence[ir.Value], fa_layouts: Iterable[ir.Attribute] +) -> tuple[Sequence[ir.Value], Sequence[_VectorTemplate | None]]: + """Flattens a sequence of values. + + Non-vector values are preserved as is. Vectors are mapped to fragmented + arrays and then flattened into per-register values. + + Args: + values: The sequence of values to flatten. + fa_layouts: The layouts of vectors in ``values``. + + Returns: + A tuple of (flattened values, templates). The templates are used to + reconstruct the vectors from the per-register values. + """ + fa_layouts_it = iter(fa_layouts) + result = [] + templates = [] + for v in values: + if ir.VectorType.isinstance(v.type): + fa = _fragmented_array_from_ir(v, next(fa_layouts_it)) + result.extend(fa.registers.flat) + templates.append((fa.registers.shape, fa.layout, ir.VectorType(v.type))) + else: + result.append(v) + templates.append(None) + return result, templates + + +def _unflatten_ir_values( + flat_values: Sequence[ir.Value], templates: Sequence[_VectorTemplate | None] +) -> Sequence[ir.Value]: + """The inverse of ``_flatten_ir_values``.""" + result = [] + flat_values_it = iter(flat_values) + for template in templates: + if template is None: + result.append(next(flat_values_it)) + continue + registers_shape, layout, vec_type = template + value_registers = np.asarray( + [next(flat_values_it) for _ in range(math.prod(registers_shape))], + dtype=object, + ) + value = fa.FragmentedArray( + _registers=value_registers.reshape(registers_shape), + _layout=layout, + _is_signed=False + if ir.IntegerType.isinstance(vec_type.element_type) + else None, + ) + result.append(_fragmented_array_to_ir(value, vec_type)) + return result + + @_register_lowering(scf.ForOp) def _for_op_lowering_rule( ctx: LoweringContext, for_op: scf.ForOp @@ -884,60 +1040,22 @@ def _for_op_lowering_rule( yield_layouts = inference_utils.in_layouts(yield_op) if in_layouts != out_layouts or in_layouts != yield_layouts: raise ValueError("Layout mismatch") - fa_layouts = in_layouts - - fa_layouts_it = iter(fa_layouts) - arg_template = [ - (_fragmented_array_from_ir(arg, next(fa_layouts_it)), arg.type) - if ir.VectorType.isinstance(arg.type) - else (arg, arg.type) - for arg in for_op.initArgs - ] - def lower_carry(carry): - fa_layouts_it = iter(fa_layouts) - carry_with_fas = [ - _fragmented_array_from_ir(arg, next(fa_layouts_it)) - if ir.VectorType.isinstance(arg.type) - else arg - for arg in carry - ] - lowered_carry = [] - for c in carry_with_fas: - if isinstance(c, fa.FragmentedArray): - lowered_carry.extend(c.registers.flat) - else: - lowered_carry.append(c) - return lowered_carry - - def recreate_carry(lowered_carry): - recreated_carry = [] - arg_it = iter(lowered_carry) - for arg_value, arg_type in arg_template: - if isinstance(arg_value, fa.FragmentedArray): - carry_registers = np.asarray( - [next(arg_it) for _ in arg_value.registers.flat], dtype=object - ) - carry_registers = carry_registers.reshape(arg_value.registers.shape) - carry = fa.FragmentedArray( - _registers=carry_registers, - _layout=arg_value.layout, - _is_signed=arg_value.is_signed, - ) - recreated_carry.append(_fragmented_array_to_ir(carry, arg_type)) - else: - recreated_carry.append(next(arg_it)) - return recreated_carry + flat_init_args, args_template = _flatten_ir_values( + for_op.initArgs, in_layouts + ) new_for_op = scf.ForOp( for_op.lowerBound, for_op.upperBound, for_op.step, - lower_carry(for_op.initArgs), + flat_init_args, ) with ir.InsertionPoint(new_for_op.body): - recreated_carry = recreate_carry(new_for_op.body.arguments[1:]) + recreated_carry = _unflatten_ir_values( + new_for_op.body.arguments[1:], args_template + ) ops_to_lower = [] - for op in for_op.body: + for op in [*for_op.body]: if op == yield_op: continue mgpu.private_operation_remove_from_parent(op) @@ -952,16 +1070,80 @@ def recreate_carry(lowered_carry): ctx.lower_op(op) with ir.InsertionPoint(new_for_op.body): - new_yield_operands = lower_carry(yield_op.operands) + flat_operands, _ = _flatten_ir_values(yield_op.operands, in_layouts) yield_op.erase() - scf.yield_(new_yield_operands) - return recreate_carry(new_for_op.results) + scf.yield_(flat_operands) + + return _unflatten_ir_values(new_for_op.results, args_template) + + +def _infer_flat_result_types( + op: ir.OpView, out_layouts: Sequence[ir.Attribute] +) -> Sequence[ir.Type]: + result_types: list[ir.Type] = [] + out_layouts_it = iter(out_layouts) + for r in op.results: + if not ir.VectorType.isinstance(r.type): + result_types.append(r.type) + continue + vec_type = ir.VectorType(r.type) + layout = layouts_lib.from_layout_attr(next(out_layouts_it)) + result_types.extend( + [layout.registers_element_type(vec_type.element_type)] + * math.prod(layout.registers_shape(tuple(vec_type.shape))) + ) + return result_types + + +@_register_lowering(scf.IfOp) +def _if_op_lowering_rule( + ctx: LoweringContext, if_op: scf.IfOp +) -> MlirLoweringRuleResult: + if not inference_utils.should_have_layout(if_op): + return _traverse_op_lowering_rule(ctx, if_op) + + raise NotImplementedError + + +@_register_lowering(scf.IndexSwitchOp) +def _index_switch_op_lowering_rule( + ctx: LoweringContext, switch_op: scf.IndexSwitchOp +) -> MlirLoweringRuleResult: + if not inference_utils.should_have_layout(switch_op): + return _traverse_op_lowering_rule(ctx, switch_op) + + out_layouts = inference_utils.out_layouts(switch_op) + new_switch_op = scf.IndexSwitchOp( + _infer_flat_result_types(switch_op, out_layouts), + switch_op.arg, + switch_op.cases, + len(switch_op.regions) - 1, + ) + + results_template: Sequence[_VectorTemplate | None] = [] + for region, new_region in zip( + switch_op.regions, new_switch_op.regions, strict=True + ): + [block] = region.blocks + new_block = new_region.blocks.append() + with ir.InsertionPoint(new_block): + for op in [*block]: + if not isinstance(op, scf.YieldOp): + mgpu.private_operation_remove_from_parent(op) + mgpu.private_block_append_owned_operation(new_block, op) + ctx.lower_op(op) + continue + if inference_utils.in_layouts(op) != out_layouts: + raise ValueError("Layout mismatch") + flat_results, results_template = _flatten_ir_values( + op.operands, out_layouts + ) + scf.yield_(flat_results) + return _unflatten_ir_values(new_switch_op.results, results_template) @_register_lowering(func.FuncOp) @_register_lowering(gpu.LaunchOp) -@_register_lowering(scf.IfOp) # TODO(apaszke,bchetioui): Add a proper rule. -@_register_lowering(scf.IndexSwitchOp) # TODO(apaszke,bchetioui): Add a proper rule. def _traverse_op_lowering_rule( ctx: LoweringContext, op: ir.OpView ) -> MlirLoweringRuleResult: @@ -989,9 +1171,11 @@ def single_thread_predicates(module: ir.Module) -> tuple[ir.Value, ir.Value]: sub_op.operation.regions[0].blocks[0] ): assert block_predicate is None - block_predicate = utils.single_thread_predicate(per_block=True) + block_predicate = utils.single_thread_predicate( + scope=utils.ThreadSubset.BLOCK + ) warpgroup_predicate = utils.single_thread_predicate( - per_block=False + scope=utils.ThreadSubset.WARPGROUP ) if block_predicate is None: diff --git a/jax/experimental/mosaic/gpu/examples/flash_attention.py b/jax/experimental/mosaic/gpu/examples/flash_attention.py index dc59dda3a6e5..57f30b8603c8 100644 --- a/jax/experimental/mosaic/gpu/examples/flash_attention.py +++ b/jax/experimental/mosaic/gpu/examples/flash_attention.py @@ -299,7 +299,7 @@ def kv_loop(kv_step, carry): scf.yield_([]) with ir.InsertionPoint(if_compute.else_block): nvvm.setmaxregister(40, nvvm.SetMaxRegisterAction.decrease) - with single_thread(per_block=False): + with single_thread(scope=ThreadSubset.WARPGROUP): k_tr = (TileTransform(tiling), TransposeTransform((1, 0, 2, 3))) v_tr = TileTransform(tiling) kv_head_idx = arith.divui(q_head_idx, c(q_heads_per_kv_head)) @@ -391,7 +391,7 @@ def only_wg(idx): kv_head_idx = arith.divui(q_head_idx, c(q_heads_per_kv_head)) def kv_copy_init(slot, kv_seq_base): - with single_thread(per_block=False): + with single_thread(ThreadSubset.WARPGROUP): txcount = 2 * blocks.kv * head_dim * bytewidth(f16) barriers[slot].arrive_expect_tx(txcount) k_tr = (TileTransform(tiling), TransposeTransform((1, 0, 2, 3))) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 5daed8416589..76f7d549cf55 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -202,6 +202,11 @@ def enumerate_negative(elems: Sequence[T]) -> Iterable[tuple[int, T]]: yield i - offset, e +@dataclasses.dataclass(frozen=True) +class Replicated: + times: int + + @dataclasses.dataclass(frozen=True) class TiledLayout: """A FragmentedArray layout derived from a tiling expression. @@ -247,8 +252,8 @@ class TiledLayout: by a single (logical) register. """ tiling: Tiling - warp_dim: int - lane_dims: tuple[int, ...] # major-to-minor + warp_dim: int | Replicated + lane_dims: tuple[int | Replicated, ...] # major-to-minor vector_dim: int def __post_init__(self): @@ -256,19 +261,34 @@ def __post_init__(self): raise ValueError("Tiling must have at least one tile") min_shape = self.tiling.tiles[0] min_tiled_shape = self.tiling.tile_shape(min_shape) - dims_set = {self.warp_dim, *self.lane_dims, self.vector_dim} - if len(dims_set) != len(self.lane_dims) + 2: + dims_set = {*self.partitioned_lane_dims, self.vector_dim} + if partitions_warp_dim := not isinstance(self.warp_dim, Replicated): + dims_set.add(self.warp_dim) + if len(dims_set) != len(self.partitioned_lane_dims) + 1 + partitions_warp_dim: raise ValueError for d in dims_set: if d >= 0: raise ValueError("All dimensions must be negative") if d < -(len(min_tiled_shape) - len(min_shape)): raise ValueError("Dimension out of range") - if min_tiled_shape[self.warp_dim] != WARPS_IN_WARPGROUP: + if isinstance(self.warp_dim, Replicated): + if self.warp_dim.times != WARPS_IN_WARPGROUP: + raise ValueError + elif min_tiled_shape[self.warp_dim] != WARPS_IN_WARPGROUP: raise ValueError - if math.prod(min_tiled_shape[d] for d in self.lane_dims) != WARP_SIZE: + lane_dims_prod = math.prod( + d.times if isinstance(d, Replicated) else min_tiled_shape[d] + for d in self.lane_dims + ) + if lane_dims_prod != WARP_SIZE: raise ValueError + @functools.cached_property + def partitioned_lane_dims(self) -> tuple[int, ...]: + return tuple( + d for d in self.lane_dims if not isinstance(d, Replicated) + ) + def thread_idxs(self, shape: tuple[int, ...]) -> Iterable[tuple[ir.Value, ...]]: # We first find the linear index and then divide by the shape to # get the index. @@ -319,11 +339,15 @@ def tiled_tiling_rank(self) -> int: def vector_length(self) -> int: return self.tiled_tiling_shape[self.vector_dim] + def registers_element_type(self, t: ir.Type) -> ir.Type: + return ir.VectorType.get((self.vector_length,), t) + def registers_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: """Returns the shape of the register array needed to represent an array of the given logical shape.""" tiled_shape = list(self.tiling.tile_shape(shape)) - tiled_shape[self.warp_dim] = 1 - for d in self.lane_dims: + if not isinstance(self.warp_dim, Replicated): + tiled_shape[self.warp_dim] = 1 + for d in self.partitioned_lane_dims: tiled_shape[d] = 1 tiled_shape[self.vector_dim] = 1 return tuple(tiled_shape) @@ -335,16 +359,20 @@ def shape_from_registers_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: """ tiled_tiling = self.tiled_tiling_shape shape = list(shape) - shape[self.warp_dim] = WARPS_IN_WARPGROUP - for d in self.lane_dims: + if not isinstance(self.warp_dim, Replicated): + shape[self.warp_dim] = WARPS_IN_WARPGROUP + for d in self.partitioned_lane_dims: shape[d] = tiled_tiling[d] shape[self.vector_dim] = tiled_tiling[self.vector_dim] return self.tiling.untile_shape(tuple(shape)) - def lane_indices(self) -> tuple[ir.Value, ...]: + def _full_lane_indices(self) -> tuple[ir.Value, ...]: i32 = ir.IntegerType.get_signless(32) tiled_shape = self.tiled_tiling_shape - lanes_shape = tuple(tiled_shape[d] for d in self.lane_dims) + lanes_shape = tuple( + d.times if isinstance(d, Replicated) else tiled_shape[d] + for d in self.lane_dims + ) assert math.prod(lanes_shape) == WARP_SIZE lane_strides = utils.get_contiguous_strides(lanes_shape) lane_idx = arith.remui(utils.thread_idx(), c(WARP_SIZE, i32)) @@ -352,20 +380,29 @@ def lane_indices(self) -> tuple[ir.Value, ...]: arith.remui(arith.divui(lane_idx, c(stride, i32)), c(size, i32)) for stride, size in zip(lane_strides, lanes_shape) ) + return lane_indices + + def lane_indices(self) -> tuple[ir.Value, ...]: + i32 = ir.IntegerType.get_signless(32) + tiled_shape = self.tiled_tiling_shape + lane_indices = self._full_lane_indices() full_indices = [arith.constant(i32, 0)] * len(tiled_shape) for d, i in zip(self.lane_dims, lane_indices): + if isinstance(d, Replicated): + continue full_indices[d] = i return tuple(full_indices) def warp_indices(self) -> tuple[ir.Value, ...]: i32 = ir.IntegerType.get_signless(32) tiled_shape_rank = len(self.tiled_tiling_shape) - warp_idx = arith.remui( - arith.divui(utils.thread_idx(), c(WARP_SIZE, i32)), - c(WARPS_IN_WARPGROUP, i32), - ) indices = [arith.constant(i32, 0)] * tiled_shape_rank - indices[self.warp_dim] = warp_idx + if not isinstance(self.warp_dim, Replicated): + warp_idx = arith.remui( + arith.divui(utils.thread_idx(), c(WARP_SIZE, i32)), + c(WARPS_IN_WARPGROUP, i32), + ) + indices[self.warp_dim] = warp_idx return tuple(indices) @@ -382,28 +419,6 @@ def _tiled_wgmma_layout(shape: tuple[int, ...]): return WGMMA_LAYOUT -@dataclasses.dataclass(frozen=True) -class WGMMARowFragLayout: - """[m] matrix, where m % 64 == 0.""" - - def thread_idxs(self, shape): - index = ir.IndexType.get() - assert len(shape) == 1 - assert shape[0] % 64 == 0 - tid = arith.index_cast(ir.IndexType.get(), mgpu.thread_idx()) - tid_wg = arith.remui(tid, c(WARPGROUP_SIZE, index)) - warp_idx = arith.divui(tid_wg, c(32, index)) - lane_id = arith.remui(tid_wg, c(32, index)) - row_base = arith.addi( - arith.divui(lane_id, c(4, index)), arith.muli(warp_idx, c(16, index)) - ) - - for row_group in range(0, shape[0], 64): - for row_subgroup in (0, 8): - row = arith.addi(row_base, c(row_group + row_subgroup, index)) - yield (row,) - - @dataclasses.dataclass(frozen=True) class WGSplatFragLayout: """A fragmented array where all the values are equal represented as a register per thread. @@ -435,6 +450,14 @@ def can_broadcast_to(self, shape) -> bool: """ return all(dim1 == dim2 or dim1 == 1 for dim1, dim2 in zip(self.shape[::-1], shape[::-1])) + def registers_element_type(self, t: ir.Type) -> ir.Type: + return t + + def registers_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: + """Returns the shape of the register array needed to represent an array of the given logical shape.""" + del shape # Unused. + return () + def thread_idxs(self, shape): assert shape == self.shape raise NotImplementedError @@ -469,6 +492,15 @@ def from_shaped_type(cls, shaped_ty: ir.Type): shape=tuple(shaped_ty.shape), vec_size=min(8 // bw, max_vec_size) ) + def registers_element_type(self, t: ir.Type) -> ir.Type: + return ir.VectorType.get((self.vec_size,), t) + + def registers_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: + """Returns the shape of the register array needed to represent an array of the given logical shape.""" + if shape != self.shape: + raise ValueError(f"Shape {shape} is not compatible with {self}") + return (math.prod(self.shape) // (WARPGROUP_SIZE * self.vec_size),) + def thread_idxs(self, shape): assert shape == self.shape index = ir.IndexType.get() @@ -497,14 +529,25 @@ def linear_thread_idxs(self): yield arith.addi(off, c(i * WARPGROUP_SIZE * self.vec_size, tidx.type)) -FragmentedLayout = WGSplatFragLayout | WGStridedFragLayout | WGMMARowFragLayout | TiledLayout +FragmentedLayout = WGSplatFragLayout | WGStridedFragLayout | TiledLayout -WGMMA_ROW_LAYOUT = WGMMARowFragLayout() +WGMMA_COL_LAYOUT = TiledLayout( + Tiling(((8,), (2,))), + warp_dim=Replicated(4), + lane_dims=(Replicated(8), -2), + vector_dim=-1, +) +WGMMA_ROW_LAYOUT = TiledLayout( + Tiling(((64,), (16,), (8,), (1,))), + warp_dim=-4, + lane_dims=(-2, Replicated(4)), + vector_dim=-1, +) # The tiled layout is equivalent to one described here in PTX documentation: # https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-d -# In this layout, we partition the 64x8 tiles over 4 warpgroups into 16x8 tiles. +# In this layout, we partition the 64x8 tiles over 4 warps into 16x8 tiles. # Then, we further split the 16x8 tiles into 8x8 submatrices which are the unit # of data that is split across a warp. Since 8*8 = 64, but a warp has only 32 # threads, we vectorize pairs of elements along columns. @@ -612,12 +655,6 @@ def __init__( ) match self.layout: - # Registers are [m_tiles, 2 rows] in WGMMA_ROW layout - # Each element is a dtype scalar - case WGMMARowFragLayout(): - if _registers.ndim != 2 or _registers.shape[-1] != 2: - raise ValueError(f"Invalid register array shape: {_registers.shape}") - # Registers are flat case WGStridedFragLayout(shape): [reg_size] = ir.VectorType(_registers.flat[0].type).shape @@ -626,8 +663,8 @@ def __init__( != math.prod(_registers.shape) * WARPGROUP_SIZE * reg_size ): raise ValueError( - "Invalid register array shape: math.prod({_registers.shape}) *" - " {WARPGROUP_SIZE} * {reg_size}, want: math.prod({shape})" + f"Invalid register array shape: math.prod({_registers.shape}) *" + f" {WARPGROUP_SIZE} * {reg_size}, want: math.prod({shape})" ) # Just a single register @@ -674,59 +711,19 @@ def load_strided( vecs = [vector.load(vec_ty, ref, vec_idx) for vec_idx in layout.thread_idxs(shape)] return cls(_registers=np.array(vecs), _layout=layout, _is_signed=is_signed) - @classmethod - def load_wgmma_row( - cls, - ref: ir.Value, - *, - is_signed: bool | None = None, - ): - if not ir.MemRefType.isinstance(ref.type): - raise TypeError(ref.type) - - ref_ty = ir.MemRefType(ref.type) - shape = tuple(ref_ty.shape) - if len(shape) != 1: - raise ValueError("WGMMARowFragLayout requires a 1D shape") - if shape[0] % 64: - raise ValueError( - "WGMMARowFragLayout requires shape[0] to be a multiple of 64" - ) - - layout = WGMMARowFragLayout() - registers = [memref.load(ref, [idx]) for (idx,) in layout.thread_idxs(shape)] - registers = np.array(registers).reshape(-1, 2) - return cls(_registers=registers, _layout=layout, _is_signed=is_signed) - - @classmethod def splat(cls, value, shape, layout=None, *, is_signed: bool | None = None): layout = layout or WGSplatFragLayout(shape) match layout: - case WGMMARowFragLayout(): - if len(shape) != 1: - raise ValueError("WGMMARowFragLayout requires a 1D shape") - if shape[0] % 64: - raise ValueError( - "WGMMARowFragLayout requires shape[0] to be a multiple of 64" - ) - reg_shape = (shape[0] // 64, 2) - case WGStridedFragLayout(vec_size=vec_size): - assert shape == layout.shape - elems = np.prod(shape) - reg_shape = (elems // (WARPGROUP_SIZE * vec_size),) - value = vector.splat(ir.VectorType.get((vec_size,), value.type), value) case WGSplatFragLayout(): - assert shape == layout.shape - reg_shape = () - case TiledLayout(): - value = vector.splat(ir.VectorType.get((layout.vector_length,), value.type), value) - reg_shape = layout.registers_shape(shape) + pass + case WGStridedFragLayout() | TiledLayout(): + value = vector.splat(layout.registers_element_type(value.type), value) case _: raise NotImplementedError(layout) return cls( - _registers=np.full(reg_shape, value, dtype=object), + _registers=np.full(layout.registers_shape(shape), value, dtype=object), _layout=layout, _is_signed=is_signed, ) @@ -734,9 +731,6 @@ def splat(cls, value, shape, layout=None, *, is_signed: bool | None = None): @property def shape(self): match self.layout: - case WGMMARowFragLayout(): - row_tiles = self.registers.shape[0] - return (row_tiles * 64,) case WGStridedFragLayout(shape): return shape case WGSplatFragLayout(shape=shape): @@ -752,7 +746,7 @@ def mlir_dtype(self): match self.layout: case WGStridedFragLayout() | TiledLayout(): return ir.VectorType(reg_ty).element_type - case WGMMARowFragLayout() | WGSplatFragLayout(): + case WGSplatFragLayout(): return reg_ty case _: raise NotImplementedError @@ -1514,7 +1508,7 @@ def upcast_to_bf16(reg, high): case WGStridedFragLayout() | TiledLayout(): shape = ir.VectorType(self.registers.flat[0].type).shape upcast_ty = ir.VectorType.get(shape, larger_ty) - case WGMMARowFragLayout() | WGSplatFragLayout(): + case WGSplatFragLayout(): upcast_ty = larger_ty case _: raise NotImplementedError(f"Unsupported layout {self.layout}") @@ -1539,7 +1533,7 @@ def upcast_to_bf16(reg, high): case WGStridedFragLayout() | TiledLayout(): shape = ir.VectorType(self.registers.flat[0].type).shape new_reg_ty = ir.VectorType.get(shape, new_dtype) - case WGMMARowFragLayout() | WGSplatFragLayout(): + case WGSplatFragLayout(): new_reg_ty = new_dtype case _: raise NotImplementedError(f"Unsupported layout {self.layout}") @@ -1594,7 +1588,7 @@ def reduce_sum(self, scratch: ir.Value | None = None): memref.store(warp_result, scratch, [warp_id]) utils.warpgroup_barrier() zero_index = c(0, index) - with mgpu.single_thread(per_block=False): + with mgpu.single_thread(scope=mgpu.ThreadSubset.WARPGROUP): scratch_vec = vector.load( ir.VectorType.get((4,), self.mlir_dtype), scratch, @@ -1638,9 +1632,9 @@ def reduce(self, op: str | Callable[[ir.Value, ir.Value], ir.Value], axis): i32 = ir.IntegerType.get_signless(32) row_tile_dim = self.registers.shape[0] row_subtile_dim = self.registers.shape[4] - new_regs = np.empty((row_tile_dim, row_subtile_dim), dtype=object) + new_regs = np.empty((row_tile_dim, 1, row_subtile_dim, 1, 1), dtype=object) assert self.registers.shape[-1] == 1 - for row_tile, row_subtile in np.ndindex(new_regs.shape): + for row_tile, row_subtile in np.ndindex(row_tile_dim, row_subtile_dim): # Reduce the registers owned by the current thread over n tiles reg_index = [0] * self.registers.ndim reg_index[0] = row_tile @@ -1671,7 +1665,9 @@ def reduce(self, op: str | Callable[[ir.Value, ir.Value], ir.Value], axis): nvvm.ShflKind.bfly, ) result = op(result, other_result) - new_regs[row_tile, row_subtile] = result + new_regs[row_tile, :, row_subtile] = vector.splat( + ir.VectorType.get((1,), self.mlir_dtype), result + ) return FragmentedArray( _registers=new_regs, _layout=WGMMA_ROW_LAYOUT, _is_signed=self.is_signed ) @@ -1716,17 +1712,32 @@ def broadcast_minor(self, n): reg_shape = WGMMA_LAYOUT.registers_shape((self.shape[0], n)) new_regs = np.empty(reg_shape, dtype=object) dtype = self.mlir_dtype - for (row_tile, row_subtile), reg in np.ndenumerate(self.registers): + i0 = arith.constant(ir.IndexType.get(), 0) + for (row_tile, _, row_subtile, *__), reg in np.ndenumerate(self.registers): tile = [slice(None)] * len(new_regs.shape) tile[0] = row_tile tile[4] = row_subtile new_regs[tuple(tile)] = vector.splat( - ir.VectorType.get((WGMMA_LAYOUT.vector_length,), dtype), reg + ir.VectorType.get((WGMMA_LAYOUT.vector_length,), dtype), + vector.extractelement(reg, position=i0), ) return FragmentedArray( _registers=new_regs, _layout=WGMMA_LAYOUT, _is_signed=self.is_signed ) + def broadcast_major(self, m): + if m % 64: + raise ValueError("Number of rows must be divisible by 64") + reg_shape = WGMMA_LAYOUT.registers_shape((m, self.shape[0])) + new_regs = np.empty(reg_shape, dtype=object) + for (col_tile, *_), reg in np.ndenumerate(self.registers): + tile = [slice(None)] * len(new_regs.shape) + tile[1] = col_tile + new_regs[tuple(tile)] = reg + return FragmentedArray( + _registers=new_regs, _layout=WGMMA_LAYOUT, _is_signed=self.is_signed + ) + def select(self, on_true, on_false): if ( not ir.IntegerType.isinstance(self.mlir_dtype) @@ -1754,12 +1765,18 @@ def foreach( for mlir_idx, reg_idx in zip(self.layout.thread_idxs(self.shape), np.ndindex(self.registers.shape), strict=True): reg = self.registers[reg_idx] assert len(mlir_idx) == len(self.shape), (mlir_idx, self.shape) - [elems] = ir.VectorType(reg.type).shape - for i in range(elems): - i = c(i, index) - val = fn(vector.extractelement(reg, position=i), (*mlir_idx[:-1], arith.addi(mlir_idx[-1], i))) + if ir.VectorType.isinstance(reg.type): + [elems] = ir.VectorType(reg.type).shape + for i in range(elems): + i = c(i, index) + val = fn(vector.extractelement(reg, position=i), (*mlir_idx[:-1], arith.addi(mlir_idx[-1], i))) + if create_array: + new_regs[reg_idx] = vector.insertelement(val, new_regs[reg_idx], position=i) + else: + val = fn(reg, mlir_idx) if create_array: - new_regs[reg_idx] = vector.insertelement(val, new_regs[reg_idx], position=i) + new_regs[reg_idx] = val + if create_array: return FragmentedArray(_registers=new_regs, _layout=self.layout, _is_signed=is_signed) @@ -1771,30 +1788,42 @@ def _(val, idx): fmt_str = fmt.format(f"[{idx_fmt}]: {{}}") utils.debug_print(fmt_str, *idx, val, uniform=False) - def store_untiled(self, ref: ir.Value, *, vector_store: bool = True): + def store_untiled( + self, ref: ir.Value, *, swizzle: int = 16, optimized: bool = True + ): if not ir.MemRefType.isinstance(ref.type): raise ValueError(ref) - - def vs_unsupported(): - if not vector_store: - raise NotImplementedError( - f"Can't use non-vector stores with layout {self.layout}" - ) - match self.layout: - case WGMMARowFragLayout(): - self._store_untiled_wgmma_row(ref) case WGSplatFragLayout(): - vs_unsupported() + # All values are the same so swizzle does not affect anything here. self._store_untiled_splat(ref) case WGStridedFragLayout(): - vs_unsupported() + if swizzle != 16: + raise NotImplementedError self._store_untiled_wg_strided(ref) case TiledLayout(): - self._store_untiled_tiled(ref, vector_store=vector_store) + ref_shape = ir.MemRefType(ref.type).shape + ref = utils.memref_reshape(ref, (*(1 for _ in ref_shape), *ref_shape)) + self.store_tiled(ref, swizzle=swizzle, optimized=optimized) case _: raise NotImplementedError(self.layout) + @classmethod + def load_untiled( + cls, + ref: ir.Value, + *, + layout: TiledLayout, + swizzle: int = 16, + is_signed: bool | None = None, + optimized: bool = True, + ): + ref_shape = ir.MemRefType(ref.type).shape + ref = utils.memref_reshape(ref, (*(1 for _ in ref_shape), *ref_shape)) + return cls.load_tiled( + ref, swizzle=swizzle, is_signed=is_signed, layout=layout, optimized=optimized + ) + def _store_untiled_splat(self, ref: ir.Value): vec_size = 64 // mgpu.bitwidth(self.mlir_dtype) if np.prod(self.shape) < vec_size * WARPGROUP_SIZE: @@ -1830,75 +1859,15 @@ def _store_untiled_wg_strided(self, ref: ir.Value): for idx, reg in zip(idxs, self.registers.flat): vector.store(reg, ref_, idx) - def _store_untiled_wgmma_row(self, ref: ir.Value): - """Stores an array with a WGMMA row layout.""" - assert self.layout == WGMMA_ROW_LAYOUT - index = ir.IndexType.get() - tid = arith.index_cast(ir.IndexType.get(), mgpu.thread_idx()) - - is_first = arith.cmpi( - arith.CmpIPredicate.eq, arith.remui(tid, c(4, index)), c(0, index) - ) - # Consecutive groups of 4 threads hold the same value in this layout, - # therefore we only need to transfer data from one of them. - with utils.when(is_first): - for (idx,), value in zip( - self.layout.thread_idxs(self.shape), self.registers.flatten() - ): - memref.store(value, ref, [idx]) - - def _store_untiled_tiled(self, ref: ir.Value, *, vector_store: bool = True): - """Stores an array with a tiled layout. Not optimized at the moment.""" - if utils.bitwidth(self.mlir_dtype) < 8: - raise NotImplementedError(f"Can't store sub-byte types ({self.mlir_dtype=})") - i32 = ir.IntegerType.get_signless(32) - layout = self.layout - assert isinstance(layout, TiledLayout) - ref_strides, _ = ir.MemRefType(ref.type).get_strides_and_offset() - if vector_store and ref_strides[layout.vector_dim] != 1: - raise NotImplementedError( - "Can't use vector stores with non-unit minormost stride" - ) - strides = layout.tiling.tile_strides(ref_strides) - smem_space = ir.Attribute.parse("#gpu.address_space") - ref_space = ir.MemRefType(ref.type).memory_space - memory_space = None - if str(ref_space) == str(smem_space): - memory_space = 3 - elif ref_space: - raise NotImplementedError(f"Unexpected ref space {ref_space}") - ptr = utils.memref_ptr(ref, memory_space=memory_space) - # Fold warp and lane offsets into the pointer once, since they are dynamic. - dyn_strides = [ - arith.constant(i32, s) for s in strides[-layout.tiled_tiling_rank :] - ] - warp_offset = utils.dyn_dot(layout.warp_indices(), dyn_strides) - lane_offset = utils.dyn_dot(layout.lane_indices(), dyn_strides) - dyn_offset = arith.addi(warp_offset, lane_offset) - ptr = utils.getelementptr(ptr, [dyn_offset], self.mlir_dtype) - # All warp tile offsets are static and can be fused into the store. - for tile_idx, reg in np.ndenumerate(self.registers): - if vector_store: - elems = [reg] - else: - index = ir.IndexType.get() - elems = [ - vector.extractelement(reg, position=c(i, index)) - for i in range(ir.VectorType(reg.type).shape[0]) - ] - for i, e in enumerate(elems): - tile_idx_local = list(tile_idx) - tile_idx_local[layout.vector_dim] += i - tile_idx_local = list(tile_idx_local) - lin_idx = sum(i * s for i, s in zip(tile_idx_local, strides, strict=True)) - reg_ptr = utils.getelementptr(ptr, [lin_idx], self.mlir_dtype) - llvm.store(e, reg_ptr) - - def store_tiled(self, ref, swizzle: int | None): + def store_tiled(self, ref, swizzle: int | None, optimized: bool = True): if not isinstance(self.layout, TiledLayout): raise NotImplementedError(self.layout) layout, shape = self.layout, self.shape - for get, _, ptr in self.transfer_tiled2(ref, swizzle, layout, shape): + # Note that the loop below will "race" for layouts that replicate data. + # However, in that case all of the racing writes store the same data, which + # is ok in the CUDA memory model. + stores = self.transfer_tiled2(ref, swizzle, layout, shape, optimized) + for get, _, ptr in stores: llvm.store(get(self.registers), ptr) @classmethod @@ -1909,6 +1878,7 @@ def load_tiled( *, is_signed: bool | None = None, layout: FragmentedLayout = WGMMA_LAYOUT, + optimized: bool = True, ): ref_ty = ir.MemRefType(ref.type) dtype = ref_ty.element_type @@ -1927,7 +1897,8 @@ def load_tiled( ) registers = np.full(layout.registers_shape(shape), zero, dtype=object) reg_ty = ir.VectorType.get((layout.vector_length,), ref_ty.element_type) - for _, update, ptr in cls.transfer_tiled2(ref, swizzle, layout, shape): + loads = cls.transfer_tiled2(ref, swizzle, layout, shape, optimized) + for _, update, ptr in loads: update(registers, llvm.load(reg_ty, ptr)) case _: raise NotImplementedError(layout) @@ -2023,6 +1994,7 @@ def transfer_tiled2( swizzle: int | None, layout: TiledLayout, shape: tuple[int, ...], + optimized: bool = True, ): """Generate a transfer schedule for a tiled layout. @@ -2074,12 +2046,18 @@ def transfer_tiled2( raise NotImplementedError("Memory and register tiling incompatible") tiled_shape = list(itertools.chain.from_iterable(tiled_nested_shape)) elem_tiled_strides = list(itertools.chain.from_iterable(tiled_nested_strides)) - elem_lane_strides = [elem_tiled_strides[d] for d in layout.lane_dims] - lane_shape = [tiled_shape[d] for d in layout.lane_dims] + lane_shape = [ + d.times if isinstance(d, Replicated) else tiled_shape[d] for d in layout.lane_dims + ] + lane_strides = [ + 0 if isinstance(d, Replicated) else elem_tiled_strides[d] for d in layout.lane_dims + ] if elem_tiled_strides[layout.vector_dim] != 1: raise ValueError("Stride of the vectorized dimension should be 1") - for d in (layout.warp_dim, *layout.lane_dims, layout.vector_dim): + for d in (*layout.partitioned_lane_dims, layout.vector_dim): tiled_shape[d] = 1 + if not isinstance(layout.warp_dim, Replicated): + tiled_shape[layout.warp_dim] = 1 element_bits = mgpu.bitwidth(dtype) if (layout.vector_length * element_bits) % 8 != 0: @@ -2114,10 +2092,22 @@ def transfer_tiled2( transfer_tiled_strides = [s // layout.vector_length for s in elem_tiled_strides] transfer_dtype = ir.VectorType.get((layout.vector_length,), dtype) - plan = plan_tiled_transfer( - tiled_shape, elem_tiled_strides, lane_shape, elem_lane_strides, layout, - element_bits, swizzle - ) + if ref_ty.memory_space is None: + llvm_memory_space = None + elif ref_ty.memory_space == ir.Attribute.parse("#gpu.address_space"): + llvm_memory_space = 3 + else: + raise ValueError(f"Unsupported memory space: {ref_ty.memory_space}") + + if optimized: + if llvm_memory_space != 3: + raise NotImplementedError("Only optimized transfers to SMEM supported") + plan = plan_tiled_transfer( + tiled_shape, elem_tiled_strides, lane_shape, lane_strides, + layout, element_bits, swizzle + ) + else: + plan = TrivialTransferPlan() # All offsets are in units of transfer_dtype. dyn_tiled_strides = [ @@ -2126,9 +2116,7 @@ def transfer_tiled2( lane_offset = utils.dyn_dot(layout.lane_indices(), dyn_tiled_strides) warp_offset = utils.dyn_dot(layout.warp_indices(), dyn_tiled_strides) dyn_offset = arith.addi(lane_offset, warp_offset) - if ref_ty.memory_space != ir.Attribute.parse("#gpu.address_space"): - raise ValueError("Tiled stores can be performed into SMEM") - ptr = utils.memref_ptr(ref, memory_space=3) + ptr = utils.memref_ptr(ref, memory_space=llvm_memory_space) _as_consts = lambda consts: [c(const) for const in consts.tolist()] # This has bits set only for the offset bits that influence swizzling. swizzle_mask = swizzle_block_transfers - swizzle_tile_transfers @@ -2307,9 +2295,14 @@ def plan_tiled_transfer( num_wavefronts = max(transfer_bytes // smem_bank_bytes, 1) wavefront_lanes = WARP_SIZE // num_wavefronts + lane_mask = np.full(lane_shape, False) + lane_mask[tuple(slice(0, 1) if s == 0 else slice(None) for s in lane_strides)] = True + wavefront_mask = lane_mask.reshape(num_wavefronts, wavefront_lanes) + lane_offsets_in_tile = np.dot(list(np.ndindex(*lane_shape)), lane_strides) def has_bank_conflicts(tile_idx_transform): - tile_idxs = np.unravel_index(np.arange(math.prod(tiled_shape)), tiled_shape) + num_tiles = math.prod(tiled_shape) + tile_idxs = np.unravel_index(np.arange(num_tiles), tiled_shape) tile_idxs = np.expand_dims(np.stack(tile_idxs, 1), 1) # [#tiles, 1, #dims] lane_tile_idx = tile_idx_transform(tile_idxs) # [#tiles, #lanes/1, #dims] assert lane_tile_idx.shape[1] in {1, WARP_SIZE} @@ -2320,10 +2313,17 @@ def has_bank_conflicts(tile_idx_transform): swizzle_bits = swizzle_groups * swizzle_tile_elems lane_banks = ((offsets ^ swizzle_bits) // elems_per_bank) % num_banks wavefront_banks = lane_banks.reshape(-1, num_wavefronts, wavefront_lanes) - # Order of threads within the wavefront is unimportant. - wavefront_banks = np.sort(wavefront_banks, axis=-1) - # There are no conflicts if each wavefront only contains unique banks. - return np.any(wavefront_banks[..., 1:] == wavefront_banks[..., :-1]) + # We step over wavefronts since they might have a different number of lanes. + wavefront_banks = wavefront_banks.swapaxes(0, 1) + for banks, mask in zip(wavefront_banks, wavefront_mask): + banks = banks[:, mask] + # Order of threads within the wavefront is unimportant. + banks = np.sort(banks, axis=-1) + # There are no conflicts if each wavefront only contains unique banks. + repeats = np.any(banks[..., 1:] == banks[..., :-1]) + if repeats: + return True + return False # We don't need any special treatment if there are no conflicts when each lane # transfers the same tile at a time. @@ -2386,10 +2386,23 @@ def optimization_barrier(*arrays: mgpu.FragmentedArray): index = ir.IndexType.get() i32 = ir.IntegerType.get_signless(32) + def _repack(regs_it, reg_ty): + if not ir.VectorType.isinstance(reg_ty): + result_reg = next(regs_it) + assert result_reg.type == reg_ty + return result_reg + + num_i32_regs = utils.bitwidth(reg_ty) // 32 + i32_reg_ty = ir.VectorType.get((num_i32_regs,), i32) + reg = llvm.mlir_undef(i32_reg_ty) + for i_elem in range(num_i32_regs): + val = llvm.bitcast(i32, next(regs_it)) + reg = llvm.insertelement(reg, val, arith.constant(i32, i_elem)) + return vector.bitcast(reg_ty, reg) + regs = [] reg_dtypes = [] reg_constraints = [] - repack_fns = [] # We unpack each array into a flat list of registers, and prepare the # functions that invert the transform in repack_fns. for array in arrays: @@ -2403,36 +2416,25 @@ def optimization_barrier(*arrays: mgpu.FragmentedArray): for reg in array.registers.flat for pos in range(vec_len) ] - def _repack(regs, reg_ty=reg_ty): - reg = llvm.mlir_undef(reg_ty) - [vec_len] = ir.VectorType(reg_ty).shape - for i_elem in range(vec_len): - reg = llvm.insertelement( - reg, next(regs), arith.constant(i32, i_elem) - ) - return reg - repack_fns.append(_repack) else: array_regs = list(array.registers.flat) - repack_fns.append(lambda regs: next(regs)) reg_constraint = "f" elif ir.BF16Type.isinstance(dtype) or ir.F16Type.isinstance(dtype): if not ir.VectorType.isinstance(reg_ty): raise NotImplementedError(array.mlir_dtype) [vec_len] = ir.VectorType(reg_ty).shape - if vec_len != 2: + if vec_len % 2: raise NotImplementedError(vec_len) - i32_reg_ty = ir.VectorType.get((1,), i32) + num_i32_regs = vec_len // 2 + i32_reg_ty = ir.VectorType.get((num_i32_regs,), i32) array_regs = [ vector.extractelement( - vector.bitcast(i32_reg_ty, reg), position=c(0, index) + vector.bitcast(i32_reg_ty, reg), position=c(i, index) ) + for i in range(num_i32_regs) for reg in array.registers.flat ] reg_constraint = "r" - def _repack(regs, reg_ty=reg_ty, i32_reg_ty=i32_reg_ty): - return vector.bitcast(reg_ty, vector.splat(i32_reg_ty, next(regs))) - repack_fns.append(_repack) else: raise NotImplementedError(array.mlir_dtype) regs += array_regs @@ -2460,14 +2462,14 @@ def _repack(regs, reg_ty=reg_ty, i32_reg_ty=i32_reg_ty): i32 = ir.IntegerType.get_signless(32) results = [] regs_it = iter(regs) - for array, repack_fn in zip(arrays, repack_fns, strict=True): + for array in arrays: num_regs = array.registers.size reg_ty = array.registers.flat[0].type if ir.VectorType.isinstance(reg_ty): reg_ty = ir.VectorType(reg_ty) new_registers = np.empty((num_regs,), dtype=object) for i_vreg in range(num_regs): - reg = repack_fn(regs_it) + reg = _repack(regs_it, reg_ty) assert reg.type == reg_ty, (reg.type, reg_ty) new_registers[i_vreg] = reg results.append( diff --git a/jax/experimental/mosaic/gpu/inference_utils.py b/jax/experimental/mosaic/gpu/inference_utils.py index 6362626404c5..73ce23c427cd 100644 --- a/jax/experimental/mosaic/gpu/inference_utils.py +++ b/jax/experimental/mosaic/gpu/inference_utils.py @@ -95,6 +95,22 @@ def has_out_transforms_set(op: MlirOperation) -> bool: return "out_transforms" in op.attributes +def attr_element( + attr_name: str, op: MlirOperation, index: int +) -> ir.Attribute | None: + """Returns `op.attributes[attr_name][index]` if it exists, otherwise None. + + If `op.attributes[attr_name]` exists, then `index` must be a valid index into + the attribute array. + """ + if attr_name not in op.attributes: + return None + attr = op.attributes[attr_name] + if not attr: + return None + return op.attributes[attr_name][index] # type: ignore + + def _in_attr_for_operand( op: MlirOperation, operand: ir.Value, @@ -109,9 +125,7 @@ def _in_attr_for_operand( operand_number = [o for o in op.operands if predicate(o)].index(operand) - if attr_name not in op.attributes: - return None - return op.attributes[attr_name][operand_number] # type: ignore + return attr_element(attr_name, op, operand_number) in_layout_for_operand = partial( diff --git a/jax/experimental/mosaic/gpu/launch_context.py b/jax/experimental/mosaic/gpu/launch_context.py index ce432f26dac2..64cdedc779c8 100644 --- a/jax/experimental/mosaic/gpu/launch_context.py +++ b/jax/experimental/mosaic/gpu/launch_context.py @@ -19,9 +19,10 @@ import enum import functools import math -from typing import Any +from typing import Any, Literal from jax._src.lib import mosaic_gpu_dialect as mgpu_dialect +from jax._src import lib as jaxlib from jaxlib.mlir import ir from jaxlib.mlir.dialects import arith from jaxlib.mlir.dialects import func @@ -158,7 +159,7 @@ class TransposeTransform(MemRefTransform): def __post_init__(self): if len(self.permutation) != len(set(self.permutation)): - raise ValueError("Permutation must be a permutation") + raise ValueError("All elements of `permutation` must be unique") def apply(self, ref: ir.Value) -> ir.Value: return utils.memref_transpose(ref, self.permutation) @@ -228,6 +229,7 @@ def batch(self, leading_rank: int) -> MemRefTransform: OnDeviceProfiler = profiler.OnDeviceProfiler +ReductionOp = Literal["add", "min", "max", "inc", "dec", "and", "or", "xor"] @dataclasses.dataclass() class LaunchContext: @@ -309,6 +311,9 @@ def _get_tma_desc( gmem_transform: tuple[MemRefTransform, ...], transformed_slice_shape: tuple[int, ...], swizzle: int | None, + reduction_op: Literal[ + "add","min","max","inc","dec","and","or","xor" + ] | None, ): tma_desc_key = (gmem_ref, transformed_slice_shape, swizzle, gmem_transform) if (tma_desc := self.tma_descriptors.get(tma_desc_key, None)) is None: @@ -320,6 +325,14 @@ def init_tma_desc(host_ptr): ref = t.apply(ref) ref_ty = ir.MemRefType(ref.type) # TODO(apaszke): Use utils.memref_ptr to compute base_ptr + strides, _ = ref_ty.get_strides_and_offset() + if strides[-1] != 1: + raise ValueError( + "TMA requires the stride of the last dimension after" + " transforming the GMEM reference to be 1, but it is" + f" {strides[-1]}." + ) + _, offset, *sizes_and_strides = memref.extract_strided_metadata(ref) aligned_ptr_idx = memref.extract_aligned_pointer_as_index(ref) as_i64 = lambda i: arith.index_cast(i64, i) @@ -337,10 +350,38 @@ def init_tma_desc(host_ptr): ) # TODO(apaszke): Better verification (e.g. slice is non-zero) # TODO(apaszke): We always know strides statically. + if jaxlib.version < (0, 5, 4): + dtype_or_bitwidth = c(utils.bitwidth(ref_ty.element_type), i64) + else: + if isinstance(ref_ty.element_type, ir.IntegerType): + if reduction_op is not None: + raise ValueError( + f"TMA with reduction_op={reduction_op} is not supported with Integers" + ) + bitwidth = utils.bitwidth_impl(ref_ty.element_type) + if bitwidth == 4: + tma_dtype = 0 + elif bitwidth == 8: + tma_dtype = 1 + elif bitwidth == 16: + tma_dtype = 2 + elif bitwidth == 32: + tma_dtype = 3 + elif bitwidth == 64: + tma_dtype = 4 + elif ir.F16Type.isinstance(ref_ty.element_type): + tma_dtype = 5 + elif ir.F32Type.isinstance(ref_ty.element_type): + tma_dtype = 6 + elif ir.BF16Type.isinstance(ref_ty.element_type): + tma_dtype = 7 + else: + raise ValueError(f"unsupported TMA dtype {ref_ty.element_type}") + dtype_or_bitwidth = c(tma_dtype, i64) args = [ host_ptr, base_ptr, - c(utils.bitwidth(ref_ty.element_type), i64), + dtype_or_bitwidth, c(rank, i64), utils.pack_array([as_i64(i) for i in sizes_and_strides[:rank]]), utils.pack_array([as_i64(i) for i in sizes_and_strides[rank:]]), @@ -374,7 +415,10 @@ def async_copy( uniform: bool = True, collective: Sequence[gpu.Dimension] | gpu.Dimension | None = None, partitioned: int | None = None, - predicate: ir.Value | None = None, # Should select 0 or 1 threads from the WG. + predicate: ( + ir.Value | None + ) = None, # Should select 0 or 1 threads from the WG. + reduction_op: ReductionOp | None = None, ): """Initiates an async copy between GMEM and SMEM. @@ -453,6 +497,13 @@ def async_copy( " multiple of 16 bytes" ) + if reduction_op is not None and jaxlib.version < (0, 5, 4): + raise ValueError("TMA with reduction is only supported with jaxlib >= 0.5.4") + if reduction_op is not None and not isinstance(gmem_ref_ty.element_type, ir.FloatType): + raise ValueError("TMA with reduction is only supported with float dtype") + if reduction_op is not None and reduction_op != "add": + raise ValueError("TMA with reduction is only supported with add operation") + # NOTE: TMA supports OOB indices, so we skip the check. base_indices, slice_shape, is_squeezed = utils.parse_indices( gmem_slice, ir.MemRefType(gmem_ref.type).shape, check_oob=False @@ -597,7 +648,7 @@ def partition_dim(dim: int, idx: ir.Value, num_chunks: int): multicast_mask = None tma_desc = self._get_tma_desc( - gmem_ref, gmem_transform, tuple(slice_shape), swizzle, + gmem_ref, gmem_transform, tuple(slice_shape), swizzle, reduction_op, ) # We constuct TMA descriptors in column-major order. @@ -606,7 +657,8 @@ def partition_dim(dim: int, idx: ir.Value, num_chunks: int): ] uniform_ctx = ( - functools.partial(utils.single_thread, per_block=False) + functools.partial( + utils.single_thread, scope=utils.ThreadSubset.WARPGROUP) if uniform and predicate is None else contextlib.nullcontext ) @@ -618,8 +670,8 @@ def partition_dim(dim: int, idx: ir.Value, num_chunks: int): ) if (zeroth_bw := slice_shape[-1] * element_bitwidth) % 128 != 0: raise ValueError( - "Async copies require the number of bytes copied along the last" - f" dimension to be divisible by 16, but got {zeroth_bw}" + "Async copies require the number of bits copied along the last" + f" dimension to be divisible by 128, but got {zeroth_bw}" ) if ( swizzle is not None @@ -641,6 +693,7 @@ def partition_dim(dim: int, idx: ir.Value, num_chunks: int): ) barrier_ptr = barrier.get_ptr() with uniform_ctx(): + assert reduction_op is None if collective_size > 1 and partitioned is not None: if predicate is None: predicate = c(1, ir.IntegerType.get_signless(1)) @@ -679,12 +732,28 @@ def partition_dim(dim: int, idx: ir.Value, num_chunks: int): ) else: assert multicast_mask is None - with uniform_ctx(): - nvvm.cp_async_bulk_tensor_global_shared_cta( - tma_desc, smem_ptr, rev_dyn_base_indices, predicate=predicate - ) - if arrive: - nvvm.cp_async_bulk_commit_group() + if reduction_op is not None: + with uniform_ctx(): + if predicate is None: + predicate = c(1, ir.IntegerType.get_signless(1)) + rank = len(slice_shape) + idx_operands = ",".join(f"${i}" for i in range(3, 3 + rank)) + llvm.inline_asm( + ir.Type.parse("!llvm.void"), + [predicate,smem_ptr,tma_desc,*rev_dyn_base_indices], + f"@$0 cp.reduce.async.bulk.tensor.{rank}d.global.shared::cta.{reduction_op}.tile.bulk_group [$2,{{{idx_operands}}}], [$1];", + "b,r,l" + ",r" * rank, + has_side_effects=True, + ) + if arrive: + nvvm.cp_async_bulk_commit_group() + else: + with uniform_ctx(): + nvvm.cp_async_bulk_tensor_global_shared_cta( + tma_desc, smem_ptr, rev_dyn_base_indices, predicate=predicate + ) + if arrive: + nvvm.cp_async_bulk_commit_group() def await_async_copy( self, allow_groups: int, await_read_only: bool = False diff --git a/jax/experimental/mosaic/gpu/layout_inference.py b/jax/experimental/mosaic/gpu/layout_inference.py index 0d2811bb5610..c9c565f331c9 100644 --- a/jax/experimental/mosaic/gpu/layout_inference.py +++ b/jax/experimental/mosaic/gpu/layout_inference.py @@ -21,6 +21,7 @@ import math from typing import cast +from jax._src import lib as jaxlib from jax._src.lib import mosaic_gpu_dialect as mgpu from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith @@ -44,7 +45,9 @@ def _add_layout_inference_rule(op: type[ir.OpView], rule: LayoutInferenceRule): - _layout_inference_rules[op.OPERATION_NAME] = rule # pytype: disable=attribute-error + if op is not None: + _layout_inference_rules[op.OPERATION_NAME] = rule # pytype: disable=attribute-error + return rule def _set_layout_attributes( @@ -192,7 +195,7 @@ def is_array(v: ir.Value) -> bool: # This is left for a future change, and currently we only do "down # propagation". layout = _choose_representative_layout(layouts) - # It is unsafe to t conclude that this op produces a splat if not all inputs + # It is unsafe to conclude that this op produces a splat if not all inputs # have been inferred: some of them might turn out not to be splats! if layouts_lib.is_splat_fragmented_layout(layout) and not all_inputs_have_layout: return None @@ -247,6 +250,51 @@ def is_array(v: ir.Value) -> bool: _add_layout_inference_rule(op, _infer_pointwise_op_layouts) +# TODO(bchetioui): remove once minimum jaxlib >= 0.5.3. +OptimizationBarrierOp = getattr(mgpu, "OptimizationBarrierOp", None) + + +@partial(_add_layout_inference_rule, OptimizationBarrierOp) +def _infer_optimization_barrier_op_layout( + op: OptimizationBarrierOp, +) -> OptionalLayouts: + def is_array(v: ir.Value) -> bool: + return ir.VectorType.isinstance(v.type) + + if inference_utils.has_in_layouts_set(op): + op_in_layouts = list(inference_utils.in_layouts(op)) + return op_in_layouts, op_in_layouts + + if inference_utils.has_out_layouts_set(op): + op_out_layouts = list(inference_utils.out_layouts(op)) + return op_out_layouts, op_out_layouts + + layouts = [None] * len(op.operands) + for i, operand in enumerate(filter(is_array, op.operands)): + layouts[i] = inference_utils.value_layout(operand) + + for i, result in enumerate(filter(is_array, op.results)): + possible_layouts = set() + for op_operand_use in cast(ir.OpResult, result).uses: + consumer = op_operand_use.owner + op_user = consumer.operands[op_operand_use.operand_number] + layout = inference_utils.in_layout_for_operand(consumer, op_user) + if layout is not None: + possible_layouts.add(layout) + if possible_layouts and layouts[i] is None: + # TODO(bchetioui): we could actually just pick any user layout here, + # and optimize later. This is fine for now. + layouts[i] = _choose_representative_layout(possible_layouts) + + # TODO(bchetioui): handle annotating layout for only certain operands. + # Otherwise, layouts may not get propagated through optimization barriers, if + # a single branch does not carry any forcing layout, which is pretty bad. + if any(layout is None for layout in layouts): + return None + + return layouts, layouts + + @partial(_add_layout_inference_rule, arith.ConstantOp) def _infer_constant_op_layout(constant_op: arith.ConstantOp) -> OptionalLayouts: if not ir.VectorType.isinstance(constant_op.result.type): @@ -306,23 +354,46 @@ def _infer_yield_op_layout(op: scf.YieldOp) -> OptionalLayouts: return (layouts, []) -@partial(_add_layout_inference_rule, scf.ForOp) -def _infer_for_op_layout(op: scf.ForOp) -> OptionalLayouts: - yield_op = op.body.operations[len(op.body.operations) - 1] - assert isinstance(yield_op, scf.YieldOp) - - if inference_utils.has_in_layouts_set(yield_op): - yield_layouts = list(inference_utils.in_layouts(yield_op)) +def _infer_from_yield_ops(op: ir.Operation) -> list[ir.Attribute] | None: + candidates = [] + for region in op.regions: + [block] = region.blocks + yield_op = block.operations[len(block.operations) - 1] + assert isinstance(yield_op, scf.YieldOp) + if not inference_utils.has_in_layouts_set(yield_op): + continue + yield_layouts = inference_utils.in_layouts(yield_op) if any( layouts_lib.is_splat_fragmented_layout(layout) for layout in yield_layouts ): - return None - return (yield_layouts, yield_layouts) + continue + candidates.append(yield_layouts) + if not candidates: + return None + return [_choose_representative_layout(set(c)) for c in zip(*candidates)] + +@partial(_add_layout_inference_rule, scf.ForOp) +def _infer_for_op_layout(op: scf.ForOp) -> OptionalLayouts: # TODO(bchetioui): we don't attempt to propagate from outside for the moment. # For the existing kernels, propagating from the YieldOp should be enough. + if layouts := _infer_from_yield_ops(op): + return layouts, layouts + return None + +@partial(_add_layout_inference_rule, scf.IfOp) +def _infer_if_op_layout(op: scf.IfOp) -> OptionalLayouts: + if layouts := _infer_from_yield_ops(op): + return [], layouts + return None + + +@partial(_add_layout_inference_rule, scf.IndexSwitchOp) +def _infer_index_switch_op_layout(op: scf.IndexSwitchOp) -> OptionalLayouts: + if layouts := _infer_from_yield_ops(op): + return [], layouts return None @@ -333,7 +404,6 @@ def _infer_splat_op_layout(splat_op: vector.SplatOp) -> OptionalLayouts: shape=cast(ir.ShapedType, splat_op.result.type).shape ) ) - return [], [layout] @@ -374,6 +444,15 @@ def _infer_reduction_op_layout(op: vector.ReductionOp) -> OptionalLayouts: return None +# TODO(dasenov): Remove this after the minimal jaxlib version is 0.5.4. +if jaxlib.version >= (0, 5, 4): + @partial(_add_layout_inference_rule, mgpu.LayoutCastOp) + def _infer_layout_cast_op_layout( + layout_cast_op: mgpu.LayoutCastOp, + ) -> OptionalLayouts: + return [layout_cast_op.new_layout], [layout_cast_op.new_layout] + + @partial(_add_layout_inference_rule, mgpu.WGMMAOp) def _infer_wgmma_op_layout(wgmma_op: mgpu.WGMMAOp) -> OptionalLayouts: layout = layouts_lib.to_layout_attr(fa.WGMMA_LAYOUT) @@ -479,23 +558,33 @@ def inference_step(op: ir.Operation): # make sure to derive a single vector size in order to avoid relayouts at # lowering time. default_vector_size = math.inf - - def update_default_vector_size(op: ir.OpView): + def update_default_vector_size_from_vector(v: ir.Value): nonlocal default_vector_size - for v in list(op.operands) + list(op.results): - if ir.VectorType.isinstance(v.type): - max_vec_size_for_v = ( - np.prod(cast(ir.ShapedType, v.type).shape) // fa.WARPGROUP_SIZE - ) - desired_vec_size = 8 // utils.bytewidth(v.type.element_type) - default_vector_size = min( - default_vector_size, max_vec_size_for_v, desired_vec_size - ) + max_vec_size_for_v = ( + np.prod(cast(ir.ShapedType, v.type).shape) // fa.WARPGROUP_SIZE + ) + desired_vec_size = 8 // utils.bytewidth(v.type.element_type) + default_vector_size = min( + default_vector_size, max_vec_size_for_v, desired_vec_size + ) + + def update_default_vector_size_from_op(op: ir.OpView): + for i, v in enumerate( + filter(lambda v: ir.VectorType.isinstance(v.type), op.operands) + ): + if inference_utils.attr_element("in_layouts", op, i) is None: + update_default_vector_size_from_vector(v) + + for i, v in enumerate( + filter(lambda v: ir.VectorType.isinstance(v.type), op.results) + ): + if inference_utils.attr_element("out_layouts", op, i) is None: + update_default_vector_size_from_vector(v) for op in module.body: - traverse_op(op, update_default_vector_size) + traverse_op(op, update_default_vector_size_from_op) - if default_vector_size is None: # Nothing to annotate. + if default_vector_size == math.inf: # Nothing to annotate. return def to_default_layout(ty: ir.Type) -> ir.Attribute | None: diff --git a/jax/experimental/mosaic/gpu/layouts.py b/jax/experimental/mosaic/gpu/layouts.py index 5c3b23119779..0a4f3ed09116 100644 --- a/jax/experimental/mosaic/gpu/layouts.py +++ b/jax/experimental/mosaic/gpu/layouts.py @@ -96,7 +96,7 @@ def is_strided_fragmented_layout(attr: ir.Attribute) -> bool: _tiled_layout_attr_pattern = re.compile( r"^#mosaic_gpu.TiledLayout<\[(?P.*)\]," - r" warp_dim\s*=\s*(?P[-\d]+)," + r" warp_dim\s*=\s*(?P.+)," r" lane_dims\s*=\s*\[(?P.*)\]," r" vector_dim\s*=\s*(?P[-\d]+)>$" ) @@ -107,15 +107,29 @@ def to_tiled_layout_attr( ) -> ir.Attribute: """Constructs a #mosaic_gpu.TiledLayout attribute from a TiledLayout.""" + def _int_or_replicated(d: int | fa.Replicated) -> str: + if isinstance(d, fa.Replicated): + return f"#mosaic_gpu.Replicated" + return str(d) + tile_str = lambda tile: "[" + ", ".join(str(d) for d in tile) + "]" tiling = "[" + ", ".join(tile_str(tile) for tile in layout.tiling.tiles) + "]" + lane_dims = ( + "[" + ",".join(_int_or_replicated(d) for d in layout.lane_dims) + "]" + ) + return ir.Attribute.parse( - f"#mosaic_gpu.TiledLayout<{tiling}, warp_dim={layout.warp_dim}," - f" lane_dims={list(layout.lane_dims)}, vector_dim={layout.vector_dim}>" + f"#mosaic_gpu.TiledLayout<{tiling}," + f" warp_dim={_int_or_replicated(layout.warp_dim)}," + f" lane_dims={lane_dims}, vector_dim={layout.vector_dim}>" ) _list_of_lists_delimiter = re.compile(r"\]\s*,\s*\[") +_int_pattern = re.compile(r"^(?P[-\d]+)(\s*:\s*\w+)?$") +_replicated_pattern = re.compile( + r"^#mosaic_gpu.Replicated<\s*times\s*=\s*(?P\d+)\s*>\s*$" +) def from_tiled_layout_attr( @@ -133,6 +147,15 @@ def from_tiled_layout_attr( f"Expected a #mosaic_gpu.TiledLayout attribute, got {attr}" ) + def _int_or_replicated(replicated_dim: str) -> int | fa.Replicated: + match = _replicated_pattern.fullmatch(replicated_dim) + if match: + return fa.Replicated(int(match.group("times"))) + match = _int_pattern.fullmatch(replicated_dim) + if match: + return int(match.group("num")) + raise ValueError(f"Unexpected format for replicated dim {replicated_dim}") + tiling_str = match.group("tiling") tile_strings = [] if len(tiling_str) > 2: @@ -140,9 +163,12 @@ def from_tiled_layout_attr( tiles = tuple(tuple(map(int, ts.split(","))) for ts in tile_strings) return fa.TiledLayout( tiling=fa.Tiling(tiles), - warp_dim=int(match.group("warp_dim")), - lane_dims=tuple(int(s) for s in match.group("lane_dims").split(",")), - vector_dim=int(match.group("vector_dim")) + warp_dim=_int_or_replicated(match.group("warp_dim")), + lane_dims=tuple( + _int_or_replicated(s.strip()) + for s in match.group("lane_dims").split(",") + ), + vector_dim=int(match.group("vector_dim")), ) @@ -155,7 +181,6 @@ def to_layout_attr( fa.WGSplatFragLayout | fa.WGStridedFragLayout | fa.TiledLayout - | fa.WGMMARowFragLayout ), ) -> ir.Attribute: """Constructs an MLIR attribute that corresponds to the given layout.""" @@ -166,30 +191,18 @@ def to_layout_attr( return to_strided_fragmented_layout_attr(layout) case fa.TiledLayout(): return to_tiled_layout_attr(layout) - case fa.WGMMARowFragLayout(): - return ir.Attribute.parse("#mosaic_gpu.WGMMARowFragLayout") case _: raise NotImplementedError( f"Unsupported layout for conversion to MLIR attribute: {layout}" ) -_wgmma_row_fragmented_layout_attr_pattern = re.compile( - r"^#mosaic_gpu.WGMMARowFragLayout$" -) - - -def is_wgmma_row_fragmented_layout(attr: ir.Attribute) -> bool: - return bool(_wgmma_row_fragmented_layout_attr_pattern.search(str(attr))) - - def from_layout_attr( attr: ir.Attribute, ) -> ( fa.WGSplatFragLayout | fa.WGStridedFragLayout | fa.TiledLayout - | fa.WGMMARowFragLayout ): """Constructs a layout from an MLIR attribute.""" if is_splat_fragmented_layout(attr): @@ -198,8 +211,6 @@ def from_layout_attr( return from_strided_fragmented_layout_attr(attr) elif is_tiled_layout(attr): return from_tiled_layout_attr(attr) - elif is_wgmma_row_fragmented_layout(attr): - return fa.WGMMARowFragLayout() else: raise NotImplementedError( f"Unsupported layout for conversion from MLIR attribute: {attr}" diff --git a/jax/experimental/mosaic/gpu/profiler.py b/jax/experimental/mosaic/gpu/profiler.py index 0c128f88d169..5b278468b98c 100644 --- a/jax/experimental/mosaic/gpu/profiler.py +++ b/jax/experimental/mosaic/gpu/profiler.py @@ -17,10 +17,11 @@ import itertools import json import math -from typing import Callable, ParamSpec, TypeVar +from typing import Callable, ParamSpec, TypeAlias, TypeVar import warnings import jax +from jax._src import stages from jax._src.lib import xla_client import jax.numpy as jnp from jaxlib.mlir import ir @@ -97,29 +98,44 @@ def run(*args, **kwargs): return outs, float(elapsed) -def _measure_cupti(f, aggregate): - def run(*args, **kwargs): - mosaic_gpu_lib._mosaic_gpu_ext._cupti_init() - try: - results = jax.block_until_ready(jax.jit(f)(*args, **kwargs)) - finally: - timings = mosaic_gpu_lib._mosaic_gpu_ext._cupti_get_timings() - return results, timings - - def wrapper(*args, **kwargs): - run(*args, **kwargs) # Warmup. - results, timings = run(*args, **kwargs) - if not timings: - return results, None - elif aggregate: - return results, sum(item[1] for item in timings) - else: - return results, timings - return wrapper - - -def measure(f: Callable, *, mode: str = "events", aggregate: bool = True -) -> Callable: +Timings: TypeAlias = list[tuple[str, float]] | float | None + + +@dataclasses.dataclass(frozen=True, kw_only=True) +class Cupti: + """CUPTI-based profiler.""" + + # If `True`, detach CUPTI from the process after measurement. + finalize: bool = True + + def measure( + self, f: Callable[P, T], *, aggregate: bool = True + ) -> Callable[P, tuple[T, Timings]]: + if not isinstance(f, (stages.Wrapped, stages.Compiled)): + f = jax.jit(f) + + def wrapper(*args: P.args, **kwargs: P.kwargs): + jax.block_until_ready(f(*args, **kwargs)) # Warmup. + ext = mosaic_gpu_lib._mosaic_gpu_ext + ext._cupti_init() + try: + results = jax.block_until_ready(f(*args, **kwargs)) + finally: + timings = ext._cupti_get_timings(self.finalize) + + if not timings: + return results, None + elif aggregate: + return results, sum(item[1] for item in timings) + else: + return results, timings + + return wrapper + + +def measure( + f: Callable[P, T], *, mode: str = "events", aggregate: bool = True +) -> Callable[P, tuple[T, Timings]]: """Sets up a function ``f`` for profiling on GPU. ``measure`` is a higher-order function that augments the argument ``f`` to @@ -173,10 +189,10 @@ def measure(f: Callable, *, mode: str = "events", aggregate: bool = True In an attempt to minimize the second effect, internally the events-based implementation may execute ``f`` more than once to "warm up" and exclude compilation time from the measurement. - """ + """ # fmt: skip match mode: case "cupti": - return _measure_cupti(f, aggregate) + return Cupti().measure(f, aggregate=aggregate) case "events": if not aggregate: raise ValueError(f"{aggregate=} is not supported with {mode=}") diff --git a/jax/experimental/mosaic/gpu/tcgen05.py b/jax/experimental/mosaic/gpu/tcgen05.py index 3330500cd6dc..53056ce594b2 100644 --- a/jax/experimental/mosaic/gpu/tcgen05.py +++ b/jax/experimental/mosaic/gpu/tcgen05.py @@ -197,7 +197,7 @@ def mma( ), a_mk, b_nk, - d_type=ir.F32Type.get(), + d_type=d.dtype, m=m_group_elems, n=n_group_elems, collective=collective, @@ -327,36 +327,65 @@ def tmem_relinquish_alloc_permit(): has_side_effects=True, ) -def tmem_load(tmem_addr, shape, num): +def _tmem_access_helper(shape, num, packing: int = 1): if num.bit_count() != 1 or num > 128: raise ValueError(f"num must be a power of 2 and <= 128, got: {num}") match shape: case "16x128b": - num_out_regs = 2 + num_regs = 2 case "16x256b": - num_out_regs = 4 + num_regs = 4 case _: raise NotImplementedError(f"{shape=} is unsupported") - if num * num_out_regs >= 256: + num_regs *= num + if num_regs > 255: raise ValueError( - f"Loading too much TMEM at once: {num=} and each load requires" - f" {num_out_regs} registers, which exceeds the limit of 256" + f"TMEM transation too big : {shape=} and {num=} involve" + f" {num_regs} registers per-thread, which exceeds the limit of 255" ) - num_out_regs *= num + regs_vector = ",".join(f"${i}" for i in range(num_regs)) + regs_vector = "{" + regs_vector + "}" + return num_regs, regs_vector + + +def tmem_load(tmem_addr, shape, num, packing: int = 1): i32 = ir.IntegerType.get_signless(32) - out_regs = ",".join("$" + str(i) for i in range(num_out_regs)) + num_out_regs, regs_vector = _tmem_access_helper(shape, num, packing) + if packing == 1: + pack_mod = "" + elif packing == 2: + pack_mod = ".pack::16b" + else: + raise ValueError(f"Unsupported packing: {packing}") regs = llvm.inline_asm( ir.Type.parse( "!llvm.struct<(" + ",".join("i32" for _ in range(num_out_regs)) + ")>" ), [tmem_addr], - f"tcgen05.ld.sync.aligned.{shape}.x{num}.b32 {{{out_regs}}}, [${num_out_regs}];", + f"tcgen05.ld.sync.aligned.{shape}.x{num}{pack_mod}.b32 {regs_vector}, [${num_out_regs}];", "=r," * num_out_regs + "r", has_side_effects=True, ) return [llvm.extractvalue(i32, regs, [i]) for i in range(num_out_regs)] +def tmem_store(tmem_addr, shape, num, regs, packing: int = 1): + num_out_regs, regs_vector = _tmem_access_helper(shape, num, packing) + if packing == 1: + pack_mod = "" + elif packing == 2: + pack_mod = ".unpack::16b" + else: + raise ValueError(f"Unsupported packing: {packing}") + llvm.inline_asm( + ir.Type.parse("!llvm.void"), + [*regs, tmem_addr], + f"tcgen05.st.sync.aligned.{shape}.x{num}{pack_mod}.b32 [${num_out_regs}], {regs_vector};", + "r," * num_out_regs + "r", + has_side_effects=True, + ) + + @dataclasses.dataclass(frozen=True) class TMEMLayout: """Represents the way a shape is laid out in TMEM. @@ -521,9 +550,9 @@ def __getitem__(self, *idxs): raise NotImplementedError("Slicing of TMEM not impelmented yet") if self.shape[1] % 8: raise NotImplementedError - if self.dtype != ir.F32Type.get(): - raise NotImplementedError(self.dtype) - layout = _m128_256bit_32bit_layout(self.shape) + if utils.bitwidth(self.dtype) not in {16, 32}: + raise NotImplementedError(f"Unsupported dtype: {self.dtype}") + layout = _m128_layout(self.shape) regs_shape = layout.registers_shape(self.shape) if self.layout == TMEMLayout(elements_in_tile=(TMEM_ROWS, 8)): # load_32xcols returns a 4xN array, but the FA tiling we use here tiles @@ -556,47 +585,168 @@ def __getitem__(self, *idxs): ) return fa.FragmentedArray(_registers=registers, _layout=layout, _is_signed=None) + def __setitem__(self, idxs, value): + if not isinstance(idxs, tuple): + idxs = (idxs,) + base_idxs, slice_shape, is_squeezed = utils.parse_indices(idxs, self.shape) + if any(is_squeezed): + raise ValueError( + "TMEM stores don't support integer indexing (only slices allowed)" + ) + if any(idx != 0 for idx in base_idxs) or tuple(slice_shape) != self.shape: + raise NotImplementedError("Slicing parts of TMEM not implemented yet") + if self.shape[1] % 8: + raise NotImplementedError + if utils.bitwidth(self.dtype) not in {16, 32}: + raise NotImplementedError(f"Unsupported dtype: {self.dtype}") + if not isinstance(value, fa.FragmentedArray): + raise ValueError(f"TMEM stores expect a FragmentedArray, got: {value}") + if value.shape != self.shape: + raise ValueError( + f"Stored array has shape {value.shape}, but TMEM has shape" + f" {self.shape}" + ) + if value.mlir_dtype != self.dtype: + raise ValueError( + f"Stored array has dtype {value.mlir_dtype}, but TMEM has dtype" + f" {self.dtype}" + ) + if value.layout != LAYOUT: + raise ValueError( + f"Stored array has layout {value.layout}, but only tcgen05.LAYOUT is" + " supported" + ) + if self.layout == TMEMLayout(elements_in_tile=(TMEM_ROWS, 8)): + # store_32xcols needs a 4xN array, but the FA tiling we use here tiles + # columns before rows, and so it is Nx4 (after ignoring all 1 dims). + _store_32xcols( + self.address, value.registers.T.reshape((4, -1)) + ) + else: # TODO(apaszke): Collective MMA layout + raise NotImplementedError( + f"Stores only implemented for refs with standard layout, got: {self.layout}" + ) + + +def _transfer_32xcols(base_addr, cols): + i32 = ir.IntegerType.get_signless(32) + cols_per_num = 8 # Here we generate a plan compatible with tcgen05.LAYOUT. + assert cols % cols_per_num == 0 + total_num = cols // cols_per_num + if total_num <= 32: + instr_num = total_num + elif total_num == 64: + instr_num = 32 + else: + raise NotImplementedError(total_num) + # We transfer 16 lanes at a time, but have 32 to deal with. + for lane_step in range(2): + addr_row = arith.addi(base_addr, utils.c((lane_step * 16) << 16, i32)) + cols_per_instr = instr_num * cols_per_num + for num_step in range(total_num // instr_num): + num_slice = slice(num_step * instr_num, (num_step + 1) * instr_num) + addr_row_col = arith.addi(addr_row, utils.c(num_step * cols_per_instr, i32)) + yield addr_row_col, instr_num, lane_step, num_slice + + +def _store_32xcols(base_addr, vector_regs): + i32 = ir.IntegerType.get_signless(32) + assert vector_regs.ndim == 2 and vector_regs.shape[0] == 4 + cols = vector_regs.shape[1] * 8 + + packing = 64 // utils.bitwidth(vector_regs.flat[0].type) + if packing == 1: + store_shape = "16x256b" # 4 threads * 64 bits per vreg = 256 bits + regs = np.empty((4, vector_regs.shape[1], 2), dtype=object) + c0 = arith.constant(i32, 0) + c1 = arith.constant(i32, 1) + for idx, vreg in np.ndenumerate(vector_regs): + regs[(*idx, 0)] = llvm.extractelement(vreg, c0) + regs[(*idx, 1)] = llvm.extractelement(vreg, c1) + regs = regs.reshape(2, 2, vector_regs.shape[1], 2).swapaxes(1, 2) + # From a single lane perspective a num tile consists of a 2x2, with the + # minor dim traversing columns and major being 8 rows apart. + # See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-fragments-shape-16256b + assert regs.shape[-2:] == (2, 2) + elif packing == 2: + store_shape = "16x128b" # 4 threads * 32 bits per vreg = 128 bits + # From a single lane perspective a num tile has 2 registers, 8 rows apart. + # See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-fragments-shape-16128b + regs = vector_regs.reshape(2, 2, vector_regs.shape[1]).swapaxes(1, 2) + else: + raise NotImplementedError(packing) + + it = _transfer_32xcols(base_addr, cols) + for addr_row_col, instr_num, lane_step, num_slice in it: + regs_slice = regs[lane_step, num_slice].flat + tmem_store(addr_row_col, store_shape, instr_num, regs_slice, packing) + + def _load_32xcols(base_addr, cols, dtype): - # See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-fragments-shape-16256b i32 = ir.IntegerType.get_signless(32) - assert cols % 8 == 0 - cols_per_num_tile = 8 - load_shape = "16x256b" - num = cols // 8 - if num <= 32: - num_tiling = num - elif num == 64: - num_tiling = 32 + vec_ty = ir.VectorType.get((2,), dtype) + packing = 32 // utils.bitwidth(dtype) + if packing == 1: + load_shape = "16x256b" # 4 threads * 64 bits per vreg = 256 bits + elif packing == 2: + load_shape = "16x128b" # 4 threads * 32 bits per vreg = 128 bits else: - raise NotImplementedError(num) - vector_regs = np.ndarray((4, num), dtype=object) - # We load 16 lanes at a time, but need 32 in total. - for row_group in range(2): - addr_row = arith.addi(base_addr, arith.constant(i32, (row_group * 16) << 16)) - regs = [] - for num_group in range(num // num_tiling): - addr_row_col = arith.addi( - addr_row, - arith.constant(i32, num_tiling * num_group * cols_per_num_tile), - ) - regs += tmem_load(addr_row_col, load_shape, num_tiling) - regs = [llvm.bitcast(dtype, r) for r in regs] - undef = llvm.mlir_undef(ir.VectorType.get((2,), dtype)) - for r_low, r_high, idx in zip(regs[::2], regs[1::2], np.ndindex(num, 2)): - high_undef = llvm.insertelement(undef, r_low, utils.c(0, i32)) - vreg = llvm.insertelement(high_undef, r_high, utils.c(1, i32)) - vector_regs[idx[1] + 2 * row_group, idx[0]] = vreg + raise NotImplementedError(packing) + + vector_regs = np.ndarray((4, cols // 8), dtype=object) + + it = _transfer_32xcols(base_addr, cols) + c0 = arith.constant(i32, 0) + c1 = arith.constant(i32, 1) + for addr_row_col, instr_num, lane_step, num_slice in it: + regs = tmem_load(addr_row_col, load_shape, instr_num, packing) + row_slice = slice(lane_step * 2, (lane_step + 1) * 2) + # This aliases the original array, so updates will be reflected there. + vector_regs_update = vector_regs[row_slice, num_slice] + assert vector_regs_update.shape == (2, instr_num), (vector_regs_update.shape, instr_num) + if packing == 1: + regs = [llvm.bitcast(dtype, r) for r in regs] + # From a single lane perspective a num tile consists of a 2x2, with the + # minor dim traversing columns and major being 8 rows apart. + # See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-fragments-shape-16256b + regs = np.asarray(regs, dtype=object).reshape(instr_num, 2, 2).swapaxes(0, 1) + undef = llvm.mlir_undef(vec_ty) + assert regs.shape == (*vector_regs_update.shape, 2) + for idx in np.ndindex(vector_regs_update.shape): + high_undef = llvm.insertelement(undef, regs[(*idx, 0)], c0) + vreg = llvm.insertelement(high_undef, regs[(*idx, 1)], c1) + vector_regs_update[idx] = vreg + else: + assert packing == 2 + regs = [llvm.bitcast(vec_ty, r) for r in regs] + # From a single lane perspective a num tile has 2 registers, 8 rows apart. + # See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-fragments-shape-16128b + regs = np.asarray(regs, dtype=object).reshape(instr_num, 2).swapaxes(0, 1) + vector_regs_update[...] = regs + return vector_regs -def _m128_256bit_32bit_layout(shape: tuple[int, ...]): +def _m128_layout(shape: tuple[int, ...]): if len(shape) != 2: raise ValueError(f"Shape {shape} is not 2D") if shape[0] % 128 != 0 or shape[1] % 8 != 0: raise ValueError(f"Shape {shape} is not a multiple of 64x8") - return fa.TiledLayout( - fa.Tiling(((128, 8), (32, 8), (8, 8), (1, 2))), - warp_dim=-8, - lane_dims=(-4, -3), - vector_dim=-1, + return LAYOUT + +# Like WGMMA_LAYOUT, only each warp holds a 32xN strip instead of 16xN. +# The name is so short, because it's meant to be used qualified (tcgen05.LAYOUT) +LAYOUT = fa.TiledLayout( + fa.Tiling(((128, 8), (32, 8), (8, 8), (1, 2))), + warp_dim=-8, + lane_dims=(-4, -3), + vector_dim=-1, +) + + +def commit_tmem(): + void = ir.Type.parse("!llvm.void") + llvm.inline_asm( + void, [], "tcgen05.wait::st.sync.aligned;", "", has_side_effects=True, ) + utils.warpgroup_barrier() diff --git a/jax/experimental/mosaic/gpu/transform_inference.py b/jax/experimental/mosaic/gpu/transform_inference.py index ef2d3661674c..6026cb216166 100644 --- a/jax/experimental/mosaic/gpu/transform_inference.py +++ b/jax/experimental/mosaic/gpu/transform_inference.py @@ -30,6 +30,7 @@ from jax._src.lib.mlir.dialects import gpu from jax._src.lib.mlir.dialects import memref from jax._src.lib.mlir.dialects import vector +from jax._src.util import safe_zip from . import fragmented_array as fa from . import inference_utils @@ -60,17 +61,45 @@ def _set_transform_attributes( op.attributes["out_transforms"] = ir.ArrayAttr.get(out_transforms) +def _resolve_transforms( + transforms: ir.ArrayAttr | None, + other_transforms: ir.ArrayAttr | None, +) -> ir.ArrayAttr | None: + """Resolves two sets of competing transforms to a single compatible set. + + Args: + transforms: one optional set of transforms. + other_transforms: another optional set of transforms. + + Returns: + A single set of transforms that is compatible with both `transforms` and + `other_transforms`, or `None` if both transforms are `None`. + Raises: + NotImplementedError: if the two sets of transforms can't be resolved to a + single set. + """ + if transforms is None: + return other_transforms + + if other_transforms is None: + return transforms + + if transforms != other_transforms: + raise NotImplementedError( + f"Conflicting transforms {transforms} != {other_transforms}." + ) + + return transforms + + def infer_transforms_for_wgmma_ref(ref_ty: ir.MemRefType) -> ir.ArrayAttr: if len(ref_ty.shape) != 2: raise ValueError(f"Expected a 2D memref, got {ref_ty}") element_bytewidth = utils.bytewidth(ref_ty.element_type) strides, _ = ref_ty.get_strides_and_offset() - - if strides[0] < strides[1]: - raise NotImplementedError("Transpositions aren't handled yet.") - - minor_dim = ref_ty.shape[1] + transposed = strides[0] < strides[1] + minor_dim = ref_ty.shape[0 if transposed else 1] major_tiling = 8 # Try tiling with all swizzling modes starting from the largest one. @@ -86,12 +115,14 @@ def infer_transforms_for_wgmma_ref(ref_ty: ir.MemRefType) -> ir.ArrayAttr: break else: # No valid tile transform can be inferred. - raise ValueError( - f"{ref_ty.shape} is not a valid WGMMA shape" - ) + raise ValueError(f"{ref_ty.shape} is not a valid WGMMA shape") + if transposed: + tiling = (minor_tiling, major_tiling) + else: + tiling = (major_tiling, minor_tiling) return ir.ArrayAttr.get([ - mgpu.TileTransformAttr.get((major_tiling, minor_tiling)), + mgpu.TileTransformAttr.get(tiling), mgpu.SwizzleTransformAttr.get(minor_tiling * element_bytewidth), ]) @@ -156,27 +187,12 @@ def _infer_vector_load_store_transforms( f"Got layout {layout} which is not yet supported" ) - if transforms is not None and layout_transforms is not None: - if transforms != layout_transforms: - raise NotImplementedError( - f"Conflicting transforms for {op.base} in {op}: " - f"{transforms} != {layout_transforms}." - ) - return [transforms], [] - - if transforms is not None: - return [transforms], [] - - if layout_transforms is not None: - return [layout_transforms], [] - - return None + transforms = _resolve_transforms(transforms, layout_transforms) + return None if transforms is None else ([transforms], []) -# TODO(bchetioui): remove this once jaxlib minimum version >= 0.5.2. -SliceSMEMOp = getattr(mgpu, "SliceSMEMOp", None) -@partial(_add_transform_inference_rule, SliceSMEMOp) -def _infer_slice_smem_transforms(op: SliceSMEMOp) -> OptionalTransforms: +@partial(_add_transform_inference_rule, mgpu.SliceSMEMOp) +def _infer_slice_smem_transforms(op: mgpu.SliceSMEMOp) -> OptionalTransforms: transforms = None uses = cast(ir.OpResult, op.result).uses @@ -186,21 +202,14 @@ def _infer_slice_smem_transforms(op: SliceSMEMOp) -> OptionalTransforms: out_transforms = inference_utils.in_transforms_for_operand( consumer, op_user ) - if transforms is not None and out_transforms is not None: - if transforms != out_transforms: - raise NotImplementedError( - f"Conflicting transforms for {op_user} in {op}: " - f"{transforms} != {out_transforms}." - ) - elif out_transforms is not None: - transforms = out_transforms + transforms = _resolve_transforms(transforms, out_transforms) return None if transforms is None else ([], [transforms]) # TODO(bchetioui,apaszke): this empty rule is necessary while Mosaic doesn't use # the dialect in all cases. -# The rule is necessary in order to handle the lowering of `utils.memref_ptr` +# The rule is necessary in order to handle the lowering of `utils.memref_ptr` # which is used in `_construct_smem_reftree`. @partial(_add_transform_inference_rule, builtin.UnrealizedConversionCastOp) def _infer_unrealized_conversion_cast_transforms( @@ -229,14 +238,7 @@ def _infer_memref_view_transforms(op: memref.ViewOp) -> OptionalTransforms: out_transforms = inference_utils.in_transforms_for_operand( consumer, op_user ) - if transforms is not None and out_transforms is not None: - if transforms != out_transforms: - raise ValueError( - f"Conflicting transforms for {op_user} in {op}: " - f"{transforms} != {out_transforms}." - ) - elif out_transforms is not None: - transforms = out_transforms + transforms = _resolve_transforms(transforms, out_transforms) # TODO(bchetioui): do we actually need to assign a transform to the input of # the view op? Presumably, it'll only be used to access scratch memory. @@ -252,6 +254,125 @@ def _infer_dynamic_smem_transforms( return None +def _get_tile_and_swizzle_transforms( + transforms: ir.ArrayAttr | None, +) -> tuple[ir.Attribute, ir.Attribute]: + if transforms is None: + return + + if len(transforms) == 2: + tile_transform, swizzle_transform = transforms + if not ( + mgpu.TileTransformAttr.isinstance(tile_transform) + and mgpu.SwizzleTransformAttr.isinstance(swizzle_transform) + ): + raise NotImplementedError(f"Unsupported transforms {transforms}.") + return tile_transform, swizzle_transform + else: + raise NotImplementedError(f"Unsupported transforms {transforms}.") + + +# This is used by Pallas' "_handle_indexing" memory transform. +@partial(_add_transform_inference_rule, memref.SubViewOp) +def _infer_memref_subview_transforms( + op: memref.SubViewOp, +) -> OptionalTransforms: + transforms = None + + for result_use in cast(ir.OpResult, op.result).uses: + consumer = result_use.owner + op_user = consumer.operands[result_use.operand_number] + user_transforms = inference_utils.in_transforms_for_operand( + consumer, op_user + ) + transforms = _resolve_transforms(transforms, user_transforms) + + in_transforms = inference_utils.value_transforms(op.source) + transforms = _resolve_transforms(transforms, in_transforms) + + if transforms is None: + return None + + # Here, we have some transforms to propagate one way or the other. For now, + # we implement only the following basic propagation rules: + # - A tile transform can be propagated bidirectionally if the axes being + # tiled are not sliced, and are the logical minor axes of the source. + # - A swizzle transform can be propagated towards the input of a subview if + # the physical minormost dimension is unchanged. + # - We only propagate transforms if they consist of a single tile transform + # and a single swizzle transform. + # TODO(bchetioui): implement more complex propagation rules. + tile_transform, _ = _get_tile_and_swizzle_transforms(transforms) + + # Check swizzle transform propagation. + strides, _ = ir.MemRefType.get_strides_and_offset(op.source.type) + minor_dim = strides.index(min(strides)) + if op.source.type.shape[minor_dim] != op.static_sizes[minor_dim]: + raise NotImplementedError( + "Swizzle transforms can only propagated if the minor dimension is " + "unchanged." + ) + + # Check tile transform propagation. + num_tiled_axes = len(mgpu.TileTransformAttr(tile_transform).tiling) + last_n_dims = op.source.type.shape[-num_tiled_axes:] + last_n_sizes = list(op.static_sizes)[-num_tiled_axes:] + for slice_size, dim_size in safe_zip(last_n_sizes, last_n_dims): + if slice_size != dim_size: + raise NotImplementedError( + "Tile transforms are only propagated if the tiled axes are not " + "sliced." + ) + + return [transforms], [transforms] + + +@partial(_add_transform_inference_rule, memref.TransposeOp) +def _infer_memref_transpose_transforms( + op: memref.TransposeOp, +) -> OptionalTransforms: + in_ty = ir.MemRefType(op.in_.type) + if len(in_ty.shape) != 2: + raise NotImplementedError(f"Only 2D memrefs are supported, got {in_ty}") + in_strides, _ = in_ty.get_strides_and_offset() + out_strides, _ = ir.MemRefType(op.result.type).get_strides_and_offset() + transpose = in_strides != out_strides + + users = list(op.result.uses) + if len(users) != 1: + raise NotImplementedError( + f"Only memref.transpose with a single use are supported, got {op}" + ) + + op_operand_use = users[0] + consumer = op_operand_use.owner + op_user = consumer.operands[op_operand_use.operand_number] + out_transforms = inference_utils.in_transforms_for_operand(consumer, op_user) + + in_transforms = [] + if not transpose: + in_transforms = out_transforms + else: + tile_transform, swizzle_transform = _get_tile_and_swizzle_transforms( + out_transforms + ) + transposed_tiling = mgpu.TileTransformAttr(tile_transform).tiling[::-1] + in_transforms.append(mgpu.TileTransformAttr.get(transposed_tiling)) + in_transforms.append(swizzle_transform) + + return [ir.ArrayAttr.get(in_transforms)], [out_transforms] + + +# `memref.load` is used to load barrier phases---the rule needn't do anything +# interesting, but we need to have it in order to avoid crashing on it. +@partial(_add_transform_inference_rule, memref.LoadOp) +def _infer_memref_load_transforms(op: memref.LoadOp) -> OptionalTransforms: + if not ir.MemRefType(op.memref.type).shape: + # memref.load returns a scalar, so there is nothing interesting to do here. + return None + raise NotImplementedError("Non-scalar memref.load transforms") + + def _should_have_transforms(op: ir.OpView) -> bool: """Returns 'True' if the operation should be assigned in/out transforms.""" return any( diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index 28534cf4025b..3c7532dde99d 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -135,7 +135,7 @@ def _debug_scalar_ty_format(arg): return "%llu", arg if ir.F32Type.isinstance(arg.type): return "%f", arg - if ir.F16Type.isinstance(arg.type): + if ir.BF16Type.isinstance(arg.type) or ir.F16Type.isinstance(arg.type): arg = arith.extf(ir.F32Type.get(), arg) return "%f", arg raise NotImplementedError(f"Can't print the type {arg.type}") @@ -164,7 +164,7 @@ def debug_print(fmt, *args, uniform=True): raise NotImplementedError(arg.type) type_formats.append(ty_format) ctx = ( - functools.partial(single_thread, per_block=False) + functools.partial(single_thread, scope=ThreadSubset.WARPGROUP) if uniform else contextlib.nullcontext ) @@ -258,6 +258,7 @@ def warpgroup_idx(sync=True): class ThreadSubset(enum.IntEnum): + WARP = enum.auto() WARPGROUP = enum.auto() BLOCK = enum.auto() @@ -266,25 +267,34 @@ class ThreadSubset(enum.IntEnum): _ONCE_PER: ThreadSubset | None = None -def single_thread_predicate(per_block=True): +def single_thread_predicate(scope: ThreadSubset = ThreadSubset.BLOCK): + """Returns a predicate that selects a single thread. + + Args: + scope: What level of the thread hierarchy to select a thread from. + For example, if the scope is BLOCK, only one thread per block will be + selected. + """ + elected = nvvm.elect_sync(ir.IntegerType.get_signless(1)) + if scope == ThreadSubset.WARP: + return elected warp = warp_idx() - if not per_block: + if scope is not ThreadSubset.BLOCK: warp = arith.remui(warp, c(4, warp.type)) first_warp = arith.cmpi(arith.CmpIPredicate.eq, warp, c(0, warp.type)) - elected = nvvm.elect_sync(ir.IntegerType.get_signless(1)) return arith.andi(first_warp, elected) @contextlib.contextmanager -def single_thread(per_block=True): +def single_thread(scope: ThreadSubset = ThreadSubset.BLOCK): """Runs the context only from a single thread. Args: - per_block: If True, only one thread per block will run the context. - Otherwise, only one thread per warp group will run the context. + scope: What level of the thread hierarchy to select a thread from. + For example, if the scope is BLOCK, only one thread per block will be + selected. """ global _ONCE_PER - scope = ThreadSubset.BLOCK if per_block else ThreadSubset.WARPGROUP # If we're already in a single-thread context, we don't have to do anything. if _ONCE_PER is not None and _ONCE_PER >= scope: yield @@ -293,7 +303,7 @@ def single_thread(per_block=True): prev_scope = _ONCE_PER _ONCE_PER = scope try: - if_op = scf.IfOp(single_thread_predicate(per_block)) + if_op = scf.IfOp(single_thread_predicate(scope)) with ir.InsertionPoint(if_op.then_block): yield scf.YieldOp([]) @@ -708,7 +718,7 @@ def initialize(address: ir.Value, num_barriers: int, arrival_count: int = 1) -> ptr = ir.Type.parse(f"!llvm.ptr<{WORKGROUP_NVPTX_ADDRESS_SPACE}>") phases = memref.alloca(ir.MemRefType.get((), i32), [], []) memref.store(c(0, i32), phases, []) - with single_thread(per_block=True): + with single_thread(scope=ThreadSubset.BLOCK): for i in range(num_barriers): nvvm.mbarrier_init_shared( llvm.getelementptr(ptr, address, [], [i], i64), @@ -870,7 +880,7 @@ def arrive(self): if self.barrier.num_barriers != 1: raise ValueError("Can only arrive on a single barrier") if self.cluster_mask is None: - with single_thread(per_block=False): + with single_thread(scope=ThreadSubset.WARPGROUP): self.barrier.arrive() return i32 = ir.IntegerType.get_signless(32) diff --git a/jax/experimental/multihost_utils.py b/jax/experimental/multihost_utils.py index 2bde1fbeadc4..7be349f0fc8f 100644 --- a/jax/experimental/multihost_utils.py +++ b/jax/experimental/multihost_utils.py @@ -99,8 +99,11 @@ def _identity_fn(x): def _handle_array_process_allgather(inp, tiled): if isinstance(inp, array.ArrayImpl) and not inp.is_fully_addressable: - reps = sharding_impls.GSPMDSharding.get_replicated( - inp.sharding._device_assignment) + if isinstance(inp.sharding, sharding_impls.NamedSharding): + reps = inp.sharding.with_spec(P()) + else: + reps = sharding_impls.GSPMDSharding.get_replicated( + inp.sharding._device_assignment, memory_kind=inp.sharding.memory_kind) out = jax.jit(_identity_fn, out_shardings=reps)(inp) else: # All inputs here will be fully addressable. diff --git a/jax/experimental/pallas/__init__.py b/jax/experimental/pallas/__init__.py index 1e0abacfc25f..2144be0fb18b 100644 --- a/jax/experimental/pallas/__init__.py +++ b/jax/experimental/pallas/__init__.py @@ -15,7 +15,7 @@ """Module for Pallas, a JAX extension for custom kernels. See the Pallas documentation at -https://jax.readthedocs.io/en/latest/pallas.html. +https://docs.jax.dev/en/latest/pallas.html. """ from jax._src.pallas.core import Blocked as Blocked @@ -30,6 +30,7 @@ from jax._src.pallas.core import MemorySpace as MemorySpace from jax._src.pallas.core import Buffered as Buffered from jax._src.pallas.core import no_block_spec as no_block_spec +from jax._src.pallas.core import semaphore as semaphore from jax._src.pallas.core import Unblocked as Unblocked from jax._src.pallas.core import unblocked as unblocked from jax._src.pallas.cost_estimate import estimate_cost as estimate_cost @@ -55,8 +56,12 @@ from jax._src.pallas.primitives import program_id as program_id from jax._src.pallas.primitives import reciprocal as reciprocal from jax._src.pallas.primitives import run_scoped as run_scoped +from jax._src.pallas.primitives import semaphore_read as semaphore_read +from jax._src.pallas.primitives import semaphore_signal as semaphore_signal +from jax._src.pallas.primitives import semaphore_wait as semaphore_wait from jax._src.pallas.primitives import store as store from jax._src.pallas.primitives import swap as swap +from jax._src.pallas.primitives import DeviceIdType as DeviceIdType from jax._src.pallas.utils import cdiv as cdiv from jax._src.pallas.utils import next_power_of_2 as next_power_of_2 from jax._src.pallas.utils import strides_from_shape as strides_from_shape diff --git a/jax/experimental/pallas/fuser.py b/jax/experimental/pallas/fuser.py index 729a447b7408..d4ec7e89cc7d 100644 --- a/jax/experimental/pallas/fuser.py +++ b/jax/experimental/pallas/fuser.py @@ -19,6 +19,6 @@ from jax._src.pallas.fuser.block_spec import pull_block_spec as pull_block_spec from jax._src.pallas.fuser.block_spec import push_block_spec as push_block_spec from jax._src.pallas.fuser.custom_evaluate import evaluate as evaluate -from jax._src.pallas.fuser.fusable import fusable as fusable +from jax._src.pallas.fuser.fusible import fusible as fusible from jax._src.pallas.fuser.fusion import Fusion as Fusion from jax._src.pallas.fuser.jaxpr_fusion import fuse as fuse diff --git a/jax/experimental/pallas/mosaic_gpu.py b/jax/experimental/pallas/mosaic_gpu.py index 631b4f720984..d74ffe6eae1b 100644 --- a/jax/experimental/pallas/mosaic_gpu.py +++ b/jax/experimental/pallas/mosaic_gpu.py @@ -18,15 +18,21 @@ """ from jax._src.pallas.mosaic_gpu.core import Barrier as Barrier +from jax._src.pallas.mosaic_gpu.core import ClusterBarrier as ClusterBarrier from jax._src.pallas.mosaic_gpu.core import GPUBlockSpec as GPUBlockSpec from jax._src.pallas.mosaic_gpu.core import GPUCompilerParams as GPUCompilerParams -from jax._src.pallas.mosaic_gpu.core import GPUMemorySpace as GPUMemorySpace from jax._src.pallas.mosaic_gpu.core import GPUMesh as GPUMesh +from jax._src.pallas.mosaic_gpu.core import GPUMemorySpace as GPUMemorySpace from jax._src.pallas.mosaic_gpu.core import kernel as kernel +from jax._src.pallas.mosaic_gpu.core import SemaphoreType as SemaphoreType from jax._src.pallas.mosaic_gpu.core import SwizzleTransform as SwizzleTransform from jax._src.pallas.mosaic_gpu.core import TilingTransform as TilingTransform +from jax._src.pallas.mosaic_gpu.core import transform_ref as transform_ref from jax._src.pallas.mosaic_gpu.core import transpose_ref as transpose_ref +from jax._src.pallas.mosaic_gpu.core import untile_ref as untile_ref +from jax._src.pallas.mosaic_gpu.core import unswizzle_ref as unswizzle_ref from jax._src.pallas.mosaic_gpu.core import TransposeTransform as TransposeTransform +from jax._src.pallas.mosaic_gpu.core import WarpMesh as WarpMesh from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as ACC # noqa: F401 from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as WGMMAAccumulatorRef from jax._src.pallas.mosaic_gpu.pipeline import emit_pipeline as emit_pipeline @@ -36,18 +42,24 @@ from jax._src.pallas.mosaic_gpu.primitives import broadcasted_iota as broadcasted_iota from jax._src.pallas.mosaic_gpu.primitives import commit_smem as commit_smem from jax._src.pallas.mosaic_gpu.primitives import commit_smem_to_gmem_group as commit_smem_to_gmem_group +from jax._src.pallas.mosaic_gpu.primitives import GPUShapeDtypeStruct as GPUShapeDtypeStruct from jax._src.pallas.mosaic_gpu.primitives import copy_gmem_to_smem as copy_gmem_to_smem from jax._src.pallas.mosaic_gpu.primitives import copy_smem_to_gmem as copy_smem_to_gmem +from jax._src.pallas.mosaic_gpu.primitives import inline_mgpu as inline_mgpu from jax._src.pallas.mosaic_gpu.primitives import Layout as Layout from jax._src.pallas.mosaic_gpu.primitives import layout_cast as layout_cast +from jax._src.pallas.mosaic_gpu.primitives import load as load +from jax._src.pallas.mosaic_gpu.primitives import RefType as RefType from jax._src.pallas.mosaic_gpu.primitives import set_max_registers as set_max_registers from jax._src.pallas.mosaic_gpu.primitives import wait_smem_to_gmem as wait_smem_to_gmem from jax._src.pallas.mosaic_gpu.primitives import wgmma as wgmma from jax._src.pallas.mosaic_gpu.primitives import wgmma_wait as wgmma_wait -from jax.experimental.mosaic.gpu.core import ThreadSemantics as ThreadSemantics +from jax.experimental.mosaic.gpu.core import LoweringSemantics as LoweringSemantics #: Alias of :data:`jax.experimental.pallas.mosaic_gpu.GPUMemorySpace.GMEM`. GMEM = GPUMemorySpace.GMEM #: Alias of :data:`jax.experimental.pallas.mosaic_gpu.GPUMemorySpace.SMEM`. SMEM = GPUMemorySpace.SMEM +#: Alias of :data:`jax.experimental.pallas.mosaic_gpu.GPUMemorySpace.TMEM`. +TMEM = GPUMemorySpace.TMEM diff --git a/jax/experimental/pallas/ops/gpu/attention_mgpu.py b/jax/experimental/pallas/ops/gpu/attention_mgpu.py index 8883878f5f0e..6a20b448ca54 100644 --- a/jax/experimental/pallas/ops/gpu/attention_mgpu.py +++ b/jax/experimental/pallas/ops/gpu/attention_mgpu.py @@ -43,8 +43,8 @@ def __post_init__(self): raise ValueError(f"{self.max_concurrent_steps=} must be at least 2") -@functools.partial(jax.jit, static_argnames=["config"]) -def attention(q, k, v, config: TuningConfig): +@functools.partial(jax.jit, static_argnames=["config", "save_residuals"]) +def attention(q, k, v, config: TuningConfig, save_residuals: bool = False): if q.ndim != 4 or k.ndim != 4 or v.ndim != 4: raise ValueError(f"q, k, and v should all be 4D, got: {q.ndim=}, {k.ndim=}, {v.ndim=}") batch_size, q_seq_len, num_q_heads, head_dim = q.shape @@ -69,12 +69,12 @@ def attention(q, k, v, config: TuningConfig): ) block_q, block_kv = config.block_q, config.block_kv - def kernel(q_ref, k_ref, v_ref, out_ref, scoped): + def kernel(q_ref, k_ref, v_ref, out_ref, lse_ref, scoped): batch = lax.axis_index("batch") q_head = lax.axis_index("heads") smem_buffers, buffer_barriers, consumed_barriers, schedule_barrier = scoped wg_idx = lax.axis_index("wg") - qo_smem2, k_smem, v_smem = smem_buffers + qo_smem2, k_smem, v_smem, lse_smem2 = smem_buffers k_barriers, v_barriers, q_barriers = buffer_barriers k_consumed_barriers, v_consumed_barriers = consumed_barriers def perform_schedule_barrier(): @@ -85,6 +85,7 @@ def perform_schedule_barrier(): def _compute_wg(): plgpu.set_max_registers(232, action="increase") qo_smem = qo_smem2.at[wg_idx] + lse_smem = lse_smem2.at[wg_idx] if lse_smem2 is not None else None q_seq_base = lax.axis_index("q_seq") * (2 * block_q) + wg_idx * block_q plgpu.copy_gmem_to_smem( @@ -162,15 +163,23 @@ def _wait(): 0, kv_seq_len // block_kv, kv_loop, (acc, m_i, l_i) ) pl.when(wg_idx == 0)(perform_schedule_barrier) - del m_i # Not needed anymore # TODO(apaszke): Invert and multiply to avoid expensive divisions. acc /= lax.broadcast_in_dim(l_i, (block_q, head_dim), [0]) qo_smem[...] = acc.astype(dtype) + if lse_smem is not None: + RCP_LN2 = 1.4426950408889634 + log2 = lambda x: jnp.log(x) * RCP_LN2 + lse_smem[...] = m_i + log2(l_i) plgpu.commit_smem() plgpu.copy_smem_to_gmem( qo_smem, out_ref.at[batch, pl.ds(q_seq_base, block_q), q_head], ) + if lse_smem is not None: + plgpu.copy_smem_to_gmem( + lse_smem, + lse_ref.at[batch, q_head, pl.ds(q_seq_base, block_q)], + ) plgpu.wait_smem_to_gmem(0) @pl.when(wg_idx == 2) def _memory_wg(): @@ -191,9 +200,9 @@ def kv_loop(kv_step, _): plgpu.copy_gmem_to_smem(v_ref.at[s], v_smem.at[tma_slot], v_barriers.at[tma_slot]) lax.fori_loop(0, kv_seq_len // block_kv - max_concurrent_steps, kv_loop, None) - def entry(q_ref, k_ref, v_ref, out_ref): + def entry(q_ref, k_ref, v_ref, out_ref, lse_ref): compute_wgs = 2 - tiling = plgpu.TilingTransform((64, 64)) + tiling = plgpu.TilingTransform((8, 64)) swizzle = plgpu.SwizzleTransform(128) qo_scratch = plgpu.SMEM( (compute_wgs, block_q, head_dim), jnp.float16, @@ -201,15 +210,18 @@ def entry(q_ref, k_ref, v_ref, out_ref): ) k_scratch = plgpu.SMEM( (max_concurrent_steps, block_kv, head_dim), jnp.float16, - transforms=(tiling, plgpu.TransposeTransform((0, 2, 1, 3, 4)), swizzle), + transforms=(tiling, swizzle), ) v_scratch = plgpu.SMEM( (max_concurrent_steps, block_kv, head_dim), jnp.float16, transforms=(tiling, swizzle), ) + scratch = [qo_scratch, k_scratch, v_scratch, None] + if save_residuals: + scratch[3] = plgpu.SMEM((compute_wgs, block_q), jnp.float32) pl.run_scoped( - lambda *args: kernel(q_ref, k_ref, v_ref, out_ref, args), - (qo_scratch, k_scratch, v_scratch), + lambda *args: kernel(q_ref, k_ref, v_ref, out_ref, lse_ref, args), + scratch, ( plgpu.Barrier(1, num_barriers=max_concurrent_steps), plgpu.Barrier(1, num_barriers=max_concurrent_steps), @@ -223,17 +235,32 @@ def entry(q_ref, k_ref, v_ref, out_ref): if rem: raise NotImplementedError(f"{q_seq_len=} must be a multiple of {block_q * 2=}") - return plgpu.kernel( + out_shape = [q, None] + if save_residuals: + # Note that we keep seq_len in the minor-most dimension so that we can do + # 1D TMAs on chunks of `block_q`. + out_shape[1] = jax.ShapeDtypeStruct( + (batch_size, num_q_heads, q_seq_len), jnp.float32 + ) + + out, lse = plgpu.kernel( entry, - out_shape=q, + out_shape=out_shape, grid=(batch_size, num_q_tiles, num_q_heads), + grid_names=("batch", "q_seq", "heads"), num_threads=3, - axis_names=("batch", "q_seq", "heads", "wg"), + thread_name="wg", compiler_params=plgpu.GPUCompilerParams(approx_math=True), )(q, k, v) -@functools.partial(jax.jit, static_argnames=["config"]) -def attention_with_pipeline_emitter(q, k, v, config: TuningConfig): + if save_residuals: + assert lse is not None + return out, (lse,) + + return out + +@functools.partial(jax.jit, static_argnames=["config", "save_residuals"]) +def attention_with_pipeline_emitter(q, k, v, config: TuningConfig, save_residuals=False): if q.ndim != 4 or k.ndim != 4 or v.ndim != 4: raise ValueError(f"q, k, and v should all be 4D, got: {q.ndim=}, {k.ndim=}, {v.ndim=}") batch_size, q_seq_len, num_q_heads, head_dim = q.shape @@ -262,14 +289,14 @@ def attention_with_pipeline_emitter(q, k, v, config: TuningConfig): if rem: raise NotImplementedError(f"{q_seq_len=} must be a multiple of {block_q * 2=}") - tiling = plgpu.TilingTransform((64, 64)) + tiling = plgpu.TilingTransform((8, 64)) swizzle = plgpu.SwizzleTransform(128) - transpose = plgpu.TransposeTransform((0, 2, 1, 3, 4)) - def fa3_kernel(q_ref, k_ref, v_ref, out_ref, scoped): + def fa3_kernel(q_ref, k_ref, v_ref, out_ref, lse_ref, scoped): batch = lax.axis_index("batch") wg_idx = lax.axis_index("wg") - qo_smem2, q_barriers, schedule_barrier = scoped + smem_buffers, q_barriers, schedule_barrier = scoped + qo_smem2, lse_smem2 = smem_buffers q_seq_base = lax.axis_index("q_seq") * (2 * block_q) + wg_idx * block_q q_head = lax.axis_index("heads") kv_head = lax.div(q_head, jnp.array(q_heads_per_kv_head, q_head.dtype)) @@ -281,6 +308,7 @@ def perform_schedule_barrier(): def _compute_thread(): qo_smem = qo_smem2.at[wg_idx] + lse_smem = lse_smem2.at[wg_idx] if lse_smem2 is not None else None m_i = plgpu.layout_cast( jnp.full((block_q,), -jnp.inf, dtype=jnp.float32), plgpu.Layout.WGMMA_ROW, ) @@ -299,18 +327,26 @@ def _compute_thread(): plgpu.barrier_wait(q_barriers.at[wg_idx]) pl.when(wg_idx == 1)(perform_schedule_barrier) final_carry = (yield (acc, m_i, l_i)) - del m_i # Unused pl.when(wg_idx == 0)(perform_schedule_barrier) - acc, _, l_i = final_carry + acc, m_i, l_i = final_carry acc /= lax.broadcast_in_dim(l_i, (block_q, head_dim), [0]) qo_smem[...] = acc.astype(dtype) + if lse_smem is not None: + RCP_LN2 = 1.4426950408889634 + log2 = lambda x: jnp.log(x) * RCP_LN2 + lse_smem[...] = m_i + log2(l_i) plgpu.commit_smem() plgpu.copy_smem_to_gmem( qo_smem, out_ref.at[batch, pl.ds(q_seq_base, block_q), q_head], ) + if lse_smem is not None: + plgpu.copy_smem_to_gmem( + lse_smem, + lse_ref.at[batch, q_head, pl.ds(q_seq_base, block_q)], + ) plgpu.wait_smem_to_gmem(0) - def kv_pipeline(k_smem, v_smem, + def kv_pipeline(_, k_smem, v_smem, k_consumed_barrier, v_consumed_barrier, carry): acc, m_i, l_i = carry @@ -353,7 +389,7 @@ def compute_pv(acc_ref): plgpu.GPUBlockSpec( # k block_shape=(block_kv, head_dim), index_map=lambda i: (i, 0), - transforms=[tiling, transpose, swizzle]), + transforms=[tiling, swizzle]), plgpu.GPUBlockSpec( # v block_shape=(block_kv, head_dim), index_map=lambda i: (i, 0), @@ -366,11 +402,12 @@ def compute_pv(acc_ref): pipeline(k_ref, v_ref) mesh = plgpu.GPUMesh( grid=(batch_size, num_q_tiles, num_q_heads), + grid_names=("batch", "q_seq", "heads"), num_threads=3, - axis_names=("batch", "q_seq", "heads", "wg"), + thread_name="wg", ) def run(refs): - q_ref, k_ref, v_ref, out_ref = refs + q_ref, k_ref, v_ref, out_ref, lse_ref = refs @pl.core_map(mesh, compiler_params=plgpu.GPUCompilerParams(approx_math=True), ) @@ -379,22 +416,36 @@ def _kernel_entry(): (compute_wgs, block_q, head_dim), jnp.float16, transforms=(tiling, swizzle), ) + scratch = [qo_scratch, None] + if save_residuals: + scratch[1] = plgpu.SMEM((compute_wgs, block_q), jnp.float32) pl.run_scoped( - lambda *args: fa3_kernel(q_ref, k_ref, v_ref, out_ref, args), - qo_scratch, + lambda *args: fa3_kernel(q_ref, k_ref, v_ref, out_ref, lse_ref, args), + scratch, plgpu.Barrier(1, num_barriers=compute_wgs), plgpu.Barrier(num_arrivals=compute_wgs), ) @jax.jit - def run_function(q, k, v, o): - _, _, _, out = pl.run_state(run)((q, k, v, o)) - return out - out = run_function(q, k, v, jnp.full_like(q, jnp.inf)) + def run_function(q, k, v, o, lse): + _, _, _, out, lse = pl.run_state(run)((q, k, v, o, lse)) + return out, lse + + lse = ( + jnp.full((batch_size, num_q_heads, q_seq_len), -jnp.inf, dtype=jnp.float32) + if save_residuals + else None + ) + out, lse = run_function(q, k, v, jnp.full_like(q, jnp.inf), lse) + + if save_residuals: + assert lse is not None + return out, (lse,) + return out -@jax.jit -def attention_reference(q, k, v): +@functools.partial(jax.jit, static_argnames=["save_residuals"]) +def attention_reference(q, k, v, save_residuals=False): batch_size, q_seq_len, num_q_heads, head_dim = q.shape num_kv_heads = k.shape[2] q, k, v = map(lambda x: x.astype(jnp.float32), (q, k, v)) @@ -406,8 +457,16 @@ def attention_reference(q, k, v): unnormalized = jnp.exp(logits - m) l = unnormalized.sum(axis=-1, keepdims=True) weights = unnormalized / l - return jnp.einsum("bqHhk,bkHc->bqHhc", weights, v).reshape(*q.shape) - + out = jnp.einsum("bqHhk,bkHc->bqHhc", weights, v).reshape(*q.shape) + + if save_residuals: + log2e = math.log2(math.e) + l = l.reshape(*q.shape[:-1]) + m = m.reshape(*q.shape[:-1]) + lse = m * log2e + jnp.log2(l) + return out, (lse.swapaxes(-1, -2),) + else: + return out def main(unused_argv): num_q_heads = 16 diff --git a/jax/experimental/pallas/ops/gpu/layer_norm.py b/jax/experimental/pallas/ops/gpu/layer_norm.py index d37afaf4d9e0..187d74ee1fd9 100644 --- a/jax/experimental/pallas/ops/gpu/layer_norm.py +++ b/jax/experimental/pallas/ops/gpu/layer_norm.py @@ -247,7 +247,7 @@ def layer_norm_backward( grid_ = (pl.cdiv(reshaped_x.shape[1], block_n),) method = pl.pallas_call( kernel, - compiler_params=dict(triton=dict(num_warps=num_warps)), + compiler_params=plgpu.TritonCompilerParams(num_warps=num_warps), grid=grid_, out_shape=out_shape_dwbias, debug=False, diff --git a/jax/experimental/pallas/ops/gpu/rms_norm.py b/jax/experimental/pallas/ops/gpu/rms_norm.py index ff224c6dfde7..baeaeb8a57b3 100644 --- a/jax/experimental/pallas/ops/gpu/rms_norm.py +++ b/jax/experimental/pallas/ops/gpu/rms_norm.py @@ -228,7 +228,7 @@ def rms_norm_backward( grid_ = (pl.cdiv(reshaped_x.shape[1], block_n),) method = pl.pallas_call( kernel, - compiler_params=dict(triton=dict(num_warps=num_warps)), + compiler_params=plgpu.TritonCompilerParams(num_warps=num_warps), grid=grid_, out_shape=out_shape_dwbias, debug=False, @@ -264,8 +264,8 @@ def rms_norm( out_shape = jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype) method = pl.pallas_call( kernel, - compiler_params=dict( - triton=dict(num_warps=num_warps, num_stages=num_stages) + compiler_params=plgpu.TritonCompilerParams( + num_warps=num_warps, num_stages=num_stages ), grid=(), out_shape=out_shape, diff --git a/jax/experimental/pallas/ops/tpu/flash_attention.py b/jax/experimental/pallas/ops/tpu/flash_attention.py index 0cb3d798d09e..ef8dd61abacb 100644 --- a/jax/experimental/pallas/ops/tpu/flash_attention.py +++ b/jax/experimental/pallas/ops/tpu/flash_attention.py @@ -391,9 +391,9 @@ def body(i, _): l_prev = l_scratch_ref[batch_idx] q = q_tile_ref[batch_idx] # [block_q, head_dim] start_k = i * block_k - k = pl.load( - k_tile_ref, (*batch_idx, pl.dslice(start_k, block_k), slice(None)) - ) # [block_k, head_dim] + k = k_tile_ref[ + (*batch_idx, pl.dslice(start_k, block_k), slice(None)) + ] # [block_k, head_dim] s = jax.lax.dot_general( q, k, TRANS_B_DIM_NUMBERS, preferred_element_type=jnp.float32 @@ -403,10 +403,9 @@ def body(i, _): # TODO(tanburn) Should the attention bias be added before or after # multiplication by sm_scale? if ab_tile_ref is not None: - ab = pl.load( - ab_tile_ref, + ab = ab_tile_ref[ (*batch_idx, pl.dslice(None), pl.dslice(start_k, block_k)) - ).astype(jnp.float32) + ].astype(jnp.float32) s += ab if sm_scale != 1.0: @@ -422,10 +421,9 @@ def body(i, _): q_segment_ids = pltpu.repeat( q_segment_ids_tile_ref[batch_idx[0]], repeats, axis=1 ) # [block_q, block_k]. - kv_segment_ids = pl.load( - kv_segment_ids_tile_ref, - (batch_idx[0], pl.dslice(1), pl.dslice(start_k, block_k)), - ) # [1, block_k]. + kv_segment_ids = kv_segment_ids_tile_ref[ + batch_idx[0], :1, pl.dslice(start_k, block_k) + ] # [1, block_k]. mask = jnp.equal(q_segment_ids, kv_segment_ids).astype(jnp.bool_) if causal: @@ -471,9 +469,7 @@ def body(i, _): l_next_inv_safe = jnp.where(l_next == 0.0, 1.0, 1.0 / l_next) acc_scratch_ref[batch_idx] *= l_broadcast(l_corr * l_next_inv_safe) - v = pl.load( - v_tile_ref, (*batch_idx, pl.dslice(start_k, block_k), slice(None)) - ) + v = v_tile_ref[(*batch_idx, pl.dslice(start_k, block_k), slice(None))] o_curr = jax.lax.dot( p.astype(v.dtype), v, preferred_element_type=jnp.float32 ) @@ -529,15 +525,13 @@ def _flash_attention_kernel_single_batch_single_step( raise NotImplementedError( f"kv block size must be a multiple of {NUM_LANES}" ) - q_segment_ids = pl.load( - q_segment_ids_tile_ref, (batch_idx[0],) - ) # [block_q, NUM_LANES]. + q_segment_ids = q_segment_ids_tile_ref[ + batch_idx[0] + ] # [block_q, NUM_LANES]. q_segment_ids = pltpu.repeat( q_segment_ids, repeats, axis=1 ) # [block_q, block_k]. - kv_segment_ids = pl.load( - kv_segment_ids_tile_ref, (batch_idx[0], pl.dslice(1)) - ) # [1, block_k]. + kv_segment_ids = kv_segment_ids_tile_ref[batch_idx[0], :1] # [1, block_k]. mask = jnp.equal(q_segment_ids, kv_segment_ids).astype(jnp.bool_) if causal: @@ -840,33 +834,27 @@ def q_body(j, _): start_q = j * block_q def k_body(i, _): start_k = i * block_k - k = pl.load(k_tile_ref, (0, 0, pl.ds(start_k, block_k), slice(None))) - v = pl.load(v_tile_ref, (0, 0, pl.ds(start_k, block_k), slice(None))) - q = pl.load(q_tile_ref, (0, 0, pl.ds(start_q, block_q), slice(None)) - ) # [block_q, head_dim] - l = pl.load(l_tile_ref, (0, 0, pl.ds(start_q, block_q), slice(None)) - ) # [block_q, 128] - m = pl.load(m_tile_ref, (0, 0, pl.ds(start_q, block_q), slice(None)) - ) # [block_q, 128] - do = pl.load(do_tile_ref, (0, 0, pl.ds(start_q, block_q), slice(None)) - ) # [block_q, 128] - di = pl.load(di_tile_ref, (0, 0, pl.ds(start_q, block_q), slice(None)) - ).astype(jnp.float32) # [block_q, 128] + k = k_tile_ref[0, 0, pl.ds(start_k, block_k), :] + v = v_tile_ref[0, 0, pl.ds(start_k, block_k), :] + q = q_tile_ref[0, 0, pl.ds(start_q, block_q), :] # [block_q, head_dim] + l = l_tile_ref[0, 0, pl.ds(start_q, block_q), :] # [block_q, 128] + m = m_tile_ref[0, 0, pl.ds(start_q, block_q), :] # [block_q, 128] + do = do_tile_ref[0, 0, pl.ds(start_q, block_q), :] # [block_q, 128] + di = di_tile_ref[0, 0, pl.ds(start_q, block_q), :].astype( + jnp.float32 + ) # [block_q, 128] capped_logits = lax.dot_general( q, k, TRANS_B_DIM_NUMBERS, preferred_element_type=jnp.float32 ) # [block_q_major, block_k] if ab_tile_ref is not None: - ab = pl.load( - ab_tile_ref, - ( - 0, - 0, - pl.dslice(j * block_q, block_q), - pl.dslice(i * block_k, block_k), - ), - ).astype(jnp.float32) + ab = ab_tile_ref[ + 0, + 0, + pl.dslice(j * block_q, block_q), + pl.dslice(i * block_k, block_k), + ].astype(jnp.float32) capped_logits += ab if sm_scale != 1.0: @@ -878,15 +866,15 @@ def k_body(i, _): if rem: raise NotImplementedError( ) - q_segment_ids = pl.load( - q_segment_ids_tile_ref, (0, pl.ds(start_q, block_q), slice(None)) - ) # [block_q, NUM_LANES]. + q_segment_ids = q_segment_ids_tile_ref[ + 0, pl.ds(start_q, block_q), : + ] # [block_q, NUM_LANES]. q_segment_ids = pltpu.repeat( q_segment_ids, repeats, axis=1 ) # [block_q, block_k]. - kv_segment_ids = pl.load( - kv_segment_ids_tile_ref, (slice(None), 0, pl.ds(start_k, block_k)) - ) # [1, block_k]. + kv_segment_ids = kv_segment_ids_tile_ref[ + :, 0, pl.ds(start_k, block_k) + ] # [1, block_k]. mask = jnp.equal(q_segment_ids, kv_segment_ids).astype(jnp.bool_) if causal: @@ -913,9 +901,9 @@ def k_body(i, _): 1 / l, block_k // MIN_BLOCK_SIZE, axis=1 ) # [block_q_major, block_k_major] dv = lax.dot(p.T.astype(do.dtype), do, preferred_element_type=jnp.float32) - pl.store(dv_scratch_ref, (pl.ds(start_k, block_k), slice(None)), - pl.load(dv_scratch_ref, (pl.ds(start_k, block_k), slice(None))) - + dv.astype(dv_scratch_ref.dtype)) + dv_scratch_ref[pl.ds(start_k, block_k), :] += dv.astype( + dv_scratch_ref.dtype + ) # di: [block_q, 128] # do: [block_q, head_dim] @@ -931,9 +919,9 @@ def k_body(i, _): # ds: [block_q_major, block_k_major] # q: [block_q_major, head_dim] dk = lax.dot(ds.T.astype(do.dtype), q, preferred_element_type=jnp.float32) - pl.store(dk_scratch_ref, (pl.ds(start_k, block_k), slice(None)), - pl.load(dk_scratch_ref, (pl.ds(start_k, block_k), slice(None))) - + dk.astype(dk_scratch_ref.dtype)) + dk_scratch_ref[pl.ds(start_k, block_k), :] += dk.astype( + dk_scratch_ref.dtype + ) lax.fori_loop(0, block_k_major // block_k, k_body, None, unroll=True) if causal: @@ -1192,12 +1180,8 @@ def start_new_sequence(): def body(i, _): k_slice = pl.ds(i * block_k, block_k) q = q_tile_ref[0, 0, :, :] - k = pl.load( - k_tile_ref, (0, 0, k_slice, slice(None)), - ) # [block_k, head_dim] - v = pl.load( - v_tile_ref, (0, 0, k_slice, slice(None)), - ) # [block_k, head_dim] + k = k_tile_ref[0, 0, k_slice, :] # [block_k, head_dim] + v = v_tile_ref[0, 0, k_slice, :] # [block_k, head_dim] l = l_tile_ref[0, 0, :, :] # [block_q_major, 128] m = m_tile_ref[0, 0, :, :] # [block_q_major, 128] do = do_tile_ref[0, 0, :, :] # [block_q_major, head_dim] @@ -1208,9 +1192,9 @@ def body(i, _): ) if ab_tile_ref is not None: - ab = pl.load( - ab_tile_ref, (0, 0, pl.dslice(None), pl.dslice(i * block_k, block_k)) - ).astype(jnp.float32) + ab = ab_tile_ref[0, 0, :, pl.dslice(i * block_k, block_k)].astype( + jnp.float32 + ) capped_logits += ab if sm_scale != 1.0: @@ -1226,9 +1210,7 @@ def body(i, _): q_segment_ids = pltpu.repeat( q_segment_ids_tile_ref[0], repeats, axis=1 ) # [block_q, block_k]. - kv_segment_ids = pl.load( - kv_segment_ids_tile_ref, (slice(None), 0, k_slice) - ) # [1, block_k]. + kv_segment_ids = kv_segment_ids_tile_ref[:, 0, k_slice] # [1, block_k]. mask = jnp.equal(q_segment_ids, kv_segment_ids).astype(jnp.bool_) if causal: @@ -1269,10 +1251,8 @@ def body(i, _): ds = ds * sm_scale if ds_tile_ref is not None: - pl.store( - ds_tile_ref, - (0, 0, pl.dslice(None), pl.dslice(i * block_k, block_k)), - ds.astype(ds_tile_ref.dtype), + ds_tile_ref[0, 0, :, pl.dslice(i * block_k, block_k)] = ds.astype( + ds_tile_ref.dtype ) # dp: [block_q_major, block_k] diff --git a/jax/experimental/pallas/ops/tpu/matmul.py b/jax/experimental/pallas/ops/tpu/matmul.py index 4ff82acbb5dd..06d868168f9e 100644 --- a/jax/experimental/pallas/ops/tpu/matmul.py +++ b/jax/experimental/pallas/ops/tpu/matmul.py @@ -14,7 +14,7 @@ """Example matmul TPU kernel. -See discussion in https://jax.readthedocs.io/en/latest/pallas/tpu/matmul.html. +See discussion in https://docs.jax.dev/en/latest/pallas/tpu/matmul.html. """ import functools diff --git a/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py b/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py index eb1e11df17da..4c03fb01be2b 100644 --- a/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py +++ b/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py @@ -114,7 +114,7 @@ def paged_flash_attention_kernel( lengths_ref, page_indices_ref, buffer_index_ref, - step_ref, + init_flag_ref, q_ref, k_pages_hbm_ref, k_scales_pages_hbm_ref, @@ -223,16 +223,12 @@ def create_kv_async_copy_descriptors(b, h, i, buffer_index): @pl.when(i * bk < length) def flash_attention(): # pylint: disable=unused-variable - step = step_ref[0] + init_flag = init_flag_ref[0] + init_flag_ref[0] = 0 buffer_index = buffer_index_ref[0] + next_b, next_h, next_i = compute_block_indices(b, h, i + 1) - @pl.when(i == 0) - def init(): # pylint: disable=unused-variable - m_ref[...] = jnp.full_like(m_ref, -jnp.inf) - l_ref[...] = jnp.zeros_like(l_ref) - o_ref[...] = jnp.zeros_like(o_ref) - - @pl.when(step == 0) + @pl.when(init_flag) def prefetch_first_block(): # pylint: disable=unused-variable async_copy_k, async_copy_v = create_kv_async_copy_descriptors( b, h, i, buffer_index @@ -240,7 +236,11 @@ def prefetch_first_block(): # pylint: disable=unused-variable async_copy_k.start() async_copy_v.start() - next_b, next_h, next_i = compute_block_indices(b, h, i + 1) + @pl.when(i == 0) + def init(): # pylint: disable=unused-variable + m_ref[...] = jnp.full_like(m_ref, -jnp.inf) + l_ref[...] = jnp.zeros_like(l_ref) + o_ref[...] = jnp.zeros_like(o_ref) @pl.when(next_b < batch_size) def prefetch_next_block(): # pylint: disable=unused-variable @@ -257,7 +257,7 @@ def prefetch_next_block(): # pylint: disable=unused-variable ) q = q_ref[...].astype(jnp.float32) k = async_copy_k.wait_and_get_loaded() - qk = jnp.einsum('hd,td->ht', q, k, preferred_element_type=jnp.float32) + qk = jnp.einsum("gd,td->gt", q, k, preferred_element_type=jnp.float32) if attn_logits_soft_cap is not None: capped_qk = jnp.tanh(qk / attn_logits_soft_cap) qk = capped_qk * attn_logits_soft_cap @@ -274,24 +274,21 @@ def prefetch_next_block(): # pylint: disable=unused-variable alpha = jnp.exp(m_prev - m_next) beta = jnp.exp(m_curr - m_next) l_next = alpha * l_prev + beta * l_curr - l_next_safe = jnp.where(l_next == 0.0, 1.0, l_next) + m_ref[...], l_ref[...] = m_next, l_next v = async_copy_v.wait_and_get_loaded() - o_curr_times_l_curr = jnp.dot(s_curr, v) + o_curr = jnp.einsum("gt,td->gd", s_curr, v) - m_ref[...], l_ref[...] = m_next, l_next_safe o_ref[...] = ( - (l_prev * alpha * o_ref[...] + beta * o_curr_times_l_curr) / l_next_safe + (l_prev * alpha * o_ref[...] + beta * o_curr) / l_next ).astype(o_ref.dtype) - step_ref[0] = step + 1 - def paged_flash_attention_kernel_inline_seq_dim( lengths_ref, page_indices_ref, buffer_index_ref, - step_ref, + init_flag_ref, q_ref, k_pages_hbm_ref, k_scales_pages_hbm_ref, @@ -326,7 +323,7 @@ def body(i, _): lengths_ref, page_indices_ref, buffer_index_ref, - step_ref, + init_flag_ref, q_ref, k_pages_hbm_ref, k_scales_pages_hbm_ref, @@ -387,7 +384,7 @@ def paged_attention( """Paged grouped query attention. Args: - q: A [batch_size, num_heads, head_dim] jax.Array. + q: A [batch_size, num_q_heads, head_dim] jax.Array. k_pages: A [num_kv_heads, total_num_pages, page_size, head_dim] jax.Array. v_pages: A [num_kv_heads, total_num_pages, page_size, head_dim] jax.Array. lengths: A i32[batch_size] jax.Array the length of each example. @@ -412,7 +409,7 @@ def paged_attention( one kernel. Returns: - The output of attention([batch_size, num_heads, head_dim]). + The output of attention([batch_size, num_q_heads, head_dim]). """ if isinstance(k_pages, quantization_utils.QuantizedTensor): k_pages, k_scales_pages = k_pages.weight, k_pages.scales @@ -431,7 +428,7 @@ def paged_attention( else: v_scales_pages = None - batch_size, num_heads, head_dim = q.shape + batch_size, num_q_heads, head_dim = q.shape num_kv_heads, _, page_size, head_dim_k = k_pages.shape batch_size_paged_indices, pages_per_sequence = page_indices.shape @@ -440,10 +437,10 @@ def paged_attention( f"k_pages and v_pages must have the same shape. Got {k_pages.shape} and" f" {v_pages.shape}" # pytype: disable=attribute-error ) - if num_heads % num_kv_heads != 0: + if num_q_heads % num_kv_heads != 0: raise ValueError( "Number of Q heads must be divisible by number of KV heads. Got" - f" {num_heads} and {num_kv_heads}." + f" {num_q_heads} and {num_kv_heads}." ) if head_dim_k != head_dim: raise ValueError( @@ -480,40 +477,41 @@ def paged_attention( else: raise ValueError("megacore_mode must be one of ['kv_head', 'batch', None]") - if (num_heads // num_kv_heads) % 8 != 0: + num_groups = num_q_heads // num_kv_heads + if (num_groups) % 8 != 0: # Reshape q to hint XLA to pick a <1x128> layout otherwise it will pick a # <8x128> layout for a <1x128> memref inside the kernel and error out. - q = q.reshape(batch_size, num_heads, 1, head_dim) + q = q.reshape(batch_size, num_q_heads, 1, head_dim) if megacore_mode == "kv_head": q_block_spec = pl.BlockSpec( - (None, num_heads // num_kv_heads, None, head_dim), + (None, num_groups, None, head_dim), lambda core_index, b, h, *_: (b, h * num_cores + core_index, 0, 0), ) elif megacore_mode == "batch": q_block_spec = pl.BlockSpec( - (None, num_heads // num_kv_heads, None, head_dim), + (None, num_groups, None, head_dim), lambda core_index, b, h, *_: (b * num_cores + core_index, h, 0, 0), ) else: q_block_spec = pl.BlockSpec( - (None, num_heads // num_kv_heads, None, head_dim), + (None, num_groups, None, head_dim), lambda core_index, b, h, *_: (b, h, 0, 0), ) q_dtype_for_kernel_launch = jnp.float32 else: if megacore_mode == "kv_head": q_block_spec = pl.BlockSpec( - (None, num_heads // num_kv_heads, head_dim), + (None, num_groups, head_dim), lambda core_index, b, h, *_: (b, h * num_cores + core_index, 0), ) elif megacore_mode == "batch": q_block_spec = pl.BlockSpec( - (None, num_heads // num_kv_heads, head_dim), + (None, num_groups, head_dim), lambda core_index, b, h, *_: (b * num_cores + core_index, h, 0), ) else: q_block_spec = pl.BlockSpec( - (None, num_heads // num_kv_heads, head_dim), + (None, num_groups, head_dim), lambda core_index, b, h, *_: (b, h, 0), ) q_dtype_for_kernel_launch = q.dtype @@ -632,7 +630,7 @@ def paged_attention( ), grid_spec=pltpu.PrefetchScalarGridSpec( # There are 4 scalars prefetched per kernel call: `lengths_ref`, - # `page_indices_ref`, `buffer_index_ref`, `step_ref` + # `page_indices_ref`, `buffer_index_ref`, `init_flag_ref` num_scalar_prefetch=4, in_specs=in_specs, out_specs=[ @@ -644,7 +642,8 @@ def paged_attention( scratch_shapes=scratch_shapes, ), compiler_params=pltpu.TPUCompilerParams( - dimension_semantics=dimension_semantics), + dimension_semantics=dimension_semantics + ), out_shape=[ jax.ShapeDtypeStruct(q.shape, q_dtype_for_kernel_launch), jax.ShapeDtypeStruct((*q.shape[:-1], 1), jnp.float32), @@ -654,11 +653,11 @@ def paged_attention( lengths, page_indices.reshape(-1), jnp.zeros((1,), jnp.int32), # buffer index - jnp.zeros((1,), jnp.int32), # step + jnp.ones((1,), jnp.int32), # init flag q.astype(q_dtype_for_kernel_launch), k_pages, k_scales_pages, v_pages, v_scales_pages, ) - return out.reshape(batch_size, num_heads, head_dim).astype(q.dtype) + return out.reshape(batch_size, num_q_heads, head_dim).astype(q.dtype) diff --git a/jax/experimental/pallas/ops/tpu/paged_attention/util.py b/jax/experimental/pallas/ops/tpu/paged_attention/util.py new file mode 100644 index 000000000000..6d6ceca3733f --- /dev/null +++ b/jax/experimental/pallas/ops/tpu/paged_attention/util.py @@ -0,0 +1,82 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""JAX reference implementation of grouped query attention.""" + +import jax +from jax.experimental.pallas.ops.tpu.paged_attention import quantization_utils +import jax.numpy as jnp + +MASK_VALUE = -0.7 * float(jnp.finfo(jnp.dtype("float32")).max) + + +def grouped_query_attention_reference( + queries: jax.Array, # [batch_size, num_q_heads, head_dim] + k_pages: jax.Array, # [batch_size, num_kv_heads, max_seq_len, head_dim] + v_pages: jax.Array, # [batch_size, num_kv_heads, max_seq_len, head_dim] + seq_lens: jax.Array, # i32[batch_size] + soft_cap: float | None = None, + debug: bool = False, +) -> jax.Array: # [batch_size, num_q_heads, head_dim] + """Grouped query attention with a single query per request.""" + # Check input shapes + assert k_pages.shape == v_pages.shape + batch_size, num_q_heads, head_dim = queries.shape + batch_size2, num_kv_heads, max_seq_len, head_dim2 = k_pages.shape + assert batch_size2 == batch_size + assert head_dim2 == head_dim + + # Unquantize kv pages if necessary + if isinstance(k_pages, quantization_utils.QuantizedTensor): + k_pages = quantization_utils.unquantize_from_int8( + k_pages, dtype=jnp.float32 + ) + if isinstance(v_pages, quantization_utils.QuantizedTensor): + v_pages = quantization_utils.unquantize_from_int8( + v_pages, dtype=jnp.float32 + ) + + # Reshape for num_groups queries per k head + assert num_q_heads % num_kv_heads == 0 + num_groups = num_q_heads // num_kv_heads + queries = queries.reshape(batch_size, num_kv_heads, num_groups, head_dim) + + # Compute the dot product q*k and apply soft cap if necessary + qk = jnp.einsum( + "bhgd,bhtd->bhgt", + queries.astype(jnp.float32), + k_pages.astype(jnp.float32), + ) + if soft_cap is not None and soft_cap != 0.0: + qk = jnp.tanh(qk / soft_cap) * soft_cap + assert qk.shape == (batch_size, num_kv_heads, num_groups, max_seq_len) + if debug: + jax.debug.print("qk: {qk}", qk=qk) + + # Enfore causal mask (adding dimensions when necessary) + mask = jnp.arange(max_seq_len)[None] < seq_lens[:, None] + qk += jnp.where(mask, 0.0, MASK_VALUE)[:, None, None, :] + if debug: + jax.debug.print("masked: {qk}", qk=qk) + + # Generate probability distribution using softmax + probs = jax.nn.softmax(qk, axis=-1).astype(v_pages.dtype) + assert probs.shape == (batch_size, num_kv_heads, num_groups, max_seq_len) + if debug: + jax.debug.print("softmax: {probs}", probs=probs) + + # Attention is probability-weighted sum of v heads + attention = jnp.einsum("bhgt,bhtd->bhgd", probs, v_pages) + assert attention.shape == (batch_size, num_kv_heads, num_groups, head_dim) + return attention.reshape(batch_size, num_q_heads, head_dim) diff --git a/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py b/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py index 6600d765024c..203dc8a7602a 100644 --- a/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py +++ b/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py @@ -19,7 +19,6 @@ specifications. It supports mixed prefill and decoding, enhancing throughput during inference. """ - import functools import jax from jax import lax @@ -35,8 +34,8 @@ class MultiPageAsyncCopyDescriptor: def __init__( self, - pages_hbm_ref, # [total_num_pages, page_size, num_kv_heads_per_blk, head_dim] - vmem_buf, # [num_kv_pages_per_blk, page_size, num_kv_heads_per_blk, head_dim] + pages_hbm_ref, # [total_num_pages, page_size, num_combined_kv_heads_per_blk, head_dim] + vmem_buf, # [num_kv_pages_per_blk, page_size, num_combined_kv_heads_per_blk, head_dim] sem, page_indices_ref, # i32[max_num_seqs, pages_per_seq] offset, # [seq_idx, kv_pages_start] @@ -73,17 +72,34 @@ def wait(self): def ref_ragged_paged_attention( queries: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim] - k_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim] - v_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim] + kv_pages: jax.Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim] kv_lens: jax.Array, # i32[max_num_seqs] page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq] cu_q_lens: jax.Array, # i32[max_num_seqs + 1] num_seqs: jax.Array, # i32[1], *, sm_scale: float = 1.0, - mask_value: float = DEFAULT_MASK_VALUE, + sliding_window: int | None = None, + soft_cap: float | None = None, + mask_value: float | None = DEFAULT_MASK_VALUE, ): - _, _, num_kv_heads, head_dim = k_pages.shape + static_validate_inputs( + queries, + kv_pages, + kv_lens, + page_indices, + cu_q_lens, + num_seqs, + sm_scale=sm_scale, + sliding_window=sliding_window, + soft_cap=soft_cap, + mask_value=mask_value, + ) + if mask_value is None: + mask_value = DEFAULT_MASK_VALUE + _, _, num_combined_kv_heads, head_dim = kv_pages.shape + assert num_combined_kv_heads % 2 == 0 + num_kv_heads = num_combined_kv_heads // 2 num_q_heads = queries.shape[1] assert num_q_heads % num_kv_heads == 0 num_query_per_kv = num_q_heads // num_kv_heads @@ -95,8 +111,12 @@ def ref_ragged_paged_attention( kv_len = kv_lens[i] indices = page_indices[i] q = queries[q_start:q_end] - k = k_pages[indices, :, :, :].reshape(-1, num_kv_heads, head_dim)[:kv_len] - v = v_pages[indices, :, :, :].reshape(-1, num_kv_heads, head_dim)[:kv_len] + k = kv_pages[indices, :, 0::2, :].reshape(-1, num_kv_heads, head_dim)[ + :kv_len + ] + v = kv_pages[indices, :, 1::2, :].reshape(-1, num_kv_heads, head_dim)[ + :kv_len + ] k = jnp.repeat(k, num_query_per_kv, axis=1) v = jnp.repeat(v, num_query_per_kv, axis=1) attn = jnp.einsum("qhd,khd->hqk", q, k, preferred_element_type=jnp.float32) @@ -105,7 +125,12 @@ def ref_ragged_paged_attention( jnp.int32, attn.shape, 1 ) kv_span = jax.lax.broadcasted_iota(jnp.int32, attn.shape, 2) - attn += jnp.where(q_span < kv_span, mask_value, 0.0) + mask = q_span < kv_span + if sliding_window is not None: + mask = jnp.logical_or(mask, q_span - sliding_window >= kv_span) + if soft_cap is not None: + attn = soft_cap * jnp.tanh(attn / soft_cap) + attn += jnp.where(mask, mask_value, 0.0) attn = jax.nn.softmax(attn, axis=-1).astype(v.dtype) out = jnp.einsum("hqk,khd->qhd", attn, v).astype(queries.dtype) outputs.append(out) @@ -113,26 +138,47 @@ def ref_ragged_paged_attention( return jnp.concatenate(outputs, axis=0) -# Expect to run these checkes during runtime. -def validate_inputs_on_runtime( +# Expect to run these checks during runtime. +def dynamic_validate_inputs( q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim] - k_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim] - v_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim] + kv_pages: jax.Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim] kv_lens: jax.Array, # i32[max_num_seqs] page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq] cu_q_lens: jax.Array, # i32[max_num_seqs + 1] - num_seqs, # i32[1] + num_seqs: jax.Array, # i32[1] + *, + # These inputs are optional. If not specified, we will not validate them. + sm_scale: float | None = None, + sliding_window: int | None = None, + soft_cap: float | None = None, + mask_value: float | None = None, + # Kernel specific params. + num_kv_pages_per_block: int | None = None, + num_queries_per_block: int | None = None, + vmem_limit_bytes: int | None = None, ): - check_inputs_shapes( - q, k_pages, v_pages, kv_lens, page_indices, cu_q_lens, num_seqs + static_validate_inputs( + q, + kv_pages, + kv_lens, + page_indices, + cu_q_lens, + num_seqs, + sm_scale=sm_scale, + sliding_window=sliding_window, + soft_cap=soft_cap, + mask_value=mask_value, + num_kv_pages_per_block=num_kv_pages_per_block, + num_queries_per_block=num_queries_per_block, + vmem_limit_bytes=vmem_limit_bytes, ) max_num_batched_tokens = q.shape[0] - page_size = k_pages.shape[1] + page_size = kv_pages.shape[1] max_num_seqs, pages_per_seq = page_indices.shape if num_seqs[0] > max_num_seqs: raise ValueError(f"{num_seqs[0]=} must be less or equal to {max_num_seqs=}") max_kv_len = jnp.max(kv_lens) - min_pages_per_seq = ceil_div(max_kv_len, page_size) + min_pages_per_seq = cdiv(max_kv_len, page_size) if pages_per_seq < min_pages_per_seq: raise ValueError( f"{pages_per_seq=} must be greater or equal to" @@ -153,24 +199,31 @@ def validate_inputs_on_runtime( # Expect to run these checks during compile time. -def check_inputs_shapes( +def static_validate_inputs( q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim] - k_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim] - v_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim] + kv_pages: jax.Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim] kv_lens: jax.Array, # i32[max_num_seqs] page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq] cu_q_lens: jax.Array, # i32[max_num_seqs + 1] - num_seqs, # i32[1] + num_seqs: jax.Array, # i32[1] + *, + # These inputs are optional. If not specified, we will not validate them. + sm_scale: float | None = None, + sliding_window: int | None = None, + soft_cap: float | None = None, + mask_value: float | None = None, + # Kernel specific params. + num_kv_pages_per_block: int | None = None, + num_queries_per_block: int | None = None, + vmem_limit_bytes: int | None = None, ): _, num_q_heads, head_dim = q.shape - _, _, num_kv_heads, head_dim_k = k_pages.shape - max_num_seqs, _ = page_indices.shape + _, _, num_combined_kv_heads, head_dim_k = kv_pages.shape + assert num_combined_kv_heads % 2 == 0 + num_kv_heads = num_combined_kv_heads // 2 + max_num_seqs, pages_per_seq = page_indices.shape if num_seqs.shape != (1,): raise ValueError(f"{num_seqs.shape=} must be (1,)") - if k_pages.shape != v_pages.shape: - raise ValueError( - f"{k_pages.shape=} and {v_pages.shape=} must have the same shape." - ) if head_dim_k != head_dim: raise ValueError( f"Q head_dim {head_dim} must be the same as that of K/V {head_dim_k}." @@ -197,6 +250,23 @@ def check_inputs_shapes( ) if num_q_heads % num_kv_heads != 0: raise ValueError(f"{num_q_heads=} must be divisible by {num_kv_heads=}") + if sliding_window is not None and sliding_window <= 0: + raise ValueError(f"{sliding_window=} must be positive.") + if soft_cap is not None and soft_cap == 0.0: + raise ValueError(f"{soft_cap=} must not be 0.0.") + if ( + num_kv_pages_per_block is not None + and not 0 < num_kv_pages_per_block <= pages_per_seq + ): + raise ValueError( + f"{num_kv_pages_per_block=} must be in range (0, {pages_per_seq}]." + ) + if num_queries_per_block is not None and num_queries_per_block <= 0: + raise ValueError(f"{num_queries_per_block=} must be positive.") + if vmem_limit_bytes is not None and vmem_limit_bytes <= 0: + raise ValueError(f"{vmem_limit_bytes=} must be positive.") + del sm_scale # No constraints on sm_scale. + del mask_value # No consstraints on mask_value. def ragged_paged_attention_kernel( @@ -209,23 +279,29 @@ def ragged_paged_attention_kernel( num_seqs_ref, # Input q_ref, # [num_q_per_blk, num_q_heads_per_blk, head_dim] - k_pages_hbm_ref, # [total_num_pages, page_size, num_kv_heads, head_dim] - v_pages_hbm_ref, # [total_num_pages, page_size, num_kv_heads, head_dim] + kv_pages_hbm_ref, # [total_num_pages, page_size, num_combined_kv_heads, head_dim] # Output o_ref, # [num_q_per_blk, num_q_heads_per_blk, head_dim] # Scratch - k_bufs, # [2, num_kv_pages_per_blk, page_size, num_kv_heads_per_blk, head_dim] - v_bufs, # [2, num_kv_pages_per_blk, page_size, num_kv_heads_per_blk, head_dim] + kv_bufs, # [2, num_kv_pages_per_blk, page_size, num_combined_kv_heads_per_blk, head_dim] sems, # [2, 2] l_ref, # [num_kv_heads_per_blk, num_q_per_blk * num_q_heads_per_kv_head, 128] m_ref, # [num_kv_heads_per_blk, num_q_per_blk * num_q_heads_per_kv_head, 128] + acc_ref, # [num_q_per_blk, num_q_heads_per_blk, head_dim] *, sm_scale: float, - mask_value: float, + sliding_window: int | None = None, + soft_cap: float | None = None, + mask_value: float | None = DEFAULT_MASK_VALUE, ): + if mask_value is None: + mask_value = DEFAULT_MASK_VALUE num_q_per_blk, num_q_heads_per_blk, head_dim = q_ref.shape num_seqs = num_seqs_ref[0] - _, num_kv_pages_per_blk, page_size, num_kv_heads_per_blk, _ = k_bufs.shape + _, num_kv_pages_per_blk, page_size, num_combined_kv_heads_per_blk, _ = ( + kv_bufs.shape + ) + num_kv_heads_per_blk = num_combined_kv_heads_per_blk // 2 num_kv_per_blk = num_kv_pages_per_blk * page_size num_q_heads_per_kv_head = num_q_heads_per_blk // num_kv_heads_per_blk heads_blk_idx, q_blk_idx = ( @@ -242,41 +318,36 @@ def create_kv_async_copy_descriptors( heads_blk_idx, seq_idx, kv_blk_idx, buf_idx ): offset = (seq_idx, kv_blk_idx * num_kv_pages_per_blk) - heads_start = heads_blk_idx * num_kv_heads_per_blk - async_copy_k = MultiPageAsyncCopyDescriptor( - k_pages_hbm_ref.at[:, :, pl.ds(heads_start, num_kv_heads_per_blk), :], - k_bufs.at[buf_idx], - sems.at[buf_idx, 0], - page_indices_ref, - offset, - ) - async_copy_v = MultiPageAsyncCopyDescriptor( - v_pages_hbm_ref.at[:, :, pl.ds(heads_start, num_kv_heads_per_blk), :], - v_bufs.at[buf_idx], - sems.at[buf_idx, 1], + heads_start = heads_blk_idx * num_combined_kv_heads_per_blk + async_copy_kv = MultiPageAsyncCopyDescriptor( + kv_pages_hbm_ref.at[ + :, :, pl.ds(heads_start, num_combined_kv_heads_per_blk), : + ], + kv_bufs.at[buf_idx], + sems.at[buf_idx], page_indices_ref, offset, ) - return async_copy_k, async_copy_v + return async_copy_kv # TODO(jevinjiang): Add these to Mosaic: # 1. Support arbitrary strided load/store for any dtype. # 2. Support arbitrary strided load/store for any last dimension. def strided_load_kv(ref, start, step): if ref.dtype == jnp.float32: - return ref[start::step, :] + return ref[start::step, :], ref[start + 1 :: step, :] packing = get_dtype_packing(ref.dtype) assert ref.dtype == jnp.bfloat16 assert step % packing == 0 b_start = start // packing - b_offset = start % packing b_step = step // packing - b_ref = ref.bitcast(jnp.int32) + b_ref = ref.bitcast(jnp.uint32) b = b_ref[b_start::b_step, :] - bw = 32 // packing - b = jnp.right_shift(b, bw * b_offset) - b = jnp.left_shift(b, bw * (packing - 1)) - return pltpu.bitcast(b, jnp.float32).astype(jnp.bfloat16) + bk = b << 16 + bv = b & jnp.uint32(0xffff0000) + k = pltpu.bitcast(bk, jnp.float32).astype(jnp.bfloat16) + v = pltpu.bitcast(bv, jnp.float32).astype(jnp.bfloat16) + return k, v def fold_on_2nd_minor(vec): assert vec.dtype == jnp.bfloat16 or vec.dtype == jnp.float32 @@ -289,15 +360,16 @@ def fold_on_2nd_minor(vec): @pl.when(heads_blk_idx + q_blk_idx == 0) def prefetch_first_kv_blk(): - async_copy_k, async_copy_v = create_kv_async_copy_descriptors( + async_copy_kv = create_kv_async_copy_descriptors( heads_blk_idx, init_seq_idx, 0, init_buf_idx ) - async_copy_k.start() - async_copy_v.start() + async_copy_kv.start() def is_cur_q_blk_needed(q_states): done, cur_seq_idx, _ = q_states - return jnp.logical_and(done == 0, cur_seq_idx < num_seqs) + should_run = jnp.logical_and(q_len_start < cu_q_lens_ref[num_seqs], + cur_seq_idx < num_seqs) + return jnp.logical_and(done == 0, should_run) def compute_with_cur_q_blk(q_states): done, cur_seq_idx, cur_buf_idx = q_states @@ -342,7 +414,7 @@ def flash_attention( v, # [num_kv_per_blk, head_dim] head_l_ref, # [num_q_per_blk * num_q_heads_per_kv_head, 128] head_m_ref, # [num_q_per_blk * num_q_heads_per_kv_head, 128] - head_o_ref, # [num_q_per_blk, num_q_heads_per_kv_head, head_dim] + head_acc_ref, # [num_q_per_blk, num_q_heads_per_kv_head, head_dim] *, kv_blk_idx, ): @@ -363,7 +435,7 @@ def flash_attention( num_q_per_blk * num_q_heads_per_kv_head, 128, ) - assert head_o_ref.shape == ( + assert head_acc_ref.shape == ( num_q_per_blk, num_q_heads_per_kv_head, head_dim, @@ -373,7 +445,7 @@ def flash_attention( def masked_store(ref, val, start, end, group=1): iota = lax.broadcasted_iota(jnp.int32, ref.shape, 0) // group mask = jnp.logical_and(iota >= start, iota < end) - pl.store(ref, tuple(slice(None) for _ in ref.shape), val, mask=mask) + pl.store(ref, idx=tuple(slice(None) for _ in ref.shape), val=val, mask=mask) qk = ( jnp.einsum("nd,md->nm", q, k, preferred_element_type=jnp.float32) @@ -399,8 +471,8 @@ def init_scratch_ref(): num_q_heads_per_kv_head, ) masked_store( - head_o_ref, - jnp.zeros_like(head_o_ref), + head_acc_ref, + jnp.zeros_like(head_acc_ref), store_start, store_end, ) @@ -422,6 +494,11 @@ def init_scratch_ref(): 1, ) causal_mask = row_ids < col_ids + if sliding_window is not None: + causal_mask = jnp.logical_or(causal_mask, + row_ids - sliding_window >= col_ids) + if soft_cap is not None: + qk = soft_cap * jnp.tanh(qk / soft_cap) qk += jnp.where(causal_mask, mask_value, 0.0) m_curr = jnp.max(qk, axis=1, keepdims=True) s_curr = jnp.exp(qk - m_curr) @@ -461,17 +538,17 @@ def broadcast_to_shape(arr, shape): [arr for _ in range(shape[1] // arr.shape[1])], axis=1 ) - o_curr = head_o_ref[...].reshape(-1, head_dim) + o_curr = head_acc_ref[...].reshape(-1, head_dim) l_alpha = broadcast_to_shape(l_alpha, qkv.shape) beta = broadcast_to_shape(beta, qkv.shape) l_next_safe = broadcast_to_shape(l_next_safe, qkv.shape) out = lax.div( l_alpha * o_curr + beta * qkv, l_next_safe, - ).astype(head_o_ref.dtype) + ) masked_store( - head_o_ref, - out.reshape(head_o_ref.shape), + head_acc_ref, + out.reshape(head_acc_ref.shape), store_start, store_end, ) @@ -493,21 +570,18 @@ def prefetch_next_kv_blk(): # TODO(jevinjiang): reuse the same buffer if it is already prefetched! # TODO(jevinjiang): only fetch effective dynamic size to hold kv_len and # DMA to fixed size buffer! - next_async_copy_k, next_async_copy_v = create_kv_async_copy_descriptors( + next_async_copy_kv = create_kv_async_copy_descriptors( next_heads_blk_idx, next_seq_idx, next_kv_blk_idx, next_buf_idx ) - next_async_copy_k.start() - next_async_copy_v.start() + next_async_copy_kv.start() - cur_async_copy_k, cur_async_copy_v = create_kv_async_copy_descriptors( + cur_async_copy_kv = create_kv_async_copy_descriptors( heads_blk_idx, cur_seq_idx, kv_blk_idx, cur_buf_idx ) - kv_to_load_shape = ( - num_kv_pages_per_blk * page_size * num_kv_heads_per_blk, + kv_ref = cur_async_copy_kv.wait().reshape( + num_kv_pages_per_blk * page_size * num_combined_kv_heads_per_blk, head_dim, ) - k_ref = cur_async_copy_k.wait().reshape(kv_to_load_shape) - v_ref = cur_async_copy_v.wait().reshape(kv_to_load_shape) for kv_head_idx in range(num_kv_heads_per_blk): q_head_idx = kv_head_idx * num_q_heads_per_kv_head # TODO(jevinjiang): extra handlig for packed type that can start at @@ -515,15 +589,16 @@ def prefetch_next_kv_blk(): q = fold_on_2nd_minor( q_ref[:, q_head_idx : q_head_idx + num_q_heads_per_kv_head, :] ) - k = strided_load_kv(k_ref, kv_head_idx, num_kv_heads_per_blk) - v = strided_load_kv(v_ref, kv_head_idx, num_kv_heads_per_blk) + k, v = strided_load_kv( + kv_ref, kv_head_idx * 2, num_combined_kv_heads_per_blk + ) flash_attention( q, k, v, l_ref.at[kv_head_idx], m_ref.at[kv_head_idx], - o_ref.at[:, q_head_idx : q_head_idx + num_q_heads_per_kv_head, :], + acc_ref.at[:, q_head_idx : q_head_idx + num_q_heads_per_kv_head, :], kv_blk_idx=kv_blk_idx, ) return kv_blk_idx + 1, next_buf_idx @@ -545,9 +620,10 @@ def prefetch_next_kv_blk(): # Reset seq_idx for next kv_heads_blk if run out of seqs! seq_buf_idx_ref[0] = lax.select(seq_idx < num_seqs, seq_idx, 0) seq_buf_idx_ref[1] = buf_idx + o_ref[...] = acc_ref[...].astype(q_ref.dtype) -def ceil_div(a, b): +def cdiv(a, b): assert b != 0 return (a + b - 1) // b @@ -564,7 +640,9 @@ def get_dtype_packing(dtype): raise ValueError(f"Not implemented: unsupported {dtype=}") -def get_min_heads_per_blk(num_q_heads, num_kv_heads, q_dtype, kv_dtype): +def get_min_heads_per_blk( + num_q_heads, num_combined_kv_heads, q_dtype, kv_dtype +): q_packing = get_dtype_packing(q_dtype) kv_packing = get_dtype_packing(kv_dtype) @@ -575,22 +653,26 @@ def can_be_xla_fully_tiled(x, packing): return x in (1, 2, 4, 8) or x % 8 == 0 # TODO(jevinjiang): support unaligned number of heads! - if not can_be_xla_fully_tiled(num_kv_heads, kv_packing): + if not can_be_xla_fully_tiled(num_combined_kv_heads, kv_packing): raise ValueError( - f"Not implemented: {num_kv_heads=} can not be XLA fully tiled." + f"Not implemented: {num_combined_kv_heads=} can not be XLA fully tiled." ) + assert num_combined_kv_heads % 2 == 0 + num_kv_heads = num_combined_kv_heads // 2 assert num_q_heads % num_kv_heads == 0 ratio = num_q_heads // num_kv_heads # TODO(jevinjiang): we can choose smaller tiling for packed type if large # second minor tiling is not on. - max_kv_tiling = 8 * kv_packing - min_kv_heads = ( - max_kv_tiling if num_kv_heads % max_kv_tiling == 0 else num_kv_heads + max_combined_kv_tiling = 8 * kv_packing + min_combined_kv_heads = ( + max_combined_kv_tiling + if num_combined_kv_heads % max_combined_kv_tiling == 0 + else num_combined_kv_heads ) - min_q_heads = min_kv_heads * ratio + min_q_heads = min_combined_kv_heads // 2 * ratio if can_be_xla_fully_tiled(min_q_heads, q_packing): - return min_q_heads, min_kv_heads - return num_q_heads, num_kv_heads + return min_q_heads, min_combined_kv_heads + return num_q_heads, num_combined_kv_heads @functools.partial( @@ -601,20 +683,23 @@ def can_be_xla_fully_tiled(x, packing): "num_kv_pages_per_block", "num_queries_per_block", "vmem_limit_bytes", + "sliding_window", + "soft_cap", ], ) def ragged_paged_attention( q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim] # TODO(jevinjiang): create a write_to_kv_cache kernel! - k_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim] - v_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim] + kv_pages: jax.Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim] kv_lens: jax.Array, # i32[max_num_seqs] page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq] cu_q_lens: jax.Array, # i32[max_num_seqs + 1] num_seqs: jax.Array, # i32[1] *, sm_scale: float = 1.0, - mask_value: float = DEFAULT_MASK_VALUE, + sliding_window: int | None = None, + soft_cap: float | None = None, + mask_value: float | None = DEFAULT_MASK_VALUE, num_kv_pages_per_block: int = 16, num_queries_per_block: int = 128, vmem_limit_bytes: int | None = None, @@ -623,8 +708,7 @@ def ragged_paged_attention( Args: q: concatenated all sequences' queries. - k_pages: paged K cache. Normally in HBM. - v_pages: paged V cache. Normally in HBM. + kv_pages: paged K cache. Normally in HBM. kv_lens: padded kv lengths. Only the first num_seqs values are valid. page_indices: the first index indicates which page to use in the kv cache for each sequence. Only the first num_seqs values are valid. @@ -632,6 +716,8 @@ def ragged_paged_attention( kv_lens, only the first num_seqs+1 values are valid. num_seqs: the dynamic number of sequences. sm_scale: the softmax scale which will be applied to the Q@K^T. + sliding_window: the sliding window size for the attention. + soft_cap: the logit soft cap for the attention. mask_value: mask value for causal mask. num_kv_pages_per_block: number of kv pages to be processed in one flash attention block in the pallas kernel. @@ -642,18 +728,36 @@ def ragged_paged_attention( Returns: The output of the attention. """ - check_inputs_shapes( - q, k_pages, v_pages, kv_lens, page_indices, cu_q_lens, num_seqs + static_validate_inputs( + q, + kv_pages, + kv_lens, + page_indices, + cu_q_lens, + num_seqs, + sm_scale=sm_scale, + sliding_window=sliding_window, + soft_cap=soft_cap, + mask_value=mask_value, + num_kv_pages_per_block=num_kv_pages_per_block, + num_queries_per_block=num_queries_per_block, + vmem_limit_bytes=vmem_limit_bytes, ) - _, num_q_heads, head_dim = q.shape - _, page_size, num_kv_heads, _ = k_pages.shape + if mask_value is None: + mask_value = DEFAULT_MASK_VALUE + num_q_tokens, num_q_heads, head_dim = q.shape + _, page_size, num_combined_kv_heads, _ = kv_pages.shape + assert num_combined_kv_heads % 2 == 0 + num_kv_heads = num_combined_kv_heads // 2 num_q_per_blk = num_queries_per_block num_kv_pages_per_blk = num_kv_pages_per_block num_q_heads_per_kv_head = num_q_heads // num_kv_heads - num_q_blks = ceil_div(cu_q_lens[num_seqs[0]], num_q_per_blk) - num_q_heads_per_blk, num_kv_heads_per_blk = get_min_heads_per_blk( - num_q_heads, num_kv_heads, q.dtype, k_pages.dtype + num_q_blks = cdiv(num_q_tokens, num_q_per_blk) + num_q_heads_per_blk, num_combined_kv_heads_per_blk = get_min_heads_per_blk( + num_q_heads, num_combined_kv_heads, q.dtype, kv_pages.dtype ) + assert num_combined_kv_heads_per_blk % 2 == 0 + num_kv_heads_per_blk = num_combined_kv_heads_per_blk // 2 assert num_q_heads_per_blk % num_q_heads_per_kv_head == 0 num_heads_blks = num_q_heads // num_q_heads_per_blk grid = (num_heads_blks, num_q_blks) @@ -668,7 +772,6 @@ def q_index_map(heads_blk_idx, q_blk_idx, *_): in_specs = [ q_block_spec, pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), ] out_specs = q_block_spec lm_scratch = pltpu.VMEM( @@ -677,22 +780,26 @@ def q_index_map(heads_blk_idx, q_blk_idx, *_): (num_kv_heads_per_blk, num_q_per_blk * num_q_heads_per_kv_head, 128), jnp.float32, ) + acc_scratch = pltpu.VMEM( + (num_q_per_blk, num_q_heads_per_blk, head_dim), + jnp.float32, + ) double_buf_scratch = pltpu.VMEM( ( 2, # For double buffering during DMA copies. num_kv_pages_per_blk, page_size, - num_kv_heads_per_blk, + num_combined_kv_heads_per_blk, head_dim, ), - k_pages.dtype, + kv_pages.dtype, ) scratch_shapes = [ - double_buf_scratch, # k_bufs - double_buf_scratch, # v_bufs - pltpu.SemaphoreType.DMA((2, 2)), # [double_buffers, k_sem/v_sem] + double_buf_scratch, # kv_bufs + pltpu.SemaphoreType.DMA((2,)), # Semaphores for double buffers. lm_scratch, # l_ref lm_scratch, # m_ref + acc_scratch, ] scalar_prefetches = ( kv_lens, @@ -705,6 +812,8 @@ def q_index_map(heads_blk_idx, q_blk_idx, *_): functools.partial( ragged_paged_attention_kernel, sm_scale=sm_scale, + sliding_window=sliding_window, + soft_cap=soft_cap, mask_value=mask_value, ), grid_spec=pltpu.PrefetchScalarGridSpec( @@ -721,9 +830,8 @@ def q_index_map(heads_blk_idx, q_blk_idx, *_): ), vmem_limit_bytes=vmem_limit_bytes, ), - out_shape=jax.ShapeDtypeStruct(shape=q.shape, dtype=jnp.float32), + out_shape=jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype), name="ragged_paged_attention_kernel", ) - # TODO(jevinjiang): Use f32 acc scratch for output! So we only need - # to transfer output with desired dtype back to HBM. - return kernel(*scalar_prefetches, q, k_pages, v_pages).astype(q.dtype) + + return kernel(*scalar_prefetches, q, kv_pages) diff --git a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py index 4b6e4a41c43b..b69b0e36f177 100644 --- a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py +++ b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py @@ -599,9 +599,9 @@ def _apply_mask_and_soft_cap( masks = [] if mask_ref is not None: if k_in_lanes: - mask = pl.load(mask_ref, (slice(None), k_slice)) + mask = mask_ref[:, k_slice] else: - mask = pl.load(mask_ref, (k_slice, slice(None))) + mask = mask_ref[k_slice, :] masks.append( jnp.bitwise_or(mask, jnp.broadcast_to(should_not_mask, mask.shape)) @@ -630,7 +630,7 @@ def _apply_mask_and_soft_cap( k_sequence = k_offset + jax.lax.broadcasted_iota( jnp.int32, (k_slice.size, bq), 0 ) - q_sequence = pl.load(q_sequence_ref, (pl.ds(1), slice(None))) # [1, bq] + q_sequence = q_sequence_ref[:1, :] # [1, bq] q_sequence = jnp.broadcast_to(q_sequence, (k_slice.size, bq)) assert q_sequence.shape == k_sequence.shape @@ -644,7 +644,7 @@ def _apply_mask_and_soft_cap( if q_segment_ids_ref is not None: if k_in_lanes: - kv_ids = pl.load(kv_segment_ids_ref, (pl.ds(1), k_slice)) # [1, k_slice] + kv_ids = kv_segment_ids_ref[:1, k_slice] # [1, k_slice] repeats, rem = divmod(kv_ids.shape[1], NUM_LANES) if rem: raise NotImplementedError(f"block_kv must be a multiple of {NUM_LANES}") @@ -655,9 +655,9 @@ def _apply_mask_and_soft_cap( if rem: raise NotImplementedError(f"block_q must be a multiple of {NUM_LANES}") kv_ids = pltpu.repeat( - pl.load(kv_segment_ids_ref, (k_slice, slice(None))), repeats, axis=1 + kv_segment_ids_ref[k_slice, :], repeats, axis=1 ) # [k_slice, bq] - q_ids = pl.load(q_segment_ids_ref, (pl.ds(1), slice(None))) # [1, bq] + q_ids = q_segment_ids_ref[:1, :] # [1, bq] masks.append(q_ids == kv_ids) def cap_logits(logits): @@ -743,9 +743,9 @@ def body(kv_compute_index, _): q = q_ref[...] if q_layout == HEAD_DIM_MINOR else q_ref[...].T qk_dims = NT_DIM_NUMBERS if k_layout == HEAD_DIM_MINOR else NN_DIM_NUMBERS if k_layout == HEAD_DIM_MINOR: - k = pl.load(k_ref, (slice_k, slice(None))) + k = k_ref[slice_k, :] else: - k = pl.load(k_ref, (slice(None), slice_k)) + k = k_ref[:, slice_k] qk = lax.dot_general(q, k, qk_dims, preferred_element_type=float32) assert qk.shape == (bq, bkv_compute) @@ -794,9 +794,9 @@ def body(kv_compute_index, _): sv_dims = NN_DIM_NUMBERS if v_layout == HEAD_DIM_MINOR else NT_DIM_NUMBERS if v_layout == HEAD_DIM_MINOR: - v = pl.load(v_ref, (slice_k, slice(None))) + v = v_ref[slice_k, :] else: - v = pl.load(v_ref, (slice(None), slice_k)) + v = v_ref[:, slice_k] v = v.astype(float32) o_curr = lax.dot_general(s_curr, v, sv_dims) @@ -1688,13 +1688,13 @@ def body(i, _): q = q_ref[...] # We keep q potentially transposed, since it's always RHS def _load_kv(ref, layout): if layout == HEAD_DIM_MINOR: - return pl.load(ref, (slice_k, slice(None))) - return pl.load(ref, (slice(None), slice_k)).T + return ref[slice_k, :] + return ref[:, slice_k].T k = _load_kv(k_ref, k_layout) v = _load_kv(v_ref, v_layout) - logsumexp = pl.load(logsumexp_ref, (pl.ds(1), slice(None))) + logsumexp = logsumexp_ref[:1, :] do = do_ref[...] - di = pl.load(di_ref, (pl.ds(1), slice(None))) + di = di_ref[:1, :] qk_dims = NT_DIM_NUMBERS if q_layout == HEAD_DIM_MINOR else NN_DIM_NUMBERS qk_uncapped = lax.dot_general( @@ -1718,10 +1718,8 @@ def _load_kv(ref, layout): ) p = jnp.exp(qk - logsumexp) dv = lax.dot(p.astype(do.dtype), do, preferred_element_type=jnp.float32) - dv = dv.astype(dv_scratch_ref.dtype) + pl.load( - dv_scratch_ref, (slice_k, slice(None)) - ) - pl.store(dv_scratch_ref, (slice_k, slice(None)), dv) + dv = dv.astype(dv_scratch_ref.dtype) + dv_scratch_ref[slice_k, :] + dv_scratch_ref[slice_k, :] = dv dp = lax.dot_general( v, do, NT_DIM_NUMBERS, @@ -1737,10 +1735,8 @@ def _load_kv(ref, layout): dk = lax.dot_general( ds.astype(do.dtype), q, dk_dims, preferred_element_type=jnp.float32 ) - dk = dk.astype(dk_scratch_ref.dtype) + pl.load( - dk_scratch_ref, (slice_k, slice(None)) - ) - pl.store(dk_scratch_ref, (slice_k, slice(None)), dk) + dk = dk.astype(dk_scratch_ref.dtype) + dk_scratch_ref[slice_k, :] + dk_scratch_ref[slice_k, :] = dk if dq_scratch_ref is not None or dq_ref is not None: dq = lax.dot_general( ds.T.astype(k.dtype), k, NN_DIM_NUMBERS, @@ -2293,6 +2289,26 @@ def _splash_attention( mask_function: MaskFunctionType | None, interpret: bool, ) -> SplashCustomReturnType: + """ + For dynamic masks, `partial_mask_blocks` has shape (head_count, q_blocks, kv_blocks, block_q, block_kv). + This shape allows sharding across both head count and query sequence dimensions. + + Note: The leading dimensions (head_count, q_blocks, kv_blocks) must be + collapsed into a single dimension before being passed to the kernel. + """ + def _collapse_partial_mask_blocks(mask_info: mask_info_lib.MaskInfo | None): + if mask_info is None or mask_info.partial_mask_blocks is None: + return mask_info + + return mask_info._replace( + partial_mask_blocks=mask_info.partial_mask_blocks.reshape( + -1, *mask_info.partial_mask_blocks.shape[-2:] + ) + ) + + fwd_mask_info = _collapse_partial_mask_blocks(fwd_mask_info) + dq_mask_info = _collapse_partial_mask_blocks(dq_mask_info) + dkv_mask_info = _collapse_partial_mask_blocks(dkv_mask_info) return _splash_attention_custom( fwd_mask_info, dq_mask_info, @@ -2352,13 +2368,16 @@ def manual_sharding_spec(self, sharding: jax.sharding.NamedSharding): spec = sharding.spec assert len(spec) == 2 replicated = jax.sharding.PartitionSpec() + partial_mask_blocks_spec = ( + spec if self.fwd_mask_info.is_dynamic_mask else replicated + ) # Shard q_sequence over the sequence dimension only. q_sequence_spec = jax.sharding.PartitionSpec(spec[1]) mask_info_specs = mask_info_lib.MaskInfo( # pytype: disable=wrong-arg-types data_next=spec if self.fwd_mask_info.data_next is not None else None, mask_next=spec if self.fwd_mask_info.mask_next is not None else None, block_mask=spec if self.fwd_mask_info.block_mask is not None else None, - partial_mask_blocks=replicated + partial_mask_blocks=partial_mask_blocks_spec if self.fwd_mask_info.partial_mask_blocks is not None else None, q_sequence=q_sequence_spec diff --git a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask.py b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask.py index eab2a695dc02..e43f30e7791c 100644 --- a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask.py +++ b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask.py @@ -92,6 +92,35 @@ def make_local_attention_mask( return mask.astype(np.bool_) +def make_chunk_attention_mask( + shape: tuple[int, int], chunk_size: int +) -> np.ndarray: + """Makes a chunked causal attention mask. + + Args: + shape: The desired shape of the mask (q_seq_len, kv_seq_len). + chunk_size: The size of the attention chunks. + + Returns: + A boolean mask of shape `mask_shape` where True indicates attention is + allowed according to chunked causal rules, and False otherwise. + + Raises: + ValueError: If chunk_window_size is None or not positive. + """ + if chunk_size <= 0: + raise ValueError('chunk_size must be positive') + + q_seq_len, kv_seq_len = shape + q_idx = np.arange(q_seq_len, dtype=np.int32) + kv_idx = np.arange(kv_seq_len, dtype=np.int32) + + # chunk mask calculation + same_chunk = (q_idx[:, None] // chunk_size) == (kv_idx[None, :] // chunk_size) + mask = same_chunk & (q_idx[:, None] >= kv_idx[None, :]) + return mask + + def make_random_mask( shape: tuple[int, int], sparsity: float, seed: int ) -> np.ndarray: @@ -196,15 +225,20 @@ def __hash__(self): class _ComputableMask(Mask): """Superclass for all masks that can be computed inside the kernel using a callable object. + This subclass is designed to be used with Splash Attention. + It allows the mask logic to be computed on-the-fly or fused into the attention + kernel, avoiding the memory cost of materializing the full + (sequence_length, sequence_length) boolean mask array, which can be excessive + for long sequences. + Attributes: _shape: Shape of the 2-dim mask: (q_seq_len, kv_seq_len). offset: Offset of q start wrt kv. A positive offset shifts the bottom triangle upward, a negative one shifts it downward. A negative offset makes the first 'offset' rows of the attention matrix all 0s which leads to undefined softmax. - q_sequence: Indices of Q sequence. - q_sequence is reused across __getitem__ calls which is important for - compile-time performance. + q_sequence: Indices of Q sequence. q_sequence is reused across __getitem__ + calls which is important for compile-time performance. mask_function: Function used by the SplashAttention kernel to compute the mask rather than loading it. """ @@ -314,6 +348,66 @@ def __hash__(self): )) +class ChunkedCausalMask(_ComputableMask): + """Lazy chunked causal mask. + + Attention is causal within each chunk (0, K), (K, 2K), (2K, 3K), ... tokens + attend to each other but not accross chunks. + Llama4 models use interleaved chunk attention along with global attention. + + + Attributes: + chunk_size: The size of each attention chunk. + """ + + chunk_size: int + + def __init__( + self, + shape: tuple[int, int], + chunk_size: int, + shard_count: int = 1, + ): + if chunk_size <= 0: + raise ValueError('chunk_size must be positive') + self.chunk_size = chunk_size + + # Define the mask function for chunk attention + def chunked_causal_mask_function(q_ids, kv_ids): + """Computes the mask logic for the given slice indices.""" + # Condition 1: Same chunk + same_chunk = (q_ids // self.chunk_size) == (kv_ids // self.chunk_size) + + # Condition 2: Causal + causal = q_ids >= kv_ids + + return same_chunk & causal + + super().__init__( + shape=shape, + mask_function=chunked_causal_mask_function, + shard_count=shard_count, + ) + + def __eq__(self, other: object): + if not isinstance(other, type(self)): + return NotImplemented + + return ( + self.shape == other.shape + and self.chunk_size == other.chunk_size + and np.array_equal(self.q_sequence, other.q_sequence) + ) + + def __hash__(self): + return hash(( + type(self), + self.shape, + self.chunk_size, + self.q_sequence.tobytes() if self.q_sequence is not None else None, + )) + + class LocalMask(Mask): """Lazy local mask, prevents model from attending to tokens outside window. diff --git a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask_info.py b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask_info.py index 65081e79c0cf..9c79fbbf7e09 100644 --- a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask_info.py +++ b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask_info.py @@ -67,6 +67,10 @@ class MaskInfo(NamedTuple): q_sequence: A i32[q_sequence_length] NumPy array. When using causal masking, this contains the list of indices that correspond to q tokens. For plain causal this is just np.arange(q_sequence_length). + is_dynamic_mask: A bool indicating whether the mask is dynamic or static. + When True, the leading dimensions of `partial_mask_blocks` (num_heads, + q_blocks, kv_blocks) are not collapsed, allowing us to shard it along + those dimensions. """ data_next: np.ndarray | jax.Array | None @@ -74,6 +78,7 @@ class MaskInfo(NamedTuple): block_mask: np.ndarray | jax.Array | None partial_mask_blocks: np.ndarray | jax.Array | None q_sequence: np.ndarray | None + is_dynamic_mask: bool = None def _downcast_to_small_type(array: np.ndarray) -> np.ndarray: @@ -168,7 +173,7 @@ def __eq__(self, other: object) -> bool: def _get_mask_info_for_shard( output_shape: tuple[int, int, int], has_mask_next: bool, - mask: mask_lib.MultiHeadMask, + mask: mask_lib.MultiHeadMask | jax.Array, block_shape: tuple[int, int], coords_to_partial_mask_block_index: dict[tuple[int, int, int], int], masks_per_head_shard: int, @@ -338,7 +343,8 @@ def _process_dynamic_mask( launched. q_seq_shards: Number of Q sequence shards of the mesh in which the kernel is launched. - shrink_grid: Whether or not we should apply the grid shrinking optimization. This is currently ignored. + shrink_grid: Whether or not we should apply the grid shrinking optimization. + This is currently ignored. Returns: `MaskInfo`, a sparse representation of the dense mask. @@ -349,11 +355,6 @@ def _process_dynamic_mask( """ del shrink_grid - - # TODO(pobudzey): Properly support sharding. - if head_shards != 1 or q_seq_shards != 1: - raise ValueError('Dynamic mask processing does not support sharding.') - if len(mask.shape) != 3: raise ValueError(f'Expected a 3-dim mask, instead got: {mask.shape}.') @@ -370,6 +371,18 @@ def _process_dynamic_mask( if kv_mod != 0: raise ValueError(f'{kv_block_size=} should divide {kv_seq_len=}.') + q_seq_len_per_shard, mod = divmod(q_seq_len, q_seq_shards) + if mod != 0: + raise ValueError(f'{q_seq_shards=} should divide {q_seq_len=}.') + + q_blocks_per_shard, mod = divmod(q_seq_len_per_shard, q_block_size) + if mod != 0: + raise ValueError(f'{q_block_size=} should divide {q_seq_len_per_shard=}.') + + heads_per_shard, mod = divmod(head_count, head_shards) + if mod != 0: + raise ValueError(f'{head_shards=} should divide {head_count=}.') + block_mask_shape = ( head_count, q_blocks_count, @@ -398,26 +411,66 @@ def _process_dynamic_mask( block_mask = jnp.where(is_full_mask, 2, block_mask) block_mask = jnp.where(is_empty_mask, 0, block_mask) - # TODO(pobudzey): Return the next valid mask index instead of 0 for a more efficient pipeline. - mask_next = jnp.where( - jnp.logical_or(is_empty_mask, is_full_mask), - 0, - jnp.arange(math.prod(block_mask_shape), dtype=np.int32).reshape( - block_mask_shape - ), - ) + q_sequence_axis = 1 + head_axis = 0 - # data_next stores the index of the next non-empty data block in the sequence. - # The indices of empty blocks are set to 0 to avoid copying extra data when - # pipeling. - if is_dkv: - data_next = jnp.arange(q_blocks_count, dtype=np.int32)[None, :, None] - else: - data_next = jnp.arange(kv_blocks_count, dtype=np.int32)[None, None, :] - data_next = jnp.broadcast_to(data_next, block_mask_shape) - data_next = jnp.where(is_empty_mask, 0, data_next) + # Each iteration of the loop processes a slice of the mask info + # tensors of this shape: + mask_info_slice_shape = (heads_per_shard, q_blocks_per_shard, kv_blocks_count) + + # Collect mask_info shards along the head dimension, concatentate (or + # broadcast) them after the loop. + data_next_per_head_list, mask_next_per_head_list = [], [] + for head_shard in range(head_shards): + head_start = head_shard * heads_per_shard + mask_head_slice = slice(head_start, head_start + heads_per_shard) + + # Collect mask_info shards along the q_sequence dimension, concatenate them + # after the loop. + data_next_sequence_slices, mask_next_sequence_slices = [], [] + for q_seq_len_shard in range(q_seq_shards): + q_seq_len_start = q_seq_len_shard * q_blocks_per_shard + blocked_q_seq_len_slice = slice( + q_seq_len_start, q_seq_len_start + q_blocks_per_shard + ) + local_block_mask = block_mask[mask_head_slice, blocked_q_seq_len_slice] + + mask_next_slice = jnp.arange( + math.prod(mask_info_slice_shape), dtype=np.int32 + ).reshape(mask_info_slice_shape) + mask_next_slice = jnp.where(local_block_mask == 1, mask_next_slice, 0) + + # data_next stores the index of the next non-empty data block in the sequence. + # The indices of empty blocks are set to 0 to avoid copying extra data when + # pipeling. + if is_dkv: + data_next_slice = jnp.arange(q_blocks_per_shard, dtype=np.int32)[ + None, :, None + ] + else: + data_next_slice = jnp.arange(kv_blocks_count, dtype=np.int32)[ + None, None, : + ] + data_next_slice = jnp.broadcast_to(data_next_slice, mask_info_slice_shape) + data_next_slice = jnp.where(local_block_mask == 0, 0, data_next_slice) + + data_next_sequence_slices.append(data_next_slice) + mask_next_sequence_slices.append(mask_next_slice) + + # Concatenate the sequence shards. + data_next_per_head = jnp.concatenate( + data_next_sequence_slices, axis=q_sequence_axis + ) + data_next_per_head_list.append(data_next_per_head) + mask_next_per_head = jnp.concatenate( + mask_next_sequence_slices, axis=q_sequence_axis + ) + mask_next_per_head_list.append(mask_next_per_head) + + # Concatenate (or broadcast) the head shards. + data_next = jnp.concatenate(data_next_per_head_list, axis=head_axis) + mask_next = jnp.concatenate(mask_next_per_head_list, axis=head_axis) - partial_mask_blocks = partial_mask_blocks.reshape(-1, *block_shape) if is_dkv: partial_mask_blocks = partial_mask_blocks.swapaxes(-1, -2) @@ -438,9 +491,11 @@ def _downcast(array: jax.Array, max_value: int) -> jax.Array: if downcast_smem_data: block_mask = block_mask.astype(np.int8) # values are in the range [0, 1, 2] data_next = _downcast( - data_next, q_blocks_count if is_dkv else kv_blocks_count + data_next, q_blocks_per_shard if is_dkv else kv_blocks_count + ) + mask_next = _downcast( + mask_next, heads_per_shard * q_blocks_per_shard * kv_blocks_count ) - mask_next = _downcast(mask_next, math.prod(block_mask_shape)) return ( MaskInfo( @@ -449,6 +504,7 @@ def _downcast(array: jax.Array, max_value: int) -> jax.Array: block_mask=block_mask, partial_mask_blocks=partial_mask_blocks, q_sequence=None, + is_dynamic_mask=True, ), None, ) diff --git a/jax/experimental/pallas/tpu.py b/jax/experimental/pallas/tpu.py index ecc9d0d15120..21976c47166b 100644 --- a/jax/experimental/pallas/tpu.py +++ b/jax/experimental/pallas/tpu.py @@ -21,7 +21,6 @@ from jax._src.pallas.mosaic.core import GridDimensionSemantics as GridDimensionSemantics from jax._src.pallas.mosaic.core import PARALLEL as PARALLEL from jax._src.pallas.mosaic.core import PrefetchScalarGridSpec as PrefetchScalarGridSpec -from jax._src.pallas.mosaic.core import semaphore as semaphore from jax._src.pallas.mosaic.core import SemaphoreType as SemaphoreType from jax._src.pallas.mosaic.core import TPUMemorySpace as TPUMemorySpace from jax._src.pallas.mosaic.core import TPUCompilerParams as TPUCompilerParams @@ -40,8 +39,6 @@ from jax._src.pallas.mosaic.primitives import async_remote_copy as async_remote_copy from jax._src.pallas.mosaic.primitives import bitcast as bitcast from jax._src.pallas.mosaic.primitives import delay as delay -from jax._src.pallas.mosaic.primitives import device_id as device_id -from jax._src.pallas.mosaic.primitives import DeviceIdType as DeviceIdType from jax._src.pallas.mosaic.primitives import get_barrier_semaphore as get_barrier_semaphore from jax._src.pallas.mosaic.primitives import make_async_copy as make_async_copy from jax._src.pallas.mosaic.primitives import make_async_remote_copy as make_async_remote_copy @@ -49,12 +46,17 @@ from jax._src.pallas.mosaic.primitives import prng_seed as prng_seed from jax._src.pallas.mosaic.primitives import repeat as repeat from jax._src.pallas.mosaic.primitives import roll as roll -from jax._src.pallas.mosaic.primitives import semaphore_read as semaphore_read -from jax._src.pallas.mosaic.primitives import semaphore_signal as semaphore_signal -from jax._src.pallas.mosaic.primitives import semaphore_wait as semaphore_wait from jax._src.pallas.mosaic.random import sample_block as sample_block from jax._src.pallas.mosaic.random import to_pallas_key as to_pallas_key +# Those primitives got moved to Pallas core. Keeping the updated imports +# here for backward compatibility. +from jax._src.pallas.core import semaphore as semaphore +from jax._src.pallas.primitives import DeviceIdType as DeviceIdType +from jax._src.pallas.primitives import semaphore_read as semaphore_read +from jax._src.pallas.primitives import semaphore_signal as semaphore_signal +from jax._src.pallas.primitives import semaphore_wait as semaphore_wait + import types from jax._src.pallas.mosaic.verification import assume from jax._src.pallas.mosaic.verification import pretend diff --git a/jax/experimental/rnn.py b/jax/experimental/rnn.py index 55cf2b3bae70..be06aba2db13 100644 --- a/jax/experimental/rnn.py +++ b/jax/experimental/rnn.py @@ -463,7 +463,7 @@ def _gpu_lowering_strip_tf32(fn, *args, cudnn_allow_tf32, **kw): rnn_fwd_p.def_abstract_eval(rnn_abstract_eval) if gpu_rnn: mlir.register_lowering(rnn_fwd_p, gpu_rnn.cudnn_rnn_lowering, platform='cuda') - if hasattr(gpu_rnn, "miopen_rnn_fwd_lowering"): + if hasattr(gpu_rnn, "miopen_rnn_lowering"): mlir.register_lowering(rnn_fwd_p, gpu_rnn.miopen_rnn_lowering, platform='rocm') diff --git a/jax/experimental/roofline/rooflines.py b/jax/experimental/roofline/rooflines.py index 1edd1e0649b1..1a84095a0e31 100644 --- a/jax/experimental/roofline/rooflines.py +++ b/jax/experimental/roofline/rooflines.py @@ -14,6 +14,7 @@ from collections import defaultdict from dataclasses import replace import itertools as it +from typing import Sequence import numpy as np from jax._src import ad_util @@ -35,6 +36,8 @@ from jax.experimental import roofline from jax.experimental import shard_map +# One FMA (Fused Multiply Add) takes 2 flops to compute. +_FMA_FLOPS_FACTOR = 2 for prim in it.chain( ad_util.__dict__.values(), @@ -156,7 +159,7 @@ def _dot_general_roofline( (lhs_contract, _), (lhs_batch, _) = dimension_numbers flops = ( - 2 + _FMA_FLOPS_FACTOR * lhs.size * rhs.size / np.prod([lhs.shape[i] for i in lhs_contract]) @@ -177,16 +180,208 @@ def _dot_general_roofline( unfused_hbm_bytes=hbm_bytes, ) + +def _get_spatial_valid_position_count_for_one_dim( + window_dim_stride: int, + base_dilation: int, + window_dilation: int, + kernel_limit: int, + input_limit: int, + output_limit: int, + padding: tuple[int, int], +) -> int: + """Gets the valid position count for conv for a single spatial dimension. + + Args: + window_dim_stride: The stride of the window along this dimension. + base_dilation: The base dilation factor along this dimension. + window_dilation: The window dilation factor along this dimension. + kernel_limit: The size of the kernel along this dimension. + input_limit: The size of the input along this dimension. + output_limit: The size of the output along this dimension. + padding: The padding applied to the input along this dimension. + """ + padding_low = padding[0] + padding_high = padding[1] + + # These two conditions will create an N^2 iteration pattern with only N + # valid elements. This is a performance optimization and produces the same + # result as the whole loop. + if ( + input_limit == output_limit + and kernel_limit == output_limit + and input_limit == base_dilation + and window_dilation == 1 + and max(1, input_limit - 1) == window_dim_stride + and padding_low == 0 + and padding_high == 0 + ): + return input_limit + + if ( + input_limit == 1 + and kernel_limit == output_limit + and window_dilation == 1 + and base_dilation == 1 + and window_dim_stride == 1 + and padding_low == output_limit - 1 + and padding_high == output_limit - 1 + ): + return output_limit + + valid_position_count = 0 + # Loop over each point in the kernel + for kernel_idx in range(kernel_limit): + + # Skip loop for trivial stride and base_dilation + if window_dim_stride == 1 and base_dilation == 1: + undilated_index_base = padding_low - kernel_idx * window_dilation + upper_limit = min( + input_limit + undilated_index_base, + output_limit, + ) + lower_limit = max(0, undilated_index_base) + + valid_position_count += max(upper_limit - lower_limit, 0) + continue + + # Loop over each point in the output + for output_idx in range(output_limit): + # Calculate lhs (input) index without taking base dilation into account + undilated_index = ( + output_idx * window_dim_stride + - padding_low + + kernel_idx * window_dilation + ) + # Calculate the actual lhs (input) index after dilation + lhs_spatial_index = int(undilated_index / base_dilation) + + # Skip if the lhs (input) index is to be dilated. + if undilated_index != lhs_spatial_index * base_dilation: + continue + # Skip if input index is not in bound. + if lhs_spatial_index < 0 or lhs_spatial_index >= input_limit: + continue + + valid_position_count += 1 + return valid_position_count + + +def _get_spatial_valid_position_count( + dnums: convolution.ConvDimensionNumbers, + lhs: roofline.RooflineShape, + rhs: roofline.RooflineShape, + out: roofline.RooflineShape, + window_strides: Sequence[int], + padding: Sequence[tuple[int, int]], + lhs_dilation: Sequence[int], + rhs_dilation: Sequence[int], +) -> int: + """Gets the number of valid spatial positions for conv_general_dilated. + + Args: + dnums: The dimension numbers for the convolution. + lhs: The shape of the left-hand side of the convolution. + rhs: The shape of the right-hand side of the convolution. + out: The shape of the output of the convolution. + window_strides: The stride of the window along each spatial dimension. + padding: The padding applied to the input along each spatial dimension. + lhs_dilation: The dilation factor for the left-hand side along each spatial + dimension. + rhs_dilation: The dilation factor for the right-hand side along each spatial + dimension. + """ + input_spatial_dims, kernel_spatial_dims, out_spatial_dims = ( + dnums.lhs_spec[2:], + dnums.rhs_spec[2:], + dnums.out_spec[2:], + ) + + valid_position_counts = 1 + # Loop over each spatial dimension and determine how many valid positions + # there are for each dimension. + for d in range(len(input_spatial_dims)): + valid_position_counts *= _get_spatial_valid_position_count_for_one_dim( + window_dim_stride=window_strides[d], + base_dilation=lhs_dilation[d], + window_dilation=rhs_dilation[d], + kernel_limit=rhs.shape[kernel_spatial_dims[d]], + input_limit=lhs.shape[input_spatial_dims[d]], + output_limit=out.shape[out_spatial_dims[d]], + padding=padding[d], + ) + + return valid_position_counts + + +def _calculate_conv_flops( + lhs: roofline.RooflineShape, + rhs: roofline.RooflineShape, + out: roofline.RooflineShape, + window_strides: Sequence[int], + padding: Sequence[tuple[int, int]], + lhs_dilation: Sequence[int], + rhs_dilation: Sequence[int], + dimension_numbers: convolution.ConvGeneralDilatedDimensionNumbers, + batch_group_count: int, +) -> int: + """Calculates roofline unfused flops for Jax's conv_general_dilated primitive. + + See `jax.lax.conv_general_dilated` for details on the arguments. + """ + dnums = convolution.conv_dimension_numbers( + lhs.shape, rhs.shape, dimension_numbers + ) + + spatial_valid_position_counts = _get_spatial_valid_position_count( + dnums, lhs, rhs, out, window_strides, padding, lhs_dilation, rhs_dilation + ) + + batch = lhs.shape[dnums.lhs_spec[0]] + num_output_features = out.shape[dnums.out_spec[1]] + num_input_features = rhs.shape[dnums.rhs_spec[1]] + num_output_batch = batch / batch_group_count + + non_spatial_dims_factor = ( + num_input_features * num_output_features * num_output_batch + ) + + fma_count = non_spatial_dims_factor * spatial_valid_position_counts + flops = fma_count * _FMA_FLOPS_FACTOR + return int(flops) + + @roofline.register_roofline(convolution.conv_general_dilated_p) def _conv_general_dilated_roofline( - ctx: roofline.RooflineRuleContext, - *args, - **kw, + ctx: roofline.RooflineRuleContext, + *args, + window_strides: Sequence[int], + padding: Sequence[tuple[int, int]], + lhs_dilation: Sequence[int], + rhs_dilation: Sequence[int], + dimension_numbers: convolution.ConvGeneralDilatedDimensionNumbers, + batch_group_count: int, + **kw, ) -> roofline.RooflineResult: + """Roofline for Jax's conv_general_dilated primitive. + + See `jax.lax.conv_general_dilated` for details on the arguments. + """ lhs, rhs = (roofline.RooflineShape.from_aval(aval) for aval in ctx.avals_in) out = roofline.RooflineShape.from_aval(ctx.avals_out[0]) - # TODO(b/394648206): support computing unfused_flops for conv. + return roofline.RooflineResult( + unfused_flops=_calculate_conv_flops( + lhs, + rhs, + out, + window_strides, + padding, + lhs_dilation, + rhs_dilation, + dimension_numbers, + batch_group_count, + ), unfused_hbm_bytes=( lhs.dtype.itemsize * lhs.size + rhs.dtype.itemsize * rhs.size @@ -272,7 +467,7 @@ def _scalar_collective_roofline( roofline.register_roofline(lax_parallel.pmax_p)(_scalar_collective_roofline) -@roofline.register_roofline(shard_map.psum2_p) +@roofline.register_roofline(lax_parallel.psum_invariant_p) def _psum2_roofline( ctx: roofline.RooflineRuleContext, *args, diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 66b70c6c2d34..51e6acae8aee 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -33,7 +33,7 @@ from jax._src import callback from jax._src import config from jax._src import core -from jax._src import custom_derivatives +from jax._src import custom_derivatives as cd from jax._src import debugging from jax._src import dispatch from jax._src import dtypes @@ -46,7 +46,8 @@ from jax._src import source_info_util from jax._src import traceback_util from jax._src import util -from jax._src.core import Tracer +from jax._src.core import pvary, pvary_p +from jax._src.core import Tracer, typeof from jax._src.mesh import (AbstractMesh, Mesh, AxisType, use_abstract_mesh, get_abstract_mesh) from jax._src.api import _shared_code_pmap, _prepare_pmap @@ -136,7 +137,7 @@ def shard_map(f: Callable, mesh: Mesh | AbstractMesh, in_specs: Specs, Examples: For examples, refer to :ref:`sharded-computation` or `SPMD multi-device parallelism with shard_map`_. - .. _SPMD multi-device parallelism with shard_map: https://jax.readthedocs.io/en/latest/notebooks/shard_map.html + .. _SPMD multi-device parallelism with shard_map: https://docs.jax.dev/en/latest/notebooks/shard_map.html """ return _shard_map(f, mesh, in_specs, out_specs, check_rep, auto) @@ -189,14 +190,13 @@ def out_names_thunk(): raise e('shard_map out_specs') from None return tuple(map(_canonicalize_spec, out_specs_flat)) - if rewrite := check_rep: - fun = _efficient_transpose_rewrite(fun, mesh, in_names_flat, out_names_thunk) + if check_rep: + fun = _implicit_pvary_on_output(fun, out_names_thunk) try: out_flat = shard_map_p.bind( fun, *args_flat, mesh=mesh, in_names=in_names_flat, - out_names_thunk=out_names_thunk, check_rep=check_rep, rewrite=rewrite, - auto=auto) + out_names_thunk=out_names_thunk, check_rep=check_rep, auto=auto) except _SpecError as e: fails, = e.args if not callable(out_specs): @@ -516,22 +516,22 @@ def _shard_map_staging( in_names: tuple[AxisNames, ...], out_names_thunk: Callable[[], tuple[AxisNames, ...]], check_rep: bool, - rewrite: bool, auto: frozenset, ) -> Sequence[pe.DynamicJaxprTracer]: in_tracers = map(trace.to_jaxpr_tracer, in_tracers) in_avals = [t.aval for t in in_tracers] - in_avals_ = map(partial(_shard_aval, mesh, auto), in_names, in_avals) + in_avals_ = map(partial(_shard_aval, mesh, auto, check_rep), in_names, + in_avals) manual_mesh = _as_manual_mesh(mesh, auto) - with _extend_axis_env(mesh, auto), use_abstract_mesh(manual_mesh): + with (_extend_axis_env(mesh, auto), use_abstract_mesh(manual_mesh), + config._check_rep(check_rep)): jaxpr, out_avals_, consts, () = pe.trace_to_jaxpr_dynamic(f, in_avals_) _check_names(out_names_thunk(), out_avals_) if check_rep: - in_rep = map(partial(_in_names_to_rep, mesh), in_names) - out_rep = _check_rep(mesh, jaxpr, in_rep) - _check_reps(mesh, out_names_thunk(), out_rep) + out_rep = [_vma_to_rep(mesh, auto, v.aval.vma) for v in jaxpr.outvars] + _check_reps(mesh, auto, out_names_thunk(), out_rep) out_avals = map(_check_shapedarray, out_avals_) - out_avals = [_check_shapedarray(_unshard_aval(mesh, names, aval)) + out_avals = [_check_shapedarray(_unshard_aval(mesh, check_rep, names, aval)) for names, aval in zip(out_names_thunk(), out_avals)] source_info = source_info_util.current() out_tracers = [pe.DynamicJaxprTracer(trace, a, source_info) for a in out_avals] @@ -539,11 +539,12 @@ def _shard_map_staging( constvars = map(trace.getvar, map(trace.to_jaxpr_tracer, consts)) outvars = map(trace.makevar, out_tracers) in_names_staged = ({},) * len(consts) + tuple(in_names) # type: ignore - with _extend_axis_env(mesh, auto), use_abstract_mesh(manual_mesh): + with (_extend_axis_env(mesh, auto), use_abstract_mesh(manual_mesh), + config._check_rep(check_rep)): jaxpr = pe.convert_constvars_jaxpr(jaxpr) params = dict(mesh=mesh, in_names=in_names_staged, out_names=tuple(out_names_thunk()), jaxpr=jaxpr, - check_rep=check_rep, rewrite=rewrite, auto=auto) + check_rep=check_rep, auto=auto) effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names) eqn = pe.new_jaxpr_eqn([*constvars, *invars], outvars, prim, params, effs, source_info) @@ -557,30 +558,33 @@ def _check_shapedarray(aval: core.AbstractValue) -> core.ShapedArray: assert isinstance(aval, core.ShapedArray) return aval -def _shard_aval(mesh: Mesh, auto, names: AxisNames, aval: core.AbstractValue - ) -> core.AbstractValue: +def _shard_aval(mesh: Mesh, auto, check_rep, names: AxisNames, + aval: core.AbstractValue) -> core.AbstractValue: if type(aval) in core.shard_aval_handlers: - return core.shard_aval_handlers[type(aval)](mesh, auto, names, aval) + return core.shard_aval_handlers[type(aval)](mesh, auto, check_rep, names, + aval) raise NotImplementedError(f"Unsupported aval type: {type(aval)}") -def _unshard_aval(mesh: Mesh, names: AxisNames, aval: core.AbstractValue - ) -> core.AbstractValue: +def _unshard_aval(mesh: Mesh, check_rep, names: AxisNames, + aval: core.AbstractValue) -> core.AbstractValue: if type(aval) in core.unshard_aval_handlers: - return core.unshard_aval_handlers[type(aval)](mesh, names, aval) + return core.unshard_aval_handlers[type(aval)](mesh, check_rep, names, aval) else: raise NotImplementedError(f"Unsupported aval type: {type(aval)}") -def _shard_shaped_array(mesh: Mesh, auto: frozenset, names: AxisNames, +def _shard_shaped_array(mesh: Mesh, auto: frozenset, check_rep, names: AxisNames, aval: core.AbstractValue) -> core.AbstractValue: assert isinstance(aval, core.ShapedArray) new_shape = tuple(sz // prod(mesh.shape[n] for n in names.get(i, ())) for i, sz in enumerate(aval.shape)) manual_mesh = _as_manual_mesh(mesh, auto) new_sharding = NamedSharding(manual_mesh, aval.sharding.spec) - return aval.update(shape=new_shape, sharding=new_sharding) + vma = (frozenset({n for ns in names.values() for n in ns}) + if check_rep else frozenset()) + return aval.update(shape=new_shape, sharding=new_sharding, vma=vma) core.shard_aval_handlers[core.ShapedArray] = _shard_shaped_array -def _unshard_shaped_array(mesh: Mesh, names: AxisNames, +def _unshard_shaped_array(mesh: Mesh, check_rep, names: AxisNames, aval: core.AbstractValue,) -> core.AbstractValue: assert isinstance(aval, core.ShapedArray) new_shape = tuple(sz * prod(mesh.shape[n] for n in names.get(i, ())) @@ -606,31 +610,35 @@ def _unshard_shaped_array(mesh: Mesh, names: AxisNames, new_mesh = (mesh.abstract_mesh if get_abstract_mesh().empty else get_abstract_mesh()) new_sharding = NamedSharding(new_mesh, out_spec) - return aval.update(shape=new_shape, sharding=new_sharding) + manual_axes = set(new_mesh.manual_axes) + vma = (frozenset(v for v in aval.vma if v in manual_axes) + if check_rep else frozenset()) + return aval.update(shape=new_shape, sharding=new_sharding, vma=vma) core.unshard_aval_handlers[core.ShapedArray] = _unshard_shaped_array # Type-checking -RepType = Union[set[AxisName], None] +RepType = Any def _shard_map_typecheck(_, *in_atoms, jaxpr, mesh, in_names, out_names, - check_rep, rewrite, auto): + check_rep, auto): # TODO(mattjj,parkers): check auto for v, x, in_name in zip(jaxpr.invars, in_atoms, in_names): - if not core.typecompat(v.aval, _shard_aval(mesh, auto, in_name, x.aval)): + if not core.typecompat(v.aval, _shard_aval( + mesh, auto, check_rep, in_name, x.aval)): raise core.JaxprTypeError("shard_map argument avals not compatible with " "jaxpr binder avals and in_names") - with _extend_axis_env(mesh, auto): + with _extend_axis_env(mesh, auto), config._check_rep(check_rep): core.check_jaxpr(jaxpr) if check_rep: - in_rep = map(partial(_in_names_to_rep, mesh), in_names) - out_rep = _check_rep(mesh, jaxpr, in_rep) + out_rep = [_vma_to_rep(mesh, auto, v.aval.vma) for v in jaxpr.outvars] for rep, dst in zip(out_rep, out_names): - if not _valid_repeats(mesh, rep, dst): + if not _valid_repeats(mesh, auto, rep, dst): raise core.JaxprTypeError("shard_map can't prove output is " "sufficiently replicated") out_avals_sharded = [x.aval for x in jaxpr.outvars] - out_avals = map(partial(_unshard_aval, mesh), out_names, out_avals_sharded) + out_avals = map(partial(_unshard_aval, mesh, check_rep), out_names, + out_avals_sharded) effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names) return out_avals, effs core.custom_typechecks[shard_map_p] = _shard_map_typecheck @@ -655,15 +663,16 @@ def write(v: core.Var, val: RepType) -> None: rule = _check_rules.get(e.primitive, partial(_rule_missing, e.primitive)) out_rep = rule(mesh, *map(read, e.invars), **e.params) if e.primitive.multiple_results: - out_rep = [out_rep] * len(e.outvars) if type(out_rep) is set else out_rep + out_rep = (out_rep if isinstance(out_rep, (list, tuple)) else + [out_rep] * len(e.outvars)) foreach(write, e.outvars, out_rep) else: write(e.outvars[0], out_rep) core.clean_up_dead_vars(e, env, last_used) return map(read, jaxpr.outvars) -def _valid_repeats(mesh: Mesh, rep: RepType, dst: AxisNames) -> bool: - return rep is None or set(_unmentioned(mesh, dst)).issubset(rep) +def _valid_repeats(mesh: Mesh, auto, rep: RepType, dst: AxisNames) -> bool: + return rep is None or (set(_unmentioned(mesh, dst)) - auto).issubset(rep) def _rule_missing(prim: core.Primitive, *_, **__): raise NotImplementedError( @@ -687,12 +696,12 @@ def _shardy_shard_map_sharding( for dim_sharding in sdy_sharding.dimension_shardings: # Only allow dimensions which have no sharding to be auto-sharded. if not dim_sharding.axes: - dim_sharding.is_closed = False + dim_sharding.is_open = True return sdy_sharding def _shard_map_lowering_shardy( - ctx, in_nodes, jaxpr, mesh, in_names, out_names, auto): + ctx, in_nodes, jaxpr, mesh, in_names, out_names, auto, check_rep): in_avals_ = [v.aval for v in jaxpr.invars] if isinstance(ctx.module_context.axis_context, sharding_impls.SPMDAxisContext): # Nested `ManualComputationOp`s cannot refer to axes that are already @@ -711,7 +720,7 @@ def _shard_map_lowering_shardy( if a in shardy_manual_axes] if np.prod([mesh.shape[a] for a in manual_axes]) == 1: # No need for a `ManualComputationOp` if all manual axes are size 1. - with _extend_axis_env(mesh, auto): + with _extend_axis_env(mesh, auto), config._check_rep(check_rep): out_nodes, _ = mlir.jaxpr_subcomp( sub_ctx, jaxpr, ctx.name_stack, mlir.TokenSet(), (), *in_nodes, dim_var_values=ctx.dim_var_values) @@ -730,7 +739,8 @@ def _shard_map_lowering_shardy( ir.ArrayAttr.get([ir.StringAttr.get(i) for i in manual_axes]))) block = ir.Block.create_at_start( manual_computation_op.body, map(mlir.aval_to_ir_type, in_avals_)) - with ir.InsertionPoint(block), _extend_axis_env(mesh, auto): + with (ir.InsertionPoint(block), _extend_axis_env(mesh, auto), + config._check_rep(check_rep)): out_nodes_, _ = mlir.jaxpr_subcomp( sub_ctx, jaxpr, ctx.name_stack, mlir.TokenSet(), (), *block.arguments, dim_var_values=ctx.dim_var_values) @@ -740,12 +750,10 @@ def _shard_map_lowering_shardy( def _shard_map_lowering(ctx, *in_nodes, jaxpr, mesh, in_names, out_names, - check_rep, rewrite, auto): - del check_rep, rewrite - + check_rep, auto): if config.use_shardy_partitioner.value: return _shard_map_lowering_shardy( - ctx, in_nodes, jaxpr, mesh, in_names, out_names, auto) + ctx, in_nodes, jaxpr, mesh, in_names, out_names, auto, check_rep) in_avals_ = [v.aval for v in jaxpr.invars] out_avals_ = [x.aval for x in jaxpr.outvars] @@ -754,7 +762,7 @@ def _shard_map_lowering(ctx, *in_nodes, jaxpr, mesh, in_names, out_names, manual_axes = frozenset(mesh.axis_names) - auto new_axis_context = sharding_impls.SPMDAxisContext(mesh, manual_axes) sub_ctx = ctx.module_context.replace(axis_context=new_axis_context) - with _extend_axis_env(mesh, auto): + with _extend_axis_env(mesh, auto), config._check_rep(check_rep): out_nodes_, tokens_out = mlir.call_lowering( "shmap_body", ctx.name_stack, jaxpr, None, sub_ctx, in_avals_, out_avals_, ctx.tokens_in, *in_nodes_, dim_var_values=ctx.dim_var_values, @@ -767,13 +775,12 @@ def _shard_map_lowering(ctx, *in_nodes, jaxpr, mesh, in_names, out_names, def _make_scoped_manual_sharding(ctx, mesh, axes): axis_ctx = ctx.module_context.axis_context + mesh = mesh.abstract_mesh if isinstance(axis_ctx, sharding_impls.SPMDAxisContext): - manual_axes = axis_ctx.manual_axes - else: - manual_axes = frozenset({}) + mesh = mesh.update_axis_types( + {a: AxisType.Manual for a in axis_ctx.manual_axes}) return NamedSharding( - mesh, sharding_impls.array_mapping_to_axis_resources(axes), # pytype: disable=wrong-arg-types - _manual_axes=manual_axes) + mesh, sharding_impls.array_mapping_to_axis_resources(axes)) # type: ignore def _xla_shard(ctx: mlir.LoweringRuleContext, mesh, auto, names, aval_in, aval_out, x): @@ -833,23 +840,42 @@ def get_mesh_from_args(args_flat, mesh): assert isinstance(mesh, Mesh) return mesh +def _rep_to_vma(mesh, auto, rep: frozenset[AxisName]) -> frozenset[AxisName]: + return frozenset((set(mesh.axis_names) - auto) - rep) + +def _rep_to_spec(mesh, auto, rep): + return _vma_to_spec(mesh, _rep_to_vma(mesh, auto, rep)) + +def _vma_to_spec(mesh, vma): + return P(tuple(i for i in mesh.axis_names if i in vma)) + +def _names_to_vma(names): + return {n for ns in names.values() for n in ns} + +def _vma_to_rep(mesh, auto, vma): + return frozenset((set(mesh.axis_names) - auto) - vma) + def _shard_map_impl(trace, prim, fun, args, *, mesh, in_names, out_names_thunk, - check_rep, rewrite, auto): + check_rep, auto): if auto: raise NotImplementedError del prim if isinstance(mesh, AbstractMesh): mesh = get_mesh_from_args(args, mesh) - args = map(partial(_unmatch_spec, mesh, context_mesh=get_abstract_mesh()), + cur_mesh = get_abstract_mesh() + args = map(partial(_unmatch_spec, mesh, check_rep, context_mesh=cur_mesh), in_names, args) in_rep = map(partial(_in_names_to_rep, mesh), in_names) - outs, out_rep = _run_shmap(fun, mesh, auto, args, in_rep, check_rep, - get_abstract_mesh()) + outs, out_rep = _run_shmap(fun, mesh, auto, args, in_rep, check_rep, cur_mesh) out_avals = [core.mapped_aval(x.shape[0], 0, core.get_aval(x)) for x in outs] _check_names(out_names_thunk(), out_avals) # pytype: disable=wrong-arg-types if check_rep: - _check_reps(mesh, out_names_thunk(), out_rep) - pspecs = map(_names_to_pspec, out_names_thunk()) - return map(partial(_match_spec, mesh, check_rep), pspecs, outs) + _check_reps(mesh, auto, out_names_thunk(), out_rep) + src_pspecs = tuple(_rep_to_spec(mesh, auto, r) for r in out_rep) + else: + src_pspecs = tuple(P(mesh.axis_names) for _ in out_rep) + dst_pspecs = map(_names_to_pspec, out_names_thunk()) + return map(partial(_match_spec, mesh, check_rep), src_pspecs, dst_pspecs, + outs) core.EvalTrace.process_shard_map = _shard_map_impl def _run_shmap(f, mesh, auto, args, reps, check_rep, context_mesh): @@ -857,7 +883,7 @@ def _run_shmap(f, mesh, auto, args, reps, check_rep, context_mesh): in_tracers = map(partial(ShardMapTracer, trace), reps, args) manual_mesh = _as_manual_mesh(mesh, auto) with (core.set_current_trace(trace), _extend_axis_env(mesh, auto), - use_abstract_mesh(manual_mesh)): + use_abstract_mesh(manual_mesh), config._check_rep(check_rep)): ans = f.call_wrapped(*in_tracers) outs, out_rep = unzip2(map(trace.to_val_rep_pair, ans)) return outs, out_rep @@ -867,43 +893,56 @@ def _names_to_pspec(names: AxisNames) -> PartitionSpec: unpack = lambda t: t[0] if t is not None and len(t) == 1 else t return PartitionSpec(*(unpack(names.get(i)) for i in range(ndmin))) -def _unmatch_spec(mesh: Mesh, src: AxisNames, x: JaxType, context_mesh) -> JaxType: +def _unmatch_spec(mesh: Mesh, check_rep, src: AxisNames, x: JaxType, + context_mesh) -> JaxType: with (core.eval_context(), jax.disable_jit(False), use_abstract_mesh(context_mesh)): - return jax.jit(HashablePartial(_unmatch, mesh, tuple(src.items())))(x) + return jax.jit(HashablePartial(_unmatch, mesh, check_rep, + tuple(src.items())))(x) -def _unmatch(mesh, src_tup, x): +def _unmatch(mesh, check_rep, src_tup, x): src = _names_to_pspec(dict(src_tup)) - dst = P(mesh.axis_names) - return shard_map(_add_singleton, mesh, (src,), dst, check_rep=False)(x) + if check_rep: + used_axes = {i for _, ns in src_tup for i in ns} + dst = P(tuple(i for i in mesh.axis_names if i in used_axes)) + else: + dst = P(mesh.axis_names) + check_rep = False + return shard_map(_add_singleton, mesh, (src,), dst, check_rep=check_rep)(x) def _check_names(names: Sequence[AxisNames], avals: Sequence[core.ShapedArray] ) -> None: fail = [a if n and not max(n) < a.ndim else no_fail for n, a in zip(names, avals)] - if any(f is not no_fail for f in fail): raise _SpecError(fail) -class _SpecError(Exception): pass + if any(f is not no_fail for f in fail): + raise _SpecError(fail) + +class _SpecError(Exception): + pass -def _check_reps(mesh, names, reps): - fail = [r if not _valid_repeats(mesh, r, n) else no_fail +def _check_reps(mesh, auto, names, reps): + fail = [r if not _valid_repeats(mesh, auto, r, n) else no_fail for n, r in zip(names, reps)] - if any(f is not no_fail for f in fail): raise _RepError(fail) -class _RepError(Exception): pass + if any(f is not no_fail for f in fail): + raise _RepError(fail) + +class _RepError(Exception): + pass def _check_reps2(mesh, reps_dest, reps): fail = [src if not dst.issubset(src) else no_fail for dst, src in zip(reps_dest, reps)] if any(f is not no_fail for f in fail): raise _RepError(fail) -def _match_spec(mesh: Mesh, check_rep: bool, - pspec: PartitionSpec, x: JaxType) -> JaxType: - fn = HashablePartial(_match, mesh, check_rep, pspec) +def _match_spec(mesh: Mesh, check_rep, src_pspec: PartitionSpec, + dst_pspec: PartitionSpec, x: JaxType) -> JaxType: + fn = HashablePartial(_match, mesh, check_rep, src_pspec, dst_pspec) with core.eval_context(), jax.disable_jit(False): - return jax.jit(fn, out_shardings=NamedSharding(mesh, pspec))(x) + return jax.jit(fn, out_shardings=NamedSharding(mesh, dst_pspec))(x) -def _match(mesh, check_rep, pspec, x): - src = P(mesh.axis_names) - return shard_map(_rem_singleton, mesh, (src,), pspec, check_rep=False)(x) +def _match(mesh, check_rep, src_pspec, dst_pspec, x): + return shard_map(_rem_singleton, mesh, src_pspec, dst_pspec, + check_rep=check_rep)(x) def _rem_singleton(x): return jnp.squeeze(x, axis=0) def _add_singleton(x): return jnp.expand_dims(x, axis=0) @@ -938,24 +977,40 @@ def to_val_rep_pair(self, val): elif isinstance(val, Tracer): raise Exception(f"Shouldn't have any non-shard_map tracers: {val}") else: - val_ = _unmatch_spec(self.mesh, {}, val, self.context_mesh) - return val_, None + val_ = _unmatch_spec(self.mesh, self.check, {}, val, self.context_mesh) + if self.check: + return val_, frozenset(self.mesh.axis_names) - self.auto + else: + return val_, None def process_primitive(self, prim, tracers, params): in_vals, in_rep = unzip2(map(self.to_val_rep_pair, tracers)) + if self.check: + in_vma = tuple(map(partial(_rep_to_vma, self.mesh, self.auto), in_rep)) + out_avals, _ = prim.abstract_eval(*(typeof(t) for t in tracers), **params) + out_avals = tuple(out_avals) if type(out_avals) is list else out_avals + out_vma = tree_map(lambda a: a.vma, out_avals) + out_rep = tree_map(partial(_vma_to_rep, self.mesh, self.auto), out_vma) + in_specs = tuple(map(partial(_vma_to_spec, self.mesh), in_vma)) + out_specs = tree_map(partial(_vma_to_spec, self.mesh), out_vma) + else: + out_rep = frozenset() + in_specs = out_specs = P(self.mesh.axis_names) + eager_rule = eager_rules.get(prim) if eager_rule: out_vals = eager_rule(self.mesh, *in_vals, **params) else: - f = HashablePartial(_prim_applier, prim, tuple(params.items()), self.mesh) + f = HashablePartial( + _prim_applier, prim, self.check, tuple(params.items()), self.mesh, + in_specs, out_specs) with (core.eval_context(), jax.disable_jit(False), jax.debug_nans(False), jax.debug_infs(False), use_abstract_mesh(self.context_mesh)): out_vals = jax.jit(f)(*in_vals) _maybe_check_special(out_vals) - rep_rule = _check_rules.get(prim, partial(_rule_missing, prim)) - out_rep = rep_rule(self.mesh, *in_rep, **params) if self.check else set() if prim.multiple_results: - out_rep = [out_rep] * len(out_vals) if type(out_rep) is set else out_rep + out_rep = (out_rep if isinstance(out_rep, (list, tuple)) + else [out_rep] * len(out_vals)) return map(partial(ShardMapTracer, self), out_rep, out_vals) return ShardMapTracer(self, out_rep, out_vals) @@ -974,11 +1029,6 @@ def process_map(self, map_primitive, fun, tracers, params): def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): # Since ShardMapTrace is only used as a base main, we can drop the jvp. - if symbolic_zeros: - msg = ("custom_jvp symbolic_zeros support with shard_map is not " - "implemented; please open an issue at " - "https://github.com/jax-ml/jax/issues") - raise NotImplementedError(msg) del prim, jvp, symbolic_zeros in_vals, in_rep = unzip2(map(self.to_val_rep_pair, tracers)) out_vals, out_rep = _run_shmap(fun, self.mesh, self.auto, in_vals, in_rep, self.check, @@ -1015,7 +1065,10 @@ def aval(self): new_sharding = NamedSharding( _as_manual_mesh(self._trace.mesh, self._trace.auto), out.sharding.spec) # pytype: disable=attribute-error - return out.update(sharding=new_sharding) + manual_axes = set(self._trace.mesh.axis_names) - self._trace.auto + vma = (frozenset(manual_axes - self.rep) if config._check_rep.value else + frozenset()) + return out.update(sharding=new_sharding, vma=vma) def to_concrete_value(self): if self.rep == set(self._trace.mesh.axis_names): @@ -1025,6 +1078,9 @@ def to_concrete_value(self): return None def __str__(self) -> str: + pb_names = set(self._trace.mesh.axis_names) - _rep_to_vma( + self._trace.mesh, self._trace.auto, self.rep) + self = pvary(self, tuple(pb_names)) with core.eval_context(), use_abstract_mesh(self._trace.context_mesh): blocks = list(self.val) mesh = self._trace.mesh @@ -1034,24 +1090,32 @@ def __str__(self) -> str: for (idx, device), block in zip(np.ndenumerate(mesh.devices), blocks)) __repr__ = __str__ # for debuggers, like `p x` -def _prim_applier(prim, params_tup, mesh, *args): +def _prim_applier(prim, check_rep, params_tup, mesh, in_specs, out_specs, *args): def apply(*args): outs = prim.bind(*map(_rem_singleton, args), **dict(params_tup)) return tree_map(_add_singleton, outs) - spec = P(mesh.axis_names) - return shard_map(apply, mesh, spec, spec, False)(*args) + out_specs = list(out_specs) if type(out_specs) is tuple else out_specs + return shard_map(apply, mesh, in_specs, out_specs, check_rep=check_rep)(*args) eager_rules: dict[core.Primitive, Callable] = {} + # TODO(mattjj): working around an apparent XLA or PjRt bug, remove eventually -def _debug_callback_eager_rule(mesh, *args, callback: Callable[..., Any], - effect: debugging.DebugEffect): +def _debug_callback_eager_rule( + mesh, + *args, + callback: Callable[..., Any], + effect: debugging.DebugEffect, + partitioned: bool, +): del effect with core.eval_context(): all_blocks = zip(*map(list, args)) for (idx, device), blocks in zip(np.ndenumerate(mesh.devices), all_blocks): callback(*blocks) return [] + + eager_rules[debugging.debug_callback_p] = _debug_callback_eager_rule def _device_put_eager_rule(mesh, *xs, srcs, devices, copy_semantics): @@ -1063,46 +1127,6 @@ def _device_put_eager_rule(mesh, *xs, srcs, devices, copy_semantics): return xs eager_rules[dispatch.device_put_p] = _device_put_eager_rule -# New primitives for efficient transposition - -# psum2_p is like psum_p except has a different transpose, so mostly copied: -psum2_p = core.Primitive('psum2') -psum2_p.multiple_results = True -psum2_p.def_impl(lax_parallel.psum_p.impl) -psum2_p.def_effectful_abstract_eval(lax_parallel.psum_p.abstract_eval) -mlir.register_lowering(psum2_p, mlir._lowerings[lax_parallel.psum_p]) -batching.fancy_primitive_batchers[psum2_p] = \ - partial(lax_parallel._batched_reduction_collective, psum2_p, - lambda v, axis_size: axis_size * v) -batching.skippable_batchers[psum2_p] = partial(lax_parallel._names_in_param, 'axes') - -def _psum2_transpose_rule(cts, *args, axes, axis_index_groups): - del args - return pbroadcast_p.bind(*cts, axes=axes, axis_index_groups=axis_index_groups) -ad.deflinear2(psum2_p, _psum2_transpose_rule) - -# pbroadcast_p is exactly the transpose of psum2_p -def pbroadcast(x, axis_name): - axes = (axis_name,) if not isinstance(axis_name, tuple) else axis_name - if not axis_name: return x - xs, treedef = tree_flatten(x) - ys = pbroadcast_p.bind(*xs, axes=axes, axis_index_groups=None) - return tree_unflatten(treedef, ys) -pbroadcast_p = core.Primitive('pbroadcast') -pbroadcast_p.multiple_results = True -pbroadcast_p.def_impl(lambda *args, axes, axis_index_groups: args) -pbroadcast_p.def_abstract_eval(lambda *args, axes, axis_index_groups: args) -mlir.register_lowering(pbroadcast_p, lambda ctx, *x, axes, axis_index_groups: x) -def _pbroadcast_batcher(vals_in, dims_in, *, axes, axis_index_groups): - if any(type(axis) is int for axis in axes): raise NotImplementedError - vals_out = pbroadcast_p.bind(*vals_in, axes=axes, - axis_index_groups=axis_index_groups) - return vals_out, dims_in -batching.primitive_batchers[pbroadcast_p] = _pbroadcast_batcher -ad.deflinear2(pbroadcast_p, - lambda cts, *_, axes, axis_index_groups: - psum2_p.bind(*cts, axes=axes, axis_index_groups=axis_index_groups)) - # Rewrite rules and static replication checking for efficient transposition _rewrite_rules: dict[core.Primitive, Callable] = {} @@ -1117,6 +1141,14 @@ def _pbroadcast_batcher(vals_in, dims_in, *, axes, axis_index_groups): register_standard_check = \ lambda prim: _check_rules.setdefault(prim, partial(_standard_check, prim)) +def _eq_rep(mesh, r1, r2) -> bool: + if r1 != r2 and r1 is None or r2 is None: + r1, r2 = _remove_none_rep(mesh, r1), _remove_none_rep(mesh, r2) + return r1 == r2 + +def _remove_none_rep(mesh, r): + return set(mesh.axis_names) if r is None else r + def _no_rewrite(prim, rule, mesh, in_rep, *args, **params): out_vals = prim.bind(*args,**params) out_rep = rule(mesh, *in_rep, **params) @@ -1129,7 +1161,7 @@ def _no_rewrite(prim, rule, mesh, in_rep, *args, **params): def _standard_rewrite_rule(prim, mesh, in_rep, *args, **params): # The standard rewrite inserts pbroadcasts but doesn't change the primitive. out_rep_ = set.intersection(*in_rep) if in_rep else set(mesh.axis_names) - args_ = [pbroadcast(x, tuple(n for n in src if n not in out_rep_)) + args_ = [pvary(x, tuple(n for n in src if n not in out_rep_)) if src - out_rep_ else x for x, src in zip(args, in_rep)] out_vals_ = prim.bind(*args_, **params) out_rep = [out_rep_] * len(out_vals_) if prim.multiple_results else [out_rep_] @@ -1140,7 +1172,7 @@ def _standard_check(prim, mesh, *in_rep, **__): # The standard check require args' and outputs' replications to be the same, # except for Nones which correspond to constants. in_rep_ = [r for r in in_rep if r is not None] - if in_rep_ and not in_rep_[:-1] == in_rep_[1:]: + if in_rep_ and in_rep_[:-1] != in_rep_[1:]: raise Exception(f"Primitive {prim} requires argument replication types " f"to match, but got {in_rep}. Please open an issue at " "https://github.com/jax-ml/jax/issues and as a temporary " @@ -1172,7 +1204,7 @@ def _standard_collective_rewrite(prim, mesh, in_rep, x, axis_name, **params): x_rep, = in_rep axis_name_set = set(axis_name) if pbroadcast_axis_name := axis_name_set & x_rep: - x = pbroadcast(x, tuple(pbroadcast_axis_name)) + x = pvary(x, tuple(pbroadcast_axis_name)) out_val = prim.bind(x, axis_name=axis_name, **params) return [out_val], [x_rep - axis_name_set] @@ -1194,7 +1226,7 @@ def _reduction_collective_rewrite(prim, mesh, in_rep, x, axes, **params): x_rep, = in_rep axes_set = set(axes) if pbroadcast_axes := axes_set & x_rep: - x = pbroadcast(x, tuple(pbroadcast_axes)) + x = pvary(x, tuple(pbroadcast_axes)) out_val, = prim.bind(x, axes=axes, **params) return [out_val], [x_rep | axes_set] @@ -1229,13 +1261,14 @@ def _psum_rewrite(mesh, in_rep, *args, axes, axis_index_groups): axes = (axes,) if not isinstance(axes, tuple) else axes axes_ = set(axes) out_rep = [r | axes_ for r in in_rep] # TODO determinism (and elsewhere) - args_ = [pbroadcast(x, tuple(n for n in mesh.axis_names if n in axes_ & src)) + args_ = [pvary(x, tuple(n for n in mesh.axis_names if n in axes_ & src)) for x, src in zip(args, in_rep)] - out_val = psum2_p.bind(*args_, axes=axes, axis_index_groups=axis_index_groups) + out_val = lax_parallel.psum_invariant_p.bind( + *args_, axes=axes, axis_index_groups=axis_index_groups) return out_val, out_rep -@register_check(psum2_p) +@register_check(lax_parallel.psum_invariant_p) def _psum2_check(mesh, *in_rep, axes, axis_index_groups): assert type(axes) is tuple if any(set(axes) & r for r in in_rep if r is not None): @@ -1246,10 +1279,10 @@ def _psum2_check(mesh, *in_rep, axes, axis_index_groups): "workaround pass the check_rep=False argument to shard_map") in_rep = tuple(set(mesh.axis_names) if r is None else r for r in in_rep) return [r | set(axes) for r in in_rep] -register_norewrite(psum2_p) +register_norewrite(lax_parallel.psum_invariant_p) -@register_check(pbroadcast_p) +@register_check(pvary_p) def _pbroadcast_check(mesh, *in_rep, axes, axis_index_groups): assert type(axes) is tuple if not all(r is None or set(axes) & r for r in in_rep): @@ -1261,7 +1294,7 @@ def _pbroadcast_check(mesh, *in_rep, axes, axis_index_groups): "workaround pass the check_rep=False argument to shard_map") in_rep = tuple(set(mesh.axis_names) if r is None else r for r in in_rep) return [r - set(axes) for r in in_rep] -register_norewrite(pbroadcast_p) +register_norewrite(pvary_p) register_standard_collective(lax_parallel.all_gather_p) @@ -1343,7 +1376,7 @@ def _scan_check(mesh, *in_rep, jaxpr, num_consts, num_carry, **_): _, carry_rep_in, _ = split_list(in_rep, [num_consts, num_carry]) out_rep = _check_rep(mesh, jaxpr.jaxpr, in_rep) carry_rep_out, _ = split_list(out_rep, [num_carry]) - if not carry_rep_in == carry_rep_out: + if not all(map(partial(_eq_rep, mesh), carry_rep_in, carry_rep_out)): raise Exception("Scan carry input and output got mismatched replication " f"types {carry_rep_in} and {carry_rep_out}. Please open an " "issue at https://github.com/jax-ml/jax/issues, and as a " @@ -1366,7 +1399,7 @@ def _scan_rewrite(mesh, in_rep, *args, jaxpr, num_consts, num_carry, **params): else: assert False, 'Fixpoint not reached' - args = [pbroadcast(x, tuple(n for n in src if n not in dst)) + args = [pvary(x, tuple(n for n in src if n not in dst)) if src - dst else x for x, src, dst in zip(args, in_rep, in_rep_)] out_rep = [*carry_rep_out, *ys_rep] jaxpr_ = _replication_rewrite_match(mesh, jaxpr, in_rep_, out_rep) @@ -1375,13 +1408,51 @@ def _scan_rewrite(mesh, in_rep, *args, jaxpr, num_consts, num_carry, **params): *args, jaxpr=jaxpr_, num_consts=num_consts, num_carry=num_carry, **params) return out_vals, out_rep +@register_check(control_flow.loops.while_p) +def _while_check(mesh, *in_rep, body_jaxpr, cond_nconsts, body_nconsts, **_): + _, bconst_rep, carry_rep_in = split_list(in_rep, [cond_nconsts, body_nconsts]) + carry_rep_out = _check_rep(mesh, body_jaxpr.jaxpr, [*bconst_rep, *carry_rep_in]) + if tuple(carry_rep_in) != tuple(carry_rep_out): + raise Exception("while_loop carry input and output got mismatched " + f"replication types {carry_rep_in} and {carry_rep_out}. " + "Please open an issue at " + "https://github.com/jax-ml/jax/issues, and as a temporary " + "workaround pass the check_rep=False argument to shard_map") + return carry_rep_out + +@register_rewrite(control_flow.loops.while_p) +def _while_rewrite(mesh, in_rep, *args, cond_jaxpr, body_jaxpr, cond_nconsts, + body_nconsts): + # while while isn't transposable, we insert pbroadcasts for consistent carry + cconst_rep, bconst_rep, carry_rep_in = split_list(in_rep, [cond_nconsts, body_nconsts]) + num_carry = len(args) - cond_nconsts - body_nconsts + for _ in range(1 + num_carry): + in_rep_ = [*bconst_rep, *carry_rep_in] + _, carry_rep_out = _replication_rewrite_nomatch(mesh, body_jaxpr, in_rep_) + if tuple(carry_rep_in) == tuple(carry_rep_out): + break + carry_rep_in = map(op.and_, carry_rep_in, carry_rep_out) + else: + assert False, "Fixpoint not reached" + + cond_jaxpr_, _ = _replication_rewrite_nomatch( + mesh, cond_jaxpr, (*cconst_rep, *carry_rep_in)) + body_jaxpr_ = _replication_rewrite_match( + mesh, body_jaxpr, (*bconst_rep, *carry_rep_in), carry_rep_out) + args_ = [pvary(x, tuple(n for n in src if n not in dst)) + if src - dst else x for x, src, dst in zip(args, in_rep, in_rep_)] + out_vals = control_flow.loops.while_p.bind( + *args_, cond_jaxpr=cond_jaxpr_, body_jaxpr=body_jaxpr_, + cond_nconsts=cond_nconsts, body_nconsts=body_nconsts) + return out_vals, carry_rep_out + @register_check(control_flow.conditionals.cond_p) def _cond_rule(mesh, *in_rep, branches): _, *args_rep = in_rep out_rep = _check_rep(mesh, branches[0].jaxpr, args_rep) for branch in branches[1:]: out_rep_ = _check_rep(mesh, branch.jaxpr, args_rep) - if not out_rep_ == out_rep: + if not all(map(partial(_eq_rep, mesh), out_rep, out_rep_)): raise Exception("The branches of cond produced mismatched replication " "types. Please open an issue at " "https://github.com/jax-ml/jax/issues, and as a " @@ -1421,12 +1492,12 @@ def _closed_call_check(mesh, *in_rep, call_jaxpr, **kwargs): return _check_rep(mesh, call_jaxpr.jaxpr, in_rep) -@register_check(custom_derivatives.custom_jvp_call_p) +@register_check(cd.custom_jvp_call_p) def _custom_jvp_call_check(mesh, *in_rep, call_jaxpr, jvp_jaxpr_fun, num_consts, symbolic_zeros): return _check_rep(mesh, call_jaxpr.jaxpr, in_rep) -@register_rewrite(custom_derivatives.custom_vjp_call_jaxpr_p) +@register_rewrite(cd.custom_vjp_call_jaxpr_p) def _custom_vjp_call_jaxpr_rewrite( mesh, in_rep, *args, fun_jaxpr, fwd_jaxpr_thunk, bwd, num_consts, out_trees, symbolic_zeros): @@ -1449,13 +1520,13 @@ def fwd_jaxpr_thunk_(*zeros): bwd_ = _rewrite_bwd(bwd, mesh, lambda: out_rep2[0], in_rep_) - outs = custom_derivatives.custom_vjp_call_jaxpr_p.bind( + outs = cd.custom_vjp_call_jaxpr_p.bind( *args, fun_jaxpr=fun_jaxpr_, fwd_jaxpr_thunk=fwd_jaxpr_thunk_, bwd=bwd_, num_consts=num_consts, out_trees=out_trees, symbolic_zeros=symbolic_zeros) out_rep = out_rep2[0] if out_rep2 else out_rep return outs, out_rep -@register_check(custom_derivatives.custom_vjp_call_jaxpr_p) +@register_check(cd.custom_vjp_call_jaxpr_p) def _custom_vjp_call_jaxpr_check(mesh, *in_rep, fun_jaxpr, **_): return _check_rep(mesh, fun_jaxpr.jaxpr, in_rep) @@ -1487,7 +1558,6 @@ def _shard_map_batch( in_names: tuple[AxisNames, ...], out_names_thunk: Callable[[], tuple[AxisNames, ...]], check_rep: bool, - rewrite: bool, auto: frozenset) -> Sequence[batching.BatchTracer]: in_vals, in_dims = unzip2(map(trace.to_batch_info, in_tracers)) if any(isinstance(d, batching.RaggedAxis) for d in in_dims): @@ -1513,7 +1583,7 @@ def new_out_names_thunk(): new_params = dict(mesh=mesh, in_names=new_in_names, out_names_thunk=new_out_names_thunk, check_rep=check_rep, - rewrite=rewrite, auto=auto) + auto=auto) with core.set_current_trace(trace.parent_trace): out_vals = prim.bind(fun, *in_vals, **new_params) make_tracer = partial(batching.BatchTracer, trace, @@ -1536,7 +1606,7 @@ def _batch_out_names(spmd_axis_name, dims, out_names): # Autodiff def _shard_map_jvp(trace, shard_map_p, f, tracers, mesh, in_names, - out_names_thunk, check_rep, rewrite, auto): + out_names_thunk, check_rep, auto): primals, tangents = unzip2(map(trace.to_primal_tangent_pair, tracers)) which_nz = [ type(t) is not ad.Zero for t in tangents] tangents = [t if type(t) is not ad.Zero else None for t in tangents] @@ -1551,7 +1621,7 @@ def new_out_names_thunk(): return (*out_ax, *(ax for ax, nz in zip(out_ax, which_nz_out()) if nz)) params = dict(mesh=mesh, in_names=(*in_names, *tangent_in_names), out_names_thunk=new_out_names_thunk, check_rep=check_rep, - rewrite=rewrite, auto=auto) + auto=auto) f_jvp, out_tree = ad.traceable(f_jvp, in_tree) result = shard_map_p.bind_with_trace(trace.parent_trace, (f_jvp,) + tuple(args), params) primal_out, tangent_out = tree_unflatten(out_tree(), result) @@ -1562,47 +1632,66 @@ def new_out_names_thunk(): def _shard_map_partial_eval(trace: pe.JaxprTrace, shard_map_p, f: lu.WrappedFun, tracers, mesh, in_names, - out_names_thunk, check_rep, rewrite, auto): + out_names_thunk, check_rep, auto): tracers = map(trace.to_jaxpr_tracer, tracers) in_pvals = [t.pval for t in tracers] in_knowns, in_avals, in_consts = pe.partition_pvals(in_pvals) unk_in_names, known_in_names = pe.partition_list(in_knowns, in_names) - all_names = _all_newly_manual_mesh_names(mesh, auto, trace) - in_avals_sharded = map(partial(_shard_aval, mesh, auto), unk_in_names, in_avals) + in_avals_sharded = map(partial(_shard_aval, mesh, auto, check_rep), + unk_in_names, in_avals) f = pe.trace_to_subjaxpr_nounits_fwd2(f, trace.tag, f.debug_info, False) f = _promote_scalar_residuals(f) - f_known, aux = pe.partial_eval_wrapper_nounits( + f_known, aux = pe.partial_eval_wrapper_nounits2( f, (*in_knowns,), (*in_avals_sharded,)) + all_names = _all_newly_manual_mesh_names(mesh, auto, trace) @as_hashable_function(closure=out_names_thunk) def known_out_names(): - in_fwd, out_fwd, out_knowns, _, jaxpr, _ = aux() + _, _, out_knowns, res_avals, _, _ = aux() _, out_known_names = pe.partition_list(out_knowns, out_names_thunk()) - num_res = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)) - return (*out_known_names, *({0: all_names},) * num_res) + if check_rep: + res_names = [{0: tuple(i for i in mesh.axis_names if i in a.vma)} + for a in res_avals] + else: + res_names = [{0: all_names}] * len(res_avals) + return (*out_known_names, *res_names) known_params = dict(mesh=mesh, in_names=(*known_in_names,), out_names_thunk=known_out_names, check_rep=check_rep, - rewrite=rewrite, auto=auto) - out = shard_map_p.bind_with_trace(trace.parent_trace, (f_known, *in_consts), known_params) - in_fwd, out_fwd, out_knowns, out_avals_sharded, jaxpr, env = aux() + auto=auto) + out = shard_map_p.bind_with_trace(trace.parent_trace, (f_known, *in_consts), + known_params) + in_fwd, out_fwd, out_knowns, res_avals, jaxpr, env = aux() num_res = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)) out_consts, non_fwd_res = split_list(out, [len(out) - num_res]) assert not jaxpr.constvars unk_out_names, _ = pe.partition_list(out_knowns, out_names_thunk()) known_out_names_ = known_out_names() res = subs_list2(in_fwd, out_fwd, in_consts, out_consts, non_fwd_res) - res_names = [known_in_names[f1] if f1 is not None else - known_out_names_[f2] if f2 is not None else - {0: all_names} for f1, f2 in zip(in_fwd, out_fwd)] + # TODO make res_avals be the full set, not just the non-fwd ones + res_avals_iter = iter(res_avals) + res_names = [] + for f1, f2 in zip(in_fwd, out_fwd): + if f1 is not None: + res_names.append(known_in_names[f1]) + elif f2 is not None: + res_names.append(known_out_names_[f2]) + else: + if check_rep: + res_vma = next(res_avals_iter).vma + res_names.append({0: tuple(n for n in mesh.axis_names if n in res_vma)}) + else: + res_names.append({0: all_names}) unk_in_names = (*res_names,) + ({},) * len(env) + (*unk_in_names,) # type: ignore[assignment] const_tracers = map(trace.new_instantiated_const, res) env_tracers = map(trace.to_jaxpr_tracer, env) unk_arg_tracers = [t for t in tracers if not t.is_known()] + out_avals_sharded = [v.aval for v in jaxpr.outvars] unk_params = dict(mesh=mesh, in_names=unk_in_names, - out_names=unk_out_names, jaxpr=jaxpr, check_rep=False, - rewrite=rewrite, auto=auto) - out_avals = map(partial(_unshard_aval, mesh), unk_out_names, out_avals_sharded) + out_names=unk_out_names, jaxpr=jaxpr, + check_rep=check_rep, auto=auto) + out_avals = map(partial(_unshard_aval, mesh, check_rep), unk_out_names, + out_avals_sharded) out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(a), None) for a in out_avals] effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names) @@ -1615,41 +1704,54 @@ def known_out_names(): def _shard_map_linearize(trace, shard_map_p, f: lu.WrappedFun, tracers, mesh, in_names, - out_names_thunk, check_rep, rewrite, auto): + out_names_thunk, check_rep, auto): primals, tangents = unzip2(map(trace.to_primal_tangent_pair, tracers)) nzs_in = tuple(type(t) is not ad.Zero for t in tangents) f_primal, linearize_outs_thunk = ad.linearize_subtrace(f, trace.tag, nzs_in, f.debug_info) f_primal = _promote_scalar_residuals_lin(f_primal, linearize_outs_thunk) - tangent_in_names = [ax for ax, nz in zip(in_names, nzs_in) if nz] - res_names = _all_newly_manual_mesh_names(mesh, auto, trace) + all_names = _all_newly_manual_mesh_names(mesh, auto, trace) @as_hashable_function(closure=linearize_outs_thunk) def fwd_out_names_thunk(): - _, _, _, _, in_fwd, out_fwd = linearize_outs_thunk() + res_avals, _, _, _, _, _ = linearize_outs_thunk() out_names = out_names_thunk() - num_res_out = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)) - # This is incorrect so we set `check_rep=False` in the tangent (as in JVP). - return (*({0: res_names} for _ in range(num_res_out)), *out_names) + if check_rep: + res_names = [{0: tuple(i for i in mesh.axis_names if i in a.vma)} + for a in res_avals] + else: + res_names = [{0: all_names}] * len(res_avals) + return (*res_names, *out_names) fwd_params = dict( mesh=mesh, in_names=in_names, - out_names_thunk=fwd_out_names_thunk, check_rep=check_rep, - rewrite=rewrite, auto=auto) + out_names_thunk=fwd_out_names_thunk, check_rep=check_rep, auto=auto) all_fwd_results = shard_map_p.bind_with_trace( trace.parent_trace, (f_primal, *primals), fwd_params) - residual_avals, nzs_out, lin_jaxpr, env, in_fwd, out_fwd = linearize_outs_thunk() + res_avals, nzs_out, lin_jaxpr, env, in_fwd, out_fwd = linearize_outs_thunk() num_res_out = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)) non_fwd_res = all_fwd_results[:num_res_out] primals_out = all_fwd_results[num_res_out:] residuals = subs_list2(in_fwd, out_fwd, primals, primals_out, non_fwd_res) args_to_promote = [getattr(aval, 'shape', ()) == () and f1 is None and f2 is None - for aval, f1, f2 in zip(residual_avals, in_fwd, out_fwd)] - with _extend_axis_env(mesh, auto), use_abstract_mesh(_as_manual_mesh(mesh, auto)): + for aval, f1, f2 in zip(res_avals, in_fwd, out_fwd)] + with (_extend_axis_env(mesh, auto), + use_abstract_mesh(_as_manual_mesh(mesh, auto)), + config._check_rep(check_rep)): lin_jaxpr = _promote_scalar_residuals_jaxpr(lin_jaxpr, args_to_promote) out_names = out_names_thunk() - residual_names = [in_names[f1] if f1 is not None else - out_names[f2] if f2 is not None else - {0: res_names} for f1, f2 in zip(in_fwd, out_fwd)] - new_in_names = (*residual_names, *({} for _ in range(len(env))), + res_avals_iter = iter(res_avals) + res_names = [] + for f1, f2 in zip(in_fwd, out_fwd): + if f1 is not None: + res_names.append(in_names[f1]) + elif f2 is not None: + res_names.append(out_names[f2]) + else: + if check_rep: + res_vma = next(res_avals_iter).vma + res_names.append({0: tuple(n for n in mesh.axis_names if n in res_vma)}) + else: + res_names.append({0: all_names}) + new_in_names = (*res_names, *({} for _ in range(len(env))), *(ax for ax, nz in zip(in_names, nzs_in) if nz)) tangent_out_names = tuple(ax for ax, nz in zip(out_names_thunk(), nzs_out) if nz) @as_hashable_function(closure=tangent_out_names) @@ -1657,7 +1759,7 @@ def tangent_out_names_thunk(): return tangent_out_names tangent_params = dict( mesh=mesh, in_names=new_in_names, out_names_thunk=tangent_out_names_thunk, - check_rep=False, rewrite=rewrite, auto=auto) + check_rep=check_rep, auto=auto) # TODO(mattjj): avoid round-tripping the jaxpr through eval_jaxpr here def f_tangent(*args): @@ -1695,8 +1797,7 @@ def _promote_scalar_residuals(f: Callable, *args, **kwargs): for x in out_consts] return jaxpr, (in_fwds, out_fwds, out_pvals, out_consts, env) -def _promote_scalar_residuals_jaxpr(jaxpr: core.Jaxpr, - which: Sequence[bool]): +def _promote_scalar_residuals_jaxpr(jaxpr: core.Jaxpr, which: Sequence[bool]): def fun(*res_and_args): res, args = split_list(res_and_args, [len(jaxpr.constvars)]) res = [_rem_singleton(x) if w else x for x, w in zip(res, which)] @@ -1720,16 +1821,16 @@ def _unmentioned2(mesh: Mesh, names: AxisNames, def _shard_map_transpose(out_cts, *args, jaxpr: core.Jaxpr, mesh, in_names, out_names, - check_rep, rewrite, auto): + check_rep, auto): mb_div = lambda x, y: x / y if y != 1 else x out_cts = [ - ad.Zero(_shard_aval(mesh, auto, ns, x.aval)) if type(x) is ad.Zero - else x if rewrite or dtypes.dtype(x) == dtypes.float0 + ad.Zero(_shard_aval(mesh, auto, check_rep, ns, x.aval)) + if type(x) is ad.Zero else x if check_rep or dtypes.dtype(x) == dtypes.float0 else mb_div(x, prod(map(mesh.shape.get, _unmentioned2(mesh, ns, auto)))) for ns, x in zip(out_names, out_cts) ] args = tuple(x if type(x) is not ad.UndefinedPrimal else - ad.UndefinedPrimal(_shard_aval(mesh, auto, ns, x.aval)) + ad.UndefinedPrimal(_shard_aval(mesh, auto, check_rep, ns, x.aval)) for ns, x in zip(in_names, args)) all_args, in_tree = tree_flatten((out_cts, args)) @@ -1744,8 +1845,8 @@ def fun_trans_callable(out_cts, args): jaxpr_unknown.jaxpr, False, (), (*res_reshaped, *undefs), out_cts )[len(res_reshaped):] _, in_ct_names = partition_list(in_undef, in_names) - in_cts = [ad.Zero(_unshard_aval(mesh, ns, x.aval)) if type(x) is ad.Zero - else x if rewrite + in_cts = [ad.Zero(_unshard_aval(mesh, check_rep, ns, x.aval)) + if type(x) is ad.Zero else x if check_rep else jax.lax.psum(x, tuple(_unmentioned2(mesh, ns, auto))) for ns, x in zip(in_ct_names, in_cts)] res_zeros = [ad_util.zero_from_primal(r) for r in res] @@ -1765,7 +1866,7 @@ def new_out_names_thunk(): try: out_flat = shard_map_p.bind( fun_trans_flat, *all_args, mesh=mesh, in_names=tuple(new_in_names), - out_names_thunk=new_out_names_thunk, check_rep=check_rep, rewrite=rewrite, + out_names_thunk=new_out_names_thunk, check_rep=check_rep, auto=auto) except (FloatingPointError, ZeroDivisionError) as e: print("Invalid nan value encountered in the backward pass of a shard_map " @@ -1777,7 +1878,7 @@ def new_out_names_thunk(): _ = shard_map_p.bind( fun_trans_flat, *all_args, mesh=mesh, in_names=tuple(new_in_names), out_names_thunk=new_out_names_thunk, check_rep=check_rep, - rewrite=rewrite, auto=auto) + auto=auto) except (FloatingPointError, ZeroDivisionError) as e2: raise e2 from None else: @@ -1793,8 +1894,8 @@ def _partial_eval_jaxpr_custom_rule( ) -> tuple[core.JaxprEqn, core.JaxprEqn, Sequence[bool], Sequence[bool], list[core.Var]]: jaxpr, mesh = eqn.params['jaxpr'], eqn.params['mesh'] - auto = eqn.params['auto'] - with _extend_axis_env(mesh, auto): + check_rep, auto = eqn.params['check_rep'], eqn.params['auto'] + with _extend_axis_env(mesh, auto), config._check_rep(check_rep): jaxpr_known, jaxpr_staged, unks_out, inst_out, num_res = \ pe.partial_eval_jaxpr_custom(jaxpr, unks_in, inst_in, False, False, saveable) num_out_primals = len(jaxpr_known.outvars) - num_res @@ -1805,7 +1906,8 @@ def _partial_eval_jaxpr_custom_rule( which = [f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)] mesh = eqn.params['mesh'] with (_extend_axis_env(mesh, auto), - use_abstract_mesh(_as_manual_mesh(mesh, auto))): + use_abstract_mesh(_as_manual_mesh(mesh, auto)), + config._check_rep(check_rep)): jaxpr_known = pe.prune_jaxpr_outputs(jaxpr_known, [True] * num_out_primals + which) jaxpr_known, jaxpr_staged = _add_reshapes(which, jaxpr_known, jaxpr_staged) jaxpr_known = core.remove_named_axis_effects(jaxpr_known, mesh.axis_names) @@ -1815,11 +1917,24 @@ def _partial_eval_jaxpr_custom_rule( _, ins_staged = partition_list(inst_in, eqn.invars) _, out_binders_staged = partition_list(inst_out, eqn.outvars) newvar = core.gensym() - params_known, params_staged, res_names = _pe_custom_params( - unks_in, inst_in, map(op.not_, unks_out), inst_out, in_fwd, out_fwd, which, + residuals, staged_in_res_names = [], [] + for var, w in zip(jaxpr_staged.invars[:num_res], which): + if w: + rn = ({0: tuple(i for i in mesh.axis_names if i in var.aval.vma)} # type: ignore + if check_rep else {0: _all_newly_manual_mesh_names(mesh, auto)}) + residuals.append(newvar(_unshard_aval(mesh, check_rep, rn, var.aval))) + staged_in_res_names.append(rn) + if check_rep: + out_res_names_known = [ + {0: tuple(i for i in mesh.axis_names if i in var.aval.vma)} + for var, o in zip(res_vars, out_fwd) if o is None + ] + else: + out_res_names_known = [{0: _all_newly_manual_mesh_names(mesh, auto)}] * sum(which) + params_known, params_staged = _pe_custom_params( + unks_in, inst_in, map(op.not_, unks_out), inst_out, in_fwd, out_fwd, + out_res_names_known, staged_in_res_names, dict(eqn.params, jaxpr=jaxpr_known), dict(eqn.params, jaxpr=jaxpr_staged)) - residuals = [newvar(_unshard_aval(mesh, {0: res_names}, var.aval)) - for var, w in zip(jaxpr_staged.invars[:num_res], which) if w] eqn_known = pe.new_jaxpr_eqn(ins_known, [*out_binders_known, *residuals], eqn.primitive, params_known, jaxpr_known.effects, eqn.source_info, eqn.ctx) @@ -1866,27 +1981,28 @@ def staged(*args): return jaxpr_known, jaxpr_staged def _pe_custom_params(unks_in, inst_in, kept_outs_known, kept_outs_staged, - in_fwd, out_fwd, which, params_known, params_staged): + in_fwd, out_fwd, out_res_names_known, staged_in_res_names, + params_known, params_staged): # prune inputs to jaxpr_known according to unks_in - mesh = params_known['mesh'] - auto = params_known['auto'] - res_names_ = _all_newly_manual_mesh_names(mesh, auto) in_names_known, _ = partition_list(unks_in, params_known['in_names']) _, out_names_known = partition_list(kept_outs_known, params_known['out_names']) - out_names_known = out_names_known + [{0: res_names_}] * sum(which) + out_names_known = out_names_known + out_res_names_known + assert len(out_names_known) == len(params_known['jaxpr'].outvars) new_params_known = dict(params_known, in_names=tuple(in_names_known), out_names=tuple(out_names_known)) # added num_res new inputs to jaxpr_staged, pruning according to inst_in _, in_names_staged = partition_list(inst_in, params_staged['in_names']) + iter_staged = iter(staged_in_res_names) res_names = [in_names_known[f1] if f1 is not None else out_names_known[f2] if f2 is not None else - {0: res_names_} for f1, f2 in zip(in_fwd, out_fwd)] + next(iter_staged) for f1, f2 in zip(in_fwd, out_fwd)] + in_names_staged = res_names + in_names_staged _, out_names_staged = partition_list(kept_outs_staged, params_staged['out_names']) new_params_staged = dict(params_staged, in_names=tuple(in_names_staged), - out_names=tuple(out_names_staged), check_rep=False) - return new_params_known, new_params_staged, res_names_ + out_names=tuple(out_names_staged)) + return new_params_known, new_params_staged # TODO(mattjj): remove this mechanism when we revise mesh scopes def _all_mesh_names_except_spmd( @@ -1904,7 +2020,7 @@ def _all_newly_manual_mesh_names( vmap_spmd_names = set(axis_env.spmd_axis_names) if not (ctx_mesh := get_abstract_mesh()).empty: mesh = ctx_mesh - already_manual_names = set(ctx_mesh._axis_types_dict.get(AxisType.Manual, ())) + already_manual_names = set(ctx_mesh.manual_axes) else: # TODO(mattjj): remove this mechanism when we revise mesh scopes already_manual_names = set(axis_env.axis_sizes) # may include vmap axis_names @@ -1921,7 +2037,8 @@ def _shard_map_dce(used_outputs: list[bool], eqn: core.JaxprEqn return [False] * len(eqn.invars), None mesh = eqn.params["mesh"] auto = eqn.params["auto"] - with _extend_axis_env(mesh, auto): + check_rep = eqn.params["check_rep"] + with _extend_axis_env(mesh, auto), config._check_rep(check_rep): jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['jaxpr'], used_outputs) if not any(used_inputs) and not any(used_outputs) and not jaxpr.effects: return used_inputs, None @@ -2012,6 +2129,11 @@ def _get_devices(p, backend): return devs[:p.global_axis_size] return devs[:p.local_axis_size] +@lu.transformation2 +def _implicit_pvary_on_output(f, out_names_thunk, *args, **kwargs): + out_flat = f(*args, **kwargs) + return [pvary(o, tuple(_names_to_vma(n) - typeof(o).vma)) + for o, n in zip(out_flat, out_names_thunk())] ### Rewrite! @@ -2073,25 +2195,16 @@ def process_call(self, call_primitive, f, in_tracers, params): return map(partial(RewriteTracer, self), out_reps(), out_vals) def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): - if symbolic_zeros: - msg = ("Please open an issue at https://github.com/jax-ml/jax/issues and " - "as a temporary workaround pass the check_rep=False argument to " - "shard_map") - raise NotImplementedError(msg) in_vals, in_reps = unzip2(map(self.to_val_rep_pair, tracers)) fun, out_reps1 = _rewrite_subtrace(fun, self.tag, self.mesh, in_reps) - jvp, out_reps2 = _rewrite_subtrace(jvp, self.tag, self.mesh, in_reps * 2) + jvp, out_reps2 = _rewrite_jvp_subtrace(jvp, self.tag, self.mesh, in_reps * 2) with core.set_current_trace(self.parent_trace): out_vals = prim.bind(fun, jvp, *in_vals, symbolic_zeros=symbolic_zeros) fst, out_reps = lu.merge_linear_aux(out_reps1, out_reps2) - if not fst: - assert out_reps == out_reps[:len(out_reps) // 2] * 2 - out_reps = out_reps[:len(out_reps) // 2] return map(partial(RewriteTracer, self), out_reps, out_vals) def process_custom_vjp_call(self, prim: core.Primitive, fun: lu.WrappedFun, - fwd: lu.WrappedFun, bwd: lu.WrappedFun, - tracers, + fwd: lu.WrappedFun, bwd: lu.WrappedFun, tracers, out_trees: Callable[[], Sequence[PyTreeDef]], symbolic_zeros: bool): if symbolic_zeros: @@ -2123,7 +2236,7 @@ def _efficient_transpose_rewrite(fun, mesh, in_names, out_names_thunk): def _efficient_transpose_rewrite_nomatch(f, store, mesh, in_reps, *args): with core.take_current_trace() as parent: tag = core.TraceTag() - t = RewriteTrace(parent_trace = parent, tag = tag, mesh=mesh) + t = RewriteTrace(parent_trace=parent, tag=tag, mesh=mesh) in_tracers = map(partial(RewriteTracer, t), in_reps, args) with core.set_current_trace(t): ans = f(*in_tracers) @@ -2138,7 +2251,7 @@ def _match_rep(f, mesh, out_reps_src_, out_reps_dst_, *args): out_reps_src = out_reps_src_() if callable(out_reps_src_) else out_reps_src_ out_reps_dst = out_reps_dst_() if callable(out_reps_dst_) else out_reps_dst_ _check_reps2(mesh, out_reps_dst, out_reps_src) - outs = [pbroadcast(x, tuple(n for n in src if n not in dst)) if src - dst + outs = [pvary(x, tuple(n for n in src if n not in dst)) if src - dst else x for x, src, dst in zip(outs, out_reps_src, out_reps_dst)] return outs @@ -2169,8 +2282,8 @@ def _replication_rewrite_nomatch( return core.ClosedJaxpr(jaxpr_, consts), out_rep() @lu.transformation_with_aux2 -def _rewrite_subtrace(f: Callable, store: lu.Store, - tag: core.TraceTag, mesh: Mesh, in_reps, *in_vals): +def _rewrite_subtrace(f: Callable, store: lu.Store, tag: core.TraceTag, + mesh: Mesh, in_reps, *in_vals): with core.take_current_trace() as parent_trace: assert len(in_reps) == len(in_vals), (len(in_reps), len(in_vals)) t = RewriteTrace(parent_trace, tag, mesh) @@ -2181,6 +2294,31 @@ def _rewrite_subtrace(f: Callable, store: lu.Store, store.store(out_reps) return out_vals +@lu.transformation_with_aux2 +def _rewrite_jvp_subtrace(f: Callable, store: lu.Store, tag: core.TraceTag, + mesh: Mesh, in_reps, *in_vals): + with core.take_current_trace() as parent_trace: + assert len(in_reps) == len(in_vals), (len(in_reps), len(in_vals)) + t = RewriteTrace(parent_trace, tag, mesh) + in_tracers = [x if type(x) is cd.SymbolicZero else RewriteTracer(t, r, x) + for r, x in zip(in_reps, in_vals)] + with core.set_current_trace(t): + out_tracers: list[RewriteTracer | cd.SymbolicZero] = f(*in_tracers) + out_vals, out_reps = unzip2(map(t.to_val_rep_pair, out_tracers)) + out_primals, out_tangents = split_list(out_vals, [len(out_vals) // 2]) + out_primal_reps, out_tangent_reps = split_list(out_reps, [len(out_vals) // 2]) + out_reps = map(_merge_reps, out_primal_reps, out_tangent_reps, out_tangents) + out_tangents = map(_match_replication, out_tangent_reps, out_reps, out_tangents) + store.store(out_reps) + return out_primals + out_tangents + +def _merge_reps(primal_rep, tangent_rep, error_message_val): + if primal_rep - tangent_rep: + raise ValueError("custom_jvp primal output is more replicated than its " + "corresponding tangent of type " + f"{core.typeof(error_message_val).str_short()}") + return primal_rep + def _rewrite_bwd(bwd: lu.WrappedFun, mesh: Mesh, in_reps, reps_dst) -> lu.WrappedFun: def new_bwd(*args): @@ -2192,18 +2330,8 @@ def new_bwd(*args): def _match_replication(src, dst, x): if dst - src: - x, = psum2_p.bind(x, axes=tuple(n for n in dst if n not in src), - axis_index_groups=None) + x, = lax_parallel.psum_invariant_p.bind( + x, axes=tuple(n for n in dst if n not in src), axis_index_groups=None) if src - dst: - x = pbroadcast(x, tuple(n for n in src if n not in dst)) + x = pvary(x, tuple(n for n in src if n not in dst)) return x - -# TODO(parkers,mattjj): change implementation when we have sharding-in-types. -def get_replication(x: jax.Array) -> set[AxisName]: - """For a jax.Array, return what axes it is known to be replicated along.""" - - if isinstance(x, RewriteTracer): - return x.rep - if isinstance(x, batching.BatchTracer): - return get_replication(x.val) - raise ValueError("get_replication not defined on %s" % repr(type(x))) diff --git a/jax/experimental/sparse/_base.py b/jax/experimental/sparse/_base.py index 7739af0291f1..36d84cb0db62 100644 --- a/jax/experimental/sparse/_base.py +++ b/jax/experimental/sparse/_base.py @@ -19,18 +19,8 @@ import jax from jax._src import core -from jax._src import ffi from jax._src import util from jax._src.typing import Array -from jax._src.lib import gpu_sparse - - -if hasattr(gpu_sparse, "registrations"): - for platform, targets in gpu_sparse.registrations().items(): - for name, value, api_version in targets: - ffi.register_ffi_target( - name, value, platform=platform, api_version=api_version - ) class JAXSparse(util.StrictABC): diff --git a/jax/experimental/sparse/_lowerings.py b/jax/experimental/sparse/_lowerings.py index 6962ef78bcff..76e74d13ed69 100644 --- a/jax/experimental/sparse/_lowerings.py +++ b/jax/experimental/sparse/_lowerings.py @@ -18,13 +18,29 @@ """ from functools import partial +from typing import Any from jax._src import core from jax._src import dispatch +from jax._src import ffi from jax._src.interpreters import mlir from jax._src.lib import gpu_sparse import numpy as np +if hasattr(gpu_sparse, "registrations"): + for platform, targets in gpu_sparse.registrations().items(): + for name, value, api_version in targets: + ffi.register_ffi_target( + name, value, platform=platform, api_version=api_version + ) + +def _get_module(target_name_prefix: str) -> Any: + if target_name_prefix == "cu": + return gpu_sparse._cusparse + elif target_name_prefix == "hip": + return gpu_sparse._hipsparse + else: + raise ValueError(f"Unsupported target_name_prefix: {target_name_prefix}") SUPPORTED_DATA_DTYPES = [np.float32, np.float64, np.complex64, np.complex128] SUPPORTED_INDEX_DTYPES = [np.int32] @@ -54,27 +70,30 @@ def _coo_spmv_abstract_eval(data, row, col, x, *, transpose, shape): shape=shape[1:] if transpose else shape[:1], dtype=x.dtype) -def _coo_spmv_gpu_lowering(coo_spmv_hlo, ctx, data, row, col, x, *, transpose, shape): +def _coo_spmv_gpu_lowering(ctx, data, row, col, x, *, transpose, shape, + target_name_prefix): + rows, cols = shape data_aval, row_aval, _, x_aval = ctx.avals_in - return [coo_spmv_hlo( - data, row, col, x, - shape=shape, - transpose=transpose, - data_dtype=data_aval.dtype, - index_dtype=row_aval.dtype, - x_dtype=x_aval.dtype)] + nnz, = data_aval.shape + buffer_size, opaque = _get_module(target_name_prefix).build_coo_matvec_descriptor( + data_aval.dtype, x_aval.dtype, data_aval.dtype, row_aval.dtype, + rows, cols, nnz, transpose) + buffer_aval = core.ShapedArray(shape=(buffer_size,), dtype=np.int8) + sub_ctx = ctx.replace(avals_out=[ctx.avals_out[0], buffer_aval]) + rule = ffi.ffi_lowering(f"{target_name_prefix}sparse_coo_matvec_ffi") + return rule(sub_ctx, data, row, col, x, opaque=opaque)[:1] coo_spmv_p.def_abstract_eval(_coo_spmv_abstract_eval) dispatch.simple_impl(coo_spmv_p) if gpu_sparse.cuda_is_supported: mlir.register_lowering( coo_spmv_p, - partial(_coo_spmv_gpu_lowering, gpu_sparse.cuda_coo_matvec), + partial(_coo_spmv_gpu_lowering, target_name_prefix='cu'), platform='cuda') if gpu_sparse.rocm_is_supported: mlir.register_lowering( coo_spmv_p, - partial(_coo_spmv_gpu_lowering, gpu_sparse.rocm_coo_matvec), + partial(_coo_spmv_gpu_lowering, target_name_prefix='hip'), platform='rocm') @@ -103,27 +122,51 @@ def _coo_spmm_abstract_eval(data, row, col, x, *, transpose, shape): shape=(shape[1] if transpose else shape[0], x.shape[1]), dtype=x.dtype) -def _coo_spmm_gpu_lowering(coo_spmm_hlo, ctx, data, row, col, x, *, transpose, shape): +def _coo_spmm_gpu_lowering(ctx, data, row, col, x, *, transpose, shape, + target_name_prefix): data_aval, row_aval, _, x_aval = ctx.avals_in - return [coo_spmm_hlo( - data, row, col, x, - shape=shape, - transpose=transpose, - data_dtype=data_aval.dtype, - index_dtype=row_aval.dtype, - x_dtype=x_aval.dtype)] + nnz, = data_aval.shape + _, Ccols = x_aval.shape + + batch_count = 1 + if len(shape) == 2: + rows, cols = shape + elif len(shape) == 3: + batch_count, rows, cols = shape + nnz = nnz // batch_count + else: + raise NotImplementedError(f"Unsupported shape: {shape}") + + # TODO(tianjianlu): use batch stride to trigger different mode of batch + # computation. Currently batch_stride = 0 is not allowed because of the issue + # in cusparse https://github.com/NVIDIA/CUDALibrarySamples/issues/81#issuecomment-1205562643 + # Set batch stride to be the matrix size for now. + lhs_batch_stride = nnz + B_rows = rows if transpose else cols + rhs_batch_stride = B_rows * Ccols + + buffer_size, opaque = _get_module(target_name_prefix).build_coo_matmat_descriptor( + data_aval.dtype, x_aval.dtype, data_aval.dtype, row_aval.dtype, + rows, cols, Ccols, nnz, transpose, batch_count, lhs_batch_stride, + rhs_batch_stride) + + buffer_aval = core.ShapedArray(shape=(buffer_size,), dtype=np.int8) + sub_ctx = ctx.replace(avals_out=[ctx.avals_out[0], buffer_aval]) + rule = ffi.ffi_lowering(f"{target_name_prefix}sparse_coo_matmat_ffi") + return rule(sub_ctx, data, row, col, x, opaque=opaque)[:1] + coo_spmm_p.def_abstract_eval(_coo_spmm_abstract_eval) dispatch.simple_impl(coo_spmm_p) if gpu_sparse.cuda_is_supported: mlir.register_lowering( coo_spmm_p, - partial(_coo_spmm_gpu_lowering, gpu_sparse.cuda_coo_matmat), + partial(_coo_spmm_gpu_lowering, target_name_prefix='cu'), platform='cuda') if gpu_sparse.rocm_is_supported: mlir.register_lowering( coo_spmm_p, - partial(_coo_spmm_gpu_lowering, gpu_sparse.rocm_coo_matmat), + partial(_coo_spmm_gpu_lowering, target_name_prefix='hip'), platform='rocm') # csr_spmv_p @@ -151,30 +194,33 @@ def _csr_spmv_abstract_eval(data, indices, indptr, x, *, transpose, shape): shape=shape[1:] if transpose else shape[:1], dtype=x.dtype) -def _csr_spmv_gpu_lowering(csr_spmv_hlo, ctx, data, indices, indptr, x, *, transpose, shape): +def _csr_spmv_gpu_lowering(ctx, data, indices, indptr, x, *, transpose, shape, + target_name_prefix): + rows, cols = shape data_aval, indices_aval, _, x_aval = ctx.avals_in - return [csr_spmv_hlo( - data, indices, indptr, x, - shape=shape, - transpose=transpose, - data_dtype=data_aval.dtype, - index_dtype=indices_aval.dtype, - x_dtype=x_aval.dtype)] + nnz, = data_aval.shape + buffer_size, opaque = _get_module(target_name_prefix).build_csr_matvec_descriptor( + data_aval.dtype, x_aval.dtype, data_aval.dtype, indices_aval.dtype, + rows, cols, nnz, transpose) + buffer_aval = core.ShapedArray(shape=(buffer_size,), dtype=np.int8) + sub_ctx = ctx.replace(avals_out=[ctx.avals_out[0], buffer_aval]) + rule = ffi.ffi_lowering(f"{target_name_prefix}sparse_csr_matvec_ffi") + return rule(sub_ctx, data, indices, indptr, x, opaque=opaque)[:1] csr_spmv_p.def_abstract_eval(_csr_spmv_abstract_eval) dispatch.simple_impl(csr_spmv_p) if gpu_sparse.cuda_is_supported: mlir.register_lowering( csr_spmv_p, - partial(_csr_spmv_gpu_lowering, gpu_sparse.cuda_csr_matvec), + partial(_csr_spmv_gpu_lowering, target_name_prefix='cu'), platform='cuda') if gpu_sparse.rocm_is_supported: mlir.register_lowering( csr_spmv_p, - partial(_csr_spmv_gpu_lowering, gpu_sparse.rocm_csr_matvec), + partial(_csr_spmv_gpu_lowering, target_name_prefix='hip'), platform='rocm') - # csr_spmm_p +# csr_spmm_p # This is an internal-only primitive that calls into cusparse CSR SpMM. # This is a raw lowering that does no validation of inputs; the indices are # assumed to be lexicographically sorted, deduplicated, and in-bounds. @@ -199,25 +245,71 @@ def _csr_spmm_abstract_eval(data, indices, indptr, x, *, transpose, shape): shape=(shape[1] if transpose else shape[0], x.shape[1]), dtype=x.dtype) -def _csr_spmm_gpu_lowering(csr_spmm_hlo, ctx, data, indices, indptr, x, *, transpose, shape): +def _csr_spmm_gpu_lowering(ctx, data, indices, indptr, x, *, transpose, shape, + target_name_prefix): + rows, cols = shape data_aval, indices_aval, _, x_aval = ctx.avals_in - return [csr_spmm_hlo( - data, indices, indptr, x, - shape=shape, - transpose=transpose, - data_dtype=data_aval.dtype, - index_dtype=indices_aval.dtype, - B_dtype=x_aval.dtype)] + nnz, = data_aval.shape + _, Ccols = x_aval.shape + buffer_size, opaque = _get_module(target_name_prefix).build_csr_matmat_descriptor( + data_aval.dtype, x_aval.dtype, data_aval.dtype, indices_aval.dtype, + rows, cols, Ccols, nnz, transpose) + buffer_aval = core.ShapedArray(shape=(buffer_size,), dtype=np.int8) + sub_ctx = ctx.replace(avals_out=[ctx.avals_out[0], buffer_aval]) + rule = ffi.ffi_lowering(f"{target_name_prefix}sparse_csr_matmat_ffi") + return rule(sub_ctx, data, indices, indptr, x, opaque=opaque)[:1] csr_spmm_p.def_abstract_eval(_csr_spmm_abstract_eval) dispatch.simple_impl(csr_spmm_p) if gpu_sparse.cuda_is_supported: mlir.register_lowering( csr_spmm_p, - partial(_csr_spmm_gpu_lowering, gpu_sparse.cuda_csr_matmat), + partial(_csr_spmm_gpu_lowering, target_name_prefix='cu'), platform='cuda') if gpu_sparse.rocm_is_supported: mlir.register_lowering( csr_spmm_p, - partial(_csr_spmm_gpu_lowering, gpu_sparse.rocm_csr_matmat), + partial(_csr_spmm_gpu_lowering, target_name_prefix='hip'), platform='rocm') + +def coo_todense_gpu_lowering(ctx, data, row, col, *, shape, target_name_prefix): + data_aval, row_aval, _ = ctx.avals_in + nnz, = data_aval.shape + rows, cols = shape + buffer_size, opaque = _get_module(target_name_prefix).build_coo_todense_descriptor( + data_aval.dtype, row_aval.dtype, rows, cols, nnz) + buffer_aval = core.ShapedArray(shape=(buffer_size,), dtype=np.int8) + sub_ctx = ctx.replace(avals_out=[ctx.avals_out[0], buffer_aval]) + rule = ffi.ffi_lowering(f"{target_name_prefix}sparse_coo_todense_ffi") + return rule(sub_ctx, data, row, col, opaque=opaque)[0] + +def coo_fromdense_gpu_lowering(ctx, mat, *, nnz, index_dtype, target_name_prefix): + mat_aval, = ctx.avals_in + rows, cols = mat_aval.shape + buffer_size, opaque = _get_module(target_name_prefix).build_coo_fromdense_descriptor( + mat_aval.dtype, np.dtype(index_dtype), rows, cols, nnz) + buffer_aval = core.ShapedArray(shape=(buffer_size,), dtype=np.int8) + sub_ctx = ctx.replace(avals_out=[*ctx.avals_out, buffer_aval]) + rule = ffi.ffi_lowering(f"{target_name_prefix}sparse_coo_fromdense_ffi") + return rule(sub_ctx, mat, opaque=opaque)[:3] + +def csr_todense_gpu_lowering(ctx, data, indices, indptr, *, shape, target_name_prefix): + data_aval, indices_aval, _, = ctx.avals_in + nnz, = data_aval.shape + rows, cols = shape + buffer_size, opaque = _get_module(target_name_prefix).build_csr_todense_descriptor( + data_aval.dtype, indices_aval.dtype, rows, cols, nnz) + buffer_aval = core.ShapedArray(shape=(buffer_size,), dtype=np.int8) + sub_ctx = ctx.replace(avals_out=[ctx.avals_out[0], buffer_aval]) + rule = ffi.ffi_lowering(f"{target_name_prefix}sparse_csr_todense_ffi") + return rule(sub_ctx, data, indices, indptr, opaque=opaque)[0] + +def csr_fromdense_gpu_lowering(ctx, mat, *, nnz, index_dtype, target_name_prefix): + mat_aval, = ctx.avals_in + rows, cols = mat_aval.shape + buffer_size, opaque = _get_module(target_name_prefix).build_csr_fromdense_descriptor( + mat_aval.dtype, np.dtype(index_dtype), rows, cols, nnz) + buffer_aval = core.ShapedArray(shape=(buffer_size,), dtype=np.int8) + sub_ctx = ctx.replace(avals_out=[*ctx.avals_out, buffer_aval]) + rule = ffi.ffi_lowering(f"{target_name_prefix}sparse_csr_fromdense_ffi") + return rule(sub_ctx, mat, opaque=opaque)[:3] diff --git a/jax/experimental/sparse/ad.py b/jax/experimental/sparse/ad.py index 018047e3d5e1..861ef5289cdd 100644 --- a/jax/experimental/sparse/ad.py +++ b/jax/experimental/sparse/ad.py @@ -22,7 +22,7 @@ from jax._src import core from jax import tree_util from jax._src.api_util import _ensure_index, _ensure_index_tuple -from jax.util import safe_zip +from jax._src.util import safe_zip from jax._src.util import split_list, wraps from jax._src.traceback_util import api_boundary from jax.experimental.sparse._base import JAXSparse diff --git a/jax/experimental/sparse/bcoo.py b/jax/experimental/sparse/bcoo.py index 42820fe73651..0365f93d551a 100644 --- a/jax/experimental/sparse/bcoo.py +++ b/jax/experimental/sparse/bcoo.py @@ -38,7 +38,7 @@ from jax.experimental.sparse._lowerings import coo_spmv_p, coo_spmm_p from jax._src.interpreters import mlir import jax.numpy as jnp -from jax.util import safe_zip, unzip2, split_list +from jax._src.util import safe_zip, unzip2, split_list from jax._src import api_util from jax._src import config from jax._src import core diff --git a/jax/experimental/sparse/bcsr.py b/jax/experimental/sparse/bcsr.py index 7fefd1572f45..4b01f362bb83 100644 --- a/jax/experimental/sparse/bcsr.py +++ b/jax/experimental/sparse/bcsr.py @@ -27,12 +27,13 @@ import jax.numpy as jnp from jax import lax from jax import tree_util +from jax.experimental.sparse import _lowerings from jax.experimental.sparse._base import JAXSparse from jax.experimental.sparse import bcoo from jax.experimental.sparse.util import ( nfold_vmap, _count_stored_elements, _csr_to_coo, CuSparseEfficiencyWarning, SparseInfo, Shape) -from jax.util import split_list, safe_zip +from jax._src.util import split_list, safe_zip from jax._src import api_util from jax._src import config @@ -620,9 +621,9 @@ def _bcsr_correct_out_of_bound_indices(data, indices, indptr, rhs, *, shape): _bcsr_correct_out_of_bound_indices, multiple_results=True) def _bcsr_dot_general_gpu_lowering( - csr_matvec_lowering, csr_matmat_lowering, + # csr_matvec_lowering, csr_matmat_lowering, ctx, lhs_data, lhs_indices, lhs_indptr, rhs, *, dimension_numbers, - preferred_element_type, lhs_spinfo: SparseInfo): + preferred_element_type, lhs_spinfo: SparseInfo, target_name_prefix): if not config.bcoo_cusparse_lowering.value: return _bcsr_dot_general_default_lowering( @@ -674,22 +675,23 @@ def _bcsr_dot_general_gpu_lowering( lhs_data, lhs_indices = _bcsr_correct_out_of_bound_indices_lowered( ctx, lhs_data, lhs_indices, lhs_indptr, rhs, shape=lhs_spinfo.shape) + sub_ctx = ctx if rhs_aval.ndim == 1: - dot_general_fn = csr_matvec_lowering - x_dtype = 'x_dtype' + dot_general_fn = _lowerings._csr_spmv_gpu_lowering elif rhs_aval.ndim == 2: - dot_general_fn = csr_matmat_lowering - x_dtype = 'B_dtype' + dot_general_fn = _lowerings._csr_spmm_gpu_lowering if rhs_contract[0] == 1: rhs = hlo.transpose(rhs, permutation=mlir.dense_int_array([1, 0])) + *avals_in, rhs_aval = sub_ctx.avals_in + rhs_aval = core.ShapedArray( + shape=(rhs_aval.shape[1], rhs_aval.shape[0]), dtype=rhs_aval.dtype) + sub_ctx = sub_ctx.replace(avals_in=[*avals_in, rhs_aval]) else: raise ValueError(f"rhs has to be 1d or 2d; get {rhs_aval.ndim}d.") - return [dot_general_fn(lhs_data, lhs_indices, lhs_indptr, rhs, - shape=lhs_spinfo.shape, transpose=False, - data_dtype=lhs_data_aval.dtype, - index_dtype=lhs_indices_aval.dtype, - **{x_dtype: rhs_aval.dtype})] + return dot_general_fn(sub_ctx, lhs_data, lhs_indices, lhs_indptr, rhs, + shape=lhs_spinfo.shape, transpose=False, + target_name_prefix=target_name_prefix) _bcsr_dot_general_default_lowering = mlir.lower_fun( _bcsr_dot_general_impl, multiple_results=False) @@ -700,14 +702,12 @@ def _bcsr_dot_general_gpu_lowering( if gpu_sparse.cuda_is_supported: mlir.register_lowering(bcsr_dot_general_p, partial(_bcsr_dot_general_gpu_lowering, - gpu_sparse.cuda_csr_matvec, - gpu_sparse.cuda_csr_matmat), + target_name_prefix='cu'), platform='cuda') if gpu_sparse.rocm_is_supported: mlir.register_lowering(bcsr_dot_general_p, partial(_bcsr_dot_general_gpu_lowering, - gpu_sparse.rocm_csr_matvec, - gpu_sparse.rocm_csr_matmat), + target_name_prefix='hip'), platform='rocm') diff --git a/jax/experimental/sparse/coo.py b/jax/experimental/sparse/coo.py index c65bc87235d6..014fe9128c1b 100644 --- a/jax/experimental/sparse/coo.py +++ b/jax/experimental/sparse/coo.py @@ -26,6 +26,7 @@ import jax from jax import lax from jax.interpreters import mlir +from jax.experimental.sparse import _lowerings from jax.experimental.sparse._base import JAXSparse from jax.experimental.sparse.util import _coo_extract, CuSparseEfficiencyWarning from jax import tree_util @@ -205,7 +206,7 @@ def _coo_todense_abstract_eval(data, row, col, *, spinfo): _coo_todense_lowering = mlir.lower_fun( _coo_todense_impl, multiple_results=False) -def _coo_todense_gpu_lowering(coo_todense_hlo, ctx, data, row, col, *, spinfo): +def _coo_todense_gpu_lowering(ctx, data, row, col, *, spinfo, target_name_prefix): data_aval, row_aval, _ = ctx.avals_in dtype = data_aval.dtype if not (np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating)): @@ -226,8 +227,13 @@ def _coo_todense_gpu_lowering(coo_todense_hlo, ctx, data, row, col, *, spinfo): "back to the default implementation.", CuSparseEfficiencyWarning) return _coo_todense_lowering(ctx, data, row, col, spinfo=spinfo) - result = coo_todense_hlo( - data, row, col, shape=shape, data_dtype=dtype, index_dtype=row_aval.dtype) + sub_ctx = ctx + if transpose: + out_aval, = ctx.avals_out + out_aval = core.ShapedArray(shape=out_aval.shape[::-1], dtype=out_aval.dtype) + sub_ctx = sub_ctx.replace(avals_out=[out_aval]) + result = _lowerings.coo_todense_gpu_lowering( + sub_ctx, data, row, col, shape=shape, target_name_prefix=target_name_prefix) return ( [hlo.transpose(result, mlir.dense_int_array([1, 0]))] if transpose else [result]) @@ -255,12 +261,12 @@ def _coo_todense_transpose(ct, data, row, col, *, spinfo): if gpu_sparse.cuda_is_supported: mlir.register_lowering( coo_todense_p, - partial(_coo_todense_gpu_lowering, gpu_sparse.cuda_coo_todense), + partial(_coo_todense_gpu_lowering, target_name_prefix='cu'), platform='cuda') if gpu_sparse.rocm_is_supported: mlir.register_lowering( coo_todense_p, - partial(_coo_todense_gpu_lowering, gpu_sparse.rocm_coo_todense), + partial(_coo_todense_gpu_lowering, target_name_prefix='hip'), platform='rocm') #-------------------------------------------------------------------- @@ -325,20 +331,15 @@ def _coo_fromdense_abstract_eval(mat, *, nse, index_dtype): _coo_fromdense_lowering = mlir.lower_fun( _coo_fromdense_impl, multiple_results=True) -def _coo_fromdense_gpu_lowering(coo_fromdense_hlo, ctx, mat, *, nse, - index_dtype): +def _coo_fromdense_gpu_lowering(ctx, mat, *, nse, index_dtype, target_name_prefix): dtype = ctx.avals_in[0].dtype if not (np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating)): warnings.warn(f"coo_fromdense cusparse/hipsparse lowering not available for {dtype=}. " "Falling back to default implementation.", CuSparseEfficiencyWarning) return _coo_fromdense_lowering(ctx, mat, nse=nse, index_dtype=index_dtype) - data, row, col = coo_fromdense_hlo( - mat, nnz=nse, - data_dtype=dtype, - index_dtype=np.dtype(index_dtype), - index_type=mlir.dtype_to_ir_type(np.dtype(index_dtype))) - return [data, row, col] - + return _lowerings.coo_fromdense_gpu_lowering( + ctx, mat, nnz=nse, index_dtype=index_dtype, + target_name_prefix=target_name_prefix) def _coo_fromdense_jvp(primals, tangents, *, nse, index_dtype): M, = primals @@ -373,12 +374,12 @@ def _coo_fromdense_transpose(ct, M, *, nse, index_dtype): if gpu_sparse.cuda_is_supported: mlir.register_lowering( coo_fromdense_p, - partial(_coo_fromdense_gpu_lowering, gpu_sparse.cuda_coo_fromdense), + partial(_coo_fromdense_gpu_lowering, target_name_prefix='cu'), platform='cuda') if gpu_sparse.rocm_is_supported: mlir.register_lowering( coo_fromdense_p, - partial(_coo_fromdense_gpu_lowering, gpu_sparse.rocm_coo_fromdense), + partial(_coo_fromdense_gpu_lowering, target_name_prefix='hip'), platform='rocm') #-------------------------------------------------------------------- @@ -444,8 +445,8 @@ def _coo_matvec_abstract_eval(data, row, col, v, *, spinfo, transpose): _coo_matvec_lowering = mlir.lower_fun( _coo_matvec_impl, multiple_results=False) -def _coo_matvec_gpu_lowering(coo_matvec_hlo, ctx, data, row, col, v, *, spinfo, - transpose): +def _coo_matvec_gpu_lowering(ctx, data, row, col, v, *, spinfo, transpose, + target_name_prefix): data_aval, row_aval, _, x_aval = ctx.avals_in dtype = data_aval.dtype if dtype not in [np.float32, np.float64, np.complex64, np.complex128]: @@ -466,9 +467,9 @@ def _coo_matvec_gpu_lowering(coo_matvec_hlo, ctx, data, row, col, v, *, spinfo, return _coo_matvec_lowering(ctx, data, row, col, v, spinfo=spinfo, transpose=transpose) - return [coo_matvec_hlo( - data, row, col, v, shape=shape, transpose=transpose, - index_dtype=row_aval.dtype, data_dtype=dtype, x_dtype=x_aval.dtype)] + return _lowerings._coo_spmv_gpu_lowering( + ctx, data, row, col, v, transpose=transpose, shape=shape, + target_name_prefix=target_name_prefix) def _coo_matvec_jvp_mat(data_dot, data, row, col, v, *, spinfo, transpose): @@ -497,12 +498,12 @@ def _coo_matvec_transpose(ct, data, row, col, v, *, spinfo, transpose): if gpu_sparse.cuda_is_supported: mlir.register_lowering( coo_matvec_p, - partial(_coo_matvec_gpu_lowering, gpu_sparse.cuda_coo_matvec), + partial(_coo_matvec_gpu_lowering, target_name_prefix='cu'), platform='cuda') if gpu_sparse.rocm_is_supported: mlir.register_lowering( coo_matvec_p, - partial(_coo_matvec_gpu_lowering, gpu_sparse.rocm_coo_matvec), + partial(_coo_matvec_gpu_lowering, target_name_prefix='hip'), platform='rocm') @@ -567,8 +568,8 @@ def _coo_matmat_abstract_eval(data, row, col, B, *, spinfo, transpose): _coo_matmat_lowering = mlir.lower_fun(_coo_matmat_impl, multiple_results=False) -def _coo_matmat_gpu_lowering(coo_matmat_hlo, ctx, data, row, col, B, *, spinfo, - transpose): +def _coo_matmat_gpu_lowering(ctx, data, row, col, B, *, spinfo, transpose, + target_name_prefix): data_aval, row_aval, _, B_aval = ctx.avals_in dtype = data_aval.dtype if dtype not in [np.float32, np.float64, np.complex64, np.complex128]: @@ -589,10 +590,9 @@ def _coo_matmat_gpu_lowering(coo_matmat_hlo, ctx, data, row, col, B, *, spinfo, return _coo_matmat_lowering(ctx, data, row, col, B, spinfo=spinfo, transpose=transpose) - return [coo_matmat_hlo(data, row, col, B, shape=shape, - transpose=transpose, x_dtype=B_aval.dtype, - data_dtype=data_aval.dtype, - index_dtype=row_aval.dtype)] + return _lowerings._coo_spmm_gpu_lowering( + ctx, data, row, col, B, transpose=transpose, shape=shape, + target_name_prefix=target_name_prefix) def _coo_matmat_jvp_left(data_dot, data, row, col, B, *, spinfo, transpose): @@ -618,10 +618,10 @@ def _coo_matmat_transpose(ct, data, row, col, B, *, spinfo, transpose): if gpu_sparse.cuda_is_supported: mlir.register_lowering( coo_matmat_p, - partial(_coo_matmat_gpu_lowering, gpu_sparse.cuda_coo_matmat), + partial(_coo_matmat_gpu_lowering, target_name_prefix='cu'), platform='cuda') if gpu_sparse.rocm_is_supported: mlir.register_lowering( coo_matmat_p, - partial(_coo_matmat_gpu_lowering, gpu_sparse.rocm_coo_matmat), + partial(_coo_matmat_gpu_lowering, target_name_prefix='hip'), platform='rocm') diff --git a/jax/experimental/sparse/csr.py b/jax/experimental/sparse/csr.py index 84171855b85e..cbc5bad1100b 100644 --- a/jax/experimental/sparse/csr.py +++ b/jax/experimental/sparse/csr.py @@ -23,6 +23,7 @@ import jax from jax.interpreters import mlir +from jax.experimental.sparse import _lowerings from jax.experimental.sparse._base import JAXSparse from jax.experimental.sparse.coo import _coo_matmat, _coo_matvec, _coo_todense, COOInfo from jax.experimental.sparse.util import _csr_to_coo, _csr_extract, CuSparseEfficiencyWarning @@ -249,17 +250,16 @@ def _csr_todense_abstract_eval(data, indices, indptr, *, shape): _csr_todense_lowering = mlir.lower_fun( _csr_todense_impl, multiple_results=False) -def _csr_todense_gpu_lowering(csr_todense_hlo, ctx, data, indices, indptr, *, - shape): +def _csr_todense_gpu_lowering(ctx, data, indices, indptr, *, shape, target_name_prefix): data_aval, indices_aval, _ = ctx.avals_in dtype = data_aval.dtype if not (np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating)): warnings.warn(f"csr_todense cusparse/hipsparse lowering not available for {dtype=}. " "Falling back to default implementation.", CuSparseEfficiencyWarning) return _csr_todense_lowering(ctx, data, indices, indptr, shape=shape) - return [csr_todense_hlo( - data, indices, indptr, shape=shape, data_dtype=dtype, - index_dtype=indices_aval.dtype)] + return [_lowerings.csr_todense_gpu_lowering( + ctx, data, indices, indptr, shape=shape, + target_name_prefix=target_name_prefix)] def _csr_todense_jvp(data_dot, data, indices, indptr, *, shape): @@ -284,12 +284,12 @@ def _csr_todense_transpose(ct, data, indices, indptr, *, shape): if gpu_sparse.cuda_is_supported: mlir.register_lowering( csr_todense_p, - partial(_csr_todense_gpu_lowering, gpu_sparse.cuda_csr_todense), + partial(_csr_todense_gpu_lowering, target_name_prefix='cu'), platform='cuda') if gpu_sparse.rocm_is_supported: mlir.register_lowering( csr_todense_p, - partial(_csr_todense_gpu_lowering, gpu_sparse.rocm_csr_todense), + partial(_csr_todense_gpu_lowering, target_name_prefix='hip'), platform='rocm') @@ -359,16 +359,16 @@ def _csr_fromdense_abstract_eval(mat, *, nse, index_dtype): _csr_fromdense_lowering = mlir.lower_fun(_csr_fromdense_impl, multiple_results=True) -def _csr_fromdense_gpu_lowering(csr_fromdense_hlo, ctx, mat, *, nse, index_dtype): +def _csr_fromdense_gpu_lowering(ctx, mat, *, nse, index_dtype, + target_name_prefix): dtype = ctx.avals_in[0].dtype if not (np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating)): warnings.warn(f"csr_fromdense cusparse/hipsparse lowering not available for {dtype=}. " "Falling back to default implementation.", CuSparseEfficiencyWarning) return _csr_fromdense_lowering(ctx, mat, nse=nse, index_dtype=index_dtype) - data, indices, indptr = csr_fromdense_hlo( - mat, nnz=nse, index_dtype=np.dtype(index_dtype), - data_dtype=dtype, index_type=mlir.dtype_to_ir_type(np.dtype(index_dtype))) - return [data, indices, indptr] + return _lowerings.csr_fromdense_gpu_lowering( + ctx, mat, nnz=nse, index_dtype=index_dtype, + target_name_prefix=target_name_prefix) def _csr_fromdense_jvp(primals, tangents, *, nse, index_dtype): @@ -404,12 +404,12 @@ def _csr_fromdense_transpose(ct, M, *, nse, index_dtype): if gpu_sparse.cuda_is_supported: mlir.register_lowering( csr_fromdense_p, - partial(_csr_fromdense_gpu_lowering, gpu_sparse.cuda_csr_fromdense), + partial(_csr_fromdense_gpu_lowering, target_name_prefix='cu'), platform='cuda') if gpu_sparse.rocm_is_supported: mlir.register_lowering( csr_fromdense_p, - partial(_csr_fromdense_gpu_lowering, gpu_sparse.rocm_csr_fromdense), + partial(_csr_fromdense_gpu_lowering, target_name_prefix='hip'), platform='rocm') #-------------------------------------------------------------------- @@ -470,8 +470,8 @@ def _csr_matvec_abstract_eval(data, indices, indptr, v, *, shape, transpose): _csr_matvec_lowering = mlir.lower_fun(_csr_matvec_impl, multiple_results=False) -def _csr_matvec_gpu_lowering(csr_matvec_hlo, ctx, data, indices, indptr, v, *, - shape, transpose): +def _csr_matvec_gpu_lowering(ctx, data, indices, indptr, v, *, shape, transpose, + target_name_prefix): data_aval, indices_aval, _, v_aval = ctx.avals_in dtype = data_aval.dtype if dtype not in [np.float32, np.float64, np.complex64, np.complex128]: @@ -479,10 +479,9 @@ def _csr_matvec_gpu_lowering(csr_matvec_hlo, ctx, data, indices, indptr, v, *, "Falling back to default implementation.", CuSparseEfficiencyWarning) return _csr_matvec_lowering(ctx, data, indices, indptr, v, shape=shape, transpose=transpose) - return [csr_matvec_hlo( - data, indices, indptr, v, shape=shape, transpose=transpose, - data_dtype=dtype, index_dtype=indices_aval.dtype, x_dtype=v_aval.dtype)] - + return _lowerings._csr_spmv_gpu_lowering( + ctx, data, indices, indptr, v, shape=shape, transpose=transpose, + target_name_prefix=target_name_prefix) def _csr_matvec_jvp_mat(data_dot, data, indices, indptr, v, *, shape, transpose): return _csr_matvec(data_dot, indices, indptr, v, shape=shape, transpose=transpose) @@ -511,12 +510,12 @@ def _csr_matvec_transpose(ct, data, indices, indptr, v, *, shape, transpose): if gpu_sparse.cuda_is_supported: mlir.register_lowering( csr_matvec_p, - partial(_csr_matvec_gpu_lowering, gpu_sparse.cuda_csr_matvec), + partial(_csr_matvec_gpu_lowering, target_name_prefix='cu'), platform='cuda') if gpu_sparse.rocm_is_supported: mlir.register_lowering( csr_matvec_p, - partial(_csr_matvec_gpu_lowering, gpu_sparse.rocm_csr_matvec), + partial(_csr_matvec_gpu_lowering, target_name_prefix='hip'), platform='rocm') @@ -580,8 +579,8 @@ def _csr_matmat_abstract_eval(data, indices, indptr, B, *, shape, transpose): _csr_matmat_lowering = mlir.lower_fun(_csr_matmat_impl, multiple_results=False) -def _csr_matmat_gpu_lowering(csr_matmat_hlo, ctx, data, indices, indptr, B, *, - shape, transpose): +def _csr_matmat_gpu_lowering(ctx, data, indices, indptr, B, *, shape, transpose, + target_name_prefix): data_aval, indices_aval, _, B_aval = ctx.avals_in dtype = data_aval.dtype if dtype not in [np.float32, np.float64, np.complex64, np.complex128]: @@ -589,11 +588,9 @@ def _csr_matmat_gpu_lowering(csr_matmat_hlo, ctx, data, indices, indptr, B, *, "Falling back to default implementation.", CuSparseEfficiencyWarning) return _csr_matmat_lowering(ctx, data, indices, indptr, B, shape=shape, transpose=transpose) - return [csr_matmat_hlo( - data, indices, indptr, B, shape=shape, transpose=transpose, - index_dtype=indices_aval.dtype, data_dtype=data_aval.dtype, - B_dtype=B_aval.dtype)] - + return _lowerings._csr_spmm_gpu_lowering( + ctx, data, indices, indptr, B, shape=shape, transpose=transpose, + target_name_prefix=target_name_prefix) def _csr_matmat_jvp_left(data_dot, data, indices, indptr, B, *, shape, transpose): return _csr_matmat(data_dot, indices, indptr, B, shape=shape, transpose=transpose) @@ -621,10 +618,10 @@ def _csr_matmat_transpose(ct, data, indices, indptr, B, *, shape, transpose): if gpu_sparse.cuda_is_supported: mlir.register_lowering( csr_matmat_p, - partial(_csr_matmat_gpu_lowering, gpu_sparse.cuda_csr_matmat), + partial(_csr_matmat_gpu_lowering, target_name_prefix='cu'), platform='cuda') if gpu_sparse.rocm_is_supported: mlir.register_lowering( csr_matmat_p, - partial(_csr_matmat_gpu_lowering, gpu_sparse.rocm_csr_matmat), + partial(_csr_matmat_gpu_lowering, target_name_prefix='hip'), platform='rocm') diff --git a/jax/experimental/sparse/linalg.py b/jax/experimental/sparse/linalg.py index a931b0a30dcf..b2e57caba9a6 100644 --- a/jax/experimental/sparse/linalg.py +++ b/jax/experimental/sparse/linalg.py @@ -29,7 +29,6 @@ from jax._src import core from jax._src import ffi from jax._src.interpreters import ad -from jax._src.lib import gpu_solver import numpy as np from scipy.sparse import csr_matrix, linalg @@ -534,11 +533,6 @@ def _spsolve_abstract_eval(data, indices, indptr, b, *, tol, reorder): def _spsolve_gpu_lowering(ctx, data, indices, indptr, b, *, tol, reorder): - # TODO(danfm): remove after JAX 0.5.1 release. - if hasattr(gpu_solver, "cuda_csrlsvqr"): - data_aval, _, _, _, = ctx.avals_in - return gpu_solver.cuda_csrlsvqr(data_aval.dtype, data, indices, - indptr, b, tol, reorder) return ffi.ffi_lowering("cusolver_csrlsvqr_ffi")( ctx, data, indices, indptr, b, tol=np.float64(tol), reorder=np.int32(reorder)) diff --git a/jax/experimental/sparse/random.py b/jax/experimental/sparse/random.py index f90c2572d282..a9146b7746e0 100644 --- a/jax/experimental/sparse/random.py +++ b/jax/experimental/sparse/random.py @@ -18,7 +18,7 @@ from jax import dtypes from jax import vmap from jax import random -from jax.util import split_list +from jax._src.util import split_list import jax.numpy as jnp from jax.experimental import sparse diff --git a/jax/experimental/sparse/test_util.py b/jax/experimental/sparse/test_util.py index 77c97513041c..63e035d2d1ac 100644 --- a/jax/experimental/sparse/test_util.py +++ b/jax/experimental/sparse/test_util.py @@ -29,7 +29,7 @@ from jax._src.typing import DTypeLike from jax.experimental import sparse import jax.numpy as jnp -from jax.util import safe_zip, split_list +from jax._src.util import safe_zip, split_list import numpy as np MATMUL_TOL = { diff --git a/jax/experimental/sparse/transform.py b/jax/experimental/sparse/transform.py index ce1d3f4af9d0..a16756d42c45 100644 --- a/jax/experimental/sparse/transform.py +++ b/jax/experimental/sparse/transform.py @@ -68,7 +68,7 @@ from jax._src.lib import pytree from jax._src.interpreters import partial_eval as pe from jax.tree_util import tree_flatten, tree_map, tree_unflatten -from jax.util import safe_map, safe_zip, split_list +from jax._src.util import safe_map, safe_zip, split_list from jax._src.lax.control_flow import _check_tree_and_avals from jax._src.numpy import indexing as jnp_indexing from jax.experimental import sparse diff --git a/jax/experimental/sparse/util.py b/jax/experimental/sparse/util.py index 36e9a9c51664..7c6bfb1ec345 100644 --- a/jax/experimental/sparse/util.py +++ b/jax/experimental/sparse/util.py @@ -25,7 +25,7 @@ from jax._src import core from jax._src.api_util import flatten_axes import jax.numpy as jnp -from jax.util import safe_zip +from jax._src.util import safe_zip from jax._src.lax.lax import _dot_general_shape_rule, DotDimensionNumbers from jax._src.typing import Array diff --git a/jax/extend/__init__.py b/jax/extend/__init__.py index bbb5925ab41a..c875abb9c598 100644 --- a/jax/extend/__init__.py +++ b/jax/extend/__init__.py @@ -16,24 +16,24 @@ The :mod:`jax.extend` module provides modules for access to JAX internal machinery. See -`JEP #15856 `_. +`JEP #15856 `_. This module is not the only means by which JAX aims to be extensible. For example, the main JAX API offers mechanisms for `customizing derivatives -`_, +`_, `registering custom pytree definitions -`_, +`_, and more. API policy ---------- Unlike the -`public API `_, +`public API `_, this module offers **no compatibility guarantee** across releases. Breaking changes will be announced via the -`JAX project changelog `_. +`JAX project changelog `_. """ from jax.extend import ( diff --git a/jax/extend/core/primitives.py b/jax/extend/core/primitives.py index d8a10154cf4a..60d8cd24a949 100644 --- a/jax/extend/core/primitives.py +++ b/jax/extend/core/primitives.py @@ -149,7 +149,6 @@ igamma_p as igamma_p, lgamma_p as lgamma_p, polygamma_p as polygamma_p, - random_gamma_grad_p as random_gamma_grad_p, regularized_incomplete_beta_p as regularized_incomplete_beta_p, zeta_p as zeta_p, ) diff --git a/jax/extend/mlir/dialects/sdy.py b/jax/extend/mlir/dialects/sdy.py index 48586cc26760..d83fd90ecdf4 100644 --- a/jax/extend/mlir/dialects/sdy.py +++ b/jax/extend/mlir/dialects/sdy.py @@ -14,8 +14,4 @@ # ruff: noqa: F403 -# TODO(bartchr): Once JAX is released with SDY, remove the try/except. -try: - from jaxlib.mlir.dialects.sdy import * -except ImportError: - pass +from jaxlib.mlir.dialects.sdy import * diff --git a/jax/interpreters/mlir.py b/jax/interpreters/mlir.py index 0f32799f7ea9..10c8d1e9e671 100644 --- a/jax/interpreters/mlir.py +++ b/jax/interpreters/mlir.py @@ -33,7 +33,7 @@ aval_to_ir_type as aval_to_ir_type, aval_to_ir_types as aval_to_ir_types, core_call_lowering as core_call_lowering, - custom_call as custom_call, + custom_call as _custom_call, dense_bool_elements as dense_bool_elements, dense_bool_array as dense_bool_array, dense_int_array as dense_int_array, @@ -43,8 +43,6 @@ flatten_ir_values as flatten_lowering_ir_args, # TODO(phawkins): remove me # noqa: F401 flatten_ir_values as flatten_ir_values, unflatten_ir_values_like_types as unflatten_ir_values_like_types, - func_dialect as func_dialect, - hlo as hlo, i32_attr as i32_attr, i64_attr as i64_attr, ir as ir, @@ -63,7 +61,6 @@ register_lowering as register_lowering, shape_tensor as shape_tensor, token_type as token_type, - xla_computation_to_mlir_module as xla_computation_to_mlir_module, ) from jax._src.mesh import Mesh as Mesh @@ -80,3 +77,23 @@ from jax._src.callback import ( emit_python_callback as emit_python_callback, ) + +_deprecations = { + # Added Apr 7 2025 + "custom_call": ( + "mlir.custom_call is deprecated; use the APIs provided by jax.ffi instead.", + _custom_call, + ) +} + +import typing as _typing + +if _typing.TYPE_CHECKING: + custom_call = _custom_call +else: + from jax._src.deprecations import deprecation_getattr as _deprecation_getattr + + __getattr__ = _deprecation_getattr(__name__, _deprecations) + del _deprecation_getattr +del _typing +del _custom_call diff --git a/jax/interpreters/partial_eval.py b/jax/interpreters/partial_eval.py index b546d774a2e9..a2d988f6bea3 100644 --- a/jax/interpreters/partial_eval.py +++ b/jax/interpreters/partial_eval.py @@ -81,7 +81,6 @@ trace_to_subjaxpr_nounits as trace_to_subjaxpr_nounits, trace_to_subjaxpr_nounits_fwd as trace_to_subjaxpr_nounits_fwd, tracers_to_jaxpr as tracers_to_jaxpr, - trivial_ctx as trivial_ctx, ) diff --git a/jax/interpreters/xla.py b/jax/interpreters/xla.py index bd3b83e37d24..2f8417ade1f8 100644 --- a/jax/interpreters/xla.py +++ b/jax/interpreters/xla.py @@ -38,19 +38,6 @@ "jax.interpreters.xla.pytype_aval_mappings is deprecated.", _src_core.pytype_aval_mappings ), - # Finalized 2024-10-24; remove after 2025-01-24 - "xb": ( - ("jax.interpreters.xla.xb was removed in JAX v0.4.36. " - "Use jax.lib.xla_bridge instead."), None - ), - "xc": ( - ("jax.interpreters.xla.xc was removed in JAX v0.4.36. " - "Use jax.lib.xla_client instead."), None - ), - "xe": ( - ("jax.interpreters.xla.xe was removed in JAX v0.4.36. " - "Use jax.lib.xla_extension instead."), None - ), } import typing as _typing diff --git a/jax/lax/__init__.py b/jax/lax/__init__.py index 4e376fb666d1..e8ec74f59a7d 100644 --- a/jax/lax/__init__.py +++ b/jax/lax/__init__.py @@ -198,6 +198,7 @@ select as select, select_n as select_n, select_n_p as select_n_p, + shape_as_value as shape_as_value, shift_left as shift_left, shift_left_p as shift_left_p, shift_right_arithmetic as shift_right_arithmetic, @@ -260,7 +261,6 @@ polygamma as polygamma, polygamma_p as polygamma_p, random_gamma_grad as random_gamma_grad, - random_gamma_grad_p as random_gamma_grad_p, regularized_incomplete_beta_p as regularized_incomplete_beta_p, zeta as zeta, zeta_p as zeta_p, @@ -377,6 +377,9 @@ ragged_all_to_all as ragged_all_to_all, ragged_all_to_all_p as ragged_all_to_all_p, ) +from jax._src.core import ( + pvary as pvary, +) from jax._src.lax.other import ( conv_general_dilated_local as conv_general_dilated_local, conv_general_dilated_patches as conv_general_dilated_patches @@ -392,3 +395,50 @@ from jax._src.pjit import with_sharding_constraint as with_sharding_constraint from jax._src.pjit import sharding_constraint_p as sharding_constraint_p from jax._src.dispatch import device_put_p as device_put_p + +import jax._src.lax.lax + +_deprecations = { + "infeed": ( + ( + "jax.lax.infeed was deprecated in JAX v0.6.0 and will be removed in" + " JAX v0.7.0." + ), + jax._src.lax.lax.infeed, + ), + "infeed_p": ( + ( + "jax.lax.infeed_p was deprecated in JAX v0.6.0 and will be removed" + " in JAX v0.7.0." + ), + jax._src.lax.lax.infeed_p, + ), + "outfeed": ( + ( + "jax.lax.outfeed was deprecated in JAX v0.6.0 and will be removed" + " in JAX v0.7.0." + ), + jax._src.lax.lax.outfeed, + ), + "outfeed_p": ( + ( + "jax.lax.outfeed_p was deprecated in JAX v0.6.0 and will be removed" + " in JAX v0.7.0." + ), + jax._src.lax.lax.outfeed_p, + ), +} + +import typing as _typing + +if _typing.TYPE_CHECKING: + infeed = jax._src.lax.lax.infeed + infeed_p = jax._src.lax.lax.infeed_p + outfeed = jax._src.lax.lax.outfeed + outfeed_p = jax._src.lax.lax.outfeed_p +else: + from jax._src.deprecations import deprecation_getattr as _deprecation_getattr + + __getattr__ = _deprecation_getattr(__name__, _deprecations) + del _deprecation_getattr +del _typing diff --git a/jax/lib/xla_bridge.py b/jax/lib/xla_bridge.py index b158d9b1ff51..95598c447262 100644 --- a/jax/lib/xla_bridge.py +++ b/jax/lib/xla_bridge.py @@ -27,15 +27,6 @@ "jax.lib.xla_bridge.get_backend is deprecated; use jax.extend.backend.get_backend.", _deprecated_get_backend ), - # Finalized 2024-12-11; remove after 2025-3-11 - "xla_client": ( - "jax.lib.xla_bridge.xla_client was removed in JAX v0.4.38; use jax.lib.xla_client directly.", - None - ), - "default_backend": ( - "jax.lib.xla_bridge.default_backend was removed in JAX v0.4.38; use jax.default_backend.", - None - ), } import typing as _typing diff --git a/jax/lib/xla_client.py b/jax/lib/xla_client.py index 86e7307c804b..cc4fa78eb576 100644 --- a/jax/lib/xla_client.py +++ b/jax/lib/xla_client.py @@ -12,116 +12,160 @@ # See the License for the specific language governing permissions and # limitations under the License. -from jax._src.lax.fft import FftType as _FftType from jax._src.lib import xla_client as _xc -get_topology_for_devices = _xc.get_topology_for_devices -heap_profile = _xc.heap_profile -mlir_api_version = _xc.mlir_api_version -Client = _xc.Client -CompileOptions = _xc.CompileOptions -DeviceAssignment = _xc.DeviceAssignment -Frame = _xc.Frame -HloSharding = _xc.HloSharding -OpSharding = _xc.OpSharding -Traceback = _xc.Traceback - _deprecations = { - # Finalized 2024-12-11; remove after 2025-3-11 - "_xla": ( - "jax.lib.xla_client._xla was removed in JAX v0.4.38; use jax.lib.xla_extension.", - None, - ), - "bfloat16": ( - "jax.lib.xla_client.bfloat16 was removed in JAX v0.4.38; use ml_dtypes.bfloat16.", - None, - ), - # Finalized 2024-12-23; remove after 2024-03-23 - "Device": ( - "jax.lib.xla_client.Device is deprecated; use jax.Device instead.", - None, - ), - "XlaRuntimeError": ( + # Finalized 2025-03-25; remove after 2025-06-25 + "FftType": ( ( - "jax.lib.xla_client.XlaRuntimeError is deprecated; use" - " jax.errors.JaxRuntimeError." + "jax.lib.xla_client.FftType was removed in JAX v0.6.0; use" + " jax.lax.FftType." ), None, ), - # Added Oct 10 2024 - "FftType": ( - "jax.lib.xla_client.FftType is deprecated; use jax.lax.FftType.", - _FftType, - ), "PaddingType": ( ( - "jax.lib.xla_client.PaddingType is deprecated; this type is unused" - " by JAX so there is no replacement." + "jax.lib.xla_client.PaddingType was removed in JAX v0.6.0;" + " this type is unused by JAX so there is no replacement." ), - _xc.PaddingType, + None, ), - # Added Oct 11 2024 "dtype_to_etype": ( - "dtype_to_etype is deprecated; use StableHLO instead.", - _xc.dtype_to_etype, + "dtype_to_etype was removed in JAX v0.6.0; use StableHLO instead.", + None, + ), + "shape_from_pyval": ( + "shape_from_pyval was removed in JAX v0.6.0; use StableHLO instead.", + None, ), + # Added Oct 11 2024, finalized 2025-04-09 "ops": ( - "ops is deprecated; use StableHLO instead.", - _xc.ops, + "ops has been removed in JAX v0.6.0; use StableHLO instead.", + None, ), "register_custom_call_target": ( - "register_custom_call_target is deprecated; use the JAX FFI instead " - "(https://jax.readthedocs.io/en/latest/ffi.html)", - _xc.register_custom_call_target, - ), - "shape_from_pyval": ( - "shape_from_pyval is deprecated; use StableHLO instead.", - _xc.shape_from_pyval, + ( + "register_custom_call_target has been removed in JAX v0.6.0; use" + " the JAX FFI instead (https://docs.jax.dev/en/latest/ffi.html)" + ), + None, ), "PrimitiveType": ( - "PrimitiveType is deprecated; use StableHLO instead.", - _xc.PrimitiveType, + "PrimitiveType has been removed in JAX v0.6.0; use StableHLO instead.", + None, ), "Shape": ( - "Shape is deprecated; use StableHLO instead.", - _xc.Shape, + "Shape has been removed in JAX v0.6.0; use StableHLO instead.", + None, ), "XlaBuilder": ( - "XlaBuilder is deprecated; use StableHLO instead.", - _xc.XlaBuilder, + "XlaBuilder has been removed in JAX v0.6.0; use StableHLO instead.", + None, ), "XlaComputation": ( - "XlaComputation is deprecated; use StableHLO instead.", - _xc.XlaComputation, + "XlaComputation has been removed in JAX v0.6.0; use StableHLO instead.", + None, ), - # Added Nov 20 2024 + # Added Nov 20 2024, finalized 2025-04-09 "ArrayImpl": ( - "jax.lib.xla_client.ArrayImpl is deprecated; use jax.Array instead.", - _xc.ArrayImpl, + ( + "jax.lib.xla_client.ArrayImpl has been removed in JAX v0.6.0; use" + " jax.Array instead." + ), + None, + ), + # Added April 4 2025. + "get_topology_for_devices": ( + ( + "jax.lib.xla_client.get_topology_for_devices was deprecated in JAX" + " v0.6.0 and will be removed in JAX v0.7.0" + ), + _xc.get_topology_for_devices, + ), + "heap_profile": ( + ( + "jax.lib.xla_client.heap_profile was deprecated in JAX v0.6.0 and" + " will be removed in JAX v0.7.0" + ), + _xc.heap_profile, + ), + "mlir_api_version": ( + ( + "jax.lib.xla_client.mlir_api_version was deprecated in JAX v0.6.0" + " and will be removed in JAX v0.7.0" + ), + _xc.mlir_api_version, + ), + "Client": ( + ( + "jax.lib.xla_client.Client was deprecated in JAX v0.6.0 and will be" + " removed in JAX v0.7.0" + ), + _xc.Client, + ), + "CompileOptions": ( + ( + "jax.lib.xla_client.CompileOptions was deprecated in JAX v0.6.0 and" + " will be removed in JAX v0.7.0" + ), + _xc.CompileOptions, + ), + "DeviceAssignment": ( + ( + "jax.lib.xla_client.DeviceAssignment was deprecated in JAX v0.6.0" + " and will be removed in JAX v0.7.0" + ), + _xc.DeviceAssignment, + ), + "Frame": ( + ( + "jax.lib.xla_client.Frame was deprecated in JAX v0.6.0 and will be" + " removed in JAX v0.7.0" + ), + _xc.Frame, + ), + "HloSharding": ( + ( + "jax.lib.xla_client.HloSharding was deprecated in JAX v0.6.0 and" + " will be removed in JAX v0.7.0" + ), + _xc.HloSharding, + ), + "OpSharding": ( + ( + "jax.lib.xla_client.OpSharding was deprecated in JAX v0.6.0 and" + " will be removed in JAX v0.7.0" + ), + _xc.OpSharding, + ), + "Traceback": ( + ( + "jax.lib.xla_client.Traceback was deprecated in JAX v0.6.0 and will" + " be removed in JAX v0.7.0" + ), + _xc.Traceback, ), } import typing as _typing if _typing.TYPE_CHECKING: - dtype_to_etype = _xc.dtype_to_etype - ops = _xc.ops - register_custom_call_target = _xc.register_custom_call_target - shape_from_pyval = _xc.shape_from_pyval - ArrayImpl = _xc.ArrayImpl - Device = _xc.Device - FftType = _FftType - PaddingType = _xc.PaddingType - PrimitiveType = _xc.PrimitiveType Shape = _xc.Shape - XlaBuilder = _xc.XlaBuilder XlaComputation = _xc.XlaComputation - XlaRuntimeError = _xc.XlaRuntimeError + get_topology_for_devices = _xc.get_topology_for_devices + heap_profile = _xc.heap_profile + mlir_api_version = _xc.mlir_api_version + Client = _xc.Client + CompileOptions = _xc.CompileOptions + DeviceAssignment = _xc.DeviceAssignment + Frame = _xc.Frame + HloSharding = _xc.HloSharding + OpSharding = _xc.OpSharding + Traceback = _xc.Traceback else: from jax._src.deprecations import deprecation_getattr as _deprecation_getattr __getattr__ = _deprecation_getattr(__name__, _deprecations) del _deprecation_getattr del _typing -del _FftType del _xc diff --git a/jax/lib/xla_extension.py b/jax/lib/xla_extension.py index 52fe94e231d1..8f1b27070e98 100644 --- a/jax/lib/xla_extension.py +++ b/jax/lib/xla_extension.py @@ -14,42 +14,122 @@ from jax._src.lib import xla_extension as _xe -get_distributed_runtime_client = _xe.get_distributed_runtime_client -get_distributed_runtime_service = _xe.get_distributed_runtime_service -hlo_module_cost_analysis = _xe.hlo_module_cost_analysis -hlo_module_to_dot_graph = _xe.hlo_module_to_dot_graph -ifrt_proxy = _xe.ifrt_proxy -jax_jit = _xe.jax_jit -mlir = _xe.mlir -pmap_lib = _xe.pmap_lib -profiler = _xe.profiler -pytree = _xe.pytree -Device = _xe.Device -DistributedRuntimeClient = _xe.DistributedRuntimeClient -HloModule = _xe.HloModule -HloPrintOptions = _xe.HloPrintOptions -OpSharding = _xe.OpSharding -PjitFunctionCache = _xe.PjitFunctionCache -PjitFunction = _xe.PjitFunction -PmapFunction = _xe.PmapFunction - _deprecations = { - # Added Nov 20 2024 "ArrayImpl": ( - "jax.lib.xla_extension.ArrayImpl is deprecated; use jax.Array instead.", - _xe.ArrayImpl, + ( + "jax.lib.xla_extension.ArrayImpl has been removed; use jax.Array" + " instead." + ), + None, ), "XlaRuntimeError": ( - "jax.lib.xla_extension.XlaRuntimeError is deprecated; use jax.errors.JaxRuntimeError instead.", - _xe.XlaRuntimeError, + ( + "jax.lib.xla_extension.XlaRuntimeError has been removed; use" + " jax.errors.JaxRuntimeError instead." + ), + None, + ), + # Deprecated March 26 2025. + "DistributedRuntimeClient": ( + ( + "jax.lib.xla_extension.DistributedRuntimeClient is" + " deprecated; use jax.distributed instead." + ), + _xe.DistributedRuntimeClient, + ), + "get_distributed_runtime_client": ( + ( + "jax.lib.xla_extension.get_distributed_runtime_client is" + " deprecated; use jax.distributed instead." + ), + _xe.get_distributed_runtime_client, + ), + "get_distributed_runtime_service": ( + ( + "jax.lib.xla_extension.get_distributed_runtime_service is" + " deprecated; use jax.distributed instead." + ), + _xe.get_distributed_runtime_service, + ), + "Device": ( + "jax.lib.xla_extension.Device is deprecated; use jax.Device instead.", + _xe.Device, + ), + "PjitFunctionCache": ( + "jax.lib.xla_extension.PjitFunctionCache is deprecated.", + _xe.PjitFunctionCache, + ), + "ifrt_proxy": ( + "jax.lib.xla_extension.ifrt_proxy is deprecated.", + _xe.ifrt_proxy, + ), + "jax_jit": ( + "jax.lib.xla_extension.jax_jit is deprecated.", + _xe.jax_jit, + ), + "mlir": ("jax.lib.xla_extension.mlir is deprecated.", _xe.mlir), + "pmap_lib": ("jax.lib.xla_extension.pmap_lib is deprecated.", _xe.pmap_lib), + "profiler": ( + "jax.lib.xla_extension.profiler is deprecated.", + _xe.profiler, + ), + "pytree": ( + "jax.lib.xla_extension.pytree is deprecated.", + _xe.pytree, + ), + "hlo_module_cost_analysis": ( + "jax.lib.xla_extension.hlo_module_cost_analysis is deprecated.", + _xe.hlo_module_cost_analysis, + ), + "hlo_module_to_dot_graph": ( + "jax.lib.xla_extension.hlo_module_to_dot_graph is deprecated.", + _xe.hlo_module_to_dot_graph, + ), + "HloModule": ( + "jax.lib.xla_extension.HloModule is deprecated.", + _xe.HloModule, + ), + "HloPrintOptions": ( + "jax.lib.xla_extension.HloPrintOptions is deprecated.", + _xe.HloPrintOptions, + ), + "OpSharding": ( + "jax.lib.xla_extension.OpSharding is deprecated.", + _xe.OpSharding, + ), + "PjitFunction": ( + "jax.lib.xla_extension.PjitFunction is deprecated.", + _xe.PjitFunction, + ), + "PmapFunction": ( + "jax.lib.xla_extension.PmapFunction is deprecated.", + _xe.PmapFunction, ), } import typing as _typing if _typing.TYPE_CHECKING: - ArrayImpl = _xe.ArrayImpl - XlaRuntimeError = _xe.XlaRuntimeError + Device = _xe.Device + DistributedRuntimeClient = _xe.DistributedRuntimeClient + HloModule = _xe.HloModule + HloPrintOptions = _xe.HloPrintOptions + OpSharding = _xe.OpSharding + PjitFunction = _xe.PjitFunction + PjitFunctionCache = _xe.PjitFunctionCache + PmapFunction = _xe.PmapFunction + + get_distributed_runtime_client = _xe.get_distributed_runtime_client + get_distributed_runtime_service = _xe.get_distributed_runtime_service + hlo_module_cost_analysis = _xe.hlo_module_cost_analysis + hlo_module_to_dot_graph = _xe.hlo_module_to_dot_graph + ifrt_proxy = _xe.ifrt_proxy + jax_jit = _xe.jax_jit + mlir = _xe.mlir + pmap_lib = _xe.pmap_lib + profiler = _xe.profiler + pytree = _xe.pytree + else: from jax._src.deprecations import deprecation_getattr as _deprecation_getattr diff --git a/jax/monitoring.py b/jax/monitoring.py index 4c9996da582c..f4ab8124f219 100644 --- a/jax/monitoring.py +++ b/jax/monitoring.py @@ -26,7 +26,9 @@ record_event_duration_secs as record_event_duration_secs, record_event_time_span as record_event_time_span, record_event as record_event, + record_scalar as record_scalar, register_event_duration_secs_listener as register_event_duration_secs_listener, register_event_listener as register_event_listener, register_event_time_span_listener as register_event_time_span_listener, + register_scalar_listener as register_scalar_listener, ) diff --git a/jax/nn/__init__.py b/jax/nn/__init__.py index 3f08e1c0fd12..651d9cf4e47f 100644 --- a/jax/nn/__init__.py +++ b/jax/nn/__init__.py @@ -35,8 +35,10 @@ standardize as standardize, one_hot as one_hot, relu as relu, + identity as identity, relu6 as relu6, dot_product_attention as dot_product_attention, + get_scaled_dot_general_config as get_scaled_dot_general_config, scaled_dot_general as scaled_dot_general, scaled_matmul as scaled_matmul, selu as selu, diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index cb291bdca79a..b6cfb1ff06ac 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -211,13 +211,18 @@ double as double, float16 as float16, float32 as float32, + float4_e2m1fn as float4_e2m1fn, float64 as float64, + float8_e3m4 as float8_e3m4, + float8_e4m3 as float8_e4m3, float8_e4m3b11fnuz as float8_e4m3b11fnuz, float8_e4m3fn as float8_e4m3fn, float8_e4m3fnuz as float8_e4m3fnuz, float8_e5m2 as float8_e5m2, float8_e5m2fnuz as float8_e5m2fnuz, + float8_e8m0fnu as float8_e8m0fnu, float_ as float_, + int2 as int2, int4 as int4, int8 as int8, int16 as int16, @@ -226,6 +231,7 @@ int_ as int_, single as single, uint as uint, + uint2 as uint2, uint4 as uint4, uint8 as uint8, uint16 as uint16, @@ -295,26 +301,6 @@ unsignedinteger as unsignedinteger, ) -# TODO(slebedev): Remove the try-except once we upgrade to ml_dtypes 0.4.1. -try: - from jax._src.numpy.scalar_types import ( - int2 as int2, - uint2 as uint2, - ) -except ImportError: - pass - -# TODO: Remove the try-except once we upgrade to ml_dtypes 0.5.0 -try: - from jax._src.numpy.scalar_types import ( - float8_e3m4 as float8_e3m4, - float8_e4m3 as float8_e4m3, - float8_e8m0fnu as float8_e8m0fnu, - float4_e2m1fn as float4_e2m1fn, - ) -except ImportError: - pass - from jax._src.numpy.array_api_metadata import ( __array_api_version__ as __array_api_version__, __array_namespace_info__ as __array_namespace_info__, @@ -506,19 +492,3 @@ from jax._src.numpy.array_methods import register_jax_array_methods register_jax_array_methods() del register_jax_array_methods - - -_deprecations = { - # Finalized 2024-12-13; remove after 2024-3-13 - "round_": ( - "jnp.round_ was deprecated in JAX 0.4.38; use jnp.round instead.", - None - ), -} - -import typing -if not typing.TYPE_CHECKING: - from jax._src.deprecations import deprecation_getattr as _deprecation_getattr - __getattr__ = _deprecation_getattr(__name__, _deprecations) - del _deprecation_getattr -del typing diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index b73a3b95b9a5..df6454c9a1f1 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -15,7 +15,7 @@ from jax._src.numpy.index_tricks import _Mgrid, _Ogrid, CClass as _CClass, RClas from jax._src.numpy.array_api_metadata import ArrayNamespaceInfo from jax._src.typing import ( Array, ArrayLike, DType, DTypeLike, DeprecatedArg, - DimSize, DuckTypedArray, Shape, StaticScalar, + DimSize, DuckTypedArray, Shape, StaticScalar, SupportsNdim, SupportsShape, SupportsSize, ) from jax._src.sharding_impls import NamedSharding, PartitionSpec as P from jax.numpy import fft as fft, linalg as linalg @@ -583,7 +583,7 @@ def iscomplexobj(x: Any) -> builtins.bool: ... def isdtype(dtype: DTypeLike, kind: DType | str | tuple[DType | str, ...]) -> builtins.bool: ... def isfinite(x: ArrayLike, /) -> Array: ... def isin(element: ArrayLike, test_elements: ArrayLike, - assume_unique: builtins.bool = ..., invert: builtins.bool = ...) -> Array: ... + assume_unique: builtins.bool = ..., invert: builtins.bool = ..., method: str = ...) -> Array: ... def isinf(x: ArrayLike, /) -> Array: ... def isnan(x: ArrayLike, /) -> Array: ... def isneginf(x: ArrayLike, /) -> Array: ... @@ -728,7 +728,7 @@ def nanvar(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., ddof: int = 0, keepdims: builtins.bool = False, where: ArrayLike | None = ...) -> Array: ... ndarray = Array -def ndim(a: ArrayLike) -> int: ... +def ndim(a: ArrayLike | SupportsNdim) -> int: ... def negative(x: ArrayLike, /) -> Array: ... newaxis = None def nextafter(x: ArrayLike, y: ArrayLike, /) -> Array: ... @@ -808,8 +808,7 @@ def remainder(x: ArrayLike, y: ArrayLike, /) -> Array: ... def repeat(a: ArrayLike, repeats: ArrayLike, axis: int | None = ..., *, total_repeat_length: int | None = ...) -> Array: ... def reshape( - a: ArrayLike, shape: DimSize | Shape = ..., - newshape: DimSize | Shape | None = ..., order: str = ... + a: ArrayLike, shape: DimSize | Shape, order: str = ..., *, copy: bool | None = ..., ) -> Array: ... def resize(a: ArrayLike, new_shape: Shape) -> Array: ... @@ -842,7 +841,7 @@ def setdiff1d( fill_value: ArrayLike | None = ..., ) -> Array: ... def setxor1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: builtins.bool = ...) -> Array: ... -def shape(a: ArrayLike) -> tuple[int, ...]: ... +def shape(a: ArrayLike | SupportsShape) -> tuple[int, ...]: ... def sign(x: ArrayLike, /) -> Array: ... def signbit(x: ArrayLike, /) -> Array: ... signedinteger = _np.signedinteger @@ -850,7 +849,7 @@ def sin(x: ArrayLike, /) -> Array: ... def sinc(x: ArrayLike, /) -> Array: ... single: Any def sinh(x: ArrayLike, /) -> Array: ... -def size(a: ArrayLike, axis: int | None = None) -> int: ... +def size(a: ArrayLike | SupportsSize, axis: int | None = None) -> int: ... def sort( a: ArrayLike, axis: int | None = ..., @@ -930,14 +929,14 @@ def tril(m: ArrayLike, k: int = ...) -> Array: ... def tril_indices( n: int, k: int = ..., m: int | None = ... ) -> tuple[Array, Array]: ... -def tril_indices_from(arr: ArrayLike, k: int = ...) -> tuple[Array, Array]: ... +def tril_indices_from(arr: ArrayLike | SupportsShape, k: int = ...) -> tuple[Array, Array]: ... def fill_diagonal(a: ArrayLike, val: ArrayLike, wrap: builtins.bool = ..., *, inplace: builtins.bool = ...) -> Array: ... def trim_zeros(filt: ArrayLike, trim: str = ...) -> Array: ... def triu(m: ArrayLike, k: int = ...) -> Array: ... def triu_indices( n: int, k: int = ..., m: int | None = ... ) -> tuple[Array, Array]: ... -def triu_indices_from(arr: ArrayLike, k: int = ...) -> tuple[Array, Array]: ... +def triu_indices_from(arr: ArrayLike | SupportsShape, k: int = ...) -> tuple[Array, Array]: ... def true_divide(x: ArrayLike, y: ArrayLike, /) -> Array: ... def trunc(x: ArrayLike, /) -> Array: ... uint: Any diff --git a/jax/random.py b/jax/random.py index 9db584895cf1..89d68a24ccaf 100644 --- a/jax/random.py +++ b/jax/random.py @@ -92,7 +92,7 @@ To learn more about this upgrade, and the design of key types, see `JEP 9263 - `_. + `_. Advanced -------- @@ -178,7 +178,7 @@ ``XLA_FLAGS=--xla_tpu_spmd_rng_bit_generator_unsafe=1``. For more about ``jax_threefry_partitionable``, see -https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#generating-random-numbers +https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#generating-random-numbers **Summary:** diff --git a/jax/scipy/linalg.py b/jax/scipy/linalg.py index 64bc0544000b..c8a2d5f81957 100644 --- a/jax/scipy/linalg.py +++ b/jax/scipy/linalg.py @@ -31,6 +31,7 @@ lu as lu, lu_factor as lu_factor, lu_solve as lu_solve, + pascal as pascal, polar as polar, qr as qr, rsf2csf as rsf2csf, diff --git a/jax/sharding.py b/jax/sharding.py index 55ff0f6aea0b..66692069d19b 100644 --- a/jax/sharding.py +++ b/jax/sharding.py @@ -20,8 +20,8 @@ NamedSharding as NamedSharding, SingleDeviceSharding as SingleDeviceSharding, PmapSharding as PmapSharding, - GSPMDSharding as GSPMDSharding, - PositionalSharding as PositionalSharding, + GSPMDSharding as _deprecated_GSPMDSharding, + PositionalSharding as _deprecated_PositionalSharding, use_mesh as use_mesh, set_mesh as set_mesh, ) @@ -36,16 +36,29 @@ ) _deprecations = { - # Finalized 2024-10-01; remove after 2025-01-01. - "XLACompatibleSharding": ( + # Added April 11, 2025. + "PositionalSharding": ( ( - "jax.sharding.XLACompatibleSharding was removed in JAX v0.4.34. " - "Use jax.sharding.Sharding instead." + "jax.sharding.PositionalSharding is deprecated. Use" + " jax.NamedSharding instead." ), - None, - ) + _deprecated_PositionalSharding, + ), + "GSPMDSharding": ( + ( + "jax.sharding.GSPMDSharding is deprecated. Use" + " jax.NamedSharding instead." + ), + _deprecated_GSPMDSharding, + ), } -from jax._src.deprecations import deprecation_getattr as _deprecation_getattr -__getattr__ = _deprecation_getattr(__name__, _deprecations) -del _deprecation_getattr +import typing +if typing.TYPE_CHECKING: + PositionalSharding = _deprecated_PositionalSharding + GSPMDSharding = _deprecated_GSPMDSharding +else: + from jax._src.deprecations import deprecation_getattr as _deprecation_getattr + __getattr__ = _deprecation_getattr(__name__, _deprecations) + del _deprecation_getattr +del typing diff --git a/jax/stages.py b/jax/stages.py index 3e7e461c385b..aa4c96168b3f 100644 --- a/jax/stages.py +++ b/jax/stages.py @@ -18,7 +18,7 @@ lowering and compilation *ahead of time*. This module defines types that represent the stages of this process. -For more, see the `AOT walkthrough `_. +For more, see the `AOT walkthrough `_. """ # Note: import as is required for names to be exported. diff --git a/jax/tools/jax_to_ir.py b/jax/tools/jax_to_ir.py index 904ce509a87e..47b85382f8bf 100644 --- a/jax/tools/jax_to_ir.py +++ b/jax/tools/jax_to_ir.py @@ -240,16 +240,12 @@ def parse_shape_str(s): _DT = { 'pred': jnp.bool_, - 'u4': jnp.uint4, 'u8': jnp.uint8, 'u16': jnp.uint16, 'u32': jnp.uint32, 'u64': jnp.uint64, - 's4': jnp.int4, 's8': jnp.int8, 's16': jnp.int16, 's32': jnp.int32, 's64': jnp.int64, + 'u2': jnp.uint2, 'u4': jnp.uint4, 'u8': jnp.uint8, 'u16': jnp.uint16, 'u32': jnp.uint32, 'u64': jnp.uint64, + 's2': jnp.int2, 's4': jnp.int4, 's8': jnp.int8, 's16': jnp.int16, 's32': jnp.int32, 's64': jnp.int64, 'bf16': jnp.bfloat16, 'f16': jnp.float16, 'f32': jnp.float32, 'f64': jnp.float64, 'c64': jnp.complex64, 'c128': jnp.complex128 } -if hasattr(jnp, 'int2'): - _DT['s2'] = jnp.int2 -if hasattr(jnp, 'uint2'): - _DT['u2'] = jnp.uint2 _SHAPE_RE = re.compile(f"^({'|'.join(_DT)})\\[\\s*(\\d*[\\s*,\\d+]*)\\s*\\]$") diff --git a/jax/tree_util.py b/jax/tree_util.py index 956d79b9b4ef..3d24c457b3f8 100644 --- a/jax/tree_util.py +++ b/jax/tree_util.py @@ -48,13 +48,13 @@ PyTreeDef as PyTreeDef, SequenceKey as SequenceKey, all_leaves as all_leaves, - build_tree as build_tree, + build_tree as _deprecated_build_tree, default_registry as default_registry, keystr as keystr, + register_dataclass as register_dataclass, register_pytree_node_class as register_pytree_node_class, register_pytree_node as register_pytree_node, register_pytree_with_keys_class as register_pytree_with_keys_class, - register_dataclass as register_dataclass, register_pytree_with_keys as register_pytree_with_keys, register_static as register_static, tree_all as tree_all, @@ -72,3 +72,23 @@ treedef_is_leaf as treedef_is_leaf, treedef_tuple as treedef_tuple, ) + +_deprecations = { + # Added March 21, 2025: + "build_tree": ( + ( + "jax.tree_util.build_tree is deprecated. Use jax.tree.unflatten" + " instead." + ), + _deprecated_build_tree, + ), +} + +import typing as _typing +if _typing.TYPE_CHECKING: + from jax._src.tree_util import build_tree as build_tree +else: + from jax._src.deprecations import deprecation_getattr + __getattr__ = deprecation_getattr(__name__, _deprecations) + del deprecation_getattr, _deprecations +del _typing diff --git a/jax/typing.py b/jax/typing.py index 89efa1f2ca66..0530c69e60ca 100644 --- a/jax/typing.py +++ b/jax/typing.py @@ -15,7 +15,7 @@ """ The JAX typing module is where JAX-specific static type annotations live. This submodule is a work in progress; to see the proposal behind the types exported -here, see https://jax.readthedocs.io/en/latest/jep/12049-type-annotations.html. +here, see https://docs.jax.dev/en/latest/jep/12049-type-annotations.html. The currently-available types are: @@ -67,7 +67,7 @@ def my_function(x: ArrayLike) -> Array: batch-wise transforms like :func:`~jax.vmap` or :func:`jax.pmap`. For more information on this, see `Non-array inputs NumPy vs JAX`_ -.. _Non-array inputs NumPy vs JAX: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#non-array-inputs-numpy-vs-jax +.. _Non-array inputs NumPy vs JAX: https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html#non-array-inputs-numpy-vs-jax """ from jax._src.typing import ( ArrayLike as ArrayLike, diff --git a/jax/util.py b/jax/util.py index 8071f77dffe2..b2c9df205206 100644 --- a/jax/util.py +++ b/jax/util.py @@ -15,19 +15,135 @@ # Note: import as is required for names to be exported. # See PEP 484 & https://github.com/jax-ml/jax/issues/7570 -from jax._src.util import ( - HashableFunction as HashableFunction, - as_hashable_function as as_hashable_function, - cache as cache, - safe_map as safe_map, - safe_zip as safe_zip, - split_dict as split_dict, - split_list as split_list, - split_list_checked as split_list_checked, - split_merge as split_merge, - subvals as subvals, - toposort as toposort, - unzip2 as unzip2, - wrap_name as wrap_name, - wraps as wraps, -) +import jax._src.deprecations +import jax._src.util + + +_deprecations = { + "to_dlpack": ( + ( + "jax.dlpack.to_dlpack was deprecated in JAX v0.6.0 and will be" + " removed in JAX v0.7.0. Please use the newer DLPack API based on" + " __dlpack__ and __dlpack_device__ instead. Typically, you can pass" + " a JAX array directly to the `from_dlpack` function of another" + " framework without using `to_dlpack`." + ), + jax._src.dlpack.to_dlpack, + ), + "HashableFunction": ( + ( + "HashableFunction was deprecated in JAX v0.6.0 and will be removed" + " in JAX v0.7.0." + ), + jax._src.util.HashableFunction, + ), + "as_hashable_function": ( + ( + "as_hashable_function was deprecated in JAX v0.6.0 and will be" + " removed in JAX v0.7.0." + ), + jax._src.util.as_hashable_function, + ), + "cache": ( + "cache was deprecated in JAX v0.6.0 and will be removed in JAX v0.7.0.", + jax._src.util.cache, + ), + "safe_map": ( + ( + "safe_map was deprecated in JAX v0.6.0 and will be removed in JAX" + " v0.7.0." + ), + jax._src.util.safe_map, + ), + "safe_zip": ( + ( + "safe_zip was deprecated in JAX v0.6.0 and will be removed in JAX" + " v0.7.0." + ), + jax._src.util.safe_zip, + ), + "split_dict": ( + ( + "split_dict was deprecated in JAX v0.6.0 and will be removed in JAX" + " v0.7.0." + ), + jax._src.util.split_dict, + ), + "split_list": ( + ( + "split_list was deprecated in JAX v0.6.0 and will be removed in JAX" + " v0.7.0." + ), + jax._src.util.split_list, + ), + "split_list_checked": ( + ( + "split_list_checked was deprecated in JAX v0.6.0 and will be" + " removed in JAX v0.7.0." + ), + jax._src.util.split_list_checked, + ), + "split_merge": ( + ( + "split_merge was deprecated in JAX v0.6.0 and will be removed in" + " JAX v0.7.0." + ), + jax._src.util.split_merge, + ), + "subvals": ( + ( + "subvals was deprecated in JAX v0.6.0 and will be removed in JAX" + " v0.7.0." + ), + jax._src.util.subvals, + ), + "toposort": ( + ( + "toposort was deprecated in JAX v0.6.0 and will be removed in JAX" + " v0.7.0." + ), + jax._src.util.toposort, + ), + "unzip2": ( + ( + "unzip2 was deprecated in JAX v0.6.0 and will be removed in JAX" + " v0.7.0." + ), + jax._src.util.unzip2, + ), + "wrap_name": ( + ( + "wrap_name was deprecated in JAX v0.6.0 and will be removed in JAX" + " v0.7.0." + ), + jax._src.util.wrap_name, + ), + "wraps": ( + "wraps was deprecated in JAX v0.6.0 and will be removed in JAX v0.7.0.", + jax._src.util.wraps, + ), +} + + +import typing as _typing + +if _typing.TYPE_CHECKING: + HashableFunction = jax._src.util.HashableFunction + as_hashable_function = jax._src.util.as_hashable_function + cache = jax._src.util.cache + safe_map = jax._src.util.safe_map + safe_zip = jax._src.util.safe_zip + split_dict = jax._src.util.split_dict + split_list = jax._src.util.split_list + split_list_checked = jax._src.util.split_list_checked + split_merge = jax._src.util.split_merge + subvals = jax._src.util.subvals + toposort = jax._src.util.toposort + unzip2 = jax._src.util.unzip2 + wrap_name = jax._src.util.wrap_name + wraps = jax._src.util.wraps +else: + __getattr__ = jax._src.deprecations.deprecation_getattr( + __name__, _deprecations + ) +del _typing diff --git a/jax/version.py b/jax/version.py index be20aca06358..9301848b0cfb 100644 --- a/jax/version.py +++ b/jax/version.py @@ -21,7 +21,7 @@ import pathlib import subprocess -_version = "0.5.3" +_version = "0.6.1" # The following line is overwritten by build scripts in distributions & # releases. Do not modify this manually, or jax/jaxlib build will fail. _release_version: str | None = None @@ -93,6 +93,12 @@ def _get_version_for_build() -> str: return _version_from_git_tree(_version) or _version_from_todays_date(_version) +def _is_prerelease() -> bool: + """Determine if this is a pre-release ("rc" wheels) build.""" + rc_version = os.getenv("WHEEL_VERSION_SUFFIX", "") + return True if rc_version.startswith("rc") else False + + def _write_version(fname: str) -> None: """Used by setup.py to write the specified version info into the source tree.""" release_version = _get_version_for_build() @@ -146,7 +152,7 @@ def make_release_tree(self, base_dir, files): __version__ = _get_version_string() -_minimum_jaxlib_version = "0.5.1" +_minimum_jaxlib_version = "0.6.0" def _version_as_tuple(version_str): return tuple(int(i) for i in version_str.split(".") if i.isdigit()) diff --git a/jax_plugins/cuda/__init__.py b/jax_plugins/cuda/__init__.py index f6540e986024..13293de7181d 100644 --- a/jax_plugins/cuda/__init__.py +++ b/jax_plugins/cuda/__init__.py @@ -92,8 +92,10 @@ def initialize(): cuda_plugin_extension.register_custom_call_target, c_api ), ) - for _name, _value in cuda_plugin_extension.registrations().items(): - xla_client.register_custom_call_target(_name, _value, platform="CUDA") + for _name, _value in cuda_plugin_extension.ffi_registrations().items(): + xla_client.register_custom_call_target( + _name, _value, platform='CUDA', api_version=1 + ) xla_client.register_custom_type_id_handler( "CUDA", functools.partial( diff --git a/jax_plugins/cuda/plugin_setup.py b/jax_plugins/cuda/plugin_setup.py index ce31684de46f..b9220cd29283 100644 --- a/jax_plugins/cuda/plugin_setup.py +++ b/jax_plugins/cuda/plugin_setup.py @@ -52,12 +52,12 @@ def has_ext_modules(self): python_requires=">=3.10", install_requires=[f"jax-cuda{cuda_version}-pjrt=={__version__}"], extras_require={ - 'with_cuda': [ + 'with-cuda': [ "nvidia-cublas-cu12>=12.1.3.1", "nvidia-cuda-cupti-cu12>=12.1.105", "nvidia-cuda-nvcc-cu12>=12.6.85", "nvidia-cuda-runtime-cu12>=12.1.105", - "nvidia-cudnn-cu12>=9.1,<10.0", + "nvidia-cudnn-cu12>=9.8,<10.0", "nvidia-cufft-cu12>=11.0.2.54", "nvidia-cusolver-cu12>=11.4.5.107", "nvidia-cusparse-cu12>=12.1.0.106", diff --git a/jax_plugins/rocm/__init__.py b/jax_plugins/rocm/__init__.py index c48a681bf337..0b1b077acfcd 100644 --- a/jax_plugins/rocm/__init__.py +++ b/jax_plugins/rocm/__init__.py @@ -92,8 +92,10 @@ def initialize(): rocm_plugin_extension.register_custom_call_target, c_api ), ) - for _name, _value in rocm_plugin_extension.registrations().items(): - xla_client.register_custom_call_target(_name, _value, platform="ROCM") + for _name, _value in rocm_plugin_extension.ffi_registrations().items(): + xla_client.register_custom_call_target( + _name, _value, platform='ROCM', api_version=1 + ) xla_client.register_custom_type_id_handler( "ROCM", functools.partial( diff --git a/jaxlib/BUILD b/jaxlib/BUILD index a35eabc9a505..373f8cd17674 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -17,9 +17,17 @@ load( "//jaxlib:jax.bzl", "nanobind_extension", + "py_deps", "py_library_providing_imports_info", + "py_strict_test", "pytype_library", ) +load( + "//jaxlib:pywrap.bzl", + "nanobind_pywrap_extension", + "pywrap_binaries", + "pywrap_library", +) load("//jaxlib:symlink_files.bzl", "symlink_files") licenses(["notice"]) @@ -29,13 +37,6 @@ package( default_visibility = ["//jax:internal"], ) -# This makes xla_extension module accessible from jax._src.lib. -genrule( - name = "xla_extension_py", - outs = ["xla_extension.py"], - cmd = "echo 'from xla.xla.python.xla_extension import *\n' > $@", -) - py_library_providing_imports_info( name = "jaxlib", srcs = [ @@ -50,15 +51,17 @@ py_library_providing_imports_info( "init.py", "lapack.py", "plugin_support.py", + "xla_client.py", + "xla_extension.py", ":version", - ":xla_client", - ":xla_extension_py", ], data = [":ffi_headers"], lib_rule = pytype_library, deps = [ ":cpu_feature_guard", + ":jax", ":utils", + ":weakref_lru_cache", "//jaxlib/cpu:_lapack", "//jaxlib/mlir", "//jaxlib/mlir:arithmetic_dialect", @@ -81,7 +84,8 @@ py_library_providing_imports_info( "//jaxlib/mlir:vector_dialect", "//jaxlib/mosaic", "//jaxlib/triton", - "@xla//xla/python:xla_extension", + "//jaxlib/xla:xla_client", + "//jaxlib/xla:xla_extension", ], ) @@ -92,13 +96,6 @@ symlink_files( flatten = True, ) -symlink_files( - name = "xla_client", - srcs = ["@xla//xla/python:xla_client"], - dst = ".", - flatten = True, -) - symlink_files( name = "ffi_headers", srcs = ["@xla//xla/ffi/api:all_headers"], @@ -111,6 +108,45 @@ exports_files([ "setup.py", ]) +pywrap_library( + name = "jax", + common_lib_def_files_or_filters = { + "jaxlib/jax_common": "jax_common.json", + }, + common_lib_version_scripts = { + "jaxlib/jax_common": select({ + "@bazel_tools//src/conditions:windows": None, + "@bazel_tools//src/conditions:darwin": "libjax_common_darwin.lds", + "//conditions:default": "libjax_common.lds", + }), + }, + deps = [ + ":utils", + ":weakref_lru_cache", + "//jaxlib/mlir/_mlir_libs:_chlo", + "//jaxlib/mlir/_mlir_libs:_mlir", + "//jaxlib/mlir/_mlir_libs:_mlirDialectsGPU", + "//jaxlib/mlir/_mlir_libs:_mlirDialectsLLVM", + "//jaxlib/mlir/_mlir_libs:_mlirDialectsNVGPU", + "//jaxlib/mlir/_mlir_libs:_mlirDialectsSparseTensor", + "//jaxlib/mlir/_mlir_libs:_mlirGPUPasses", + "//jaxlib/mlir/_mlir_libs:_mlirHlo", + "//jaxlib/mlir/_mlir_libs:_mlirSparseTensorPasses", + "//jaxlib/mlir/_mlir_libs:_mosaic_gpu_ext", + "//jaxlib/mlir/_mlir_libs:_sdy", + "//jaxlib/mlir/_mlir_libs:_stablehlo", + "//jaxlib/mlir/_mlir_libs:_tpu_ext", + "//jaxlib/mlir/_mlir_libs:_triton_ext", + "//jaxlib/mlir/_mlir_libs:register_jax_dialects", + "//jaxlib/xla:xla_extension", + ], +) + +pywrap_binaries( + name = "jaxlib_binaries", + dep = ":jax", +) + cc_library( name = "absl_status_casters", hdrs = ["absl_status_casters.h"], @@ -167,51 +203,54 @@ cc_library( features = ["-use_header_modules"], deps = [ "@com_google_absl//absl/base", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", ], ) -cc_library( - name = "pass_boilerplate", - hdrs = ["pass_boilerplate.h"], - # compatible with libtpu +# This isn't a CPU kernel. This exists to catch cases where jaxlib is built for the wrong +# target architecture. +nanobind_extension( + name = "cpu_feature_guard", + srcs = ["cpu_feature_guard.c"], + module_name = "cpu_feature_guard", deps = [ - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:Support", + "@xla//third_party/python_runtime:headers", ], ) -cc_library( - name = "handle_pool", - hdrs = ["handle_pool.h"], - copts = [ - "-fexceptions", - "-fno-strict-aliasing", - ], - features = ["-use_header_modules"], +nanobind_pywrap_extension( + name = "weakref_lru_cache", + srcs = ["weakref_lru_cache.cc"], + pytype_srcs = ["weakref_lru_cache.pyi"], deps = [ "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", + "@nanobind", + "@xla//third_party/python_runtime:headers", + "@xla//xla/pjrt:lru_cache", + "@xla//xla/tsl/platform:logging", ], ) -# This isn't a CPU kernel. This exists to catch cases where jaxlib is built for the wrong -# target architecture. -nanobind_extension( - name = "cpu_feature_guard", - srcs = ["cpu_feature_guard.c"], - module_name = "cpu_feature_guard", +py_strict_test( + name = "weakref_lru_cache_test", + srcs = ["weakref_lru_cache_test.py"], deps = [ - "@xla//third_party/python_runtime:headers", - ], + ":weakref_lru_cache", + ] + py_deps([ + "absl/flags", + "absl/logging", + "absl/testing", + ]), ) -nanobind_extension( +nanobind_pywrap_extension( name = "utils", srcs = ["utils.cc"], - module_name = "utils", deps = [ "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", diff --git a/jaxlib/cpu/_lapack/__init__.pyi b/jaxlib/cpu/_lapack/__init__.pyi index 4275d8e48813..f8b9a023b480 100644 --- a/jaxlib/cpu/_lapack/__init__.pyi +++ b/jaxlib/cpu/_lapack/__init__.pyi @@ -17,39 +17,3 @@ from . import eig as eig def initialize() -> None: ... def registrations() -> dict: ... - - -# Old-style LAPACK Workspace Size Queries -def cgesdd_rwork_size(m: int, n: int, compute_uv: int) -> int: ... -def cgesdd_work_size(m: int, n: int, job_opt_compute_uv: bool, job_opt_full_matrices: bool) -> int: ... -def dgesdd_work_size(m: int, n: int, job_opt_compute_uv: bool, job_opt_full_matrices: bool) -> int: ... -def gesdd_iwork_size(m: int, n: int) -> int: ... -def heevd_rwork_size(n: int) -> int: ... -def heevd_work_size(n: int) -> int: ... -def lapack_cgehrd_workspace(lda: int, n: int, ilo: int, ihi: int) -> int: ... -def lapack_cgeqrf_workspace(m: int, n: int) -> int: ... -def lapack_chetrd_workspace(lda: int, n: int) -> int: ... -def lapack_cungqr_workspace(m: int, n: int, k: int) -> int: ... -def lapack_dgehrd_workspace(lda: int, n: int, ilo: int, ihi: int) -> int: ... -def lapack_dgeqrf_workspace(m: int, n: int) -> int: ... -def lapack_dorgqr_workspace(m: int, n: int, k: int) -> int: ... -def lapack_dsytrd_workspace(lda: int, n: int) -> int: ... -def lapack_sgehrd_workspace(lda: int, n: int, ilo: int, ihi: int) -> int: ... -def lapack_sgeqrf_workspace(m: int, n: int) -> int: ... -def lapack_sorgqr_workspace(m: int, n: int, k: int) -> int: ... -def lapack_ssytrd_workspace(lda: int, n: int) -> int: ... -def lapack_zgehrd_workspace(lda: int, n: int, ilo: int, ihi: int) -> int: ... -def lapack_zgeqrf_workspace(m: int, n: int) -> int: ... -def lapack_zhetrd_workspace(lda: int, n: int) -> int: ... -def lapack_zungqr_workspace(m: int, n: int, k: int) -> int: ... -def sgesdd_work_size(m: int, n: int, job_opt_compute_uv: bool, job_opt_full_matrices: bool) -> int: ... -def syevd_iwork_size(n: int) -> int: ... -def syevd_work_size(n: int) -> int: ... -def zgesdd_work_size(m: int, n: int, job_opt_compute_uv: bool, job_opt_full_matrices: bool) -> int: ... - - -# FFI Kernel LAPACK Workspace Size Queries -def lapack_cungqr_workspace_ffi(m: int, n: int, k: int) -> int: ... -def lapack_dorgqr_workspace_ffi(m: int, n: int, k: int) -> int: ... -def lapack_sorgqr_workspace_ffi(m: int, n: int, k: int) -> int: ... -def lapack_zungqr_workspace_ffi(m: int, n: int, k: int) -> int: ... diff --git a/jaxlib/cpu/cpu_kernels.cc b/jaxlib/cpu/cpu_kernels.cc index 6ed42496f2f2..a118c20a4490 100644 --- a/jaxlib/cpu/cpu_kernels.cc +++ b/jaxlib/cpu/cpu_kernels.cc @@ -42,70 +42,6 @@ XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("blas_ctrsm", XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("blas_ztrsm", Trsm>::Kernel, "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_sgetrf", Getrf::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_dgetrf", Getrf::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_cgetrf", - Getrf>::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_zgetrf", - Getrf>::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_sgeqrf", Geqrf::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_dgeqrf", Geqrf::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_cgeqrf", - Geqrf>::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_zgeqrf", - Geqrf>::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_sorgqr", Orgqr::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_dorgqr", Orgqr::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_cungqr", - Orgqr>::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_zungqr", - Orgqr>::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_spotrf", Potrf::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_dpotrf", Potrf::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_cpotrf", - Potrf>::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_zpotrf", - Potrf>::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_sgesdd", - RealGesdd::Kernel, "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_dgesdd", - RealGesdd::Kernel, "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM( - "lapack_cgesdd", ComplexGesdd>::Kernel, "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM( - "lapack_zgesdd", ComplexGesdd>::Kernel, "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_ssyevd", - RealSyevd::Kernel, "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_dsyevd", - RealSyevd::Kernel, "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM( - "lapack_cheevd", ComplexHeevd>::Kernel, "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM( - "lapack_zheevd", ComplexHeevd>::Kernel, "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_sgeev", - RealGeev::Kernel, "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_dgeev", - RealGeev::Kernel, "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM( - "lapack_cgeev", ComplexGeev>::Kernel, "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM( - "lapack_zgeev", ComplexGeev>::Kernel, "Host"); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_sgees", RealGees::Kernel, "Host"); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_dgees", diff --git a/jaxlib/cpu/lapack.cc b/jaxlib/cpu/lapack.cc index c104019777e5..1bb3f1f13405 100644 --- a/jaxlib/cpu/lapack.cc +++ b/jaxlib/cpu/lapack.cc @@ -15,8 +15,8 @@ limitations under the License. #include -#include "nanobind/nanobind.h" #include "absl/base/call_once.h" +#include "nanobind/nanobind.h" #include "jaxlib/cpu/lapack_kernels.h" #include "jaxlib/kernel_nanobind_helpers.h" @@ -58,19 +58,11 @@ void GetLapackKernelsFromScipy() { auto lapack_ptr = [&](const char* name) { return nb::cast(lapack_capi[name]).data(); }; - AssignKernelFn>(lapack_ptr("sgetrf")); - AssignKernelFn>(lapack_ptr("dgetrf")); - AssignKernelFn>>(lapack_ptr("cgetrf")); - AssignKernelFn>>(lapack_ptr("zgetrf")); AssignKernelFn>(lapack_ptr("sgetrf")); AssignKernelFn>(lapack_ptr("dgetrf")); AssignKernelFn>(lapack_ptr("cgetrf")); AssignKernelFn>(lapack_ptr("zgetrf")); - AssignKernelFn>(lapack_ptr("sgeqrf")); - AssignKernelFn>(lapack_ptr("dgeqrf")); - AssignKernelFn>>(lapack_ptr("cgeqrf")); - AssignKernelFn>>(lapack_ptr("zgeqrf")); AssignKernelFn>(lapack_ptr("sgeqrf")); AssignKernelFn>(lapack_ptr("dgeqrf")); AssignKernelFn>(lapack_ptr("cgeqrf")); @@ -85,28 +77,16 @@ void GetLapackKernelsFromScipy() { AssignKernelFn>( lapack_ptr("zgeqp3")); - AssignKernelFn>(lapack_ptr("sorgqr")); - AssignKernelFn>(lapack_ptr("dorgqr")); - AssignKernelFn>>(lapack_ptr("cungqr")); - AssignKernelFn>>(lapack_ptr("zungqr")); AssignKernelFn>(lapack_ptr("sorgqr")); AssignKernelFn>(lapack_ptr("dorgqr")); AssignKernelFn>(lapack_ptr("cungqr")); AssignKernelFn>(lapack_ptr("zungqr")); - AssignKernelFn>(lapack_ptr("spotrf")); - AssignKernelFn>(lapack_ptr("dpotrf")); - AssignKernelFn>>(lapack_ptr("cpotrf")); - AssignKernelFn>>(lapack_ptr("zpotrf")); AssignKernelFn>(lapack_ptr("spotrf")); AssignKernelFn>(lapack_ptr("dpotrf")); AssignKernelFn>(lapack_ptr("cpotrf")); AssignKernelFn>(lapack_ptr("zpotrf")); - AssignKernelFn>(lapack_ptr("sgesdd")); - AssignKernelFn>(lapack_ptr("dgesdd")); - AssignKernelFn>>(lapack_ptr("cgesdd")); - AssignKernelFn>>(lapack_ptr("zgesdd")); AssignKernelFn>(lapack_ptr("sgesdd")); AssignKernelFn>(lapack_ptr("dgesdd")); AssignKernelFn>(lapack_ptr("cgesdd")); @@ -116,10 +96,6 @@ void GetLapackKernelsFromScipy() { AssignKernelFn>(lapack_ptr("cgesvd")); AssignKernelFn>(lapack_ptr("zgesvd")); - AssignKernelFn>(lapack_ptr("ssyevd")); - AssignKernelFn>(lapack_ptr("dsyevd")); - AssignKernelFn>>(lapack_ptr("cheevd")); - AssignKernelFn>>(lapack_ptr("zheevd")); AssignKernelFn>( lapack_ptr("ssyevd")); AssignKernelFn>( @@ -129,10 +105,6 @@ void GetLapackKernelsFromScipy() { AssignKernelFn>( lapack_ptr("zheevd")); - AssignKernelFn>(lapack_ptr("sgeev")); - AssignKernelFn>(lapack_ptr("dgeev")); - AssignKernelFn>>(lapack_ptr("cgeev")); - AssignKernelFn>>(lapack_ptr("zgeev")); AssignKernelFn>(lapack_ptr("sgeev")); AssignKernelFn>(lapack_ptr("dgeev")); AssignKernelFn>( @@ -151,10 +123,6 @@ void GetLapackKernelsFromScipy() { AssignKernelFn>( lapack_ptr("zgees")); - AssignKernelFn>(lapack_ptr("sgehrd")); - AssignKernelFn>(lapack_ptr("dgehrd")); - AssignKernelFn>>(lapack_ptr("cgehrd")); - AssignKernelFn>>(lapack_ptr("zgehrd")); AssignKernelFn>( lapack_ptr("sgehrd")); AssignKernelFn>( @@ -186,63 +154,12 @@ nb::dict Registrations() { dict["blas_dtrsm"] = EncapsulateFunction(Trsm::Kernel); dict["blas_ctrsm"] = EncapsulateFunction(Trsm>::Kernel); dict["blas_ztrsm"] = EncapsulateFunction(Trsm>::Kernel); - dict["lapack_sgetrf"] = EncapsulateFunction(Getrf::Kernel); - dict["lapack_dgetrf"] = EncapsulateFunction(Getrf::Kernel); - dict["lapack_cgetrf"] = - EncapsulateFunction(Getrf>::Kernel); - dict["lapack_zgetrf"] = - EncapsulateFunction(Getrf>::Kernel); - dict["lapack_sgeqrf"] = EncapsulateFunction(Geqrf::Kernel); - dict["lapack_dgeqrf"] = EncapsulateFunction(Geqrf::Kernel); - dict["lapack_cgeqrf"] = - EncapsulateFunction(Geqrf>::Kernel); - dict["lapack_zgeqrf"] = - EncapsulateFunction(Geqrf>::Kernel); - dict["lapack_sorgqr"] = EncapsulateFunction(Orgqr::Kernel); - dict["lapack_dorgqr"] = EncapsulateFunction(Orgqr::Kernel); - dict["lapack_cungqr"] = - EncapsulateFunction(Orgqr>::Kernel); - dict["lapack_zungqr"] = - EncapsulateFunction(Orgqr>::Kernel); - dict["lapack_spotrf"] = EncapsulateFunction(Potrf::Kernel); - dict["lapack_dpotrf"] = EncapsulateFunction(Potrf::Kernel); - dict["lapack_cpotrf"] = - EncapsulateFunction(Potrf>::Kernel); - dict["lapack_zpotrf"] = - EncapsulateFunction(Potrf>::Kernel); - dict["lapack_sgesdd"] = EncapsulateFunction(RealGesdd::Kernel); - dict["lapack_dgesdd"] = EncapsulateFunction(RealGesdd::Kernel); - dict["lapack_cgesdd"] = - EncapsulateFunction(ComplexGesdd>::Kernel); - dict["lapack_zgesdd"] = - EncapsulateFunction(ComplexGesdd>::Kernel); - dict["lapack_ssyevd"] = EncapsulateFunction(RealSyevd::Kernel); - dict["lapack_dsyevd"] = EncapsulateFunction(RealSyevd::Kernel); - dict["lapack_cheevd"] = - EncapsulateFunction(ComplexHeevd>::Kernel); - dict["lapack_zheevd"] = - EncapsulateFunction(ComplexHeevd>::Kernel); - dict["lapack_sgeev"] = EncapsulateFunction(RealGeev::Kernel); - dict["lapack_dgeev"] = EncapsulateFunction(RealGeev::Kernel); - dict["lapack_cgeev"] = - EncapsulateFunction(ComplexGeev>::Kernel); - dict["lapack_zgeev"] = - EncapsulateFunction(ComplexGeev>::Kernel); - dict["lapack_sgees"] = EncapsulateFunction(RealGees::Kernel); dict["lapack_dgees"] = EncapsulateFunction(RealGees::Kernel); dict["lapack_cgees"] = EncapsulateFunction(ComplexGees>::Kernel); dict["lapack_zgees"] = EncapsulateFunction(ComplexGees>::Kernel); - - dict["lapack_sgehrd"] = EncapsulateFunction(Gehrd::Kernel); - dict["lapack_dgehrd"] = EncapsulateFunction(Gehrd::Kernel); - dict["lapack_cgehrd"] = - EncapsulateFunction(Gehrd>::Kernel); - dict["lapack_zgehrd"] = - EncapsulateFunction(Gehrd>::Kernel); - dict["lapack_ssytrd"] = EncapsulateFunction(Sytrd::Kernel); dict["lapack_dsytrd"] = EncapsulateFunction(Sytrd::Kernel); dict["lapack_chetrd"] = @@ -335,73 +252,6 @@ NB_MODULE(_lapack, m) { nb::enum_(schur, "Sort") .value("kNoSortEigenvalues", schur::Sort::kNoSortEigenvalues) .value("kSortEigenvalues", schur::Sort::kSortEigenvalues); - - // Old-style LAPACK Workspace Size Queries - m.def("lapack_sgeqrf_workspace", &Geqrf::Workspace, nb::arg("m"), - nb::arg("n")); - m.def("lapack_dgeqrf_workspace", &Geqrf::Workspace, nb::arg("m"), - nb::arg("n")); - m.def("lapack_cgeqrf_workspace", &Geqrf>::Workspace, - nb::arg("m"), nb::arg("n")); - m.def("lapack_zgeqrf_workspace", &Geqrf>::Workspace, - nb::arg("m"), nb::arg("n")); - m.def("lapack_sorgqr_workspace", &Orgqr::Workspace, nb::arg("m"), - nb::arg("n"), nb::arg("k")); - m.def("lapack_dorgqr_workspace", &Orgqr::Workspace, nb::arg("m"), - nb::arg("n"), nb::arg("k")); - m.def("lapack_cungqr_workspace", &Orgqr>::Workspace, - nb::arg("m"), nb::arg("n"), nb::arg("k")); - m.def("lapack_zungqr_workspace", &Orgqr>::Workspace, - nb::arg("m"), nb::arg("n"), nb::arg("k")); - m.def("gesdd_iwork_size", &GesddIworkSize, nb::arg("m"), nb::arg("n")); - m.def("sgesdd_work_size", &RealGesdd::Workspace, nb::arg("m"), - nb::arg("n"), nb::arg("job_opt_compute_uv"), - nb::arg("job_opt_full_matrices")); - m.def("dgesdd_work_size", &RealGesdd::Workspace, nb::arg("m"), - nb::arg("n"), nb::arg("job_opt_compute_uv"), - nb::arg("job_opt_full_matrices")); - m.def("cgesdd_rwork_size", &ComplexGesddRworkSize, nb::arg("m"), nb::arg("n"), - nb::arg("compute_uv")); - m.def("cgesdd_work_size", &ComplexGesdd>::Workspace, - nb::arg("m"), nb::arg("n"), nb::arg("job_opt_compute_uv"), - nb::arg("job_opt_full_matrices")); - m.def("zgesdd_work_size", &ComplexGesdd>::Workspace, - nb::arg("m"), nb::arg("n"), nb::arg("job_opt_compute_uv"), - nb::arg("job_opt_full_matrices")); - m.def("syevd_work_size", &SyevdWorkSize, nb::arg("n")); - m.def("syevd_iwork_size", &SyevdIworkSize, nb::arg("n")); - m.def("heevd_work_size", &HeevdWorkSize, nb::arg("n")); - m.def("heevd_rwork_size", &HeevdRworkSize, nb::arg("n")); - - m.def("lapack_sgehrd_workspace", &Gehrd::Workspace, nb::arg("lda"), - nb::arg("n"), nb::arg("ilo"), nb::arg("ihi")); - m.def("lapack_dgehrd_workspace", &Gehrd::Workspace, nb::arg("lda"), - nb::arg("n"), nb::arg("ilo"), nb::arg("ihi")); - m.def("lapack_cgehrd_workspace", &Gehrd>::Workspace, - nb::arg("lda"), nb::arg("n"), nb::arg("ilo"), nb::arg("ihi")); - m.def("lapack_zgehrd_workspace", &Gehrd>::Workspace, - nb::arg("lda"), nb::arg("n"), nb::arg("ilo"), nb::arg("ihi")); - m.def("lapack_ssytrd_workspace", &Sytrd::Workspace, nb::arg("lda"), - nb::arg("n")); - m.def("lapack_dsytrd_workspace", &Sytrd::Workspace, nb::arg("lda"), - nb::arg("n")); - m.def("lapack_chetrd_workspace", &Sytrd>::Workspace, - nb::arg("lda"), nb::arg("n")); - m.def("lapack_zhetrd_workspace", &Sytrd>::Workspace, - nb::arg("lda"), nb::arg("n")); - // FFI Kernel LAPACK Workspace Size Queries - m.def("lapack_sorgqr_workspace_ffi", - &OrthogonalQr::GetWorkspaceSize, nb::arg("m"), - nb::arg("n"), nb::arg("k")); - m.def("lapack_dorgqr_workspace_ffi", - &OrthogonalQr::GetWorkspaceSize, nb::arg("m"), - nb::arg("n"), nb::arg("k")); - m.def("lapack_cungqr_workspace_ffi", - &OrthogonalQr::GetWorkspaceSize, nb::arg("m"), - nb::arg("n"), nb::arg("k")); - m.def("lapack_zungqr_workspace_ffi", - &OrthogonalQr::GetWorkspaceSize, nb::arg("m"), - nb::arg("n"), nb::arg("k")); } } // namespace diff --git a/jaxlib/cpu/lapack_kernels.cc b/jaxlib/cpu/lapack_kernels.cc index ddc93261eeb5..3b510708a8bb 100644 --- a/jaxlib/cpu/lapack_kernels.cc +++ b/jaxlib/cpu/lapack_kernels.cc @@ -149,8 +149,7 @@ ffi::Error TriMatrixEquationSolver::Kernel( ffi::Buffer x, ffi::Buffer y, // TODO(b/397715595): Remove RemainingArgs no earlier than 180 days after // the release of JAX 0.5.2. - ffi::RemainingArgs, - ffi::ResultBuffer y_out, MatrixParams::Side side, + ffi::RemainingArgs, ffi::ResultBuffer y_out, MatrixParams::Side side, MatrixParams::UpLo uplo, MatrixParams::Transpose trans_x, MatrixParams::Diag diag) { CopyIfDiffBuffer(y, y_out); @@ -189,42 +188,6 @@ template struct TriMatrixEquationSolver; //== LU Decomposition ==// -// lapack getrf - -template -typename Getrf::FnType* Getrf::fn = nullptr; - -template -void Getrf::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) { - int b = *(reinterpret_cast(data[0])); - int m = *(reinterpret_cast(data[1])); - int n = *(reinterpret_cast(data[2])); - const T* a_in = reinterpret_cast(data[3]); - - void** out = reinterpret_cast(out_tuple); - T* a_out = reinterpret_cast(out[0]); - int* ipiv = reinterpret_cast(out[1]); - int* info = reinterpret_cast(out[2]); - if (a_out != a_in) { - std::memcpy(a_out, a_in, - static_cast(b) * static_cast(m) * - static_cast(n) * sizeof(T)); - } - for (int i = 0; i < b; ++i) { - fn(&m, &n, a_out, &m, ipiv, info); - a_out += static_cast(m) * static_cast(n); - ipiv += std::min(m, n); - ++info; - } -} - -template struct Getrf; -template struct Getrf; -template struct Getrf>; -template struct Getrf>; - -// FFI Kernel - template ffi::Error LuDecomposition::Kernel( ffi::Buffer x, ffi::ResultBuffer x_out, @@ -261,55 +224,6 @@ template struct LuDecomposition; //== QR Factorization ==// -// lapack geqrf - -template -typename Geqrf::FnType* Geqrf::fn = nullptr; - -template -void Geqrf::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) { - int b = *(reinterpret_cast(data[0])); - int m = *(reinterpret_cast(data[1])); - int n = *(reinterpret_cast(data[2])); - int lwork = *(reinterpret_cast(data[3])); - const T* a_in = reinterpret_cast(data[4]); - - void** out = reinterpret_cast(out_tuple); - T* a_out = reinterpret_cast(out[0]); - T* tau = reinterpret_cast(out[1]); - int* info = reinterpret_cast(out[2]); - T* work = reinterpret_cast(out[3]); - - if (a_out != a_in) { - std::memcpy(a_out, a_in, - static_cast(b) * static_cast(m) * - static_cast(n) * sizeof(T)); - } - - for (int i = 0; i < b; ++i) { - fn(&m, &n, a_out, &m, tau, work, &lwork, info); - a_out += static_cast(m) * static_cast(n); - tau += std::min(m, n); - ++info; - } -} - -template -int64_t Geqrf::Workspace(lapack_int m, lapack_int n) { - T work = 0; - lapack_int lwork = -1; - lapack_int info = 0; - fn(&m, &n, nullptr, &m, nullptr, &work, &lwork, &info); - return info == 0 ? static_cast(std::real(work)) : -1; -} - -template struct Geqrf; -template struct Geqrf; -template struct Geqrf>; -template struct Geqrf>; - -// FFI Kernel - template ffi::Error QrFactorization::Kernel(ffi::Buffer x, ffi::ResultBuffer x_out, @@ -430,56 +344,6 @@ template struct PivotingQrFactorization; //== Orthogonal QR ==// //== Computes orthogonal matrix Q from QR Decomposition ==// -// lapack orgqr - -template -typename Orgqr::FnType* Orgqr::fn = nullptr; - -template -void Orgqr::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) { - int b = *(reinterpret_cast(data[0])); - int m = *(reinterpret_cast(data[1])); - int n = *(reinterpret_cast(data[2])); - int k = *(reinterpret_cast(data[3])); - int lwork = *(reinterpret_cast(data[4])); - const T* a_in = reinterpret_cast(data[5]); - T* tau = reinterpret_cast(data[6]); - - void** out = reinterpret_cast(out_tuple); - T* a_out = reinterpret_cast(out[0]); - int* info = reinterpret_cast(out[1]); - T* work = reinterpret_cast(out[2]); - - if (a_out != a_in) { - std::memcpy(a_out, a_in, - static_cast(b) * static_cast(m) * - static_cast(n) * sizeof(T)); - } - - for (int i = 0; i < b; ++i) { - fn(&m, &n, &k, a_out, &m, tau, work, &lwork, info); - a_out += static_cast(m) * static_cast(n); - tau += k; - ++info; - } -} - -template -int64_t Orgqr::Workspace(int m, int n, int k) { - T work = 0; - int lwork = -1; - int info = 0; - fn(&m, &n, &k, nullptr, &m, nullptr, &work, &lwork, &info); - return info ? -1 : static_cast(std::real(work)); -} - -template struct Orgqr; -template struct Orgqr; -template struct Orgqr>; -template struct Orgqr>; - -// FFI Kernel - template ffi::Error OrthogonalQr::Kernel(ffi::Buffer x, ffi::Buffer tau, @@ -535,42 +399,6 @@ template struct OrthogonalQr; //== Cholesky Factorization ==// -// lapack potrf - -template -typename Potrf::FnType* Potrf::fn = nullptr; - -template -void Potrf::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) { - int32_t lower = *(reinterpret_cast(data[0])); - int b = *(reinterpret_cast(data[1])); - int n = *(reinterpret_cast(data[2])); - const T* a_in = reinterpret_cast(data[3]); - char uplo = lower ? 'L' : 'U'; - - void** out = reinterpret_cast(out_tuple); - T* a_out = reinterpret_cast(out[0]); - int* info = reinterpret_cast(out[1]); - if (a_out != a_in) { - std::memcpy(a_out, a_in, - static_cast(b) * static_cast(n) * - static_cast(n) * sizeof(T)); - } - - for (int i = 0; i < b; ++i) { - fn(&uplo, &n, a_out, &n, info); - a_out += static_cast(n) * static_cast(n); - ++info; - } -} - -template struct Potrf; -template struct Potrf; -template struct Potrf>; -template struct Potrf>; - -// FFI Kernel - template ffi::Error CholeskyFactorization::Kernel( ffi::Buffer x, MatrixParams::UpLo uplo, @@ -604,162 +432,6 @@ template struct CholeskyFactorization; //== Singular Value Decomposition (SVD) ==// //== using a divide and conquer method ==// -// lapack gesdd - -static char GesddJobz(bool job_opt_compute_uv, bool job_opt_full_matrices) { - if (!job_opt_compute_uv) { - return 'N'; - } else if (!job_opt_full_matrices) { - return 'S'; - } - return 'A'; -} - -lapack_int GesddIworkSize(int64_t m, int64_t n) { - return CastNoOverflow(8 * std::min(m, n), "gesdd iwork"); -} - -template -typename RealGesdd::FnType* RealGesdd::fn = nullptr; - -template -void RealGesdd::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) { - int32_t job_opt_full_matrices = *(reinterpret_cast(data[0])); - int32_t job_opt_compute_uv = *(reinterpret_cast(data[1])); - int b = *(reinterpret_cast(data[2])); - int m = *(reinterpret_cast(data[3])); - int n = *(reinterpret_cast(data[4])); - int lwork = *(reinterpret_cast(data[5])); - T* a_in = reinterpret_cast(data[6]); - - void** out = reinterpret_cast(out_tuple); - T* a_out = reinterpret_cast(out[0]); - T* s = reinterpret_cast(out[1]); - T* u = reinterpret_cast(out[2]); - T* vt = reinterpret_cast(out[3]); - int* info = reinterpret_cast(out[4]); - int* iwork = reinterpret_cast(out[5]); - T* work = reinterpret_cast(out[6]); - - if (a_out != a_in) { - std::memcpy(a_out, a_in, - static_cast(b) * static_cast(m) * - static_cast(n) * sizeof(T)); - } - - char jobz = GesddJobz(job_opt_compute_uv, job_opt_full_matrices); - - int lda = m; - int ldu = m; - int tdu = job_opt_full_matrices ? m : std::min(m, n); - int ldvt = job_opt_full_matrices ? n : std::min(m, n); - - for (int i = 0; i < b; ++i) { - fn(&jobz, &m, &n, a_out, &lda, s, u, &ldu, vt, &ldvt, work, &lwork, iwork, - info); - a_out += static_cast(m) * n; - s += std::min(m, n); - u += static_cast(m) * tdu; - vt += static_cast(ldvt) * n; - ++info; - } -} - -template -int64_t RealGesdd::Workspace(lapack_int m, lapack_int n, - bool job_opt_compute_uv, - bool job_opt_full_matrices) { - T work = 0; - int lwork = -1; - int info = 0; - int ldvt = job_opt_full_matrices ? n : std::min(m, n); - char jobz = GesddJobz(job_opt_compute_uv, job_opt_full_matrices); - fn(&jobz, &m, &n, nullptr, &m, nullptr, nullptr, &m, nullptr, &ldvt, &work, - &lwork, nullptr, &info); - return info ? -1 : static_cast(work); -} - -lapack_int ComplexGesddRworkSize(int64_t m, int64_t n, int compute_uv) { - int64_t mn = std::min(m, n); - if (compute_uv == 0) { - return CastNoOverflow(7 * mn, "complex gesdd rwork"); - } - int64_t mx = std::max(m, n); - return CastNoOverflow( - std::max(5 * mn * mn + 5 * mn, 2 * mx * mn + 2 * mn * mn + mn), - "complex gesdd rwork"); -} - -template -typename ComplexGesdd::FnType* ComplexGesdd::fn = nullptr; - -template -void ComplexGesdd::Kernel(void* out_tuple, void** data, - XlaCustomCallStatus*) { - int32_t job_opt_full_matrices = *(reinterpret_cast(data[0])); - int32_t job_opt_compute_uv = *(reinterpret_cast(data[1])); - int b = *(reinterpret_cast(data[2])); - int m = *(reinterpret_cast(data[3])); - int n = *(reinterpret_cast(data[4])); - int lwork = *(reinterpret_cast(data[5])); - T* a_in = reinterpret_cast(data[6]); - - void** out = reinterpret_cast(out_tuple); - T* a_out = reinterpret_cast(out[0]); - typename T::value_type* s = reinterpret_cast(out[1]); - T* u = reinterpret_cast(out[2]); - T* vt = reinterpret_cast(out[3]); - int* info = reinterpret_cast(out[4]); - int* iwork = reinterpret_cast(out[5]); - typename T::value_type* rwork = - reinterpret_cast(out[6]); - T* work = reinterpret_cast(out[7]); - - if (a_out != a_in) { - std::memcpy(a_out, a_in, - static_cast(b) * static_cast(m) * - static_cast(n) * sizeof(T)); - } - - char jobz = GesddJobz(job_opt_compute_uv, job_opt_full_matrices); - - int lda = m; - int ldu = m; - int tdu = job_opt_full_matrices ? m : std::min(m, n); - int ldvt = job_opt_full_matrices ? n : std::min(m, n); - - for (int i = 0; i < b; ++i) { - fn(&jobz, &m, &n, a_out, &lda, s, u, &ldu, vt, &ldvt, work, &lwork, rwork, - iwork, info); - a_out += static_cast(m) * n; - s += std::min(m, n); - u += static_cast(m) * tdu; - vt += static_cast(ldvt) * n; - ++info; - } -} - -template -int64_t ComplexGesdd::Workspace(lapack_int m, lapack_int n, - bool job_opt_compute_uv, - bool job_opt_full_matrices) { - T work = 0; - int lwork = -1; - int info = 0; - int ldvt = job_opt_full_matrices ? n : std::min(m, n); - char jobz = GesddJobz(job_opt_compute_uv, job_opt_full_matrices); - fn(&jobz, &m, &n, nullptr, &m, nullptr, nullptr, &m, nullptr, &ldvt, &work, - &lwork, nullptr, nullptr, &info); - return info ? -1 : static_cast(work.real()); -} - -template struct RealGesdd; -template struct RealGesdd; -template struct ComplexGesdd>; -template struct ComplexGesdd>; - -// FFI Kernel - namespace internal { template @@ -949,16 +621,16 @@ static ffi::Error SvdQRKernel( for (int64_t i = 0; i < batch_count; ++i) { if constexpr (ffi::IsComplexType()) { - svd::SVDQRType::fn(&mode_v, &mode_v, &x_rows_v, &x_cols_v, x_out_data, - &x_leading_dim_v, singular_values_data, u_data, - &u_leading_dim_v, vt_data, &vt_leading_dim_v, - work_data.get(), &workspace_dim_v, rwork.get(), - info_data); + svd::SVDQRType::fn(&mode_v, &mode_v, &x_rows_v, &x_cols_v, + x_out_data, &x_leading_dim_v, + singular_values_data, u_data, &u_leading_dim_v, + vt_data, &vt_leading_dim_v, work_data.get(), + &workspace_dim_v, rwork.get(), info_data); } else { - svd::SVDQRType::fn(&mode_v, &mode_v, &x_rows_v, &x_cols_v, x_out_data, - &x_leading_dim_v, singular_values_data, u_data, - &u_leading_dim_v, vt_data, &vt_leading_dim_v, - work_data.get(), &workspace_dim_v, info_data); + svd::SVDQRType::fn( + &mode_v, &mode_v, &x_rows_v, &x_cols_v, x_out_data, &x_leading_dim_v, + singular_values_data, u_data, &u_leading_dim_v, vt_data, + &vt_leading_dim_v, work_data.get(), &workspace_dim_v, info_data); } x_out_data += x_out_step; singular_values_data += singular_values_step; @@ -970,9 +642,8 @@ static ffi::Error SvdQRKernel( } template -static absl::StatusOr SvdQRGetWorkspaceSize(lapack_int x_rows, - lapack_int x_cols, - svd::ComputationMode mode) { +static absl::StatusOr SvdQRGetWorkspaceSize( + lapack_int x_rows, lapack_int x_cols, svd::ComputationMode mode) { ffi::NativeType optimal_size = {}; lapack_int info = 0; lapack_int workspace_query = -1; @@ -994,7 +665,8 @@ static absl::StatusOr SvdQRGetWorkspaceSize(lapack_int x_rows, &u_leading_dim_v, nullptr, &vt_leading_dim_v, &optimal_size, &workspace_query, &info); } - return info == 0 ? MaybeCastNoOverflow(std::real(optimal_size)) : -1; + return info == 0 ? MaybeCastNoOverflow(std::real(optimal_size)) + : -1; } } // namespace internal @@ -1053,7 +725,8 @@ ffi::Error SingularValueDecompositionQRComplex::Kernel( } template -absl::StatusOr SingularValueDecompositionQR::GetWorkspaceSize( +absl::StatusOr +SingularValueDecompositionQR::GetWorkspaceSize( lapack_int x_rows, lapack_int x_cols, svd::ComputationMode mode) { return internal::SvdQRGetWorkspaceSize(x_rows, x_cols, mode); } @@ -1077,7 +750,8 @@ absl::StatusOr svd::GetRealWorkspaceSize( 2 * max_dim * min_dim + 2 * min_dim * min_dim + min_dim)); } -absl::StatusOr svd::GetRealWorkspaceSizeQR(int64_t x_rows, int64_t x_cols) { +absl::StatusOr svd::GetRealWorkspaceSizeQR(int64_t x_rows, + int64_t x_cols) { return CastNoOverflow(5 * std::min(x_rows, x_cols)); } @@ -1098,109 +772,6 @@ template struct SingularValueDecompositionQRComplex; //== Eigenvalues and eigenvectors ==// -// lapack syevd/heevd - -// # Workspace sizes, taken from the LAPACK documentation. -lapack_int SyevdWorkSize(int64_t n) { - return CastNoOverflow(1 + 6 * n + 2 * n * n, "syevd lwork"); -} - -lapack_int SyevdIworkSize(int64_t n) { - return CastNoOverflow(3 + 5 * n, "syevd iwork"); -} - -template -typename RealSyevd::FnType* RealSyevd::fn = nullptr; - -template -void RealSyevd::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) { - int32_t lower = *(reinterpret_cast(data[0])); - int b = *(reinterpret_cast(data[1])); - int n = *(reinterpret_cast(data[2])); - const T* a_in = reinterpret_cast(data[3]); - void** out = reinterpret_cast(out_tuple); - T* a_out = reinterpret_cast(out[0]); - T* w_out = reinterpret_cast(out[1]); - int* info_out = reinterpret_cast(out[2]); - T* work = reinterpret_cast(out[3]); - int* iwork = reinterpret_cast(out[4]); - if (a_out != a_in) { - std::memcpy(a_out, a_in, - static_cast(b) * static_cast(n) * - static_cast(n) * sizeof(T)); - } - - char jobz = 'V'; - char uplo = lower ? 'L' : 'U'; - - lapack_int lwork = SyevdWorkSize(n); - lapack_int liwork = SyevdIworkSize(n); - for (int i = 0; i < b; ++i) { - fn(&jobz, &uplo, &n, a_out, &n, w_out, work, &lwork, iwork, &liwork, - info_out); - a_out += static_cast(n) * n; - w_out += n; - ++info_out; - } -} - -// Workspace sizes, taken from the LAPACK documentation. -lapack_int HeevdWorkSize(int64_t n) { - return CastNoOverflow(1 + 2 * n + n * n, "heevd work"); -} - -lapack_int HeevdRworkSize(int64_t n) { - return CastNoOverflow(1 + 5 * n + 2 * n * n, "heevd rwork"); -} - -template -typename ComplexHeevd::FnType* ComplexHeevd::fn = nullptr; - -template -void ComplexHeevd::Kernel(void* out_tuple, void** data, - XlaCustomCallStatus*) { - int32_t lower = *(reinterpret_cast(data[0])); - int b = *(reinterpret_cast(data[1])); - int n = *(reinterpret_cast(data[2])); - const T* a_in = reinterpret_cast(data[3]); - - void** out = reinterpret_cast(out_tuple); - T* a_out = reinterpret_cast(out[0]); - typename T::value_type* w_out = - reinterpret_cast(out[1]); - int* info_out = reinterpret_cast(out[2]); - T* work = reinterpret_cast(out[3]); - typename T::value_type* rwork = - reinterpret_cast(out[4]); - int* iwork = reinterpret_cast(out[5]); - if (a_out != a_in) { - std::memcpy(a_out, a_in, - static_cast(b) * static_cast(n) * - static_cast(n) * sizeof(T)); - } - - char jobz = 'V'; - char uplo = lower ? 'L' : 'U'; - - lapack_int lwork = HeevdWorkSize(n); - lapack_int lrwork = HeevdRworkSize(n); - lapack_int liwork = SyevdIworkSize(n); - for (int i = 0; i < b; ++i) { - fn(&jobz, &uplo, &n, a_out, &n, w_out, work, &lwork, rwork, &lrwork, iwork, - &liwork, info_out); - a_out += static_cast(n) * n; - w_out += n; - ++info_out; - } -} - -template struct RealSyevd; -template struct RealSyevd; -template struct ComplexHeevd>; -template struct ComplexHeevd>; - -// FFI Kernel - absl::StatusOr eig::GetWorkspaceSize(int64_t x_cols, ComputationMode mode) { switch (mode) { @@ -1339,155 +910,6 @@ template struct EigenvalueDecompositionSymmetric; template struct EigenvalueDecompositionHermitian; template struct EigenvalueDecompositionHermitian; -// lapack geev - -template -typename RealGeev::FnType* RealGeev::fn = nullptr; - -template -void RealGeev::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) { - int b = *(reinterpret_cast(data[0])); - int n_int = *(reinterpret_cast(data[1])); - int64_t n = n_int; - char jobvl = *(reinterpret_cast(data[2])); - char jobvr = *(reinterpret_cast(data[3])); - - const T* a_in = reinterpret_cast(data[4]); - - void** out = reinterpret_cast(out_tuple); - T* a_work = reinterpret_cast(out[0]); - T* vl_work = reinterpret_cast(out[1]); - T* vr_work = reinterpret_cast(out[2]); - - T* wr_out = reinterpret_cast(out[3]); - T* wi_out = reinterpret_cast(out[4]); - std::complex* vl_out = reinterpret_cast*>(out[5]); - std::complex* vr_out = reinterpret_cast*>(out[6]); - int* info_out = reinterpret_cast(out[7]); - - // TODO(phawkins): preallocate workspace using XLA. - T work_query; - int lwork = -1; - fn(&jobvl, &jobvr, &n_int, a_work, &n_int, wr_out, wi_out, vl_work, &n_int, - vr_work, &n_int, &work_query, &lwork, info_out); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&work_query, sizeof(work_query)); - lwork = static_cast(work_query); - T* work = new T[lwork]; - - auto is_finite = [](T* a_work, int64_t n) { - for (int64_t j = 0; j < n; ++j) { - for (int64_t k = 0; k < n; ++k) { - if (!std::isfinite(a_work[j * n + k])) { - return false; - } - } - } - return true; - }; - for (int i = 0; i < b; ++i) { - size_t a_size = n * n * sizeof(T); - std::memcpy(a_work, a_in, a_size); - if (is_finite(a_work, n)) { - fn(&jobvl, &jobvr, &n_int, a_work, &n_int, wr_out, wi_out, vl_work, - &n_int, vr_work, &n_int, work, &lwork, info_out); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(a_work, a_size); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(wr_out, sizeof(T) * n); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(wi_out, sizeof(T) * n); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vl_work, sizeof(T) * n * n); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vr_work, sizeof(T) * n * n); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_out, sizeof(int)); - if (info_out[0] == 0) { - UnpackEigenvectors(n, wi_out, vl_work, vl_out); - UnpackEigenvectors(n, wi_out, vr_work, vr_out); - } - } else { - *info_out = -4; - } - a_in += n * n; - wr_out += n; - wi_out += n; - vl_out += n * n; - vr_out += n * n; - ++info_out; - } - delete[] work; -} - -template -typename ComplexGeev::FnType* ComplexGeev::fn = nullptr; - -template -void ComplexGeev::Kernel(void* out_tuple, void** data, - XlaCustomCallStatus*) { - int b = *(reinterpret_cast(data[0])); - int n_int = *(reinterpret_cast(data[1])); - int64_t n = n_int; - char jobvl = *(reinterpret_cast(data[2])); - char jobvr = *(reinterpret_cast(data[3])); - - const T* a_in = reinterpret_cast(data[4]); - - void** out = reinterpret_cast(out_tuple); - T* a_work = reinterpret_cast(out[0]); - typename T::value_type* r_work = - reinterpret_cast(out[1]); - - T* w_out = reinterpret_cast(out[2]); - T* vl_out = reinterpret_cast(out[3]); - T* vr_out = reinterpret_cast(out[4]); - int* info_out = reinterpret_cast(out[5]); - - // TODO(phawkins): preallocate workspace using XLA. - T work_query; - int lwork = -1; - fn(&jobvl, &jobvr, &n_int, a_work, &n_int, w_out, vl_out, &n_int, vr_out, - &n_int, &work_query, &lwork, r_work, info_out); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&work_query, sizeof(work_query)); - lwork = static_cast(work_query.real()); - T* work = new T[lwork]; - - auto is_finite = [](T* a_work, int64_t n) { - for (int64_t j = 0; j < n; ++j) { - for (int64_t k = 0; k < n; ++k) { - T v = a_work[j * n + k]; - if (!std::isfinite(v.real()) || !std::isfinite(v.imag())) { - return false; - } - } - } - return true; - }; - - for (int i = 0; i < b; ++i) { - size_t a_size = n * n * sizeof(T); - std::memcpy(a_work, a_in, a_size); - if (is_finite(a_work, n)) { - fn(&jobvl, &jobvr, &n_int, a_work, &n_int, w_out, vl_out, &n_int, vr_out, - &n_int, work, &lwork, r_work, info_out); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(a_work, a_size); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(w_out, sizeof(T) * n); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vl_out, sizeof(T) * n * n); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vr_out, sizeof(T) * n * n); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_out, sizeof(int)); - } else { - *info_out = -4; - } - a_in += n * n; - w_out += n; - vl_out += n * n; - vr_out += n * n; - info_out += 1; - } - delete[] work; -} - -template struct RealGeev; -template struct RealGeev; -template struct ComplexGeev>; -template struct ComplexGeev>; - -// FFI Kernel - template ffi::Error EigenvalueDecomposition::Kernel( ffi::Buffer x, eig::ComputationMode compute_left, @@ -1968,60 +1390,6 @@ template struct SchurDecompositionComplex; //== Hessenberg Decomposition ==// -// lapack gehrd - -template -typename Gehrd::FnType* Gehrd::fn = nullptr; - -template -void Gehrd::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) { - int32_t n = *reinterpret_cast(data[0]); - int32_t ilo = *reinterpret_cast(data[1]); - int32_t ihi = *reinterpret_cast(data[2]); - int32_t lda = *reinterpret_cast(data[3]); - int32_t batch = *reinterpret_cast(data[4]); - int32_t lwork = *reinterpret_cast(data[5]); - T* a = reinterpret_cast(data[6]); - - void** out = reinterpret_cast(out_tuple); - T* a_out = reinterpret_cast(out[0]); - T* tau = reinterpret_cast(out[1]); - int* info = reinterpret_cast(out[2]); - T* work = reinterpret_cast(out[3]); - - if (a_out != a) { - std::memcpy(a_out, a, - static_cast(batch) * static_cast(n) * - static_cast(n) * sizeof(T)); - } - - int64_t a_plus = static_cast(lda) * static_cast(n); - - for (int i = 0; i < batch; ++i) { - fn(&n, &ilo, &ihi, a_out, &lda, tau, work, &lwork, info); - a_out += a_plus; - tau += n - 1; - ++info; - } -} - -template -int64_t Gehrd::Workspace(lapack_int lda, lapack_int n, lapack_int ilo, - lapack_int ihi) { - T work = 0; - lapack_int lwork = -1; - lapack_int info = 0; - fn(&n, &ilo, &ihi, nullptr, &lda, nullptr, &work, &lwork, &info); - return info == 0 ? static_cast(std::real(work)) : -1; -} - -template struct Gehrd; -template struct Gehrd; -template struct Gehrd>; -template struct Gehrd>; - -// FFI Kernel - template ffi::Error HessenbergDecomposition::Kernel( ffi::Buffer x, lapack_int low, lapack_int high, diff --git a/jaxlib/cpu/lapack_kernels.h b/jaxlib/cpu/lapack_kernels.h index e075ff29387f..71ba8b8a5e0c 100644 --- a/jaxlib/cpu/lapack_kernels.h +++ b/jaxlib/cpu/lapack_kernels.h @@ -154,19 +154,6 @@ struct TriMatrixEquationSolver { //== LU Decomposition ==// -// lapack getrf - -template -struct Getrf { - using FnType = void(lapack_int* m, lapack_int* n, T* a, lapack_int* lda, - lapack_int* ipiv, lapack_int* info); - - static FnType* fn; - static void Kernel(void* out, void** data, XlaCustomCallStatus*); -}; - -// FFI Kernel - template <::xla::ffi::DataType dtype> struct LuDecomposition { using ValueType = ::xla::ffi::NativeType; @@ -182,21 +169,6 @@ struct LuDecomposition { //== QR Factorization ==// -// lapack geqrf - -template -struct Geqrf { - using FnType = void(lapack_int* m, lapack_int* n, T* a, lapack_int* lda, - T* tau, T* work, lapack_int* lwork, lapack_int* info); - - static FnType* fn; - static void Kernel(void* out, void** data, XlaCustomCallStatus*); - - static int64_t Workspace(lapack_int m, lapack_int n); -}; - -// FFI Kernel - template <::xla::ffi::DataType dtype> struct QrFactorization { using ValueType = ::xla::ffi::NativeType; @@ -240,23 +212,8 @@ struct PivotingQrFactorization { static int64_t GetWorkspaceSize(lapack_int x_rows, lapack_int x_cols); }; - //== Orthogonal QR ==// -// lapack orgqr - -template -struct Orgqr { - using FnType = void(lapack_int* m, lapack_int* n, lapack_int* k, T* a, - lapack_int* lda, T* tau, T* work, lapack_int* lwork, - lapack_int* info); - static FnType* fn; - static void Kernel(void* out, void** data, XlaCustomCallStatus*); - static int64_t Workspace(lapack_int m, lapack_int n, lapack_int k); -}; - -// FFI Kernel - template <::xla::ffi::DataType dtype> struct OrthogonalQr { using ValueType = ::xla::ffi::NativeType; @@ -276,16 +233,6 @@ struct OrthogonalQr { //== Cholesky Factorization ==// -// lapack potrf - -template -struct Potrf { - using FnType = void(char* uplo, lapack_int* n, T* a, lapack_int* lda, - lapack_int* info); - static FnType* fn; - static void Kernel(void* out, void** data, XlaCustomCallStatus*); -}; - template <::xla::ffi::DataType dtype> struct CholeskyFactorization { using ValueType = ::xla::ffi::NativeType; @@ -302,41 +249,6 @@ struct CholeskyFactorization { //== Singular Value Decomposition (SVD) ==// -// lapack gesdd - -lapack_int GesddIworkSize(int64_t m, int64_t n); - -template -struct RealGesdd { - using FnType = void(char* jobz, lapack_int* m, lapack_int* n, T* a, - lapack_int* lda, T* s, T* u, lapack_int* ldu, T* vt, - lapack_int* ldvt, T* work, lapack_int* lwork, - lapack_int* iwork, lapack_int* info); - static FnType* fn; - static void Kernel(void* out, void** data, XlaCustomCallStatus*); - - static int64_t Workspace(lapack_int m, lapack_int n, bool job_opt_compute_uv, - bool job_opt_full_matrices); -}; - -lapack_int ComplexGesddRworkSize(int64_t m, int64_t n, int compute_uv); - -template -struct ComplexGesdd { - using FnType = void(char* jobz, lapack_int* m, lapack_int* n, T* a, - lapack_int* lda, typename T::value_type* s, T* u, - lapack_int* ldu, T* vt, lapack_int* ldvt, T* work, - lapack_int* lwork, typename T::value_type* rwork, - lapack_int* iwork, lapack_int* info); - static FnType* fn; - static void Kernel(void* out, void** data, XlaCustomCallStatus*); - - static int64_t Workspace(lapack_int m, lapack_int n, bool job_opt_compute_uv, - bool job_opt_full_matrices); -}; - -// FFI Kernel - template <::xla::ffi::DataType dtype> struct SingularValueDecomposition { static_assert(!::xla::ffi::IsComplexType(), @@ -407,8 +319,8 @@ struct SingularValueDecompositionQR { ::xla::ffi::ResultBuffer info, svd::ComputationMode mode); static absl::StatusOr GetWorkspaceSize(lapack_int x_rows, - lapack_int x_cols, - svd::ComputationMode mode); + lapack_int x_cols, + svd::ComputationMode mode); }; template <::xla::ffi::DataType dtype> @@ -432,8 +344,8 @@ struct SingularValueDecompositionQRComplex { ::xla::ffi::ResultBuffer info, svd::ComputationMode mode); static absl::StatusOr GetWorkspaceSize(lapack_int x_rows, - lapack_int x_cols, - svd::ComputationMode mode); + lapack_int x_cols, + svd::ComputationMode mode); }; namespace svd { @@ -451,42 +363,13 @@ using SVDQRType = std::conditional_t<::xla::ffi::IsComplexType(), absl::StatusOr GetIntWorkspaceSize(int64_t x_rows, int64_t x_cols); absl::StatusOr GetRealWorkspaceSize(int64_t x_rows, int64_t x_cols, ComputationMode mode); -absl::StatusOr GetRealWorkspaceSizeQR(int64_t x_rows, int64_t x_cols); +absl::StatusOr GetRealWorkspaceSizeQR(int64_t x_rows, + int64_t x_cols); } // namespace svd //== Eigenvalues and eigenvectors ==// -// lapack syevd/heevd - -lapack_int SyevdWorkSize(int64_t n); -lapack_int SyevdIworkSize(int64_t n); - -template -struct RealSyevd { - using FnType = void(char* jobz, char* uplo, lapack_int* n, T* a, - lapack_int* lda, T* w, T* work, lapack_int* lwork, - lapack_int* iwork, lapack_int* liwork, lapack_int* info); - static FnType* fn; - static void Kernel(void* out, void** data, XlaCustomCallStatus*); -}; - -lapack_int HeevdWorkSize(int64_t n); -lapack_int HeevdRworkSize(int64_t n); - -template -struct ComplexHeevd { - using FnType = void(char* jobz, char* uplo, lapack_int* n, T* a, - lapack_int* lda, typename T::value_type* w, T* work, - lapack_int* lwork, typename T::value_type* rwork, - lapack_int* lrwork, lapack_int* iwork, lapack_int* liwork, - lapack_int* info); - static FnType* fn; - static void Kernel(void* out, void** data, XlaCustomCallStatus*); -}; - -// FFI Kernel - namespace eig { // Eigenvalue Decomposition @@ -544,8 +427,6 @@ struct EigenvalueDecompositionHermitian { ::xla::ffi::ResultBuffer info, eig::ComputationMode mode); }; -// lapack geev - // LAPACK uses a packed representation to represent a mixture of real // eigenvectors and complex conjugate pairs. This helper unpacks the // representation into regular complex matrices. @@ -574,28 +455,6 @@ static void UnpackEigenvectors(Int n, const T* eigenvals_imag, const T* packed, } } -template -struct RealGeev { - using FnType = void(char* jobvl, char* jobvr, lapack_int* n, T* a, - lapack_int* lda, T* wr, T* wi, T* vl, lapack_int* ldvl, - T* vr, lapack_int* ldvr, T* work, lapack_int* lwork, - lapack_int* info); - static FnType* fn; - static void Kernel(void* out, void** data, XlaCustomCallStatus*); -}; - -template -struct ComplexGeev { - using FnType = void(char* jobvl, char* jobvr, lapack_int* n, T* a, - lapack_int* lda, T* w, T* vl, lapack_int* ldvl, T* vr, - lapack_int* ldvr, T* work, lapack_int* lwork, - typename T::value_type* rwork, lapack_int* info); - static FnType* fn; - static void Kernel(void* out, void** data, XlaCustomCallStatus*); -}; - -// FFI Kernel - template <::xla::ffi::DataType dtype> struct EigenvalueDecomposition { static_assert(!::xla::ffi::IsComplexType(), @@ -737,32 +596,6 @@ struct SchurDecompositionComplex { //== Hessenberg Decomposition ==// //== Reduces a non-symmetric square matrix to upper Hessenberg form ==// -// lapack gehrd - -template -struct Gehrd { - using FnType = void(lapack_int* n, lapack_int* ilo, lapack_int* ihi, T* a, - lapack_int* lda, T* tau, T* work, lapack_int* lwork, - lapack_int* info); - - static FnType* fn; - static void Kernel(void* out, void** data, XlaCustomCallStatus*); - - static int64_t Workspace(lapack_int lda, lapack_int n, lapack_int ilo, - lapack_int ihi); -}; - -template -struct real_type { - typedef T type; -}; -template -struct real_type> { - typedef T type; -}; - -// FFI Kernel - template <::xla::ffi::DataType dtype> struct HessenbergDecomposition { using ValueType = ::xla::ffi::NativeType; @@ -785,6 +618,15 @@ struct HessenbergDecomposition { //== Tridiagonal Reduction ==// //== Reduces a Symmetric/Hermitian square matrix to tridiagonal form ==// +template +struct real_type { + typedef T type; +}; +template +struct real_type> { + typedef T type; +}; + // lapack sytrd/hetrd template diff --git a/jaxlib/cpu/lapack_kernels_using_lapack.cc b/jaxlib/cpu/lapack_kernels_using_lapack.cc index 3c8ddf11cf29..e771aa0e37d1 100644 --- a/jaxlib/cpu/lapack_kernels_using_lapack.cc +++ b/jaxlib/cpu/lapack_kernels_using_lapack.cc @@ -118,114 +118,6 @@ static_assert( std::is_same_v::FnType, jax::Trsm>::FnType>, JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert(std::is_same_v::FnType, - jax::Getrf::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert(std::is_same_v::FnType, - jax::Getrf::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert(std::is_same_v::FnType, - jax::Getrf>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert(std::is_same_v::FnType, - jax::Getrf>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert(std::is_same_v::FnType, - jax::Geqrf::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert(std::is_same_v::FnType, - jax::Geqrf::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert(std::is_same_v::FnType, - jax::Geqrf>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert(std::is_same_v::FnType, - jax::Geqrf>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert(std::is_same_v::FnType, - jax::Orgqr::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert(std::is_same_v::FnType, - jax::Orgqr::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert(std::is_same_v::FnType, - jax::Orgqr>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert(std::is_same_v::FnType, - jax::Orgqr>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::Potrf::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::Potrf::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::Potrf>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::Potrf>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::RealGesdd::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::RealGesdd::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v< - jax::SingularValueDecompositionComplex::FnType, - jax::ComplexGesdd>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v< - jax::SingularValueDecompositionComplex::FnType, - jax::ComplexGesdd>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v< - jax::EigenvalueDecompositionSymmetric::FnType, - jax::RealSyevd::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v< - jax::EigenvalueDecompositionSymmetric::FnType, - jax::RealSyevd::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v< - jax::EigenvalueDecompositionHermitian::FnType, - jax::ComplexHeevd>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v< - jax::EigenvalueDecompositionHermitian::FnType, - jax::ComplexHeevd>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::RealGeev::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::RealGeev::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v< - jax::EigenvalueDecompositionComplex::FnType, - jax::ComplexGeev>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v< - jax::EigenvalueDecompositionComplex::FnType, - jax::ComplexGeev>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); static_assert( std::is_same_v::FnType, jax::Sytrd::FnType>, @@ -258,22 +150,6 @@ static_assert( std::is_same_v::FnType, jax::ComplexGees>::FnType>, JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::Gehrd::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::Gehrd::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::Gehrd>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::Gehrd>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); #undef JAX_KERNEL_FNTYPE_MISMATCH_MSG @@ -283,51 +159,11 @@ static auto init = []() -> int { AssignKernelFn>>(ctrsm_); AssignKernelFn>>(ztrsm_); - AssignKernelFn>(sgetrf_); - AssignKernelFn>(dgetrf_); - AssignKernelFn>>(cgetrf_); - AssignKernelFn>>(zgetrf_); - - AssignKernelFn>(sgeqrf_); - AssignKernelFn>(dgeqrf_); - AssignKernelFn>>(cgeqrf_); - AssignKernelFn>>(zgeqrf_); - - AssignKernelFn>(sorgqr_); - AssignKernelFn>(dorgqr_); - AssignKernelFn>>(cungqr_); - AssignKernelFn>>(zungqr_); - - AssignKernelFn>(spotrf_); - AssignKernelFn>(dpotrf_); - AssignKernelFn>>(cpotrf_); - AssignKernelFn>>(zpotrf_); - - AssignKernelFn>(sgesdd_); - AssignKernelFn>(dgesdd_); - AssignKernelFn>>(cgesdd_); - AssignKernelFn>>(zgesdd_); - - AssignKernelFn>(ssyevd_); - AssignKernelFn>(dsyevd_); - AssignKernelFn>>(cheevd_); - AssignKernelFn>>(zheevd_); - - AssignKernelFn>(sgeev_); - AssignKernelFn>(dgeev_); - AssignKernelFn>>(cgeev_); - AssignKernelFn>>(zgeev_); - AssignKernelFn>(sgees_); AssignKernelFn>(dgees_); AssignKernelFn>>(cgees_); AssignKernelFn>>(zgees_); - AssignKernelFn>(sgehrd_); - AssignKernelFn>(dgehrd_); - AssignKernelFn>>(cgehrd_); - AssignKernelFn>>(zgehrd_); - AssignKernelFn>(ssytrd_); AssignKernelFn>(dsytrd_); AssignKernelFn>>(chetrd_); diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index a9bd35b7768d..be7ac6116d2f 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -64,7 +64,6 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", - "@local_config_cuda//cuda:cublas_headers", "@local_config_cuda//cuda:cuda_headers", "@xla//xla/tsl/cuda:cupti", "@xla//xla/tsl/cuda:cusolver", @@ -89,7 +88,7 @@ cc_library( deps = [ ":cuda_gpu_kernel_helpers", ":cuda_vendor", - "//jaxlib:handle_pool", + "//jaxlib/gpu:handle_pool", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/synchronization", "@local_config_cuda//cuda:cuda_headers", @@ -98,55 +97,6 @@ cc_library( ], ) -cc_library( - name = "cublas_kernels", - srcs = ["//jaxlib/gpu:blas_kernels.cc"], - hdrs = ["//jaxlib/gpu:blas_kernels.h"], - deps = [ - ":cuda_blas_handle_pool", - ":cuda_gpu_kernel_helpers", - ":cuda_make_batch_pointers", - ":cuda_vendor", - "//jaxlib:kernel_helpers", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/base", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/hash", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@local_config_cuda//cuda:cublas_headers", - "@local_config_cuda//cuda:cuda_headers", - "@xla//xla/service:custom_call_status", - "@xla//xla/tsl/cuda:cublas", - "@xla//xla/tsl/cuda:cudart", - ], -) - -nanobind_extension( - name = "_blas", - srcs = ["//jaxlib/gpu:blas.cc"], - copts = [ - "-fexceptions", - "-fno-strict-aliasing", - ], - features = ["-use_header_modules"], - module_name = "_blas", - deps = [ - ":cublas_kernels", - ":cuda_vendor", - "//jaxlib:kernel_nanobind_helpers", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/strings:str_format", - "@nanobind", - "@xla//xla/tsl/cuda:cublas", - "@xla//xla/tsl/cuda:cudart", - "@xla//xla/tsl/python/lib/core:numpy", - ], -) - cc_library( name = "cudnn_rnn_kernels", srcs = ["//jaxlib/gpu:rnn_kernels.cc"], @@ -155,11 +105,12 @@ cc_library( ":cuda_gpu_kernel_helpers", ":cuda_vendor", ":ffi_wrapper", - "//jaxlib:handle_pool", "//jaxlib:kernel_helpers", + "//jaxlib/gpu:handle_pool", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", "@local_config_cuda//cuda:cuda_headers", "@xla//xla/ffi/api:ffi", "@xla//xla/service:custom_call_status", @@ -195,7 +146,7 @@ cc_library( deps = [ ":cuda_gpu_kernel_helpers", ":cuda_vendor", - "//jaxlib:handle_pool", + "//jaxlib/gpu:handle_pool", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/synchronization", "@local_config_cuda//cuda:cuda_headers", @@ -308,8 +259,8 @@ cc_library( ":cuda_gpu_kernel_helpers", ":cuda_vendor", ":ffi_wrapper", - "//jaxlib:handle_pool", "//jaxlib:kernel_helpers", + "//jaxlib/gpu:handle_pool", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -336,6 +287,7 @@ nanobind_extension( ":cuda_vendor", ":cusparse_kernels", "//jaxlib:absl_status_casters", + "//jaxlib:kernel_helpers", "//jaxlib:kernel_nanobind_helpers", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", @@ -343,11 +295,13 @@ nanobind_extension( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/hash", "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@local_config_cuda//cuda:cuda_headers", "@nanobind", + "@xla//xla/service:custom_call_status", "@xla//xla/tsl/cuda:cudart", "@xla//xla/tsl/cuda:cusparse", "@xla//xla/tsl/python/lib/core:numpy", @@ -455,6 +409,7 @@ nanobind_extension( deps = [ ":cuda_gpu_kernel_helpers", ":cuda_prng_kernels", + ":cuda_vendor", "//jaxlib:kernel_nanobind_helpers", "@local_config_cuda//cuda:cuda_headers", "@nanobind", @@ -511,7 +466,6 @@ cc_library( srcs = ["//jaxlib/gpu:gpu_kernels.cc"], visibility = ["//visibility:public"], deps = [ - ":cublas_kernels", ":cuda_linalg_kernels", ":cuda_prng_kernels", ":cuda_vendor", @@ -545,8 +499,8 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/synchronization", - "@tsl//tsl/platform:env", "@xla//xla/service:custom_call_status", "@xla//xla/stream_executor/cuda:cuda_asm_compiler", "@xla//xla/tsl/cuda:cudart", @@ -586,7 +540,9 @@ nanobind_extension( "//jaxlib:absl_status_casters", "//jaxlib:kernel_nanobind_helpers", "//jaxlib/gpu:triton_cc_proto", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", "@nanobind", ], ) @@ -644,7 +600,6 @@ nanobind_extension( py_library( name = "cuda_gpu_support", deps = [ - ":_blas", ":_hybrid", ":_linalg", ":_prng", @@ -657,11 +612,49 @@ py_library( ], ) +cc_library( + name = "py_client_gpu", + srcs = ["//jaxlib/gpu:py_client_gpu.cc"], + hdrs = ["//jaxlib/gpu:py_client_gpu.h"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":cuda_vendor", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@nanobind", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + "@xla//xla:comparison_util", + "@xla//xla:shape_util", + "@xla//xla:xla_data_proto_cc", + "@xla//xla/ffi:ffi_api", + "@xla//xla/ffi/api:ffi", + "@xla//xla/pjrt:host_callback", + "@xla//xla/pjrt:transpose", + "@xla//xla/python:nb_numpy", + "@xla//xla/python:types", + "@xla//xla/service:platform_util", + ], +) + nanobind_extension( name = "cuda_plugin_extension", srcs = ["cuda_plugin_extension.cc"], module_name = "cuda_plugin_extension", deps = [ + ":py_client_gpu", + "//jaxlib:kernel_nanobind_helpers", "//jaxlib/gpu:gpu_plugin_extension", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", diff --git a/jaxlib/cuda/cuda_plugin_extension.cc b/jaxlib/cuda/cuda_plugin_extension.cc index 8d8514bd2740..e753025ac714 100644 --- a/jaxlib/cuda/cuda_plugin_extension.cc +++ b/jaxlib/cuda/cuda_plugin_extension.cc @@ -16,17 +16,20 @@ limitations under the License. #include #include -#include "nanobind/nanobind.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "third_party/gpus/cuda/include/cuda.h" +#include "nanobind/nanobind.h" #include "jaxlib/gpu/gpu_plugin_extension.h" +#include "jaxlib/gpu/py_client_gpu.h" +#include "jaxlib/kernel_nanobind_helpers.h" #include "xla/pjrt/status_casters.h" namespace nb = nanobind; namespace xla { namespace { + static std::string ToString(CUresult result) { const char* error_name; if (cuGetErrorName(result, &error_name)) { @@ -38,10 +41,25 @@ static std::string ToString(CUresult result) { } return absl::StrCat(error_name, ": ", error_string); } + +nb::dict FfiRegistrations() { + nb::dict dict; + nb::dict gpu_callback_dict; + gpu_callback_dict["instantiate"] = + jax::EncapsulateFfiHandler(jax::cuda::kGpuTransposePlanCacheInstantiate); + gpu_callback_dict["execute"] = + jax::EncapsulateFfiHandler(jax::cuda::kXlaFfiPythonGpuCallback); + dict["xla_ffi_python_gpu_callback"] = gpu_callback_dict; + dict["xla_ffi_partitioned_python_gpu_callback"] = gpu_callback_dict; + return dict; +} + } // namespace NB_MODULE(cuda_plugin_extension, m) { BuildGpuPluginExtension(m); + m.def("ffi_registrations", &FfiRegistrations); + m.def( "get_device_ordinal", [](std::intptr_t data_value) { diff --git a/jaxlib/cuda/versions.cc b/jaxlib/cuda/versions.cc index 8d6577f46709..d9f9f4c86865 100644 --- a/jaxlib/cuda/versions.cc +++ b/jaxlib/cuda/versions.cc @@ -13,9 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "jaxlib/cuda/versions_helpers.h" - #include "nanobind/nanobind.h" +#include "jaxlib/cuda/versions_helpers.h" #include "jaxlib/gpu/vendor.h" namespace jax::cuda { diff --git a/jaxlib/cuda/versions_helpers.cc b/jaxlib/cuda/versions_helpers.cc index d42199d37467..508a92c326cb 100644 --- a/jaxlib/cuda/versions_helpers.cc +++ b/jaxlib/cuda/versions_helpers.cc @@ -16,6 +16,7 @@ limitations under the License. #include "jaxlib/cuda/versions_helpers.h" #include +#include #include #include "absl/base/dynamic_annotations.h" diff --git a/jaxlib/ffi_helpers.h b/jaxlib/ffi_helpers.h index 5c6d80093df5..634a48fcffc7 100644 --- a/jaxlib/ffi_helpers.h +++ b/jaxlib/ffi_helpers.h @@ -1,3 +1,18 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #ifndef JAXLIB_FFI_HELPERS_H_ #define JAXLIB_FFI_HELPERS_H_ diff --git a/jaxlib/gpu/BUILD b/jaxlib/gpu/BUILD index b5292746dd10..e153e0588cf6 100644 --- a/jaxlib/gpu/BUILD +++ b/jaxlib/gpu/BUILD @@ -30,11 +30,8 @@ package( ) exports_files(srcs = [ - "blas.cc", "blas_handle_pool.cc", "blas_handle_pool.h", - "blas_kernels.cc", - "blas_kernels.h", "ffi_wrapper.h", "gpu_kernel_helpers.cc", "gpu_kernel_helpers.h", @@ -52,6 +49,8 @@ exports_files(srcs = [ "prng_kernels.cc", "prng_kernels.cu.cc", "prng_kernels.h", + "py_client_gpu.cc", + "py_client_gpu.h", "rnn.cc", "rnn_kernels.cc", "rnn_kernels.h", @@ -82,6 +81,7 @@ proto_library( cc_proto_library( name = "triton_cc_proto", + compatible_with = None, deps = [":triton_proto"], ) @@ -91,6 +91,21 @@ xla_py_proto_library( deps = [":triton_proto"], ) +cc_library( + name = "handle_pool", + hdrs = ["handle_pool.h"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/synchronization", + ], +) + cc_library( name = "gpu_plugin_extension", srcs = ["gpu_plugin_extension.cc"], @@ -115,7 +130,7 @@ cc_library( "@xla//xla/pjrt/c:pjrt_c_api_hdrs", "@xla//xla/pjrt/c:pjrt_c_api_helpers", "@xla//xla/pjrt/c:pjrt_c_api_triton_extension_hdrs", - "@xla//xla/python:py_client_gpu", + "@xla//xla/tsl/platform:statusor", "@xla//xla/tsl/python/lib/core:numpy", ], ) diff --git a/jaxlib/gpu/blas.cc b/jaxlib/gpu/blas.cc deleted file mode 100644 index e8761bd32ac9..000000000000 --- a/jaxlib/gpu/blas.cc +++ /dev/null @@ -1,85 +0,0 @@ -/* Copyright 2019 The JAX Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include - -#include "nanobind/nanobind.h" -#include "nanobind/stl/pair.h" -#include "absl/container/flat_hash_map.h" -#include "absl/strings/str_format.h" -#include "jaxlib/gpu/blas_kernels.h" -#include "jaxlib/gpu/vendor.h" -#include "jaxlib/kernel_nanobind_helpers.h" -#include "xla/tsl/python/lib/core/numpy.h" - -namespace jax { -namespace JAX_GPU_NAMESPACE { -namespace { - -namespace nb = nanobind; - -// Converts a NumPy dtype to a Type. -BlasType DtypeToBlasType(const dtype& np_type) { - static auto* types = new absl::flat_hash_map, BlasType>({ - {{'f', 4}, BlasType::F32}, - {{'f', 8}, BlasType::F64}, - {{'c', 8}, BlasType::C64}, - {{'c', 16}, BlasType::C128}, - }); - auto it = types->find({np_type.kind(), np_type.itemsize()}); - if (it == types->end()) { - nb::str repr = nb::repr(np_type); - throw std::invalid_argument( - absl::StrFormat("Unsupported dtype %s", repr.c_str())); - } - return it->second; -} - -// Returns the descriptor for a GetrfBatched operation. -std::pair BuildGetrfBatchedDescriptor(const dtype& dtype, - int b, int n) { - BlasType type = DtypeToBlasType(dtype); - size_t size = b * sizeof(void*); - return {size, PackDescriptor(GetrfBatchedDescriptor{type, b, n})}; -} - -// Returns the descriptor for a GetrfBatched operation. -std::pair BuildGeqrfBatchedDescriptor(const dtype& dtype, - int b, int m, int n) { - BlasType type = DtypeToBlasType(dtype); - size_t size = b * sizeof(void*); - return {size, PackDescriptor(GeqrfBatchedDescriptor{type, b, m, n})}; -} - -nb::dict Registrations() { - nb::dict dict; - dict[JAX_GPU_PREFIX "blas_getrf_batched"] = EncapsulateFunction(GetrfBatched); - dict[JAX_GPU_PREFIX "blas_geqrf_batched"] = EncapsulateFunction(GeqrfBatched); - return dict; -} - -NB_MODULE(_blas, m) { - tsl::ImportNumpy(); - - m.def("registrations", &Registrations); - m.def("build_getrf_batched_descriptor", &BuildGetrfBatchedDescriptor); - m.def("build_geqrf_batched_descriptor", &BuildGeqrfBatchedDescriptor); -} - -} // namespace -} // namespace JAX_GPU_NAMESPACE -} // namespace jax diff --git a/jaxlib/gpu/blas_handle_pool.cc b/jaxlib/gpu/blas_handle_pool.cc index 2ce204453039..ff381b802ab2 100644 --- a/jaxlib/gpu/blas_handle_pool.cc +++ b/jaxlib/gpu/blas_handle_pool.cc @@ -19,7 +19,7 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/vendor.h" -#include "jaxlib/handle_pool.h" +#include "jaxlib/gpu/handle_pool.h" namespace jax { diff --git a/jaxlib/gpu/blas_handle_pool.h b/jaxlib/gpu/blas_handle_pool.h index b3cdbaa88867..43724baab45e 100644 --- a/jaxlib/gpu/blas_handle_pool.h +++ b/jaxlib/gpu/blas_handle_pool.h @@ -18,7 +18,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "jaxlib/gpu/vendor.h" -#include "jaxlib/handle_pool.h" +#include "jaxlib/gpu/handle_pool.h" namespace jax { diff --git a/jaxlib/gpu/blas_kernels.cc b/jaxlib/gpu/blas_kernels.cc deleted file mode 100644 index ac30aa9cc520..000000000000 --- a/jaxlib/gpu/blas_kernels.cc +++ /dev/null @@ -1,198 +0,0 @@ -/* Copyright 2019 The JAX Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "jaxlib/gpu/blas_kernels.h" - -#include -#include -#include -#include -#include - -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/str_format.h" -#include "jaxlib/gpu/blas_handle_pool.h" -#include "jaxlib/gpu/gpu_kernel_helpers.h" -#include "jaxlib/gpu/make_batch_pointers.h" -#include "jaxlib/gpu/vendor.h" -#include "jaxlib/kernel_helpers.h" -#include "xla/service/custom_call_status.h" - -namespace jax { - -namespace JAX_GPU_NAMESPACE { - -namespace { - -int SizeOfBlasType(BlasType type) { - switch (type) { - case BlasType::F32: - return sizeof(float); - case BlasType::F64: - return sizeof(double); - case BlasType::C64: - return sizeof(gpublasComplex); - case BlasType::C128: - return sizeof(gpublasDoubleComplex); - } -} - -} // namespace - -// Batched LU decomposition: getrfbatched - -static absl::Status GetrfBatched_(gpuStream_t stream, void** buffers, - const char* opaque, size_t opaque_len) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const GetrfBatchedDescriptor& d = **s; - auto h = BlasHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - if (buffers[0] != buffers[1]) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync( - buffers[1], buffers[0], SizeOfBlasType(d.type) * d.batch * d.n * d.n, - gpuMemcpyDeviceToDevice, stream))); - } - - int* ipiv = static_cast(buffers[2]); - int* info = static_cast(buffers[3]); - MakeBatchPointersAsync(stream, buffers[1], buffers[4], d.batch, - SizeOfBlasType(d.type) * d.n * d.n); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuGetLastError())); - switch (d.type) { - case BlasType::F32: { - float** batch_ptrs = static_cast(buffers[4]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpublasSgetrfBatched( - handle.get(), d.n, batch_ptrs, d.n, ipiv, info, d.batch))); - break; - } - case BlasType::F64: { - double** batch_ptrs = static_cast(buffers[4]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpublasDgetrfBatched( - handle.get(), d.n, batch_ptrs, d.n, ipiv, info, d.batch))); - break; - } - case BlasType::C64: { - gpublasComplex** batch_ptrs = static_cast(buffers[4]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpublasCgetrfBatched( - handle.get(), d.n, batch_ptrs, d.n, ipiv, info, d.batch))); - break; - } - case BlasType::C128: { - gpublasDoubleComplex** batch_ptrs = - static_cast(buffers[4]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpublasZgetrfBatched( - handle.get(), d.n, batch_ptrs, d.n, ipiv, info, d.batch))); - break; - } - } - return absl::OkStatus(); -} - -void GetrfBatched(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = GetrfBatched_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - -// Batched QR decomposition: geqrfbatched - -static absl::Status GeqrfBatched_(gpuStream_t stream, void** buffers, - const char* opaque, size_t opaque_len) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const GeqrfBatchedDescriptor& d = **s; - auto h = BlasHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - if (buffers[0] != buffers[1]) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync( - buffers[1], buffers[0], SizeOfBlasType(d.type) * d.batch * d.m * d.n, - gpuMemcpyDeviceToDevice, stream))); - } - - std::vector info(d.batch); - MakeBatchPointersAsync(stream, buffers[1], buffers[3], d.batch, - SizeOfBlasType(d.type) * d.m * d.n); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuGetLastError())); - MakeBatchPointersAsync(stream, buffers[2], buffers[4], d.batch, - SizeOfBlasType(d.type) * std::min(d.m, d.n)); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuGetLastError())); - switch (d.type) { - case BlasType::F32: { - float** a_batch_ptrs = static_cast(buffers[3]); - float** tau_batch_ptrs = static_cast(buffers[4]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - gpublasSgeqrfBatched(handle.get(), d.m, d.n, a_batch_ptrs, d.m, - tau_batch_ptrs, info.data(), d.batch))); - break; - } - case BlasType::F64: { - double** a_batch_ptrs = static_cast(buffers[3]); - double** tau_batch_ptrs = static_cast(buffers[4]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - gpublasDgeqrfBatched(handle.get(), d.m, d.n, a_batch_ptrs, d.m, - tau_batch_ptrs, info.data(), d.batch))); - break; - } - case BlasType::C64: { - gpublasComplex** a_batch_ptrs = static_cast(buffers[3]); - gpublasComplex** tau_batch_ptrs = - static_cast(buffers[4]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - gpublasCgeqrfBatched(handle.get(), d.m, d.n, a_batch_ptrs, d.m, - tau_batch_ptrs, info.data(), d.batch))); - break; - } - case BlasType::C128: { - gpublasDoubleComplex** a_batch_ptrs = - static_cast(buffers[3]); - gpublasDoubleComplex** tau_batch_ptrs = - static_cast(buffers[4]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - gpublasZgeqrfBatched(handle.get(), d.m, d.n, a_batch_ptrs, d.m, - tau_batch_ptrs, info.data(), d.batch))); - break; - } - } - auto it = - std::find_if(info.begin(), info.end(), [](int i) { return i != 0; }); - - if (it != info.end()) { - return absl::InvalidArgumentError( - absl::StrFormat("QR decomposition failed with status %d for batch " - "element %d", - *it, std::distance(info.begin(), it))); - } - - return absl::OkStatus(); -} - -void GeqrfBatched(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = GeqrfBatched_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - -} // namespace JAX_GPU_NAMESPACE -} // namespace jax diff --git a/jaxlib/gpu/blas_kernels.h b/jaxlib/gpu/blas_kernels.h deleted file mode 100644 index 724565ea73d1..000000000000 --- a/jaxlib/gpu/blas_kernels.h +++ /dev/null @@ -1,58 +0,0 @@ -/* Copyright 2019 The JAX Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef JAXLIB_GPU_BLAS_KERNELS_H_ -#define JAXLIB_GPU_BLAS_KERNELS_H_ - -#include - -#include "jaxlib/gpu/vendor.h" -#include "xla/service/custom_call_status.h" - -namespace jax { -namespace JAX_GPU_NAMESPACE { - -// Set of types known to Cusolver. -enum class BlasType { - F32, - F64, - C64, - C128, -}; - -// Batched LU decomposition: getrfbatched - -struct GetrfBatchedDescriptor { - BlasType type; - int batch, n; -}; - -void GetrfBatched(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - -// Batched QR decomposition: geqrfbatched - -struct GeqrfBatchedDescriptor { - BlasType type; - int batch, m, n; -}; - -void GeqrfBatched(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - -} // namespace JAX_GPU_NAMESPACE -} // namespace jax - -#endif // JAXLIB_GPU_BLAS_KERNELS_H_ diff --git a/jaxlib/gpu/gpu_kernel_helpers.cc b/jaxlib/gpu/gpu_kernel_helpers.cc index 5a434f4b6ad5..5b509ad9912d 100644 --- a/jaxlib/gpu/gpu_kernel_helpers.cc +++ b/jaxlib/gpu/gpu_kernel_helpers.cc @@ -15,12 +15,15 @@ limitations under the License. #include "jaxlib/gpu/gpu_kernel_helpers.h" +#include +#include + #include "absl/base/optimization.h" #include "absl/log/check.h" -#include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "jaxlib/gpu/vendor.h" namespace jax { namespace JAX_GPU_NAMESPACE { diff --git a/jaxlib/gpu/gpu_kernel_helpers.h b/jaxlib/gpu/gpu_kernel_helpers.h index aecb8a4fdcf1..0326d7f44620 100644 --- a/jaxlib/gpu/gpu_kernel_helpers.h +++ b/jaxlib/gpu/gpu_kernel_helpers.h @@ -16,11 +16,10 @@ limitations under the License. #ifndef JAXLIB_GPU_GPU_KERNEL_HELPERS_H_ #define JAXLIB_GPU_GPU_KERNEL_HELPERS_H_ -#include +#include #include "absl/base/optimization.h" #include "absl/status/status.h" -#include "absl/status/statusor.h" #include "jaxlib/gpu/vendor.h" #define JAX_AS_STATUS(expr) \ diff --git a/jaxlib/gpu/gpu_kernels.cc b/jaxlib/gpu/gpu_kernels.cc index 242078357254..620f9cf45199 100644 --- a/jaxlib/gpu/gpu_kernels.cc +++ b/jaxlib/gpu/gpu_kernels.cc @@ -16,7 +16,6 @@ limitations under the License. // This file is not used by JAX itself, but exists to assist with running // JAX-generated HLO code from outside of JAX. -#include "jaxlib/gpu/blas_kernels.h" #include "jaxlib/gpu/linalg_kernels.h" #include "jaxlib/gpu/prng_kernels.h" #include "jaxlib/gpu/rnn_kernels.h" @@ -33,24 +32,17 @@ namespace jax { namespace JAX_GPU_NAMESPACE { namespace { -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cublas_getrf_batched", GetrfBatched, - "CUDA"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cublas_geqrf_batched", GeqrfBatched, - "CUDA"); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cudnn_rnn", RNNForward, "CUDA"); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cudnn_rnn_bwd", RNNBackward, "CUDA"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_getrf", Getrf, "CUDA"); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_getrf_ffi", "CUDA", GetrfFfi); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_syrk_ffi", "CUDA", SyrkFfi); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_geqrf", Geqrf, "CUDA"); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_geqrf_ffi", "CUDA", GeqrfFfi); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_csrlsvqr", Csrlsvqr, "CUDA"); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_csrlsvqr_ffi", "CUDA", CsrlsvqrFfi); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_orgqr", Orgqr, "CUDA"); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_orgqr_ffi", "CUDA", OrgqrFfi); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_syevd", Syevd, "CUDA"); diff --git a/jaxlib/gpu/gpu_plugin_extension.cc b/jaxlib/gpu/gpu_plugin_extension.cc index b56cb8337f1b..cca615cfb260 100644 --- a/jaxlib/gpu/gpu_plugin_extension.cc +++ b/jaxlib/gpu/gpu_plugin_extension.cc @@ -20,13 +20,13 @@ limitations under the License. #include #include -#include "nanobind/nanobind.h" -#include "nanobind/stl/string.h" // IWYU pragma: keep -#include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "jaxlib/kernel_nanobind_helpers.h" #include "xla/ffi/api/c_api.h" #include "xla/pjrt/c/pjrt_c_api.h" @@ -35,7 +35,7 @@ limitations under the License. #include "xla/pjrt/c/pjrt_c_api_helpers.h" #include "xla/pjrt/c/pjrt_c_api_triton_extension.h" #include "xla/pjrt/status_casters.h" -#include "xla/python/py_client_gpu.h" +#include "xla/tsl/platform/statusor.h" #include "xla/tsl/python/lib/core/numpy.h" #include "xla/util.h" @@ -202,13 +202,6 @@ absl::Status RegisterCustomTypeId(const PJRT_Api* c_api, return absl::OkStatus(); } -nb::dict Registrations() { - nb::dict dict; - dict["xla_python_gpu_callback"] = - jax::EncapsulateFunction(xla::XlaPythonGpuCallback); - return dict; -} - } // namespace void BuildGpuPluginExtension(nanobind::module_& m) { @@ -264,7 +257,6 @@ void BuildGpuPluginExtension(nanobind::module_& m) { type_name_size, std::move(type_id))); }, nb::arg("c_api"), nb::arg("type_name"), nb::arg("type_id")); - m.def("registrations", &Registrations); } } // namespace xla diff --git a/jaxlib/handle_pool.h b/jaxlib/gpu/handle_pool.h similarity index 96% rename from jaxlib/handle_pool.h rename to jaxlib/gpu/handle_pool.h index 9201d8d579c5..9189bb174b06 100644 --- a/jaxlib/handle_pool.h +++ b/jaxlib/gpu/handle_pool.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef JAXLIB_HANDLE_POOL_H_ -#define JAXLIB_HANDLE_POOL_H_ +#ifndef JAXLIB_GPU_HANDLE_POOL_H_ +#define JAXLIB_GPU_HANDLE_POOL_H_ #include #include @@ -107,4 +107,4 @@ void HandlePool::Return(HandleType handle, } // namespace jax -#endif // JAXLIB_HANDLE_POOL_H_ +#endif // JAXLIB_GPU_HANDLE_POOL_H_ diff --git a/jaxlib/gpu/hybrid.cc b/jaxlib/gpu/hybrid.cc index 94975a5b969f..71c320a60f02 100644 --- a/jaxlib/gpu/hybrid.cc +++ b/jaxlib/gpu/hybrid.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "nanobind/nanobind.h" #include "absl/base/call_once.h" +#include "nanobind/nanobind.h" #include "jaxlib/cpu/lapack_kernels.h" #include "jaxlib/gpu/hybrid_kernels.h" #include "jaxlib/gpu/vendor.h" diff --git a/jaxlib/gpu/linalg_kernels.cc b/jaxlib/gpu/linalg_kernels.cc index 2293bef89b7d..b48e64f2181d 100644 --- a/jaxlib/gpu/linalg_kernels.cc +++ b/jaxlib/gpu/linalg_kernels.cc @@ -90,8 +90,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(CholeskyUpdateFfi, CholeskyUpdateFfiImpl, namespace { ffi::Error LuPivotsToPermutationImpl( - gpuStream_t stream, ffi::Dictionary /* unused */, - ffi::Buffer pivots, + gpuStream_t stream, ffi::Buffer pivots, ffi::Result> permutation) { FFI_ASSIGN_OR_RETURN((auto [batch_size, pivot_size]), SplitBatch1D(pivots.dimensions())); @@ -119,10 +118,6 @@ ffi::Error LuPivotsToPermutationImpl( XLA_FFI_DEFINE_HANDLER_SYMBOL(LuPivotsToPermutation, LuPivotsToPermutationImpl, ffi::Ffi::Bind() .Ctx>() - // TODO(b/358275922): remove Attrs (and the - // unused Dictionary above) 12 weeks after - // release of jaxlib v0.4.32. - .Attrs() .Arg>() .Ret>()); diff --git a/jaxlib/gpu/make_batch_pointers.cu.cc b/jaxlib/gpu/make_batch_pointers.cu.cc index 3a24e355ead0..1d05fa8adcac 100644 --- a/jaxlib/gpu/make_batch_pointers.cu.cc +++ b/jaxlib/gpu/make_batch_pointers.cu.cc @@ -16,6 +16,7 @@ limitations under the License. #include "jaxlib/gpu/make_batch_pointers.h" #include +#include #include #include "jaxlib/gpu/vendor.h" diff --git a/jaxlib/gpu/prng.cc b/jaxlib/gpu/prng.cc index 1ce428d7f9dc..007e51b76de7 100644 --- a/jaxlib/gpu/prng.cc +++ b/jaxlib/gpu/prng.cc @@ -15,6 +15,7 @@ limitations under the License. #include "nanobind/nanobind.h" #include "jaxlib/gpu/prng_kernels.h" +#include "jaxlib/gpu/vendor.h" #include "jaxlib/kernel_nanobind_helpers.h" namespace jax { diff --git a/jaxlib/gpu/prng_kernels.cc b/jaxlib/gpu/prng_kernels.cc index f5d6abef83f8..1dac1e47bd44 100644 --- a/jaxlib/gpu/prng_kernels.cc +++ b/jaxlib/gpu/prng_kernels.cc @@ -17,16 +17,12 @@ limitations under the License. #include #include -#include #include "absl/algorithm/container.h" -#include "absl/status/status.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/vendor.h" #include "jaxlib/ffi_helpers.h" -#include "jaxlib/kernel_helpers.h" #include "xla/ffi/api/ffi.h" -#include "xla/service/custom_call_status.h" namespace jax { namespace JAX_GPU_NAMESPACE { diff --git a/jaxlib/gpu/prng_kernels.cu.cc b/jaxlib/gpu/prng_kernels.cu.cc index d4aaec62320d..e42165f95d15 100644 --- a/jaxlib/gpu/prng_kernels.cu.cc +++ b/jaxlib/gpu/prng_kernels.cu.cc @@ -15,8 +15,7 @@ limitations under the License. #include "jaxlib/gpu/prng_kernels.h" -#include -#include +#include #include #include "jaxlib/gpu/vendor.h" diff --git a/jaxlib/gpu/prng_kernels.h b/jaxlib/gpu/prng_kernels.h index c98fd485700d..4d64d2b4a4e4 100644 --- a/jaxlib/gpu/prng_kernels.h +++ b/jaxlib/gpu/prng_kernels.h @@ -16,12 +16,10 @@ limitations under the License. #ifndef JAXLIB_GPU_PRNG_KERNELS_H_ #define JAXLIB_GPU_PRNG_KERNELS_H_ -#include #include #include "jaxlib/gpu/vendor.h" #include "xla/ffi/api/ffi.h" -#include "xla/service/custom_call_status.h" namespace jax { namespace JAX_GPU_NAMESPACE { diff --git a/jaxlib/gpu/py_client_gpu.cc b/jaxlib/gpu/py_client_gpu.cc new file mode 100644 index 000000000000..580e0130c3a8 --- /dev/null +++ b/jaxlib/gpu/py_client_gpu.cc @@ -0,0 +1,229 @@ +/* Copyright 2022 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/gpu/py_client_gpu.h" + +#include + +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/ascii.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "jaxlib/gpu/vendor.h" +#include "xla/ffi/api/ffi.h" +#include "xla/ffi/ffi_api.h" +#include "xla/pjrt/host_callback.h" +#include "xla/pjrt/transpose.h" +#include "xla/primitive_util.h" +#include "xla/python/nb_numpy.h" +#include "xla/python/types.h" +#include "xla/shape_util.h" +#include "xla/xla_data.pb.h" + +namespace nb = nanobind; + +namespace jax { +namespace JAX_GPU_NAMESPACE { + +struct GpuTransposePlanCache { + static xla::ffi::TypeId id; + explicit GpuTransposePlanCache(int capacity) : cache(capacity) {} + xla::TransposePlanCache cache; +}; +xla::ffi::TypeId GpuTransposePlanCache::id = {}; + +XLA_FFI_REGISTER_TYPE(xla::ffi::GetXlaFfiApi(), "GpuTransposePlanCache", + &GpuTransposePlanCache::id); + +static xla::ffi::ErrorOr> +GpuTransposePlanCacheInstantiate(uint64_t index) { + return std::make_unique(16); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + kGpuTransposePlanCacheInstantiate, GpuTransposePlanCacheInstantiate, + xla::ffi::Ffi::BindInstantiate().Attr("index")); +xla::ffi::Error XlaFfiPythonGpuCallback(gpuStream_t stream, + xla::FfiLoadedHostCallbacks* callbacks, + GpuTransposePlanCache* transpose_cache, + uint64_t index, + xla::ffi::RemainingArgs args, + xla::ffi::RemainingRets rets) { + size_t arity = args.size(); + std::vector host_input_buffers(arity); + // Copy input GPU buffers to host + for (size_t i = 0; i < arity; ++i) { + auto arg = args.get(i); + auto ptype = static_cast(arg->element_type()); + // TODO(b/395428868): Remove this check once we support subbyte types. + if (ptype == xla::S1 || ptype == xla::S2 || ptype == xla::S4 || + ptype == xla::U1 || ptype == xla::U2 || ptype == xla::U4) { + return xla::ffi::Error(xla::ffi::ErrorCode::kUnimplemented, + absl::StrFormat("Unsupported primitive type: %s", + PrimitiveType_Name(ptype))); + } + if (ptype == xla::TOKEN) { + host_input_buffers[i] = nullptr; + continue; + } + void* buf = new char[arg->size_bytes()]; + host_input_buffers[i] = buf; + // TODO(b/238441608): Use pinned memory here to speed up the transfer. + auto gpu_res = + gpuMemcpyAsync(buf, arg.value().untyped_data(), arg->size_bytes(), + gpuMemcpyDeviceToHost, stream); + CHECK_EQ(gpu_res, gpuSuccess) << "Failed to gpuMemcpyAsync"; + } + CHECK_EQ(gpuStreamSynchronize(stream), gpuSuccess) + << "Failed to gpuStreamSynchronize"; + nb::gil_scoped_acquire gil; + auto callback = nb::borrow( + static_cast(callbacks->callbacks[index])); + nb::tuple host_input_arrays = nb::steal(PyTuple_New(arity)); + for (size_t i = 0; i < arity; ++i) { + auto arg = args.get(i); + auto ptype = static_cast(arg->element_type()); + if (ptype == xla::TOKEN) { + PyTuple_SET_ITEM(host_input_arrays.ptr(), i, nb::none().inc_ref().ptr()); + continue; + } + nb::capsule base(host_input_buffers[i], [](void* ptr) noexcept { + delete[] static_cast(ptr); + }); + auto maybe_dtype = PrimitiveTypeToNbDtype(ptype); + if (!maybe_dtype.ok()) { + return xla::ffi::Error::Internal(maybe_dtype.status().ToString()); + } + auto dtype = maybe_dtype.value(); + auto dims = absl::Span(arg->dimensions().begin(), + arg->dimensions().size()); + auto array = xla::nb_numpy_ndarray(dtype, dims, std::nullopt, + host_input_buffers[i], base); + array.attr("flags").attr("writeable") = nb::bool_(false); + PyTuple_SET_ITEM(host_input_arrays.ptr(), i, array.inc_ref().ptr()); + } + + xla::EnterHostCallback(); + // TODO(dsuo): Change this to use the Python vectorcall protocol, which allows + // you to avoid constructing a tuple for the arguments. + nb::tuple result_tuple; + try { + auto result_object = callback(*nb::borrow(host_input_arrays)); + result_tuple = nb::cast(result_object); + } catch (nb::python_error& e) { + return xla::ffi::Error::Internal( + absl::StrFormat("CpuCallback error calling callback: %s", e.what())); + } + xla::LeaveHostCallback(); + + std::vector temp_buffers; + for (size_t i = 0; i < rets.size(); ++i) { + auto ret = rets.get(i).value(); + auto ptype = static_cast(ret->element_type()); + // TODO(b/395428868): Remove this check once we support subbyte types. + if (ptype == xla::S1 || ptype == xla::S2 || ptype == xla::S4 || + ptype == xla::U1 || ptype == xla::U2 || ptype == xla::U4) { + return xla::ffi::Error(xla::ffi::ErrorCode::kUnimplemented, + absl::StrFormat("Unsupported primitive type: %s", + PrimitiveType_Name(ptype))); + } + if (ptype == xla::TOKEN) continue; + nb::object output = + nb::borrow(PyTuple_GetItem(result_tuple.ptr(), i)); + auto array = xla::nb_numpy_ndarray::ensure(std::move(output)); + absl::Span strides( + reinterpret_cast(array.strides()), array.ndim()); + // We expect the output to be in default numpy layout. + auto dims = absl::Span(ret->dimensions().begin(), + ret->dimensions().size()); + auto maybe_expected_shape = xla::ShapeUtil::MakeValidatedShape(ptype, dims); + if (!maybe_expected_shape.ok()) { + return xla::ffi::Error::Internal( + maybe_expected_shape.status().ToString()); + } + auto expected_shape = maybe_expected_shape.value(); + auto expected_strides = xla::ByteStridesForShape(expected_shape); + if (strides == expected_strides) { + auto gpu_res = + gpuMemcpyAsync(ret->untyped_data(), array.data(), ret->size_bytes(), + gpuMemcpyHostToDevice, stream); + CHECK_EQ(gpu_res, gpuSuccess) << "Failed to gpuMemcpyAsync"; + continue; + } + void* temp = new char[ret->size_bytes()]; + temp_buffers.push_back(temp); + xla::TransposePlan::Options options; + options.elem_size_in_bytes = xla::primitive_util::ByteWidth(ptype); + options.dims = absl::Span( + reinterpret_cast(array.shape()), array.ndim()); + absl::InlinedVector reversed_layout; + reversed_layout.resize(expected_shape.dimensions().size()); + absl::c_reverse_copy(expected_shape.layout().minor_to_major(), + reversed_layout.begin()); + options.permutation = reversed_layout; + options.input_layout = xla::TransposePlan::Striding{strides}; + auto maybe_plan = transpose_cache->cache.GetOrCreate(options); + if (!maybe_plan.ok()) { + return xla::ffi::Error::Internal(maybe_plan.status().ToString()); + } + auto plan = maybe_plan.value(); + plan->Execute(array.data(), temp); + auto gpu_res = gpuMemcpyAsync(ret->untyped_data(), temp, ret->size_bytes(), + gpuMemcpyHostToDevice, stream); + CHECK_EQ(gpu_res, gpuSuccess) << "Failed to gpuMemcpyAsync"; + } + nb::gil_scoped_release release; + CHECK_EQ(gpuStreamSynchronize(stream), gpuSuccess) + << "Failed to gpuStreamSynchronize"; + for (int i = 0; i < temp_buffers.size(); ++i) { + delete[] static_cast(temp_buffers[i]); + } + return xla::ffi::Error::Success(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + kXlaFfiPythonGpuCallback, XlaFfiPythonGpuCallback, + xla::ffi::Ffi::Bind() + .Ctx>() + .Ctx>() + .Ctx>() + .Attr("index") + .RemainingArgs() + .RemainingRets()); +XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), + "xla_ffi_python_gpu_callback", + absl::AsciiStrToUpper(JAX_GPU_PLUGIN_NAME), + {kGpuTransposePlanCacheInstantiate, nullptr, nullptr, + kXlaFfiPythonGpuCallback}); +XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), + "xla_ffi_partitioned_python_gpu_callback", + absl::AsciiStrToUpper(JAX_GPU_PLUGIN_NAME), + {kGpuTransposePlanCacheInstantiate, nullptr, nullptr, + kXlaFfiPythonGpuCallback}); +} // namespace JAX_GPU_NAMESPACE +} // namespace jax diff --git a/jaxlib/gpu/py_client_gpu.h b/jaxlib/gpu/py_client_gpu.h new file mode 100644 index 000000000000..4d48858ad278 --- /dev/null +++ b/jaxlib/gpu/py_client_gpu.h @@ -0,0 +1,30 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAX_JAXLIB_GPU_PY_CLIENT_GPU_H_ +#define JAX_JAXLIB_GPU_PY_CLIENT_GPU_H_ + + +#include "jaxlib/gpu/vendor.h" +#include "xla/ffi/api/ffi.h" + +namespace jax { +namespace JAX_GPU_NAMESPACE { +XLA_FFI_DECLARE_HANDLER_SYMBOL(kGpuTransposePlanCacheInstantiate); +XLA_FFI_DECLARE_HANDLER_SYMBOL(kXlaFfiPythonGpuCallback); +} // namespace JAX_GPU_NAMESPACE +} // namespace jax + +#endif // JAX_JAXLIB_GPU_PY_CLIENT_GPU_H_ diff --git a/jaxlib/gpu/rnn.cc b/jaxlib/gpu/rnn.cc index eaa815d33e68..32e0842e3038 100644 --- a/jaxlib/gpu/rnn.cc +++ b/jaxlib/gpu/rnn.cc @@ -16,7 +16,7 @@ limitations under the License. #include #include "nanobind/nanobind.h" -#include "nanobind/stl/pair.h" +#include "nanobind/stl/pair.h" // IWYU pragma: keep #include "jaxlib/absl_status_casters.h" #include "jaxlib/gpu/rnn_kernels.h" #include "jaxlib/gpu/vendor.h" diff --git a/jaxlib/gpu/rnn_kernels.cc b/jaxlib/gpu/rnn_kernels.cc index e9820bc31f1e..d06535a668ac 100644 --- a/jaxlib/gpu/rnn_kernels.cc +++ b/jaxlib/gpu/rnn_kernels.cc @@ -16,14 +16,19 @@ limitations under the License. #include "jaxlib/gpu/rnn_kernels.h" #include +#include +#include #include #include #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_format.h" +#include "absl/synchronization/mutex.h" #include "jaxlib/gpu/ffi_wrapper.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" -#include "jaxlib/handle_pool.h" +#include "jaxlib/gpu/handle_pool.h" +#include "jaxlib/gpu/vendor.h" #include "jaxlib/kernel_helpers.h" #include "xla/service/custom_call_status.h" diff --git a/jaxlib/gpu/rnn_kernels.h b/jaxlib/gpu/rnn_kernels.h index e95b7788382a..36d8c25c6a9f 100644 --- a/jaxlib/gpu/rnn_kernels.h +++ b/jaxlib/gpu/rnn_kernels.h @@ -17,6 +17,7 @@ limitations under the License. #define JAXLIB_GPU_RNN_KERNELS_H_ #include +#include #include "absl/status/statusor.h" #include "jaxlib/gpu/vendor.h" diff --git a/jaxlib/gpu/solver.cc b/jaxlib/gpu/solver.cc index 357a38eecfd5..3c76598e5285 100644 --- a/jaxlib/gpu/solver.cc +++ b/jaxlib/gpu/solver.cc @@ -17,11 +17,11 @@ limitations under the License. #include #include -#include "nanobind/nanobind.h" -#include "nanobind/stl/pair.h" #include "absl/container/flat_hash_map.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/pair.h" // IWYU pragma: keep #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/solver_handle_pool.h" #include "jaxlib/gpu/solver_kernels.h" @@ -54,84 +54,6 @@ SolverType DtypeToSolverType(const dtype& np_type) { return it->second; } -// getrf: LU decomposition - -// Returns the workspace size and a descriptor for a getrf operation. -std::pair BuildGetrfDescriptor(const dtype& dtype, int b, int m, - int n) { - SolverType type = DtypeToSolverType(dtype); - auto h = SolverHandlePool::Borrow(/*stream=*/nullptr); - JAX_THROW_IF_ERROR(h.status()); - auto& handle = *h; - int lwork; - switch (type) { - case SolverType::F32: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(gpusolverDnSgetrf_bufferSize(handle.get(), m, n, - /*A=*/nullptr, - /*lda=*/m, &lwork))); - break; - case SolverType::F64: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(gpusolverDnDgetrf_bufferSize(handle.get(), m, n, - /*A=*/nullptr, - /*lda=*/m, &lwork))); - break; - case SolverType::C64: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(gpusolverDnCgetrf_bufferSize(handle.get(), m, n, - /*A=*/nullptr, - /*lda=*/m, &lwork))); - break; - case SolverType::C128: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(gpusolverDnZgetrf_bufferSize(handle.get(), m, n, - /*A=*/nullptr, - /*lda=*/m, &lwork))); - break; - } - return {lwork, PackDescriptor(GetrfDescriptor{type, b, m, n, lwork})}; -} - -// geqrf: QR decomposition - -// Returns the workspace size and a descriptor for a geqrf operation. -std::pair BuildGeqrfDescriptor(const dtype& dtype, int b, int m, - int n) { - SolverType type = DtypeToSolverType(dtype); - auto h = SolverHandlePool::Borrow(/*stream=*/nullptr); - JAX_THROW_IF_ERROR(h.status()); - auto& handle = *h; - int lwork; - switch (type) { - case SolverType::F32: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(gpusolverDnSgeqrf_bufferSize(handle.get(), m, n, - /*A=*/nullptr, - /*lda=*/m, &lwork))); - break; - case SolverType::F64: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(gpusolverDnDgeqrf_bufferSize(handle.get(), m, n, - /*A=*/nullptr, - /*lda=*/m, &lwork))); - break; - case SolverType::C64: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(gpusolverDnCgeqrf_bufferSize(handle.get(), m, n, - /*A=*/nullptr, - /*lda=*/m, &lwork))); - break; - case SolverType::C128: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(gpusolverDnZgeqrf_bufferSize(handle.get(), m, n, - /*A=*/nullptr, - /*lda=*/m, &lwork))); - break; - } - return {lwork, PackDescriptor(GeqrfDescriptor{type, b, m, n, lwork})}; -} - #ifdef JAX_GPU_CUDA // csrlsvqr: Linear system solve via Sparse QR @@ -145,49 +67,6 @@ nb::bytes BuildCsrlsvqrDescriptor(const dtype& dtype, int n, int nnzA, #endif // JAX_GPU_CUDA -// orgqr/ungqr: apply elementary Householder transformations - -// Returns the workspace size and a descriptor for a geqrf operation. -std::pair BuildOrgqrDescriptor(const dtype& dtype, int b, int m, - int n, int k) { - SolverType type = DtypeToSolverType(dtype); - auto h = SolverHandlePool::Borrow(/*stream=*/nullptr); - JAX_THROW_IF_ERROR(h.status()); - auto& handle = *h; - int lwork; - switch (type) { - case SolverType::F32: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(gpusolverDnSorgqr_bufferSize(handle.get(), m, n, k, - /*A=*/nullptr, - /*lda=*/m, - /*tau=*/nullptr, &lwork))); - break; - case SolverType::F64: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(gpusolverDnDorgqr_bufferSize(handle.get(), m, n, k, - /*A=*/nullptr, - /*lda=*/m, - /*tau=*/nullptr, &lwork))); - break; - case SolverType::C64: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(gpusolverDnCungqr_bufferSize(handle.get(), m, n, k, - /*A=*/nullptr, - /*lda=*/m, - /*tau=*/nullptr, &lwork))); - break; - case SolverType::C128: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(gpusolverDnZungqr_bufferSize(handle.get(), m, n, k, - /*A=*/nullptr, - /*lda=*/m, - /*tau=*/nullptr, &lwork))); - break; - } - return {lwork, PackDescriptor(OrgqrDescriptor{type, b, m, n, k, lwork})}; -} - // Symmetric (Hermitian) eigendecomposition, QR algorithm: syevd/heevd // Returns the workspace size and a descriptor for a syevd operation. @@ -462,9 +341,6 @@ std::pair BuildSytrdDescriptor(const dtype& dtype, bool lower, nb::dict Registrations() { nb::dict dict; - dict[JAX_GPU_PREFIX "solver_getrf"] = EncapsulateFunction(Getrf); - dict[JAX_GPU_PREFIX "solver_geqrf"] = EncapsulateFunction(Geqrf); - dict[JAX_GPU_PREFIX "solver_orgqr"] = EncapsulateFunction(Orgqr); dict[JAX_GPU_PREFIX "solver_syevd"] = EncapsulateFunction(Syevd); dict[JAX_GPU_PREFIX "solver_syevj"] = EncapsulateFunction(Syevj); dict[JAX_GPU_PREFIX "solver_gesvd"] = EncapsulateFunction(Gesvd); @@ -496,9 +372,6 @@ nb::dict Registrations() { NB_MODULE(_solver, m) { tsl::ImportNumpy(); m.def("registrations", &Registrations); - m.def("build_getrf_descriptor", &BuildGetrfDescriptor); - m.def("build_geqrf_descriptor", &BuildGeqrfDescriptor); - m.def("build_orgqr_descriptor", &BuildOrgqrDescriptor); m.def("build_syevd_descriptor", &BuildSyevdDescriptor); m.def("build_syevj_descriptor", &BuildSyevjDescriptor); m.def("build_gesvd_descriptor", &BuildGesvdDescriptor); diff --git a/jaxlib/gpu/solver_handle_pool.cc b/jaxlib/gpu/solver_handle_pool.cc index c55ea923b21b..416ccf9d1bbc 100644 --- a/jaxlib/gpu/solver_handle_pool.cc +++ b/jaxlib/gpu/solver_handle_pool.cc @@ -19,7 +19,7 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/vendor.h" -#include "jaxlib/handle_pool.h" +#include "jaxlib/gpu/handle_pool.h" #ifdef JAX_GPU_CUDA #include "third_party/gpus/cuda/include/cusolverSp.h" diff --git a/jaxlib/gpu/solver_handle_pool.h b/jaxlib/gpu/solver_handle_pool.h index c46c062b3054..4e369ea85520 100644 --- a/jaxlib/gpu/solver_handle_pool.h +++ b/jaxlib/gpu/solver_handle_pool.h @@ -18,7 +18,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "jaxlib/gpu/vendor.h" -#include "jaxlib/handle_pool.h" +#include "jaxlib/gpu/handle_pool.h" #ifdef JAX_GPU_CUDA #include "third_party/gpus/cuda/include/cusolverSp.h" diff --git a/jaxlib/gpu/solver_kernels.cc b/jaxlib/gpu/solver_kernels.cc index 8c22dfcdbca7..040b5a137bc6 100644 --- a/jaxlib/gpu/solver_kernels.cc +++ b/jaxlib/gpu/solver_kernels.cc @@ -50,175 +50,6 @@ static int SizeOfSolverType(SolverType type) { } } -// getrf: LU decomposition - -static absl::Status Getrf_(gpuStream_t stream, void** buffers, - const char* opaque, size_t opaque_len) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const GetrfDescriptor& d = **s; - auto h = SolverHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - if (buffers[1] != buffers[0]) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync( - buffers[1], buffers[0], - SizeOfSolverType(d.type) * static_cast(d.batch) * - static_cast(d.m) * static_cast(d.n), - gpuMemcpyDeviceToDevice, stream))); - } - - int* ipiv = static_cast(buffers[2]); - int* info = static_cast(buffers[3]); - void* workspace = buffers[4]; - switch (d.type) { - case SolverType::F32: { - float* a = static_cast(buffers[1]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnSgetrf( - handle.get(), d.m, d.n, a, d.m, static_cast(workspace), - d.lwork, ipiv, info))); - a += d.m * d.n; - ipiv += std::min(d.m, d.n); - ++info; - } - break; - } - case SolverType::F64: { - double* a = static_cast(buffers[1]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnDgetrf( - handle.get(), d.m, d.n, a, d.m, static_cast(workspace), - d.lwork, ipiv, info))); - a += d.m * d.n; - ipiv += std::min(d.m, d.n); - ++info; - } - break; - } - case SolverType::C64: { - gpuComplex* a = static_cast(buffers[1]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnCgetrf( - handle.get(), d.m, d.n, a, d.m, static_cast(workspace), - d.lwork, ipiv, info))); - a += d.m * d.n; - ipiv += std::min(d.m, d.n); - ++info; - } - break; - } - case SolverType::C128: { - gpuDoubleComplex* a = static_cast(buffers[1]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnZgetrf( - handle.get(), d.m, d.n, a, d.m, - static_cast(workspace), d.lwork, ipiv, info))); - a += d.m * d.n; - ipiv += std::min(d.m, d.n); - ++info; - } - break; - } - } - return absl::OkStatus(); -} - -void Getrf(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = Getrf_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - -// geqrf: QR decomposition - -static absl::Status Geqrf_(gpuStream_t stream, void** buffers, - const char* opaque, size_t opaque_len) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const GeqrfDescriptor& d = **s; - auto h = SolverHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - if (buffers[1] != buffers[0]) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync( - buffers[1], buffers[0], - SizeOfSolverType(d.type) * static_cast(d.batch) * - static_cast(d.m) * static_cast(d.n), - gpuMemcpyDeviceToDevice, stream))); - } - - int* info = static_cast(buffers[3]); - void* workspace = buffers[4]; - switch (d.type) { - case SolverType::F32: { - float* a = static_cast(buffers[1]); - float* tau = static_cast(buffers[2]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - gpusolverDnSgeqrf(handle.get(), d.m, d.n, a, d.m, tau, - static_cast(workspace), d.lwork, info))); - a += d.m * d.n; - tau += std::min(d.m, d.n); - ++info; - } - break; - } - case SolverType::F64: { - double* a = static_cast(buffers[1]); - double* tau = static_cast(buffers[2]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - gpusolverDnDgeqrf(handle.get(), d.m, d.n, a, d.m, tau, - static_cast(workspace), d.lwork, info))); - a += d.m * d.n; - tau += std::min(d.m, d.n); - ++info; - } - break; - } - case SolverType::C64: { - gpuComplex* a = static_cast(buffers[1]); - gpuComplex* tau = static_cast(buffers[2]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnCgeqrf( - handle.get(), d.m, d.n, a, d.m, tau, - static_cast(workspace), d.lwork, info))); - a += d.m * d.n; - tau += std::min(d.m, d.n); - ++info; - } - break; - } - case SolverType::C128: { - gpuDoubleComplex* a = static_cast(buffers[1]); - gpuDoubleComplex* tau = static_cast(buffers[2]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnZgeqrf( - handle.get(), d.m, d.n, a, d.m, tau, - static_cast(workspace), d.lwork, info))); - a += d.m * d.n; - tau += std::min(d.m, d.n); - ++info; - } - break; - } - } - return absl::OkStatus(); -} - -void Geqrf(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = Geqrf_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - #ifdef JAX_GPU_CUDA // csrlsvqr: Linear system solve via Sparse QR @@ -320,92 +151,6 @@ void Csrlsvqr(gpuStream_t stream, void** buffers, const char* opaque, #endif // JAX_GPU_CUDA -// orgqr/ungqr: apply elementary Householder transformations - -static absl::Status Orgqr_(gpuStream_t stream, void** buffers, - const char* opaque, size_t opaque_len) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const OrgqrDescriptor& d = **s; - auto h = SolverHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - if (buffers[2] != buffers[0]) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync( - buffers[2], buffers[0], - SizeOfSolverType(d.type) * static_cast(d.batch) * - static_cast(d.m) * static_cast(d.n), - gpuMemcpyDeviceToDevice, stream))); - } - - int* info = static_cast(buffers[3]); - void* workspace = buffers[4]; - switch (d.type) { - case SolverType::F32: { - float* a = static_cast(buffers[2]); - float* tau = static_cast(buffers[1]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - gpusolverDnSorgqr(handle.get(), d.m, d.n, d.k, a, d.m, tau, - static_cast(workspace), d.lwork, info))); - a += d.m * d.n; - tau += d.k; - ++info; - } - break; - } - case SolverType::F64: { - double* a = static_cast(buffers[2]); - double* tau = static_cast(buffers[1]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - gpusolverDnDorgqr(handle.get(), d.m, d.n, d.k, a, d.m, tau, - static_cast(workspace), d.lwork, info))); - a += d.m * d.n; - tau += d.k; - ++info; - } - break; - } - case SolverType::C64: { - gpuComplex* a = static_cast(buffers[2]); - gpuComplex* tau = static_cast(buffers[1]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnCungqr( - handle.get(), d.m, d.n, d.k, a, d.m, tau, - static_cast(workspace), d.lwork, info))); - a += d.m * d.n; - tau += d.k; - ++info; - } - break; - } - case SolverType::C128: { - gpuDoubleComplex* a = static_cast(buffers[2]); - gpuDoubleComplex* tau = static_cast(buffers[1]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnZungqr( - handle.get(), d.m, d.n, d.k, a, d.m, tau, - static_cast(workspace), d.lwork, info))); - a += d.m * d.n; - tau += d.k; - ++info; - } - break; - } - } - return absl::OkStatus(); -} - -void Orgqr(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = Orgqr_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - // Symmetric (Hermitian) eigendecomposition, QR algorithm: syevd/heevd static absl::Status Syevd_(gpuStream_t stream, void** buffers, diff --git a/jaxlib/gpu/solver_kernels.h b/jaxlib/gpu/solver_kernels.h index 51082f2fe812..a68aaf1ca233 100644 --- a/jaxlib/gpu/solver_kernels.h +++ b/jaxlib/gpu/solver_kernels.h @@ -33,26 +33,6 @@ enum class SolverType { C128, }; -// getrf: LU decomposition - -struct GetrfDescriptor { - SolverType type; - int batch, m, n, lwork; -}; - -void Getrf(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - -// geqrf: QR decomposition - -struct GeqrfDescriptor { - SolverType type; - int batch, m, n, lwork; -}; - -void Geqrf(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - #ifdef JAX_GPU_CUDA // csrlsvpr: Linear system solve via Sparse QR @@ -68,16 +48,6 @@ void Csrlsvqr(gpuStream_t stream, void** buffers, const char* opaque, #endif // JAX_GPU_CUDA -// orgqr/ungqr: apply elementary Householder transformations - -struct OrgqrDescriptor { - SolverType type; - int batch, m, n, k, lwork; -}; - -void Orgqr(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - // Symmetric (Hermitian) eigendecomposition, QR algorithm: syevd/heevd struct SyevdDescriptor { diff --git a/jaxlib/gpu/sparse.cc b/jaxlib/gpu/sparse.cc index 429c8018dc7a..592c0f454a55 100644 --- a/jaxlib/gpu/sparse.cc +++ b/jaxlib/gpu/sparse.cc @@ -13,24 +13,23 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -#include +#include #include +#include #include -#include -#include "nanobind/nanobind.h" -#include "nanobind/stl/pair.h" -#include "absl/base/casts.h" -#include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" #include "absl/strings/str_format.h" -#include "absl/synchronization/mutex.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/pair.h" // IWYU pragma: keep #include "jaxlib/absl_status_casters.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/sparse_kernels.h" #include "jaxlib/gpu/vendor.h" +#include "jaxlib/kernel_helpers.h" #include "jaxlib/kernel_nanobind_helpers.h" +#include "xla/service/custom_call_status.h" #include "xla/tsl/python/lib/core/numpy.h" namespace nb = nanobind; diff --git a/jaxlib/gpu/sparse_kernels.cc b/jaxlib/gpu/sparse_kernels.cc index 5b620a05236d..a9c08317e066 100644 --- a/jaxlib/gpu/sparse_kernels.cc +++ b/jaxlib/gpu/sparse_kernels.cc @@ -15,11 +15,9 @@ limitations under the License. #include "jaxlib/gpu/sparse_kernels.h" -#include -#include -#include -#include -#include +#include +#include +#include #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -27,8 +25,8 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "jaxlib/gpu/ffi_wrapper.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" +#include "jaxlib/gpu/handle_pool.h" #include "jaxlib/gpu/vendor.h" -#include "jaxlib/handle_pool.h" #include "jaxlib/kernel_helpers.h" #include "xla/service/custom_call_status.h" diff --git a/jaxlib/gpu/sparse_kernels.h b/jaxlib/gpu/sparse_kernels.h index 323431812758..d735c320307c 100644 --- a/jaxlib/gpu/sparse_kernels.h +++ b/jaxlib/gpu/sparse_kernels.h @@ -16,15 +16,12 @@ limitations under the License. #ifndef JAXLIB_GPU_SPARSE_KERNELS_H_ #define JAXLIB_GPU_SPARSE_KERNELS_H_ -#include +#include #include -#include -#include -#include #include "absl/status/statusor.h" +#include "jaxlib/gpu/handle_pool.h" #include "jaxlib/gpu/vendor.h" -#include "jaxlib/handle_pool.h" #include "xla/ffi/api/ffi.h" #include "xla/service/custom_call_status.h" diff --git a/jaxlib/gpu/triton.cc b/jaxlib/gpu/triton.cc index 500034af3ebb..b3f313e4f7ea 100644 --- a/jaxlib/gpu/triton.cc +++ b/jaxlib/gpu/triton.cc @@ -1,17 +1,35 @@ +/* Copyright 2022 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include #include -#include #include #include #include +#include #include -#include "nanobind/nanobind.h" -#include "nanobind/stl/pair.h" -#include "nanobind/stl/string.h" -#include "nanobind/stl/string_view.h" -#include "nanobind/stl/tuple.h" -#include "nanobind/stl/vector.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/pair.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/tuple.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep #include "jaxlib/absl_status_casters.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/triton.pb.h" diff --git a/jaxlib/gpu/triton_kernels.cc b/jaxlib/gpu/triton_kernels.cc index 22397ff908bc..9e0dc6c855ac 100644 --- a/jaxlib/gpu/triton_kernels.cc +++ b/jaxlib/gpu/triton_kernels.cc @@ -1,3 +1,18 @@ +/* Copyright 2023 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #include "jaxlib/gpu/triton_kernels.h" #include @@ -25,6 +40,7 @@ #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/triton.pb.h" @@ -37,7 +53,8 @@ #endif // JAX_GPU_CUDA #ifdef JAX_GPU_HIP -#include "tsl/platform/env.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/errors.h" #endif // JAX_GPU_HIP #define GPU_RETURN_IF_ERROR(expr) JAX_RETURN_IF_ERROR(JAX_AS_STATUS(expr)) diff --git a/jaxlib/gpu/triton_kernels.h b/jaxlib/gpu/triton_kernels.h index c3457093c4f8..3ab3e9143fb8 100644 --- a/jaxlib/gpu/triton_kernels.h +++ b/jaxlib/gpu/triton_kernels.h @@ -1,8 +1,23 @@ +/* Copyright 2023 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #ifndef JAXLIB_GPU_TRITON_H_ #define JAXLIB_GPU_TRITON_H_ +#include #include -#include #include #include #include @@ -10,7 +25,6 @@ #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/triton.pb.h" #include "jaxlib/gpu/vendor.h" #include "xla/service/custom_call_status.h" diff --git a/jaxlib/gpu/triton_utils.cc b/jaxlib/gpu/triton_utils.cc index b3a0779118de..fd63435da177 100644 --- a/jaxlib/gpu/triton_utils.cc +++ b/jaxlib/gpu/triton_utils.cc @@ -1,3 +1,18 @@ +/* Copyright 2023 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #include "jaxlib/gpu/triton_utils.h" #include @@ -9,6 +24,7 @@ #include "absl/strings/string_view.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/triton.pb.h" +#include "jaxlib/gpu/vendor.h" namespace jax::JAX_GPU_NAMESPACE { diff --git a/jaxlib/gpu/triton_utils.h b/jaxlib/gpu/triton_utils.h index 0c286391e296..a79c098373d1 100644 --- a/jaxlib/gpu/triton_utils.h +++ b/jaxlib/gpu/triton_utils.h @@ -1,9 +1,23 @@ +/* Copyright 2023 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #ifndef JAXLIB_GPU_TRITON_UTILS_H_ #define JAXLIB_GPU_TRITON_UTILS_H_ #include -#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "jaxlib/gpu/vendor.h" diff --git a/jaxlib/gpu/vendor.h b/jaxlib/gpu/vendor.h index 7334d4690b59..5deb8d4c650a 100644 --- a/jaxlib/gpu/vendor.h +++ b/jaxlib/gpu/vendor.h @@ -20,6 +20,7 @@ limitations under the License. #ifndef JAXLIB_GPU_VENDOR_H_ #define JAXLIB_GPU_VENDOR_H_ +#include #if defined(JAX_GPU_CUDA) // IWYU pragma: begin_exports @@ -29,7 +30,7 @@ limitations under the License. #include "third_party/gpus/cuda/include/cublas_v2.h" #include "third_party/gpus/cuda/include/cuda.h" #include "third_party/gpus/cuda/include/cuda_fp8.h" -#include "third_party/gpus/cuda/include/cuda_runtime_api.h" +#include "cuda_runtime_api.h" #include "third_party/gpus/cuda/include/cufft.h" #include "third_party/gpus/cuda/include/cusolverDn.h" #include "third_party/gpus/cuda/include/cusolver_common.h" @@ -48,6 +49,7 @@ limitations under the License. #define JAX_GPU_NAMESPACE cuda #define JAX_GPU_PREFIX "cu" +#define JAX_GPU_PLUGIN_NAME "cuda" typedef cuComplex gpuComplex; typedef cuDoubleComplex gpuDoubleComplex; @@ -413,6 +415,7 @@ constexpr uint32_t kNumThreadsPerWarp = 32; #define JAX_GPU_NAMESPACE hip #define JAX_GPU_PREFIX "hip" +#define JAX_GPU_PLUGIN_NAME "rocm" #define JAX_GPU_HAVE_SPARSE 1 #define JAX_GPU_HAVE_64_BIT 0 diff --git a/jaxlib/gpu_prng.py b/jaxlib/gpu_prng.py index 6f74d5813ce4..b112534c0575 100644 --- a/jaxlib/gpu_prng.py +++ b/jaxlib/gpu_prng.py @@ -12,79 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations -from functools import partial -import itertools +from typing import Any -import jaxlib.mlir.ir as ir - -from jaxlib import xla_client - -from .hlo_helpers import custom_call from .plugin_support import import_from_plugin _cuda_prng = import_from_plugin("cuda", "_prng") _hip_prng = import_from_plugin("rocm", "_prng") -if _cuda_prng: - for _name, _value in _cuda_prng.registrations().items(): - # TODO(danfm): remove after JAX 0.5.1 release - api_version = 1 if "_ffi" in _name else 0 - xla_client.register_custom_call_target(_name, _value, platform="CUDA", - api_version=api_version) - -if _hip_prng: - for _name, _value in _hip_prng.registrations().items(): - # TODO(danfm): remove after JAX 0.5.1 release - api_version = 1 if "_ffi" in _name else 0 - xla_client.register_custom_call_target(_name, _value, platform="ROCM", - api_version=api_version) - - -def _threefry2x32_lowering(prng, platform: str, keys, data, - length: int | ir.Value | None = None, - output_shape: ir.Value | None = None, - forward_compatibility_mode: bool = False): - """ThreeFry2x32 kernel for GPU. - - In presence of dynamic shapes, `length` is an `ir.Value` and `output_shape` - is a 1D tensor describing the shape of the two outputs. - """ - del forward_compatibility_mode - assert len(keys) == 2, keys - assert len(data) == 2, data - assert (ir.RankedTensorType(keys[0].type).element_type == - ir.IntegerType.get_unsigned(32)), keys[0].type - - typ = keys[0].type - dims = ir.RankedTensorType(typ).shape - - for x in itertools.chain(keys, data): - assert x.type == typ, (x.type, typ) - ndims = len(dims) - layout = tuple(range(ndims - 1, -1, -1)) - operand_layouts = [layout] * 4 - operands = [keys[0], keys[1], data[0], data[1]] - - opaque = {} # Use if not forward_compatibility_mode to trigger the FFI (v4). - if isinstance(length, int): - result_shapes = None - else: - assert output_shape is not None - # We also need to pass separately the shapes of the outputs. - result_shapes = [output_shape, output_shape] - - custom_call_target = f"{platform}_threefry2x32_ffi" - return custom_call( - custom_call_target, - api_version=4, - result_types=[typ, typ], - operands=operands, - backend_config=opaque, - operand_layouts=operand_layouts, - result_layouts=[layout] * 2, - result_shapes=result_shapes).results - - -cuda_threefry2x32 = partial(_threefry2x32_lowering, _cuda_prng, "cu") -rocm_threefry2x32 = partial(_threefry2x32_lowering, _hip_prng, "hip") +def registrations() -> dict[str, list[tuple[str, Any, int]]]: + registrations = {"CUDA": [], "ROCM": []} + for platform, module in [("CUDA", _cuda_prng), ("ROCM", _hip_prng)]: + if module: + registrations[platform].extend( + (name, value, int(name.endswith("_ffi"))) + for name, value in module.registrations().items()) + return registrations diff --git a/jaxlib/gpu_solver.py b/jaxlib/gpu_solver.py index efb58f9a4164..c846c63e2ff8 100644 --- a/jaxlib/gpu_solver.py +++ b/jaxlib/gpu_solver.py @@ -16,21 +16,15 @@ from .plugin_support import import_from_plugin -_cublas = import_from_plugin("cuda", "_blas") _cusolver = import_from_plugin("cuda", "_solver") _cuhybrid = import_from_plugin("cuda", "_hybrid") -_hipblas = import_from_plugin("rocm", "_blas") _hipsolver = import_from_plugin("rocm", "_solver") _hiphybrid = import_from_plugin("rocm", "_hybrid") def registrations() -> dict[str, list[tuple[str, Any, int]]]: registrations = {"CUDA": [], "ROCM": []} - for platform, module in [("CUDA", _cublas), ("ROCM", _hipblas)]: - if module: - registrations[platform].extend( - (*i, 0) for i in module.registrations().items()) for platform, module in [("CUDA", _cusolver), ("ROCM", _hipsolver)]: if module: registrations[platform].extend( diff --git a/jaxlib/gpu_sparse.py b/jaxlib/gpu_sparse.py index d8645041c946..cc2b2ad08e55 100644 --- a/jaxlib/gpu_sparse.py +++ b/jaxlib/gpu_sparse.py @@ -11,25 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -cusparse wrappers for performing sparse matrix computations in JAX -""" -import math -from functools import partial from typing import Any -import jaxlib.mlir.ir as ir - -import numpy as np - -from .hlo_helpers import custom_call, mk_result_types_and_shapes - from .plugin_support import import_from_plugin _cusparse = import_from_plugin("cuda", "_sparse") _hipsparse = import_from_plugin("rocm", "_sparse") +cuda_is_supported = bool(_cusparse and _cusparse.sparse_supported) +rocm_is_supported = bool(_hipsparse and _hipsparse.sparse_supported) + def registrations() -> dict[str, list[tuple[str, Any, int]]]: registrations = {"CUDA": [], "ROCM": []} for platform, module in [("CUDA", _cusparse), ("ROCM", _hipsparse)]: @@ -38,346 +30,3 @@ def registrations() -> dict[str, list[tuple[str, Any, int]]]: (name, value, int(name.endswith("_ffi"))) for name, value in module.registrations().items()) return registrations # pytype: disable=bad-return-type - - -cuda_is_supported = bool(_cusparse and _cusparse.sparse_supported) -rocm_is_supported = bool(_hipsparse and _hipsparse.sparse_supported) - - -def _validate_csr_hlo(data, indices, indptr, shape): - data_type = ir.RankedTensorType(data.type) - indices_type = ir.RankedTensorType(indices.type) - indptr_type = ir.RankedTensorType(indptr.type) - - nnz, = data_type.shape - assert indices_type.shape == [nnz] - assert indptr_type.element_type == indices_type.element_type - assert indptr_type.shape == [shape[0] + 1] - return data_type.element_type, indices_type.element_type, nnz - -def _validate_coo_hlo(data, row, col): - data_type = ir.RankedTensorType(data.type) - row_type = ir.RankedTensorType(row.type) - col_type = ir.RankedTensorType(col.type) - - nnz, = data_type.shape - assert row_type.shape == [nnz] - assert col_type.element_type == row_type.element_type - assert col_type.shape == [nnz] - return data_type.element_type, row_type.element_type, nnz - - -def _csr_todense_hlo(platform, gpu_sparse, data, indices, indptr, *, shape, - data_dtype, index_dtype): - """CSR to dense matrix.""" - data_type, index_type, nnz = _validate_csr_hlo(data, indices, indptr, shape) - rows, cols = shape - - buffer_size, opaque = gpu_sparse.build_csr_todense_descriptor( - data_dtype, index_dtype, rows, cols, nnz) - - out = custom_call( - f"{platform}sparse_csr_todense_ffi", - result_types=[ - ir.RankedTensorType.get(shape, data_type), - ir.RankedTensorType.get([buffer_size], - ir.IntegerType.get_signless(8)), - ], - operands=[data, indices, indptr], - backend_config={"opaque": ir.StringAttr.get(opaque)}, - api_version=4, - operand_layouts=[[0]] * 3, - result_layouts=[[1, 0], [0]]).results - return out[0] - -cuda_csr_todense = partial(_csr_todense_hlo, "cu", _cusparse) -rocm_csr_todense = partial(_csr_todense_hlo, "hip", _hipsparse) - - -def _csr_fromdense_hlo(platform, gpu_sparse, mat, *, nnz, index_dtype, - data_dtype, index_type): - """CSR from dense matrix.""" - mat_type = ir.RankedTensorType(mat.type) - rows, cols = mat_type.shape - - buffer_size, opaque = gpu_sparse.build_csr_fromdense_descriptor( - data_dtype, index_dtype, rows, cols, nnz) - - out = custom_call( - f"{platform}sparse_csr_fromdense_ffi", - result_types=[ - ir.RankedTensorType.get([nnz], mat_type.element_type), - ir.RankedTensorType.get([nnz], index_type), - ir.RankedTensorType.get([rows + 1], index_type), - ir.RankedTensorType.get([buffer_size], - ir.IntegerType.get_signless(8)), - ], - operands=[mat], - backend_config={"opaque": ir.StringAttr.get(opaque)}, - api_version=4, - operand_layouts=[[1, 0]], - result_layouts=[[0]] * 4).results - return out[:3] - -cuda_csr_fromdense = partial(_csr_fromdense_hlo, "cu", _cusparse) -rocm_csr_fromdense = partial(_csr_fromdense_hlo, "hip", _hipsparse) - - -def _csr_matvec_hlo(platform, gpu_sparse, data, indices, indptr, x, *, shape, - transpose=False, compute_dtype=None, compute_type=None, - data_dtype, index_dtype, x_dtype): - """CSR matrix/vector multiply.""" - data_type, index_type, nnz = _validate_csr_hlo(data, indices, indptr, shape) - rows, cols = shape - - if compute_dtype is None: - compute_dtype = data_dtype - compute_type = data_type - - buffer_size, opaque = gpu_sparse.build_csr_matvec_descriptor( - data_dtype, x_dtype, compute_dtype, index_dtype, - rows, cols, nnz, transpose) - out_size = cols if transpose else rows - - out = custom_call( - f"{platform}sparse_csr_matvec_ffi", - result_types=[ - ir.RankedTensorType.get([out_size], compute_type), - ir.RankedTensorType.get([buffer_size], - ir.IntegerType.get_signless(8)), - ], - operands=[data, indices, indptr, x], - backend_config={"opaque": ir.StringAttr.get(opaque)}, - api_version=4, - operand_layouts=[[0]] * 4, - result_layouts=[[0]] * 2).results - return out[0] - -cuda_csr_matvec = partial(_csr_matvec_hlo, "cu", _cusparse) -rocm_csr_matvec = partial(_csr_matvec_hlo, "hip", _hipsparse) - - -def _csr_matmat_hlo(platform, gpu_sparse, data, indices, indptr, B, *, shape, - transpose=False, compute_dtype=None, compute_type=None, - index_dtype, data_dtype, B_dtype): - """CSR from dense matrix.""" - data_type, index_type, nnz = _validate_csr_hlo(data, indices, indptr, shape) - rows, cols = shape - B_shape = ir.RankedTensorType(B.type).shape - _, Ccols = B_shape - - if compute_dtype is None: - compute_dtype = data_dtype - compute_type = data_type - - buffer_size, opaque = gpu_sparse.build_csr_matmat_descriptor( - data_dtype, B_dtype, compute_dtype, index_dtype, - rows, cols, Ccols, nnz, transpose) - out_size = cols if transpose else rows - - out = custom_call( - f"{platform}sparse_csr_matmat_ffi", - result_types=[ - ir.RankedTensorType.get([out_size, Ccols], compute_type), - ir.RankedTensorType.get([buffer_size], - ir.IntegerType.get_signless(8)), - ], - operands=[data, indices, indptr, B], - backend_config={"opaque": ir.StringAttr.get(opaque)}, - api_version=4, - operand_layouts=[[0], [0], [0], [1, 0]], - result_layouts=[[1, 0], [0]]).results - return out[0] - -cuda_csr_matmat = partial(_csr_matmat_hlo, "cu", _cusparse) -rocm_csr_matmat = partial(_csr_matmat_hlo, "hip", _hipsparse) - - -def _coo_todense_hlo(platform, gpu_sparse, data, row, col, *, shape, - data_dtype, index_dtype): - """COO to dense matrix.""" - data_type, _, nnz = _validate_coo_hlo(data, row, col) - rows, cols = shape - - buffer_size, opaque = gpu_sparse.build_coo_todense_descriptor( - data_dtype, index_dtype, rows, cols, nnz) - - out = custom_call( - f"{platform}sparse_coo_todense_ffi", - result_types=[ - ir.RankedTensorType.get(shape, data_type), - ir.RankedTensorType.get([buffer_size], - ir.IntegerType.get_signless(8)), - ], - operands=[data, row, col], - backend_config={"opaque": ir.StringAttr.get(opaque)}, - api_version=4, - operand_layouts=[[0]] * 3, - result_layouts=[[1, 0], [0]]).results - return out[0] - -cuda_coo_todense = partial(_coo_todense_hlo, "cu", _cusparse) -rocm_coo_todense = partial(_coo_todense_hlo, "hip", _hipsparse) - - -def _coo_fromdense_hlo(platform, gpu_sparse, mat, *, nnz, data_dtype, - index_dtype, index_type): - """COO from dense matrix.""" - mat_type = ir.RankedTensorType(mat.type) - rows, cols = mat_type.shape - - buffer_size, opaque = gpu_sparse.build_coo_fromdense_descriptor( - data_dtype, index_dtype, rows, cols, nnz) - - out = custom_call( - f"{platform}sparse_coo_fromdense_ffi", - result_types=[ - ir.RankedTensorType.get([nnz], mat_type.element_type), - ir.RankedTensorType.get([nnz], index_type), - ir.RankedTensorType.get([nnz], index_type), - ir.RankedTensorType.get([buffer_size], - ir.IntegerType.get_signless(8)), - ], - operands=[mat], - backend_config={"opaque": ir.StringAttr.get(opaque)}, - api_version=4, - operand_layouts=[[1, 0]], - result_layouts=[[0]] * 4).results - return out[:3] - -cuda_coo_fromdense = partial(_coo_fromdense_hlo, "cu", _cusparse) -rocm_coo_fromdense = partial(_coo_fromdense_hlo, "hip", _hipsparse) - - -def _coo_matvec_hlo(platform, gpu_sparse, data, row, col, x, *, shape, - transpose=False, compute_dtype=None, compute_type=None, - index_dtype, data_dtype, x_dtype): - """COO matrix/vector multiply.""" - data_type, _, nnz = _validate_coo_hlo(data, row, col) - rows, cols = shape - - if compute_dtype is None: - compute_dtype = data_dtype - compute_type = data_type - - buffer_size, opaque = gpu_sparse.build_coo_matvec_descriptor( - data_dtype, x_dtype, compute_dtype, index_dtype, - rows, cols, nnz, transpose) - out_size = cols if transpose else rows - - out = custom_call( - f"{platform}sparse_coo_matvec_ffi", - result_types=[ - ir.RankedTensorType.get([out_size], compute_type), - ir.RankedTensorType.get([buffer_size], - ir.IntegerType.get_signless(8)), - ], - operands=[data, row, col, x], - backend_config={"opaque": ir.StringAttr.get(opaque)}, - api_version=4, - operand_layouts=[[0]] * 4, - result_layouts=[[0]] * 2).results - return out[0] - -cuda_coo_matvec = partial(_coo_matvec_hlo, "cu", _cusparse) -rocm_coo_matvec = partial(_coo_matvec_hlo, "hip", _hipsparse) - - -def _coo_matmat_hlo(platform, gpu_sparse, data, row, col, B, *, shape, - transpose=False, compute_dtype=None, compute_type=None, - x_dtype, data_dtype, index_dtype): - """COO from dense matrix.""" - data_type, _, nnz = _validate_coo_hlo(data, row, col) - is_batched_matmat = False - batch_count = 1 - if len(shape) == 2: - rows, cols = shape - elif len(shape) == 3: - is_batched_matmat = True - batch_count, rows, cols = shape - # Redefine nnz as nnz per batch. - nnz = nnz // batch_count - - B_shape = ir.RankedTensorType(B.type).shape - _, Ccols = B_shape - - if compute_dtype is None: - compute_dtype = data_dtype - compute_type = data_type - - # TODO(tianjianlu): use batch stride to trigger different mode of batch - # computation. Currently batch_stride = 0 is not allowed because of the issue - # in cusparse https://github.com/NVIDIA/CUDALibrarySamples/issues/81#issuecomment-1205562643 - # Set batch stride to be the matrix size for now. - lhs_batch_stride = nnz - B_rows = rows if transpose else cols - rhs_batch_stride = B_rows * Ccols - - buffer_size, opaque = gpu_sparse.build_coo_matmat_descriptor( - data_dtype, x_dtype, compute_dtype, index_dtype, - rows, cols, Ccols, nnz, transpose, batch_count, lhs_batch_stride, - rhs_batch_stride) - out_size = cols if transpose else rows - - if is_batched_matmat: - out_shape = [batch_count, out_size, Ccols] - out_layout = [2, 1, 0] - else: - out_shape = [out_size, Ccols] - out_layout = [1, 0] - - out = custom_call( - f"{platform}sparse_coo_matmat_ffi", - result_types=[ - ir.RankedTensorType.get(out_shape, compute_type), - ir.RankedTensorType.get([buffer_size], - ir.IntegerType.get_signless(8)), - ], - operands=[data, row, col, B], - backend_config={"opaque": ir.StringAttr.get(opaque)}, - api_version=4, - operand_layouts=[[0], [0], [0], [1, 0]], - result_layouts=[out_layout, [0]]).results - return out[0] - -cuda_coo_matmat = partial(_coo_matmat_hlo, "cu", _cusparse) -rocm_coo_matmat = partial(_coo_matmat_hlo, "hip", _hipsparse) - - -def _gtsv2_hlo( - platform, gpu_sparse, dl, d, du, B, *, m, n, ldb, t, b_shape_vals=None): - """Calls `cusparsegtsv2(dl, d, du, B, m, n, ldb)`.""" - assert len(b_shape_vals) >= 2 - batch_dim_vals = b_shape_vals[:-2] - batch_size = math.prod(batch_dim_vals) - num_bd = len(b_shape_vals) - 2 - f32 = (t == np.float32) - if f32: - buffer_size = gpu_sparse.gtsv2_f32_buffer_size(m, n, ldb) - else: - buffer_size = gpu_sparse.gtsv2_f64_buffer_size(m, n, ldb) - - b_layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) - d_layout = (num_bd,) + tuple(range(num_bd - 1, -1, -1)) - b_type = ir.RankedTensorType(B.type) - - shape_type_pairs = [ - (batch_dim_vals + (ldb, n), b_type.element_type), - ((buffer_size,), ir.IntegerType.get_signless(8)) - ] - result_types, result_shapes = mk_result_types_and_shapes(shape_type_pairs) - opaque = gpu_sparse.build_gtsv2_descriptor(batch_size, m, n, ldb) - out = custom_call( - f"{platform}sparse_gtsv2_" + ("f32" if f32 else "f64") + "_ffi", - result_types=result_types, - operands=[dl, d, du, B], - backend_config={"opaque": ir.StringAttr.get(opaque)}, - api_version=4, - operand_layouts=[d_layout] * 3 + [b_layout], - result_layouts=[b_layout, [0]], - operand_output_aliases={3: 0}, - result_shapes=result_shapes).results - return out[0] - -cuda_gtsv2 = partial(_gtsv2_hlo, "cu", _cusparse) -rocm_gtsv2 = partial(_gtsv2_hlo, "hip", _hipsparse) diff --git a/jaxlib/hlo_helpers.py b/jaxlib/hlo_helpers.py index 0d57a04f1aa7..11ff844ae53f 100644 --- a/jaxlib/hlo_helpers.py +++ b/jaxlib/hlo_helpers.py @@ -19,11 +19,22 @@ from collections.abc import Callable, Sequence from functools import partial from typing import Union +import warnings import jaxlib.mlir.ir as ir import jaxlib.mlir.dialects.stablehlo as hlo import numpy as np +# TODO(danfm): This module isn't covered by JAX's compatibility policy, so no +# formal deprecation period is required, but there are enough users that we +# should keep this warning for at least one full release cycle. +# Deprecation added 2025-03-19 after the release of v0.5.3. Remove this whole +# module after the release of v0.5.4 or later. +warnings.warn( + "The jaxlib.hlo_helpers submodule is deprecated. Instead, use jax.ffi if " + "possible or, for lower-level operations, jax.interpreters.mlir.", + DeprecationWarning, +) _dtype_to_ir_type_factory : dict[np.dtype, Callable[[], ir.Type]] = { np.dtype(np.bool_): partial(ir.IntegerType.get_signless, 1), diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index 89f1545995d5..3c234f5f8c37 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -31,7 +31,7 @@ load("@xla//xla/tsl/platform:build_config_root.bzl", _tf_cuda_tests_tags = "tf_c cc_proto_library = _cc_proto_library cuda_library = _cuda_library rocm_library = _rocm_library -pytype_test = native.py_test +proto_library = native.proto_library nanobind_extension = _pybind_extension if_cuda_is_configured = _if_cuda_is_configured if_rocm_is_configured = _if_rocm_is_configured @@ -64,28 +64,45 @@ PLATFORM_TAGS_DICT = { ("Windows", "AMD64"): ("win", "amd64"), } -# TODO(vam): remove this once zstandard builds against Python 3.13 +_GPU_PYPI_WHEEL_DEPS = [ + "//:jax_wheel_with_internal_test_util", + "@pypi_jaxlib//:pkg", + "@pypi_jax_cuda12_plugin//:pkg", + "@pypi_jax_cuda12_pjrt//:pkg", +] + +_CPU_PYPI_WHEEL_DEPS = [ + "//:jax_wheel_with_internal_test_util", + "@pypi_jaxlib//:pkg", +] + +# TODO(vam): remove this once zstandard builds against Python >3.13 def get_zstandard(): - if HERMETIC_PYTHON_VERSION == "3.13" or HERMETIC_PYTHON_VERSION == "3.13-ft": + if HERMETIC_PYTHON_VERSION in ("3.13", "3.13-ft", "3.14", "3.14-ft"): return [] return ["@pypi_zstandard//:pkg"] +def get_optional_dep(package, excluded_py_versions = ["3.14", "3.14-ft"]): + if HERMETIC_PYTHON_VERSION in excluded_py_versions: + return [] + return [package] + _py_deps = { "absl/logging": ["@pypi_absl_py//:pkg"], "absl/testing": ["@pypi_absl_py//:pkg"], "absl/flags": ["@pypi_absl_py//:pkg"], - "cloudpickle": ["@pypi_cloudpickle//:pkg"], - "colorama": ["@pypi_colorama//:pkg"], - "epath": ["@pypi_etils//:pkg"], # etils.epath - "filelock": ["@pypi_filelock//:pkg"], + "cloudpickle": get_optional_dep("@pypi_cloudpickle//:pkg"), + "colorama": get_optional_dep("@pypi_colorama//:pkg"), + "epath": get_optional_dep("@pypi_etils//:pkg"), # etils.epath + "filelock": get_optional_dep("@pypi_filelock//:pkg"), "flatbuffers": ["@pypi_flatbuffers//:pkg"], "hypothesis": ["@pypi_hypothesis//:pkg"], "magma": [], - "matplotlib": ["@pypi_matplotlib//:pkg"], + "matplotlib": get_optional_dep("@pypi_matplotlib//:pkg"), "mpmath": [], "opt_einsum": ["@pypi_opt_einsum//:pkg"], - "pil": ["@pypi_pillow//:pkg"], - "portpicker": ["@pypi_portpicker//:pkg"], + "pil": get_optional_dep("@pypi_pillow//:pkg"), + "portpicker": get_optional_dep("@pypi_portpicker//:pkg"), "ml_dtypes": ["@pypi_ml_dtypes//:pkg"], "numpy": ["@pypi_numpy//:pkg"], "scipy": ["@pypi_scipy//:pkg"], @@ -132,6 +149,9 @@ def pytype_strict_library(name, pytype_srcs = [], **kwargs): new_kwargs = {k: v for k, v in kwargs.items() if k != "data"} native.py_library(name = name, data = data, **new_kwargs) +py_strict_library = native.py_library +py_strict_test = native.py_test + def py_library_providing_imports_info(*, name, lib_rule = native.py_library, pytype_srcs = [], **kwargs): data = pytype_srcs + (kwargs["data"] if "data" in kwargs else []) new_kwargs = {k: v for k, v in kwargs.items() if k != "data"} @@ -140,119 +160,63 @@ def py_library_providing_imports_info(*, name, lib_rule = native.py_library, pyt def py_extension(name, srcs, copts, deps, linkopts = []): nanobind_extension(name, srcs = srcs, copts = copts, linkopts = linkopts, deps = deps, module_name = name) -def windows_cc_shared_mlir_library(name, out, deps = [], srcs = [], exported_symbol_prefixes = []): - """Workaround DLL building issue. - - 1. cc_binary with linkshared enabled cannot produce DLL with symbol - correctly exported. - 2. Even if the DLL is correctly built, the resulting target cannot be - correctly consumed by other targets. - - Args: - name: the name of the output target - out: the name of the output DLL filename - deps: deps - srcs: srcs - """ - - # create a dummy library to get the *.def file - dummy_library_name = name + ".dummy.dll" - native.cc_binary( - name = dummy_library_name, - linkshared = 1, - linkstatic = 1, - deps = deps, - target_compatible_with = ["@platforms//os:windows"], - ) - - # .def file with all symbols, not usable - full_def_name = name + ".full.def" - native.filegroup( - name = full_def_name, - srcs = [dummy_library_name], - output_group = "def_file", - target_compatible_with = ["@platforms//os:windows"], - ) - - # say filtered_symbol_prefixes == ["mlir", "chlo"], then construct the regex - # pattern as "^\\s*(mlir|clho)" to use grep - pattern = "^\\s*(" + "|".join(exported_symbol_prefixes) + ")" - - # filtered def_file, only the needed symbols are included - filtered_def_name = name + ".filtered.def" - filtered_def_file = out + ".def" - native.genrule( - name = filtered_def_name, - srcs = [full_def_name], - outs = [filtered_def_file], - cmd = """echo 'LIBRARY {}\nEXPORTS ' > $@ && grep -E '{}' $(location :{}) >> $@""".format(out, pattern, full_def_name), - target_compatible_with = ["@platforms//os:windows"], - ) - - # create the desired library - native.cc_binary( - name = out, # this name must be correct, it will be the filename - linkshared = 1, - deps = deps, - win_def_file = filtered_def_file, - target_compatible_with = ["@platforms//os:windows"], - ) - - # however, the created cc_library (a shared library) cannot be correctly - # consumed by other cc_*... - interface_library_file = out + ".if.lib" - native.filegroup( - name = interface_library_file, - srcs = [out], - output_group = "interface_library", - target_compatible_with = ["@platforms//os:windows"], - ) - - # but this one can be correctly consumed, this is our final product - native.cc_import( - name = name, - interface_library = interface_library_file, - shared_library = out, - target_compatible_with = ["@platforms//os:windows"], - ) - ALL_BACKENDS = ["cpu", "gpu", "tpu"] def if_building_jaxlib( if_building, - if_not_building = [ - "@pypi_jaxlib//:pkg", - "@pypi_jax_cuda12_plugin//:pkg", - "@pypi_jax_cuda12_pjrt//:pkg", - ], - if_not_building_for_cpu = ["@pypi_jaxlib//:pkg"], - if_py_import = [ - "//jaxlib/tools:jaxlib_py_import", - "//jaxlib/tools:jax_cuda_plugin_py_import", - "//jaxlib/tools:jax_cuda_pjrt_py_import", - ], - if_py_import_for_cpu = [ - "//jaxlib/tools:jaxlib_py_import", - ]): + if_not_building = _GPU_PYPI_WHEEL_DEPS, + if_not_building_for_cpu = _CPU_PYPI_WHEEL_DEPS): """Adds jaxlib and jaxlib cuda plugin wheels as dependencies instead of depending on sources. This allows us to test prebuilt versions of jaxlib wheels against the rest of the JAX codebase. Args: if_building: the source code targets to depend on in case we don't depend on the jaxlib wheels - if_not_building: the jaxlib wheels to depend on including gpu-specific plugins in case of + if_not_building: the wheels to depend on including gpu-specific plugins in case of gpu-enabled builds - if_not_building_for_cpu: the jaxlib wheels to depend on in case of cpu-only builds - if_py_import: the py_import targets to depend on in case of gpu-enabled builds - if_py_import_for_cpu: the py_import targets to depend on in case of cpu-only builds + if_not_building_for_cpu: the wheels to depend on in case of cpu-only builds """ return select({ "//jax:enable_jaxlib_build": if_building, "//jax_plugins/cuda:disable_jaxlib_for_cpu_build": if_not_building_for_cpu, "//jax_plugins/cuda:disable_jaxlib_for_cuda12_build": if_not_building, - "//jax_plugins/cuda:enable_py_import_for_cpu_build": if_py_import_for_cpu, - "//jax_plugins/cuda:enable_py_import_for_cuda12_build": if_py_import, + "//conditions:default": [], + }) + +def _get_test_deps(deps, backend_independent): + gpu_build_deps = [ + "//jaxlib/cuda:gpu_only_test_deps", + "//jaxlib/rocm:gpu_only_test_deps", + "//jax_plugins:gpu_plugin_only_test_deps", + ] + + gpu_py_imports = [ + "//:jax_py_import", + "//jaxlib/tools:jaxlib_py_import", + "//jaxlib/tools:jax_cuda_plugin_py_import", + "//jaxlib/tools:jax_cuda_pjrt_py_import", + ] + cpu_py_imports = [ + "//:jax_py_import", + "//jaxlib/tools:jaxlib_py_import", + ] + + if backend_independent: + jaxlib_build_deps = deps + gpu_pypi_wheel_deps = _CPU_PYPI_WHEEL_DEPS + gpu_py_import_deps = cpu_py_imports + else: + jaxlib_build_deps = gpu_build_deps + deps + gpu_pypi_wheel_deps = _GPU_PYPI_WHEEL_DEPS + gpu_py_import_deps = gpu_py_imports + + return select({ + "//jax:enable_jaxlib_build": jaxlib_build_deps, + "//jax_plugins/cuda:disable_jaxlib_for_cpu_build": _CPU_PYPI_WHEEL_DEPS, + "//jax_plugins/cuda:disable_jaxlib_for_cuda12_build": gpu_pypi_wheel_deps, + "//jax_plugins/cuda:enable_py_import_for_cpu_build": cpu_py_imports, + "//jax_plugins/cuda:enable_py_import_for_cuda12_build": gpu_py_import_deps, }) # buildifier: disable=function-docstring @@ -305,14 +269,10 @@ def jax_multiplatform_test( srcs = srcs, args = test_args, env = env, - deps = [ + deps = _get_test_deps([ "//jax", "//jax:test_util", - ] + deps + if_building_jaxlib([ - "//jaxlib/cuda:gpu_only_test_deps", - "//jaxlib/rocm:gpu_only_test_deps", - "//jax_plugins:gpu_plugin_only_test_deps", - ]), + ] + deps, backend_independent = False), data = data, shard_count = test_shards, tags = test_tags, @@ -362,7 +322,7 @@ def _get_full_wheel_name( free_threaded_suffix = "t" if py_freethreaded.lower() == "yes" else "", ) -def _get_source_distribution_name(package_name, wheel_version): +def _get_source_package_name(package_name, wheel_version): return "{package_name}-{wheel_version}.tar.gz".format( package_name = package_name, wheel_version = wheel_version, @@ -394,37 +354,47 @@ def _jax_wheel_impl(ctx): no_abi = ctx.attr.no_abi platform_independent = ctx.attr.platform_independent build_wheel_only = ctx.attr.build_wheel_only + build_source_package_only = ctx.attr.build_source_package_only editable = ctx.attr.editable platform_name = ctx.attr.platform_name + + output_dir_path = "" + outputs = [] if editable: output_dir = ctx.actions.declare_directory(output_path + "/" + ctx.attr.wheel_name) - wheel_dir = output_dir.path + output_dir_path = output_dir.path outputs = [output_dir] args.add("--editable") else: - wheel_name = _get_full_wheel_name( - package_name = ctx.attr.wheel_name, - no_abi = no_abi, - platform_independent = platform_independent, - platform_name = platform_name, - cpu_name = cpu, - wheel_version = full_wheel_version, - py_freethreaded = py_freethreaded, - ) - wheel_file = ctx.actions.declare_file(output_path + - "/" + wheel_name) - wheel_dir = wheel_file.path[:wheel_file.path.rfind("/")] - outputs = [wheel_file] - if not build_wheel_only: - source_distribution_name = _get_source_distribution_name( + if build_wheel_only: + wheel_name = _get_full_wheel_name( package_name = ctx.attr.wheel_name, + no_abi = no_abi, + platform_independent = platform_independent, + platform_name = platform_name, + cpu_name = cpu, wheel_version = full_wheel_version, + py_freethreaded = py_freethreaded, ) - source_distribution_file = ctx.actions.declare_file(output_path + - "/" + source_distribution_name) - outputs.append(source_distribution_file) - - args.add("--output_path", wheel_dir) # required argument + wheel_file = ctx.actions.declare_file(output_path + + "/" + wheel_name) + output_dir_path = wheel_file.path[:wheel_file.path.rfind("/")] + outputs = [wheel_file] + if ctx.attr.wheel_name == "jax": + args.add("--build-wheel-only", "True") + if build_source_package_only: + source_package_name = _get_source_package_name( + package_name = ctx.attr.wheel_name, + wheel_version = full_wheel_version, + ) + source_package_file = ctx.actions.declare_file(output_path + + "/" + source_package_name) + output_dir_path = source_package_file.path[:source_package_file.path.rfind("/")] + outputs = [source_package_file] + if ctx.attr.wheel_name == "jax": + args.add("--build-source-package-only", "True") + + args.add("--output_path", output_dir_path) # required argument if not platform_independent: args.add("--cpu", cpu) args.add("--jaxlib_git_hash", git_hash) # required argument @@ -472,16 +442,17 @@ _jax_wheel = rule( "wheel_name": attr.string(mandatory = True), "no_abi": attr.bool(default = False), "platform_independent": attr.bool(default = False), - "build_wheel_only": attr.bool(default = True), + "build_wheel_only": attr.bool(mandatory = True, default = True), + "build_source_package_only": attr.bool(mandatory = True, default = False), "editable": attr.bool(default = False), - "cpu": attr.string(mandatory = True), - "platform_name": attr.string(mandatory = True), + "cpu": attr.string(), + "platform_name": attr.string(), "git_hash": attr.label(default = Label("//jaxlib/tools:jaxlib_git_hash")), "source_files": attr.label_list(allow_files = True), "output_path": attr.label(default = Label("//jaxlib/tools:output_path")), "enable_cuda": attr.bool(default = False), # A cuda/rocm version is required for gpu wheels; for cpu wheels, it can be an empty string. - "platform_version": attr.string(mandatory = True, default = ""), + "platform_version": attr.string(), "skip_gpu_kernels": attr.bool(default = False), "enable_rocm": attr.bool(default = False), "include_cuda_libs": attr.label(default = Label("@local_config_cuda//cuda:include_cuda_libs")), @@ -498,7 +469,6 @@ def jax_wheel( wheel_name, no_abi = False, platform_independent = False, - build_wheel_only = True, editable = False, enable_cuda = False, enable_rocm = False, @@ -509,11 +479,10 @@ def jax_wheel( Common artifact attributes are grouped within a single macro. Args: - name: the name of the wheel + name: the target name wheel_binary: the binary to use to build the wheel wheel_name: the name of the wheel no_abi: whether to build a wheel without ABI - build_wheel_only: whether to build a wheel without source distribution editable: whether to build an editable wheel platform_independent: whether to build a wheel without platform tag enable_cuda: whether to build a cuda wheel @@ -522,7 +491,7 @@ def jax_wheel( source_files: the source files to include in the wheel Returns: - A directory containing the wheel + A wheel file or a wheel directory. """ _jax_wheel( name = name, @@ -530,7 +499,8 @@ def jax_wheel( wheel_name = wheel_name, no_abi = no_abi, platform_independent = platform_independent, - build_wheel_only = build_wheel_only, + build_wheel_only = True, + build_source_package_only = False, editable = editable, enable_cuda = enable_cuda, enable_rocm = enable_rocm, @@ -554,6 +524,34 @@ def jax_wheel( source_files = source_files, ) +def jax_source_package( + name, + source_package_binary, + source_package_name, + source_files = []): + """Create jax source package. + + Common artifact attributes are grouped within a single macro. + + Args: + name: the target name + source_package_binary: the binary to use to build the package + source_package_name: the name of the source package + source_files: the source files to include in the package + + Returns: + A jax source package file. + """ + _jax_wheel( + name = name, + wheel_binary = source_package_binary, + wheel_name = source_package_name, + build_source_package_only = True, + build_wheel_only = False, + platform_independent = True, + source_files = source_files, + ) + jax_test_file_visibility = [] jax_export_file_visibility = [] @@ -568,4 +566,22 @@ def jax_py_test( env = dict(env) if "PYTHONWARNINGS" not in env: env["PYTHONWARNINGS"] = "error" + deps = kwargs.get("deps", []) + test_deps = _get_test_deps(deps, backend_independent = True) + kwargs["deps"] = test_deps py_test(name = name, env = env, **kwargs) + +def pytype_test(name, **kwargs): + deps = kwargs.get("deps", []) + test_deps = _get_test_deps(deps, backend_independent = True) + kwargs["deps"] = test_deps + native.py_test(name = name, **kwargs) + +def if_oss(oss_value, google_value = []): + """Returns one of the arguments based on the non-configurable build env. + + Specifically, it does not return a `select`, and can be used to e.g. + compute elements of list attributes. + """ + _ = (google_value, oss_value) # buildifier: disable=unused-variable + return oss_value diff --git a/jaxlib/jax_common.json b/jaxlib/jax_common.json new file mode 100644 index 000000000000..61a2c9313897 --- /dev/null +++ b/jaxlib/jax_common.json @@ -0,0 +1,8 @@ +{ + "global": [ + "Wrapped_PyInit_*" + ], + "local": [ + "*" + ] +} diff --git a/jaxlib/kernel_helpers.h b/jaxlib/kernel_helpers.h index dac0355fbde6..5a053f833ce4 100644 --- a/jaxlib/kernel_helpers.h +++ b/jaxlib/kernel_helpers.h @@ -17,10 +17,10 @@ limitations under the License. #define JAXLIB_KERNEL_HELPERS_H_ #include -#include #include #include "absl/base/casts.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" namespace jax { diff --git a/jaxlib/kernel_nanobind_helpers.h b/jaxlib/kernel_nanobind_helpers.h index fde37e695349..127d89f702c8 100644 --- a/jaxlib/kernel_nanobind_helpers.h +++ b/jaxlib/kernel_nanobind_helpers.h @@ -19,8 +19,8 @@ limitations under the License. #include #include -#include "nanobind/nanobind.h" #include "absl/base/casts.h" +#include "nanobind/nanobind.h" #include "jaxlib/kernel_helpers.h" #include "xla/ffi/api/c_api.h" #include "xla/tsl/python/lib/core/numpy.h" // NOLINT diff --git a/jaxlib/libjax_common.lds b/jaxlib/libjax_common.lds new file mode 100644 index 000000000000..6130415a8d26 --- /dev/null +++ b/jaxlib/libjax_common.lds @@ -0,0 +1,7 @@ +{ + global: + Wrapped_PyInit_*; + + local: + *; +}; diff --git a/jaxlib/libjax_common_darwin.lds b/jaxlib/libjax_common_darwin.lds new file mode 100644 index 000000000000..aed9a1d7512a --- /dev/null +++ b/jaxlib/libjax_common_darwin.lds @@ -0,0 +1 @@ +*Wrapped_PyInit_* diff --git a/jaxlib/mlir/_mlir_libs/BUILD.bazel b/jaxlib/mlir/_mlir_libs/BUILD.bazel index fb94837cff37..25f2162685b9 100644 --- a/jaxlib/mlir/_mlir_libs/BUILD.bazel +++ b/jaxlib/mlir/_mlir_libs/BUILD.bazel @@ -15,10 +15,8 @@ load( "//jaxlib:jax.bzl", "if_windows", - "nanobind_extension", - "py_extension", - "windows_cc_shared_mlir_library", ) +load("//jaxlib:pywrap.bzl", "nanobind_pywrap_extension") load("//jaxlib:symlink_files.bzl", "symlink_inputs") package( @@ -33,134 +31,107 @@ COPTS = [ "-frtti", ] -LINKOPTS = select({ - "@xla//xla/tsl:macos": [ - "-Wl,-rpath,@loader_path/", - "-Wl,-rename_section,__TEXT,text_env,__TEXT,__text", - ], - "@xla//xla/tsl:windows": [], - "//conditions:default": [ - "-Wl,-rpath,$$ORIGIN/", - ], -}) - -py_extension( +nanobind_pywrap_extension( name = "_mlir", srcs = [ "@llvm-project//mlir:lib/Bindings/Python/MainModule.cpp", ], copts = COPTS, - linkopts = LINKOPTS, deps = [ - ":jaxlib_mlir_capi_shared_library", - "@llvm-project//mlir:MLIRBindingsPythonCoreNoCAPI", + "@llvm-project//mlir:MLIRBindingsPythonCore", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", "@nanobind", ], ) -py_extension( +nanobind_pywrap_extension( name = "_mlirDialectsGPU", srcs = [ "@llvm-project//mlir:lib/Bindings/Python/DialectGPU.cpp", ], copts = COPTS, - linkopts = LINKOPTS, deps = [ - ":jaxlib_mlir_capi_shared_library", - "@llvm-project//mlir:CAPIGPUHeaders", - "@llvm-project//mlir:CAPIIRHeaders", + "@llvm-project//mlir:CAPIGPU", + "@llvm-project//mlir:CAPIIR", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", "@nanobind", ], ) -py_extension( +nanobind_pywrap_extension( name = "_mlirGPUPasses", srcs = [ "@llvm-project//mlir:lib/Bindings/Python/GPUPasses.cpp", ], copts = COPTS, - linkopts = LINKOPTS, deps = [ - ":jaxlib_mlir_capi_shared_library", - "@llvm-project//mlir:CAPIGPUHeaders", + "@llvm-project//mlir:CAPIGPU", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", "@nanobind", ], ) -py_extension( +nanobind_pywrap_extension( name = "_mlirDialectsNVGPU", srcs = [ "@llvm-project//mlir:lib/Bindings/Python/DialectNVGPU.cpp", ], copts = COPTS, - linkopts = LINKOPTS, deps = [ - ":jaxlib_mlir_capi_shared_library", - "@llvm-project//mlir:CAPIIRHeaders", - "@llvm-project//mlir:CAPINVGPUHeaders", + "@llvm-project//mlir:CAPIIR", + "@llvm-project//mlir:CAPINVGPU", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", "@nanobind", ], ) -py_extension( +nanobind_pywrap_extension( name = "_mlirDialectsLLVM", srcs = [ "@llvm-project//mlir:lib/Bindings/Python/DialectLLVM.cpp", ], copts = COPTS, - linkopts = LINKOPTS, deps = [ - ":jaxlib_mlir_capi_shared_library", - "@llvm-project//mlir:CAPIIRHeaders", - "@llvm-project//mlir:CAPILLVMHeaders", + "@llvm-project//mlir:CAPIIR", + "@llvm-project//mlir:CAPILLVM", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", "@nanobind", ], ) -py_extension( +nanobind_pywrap_extension( name = "_mlirDialectsSparseTensor", srcs = [ "@llvm-project//mlir:lib/Bindings/Python/DialectSparseTensor.cpp", ], copts = COPTS, - linkopts = LINKOPTS, deps = [ - ":jaxlib_mlir_capi_shared_library", - "@llvm-project//mlir:CAPISparseTensorHeaders", + "@llvm-project//mlir:CAPISparseTensor", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", "@nanobind", ], ) -py_extension( +nanobind_pywrap_extension( name = "_mlirSparseTensorPasses", srcs = [ "@llvm-project//mlir:lib/Bindings/Python/SparseTensorPasses.cpp", ], copts = COPTS, - linkopts = LINKOPTS, deps = [ - ":jaxlib_mlir_capi_shared_library", - "@llvm-project//mlir:CAPISparseTensorHeaders", + "@llvm-project//mlir:CAPISparseTensor", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", "@nanobind", ], ) -py_extension( +nanobind_pywrap_extension( name = "_mosaic_gpu_ext", srcs = ["mosaic_gpu_ext.cc"], copts = COPTS, - linkopts = LINKOPTS, deps = [ - ":jaxlib_mlir_capi_shared_library", - "//jaxlib/mosaic/dialect/gpu:gpu_dialect_capi_headers", - "@llvm-project//mlir:CAPIIRHeaders", + "//jaxlib/mosaic/dialect/gpu:gpu_dialect_capi", + "@llvm-project//mlir:CAPIIR", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeadersAndDeps", "@nanobind", ], @@ -171,17 +142,15 @@ py_extension( # :jaxlib_mlir_capi_shared_library). This ensures that the RPATH works correctly # across platforms. It's not clear if Windows supports RPATH-like functionality # across different directories at all. -py_extension( +nanobind_pywrap_extension( name = "_tpu_ext", srcs = ["tpu_ext.cc"], copts = COPTS, - linkopts = LINKOPTS, deps = [ - ":jaxlib_mlir_capi_shared_library", - "//jaxlib/mosaic:tpu_dialect_capi_headers", + "//jaxlib/mosaic:tpu_dialect_capi", "@com_google_absl//absl/log:check", "@llvm-project//llvm:Support", - "@llvm-project//mlir:CAPIIRHeaders", + "@llvm-project//mlir:CAPIIR", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeadersAndDeps", "@nanobind", "@xla//xla/python:nb_numpy", @@ -190,7 +159,7 @@ py_extension( ) # This target contains the extension and it's Python dependencies, which are not -# supported by the `py_extension`/`nanobind_extension` macros. +# supported by the `nanobind_pywrap_extension`/`nanobind_extension` macros. py_library( name = "_tpu_ext_lib", deps = [ @@ -200,19 +169,21 @@ py_library( ], ) -nanobind_extension( +nanobind_pywrap_extension( name = "_triton_ext", srcs = ["triton_ext.cc"], copts = COPTS, - linkopts = LINKOPTS, pytype_srcs = ["_triton_ext.pyi"], deps = [ - ":jaxlib_mlir_capi_shared_library", - "//jaxlib/triton:triton_dialect_capi_headers", - "@llvm-project//mlir:CAPIIRHeaders", - "@llvm-project//mlir:MLIRBindingsPythonNanobindHeadersAndDeps", "@nanobind", - ], + ] + if_windows( + [], + [ + "//jaxlib/triton:triton_dialect_capi", + "@llvm-project//mlir:CAPIIR", + "@llvm-project//mlir:MLIRBindingsPythonNanobindHeadersAndDeps", + ], + ), ) symlink_inputs( @@ -229,55 +200,28 @@ symlink_inputs( ], ) -cc_library( - name = "jaxlib_mlir_capi_shims", - srcs = ["jaxlib_mlir_capi_shims.cc"], - hdrs = ["jaxlib_mlir_capi_shims.h"], - deps = [ - "@llvm-project//mlir:BuiltinToLLVMIRTranslation", - "@llvm-project//mlir:CAPIIRHeaders", - "@llvm-project//mlir:GPUPipelines", - "@llvm-project//mlir:GPUToLLVMIRTranslation", - "@llvm-project//mlir:LLVMToLLVMIRTranslation", - "@llvm-project//mlir:MemRefTransforms", - "@llvm-project//mlir:NVVMTarget", - "@llvm-project//mlir:NVVMToLLVMIRTranslation", - ], - alwayslink = 1, -) - -cc_library( - name = "jaxlib_mlir_capi_shims_hdrs", - hdrs = ["jaxlib_mlir_capi_shims.h"], - deps = [ - "@llvm-project//mlir:CAPIIRHeaders", - ], -) - # JAX-specific registrations. -py_extension( +nanobind_pywrap_extension( name = "register_jax_dialects", srcs = ["register_jax_dialects.cc"], copts = COPTS, - linkopts = LINKOPTS, deps = [ - ":jaxlib_mlir_capi_shared_library", - "//jaxlib/mosaic/gpu:mlir_capi_headers", - "@llvm-project//mlir:CAPIArithHeaders", - "@llvm-project//mlir:CAPIGPUHeaders", - "@llvm-project//mlir:CAPIIRHeaders", - "@llvm-project//mlir:CAPILLVMHeaders", - "@llvm-project//mlir:CAPIMathHeaders", - "@llvm-project//mlir:CAPIMemRefHeaders", - "@llvm-project//mlir:CAPINVGPUHeaders", - "@llvm-project//mlir:CAPINVVMHeaders", - "@llvm-project//mlir:CAPISCFHeaders", - "@llvm-project//mlir:CAPITransformsHeaders", - "@llvm-project//mlir:CAPIVectorHeaders", + "//jaxlib/mosaic/gpu:mlir_capi", + "@llvm-project//mlir:CAPIArith", + "@llvm-project//mlir:CAPIGPU", + "@llvm-project//mlir:CAPIIR", + "@llvm-project//mlir:CAPILLVM", + "@llvm-project//mlir:CAPIMath", + "@llvm-project//mlir:CAPIMemRef", + "@llvm-project//mlir:CAPINVGPU", + "@llvm-project//mlir:CAPINVVM", + "@llvm-project//mlir:CAPISCF", + "@llvm-project//mlir:CAPITransforms", + "@llvm-project//mlir:CAPIVector", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", "@local_config_python//:headers", "@nanobind", - "@shardy//shardy/integrations/c:sdy_capi_headers", + "@shardy//shardy/integrations/c:sdy_capi", ], ) @@ -285,20 +229,18 @@ py_extension( # MHLO Extensions ##---------------------------------------------------------------------------## -py_extension( +nanobind_pywrap_extension( name = "_mlirHlo", srcs = [ "@xla//xla/mlir_hlo:bindings/python/MlirHloModule.cc", ], copts = COPTS, - linkopts = LINKOPTS, deps = [ - ":jaxlib_mlir_capi_shared_library", - "@llvm-project//mlir:CAPIIRHeaders", + "@llvm-project//mlir:CAPIIR", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", "@local_config_python//:headers", "@nanobind", - "@xla//xla/mlir_hlo:CAPIHeaders", + "@xla//xla/mlir_hlo:CAPI", ], ) @@ -306,21 +248,19 @@ py_extension( # Shardy Extensions ##---------------------------------------------------------------------------## -py_extension( +nanobind_pywrap_extension( name = "_sdy", srcs = [ "@shardy//shardy/integrations/python/ir:sdy_module.cc", ], copts = COPTS, - linkopts = LINKOPTS, deps = [ - ":jaxlib_mlir_capi_shared_library", - "@llvm-project//mlir:CAPIIRHeaders", + "@llvm-project//mlir:CAPIIR", "@llvm-project//mlir:IR", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", "@local_config_python//:headers", "@nanobind", - "@shardy//shardy/integrations/c:sdy_capi_headers", + "@shardy//shardy/integrations/c:sdy_capi", ], ) @@ -328,115 +268,33 @@ py_extension( # Stablehlo Extensions ##---------------------------------------------------------------------------## -py_extension( +nanobind_pywrap_extension( name = "_chlo", srcs = [ "@stablehlo//:chlo_py_api_files", ], copts = COPTS, - linkopts = LINKOPTS, deps = [ - ":jaxlib_mlir_capi_shared_library", - "@llvm-project//mlir:CAPIIRHeaders", + "@llvm-project//mlir:CAPIIR", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", "@local_config_python//:headers", "@nanobind", - "@stablehlo//:chlo_capi_headers", + "@stablehlo//:chlo_capi", ], ) -py_extension( +nanobind_pywrap_extension( name = "_stablehlo", srcs = [ "@stablehlo//:stablehlo_py_api_files", ], copts = COPTS, - linkopts = LINKOPTS, deps = [ - ":jaxlib_mlir_capi_shared_library", "@llvm-project//llvm:Support", - "@llvm-project//mlir:CAPIIRHeaders", + "@llvm-project//mlir:CAPIIR", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", "@local_config_python//:headers", "@nanobind", - "@stablehlo//:stablehlo_capi_headers", + "@stablehlo//:stablehlo_capi", ], -) - -# Shared C++ extension library - -cc_library( - name = "jaxlib_mlir_capi_shared_library", - srcs = select({ - "@xla//xla/tsl:windows": [":jaxlib_mlir_capi.dll"], - "@xla//xla/tsl:macos": [":libjaxlib_mlir_capi.dylib"], - "//conditions:default": [":libjaxlib_mlir_capi.so"], - }), - deps = select({ - "@xla//xla/tsl:windows": [":jaxlib_mlir_capi_dll"], - "//conditions:default": [], - }), -) - -cc_library( - name = "jaxlib_mlir_capi_objects", - deps = [ - "//jaxlib/mosaic:tpu_dialect_capi_objects", - "//jaxlib/mosaic/dialect/gpu:gpu_dialect_capi_objects", - "//jaxlib/mosaic/gpu:mlir_capi_objects", - "@llvm-project//mlir:CAPIArithObjects", - "@llvm-project//mlir:CAPIGPUObjects", - "@llvm-project//mlir:CAPIIRObjects", - "@llvm-project//mlir:CAPILLVMObjects", - "@llvm-project//mlir:CAPIMathObjects", - "@llvm-project//mlir:CAPIMemRefObjects", - "@llvm-project//mlir:CAPINVGPUObjects", - "@llvm-project//mlir:CAPINVVMObjects", - "@llvm-project//mlir:CAPISCFObjects", - "@llvm-project//mlir:CAPISparseTensorObjects", - "@llvm-project//mlir:CAPITransformsObjects", - "@llvm-project//mlir:CAPIVectorObjects", - "@llvm-project//mlir:MLIRBindingsPythonCAPIObjects", - "@shardy//shardy/integrations/c:sdy_capi_objects", - "@stablehlo//:chlo_capi_objects", - "@stablehlo//:stablehlo_capi_objects", - "@xla//xla/mlir_hlo:CAPIObjects", - ] + if_windows( - [], - [ - "//jaxlib/triton:triton_dialect_capi_objects", - ], - ), -) - -cc_binary( - name = "libjaxlib_mlir_capi.so", - linkopts = [ - "-Wl,-soname=libjaxlib_mlir_capi.so", - "-Wl,-rpath='$$ORIGIN'", - ], - linkshared = 1, - deps = [":jaxlib_mlir_capi_objects"], -) - -cc_binary( - name = "libjaxlib_mlir_capi.dylib", - linkopts = [ - "-Wl,-rpath,@loader_path/", - "-Wl,-install_name,@loader_path/libjaxlib_mlir_capi.dylib", - ], - linkshared = 1, - deps = [":jaxlib_mlir_capi_objects"], -) - -windows_cc_shared_mlir_library( - name = "jaxlib_mlir_capi_dll", - out = "jaxlib_mlir_capi.dll", - exported_symbol_prefixes = [ - "mlir", - "chlo", - "sdy", - "stablehlo", - ], - deps = [":jaxlib_mlir_capi_objects"], -) +) \ No newline at end of file diff --git a/jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc b/jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc index c73084abc99d..2751719fc61d 100644 --- a/jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc +++ b/jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc @@ -138,25 +138,4 @@ NB_MODULE(_mosaic_gpu_ext, m) { .def_property_readonly("swizzle", [](MlirAttribute self) { return mlirMosaicGpuSwizzleTransformAttrGetSwizzle(self); }); - - mlir::python::nanobind_adaptors::mlir_attribute_subclass( - m, "LayoutAttr", mlirMosaicGpuIsALayoutAttr) - .def_classmethod( - "get", - [](nb::object cls, int32_t num_dimensions, - std::vector& transforms, MlirContext ctx) { - return cls(mlirMosaicGpuLayoutAttrGet( - ctx, num_dimensions, transforms.data(), transforms.size())); - }, - nb::arg("cls"), nb::arg("num_dimensions"), nb::arg("transforms"), - nb::arg("context").none() = nb::none(), - "Creates a LayoutAttr with the given transforms.") - .def_property_readonly("transforms", [](MlirAttribute self) { - std::vector result; - for (int i = 0; i < mlirMosaicGpuLayoutAttrGetTransformsSize(self); - ++i) { - result.push_back(mlirMosaicGpuLayoutAttrGetTransform(self, i)); - } - return result; - }); } diff --git a/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc b/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc index 9da841acc7de..b8432bf615c9 100644 --- a/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc +++ b/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc @@ -1,28 +1,43 @@ +/* Copyright 2022 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + // Registers MLIR dialects used by JAX. // This module is called by mlir/__init__.py during initialization. #include -#include "mlir-c/Dialect/Arith.h" -#include "mlir-c/Dialect/Func.h" -#include "mlir-c/Dialect/GPU.h" -#include "mlir-c/Dialect/LLVM.h" -#include "mlir-c/Dialect/Math.h" -#include "mlir-c/Dialect/MemRef.h" -#include "mlir-c/Dialect/NVGPU.h" -#include "mlir-c/Dialect/NVVM.h" -#include "mlir-c/Dialect/SCF.h" -#include "mlir-c/Dialect/Vector.h" +#include "mlir-c/Dialect/Arith.h" // IWYU pragma: keep +#include "mlir-c/Dialect/Func.h" // IWYU pragma: keep +#include "mlir-c/Dialect/GPU.h" // IWYU pragma: keep +#include "mlir-c/Dialect/LLVM.h" // IWYU pragma: keep +#include "mlir-c/Dialect/Math.h" // IWYU pragma: keep +#include "mlir-c/Dialect/MemRef.h" // IWYU pragma: keep +#include "mlir-c/Dialect/NVGPU.h" // IWYU pragma: keep +#include "mlir-c/Dialect/NVVM.h" // IWYU pragma: keep +#include "mlir-c/Dialect/SCF.h" // IWYU pragma: keep +#include "mlir-c/Dialect/Vector.h" // IWYU pragma: keep +#include "mlir-c/IR.h" #include "mlir-c/Transforms.h" -#include "mlir/Bindings/Python/NanobindAdaptors.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" // IWYU pragma: keep #include "shardy/integrations/c/passes.h" #include "jaxlib/mosaic/gpu/integrations/c/passes.h" - namespace nb = nanobind; -#define REGISTER_DIALECT(name) \ - MlirDialectHandle name##_dialect = mlirGetDialectHandle__##name##__(); \ - mlirDialectHandleInsertDialect(name##_dialect, registry) +#define REGISTER_DIALECT(name) \ + MlirDialectHandle name##_dialect = mlirGetDialectHandle__##name##__(); \ + mlirDialectHandleInsertDialect(name##_dialect, registry) NB_MODULE(register_jax_dialects, m) { m.doc() = "Registers upstream MLIR dialects used by JAX."; diff --git a/jaxlib/mlir/_mlir_libs/tpu_ext.cc b/jaxlib/mlir/_mlir_libs/tpu_ext.cc index 2b5ec898ad3e..8f751693e451 100644 --- a/jaxlib/mlir/_mlir_libs/tpu_ext.cc +++ b/jaxlib/mlir/_mlir_libs/tpu_ext.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -27,9 +26,8 @@ limitations under the License. #include #include -#include "llvm/ADT/ArrayRef.h" +#include "absl/types/span.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/SmallVectorExtras.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" #include "mlir-c/AffineMap.h" @@ -41,15 +39,14 @@ limitations under the License. #include "mlir-c/Support.h" #include "mlir/Bindings/Python/NanobindAdaptors.h" // IWYU pragma: keep // clang-format off -#include "mlir-c/Bindings/Python/Interop.h" // clang-format on +#include "absl/log/check.h" #include "nanobind/nanobind.h" #include "nanobind/stl/optional.h" // IWYU pragma: keep -#include "nanobind/stl/pair.h" // IWYU pragma: keep -#include "nanobind/stl/string.h" // IWYU pragma: keep -#include "nanobind/stl/variant.h" // IWYU pragma: keep -#include "nanobind/stl/vector.h" // IWYU pragma: keep -#include "absl/log/check.h" +#include "nanobind/stl/pair.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/variant.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep #include "jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.h" #include "xla/python/nb_numpy.h" #include "xla/tsl/python/lib/core/numpy.h" diff --git a/jaxlib/mlir/_mlir_libs/triton_ext.cc b/jaxlib/mlir/_mlir_libs/triton_ext.cc index 2a13c40d963f..687ceec4cd33 100644 --- a/jaxlib/mlir/_mlir_libs/triton_ext.cc +++ b/jaxlib/mlir/_mlir_libs/triton_ext.cc @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#ifndef _WIN32 + +#include #include #include "mlir-c/IR.h" @@ -73,3 +76,11 @@ NB_MODULE(_triton_ext, m) { return encoding; }); } + +#else // _WIN32 + +#include "nanobind/nanobind.h" + +NB_MODULE(_triton_ext, m) {} + +#endif // _WIN32 diff --git a/jaxlib/mosaic/BUILD b/jaxlib/mosaic/BUILD index 4cc2530dd7ca..a4123d0654bf 100644 --- a/jaxlib/mosaic/BUILD +++ b/jaxlib/mosaic/BUILD @@ -60,9 +60,9 @@ cc_library( ]), # compatible with libtpu deps = [ + ":pass_boilerplate", + ":serde", ":tpu_inc_gen", - "//jaxlib:pass_boilerplate", - "//jaxlib/mosaic:serde", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/hash", @@ -95,75 +95,37 @@ cc_library( "@xla//xla:shape_util", "@xla//xla:util", "@xla//xla/tsl/platform:errors", + "@xla//xla/tsl/platform:statusor", ] + mosaic_extension_deps, ) gentbl_cc_library( name = "tpu_inc_gen", # compatible with libtpu - tbl_outs = [ - ( - ["-gen-op-decls"], - "dialect/tpu/tpu_ops.h.inc", - ), - ( - ["-gen-op-defs"], - "dialect/tpu/tpu_ops.cc.inc", - ), - ( - ["-gen-dialect-decls"], - "dialect/tpu/tpu_dialect.h.inc", - ), - ( - ["-gen-dialect-defs"], - "dialect/tpu/tpu_dialect.cc.inc", - ), - ( - ["-gen-enum-decls"], - "dialect/tpu/tpu_enums.h.inc", - ), - ( - ["-gen-enum-defs"], - "dialect/tpu/tpu_enums.cc.inc", - ), - ( - ["-gen-attrdef-decls"], - "dialect/tpu/tpu_attr_defs.h.inc", - ), - ( - ["-gen-attrdef-defs"], - "dialect/tpu/tpu_attr_defs.cc.inc", - ), - ( - ["-gen-typedef-decls"], - "dialect/tpu/tpu_type_defs.h.inc", - ), - ( - ["-gen-typedef-defs"], - "dialect/tpu/tpu_type_defs.cc.inc", - ), - ( - [ - "-gen-pass-decls", - "-name=TPU", - ], - "dialect/tpu/tpu_passes.h.inc", - ), - ( - [ - "-gen-pass-capi-header", - "--prefix=TPU", - ], - "dialect/tpu/integrations/c/tpu_passes.capi.h.inc", - ), - ( - [ - "-gen-pass-capi-impl", - "--prefix=TPU", - ], - "dialect/tpu/integrations/c/tpu_passes.capi.cc.inc", - ), - ], + tbl_outs = { + "dialect/tpu/tpu_ops.h.inc": ["-gen-op-decls"], + "dialect/tpu/tpu_ops.cc.inc": ["-gen-op-defs"], + "dialect/tpu/tpu_dialect.h.inc": ["-gen-dialect-decls"], + "dialect/tpu/tpu_dialect.cc.inc": ["-gen-dialect-defs"], + "dialect/tpu/tpu_enums.h.inc": ["-gen-enum-decls"], + "dialect/tpu/tpu_enums.cc.inc": ["-gen-enum-defs"], + "dialect/tpu/tpu_attr_defs.h.inc": ["-gen-attrdef-decls"], + "dialect/tpu/tpu_attr_defs.cc.inc": ["-gen-attrdef-defs"], + "dialect/tpu/tpu_type_defs.h.inc": ["-gen-typedef-decls"], + "dialect/tpu/tpu_type_defs.cc.inc": ["-gen-typedef-defs"], + "dialect/tpu/tpu_passes.h.inc": [ + "-gen-pass-decls", + "-name=TPU", + ], + "dialect/tpu/integrations/c/tpu_passes.capi.h.inc": [ + "-gen-pass-capi-header", + "--prefix=TPU", + ], + "dialect/tpu/integrations/c/tpu_passes.capi.cc.inc": [ + "-gen-pass-capi-impl", + "--prefix=TPU", + ], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "dialect/tpu/tpu.td", deps = [":tpu_td_files"], @@ -279,6 +241,17 @@ filegroup( # compatible with libtpu ) +cc_library( + name = "pass_boilerplate", + hdrs = ["pass_boilerplate.h"], + # compatible with libtpu + deps = [ + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + ], +) + cc_library( name = "serde", srcs = ["serde.cc"], diff --git a/jaxlib/mosaic/dialect/gpu/BUILD b/jaxlib/mosaic/dialect/gpu/BUILD index e21c8756a4e2..854955a60493 100644 --- a/jaxlib/mosaic/dialect/gpu/BUILD +++ b/jaxlib/mosaic/dialect/gpu/BUILD @@ -39,66 +39,36 @@ td_library( gentbl_cc_library( name = "mosaic_gpu_inc_gen", - tbl_outs = [ - ( - [ - "-gen-dialect-decls", - "-dialect=mosaic_gpu", - ], - "mosaic_gpu_dialect.h.inc", - ), - ( - [ - "-gen-dialect-defs", - "-dialect=mosaic_gpu", - ], - "mosaic_gpu_dialect.cc.inc", - ), - ( - ["-gen-op-decls"], - "mosaic_gpu_ops.h.inc", - ), - ( - ["-gen-op-defs"], - "mosaic_gpu_ops.cc.inc", - ), - ( - [ - "-gen-typedef-decls", - "--typedefs-dialect=mosaic_gpu", - ], - "mosaic_gpu_types.h.inc", - ), - ( - [ - "-gen-typedef-defs", - "--typedefs-dialect=mosaic_gpu", - ], - "mosaic_gpu_types.cc.inc", - ), - ( - ["-gen-enum-decls"], - "mosaic_gpu_enums.h.inc", - ), - ( - ["-gen-enum-defs"], - "mosaic_gpu_enums.cc.inc", - ), - ( - [ - "-gen-attrdef-decls", - "--attrdefs-dialect=mosaic_gpu", - ], - "mosaic_gpu_attrdefs.h.inc", - ), - ( - [ - "-gen-attrdef-defs", - "--attrdefs-dialect=mosaic_gpu", - ], - "mosaic_gpu_attrdefs.cc.inc", - ), - ], + tbl_outs = { + "mosaic_gpu_dialect.h.inc": [ + "-gen-dialect-decls", + "-dialect=mosaic_gpu", + ], + "mosaic_gpu_dialect.cc.inc": [ + "-gen-dialect-defs", + "-dialect=mosaic_gpu", + ], + "mosaic_gpu_ops.h.inc": ["-gen-op-decls"], + "mosaic_gpu_ops.cc.inc": ["-gen-op-defs"], + "mosaic_gpu_types.h.inc": [ + "-gen-typedef-decls", + "--typedefs-dialect=mosaic_gpu", + ], + "mosaic_gpu_types.cc.inc": [ + "-gen-typedef-defs", + "--typedefs-dialect=mosaic_gpu", + ], + "mosaic_gpu_enums.h.inc": ["-gen-enum-decls"], + "mosaic_gpu_enums.cc.inc": ["-gen-enum-defs"], + "mosaic_gpu_attrdefs.h.inc": [ + "-gen-attrdef-decls", + "--attrdefs-dialect=mosaic_gpu", + ], + "mosaic_gpu_attrdefs.cc.inc": [ + "-gen-attrdef-defs", + "--attrdefs-dialect=mosaic_gpu", + ], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "mosaic_gpu.td", deps = [ @@ -119,6 +89,7 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:GPUDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:InferTypeOpInterface", "@llvm-project//mlir:LLVMCommonConversion", @@ -127,7 +98,7 @@ cc_library( "@llvm-project//mlir:MemRefUtils", "@llvm-project//mlir:SCFUtils", "@llvm-project//mlir:Support", - "@tsl//tsl/platform:statusor", + "@xla//xla/tsl/platform:statusor", ], ) @@ -151,7 +122,7 @@ cc_test( "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:SCFUtils", "@llvm-project//mlir:Support", - "@tsl//tsl/platform:errors", + "@xla//xla/tsl/platform:errors", ], ) diff --git a/jaxlib/mosaic/dialect/gpu/integrations/c/attributes.cc b/jaxlib/mosaic/dialect/gpu/integrations/c/attributes.cc index eac1d104f07f..523b14e425c9 100644 --- a/jaxlib/mosaic/dialect/gpu/integrations/c/attributes.cc +++ b/jaxlib/mosaic/dialect/gpu/integrations/c/attributes.cc @@ -1,7 +1,21 @@ +/* Copyright 2025 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #include "jaxlib/mosaic/dialect/gpu/integrations/c/attributes.h" #include -#include #include "mlir-c/IR.h" #include "mlir/CAPI/IR.h" @@ -82,36 +96,3 @@ int32_t mlirMosaicGpuSwizzleTransformAttrGetSwizzle(MlirAttribute attr) { .getSwizzle() .getValue()); } - -//===----------------------------------------------------------------------===// -// LayoutAttr -//===----------------------------------------------------------------------===// - -bool mlirMosaicGpuIsALayoutAttr(MlirAttribute attr) { - return mlir::isa(unwrap(attr)); -} - -MlirAttribute mlirMosaicGpuLayoutAttrGet(MlirContext ctx, - int32_t num_dimensions, - MlirAttribute* transforms, - int32_t transforms_size) { - std::vector unwrapped_transforms; - unwrapped_transforms.reserve(transforms_size); - for (int i = 0; i < transforms_size; ++i) { - unwrapped_transforms.push_back(unwrap(transforms[i])); - } - return wrap(mosaic_gpu::LayoutAttr::get(unwrap(ctx), num_dimensions, - unwrapped_transforms)); -} - -int32_t mlirMosaicGpuLayoutAttrGetTransformsSize(MlirAttribute attr) { - return mlir::cast(unwrap(attr)) - .getTransforms() - .size(); -} - -MlirAttribute mlirMosaicGpuLayoutAttrGetTransform(MlirAttribute attr, - int32_t index) { - return wrap( - mlir::cast(unwrap(attr)).getTransforms()[index]); -} \ No newline at end of file diff --git a/jaxlib/mosaic/dialect/gpu/integrations/c/attributes.h b/jaxlib/mosaic/dialect/gpu/integrations/c/attributes.h index 3b8425b6b142..3221b9220e5d 100644 --- a/jaxlib/mosaic/dialect/gpu/integrations/c/attributes.h +++ b/jaxlib/mosaic/dialect/gpu/integrations/c/attributes.h @@ -69,22 +69,6 @@ mlirMosaicGpuSwizzleTransformAttrGet(MlirContext ctx, int32_t swizzle); MLIR_CAPI_EXPORTED int32_t mlirMosaicGpuSwizzleTransformAttrGetSwizzle(MlirAttribute attr); -//===----------------------------------------------------------------------===// -// LayoutAttr -//===----------------------------------------------------------------------===// - -MLIR_CAPI_EXPORTED bool mlirMosaicGpuIsALayoutAttr(MlirAttribute attr); - -MLIR_CAPI_EXPORTED MlirAttribute -mlirMosaicGpuLayoutAttrGet(MlirContext ctx, int32_t num_dimensions, - MlirAttribute* transforms, int32_t transforms_size); - -MLIR_CAPI_EXPORTED int32_t -mlirMosaicGpuLayoutAttrGetTransformsSize(MlirAttribute attr); - -MLIR_CAPI_EXPORTED MlirAttribute -mlirMosaicGpuLayoutAttrGetTransform(MlirAttribute attr, int32_t index); - #ifdef __cplusplus } #endif diff --git a/jaxlib/mosaic/dialect/gpu/integrations/c/gpu_dialect.h b/jaxlib/mosaic/dialect/gpu/integrations/c/gpu_dialect.h index bb6cf6e3af4a..5fd0ce7a4f7a 100644 --- a/jaxlib/mosaic/dialect/gpu/integrations/c/gpu_dialect.h +++ b/jaxlib/mosaic/dialect/gpu/integrations/c/gpu_dialect.h @@ -18,7 +18,7 @@ limitations under the License. #include -#include "mlir/CAPI/Registration.h" +#include "mlir-c/IR.h" #ifdef __cplusplus extern "C" { diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc index a1e7b571d20e..073697df58ef 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc @@ -18,6 +18,11 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" // IWYU pragma: keep #include "llvm/Support/Casting.h" @@ -26,13 +31,17 @@ limitations under the License. #include "mlir/Conversion/LLVMCommon/MemRefBuilder.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" #include "mlir/Dialect/SCF/Utils/Utils.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/DialectImplementation.h" // IWYU pragma: keep #include "mlir/IR/ImplicitLocOpBuilder.h" @@ -43,15 +52,7 @@ limitations under the License. #include "mlir/IR/Value.h" #include "mlir/IR/ValueRange.h" #include "mlir/Support/LLVM.h" -#include "absl/algorithm/container.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "mlir/include/mlir/IR/BuiltinTypeInterfaces.h" -#include "mlir/include/mlir/IR/BuiltinTypes.h" -#include "mlir/include/mlir/IR/Diagnostics.h" -#include "tsl/platform/statusor.h" +#include "xla/tsl/platform/statusor.h" // Generated definitions. #include "jaxlib/mosaic/dialect/gpu/mosaic_gpu_dialect.cc.inc" @@ -371,12 +372,39 @@ llvm::LogicalResult WGMMAOp::verify() { return llvm::success(); } -mlir::AffineMap LayoutAttr::getAffineMap() const { - // This always returns an identity map. It's technically not correct, but we - // don't actually use it anywhere. It's only called during verification of the - // layout attribute and needs to be semi-valid. - return mlir::AffineMap::getMultiDimIdentityMap(getNumDimensions(), - getContext()); +llvm::LogicalResult CustomPrimitiveOp::verify() { + int num_vector_operands = 0; + int num_smem_ref_operands = 0; + mlir::Attribute smem = mlir::gpu::AddressSpaceAttr::get( + getContext(), mlir::gpu::AddressSpace::Workgroup); + for (auto operand : getOperands()) { + if (mlir::isa(operand.getType())) { + ++num_vector_operands; + } + + if (auto ref_ty = mlir::dyn_cast(operand.getType())) { + if (ref_ty.getMemorySpace() == smem) { + ++num_smem_ref_operands; + } + } + } + + if (num_vector_operands != getInLayouts().size()) { + return emitOpError( + "Custom primitive must have a layout for each vector operand."); + } + + if (num_smem_ref_operands != getInTransforms().size()) { + return emitOpError( + "Custom primitive must have transforms for each memref operand in " + "smem."); + } + + if (getResults().size() != getOutLayouts().size()) { + return emitOpError("Custom primitive must have a layout for each result."); + } + + return llvm::success(); } void MosaicGPUDialect::initialize() { diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.h b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.h index b4f13c50bd8c..474ed93806a1 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.h +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.h @@ -19,17 +19,15 @@ limitations under the License. #include #include -#include "llvm/ADT/StringRef.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Value.h" -#include "mlir/Interfaces/InferTypeOpInterface.h" -#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" // IWYU pragma: keep #include "mlir/Support/LLVM.h" -#include "absl/status/status.h" -#include "absl/strings/string_view.h" // Generated definitions. #include "jaxlib/mosaic/dialect/gpu/mosaic_gpu_dialect.h.inc" // IWYU pragma: keep diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td index 0882986fcf5e..cbc0ef9703aa 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td @@ -142,16 +142,15 @@ def MosaicGPU_WGSplatFragLayout : AttrDef { - let summary = "1D array that is a row that can be tiled by supported WGMMA shapes."; +def MosaicGPU_Replicated : AttrDef { + let summary = "Indicates a replicated dimension in a tiled layout."; let description = [{ - This layout is used to handle rows that are fragmented across all threads - in a warpgroup that is executing a WGMMA operation. The length of the array - must be divisible by 64. + See mosaic/gpu/fragmented_array.py -> Replicated for more details. }]; - let mnemonic = "WGMMARowFragLayout"; - let assemblyFormat = ""; + let parameters = (ins "int":$times); + let mnemonic = "Replicated"; + let assemblyFormat = "`<` `times` `=` $times `>`"; } def MosaicGPU_TiledLayout : AttrDef { @@ -162,7 +161,7 @@ def MosaicGPU_TiledLayout : AttrDef { let parameters = (ins "::mlir::ArrayAttr":$tiling, - "int":$warp_dim, + "::mlir::Attribute":$warp_dim, "::mlir::ArrayAttr":$lane_dims, "int":$vector_dim ); @@ -225,27 +224,6 @@ def SwizzleTransformAttr : MosaicGPU_Attr<"SwizzleTransform", "swizzle"> { let assemblyFormat = "`<` $swizzle `>`"; } -def LayoutAttr : MosaicGPU_Attr<"Layout", "layout", - [DeclareAttrInterfaceMethods]> { - let parameters = (ins - TypeParameter<"int32_t", "number of dimensions">:$num_dimensions, - ArrayRefParameter<"mlir::Attribute", "transforms">:$transforms - ); - - let summary = "Specifies a layout of a memref in SMEM."; - let description = [{ - This layout attribute is used to specify the layout of a memref in SMEM. - It is composed of a number of transforms, which are applied in the order - they are provided. The transforms can be any combination of: - - TileTransformAttr - - TransposeTransformAttr - - SwizzleTransformAttr - - The num_dimensions parameter must match the rank of the memref shape. - }]; - let assemblyFormat = "`<` $num_dimensions `,` $transforms `>`"; -} - def MosaicGPU_AsyncLoadOp : Op { let summary = "Schedules an async load of a MemRef from GMEM to SMEM"; @@ -265,28 +243,20 @@ def MosaicGPU_AsyncLoadOp : Op { + let summary = "Casts a vector to a new layout."; + let description = [{Casts a vector value to a new strided or tiled layout.}]; + let arguments = (ins + AnyVectorOfAnyRank:$x, + + // Attributes + AnyAttrOf<[ + MosaicGPU_WGStridedFragLayout, + MosaicGPU_TiledLayout + ]>:$new_layout + ); + + let results = (outs AnyVectorOfAnyRank); + + let assemblyFormat = "`x` `(` $x `:` type($x) `)` attr-dict"; + + let extraClassDeclaration = [{ + static llvm::LogicalResult inferReturnTypes( + mlir::MLIRContext *, + std::optional location, + mlir::ValueRange operands, + mlir::DictionaryAttr attributes, + mlir::OpaqueProperties properties, + mlir::RegionRange regions, + llvm::SmallVectorImpl &inferredReturnTypes) { + if (operands.empty()) { + return ::mlir::emitOptionalError( + location, "expected non-empty operands"); + } + inferredReturnTypes.assign({operands[0].getType()}); + return ::mlir::success(); + } + }]; +} + def MosaicGPU_SliceSMEMOp : Op { let summary = "Constructs an SMEM MemRef with the requested type that begins at the specified SMEM offset address."; @@ -394,19 +395,14 @@ def MosaicGPU_WGMMAOp : Op { This operation supports larger inputs than the PTX-level WGMMA operation and will schedule as many PTX-level WGMMA operations as needed to - accomplish the calculation. The `b` matrix, and optionally `a`, needs to be - provided as a 2-dimensional memref. All memrefs may have transforms that - define swizzling, tiling, and transposition. + accomplish the calculation. The `b` matrix, and optionally `a`, need to be + provided as a 2-dimensional memref. The inputs should have the following shapes: - a: [groups_m * 64, groups_k * s] - b: [groups_k * s, groups_n * s] - accumulator: [groups_m * 64, groups_n * s] - Where: - - `s == swizzle/element_bytediwth` (for `kNoSwizzle`, `swizzle` is 16.) - and the tilings are [64, s] for `a` and [s, s] for `b`. - - `a` and/or `b` may be transposed if the corresponding attribute is set - to `true`. + where `s == swizzle / element_bytewidth`. The output has an identical shape and type as the input accumulator. @@ -429,10 +425,7 @@ def MosaicGPU_WGMMAOp : Op { AnyTypeOf<[ MemRefOf<[MosaicGPU_WGMMASupportedType]>, VectorOfAnyRankOf<[MosaicGPU_WGMMASupportedType]>]>:$a, - MemRefOf<[MosaicGPU_WGMMASupportedType]>:$b, - - DefaultValuedOptionalAttr:$transpose_a, - DefaultValuedOptionalAttr:$transpose_b + MemRefOf<[MosaicGPU_WGMMASupportedType]>:$b ); let results = (outs VectorOfAnyRankOf<[MosaicGPU_WGMMASupportedType]>); @@ -465,4 +458,59 @@ def MosaicGPU_WGMMAOp : Op { let hasVerifier = 1; } +def MosaicGPU_OptimizationBarrierOp : Op { + let summary = "Prevents MLIR from moving operations across the barrier."; + + let arguments = (ins + Variadic:$operands + ); + let results = (outs Variadic); + + let extraClassDeclaration = [{ + static llvm::LogicalResult inferReturnTypes( + mlir::MLIRContext *, + std::optional location, + mlir::ValueRange operands, + mlir::DictionaryAttr attributes, + mlir::OpaqueProperties properties, + mlir::RegionRange regions, + llvm::SmallVectorImpl &inferredReturnTypes) { + if (operands.empty()) { + return ::mlir::emitOptionalError( + location, "expected non-empty operands"); + } + ::mlir::TypeRange operand_types = operands.getTypes(); + inferredReturnTypes.assign(operand_types.begin(), operand_types.end()); + return ::mlir::success(); + } + }]; +} + +def MosaicGPU_CustomPrimitiveOp : Op { + let summary = "Allows defining a custom Mosaic GPU primitive."; + let description = [{ + Allows defining a custom Mosaic GPU primitive. + + Custom primitives should carry input and output layouts for each of their + vector operands and outputs, and input transforms for each of their memref + operands that live in SMEM. + + Custom primitives can only return vectors. + }]; + + let arguments = ( + ins Variadic:$operands, + // Attributes + ArrayAttr:$in_layouts, + ArrayAttr:$in_transforms, + ArrayAttr:$out_layouts + ); + + let results = (outs Variadic>); + let regions = (region AnyRegion:$body); + + let hasVerifier = 1; +} + #endif // THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_GPU_MOSAIC_GPU_TD_ diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu_test.cc b/jaxlib/mosaic/dialect/gpu/mosaic_gpu_test.cc index 527aa7c7ce25..5458ba7fac88 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu_test.cc +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -25,26 +26,25 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" -#include "llvm/include/llvm/ADT/ArrayRef.h" -#include "llvm/include/llvm/ADT/SmallVector.h" -#include "mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h" -#include "mlir/include/mlir/Conversion/LLVMCommon/StructBuilder.h" -#include "mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" -#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/include/mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/include/mlir/Dialect/SCF/Utils/Utils.h" -#include "mlir/include/mlir/IR/Builders.h" -#include "mlir/include/mlir/IR/BuiltinOps.h" -#include "mlir/include/mlir/IR/Diagnostics.h" -#include "mlir/include/mlir/IR/MLIRContext.h" -#include "mlir/include/mlir/IR/OwningOpRef.h" -#include "mlir/include/mlir/IR/Types.h" -#include "mlir/include/mlir/IR/Value.h" -#include "mlir/include/mlir/IR/Verifier.h" -#include "mlir/include/mlir/Interfaces/DataLayoutInterfaces.h" -#include "mlir/include/mlir/Support/LLVM.h" -#include "tsl/platform/errors.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Conversion/LLVMCommon/MemRefBuilder.h" +#include "mlir/Conversion/LLVMCommon/StructBuilder.h" +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/Utils/Utils.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OwningOpRef.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Interfaces/DataLayoutInterfaces.h" +#include "mlir/Support/LLVM.h" +#include "xla/tsl/platform/errors.h" namespace mosaic_gpu { namespace { diff --git a/jaxlib/mosaic/dialect/tpu/array_util.cc b/jaxlib/mosaic/dialect/tpu/array_util.cc index 4c1e79667c0f..f7d559fb08bc 100644 --- a/jaxlib/mosaic/dialect/tpu/array_util.cc +++ b/jaxlib/mosaic/dialect/tpu/array_util.cc @@ -19,8 +19,8 @@ limitations under the License. #include "absl/log/check.h" #include "absl/types/span.h" -#include "llvm/include/llvm/ADT/STLExtras.h" -#include "mlir/include/mlir/Support/LLVM.h" +#include "llvm/ADT/STLExtras.h" +#include "mlir/Support/LLVM.h" namespace mlir::tpu::internal { diff --git a/jaxlib/mosaic/dialect/tpu/array_util.h b/jaxlib/mosaic/dialect/tpu/array_util.h index 1b755dbf8495..ab8e98d17836 100644 --- a/jaxlib/mosaic/dialect/tpu/array_util.h +++ b/jaxlib/mosaic/dialect/tpu/array_util.h @@ -20,7 +20,7 @@ limitations under the License. #include "absl/log/check.h" #include "absl/types/span.h" -#include "mlir/include/mlir/Support/LLVM.h" +#include "mlir/Support/LLVM.h" #include "jaxlib/mosaic/dialect/tpu/util.h" #include "xla/array.h" diff --git a/jaxlib/mosaic/dialect/tpu/array_util_test.cc b/jaxlib/mosaic/dialect/tpu/array_util_test.cc index 18c2f94fa8b6..bcbf417a967b 100644 --- a/jaxlib/mosaic/dialect/tpu/array_util_test.cc +++ b/jaxlib/mosaic/dialect/tpu/array_util_test.cc @@ -20,7 +20,7 @@ limitations under the License. #include #include -#include "mlir/include/mlir/Support/LLVM.h" +#include "mlir/Support/LLVM.h" #include "xla/array.h" namespace mlir::tpu { diff --git a/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.cc b/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.cc index 772e87beff71..dee4f5de43d8 100644 --- a/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.cc +++ b/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.cc @@ -21,8 +21,13 @@ limitations under the License. #include #include #include +#include +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/LogicalResult.h" #include "llvm/Support/MemAlloc.h" #include "llvm/Support/raw_ostream.h" #include "mlir-c/IR.h" @@ -31,16 +36,14 @@ limitations under the License. #include "mlir/CAPI/Registration.h" #include "mlir/CAPI/Utils.h" #include "mlir/CAPI/Wrap.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/MLIRContext.h" #include "mlir/IR/Value.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" -#include "absl/log/check.h" -#include "absl/log/log.h" #include "jaxlib/mosaic/dialect/tpu/layout.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" #include "jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.h" @@ -410,7 +413,7 @@ MLIR_CAPI_EXPORTED void mlirTpuRegisterMosaicSerdePass() { mlir::tpu::registerMosaicSerdePass(); } -#include "mlir/CAPI/Pass.h" // IWYU pragma: keep +#include "mlir/CAPI/Pass.h" // IWYU pragma: keep #include "mlir/CAPI/Support.h" // IWYU pragma: keep extern "C" { diff --git a/jaxlib/mosaic/dialect/tpu/layout.cc b/jaxlib/mosaic/dialect/tpu/layout.cc index 172f2e91b41f..7ae8681e6980 100644 --- a/jaxlib/mosaic/dialect/tpu/layout.cc +++ b/jaxlib/mosaic/dialect/tpu/layout.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -27,6 +26,7 @@ limitations under the License. #include #include +#include "absl/log/check.h" #include "llvm/ADT/Hashing.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/MathExtras.h" @@ -41,7 +41,6 @@ limitations under the License. #include "mlir/IR/ValueRange.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" -#include "absl/log/check.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" #include "jaxlib/mosaic/dialect/tpu/util.h" diff --git a/jaxlib/mosaic/dialect/tpu/layout.h b/jaxlib/mosaic/dialect/tpu/layout.h index 2c45be62fa7d..8261d09697e3 100644 --- a/jaxlib/mosaic/dialect/tpu/layout.h +++ b/jaxlib/mosaic/dialect/tpu/layout.h @@ -18,15 +18,16 @@ limitations under the License. #include #include +#include #include -#include #include #include #include +#include "absl/log/check.h" +#include "llvm/ADT/Hashing.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/bit.h" -#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/Builders.h" @@ -38,7 +39,6 @@ limitations under the License. #include "mlir/IR/Value.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" -#include "absl/log/check.h" namespace mlir::tpu { @@ -233,6 +233,11 @@ class VectorLayout { implicit_dim_(implicit_dim) { // TODO(b/275751535): Allow more bitwidths. CHECK(llvm::has_single_bit(bitwidth_) && bitwidth_ <= 32); + CHECK_GT(tiling_[0], 0); + CHECK_GT(tiling_[1], 0); + CHECK_GE(offsets_[0].value_or(0), 0); + CHECK_GE(offsets_[1].value_or(0), 0); + CHECK_LT(offsets_[0].value_or(0), tiling_[0]); } static int num_implicit_dims(const ImplicitDim implicit_dim) { diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index 4b5ed34934d7..e574889626aa 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -176,6 +176,9 @@ class TPU_Op traits = []> : Op { } +def DefaultMemWrite : MemoryEffects<[MemWrite]>; +def DefaultMemRead : MemoryEffects<[MemRead]>; + def TPU_ReductionKind : I32EnumAttr<"ReductionKind", "Reduction kind", [ I32EnumAttrCase<"SUM", 0, "sum">, I32EnumAttrCase<"MAX", 1, "max">, @@ -198,7 +201,7 @@ def TPU_AllReduceOp : TPU_Op<"all_reduce", [Pure, SameOperandsAndResultType]> { }]; } -def TPU_StoreOp : TPU_Op<"store", [AttrSizedOperandSegments]> { +def TPU_StoreOp : TPU_Op<"store", [DefaultMemWrite, AttrSizedOperandSegments]> { let arguments = (ins TPU_Vreg:$valueToStore, AnyType:$base, @@ -213,7 +216,7 @@ def TPU_StoreOp : TPU_Op<"store", [AttrSizedOperandSegments]> { }]; } -def TPU_LoadOp : TPU_Op<"load"> { +def TPU_LoadOp : TPU_Op<"load", [DefaultMemRead]> { let arguments = (ins AnyType:$base, Variadic:$indices, @@ -227,7 +230,7 @@ def TPU_LoadOp : TPU_Op<"load"> { } // TODO(jevinjiang): migrate tpu.strided_store to general vector store op. -def TPU_VectorStoreOp :TPU_Op<"vector_store", [AttrSizedOperandSegments]> { +def TPU_VectorStoreOp :TPU_Op<"vector_store", [DefaultMemWrite, AttrSizedOperandSegments]> { let arguments = (ins AnyVectorOfNonZeroRank:$valueToStore, AnyMemRef:$base, @@ -242,7 +245,7 @@ def TPU_VectorStoreOp :TPU_Op<"vector_store", [AttrSizedOperandSegments]> { let hasVerifier = 1; } -def TPU_StridedLoadOp : TPU_Op<"strided_load"> { +def TPU_StridedLoadOp : TPU_Op<"strided_load", [DefaultMemRead]> { let arguments = (ins AnyMemRef:$base, Variadic:$indices, @@ -255,7 +258,7 @@ def TPU_StridedLoadOp : TPU_Op<"strided_load"> { let hasVerifier = 1; } -def TPU_StridedStoreOp : TPU_Op<"strided_store"> { +def TPU_StridedStoreOp : TPU_Op<"strided_store", [DefaultMemWrite]> { let arguments = (ins AnyVectorOfNonZeroRank:$valueToStore, AnyMemRef:$base, @@ -269,7 +272,7 @@ def TPU_StridedStoreOp : TPU_Op<"strided_store"> { let hasVerifier = 1; } -def TPU_ShuffledLoadOp : TPU_Op<"shuffled_load"> { +def TPU_ShuffledLoadOp : TPU_Op<"shuffled_load", [DefaultMemRead]> { let arguments = (ins AnyMemRef:$base, Variadic:$indices, @@ -284,7 +287,7 @@ def TPU_ShuffledLoadOp : TPU_Op<"shuffled_load"> { let hasCanonicalizeMethod = 1; } -def TPU_ShuffledStoreOp : TPU_Op<"shuffled_store"> { +def TPU_ShuffledStoreOp : TPU_Op<"shuffled_store", [DefaultMemWrite]> { let arguments = (ins TPU_Vreg:$valueToStore, AnyMemRef:$base, @@ -302,6 +305,11 @@ def TPU_ShuffledStoreOp : TPU_Op<"shuffled_store"> { // TODO(jevinjiang): deprecate to use dynamic_rotate. def TPU_RotateOp : TPU_Op<"rotate", [Pure, SameOperandsAndResultType]> { + let description = [{ + Rotates the given vector by the given amount in the given dimension, i.e., + for a 2D vector of shape (m, n), rotating dim 0 by `amount` will shift a row + at index `i` to index `(i + amount) % m` + }]; let arguments = (ins AnyVectorOfNonZeroRank:$value, SI32Attr:$amount, @@ -752,7 +760,9 @@ def TPU_EnqueueDMAOp : TPU_Op<"enqueue_dma", [AttrSizedOperandSegments]> { AnyMemRef:$target, MemRefOf<[TPU_DMASemaphoreType]>:$target_semaphore, Optional:$device_id, // For remote DMAs - Optional:$core_id // For megacore + Optional:$core_id, // For megacore + // Smaller number means higher priority. 0 is the highest and the default. + DefaultValuedAttr:$priority ); let hasVerifier = 1; } diff --git a/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc b/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc index 59ca5d7a3437..e0e061fbd6dd 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc @@ -15,27 +15,23 @@ limitations under the License. #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" -#include -#include #include -#include #include -#include -#include +#include "absl/hash/hash.h" +#include "absl/log/log.h" +#include "llvm/ADT/Hashing.h" #include "llvm/ADT/TypeSwitch.h" // IWYU pragma: keep. +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/DialectImplementation.h" // IWYU pragma: keep. #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" -#include "absl/hash/hash.h" -#include "absl/log/log.h" -#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" +#include "jaxlib/mosaic/dialect/tpu/layout.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.cc.inc" #include "jaxlib/mosaic/dialect/tpu/tpu_enums.cc.inc" #include "xla/layout.h" diff --git a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h index 0800a9e75087..2afaf08f29ed 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h +++ b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h @@ -23,16 +23,14 @@ limitations under the License. #include #include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Value.h" #include "mlir/Pass/Pass.h" -#include "mlir/include/mlir/IR/BuiltinOps.h" -#include "mlir/include/mlir/IR/BuiltinTypes.h" -#include "mlir/include/mlir/IR/Value.h" -#include "mlir/include/mlir/Support/LogicalResult.h" -#include "jaxlib/mosaic/dialect/tpu/layout.h" +#include "mlir/Support/LogicalResult.h" +#include "jaxlib/mosaic/dialect/tpu/layout.h" // IWYU pragma: keep #include "jaxlib/mosaic/dialect/tpu/tpu_enums.h.inc" -#include "jaxlib/mosaic/dialect/tpu/transforms/serde.h" -#include "xla/layout.h" +#include "xla/layout.h" // IWYU pragma: keep namespace mlir::tpu { class TPUDialect; @@ -64,11 +62,11 @@ struct ApplyVectorLayoutContext { // mxu_shape = {contracting_size, non_contracting_size} std::array mxu_shape = {128, 128}; int64_t max_sublanes_in_scratch = 0; - int64_t vmem_banks = -1; // -1 means "unspecified". + int64_t vmem_banks = -1; // -1 means "unspecified". int32_t max_shuffle_sublane_offset = -1; // -1 means "unspecified". }; -std::pair mightCommunicateBetweenChips(Operation* op); +std::pair mightCommunicateBetweenChips(Operation *op); std::unique_ptr> createInferMemRefLayoutPass( int hardware_generation = -1, diff --git a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc index c73accb09b26..341ead8431b4 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc @@ -19,26 +19,27 @@ limitations under the License. #include #include +#include "absl/log/check.h" +#include "absl/strings/str_format.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/FormatVariadic.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/OperationSupport.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Value.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" -#include "absl/log/check.h" -#include "absl/strings/str_format.h" -#include "mlir/include/mlir/Dialect/Math/IR/Math.h" -#include "mlir/include/mlir/IR/Builders.h" -#include "mlir/include/mlir/IR/BuiltinTypeInterfaces.h" -#include "mlir/include/mlir/IR/BuiltinTypes.h" -#include "mlir/include/mlir/IR/Diagnostics.h" -#include "mlir/include/mlir/IR/IRMapping.h" -#include "mlir/include/mlir/IR/OperationSupport.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" #include "jaxlib/mosaic/dialect/tpu/util.h" +#include "xla/layout.h" namespace mlir { namespace tpu { @@ -954,13 +955,24 @@ LogicalResult EnqueueDMAOp::verify() { "device_id or core_id is specified"); } } + bool is_remote = getDeviceId() || getCoreId(); if (getSourceSemaphore()) { - if (!getDeviceId() && !getCoreId()) { + if (!is_remote) { return emitOpError( "DMA destination device_id or core_id must be specified when source " "semaphore is specified"); } } + int priority = getPriority(); + if (priority < 0 || priority > 1) { + return emitOpError( + "Not implemented: only support priority 0 or 1, but got ") + << priority; + } + if (priority != 0 && is_remote) { + return emitOpError( + "Not implemented: non-zero priority is not supported for remote DMA"); + } return success(); } @@ -1084,7 +1096,7 @@ LogicalResult ConcatenateOp::verify() { if (getOperands().size() < 2) { return emitOpError("Expected at least 2 operands for concatenate op."); } - auto first_type = getOperand(0).getType().cast(); + auto first_type = cast(getOperand(0).getType()); auto first_shape = first_type.getShape(); auto first_dtype = first_type.getElementType(); for (auto operand : getOperands()) { diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 1997ffe34535..d6452cba8b8d 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -1,3 +1,18 @@ +/* Copyright 2021 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #include "jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.h" #include @@ -13,15 +28,22 @@ #include #include +#include "absl/algorithm/container.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/types/span.h" +#include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVectorExtras.h" #include "llvm/ADT/StringMap.h" #include "llvm/ADT/iterator_range.h" -#include "llvm/Support/Compiler.h" +#include "llvm/Support/LogicalResult.h" #include "llvm/Support/MathExtras.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Traits.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" @@ -33,9 +55,11 @@ #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Diagnostics.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/Operation.h" +#include "mlir/IR/OperationSupport.h" #include "mlir/IR/Region.h" #include "mlir/IR/TypeRange.h" #include "mlir/IR/Types.h" @@ -45,21 +69,6 @@ #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" -#include "absl/algorithm/container.h" -#include "absl/log/check.h" -#include "absl/log/log.h" -#include "absl/status/status.h" -#include "absl/types/span.h" -#include "llvm/include/llvm/ADT/APInt.h" -#include "llvm/include/llvm/Support/LogicalResult.h" -#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/include/mlir/Dialect/Math/IR/Math.h" -#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/include/mlir/IR/Attributes.h" -#include "mlir/include/mlir/IR/Builders.h" -#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" -#include "mlir/include/mlir/IR/OperationSupport.h" #include "jaxlib/mosaic/dialect/tpu/array_util.h" #include "jaxlib/mosaic/dialect/tpu/layout.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" @@ -3421,20 +3430,25 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op, if (tiling[1] != ctx.target_shape[1]) { return op.emitOpError("Not implemented: unsupported tiling"); } - int64_t num_tiles = layout_in.tilesPerVreg(ctx.target_shape); + const int64_t num_tiles = layout_in.tilesPerVreg(ctx.target_shape); + const int64_t sublanes_per_tile = + layout_in.sublanesPerTile(ctx.target_shape); if (needs_physical_broadcast == std::array{true, false}) { // Sublane broadcast const int packing = layout_in.packing(); - if (num_tiles != 1) { - return op.emitOpError( - "Not implemented: Only native tiling supported"); - } TPU_ASSERT_EQ_OP(*(src_tiles.dimensions().end() - 2), 1); TPU_ASSERT_OP(offsets_in[0].has_value()); const int64_t sublane_offset = *offsets_in[0] / packing; const int64_t subelement_offset = *offsets_in[0] % packing; - const DenseI32ArrayAttr indices = builder.getDenseI32ArrayAttr( - SmallVector(ctx.target_shape[0], sublane_offset)); + SmallVector pattern; + pattern.reserve(ctx.target_shape[0]); + for (int32_t t = 0; t < num_tiles; ++t) { + for (int32_t i = 0; i < sublanes_per_tile; ++i) { + pattern.push_back(sublanes_per_tile * t + sublane_offset); + } + } + const DenseI32ArrayAttr sublane_pattern = + builder.getDenseI32ArrayAttr(pattern); const absl::Status status = src_tiles.EachStatus([&](const absl::Span src_idx, Value *const src_vreg) { @@ -3451,8 +3465,8 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op, return absl::InternalError(""); } } - dst_vreg = builder.create(dst_vreg.getType(), - dst_vreg, indices, 0); + dst_vreg = builder.create( + dst_vreg.getType(), dst_vreg, sublane_pattern, 0); SmallVector dst_starts(dst_tiles_implicit_shape.size()); SmallVector dst_limits(dst_tiles_implicit_shape.size()); for (int64_t i = 0; i < dst_tiles.num_dimensions(); ++i) { @@ -3474,8 +3488,6 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op, std::array{false, true}) { // Lane broadcast TPU_ASSERT_EQ_OP(*(src_tiles.dimensions().end() - 1), 1); TPU_ASSERT_OP(offsets_in[1].has_value()); - const int64_t sublanes_per_tile = - layout_in.sublanesPerTile(ctx.target_shape); const int64_t offset = *offsets_in[1]; const int64_t lane_offset = offset % ctx.target_shape[1]; const int64_t tile_offset = offset / ctx.target_shape[1]; @@ -3728,10 +3740,6 @@ LogicalResult vector_extract_rule(RewriteContext &ctx, Operation &op, TPU_ASSERT_EQ_OP(layouts_out.size(), 1); TPU_ASSERT_OP(layouts_in.front().has_value()); const VectorLayout &layout_in = *layouts_in.front(); - if (layout_in.bitwidth() != 32) { - return op.emitOpError( - "Not implemented: Only 32-bit vector.extract supported"); - } const VectorType res_vty = dyn_cast(extract_op.getResult().getType()); if (res_vty != nullptr) { @@ -3760,6 +3768,10 @@ LogicalResult vector_extract_rule(RewriteContext &ctx, Operation &op, op.erase(); return success(); } else { + if (layout_in.bitwidth() != 32) { + return op.emitOpError( + "Not implemented: Only 32-bit vector.extract supported"); + } // TODO(b/367459476): Support non-zero offsets. if (layout_in.offsets() != LayoutOffsets{0, 0}) { return op.emitOpError("Not implemented: Unsupported layout"); @@ -5496,6 +5508,60 @@ void rotateLanes(OpBuilder &builder, xla::Array &vregs, rotateVregs(builder, vregs, amount, 1); } +// Rotate a vreg by a certain amount of rows, and get the low or high bits of +// each sublane after rotation. +// +// For these purposes, the vreg is considered to have shape (row_packing * +// target_shape[0], target_shape[1]) +// +// Args: +// vreg: The vreg to rotate +// rotate_amount: The amount to rotate the vreg by. +// rows_per_sublane: The number of rows in a sublane. +// is_high: If true, get the high bits of each sublane, otherwise get low bits. +// +// Returns: +// The rotated vreg. +Value rotateVregRows(OpBuilder &builder, Location loc, Value vreg, + const int64_t rotate_amount, + const int64_t rows_per_sublane, const bool is_high, + const std::array target_shape) { + CHECK_LE(0, rotate_amount); + CHECK_LT(0, rows_per_sublane); + const int64_t bits_per_row = 32 / rows_per_sublane; + const int64_t sublane_rotate_amount = + rotate_amount / rows_per_sublane + (is_high ? 0 : 1); + const int64_t within_sublane_rotate_amount = rotate_amount % rows_per_sublane; + vreg = builder.create(vreg.getLoc(), vreg, + /*amount=*/sublane_rotate_amount, + /*dimension=*/0, /*stride=*/nullptr, + /*stride_dimension=*/nullptr); + if (within_sublane_rotate_amount != 0) { + const VectorType vreg_ty = cast(vreg.getType()); + const VectorType i32_vreg_ty = + getNativeVregType(builder.getI32Type(), target_shape); + vreg = builder.create(loc, i32_vreg_ty, vreg); + if (is_high) { + auto shift_amt = builder.create( + loc, + DenseElementsAttr::get( + i32_vreg_ty, static_cast(bits_per_row * + within_sublane_rotate_amount))); + vreg = builder.create(loc, vreg, shift_amt); + } else { + auto shift_amt = builder.create( + loc, + DenseElementsAttr::get( + i32_vreg_ty, static_cast( + bits_per_row * (rows_per_sublane - + within_sublane_rotate_amount)))); + vreg = builder.create(loc, vreg, shift_amt); + } + vreg = builder.create(loc, vreg_ty, vreg); + } + return vreg; +} + // Relayout src_vregs from layout src to layout dst, where dst is the same as // src except that the column offset is dst_col_offset. FailureOr> doColumnShiftRelayout( @@ -6637,6 +6703,59 @@ FailureOr>> changeImplicitDim( src_candidate.tileArrayImplicitShape(vty.getShape(), target_shape)); return std::make_pair(src_candidate, vregs); } + const int64_t sublanes_per_tile = src.sublanesPerTile(target_shape); + CHECK_GT(sublanes_per_tile, 0); + if (src.tiling()[0] % sublanes_per_tile != 0) { + // Tilings such as 32-bit (4, 256) are not used and not supported. + return emitError( + loc, "Not implemented: Rows within tile span multiple sublanes"); + } + const int64_t rows_per_sublane = src.tiling()[0] / sublanes_per_tile; + // Add second minor implicit dim + if (src.implicit_dim() == VectorLayout::ImplicitDim::kNone && + dst_implicit_dim == VectorLayout::ImplicitDim::kSecondMinor) { + // TODO(tlongeri): Detect replicated source 2nd minor as a no-op above + const int64_t src_offset = src.offsets()[0].value_or(0); + // TODO(tlongeri): Do broadcast (different path) for replicated output + const int64_t dst_offset = dst_offset_hints[0].value_or(0); + VectorLayout dst(src.bitwidth(), {dst_offset, src.offsets()[1]}, + src.tiling(), dst_implicit_dim); + xla::Array new_vregs( + dst.tileArrayImplicitShape(vty.getShape(), target_shape)); + DCHECK_EQ(*(new_vregs.dimensions().end() - 2), 1); + // Define src_idx outside loop to avoid reallocation + SmallVector src_idx; + new_vregs.Each([&](const absl::Span idx, Value *new_vreg) { + // Shift the desired row from the source vreg to the desired offset for + // the destination vreg. This is done with rotates and, for packed types + // with multiple rows per sublane, bitshifts. + // Note that the offset of the source row varies but the destination + // offset is always the same. + const int64_t dst_offset_in_sublane = dst_offset % rows_per_sublane; + // src_row_with_offset is the row of the padded implicit shape that we + // will place in the destination vreg. The first dst vreg along the + // non-implicit 2nd minor has the source row at offset src_offset, the + // second has the source row at offset src_offset+1, etc. + const int64_t src_row_with_offset = *(idx.end() - 3) + src_offset; + src_idx.assign(idx.begin(), idx.end() - 3); + src_idx.push_back(src_row_with_offset / src.tiling()[0]); + src_idx.push_back(idx.back()); + Value vreg = vregs(src_idx); + const int64_t src_offset_in_vreg = src_row_with_offset % src.tiling()[0]; + const int64_t src_offset_in_sublane = + src_row_with_offset % rows_per_sublane; + int64_t row_rotate_amt = dst_offset - src_offset_in_vreg; + if (row_rotate_amt < 0) { + row_rotate_amt += rows_per_sublane * target_shape[0]; + } + *new_vreg = rotateVregRows( + builder, loc, vreg, row_rotate_amt, rows_per_sublane, + /*is_high=*/src_offset_in_sublane <= dst_offset_in_sublane, + ctx.target_shape); + }); + return std::make_pair(dst, new_vregs); + } + // Remove second minor implicit dim, for values that have (m, 128) tiling (for // m that is a power of 2). if (src.implicit_dim() == VectorLayout::ImplicitDim::kSecondMinor && @@ -6663,7 +6782,6 @@ FailureOr>> changeImplicitDim( // For example, extended offsets allow us to skip copies of low sublanes // in tiles with idx.back() == 0. const int tiles_per_vreg = src.tilesPerVreg(target_shape); - const int sublanes_per_tile = src.sublanesPerTile(target_shape); src_idx[dst_2nd_minor_idx] = src.tiling()[0] * idx[dst_2nd_minor_idx] + dst_sl_start - dst_sublane_offset; for (int dst_sl_idx = dst_sl_start; diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.h b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.h index ed72a21028eb..bbf23a9f3844 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.h +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.h @@ -1,3 +1,18 @@ +/* Copyright 2023 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #ifndef THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_APPLY_VECTOR_LAYOUT_H_ #define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_APPLY_VECTOR_LAYOUT_H_ diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout_extensions.h b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout_extensions.h index 33c9e7421004..72bd8ca370c8 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout_extensions.h +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout_extensions.h @@ -1,11 +1,26 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #ifndef THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANFORMS_APPLY_VECTOR_LAYOUT_EXTENSIONS_H_ #define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANFORMS_APPLY_VECTOR_LAYOUT_EXTENSIONS_H_ #include -#include "llvm/include/llvm/ADT/StringMap.h" -#include "mlir/include/mlir/IR/Operation.h" -#include "mlir/include/mlir/Support/LLVM.h" +#include "llvm/ADT/StringMap.h" +#include "mlir/IR/Operation.h" +#include "mlir/Support/LLVM.h" #include "jaxlib/mosaic/dialect/tpu/layout.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" diff --git a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc index 5efbdb9cb437..247f47431745 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc @@ -1,3 +1,18 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #include #include #include @@ -7,37 +22,33 @@ #include #include -#include "llvm/ADT/STLExtras.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -// It requires these headers, but does not include them. -// NOLINTNEXTLINE(misc-include-cleaner) -#include "mlir/Dialect/MemRef/IR/MemRef.h" -// NOLINTNEXTLINE(misc-include-cleaner) +#include "absl/log/check.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringMap.h" -#include "llvm/ADT/StringSet.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Math/IR/Math.h" -#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" // IWYU pragma: keep +#include "mlir/Dialect/SCF/IR/SCF.h" // IWYU pragma: keep +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributeInterfaces.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Region.h" +#include "mlir/IR/Value.h" #include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" -#include "absl/log/check.h" -#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/include/mlir/Dialect/Math/IR/Math.h" -#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h" -#include "mlir/include/mlir/IR/AffineExpr.h" -#include "mlir/include/mlir/IR/Attributes.h" -#include "mlir/include/mlir/IR/Block.h" -#include "mlir/include/mlir/IR/Builders.h" -#include "mlir/include/mlir/IR/BuiltinAttributes.h" -#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" -#include "mlir/include/mlir/IR/OpDefinition.h" -#include "mlir/include/mlir/IR/Operation.h" -#include "mlir/include/mlir/IR/PatternMatch.h" -#include "mlir/include/mlir/IR/Region.h" -#include "mlir/include/mlir/IR/Value.h" -#include "mlir/include/mlir/Support/LLVM.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" #include "jaxlib/mosaic/dialect/tpu/vreg_util.h" @@ -246,7 +257,7 @@ LogicalResult tpu_matmul_rule(const CanonicalizeContext &ctx, auto matmul_res = dot_dim_matmul(sliced_lhs.getResult(), sliced_rhs.getResult(), sliced_acc.getResult()); - auto res_ty = matmul_res.getType().cast(); + auto res_ty = cast(matmul_res.getType()); auto res_shape = res_ty.getShape(); // reshape to 1x[prior_shape] auto reshape_shape = llvm::to_vector(res_shape); diff --git a/jaxlib/mosaic/dialect/tpu/transforms/communication.cc b/jaxlib/mosaic/dialect/tpu/transforms/communication.cc index 89e3a8bb9f70..7e99dd15611b 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/communication.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/communication.cc @@ -17,13 +17,16 @@ limitations under the License. #include #include +#include "absl/log/check.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Diagnostics.h" +#include "mlir/IR/ValueRange.h" #include "mlir/IR/Visitors.h" -#include "mlir/Support/LogicalResult.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" #include "xla/layout.h" diff --git a/jaxlib/mosaic/dialect/tpu/transforms/extensions/apply_vector_layout_extensions.cc b/jaxlib/mosaic/dialect/tpu/transforms/extensions/apply_vector_layout_extensions.cc index e7528533938f..d2c149a47150 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/extensions/apply_vector_layout_extensions.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/extensions/apply_vector_layout_extensions.cc @@ -1,7 +1,22 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #include "jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout_extensions.h" -#include "llvm/include/llvm/ADT/StringMap.h" -#include "mlir/include/mlir/IR/Operation.h" +#include "llvm/ADT/StringMap.h" +#include "mlir/IR/Operation.h" namespace mlir::tpu::extensions { diff --git a/jaxlib/mosaic/dialect/tpu/transforms/extensions/infer_vector_layout_extensions.cc b/jaxlib/mosaic/dialect/tpu/transforms/extensions/infer_vector_layout_extensions.cc index c9c4a97e6222..e34ef7fcb261 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/extensions/infer_vector_layout_extensions.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/extensions/infer_vector_layout_extensions.cc @@ -1,11 +1,26 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #include "jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout_extensions.h" #include #include -#include "mlir/include/mlir/IR/Operation.h" -#include "mlir/include/mlir/Support/LLVM.h" -#include "mlir/include/mlir/Support/LogicalResult.h" +#include "mlir/IR/Operation.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" namespace mlir::tpu::extensions { diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc index 0926f8a3c7b5..bfb9be87dfd0 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc @@ -1,3 +1,18 @@ +/* Copyright 2023 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #include "jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.h" #include @@ -6,6 +21,7 @@ #include #include +#include "absl/log/check.h" #include "llvm/ADT/bit.h" #include "llvm/Support/MathExtras.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -14,7 +30,6 @@ #include "mlir/IR/Block.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Location.h" @@ -23,7 +38,6 @@ #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" -#include "absl/log/check.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" #include "jaxlib/mosaic/dialect/tpu/util.h" #include "xla/layout.h" @@ -62,18 +76,21 @@ int getTilingFactor(const int src_sublane, const int hardware_generation, const int max_normal_tiling = tiling_sublane; int large_tiling = [&] { + if (bitwidth == 2) { + return target_sublane_count * 16; + } if (bitwidth == 4 && tpu_tiling_flags.use_x4_large_second_minor) { - return tiling_sublane * 8; + return target_sublane_count * 8; } if (bitwidth == 8 && tpu_tiling_flags.use_x8_large_second_minor) { - return tiling_sublane * 4; + return target_sublane_count * 4; } // 16-bit values are generally always possible to relayout on the fly in v6, // so we allow large 2nd minor tiling whenever possible. We can't do this // for kernel arguments, because the layout of those is controlled by XLA. if (bitwidth == 16 && (tpu_tiling_flags.use_x16_large_second_minor || (!is_kernel_argument && hardware_generation >= 6))) { - return tiling_sublane * 2; + return target_sublane_count * 2; } return tiling_sublane; }(); diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.h b/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.h index f2ab7c624eb1..a6dd8ad1dbd3 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.h +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.h @@ -1,3 +1,18 @@ +/* Copyright 2023 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #ifndef THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_INFER_MEMREF_LAYOUT_H_ #define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_INFER_MEMREF_LAYOUT_H_ diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index 0081feba985b..c81701d9a398 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -19,35 +19,29 @@ limitations under the License. #include #include #include -#include -#include #include +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/types/span.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVectorExtras.h" -#include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/OpDefinition.h" #include "mlir/IR/Value.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" -#include "absl/log/check.h" -#include "absl/log/log.h" -#include "absl/types/span.h" -#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/include/mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/include/mlir/IR/Attributes.h" -#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" -#include "mlir/include/mlir/IR/OpDefinition.h" -#include "mlir/include/mlir/IR/Visitors.h" -#include "mlir/include/mlir/Pass/Pass.h" #include "jaxlib/mosaic/dialect/tpu/layout.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" #include "jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout_extensions.h" @@ -142,10 +136,10 @@ class VectorLayoutInferer { bool has_vector_io = false; for (auto op : any_op.getOperands()) { - has_vector_io |= op.getType().isa(); + has_vector_io |= isa(op.getType()); } for (auto r : any_op.getResults()) { - has_vector_io |= r.getType().isa(); + has_vector_io |= isa(r.getType()); } if (!has_vector_io && any_op.getRegions().empty()) { SmallVector in_layout(any_op.getNumOperands(), kNoLayout); @@ -1095,13 +1089,11 @@ class VectorLayoutInferer { } auto src_tiled_ishape = layout.getImplicitTiledDims(src_ty.getShape(), 1); auto dst_tiled_ishape = layout.getImplicitTiledDims(res_ty.getShape(), 1); - // Since we can only do sublane broadcasts in the (8, 128) tiling, we - // should always use that when sublane broadcasting is required. if (src_tiled_ishape[0] != dst_tiled_ishape[0] && layout.offsets()[0] != std::nullopt) { + // TODO(tlongeri): Remove this. We support non-native tiling now, but + // things may still break downstream due to missing relayouts. LayoutOffsets offsets = layout.offsets(); - // At the moment relayout can only produce replicated sublanes when - // converting to (8, 128) if the input was in (1, 128) tiling if (layout.tiling()[0] == 1 && layout.bitwidth() == kNativeBitwidth) { offsets[0] = std::nullopt; } @@ -1301,7 +1293,7 @@ class VectorLayoutInferer { (*(offsets.end() - 1) + *input_layout->offsets()[1]) % vreg_slice[1]; } for (auto stride : strides_attr) { - TPU_CHECK_OP(stride.cast().getInt() == 1, + TPU_CHECK_OP(cast(stride).getInt() == 1, "Only trivial strides supported."); } diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout_extensions.h b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout_extensions.h index d240f27fd42d..a81e982f8e1a 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout_extensions.h +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout_extensions.h @@ -1,11 +1,26 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #ifndef THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_INFER_VECTOR_LAYOUT_EXTENSIONS_H_ #define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_INFER_VECTOR_LAYOUT_EXTENSIONS_H_ #include #include -#include "mlir/include/mlir/IR/Operation.h" -#include "mlir/include/mlir/Support/LLVM.h" +#include "mlir/IR/Operation.h" +#include "mlir/Support/LLVM.h" namespace mlir::tpu::extensions { diff --git a/jaxlib/mosaic/dialect/tpu/transforms/linalg_vectorization.cc b/jaxlib/mosaic/dialect/tpu/transforms/linalg_vectorization.cc index 949a26a4f593..0d310ff45b30 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/linalg_vectorization.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/linalg_vectorization.cc @@ -19,32 +19,32 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" -#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h" -#include "mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h" -#include "mlir/include/mlir/Dialect/Math/IR/Math.h" -#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/include/mlir/Dialect/Utils/StaticValueUtils.h" -#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h" -#include "mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" -#include "mlir/include/mlir/IR/AffineMap.h" -#include "mlir/include/mlir/IR/Attributes.h" -#include "mlir/include/mlir/IR/BuiltinAttributes.h" -#include "mlir/include/mlir/IR/BuiltinTypeInterfaces.h" -#include "mlir/include/mlir/IR/BuiltinTypes.h" -#include "mlir/include/mlir/IR/DialectRegistry.h" -#include "mlir/include/mlir/IR/Matchers.h" -#include "mlir/include/mlir/IR/Operation.h" -#include "mlir/include/mlir/IR/OperationSupport.h" -#include "mlir/include/mlir/IR/PatternMatch.h" -#include "mlir/include/mlir/IR/Types.h" -#include "mlir/include/mlir/IR/Value.h" -#include "mlir/include/mlir/Pass/Pass.h" -#include "mlir/include/mlir/Support/LLVM.h" -#include "mlir/include/mlir/Support/LogicalResult.h" -#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/Transforms/Hoisting.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" +#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" namespace mlir::tpu { diff --git a/jaxlib/mosaic/dialect/tpu/transforms/memory_space_specialization.cc b/jaxlib/mosaic/dialect/tpu/transforms/memory_space_specialization.cc index b73ea0f1250f..1cfb797c5478 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/memory_space_specialization.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/memory_space_specialization.cc @@ -13,12 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "absl/log/check.h" -#include "mlir/include/mlir/IR/Attributes.h" -#include "mlir/include/mlir/IR/BuiltinTypes.h" -#include "mlir/include/mlir/IR/Value.h" -#include "mlir/include/mlir/Support/LLVM.h" -#include "mlir/include/mlir/Support/LogicalResult.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" namespace mlir { diff --git a/jaxlib/mosaic/dialect/tpu/transforms/relayout_insertion.cc b/jaxlib/mosaic/dialect/tpu/transforms/relayout_insertion.cc index b88504e35068..6ddf8bd5ce66 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/relayout_insertion.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/relayout_insertion.cc @@ -1,22 +1,36 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #include #include #include #include +#include "absl/log/check.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/MathExtras.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Diagnostics.h" #include "mlir/IR/Types.h" #include "mlir/IR/Value.h" #include "mlir/IR/Visitors.h" #include "mlir/Pass/Pass.h" -#include "absl/log/check.h" -#include "llvm/include/llvm/ADT/STLExtras.h" -#include "llvm/include/llvm/Support/MathExtras.h" -#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/include/mlir/IR/Builders.h" -#include "mlir/include/mlir/IR/Diagnostics.h" -#include "mlir/include/mlir/Support/LLVM.h" +#include "mlir/Support/LLVM.h" #include "jaxlib/mosaic/dialect/tpu/layout.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" #include "jaxlib/mosaic/dialect/tpu/util.h" diff --git a/jaxlib/mosaic/dialect/tpu/transforms/serde.cc b/jaxlib/mosaic/dialect/tpu/transforms/serde.cc index 0981c263d252..e08149fe44fc 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/serde.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/serde.cc @@ -18,19 +18,17 @@ limitations under the License. #include #include +#include "llvm/ADT/StringMap.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/OpDefinition.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/Value.h" #include "mlir/IR/Visitors.h" #include "mlir/Support/LLVM.h" -#include "llvm/include/llvm/ADT/StringMap.h" -#include "mlir/include/mlir/IR/Attributes.h" -#include "mlir/include/mlir/IR/BuiltinAttributes.h" -#include "mlir/include/mlir/IR/OpDefinition.h" -#include "mlir/include/mlir/IR/OperationSupport.h" -#include "mlir/include/mlir/Support/LogicalResult.h" +#include "mlir/Support/LogicalResult.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" #include "jaxlib/mosaic/serde.h" @@ -42,7 +40,7 @@ constexpr StringRef kMangledDialect = "stable_mosaic."; constexpr StringRef kVersionAttrName = "stable_mosaic.version"; // When this is bumped, we should file a TODO to update the forward-compatible // version in tpu_custom_call.py in a month! -constexpr int kVersion = 3; +constexpr int kVersion = 4; using SerdeRuleType = jaxlib::mosaic::SerdeRuleType; @@ -64,6 +62,11 @@ LogicalResult enqueue_dma_upgrade(Operation* op, int version) { << op->getNumOperands(); } } + if (version < 4) { + op->setAttr("priority", + mlir::IntegerAttr::get( + mlir::IntegerType::get(op->getContext(), 32), 0)); + } return success(); } @@ -71,6 +74,9 @@ LogicalResult enqueue_dma_downgrade(Operation* op, int version) { if (version < 2) { return op->emitError("Downgrade to version ") << version << " unsupported"; } + if (version < 4) { + op->removeAttr("priority"); + } return success(); } diff --git a/jaxlib/mosaic/dialect/tpu/transforms/serde.h b/jaxlib/mosaic/dialect/tpu/transforms/serde.h index 8685918d3b39..e5617ef151f7 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/serde.h +++ b/jaxlib/mosaic/dialect/tpu/transforms/serde.h @@ -1,15 +1,31 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #ifndef THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_SERDE_H_ #define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_SERDE_H_ #include #include -#include "llvm/include/llvm/ADT/StringRef.h" -#include "llvm/include/llvm/Support/CommandLine.h" -#include "mlir/include/mlir/Interfaces/DataLayoutInterfaces.h" -#include "mlir/include/mlir/Pass/Pass.h" -#include "mlir/include/mlir/Pass/PassRegistry.h" -#include "jaxlib/pass_boilerplate.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/CommandLine.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Interfaces/DataLayoutInterfaces.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "jaxlib/mosaic/pass_boilerplate.h" namespace mlir::tpu { diff --git a/jaxlib/mosaic/dialect/tpu/util.cc b/jaxlib/mosaic/dialect/tpu/util.cc index 651cef85f740..141f52ec125b 100644 --- a/jaxlib/mosaic/dialect/tpu/util.cc +++ b/jaxlib/mosaic/dialect/tpu/util.cc @@ -22,16 +22,17 @@ limitations under the License. #include #include -#include "llvm/Support/MathExtras.h" #include "absl/log/check.h" #include "absl/types/span.h" -#include "llvm/include/llvm/Support/raw_ostream.h" -#include "mlir/include/mlir/IR/Attributes.h" -#include "mlir/include/mlir/IR/BuiltinAttributes.h" -#include "mlir/include/mlir/IR/BuiltinTypes.h" -#include "mlir/include/mlir/IR/Value.h" -#include "mlir/include/mlir/IR/ValueRange.h" -#include "mlir/include/mlir/Support/LLVM.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/MathExtras.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Support/LLVM.h" #include "jaxlib/mosaic/dialect/tpu/layout.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" @@ -45,18 +46,18 @@ std::ostream &operator<<(std::ostream &os, Print p) { return os; } -SmallVector ComputeTileStrides(MemRefType memref_ty, +SmallVector ComputeTileStrides(absl::Span shape, absl::Span tiling) { - SmallVector tile_strides(memref_ty.getRank()); + SmallVector tile_strides(shape.size()); int64_t stride = 1; - for (int64_t i = 0; i < memref_ty.getRank(); ++i) { - int64_t idx = memref_ty.getRank() - 1 - i; + for (int64_t i = 0; i < shape.size(); ++i) { + int64_t idx = shape.size() - 1 - i; int64_t tiling_idx = tiling.size() - 1 - i; tile_strides[idx] = stride; if (tiling_idx >= 0) { - stride *= llvm::divideCeil(memref_ty.getShape()[idx], tiling[tiling_idx]); + stride *= llvm::divideCeil(shape[idx], tiling[tiling_idx]); } else { - stride *= memref_ty.getShape()[idx]; + stride *= shape[idx]; } } return tile_strides; diff --git a/jaxlib/mosaic/dialect/tpu/util.h b/jaxlib/mosaic/dialect/tpu/util.h index 2e19cb820b5b..eed0df14f707 100644 --- a/jaxlib/mosaic/dialect/tpu/util.h +++ b/jaxlib/mosaic/dialect/tpu/util.h @@ -1,3 +1,18 @@ +/* Copyright 2023 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #ifndef THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_UTIL_H_ #define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_UTIL_H_ @@ -10,22 +25,21 @@ #include #include +#include "absl/status/status.h" +#include "absl/types/span.h" +#include "llvm/Support/Compiler.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Location.h" +#include "mlir/IR/Value.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" -#include "absl/status/status.h" -#include "absl/types/span.h" -#include "mlir/include/mlir/IR/Attributes.h" -#include "mlir/include/mlir/IR/BuiltinTypes.h" -#include "mlir/include/mlir/IR/Diagnostics.h" -#include "mlir/include/mlir/IR/Value.h" #include "jaxlib/mosaic/dialect/tpu/layout.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" -#include "tsl/platform/statusor.h" +#include "xla/tsl/platform/statusor.h" // TODO: Instead of CHECK_EQs, can we do something like TF_RET_CHECK but with // MLIR diagnostics? @@ -192,8 +206,15 @@ std::string shapeToString(const T &shape) { return os.str(); } -SmallVector ComputeTileStrides(MemRefType memref_ty, +SmallVector ComputeTileStrides(absl::Span shape, absl::Span tiling); + +inline SmallVector ComputeTileStrides( + MemRefType memref_ty, absl::Span tiling) { + absl::Span shape(memref_ty.getShape().data(), + memref_ty.getShape().size()); + return ComputeTileStrides(shape, tiling); +} // Assuming MKN matmul - This function must only be called after // canonicalization passes. // diff --git a/jaxlib/mosaic/dialect/tpu/vreg_util.cc b/jaxlib/mosaic/dialect/tpu/vreg_util.cc index 1f59ee13a311..72e0bf7f0caf 100644 --- a/jaxlib/mosaic/dialect/tpu/vreg_util.cc +++ b/jaxlib/mosaic/dialect/tpu/vreg_util.cc @@ -19,16 +19,16 @@ limitations under the License. #include #include "absl/log/check.h" -#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/include/mlir/IR/Attributes.h" -#include "mlir/include/mlir/IR/BuiltinAttributes.h" -#include "mlir/include/mlir/IR/BuiltinTypes.h" -#include "mlir/include/mlir/IR/Diagnostics.h" -#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" -#include "mlir/include/mlir/IR/Types.h" -#include "mlir/include/mlir/IR/Value.h" -#include "mlir/include/mlir/IR/ValueRange.h" -#include "mlir/include/mlir/Support/LLVM.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Support/LLVM.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" #include "jaxlib/mosaic/dialect/tpu/util.h" #include "xla/array.h" diff --git a/jaxlib/mosaic/dialect/tpu/vreg_util.h b/jaxlib/mosaic/dialect/tpu/vreg_util.h index 86955e128f59..8c2967e776c7 100644 --- a/jaxlib/mosaic/dialect/tpu/vreg_util.h +++ b/jaxlib/mosaic/dialect/tpu/vreg_util.h @@ -19,12 +19,12 @@ limitations under the License. #include #include -#include "mlir/include/mlir/IR/Attributes.h" -#include "mlir/include/mlir/IR/Builders.h" -#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" -#include "mlir/include/mlir/IR/Types.h" -#include "mlir/include/mlir/IR/Value.h" -#include "mlir/include/mlir/Support/LLVM.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LLVM.h" #include "xla/array.h" namespace mlir::tpu { diff --git a/jaxlib/mosaic/dialect/tpu/vreg_util_test.cc b/jaxlib/mosaic/dialect/tpu/vreg_util_test.cc index ea3063361e1a..8a6d437ab73c 100644 --- a/jaxlib/mosaic/dialect/tpu/vreg_util_test.cc +++ b/jaxlib/mosaic/dialect/tpu/vreg_util_test.cc @@ -21,20 +21,20 @@ limitations under the License. #include #include -#include "llvm/include/llvm/ADT/TypeSwitch.h" -#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/include/mlir/IR/Attributes.h" -#include "mlir/include/mlir/IR/Builders.h" -#include "mlir/include/mlir/IR/BuiltinAttributes.h" -#include "mlir/include/mlir/IR/BuiltinOps.h" -#include "mlir/include/mlir/IR/BuiltinTypes.h" -#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" -#include "mlir/include/mlir/IR/MLIRContext.h" -#include "mlir/include/mlir/IR/OwningOpRef.h" -#include "mlir/include/mlir/IR/Value.h" -#include "mlir/include/mlir/Support/DebugStringHelper.h" -#include "mlir/include/mlir/Support/LLVM.h" +#include "llvm/ADT/TypeSwitch.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OwningOpRef.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/DebugStringHelper.h" +#include "mlir/Support/LLVM.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" namespace mlir::tpu { diff --git a/jaxlib/mosaic/gpu/BUILD b/jaxlib/mosaic/gpu/BUILD index 9249ae256901..0eb24781379e 100644 --- a/jaxlib/mosaic/gpu/BUILD +++ b/jaxlib/mosaic/gpu/BUILD @@ -52,7 +52,7 @@ cc_library( "serde.h", ], deps = [ - "//jaxlib:pass_boilerplate", + "//jaxlib/mosaic:pass_boilerplate", "//jaxlib/mosaic:serde", "@llvm-project//llvm:Support", "@llvm-project//mlir:DataLayoutInterfaces", @@ -111,22 +111,38 @@ cc_library( cc_library( name = "runtime", srcs = ["runtime.cc"], + # Linker may prune these symbols if they are not explicitly exported. + linkopts = ["-Wl,--export-dynamic-symbol='mosaic_gpu_*'"], deps = [ + ":mosaic_gpu_comm", "@local_config_cuda//cuda:cuda_headers", ], + alwayslink = True, +) + +cc_library( + name = "mosaic_gpu_comm", + hdrs = ["mosaic_gpu_comm.h"], + deps = [ + "@local_config_cuda//cuda:cuda_headers", + "@xla//xla/tsl/cuda:cudart", + ], ) cc_library( name = "custom_call", srcs = ["custom_call.cc"], deps = [ + ":mosaic_gpu_comm", ":passes", ":target", "//jaxlib/cuda:cuda_vendor", "//jaxlib/mosaic/dialect/gpu:mosaic_gpu", + "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -170,6 +186,8 @@ cc_library( "@llvm-project//mlir:UBToLLVM", "@llvm-project//mlir:VectorDialect", "@llvm-project//mlir:VectorToLLVM", + "@xla//xla/ffi", + "@xla//xla/ffi:ffi_api", "@xla//xla/service:custom_call_status", "@xla//xla/service:custom_call_target_registry", ], @@ -205,6 +223,7 @@ cc_binary( "notap", ], deps = [ + "//jaxlib/mosaic/gpu:mosaic_gpu_comm", "@local_config_cuda//cuda:cuda_headers", "@xla//xla/tsl/cuda:cudart", ], diff --git a/jaxlib/mosaic/gpu/custom_call.cc b/jaxlib/mosaic/gpu/custom_call.cc index 402e099c8d6b..a933b72ad55a 100644 --- a/jaxlib/mosaic/gpu/custom_call.cc +++ b/jaxlib/mosaic/gpu/custom_call.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -31,72 +32,90 @@ limitations under the License. #include #include +#include "absl/base/call_once.h" #include "absl/base/optimization.h" #include "absl/cleanup/cleanup.h" #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" -#include "llvm/include/llvm/ADT/SmallVector.h" -#include "llvm/include/llvm/Support/CodeGen.h" -#include "llvm/include/llvm/Support/TargetSelect.h" -#include "mlir/include/mlir/Conversion/ArithToLLVM/ArithToLLVM.h" -#include "mlir/include/mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" -#include "mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" -#include "mlir/include/mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" -#include "mlir/include/mlir/Conversion/IndexToLLVM/IndexToLLVM.h" -#include "mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h" -#include "mlir/include/mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" -#include "mlir/include/mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h" -#include "mlir/include/mlir/Conversion/Passes.h" -#include "mlir/include/mlir/Conversion/UBToLLVM/UBToLLVM.h" -#include "mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" -#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/include/mlir/Dialect/Arith/Transforms/Passes.h" -#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h" -#include "mlir/include/mlir/Dialect/GPU/Transforms/Passes.h" -#include "mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h" -#include "mlir/include/mlir/Dialect/Math/IR/Math.h" -#include "mlir/include/mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h" -#include "mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h" -#include "mlir/include/mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/include/mlir/ExecutionEngine/ExecutionEngine.h" -#include "mlir/include/mlir/ExecutionEngine/OptUtils.h" -#include "mlir/include/mlir/IR/AsmState.h" -#include "mlir/include/mlir/IR/DialectRegistry.h" -#include "mlir/include/mlir/IR/MLIRContext.h" -#include "mlir/include/mlir/Parser/Parser.h" -#include "mlir/include/mlir/Pass/PassManager.h" -#include "mlir/include/mlir/Pass/PassRegistry.h" -#include "mlir/include/mlir/Support/LLVM.h" -#include "mlir/include/mlir/Support/LogicalResult.h" -#include "mlir/include/mlir/Target/LLVM/NVVM/Target.h" -#include "mlir/include/mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" -#include "mlir/include/mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h" -#include "mlir/include/mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" -#include "mlir/include/mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" -#include "mlir/include/mlir/Transforms/Passes.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/CodeGen.h" +#include "llvm/Support/TargetSelect.h" +#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" +#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" +#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" +#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" +#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h" +#include "mlir/Conversion/MathToLLVM/MathToLLVM.h" +#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" +#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h" +#include "mlir/Conversion/Passes.h" +#include "mlir/Conversion/UBToLLVM/UBToLLVM.h" +#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Transforms/Passes.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/GPU/Transforms/Passes.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" +#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/ExecutionEngine/ExecutionEngine.h" +#include "mlir/ExecutionEngine/OptUtils.h" +#include "mlir/IR/AsmState.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/Parser/Parser.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Target/LLVM/NVVM/Target.h" +#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" +#include "mlir/Transforms/Passes.h" #include "jaxlib/gpu/vendor.h" #include "jaxlib/mosaic/dialect/gpu/mosaic_gpu.h" #include "jaxlib/mosaic/gpu/launch_lowering.h" +#include "jaxlib/mosaic/gpu/mosaic_gpu_comm.h" #include "jaxlib/mosaic/gpu/passes.h" #include "jaxlib/mosaic/gpu/serde.h" #include "jaxlib/mosaic/gpu/target.h" +#include "xla/ffi/ffi.h" +#include "xla/ffi/ffi_api.h" #include "xla/service/custom_call_status.h" #include "xla/service/custom_call_target_registry.h" namespace { +namespace ffi = xla::ffi; + using MosaicInitFunc = void(void****); using MosaicHostFunc = void(void**); +void EnsureLLVMNVPTXTargetIsRegistered() { + static absl::once_flag register_nvptx_target_flag; + absl::call_once(register_nvptx_target_flag, []() { + LLVMInitializeNVPTXTarget(); + LLVMInitializeNVPTXTargetInfo(); + LLVMInitializeNVPTXTargetMC(); + LLVMInitializeNVPTXAsmPrinter(); + }); +} + absl::StatusOr> GetSmAndPtxIsaVersion() { // Assumes driver has been initialized and a context exists. XLA already has // some utilities to query this, but we try to stay runtime-agnostic, so we @@ -115,13 +134,18 @@ absl::StatusOr> GetSmAndPtxIsaVersion() { device) != CUDA_SUCCESS) { return absl::InternalError("Failed to get minor compute capability"); } + EnsureLLVMNVPTXTargetIsRegistered(); return mosaic::gpu::GetSmAndPtxIsaVersion(major, minor); } + mlir::FailureOr GetPassPipeline( mlir::MLIRContext* ctx, mlir::gpu::CompilationTarget target, - const std::string& sm, const std::string& ptx_isa) { - static bool register_once = []() { + const std::string& sm, const std::string& ptx_isa, const std::string& nvshmem_path) { + static absl::once_flag register_passes_flag; + absl::call_once(register_passes_flag, []() { + EnsureLLVMNVPTXTargetIsRegistered(); + llvm::InitializeNativeTarget(); llvm::InitializeNativeTarget(); llvm::InitializeNativeTargetAsmPrinter(); @@ -149,8 +173,7 @@ mlir::FailureOr GetPassPipeline( mosaic::gpu::registerByvalInsertionPass(); mlir::arith::registerArithExpandOpsPass(); return true; - }(); - (void)register_once; + }); return mlir::parsePassPipeline(absl::StrCat( R"( builtin.module( @@ -178,8 +201,8 @@ mlir::FailureOr GetPassPipeline( gpu.module(mosaic-byval-insertion), gpu.module(reconcile-unrealized-casts), mosaic-convert-gpu-to-llvm, - gpu-module-to-binary{format=)", - mlir::gpu::stringifyCompilationTarget(target).str(), R"(}, + gpu-module-to-binary{format=)" + + mlir::gpu::stringifyCompilationTarget(target).str() + (!nvshmem_path.empty() ? R"( l=)" + nvshmem_path : "") + R"(}, convert-math-to-llvm{approximate-log1p=true}, canonicalize{max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, cse, @@ -288,7 +311,7 @@ class TemporaryDirectory { }; void DumpCompilationOutput(mlir::ModuleOp module, const std::string& sm, - const std::string& ptx_isa) { + const std::string& ptx_isa, const std::string& nvshmem_path) { bool dump_ptx = getenv("MOSAIC_GPU_DUMP_PTX") != nullptr; bool dump_ptxas = getenv("MOSAIC_GPU_DUMP_PTXAS") != nullptr; bool dump_sass = getenv("MOSAIC_GPU_DUMP_SASS") != nullptr; @@ -299,7 +322,8 @@ void DumpCompilationOutput(mlir::ModuleOp module, const std::string& sm, module = module.clone(); // Prevent accidental modification. absl::Cleanup module_destroyer = [module] { module->erase(); }; auto passes = GetPassPipeline( - module.getContext(), mlir::gpu::CompilationTarget::Assembly, sm, ptx_isa); + module.getContext(), mlir::gpu::CompilationTarget::Assembly, + sm, ptx_isa, nvshmem_path); if (mlir::failed(passes) || mlir::failed(RunPasses(std::move(*passes), module))) { return; @@ -357,7 +381,29 @@ void DumpCompilationOutput(mlir::ModuleOp module, const std::string& sm, } } -absl::StatusOr> Compile( +bool is_nvshmem_used(mlir::ModuleOp module) { + constexpr std::string_view prefix1 = "nvshmem_"; + constexpr std::string_view prefix2 = "nvshmemx_"; + for (mlir::LLVM::LLVMFuncOp llvm_func : module.getOps()) { + const auto& func_name = llvm_func.getName(); + if (!func_name.starts_with(prefix1) && !func_name.starts_with(prefix2)) { + continue; + } + auto uses = mlir::SymbolTable::getSymbolUses(llvm_func, module.getOperation()); + if (uses && !uses->empty()) { + return true; + } + } + return false; +} + +absl::StatusOr get_nvshmem_llvm_lib_path() { + const char * nvshmem_path_ptr = getenv("MOSAIC_GPU_NVSHMEM_LLVM_LIB_PATH"); + if (!nvshmem_path_ptr) return absl::InternalError("Failed to get MOSAIC_GPU_NVSHMEM_LLVM_LIB_PATH"); + return nvshmem_path_ptr; +} + +absl::StatusOr, bool>> Compile( mlir::ModuleOp module) { auto sm_and_ptx_isa = GetSmAndPtxIsaVersion(); if (!sm_and_ptx_isa.ok()) { @@ -365,9 +411,16 @@ absl::StatusOr> Compile( } const std::string sm = sm_and_ptx_isa.value().first; const std::string ptx_isa = sm_and_ptx_isa.value().second; - DumpCompilationOutput(module, sm, ptx_isa); + bool is_comm_used = is_nvshmem_used(module); + std::string nvshmem_path = ""; + if (is_comm_used) { + TF_ASSIGN_OR_RETURN(nvshmem_path, get_nvshmem_llvm_lib_path()); + } + DumpCompilationOutput(module, sm, ptx_isa, nvshmem_path); auto passes = GetPassPipeline( - module.getContext(), mlir::gpu::CompilationTarget::Binary, sm, ptx_isa); + module.getContext(), + mlir::gpu::CompilationTarget::Binary, + sm, ptx_isa, nvshmem_path); if (mlir::failed(passes)) { return absl::InternalError("Failed to construct pass pipeline"); } @@ -391,23 +444,25 @@ absl::StatusOr> Compile( if (!maybe_execution_engine) { return absl::InternalError("Failed to compile kernel"); } - return std::move(*maybe_execution_engine); + return std::make_pair(std::move(*maybe_execution_engine), is_comm_used); } class CompiledKernel { public: CompiledKernel(std::unique_ptr engine, void* ctx, - MosaicHostFunc* host_launch) - : engine_(std::move(engine)), ctx_(ctx), host_launch_(host_launch) {} + MosaicHostFunc* host_launch, bool is_comm_used) + : engine_(std::move(engine)), ctx_(ctx), host_launch_(host_launch), + is_comm_used_(is_comm_used) {} - std::tuple GetHostLaunch() { - return std::make_tuple(ctx_, host_launch_); + std::tuple GetHostLaunch() { + return std::make_tuple(ctx_, host_launch_, is_comm_used_); } private: std::unique_ptr engine_; void* ctx_; // TODO(apaszke): Destroy this properly MosaicHostFunc* host_launch_; + bool is_comm_used_; }; using KernelHash = std::array; @@ -476,7 +531,8 @@ absl::StatusOr CompileAndInit(const char* module) { if (!maybe_engine.ok()) { return maybe_engine.status(); } - mlir::ExecutionEngine* execution_engine = maybe_engine->get(); + mlir::ExecutionEngine* execution_engine = maybe_engine.value().first.get(); + bool is_comm_used = maybe_engine.value().second; auto host_and_init_func_names = GetHostAndInitFuncNames(*module_op); if (!host_and_init_func_names.ok()) { @@ -495,14 +551,15 @@ absl::StatusOr CompileAndInit(const char* module) { void** kernel_ptr_ptr = &kernel_ptr; void*** init_args[2] = {&module_ptr_ptr, &kernel_ptr_ptr}; reinterpret_cast(*init)(init_args); - return CompiledKernel(std::move(*maybe_engine), kernel_ptr, - reinterpret_cast(*host)); + return CompiledKernel(std::move(maybe_engine.value().first), kernel_ptr, + reinterpret_cast(*host), + is_comm_used); } // Each compiled kernel has a unique init func, and each kernel is used from // a single HLO module. So it should be safe to not include the CUDA context // in the key. -absl::StatusOr> CachedCompileAndInit( +absl::StatusOr CachedCompileAndInit( CacheKey key, const char* module) { auto cache_and_mutex = GetKernelCache(); auto* cache = cache_and_mutex.first; @@ -513,7 +570,7 @@ absl::StatusOr> CachedCompileAndInit( absl::ReaderMutexLock lock(mutex); auto it = cache->find(key); if (ABSL_PREDICT_TRUE(it != cache->end())) - return it->second.GetHostLaunch(); + return &it->second; } absl::MutexLock lock(mutex); @@ -525,11 +582,12 @@ absl::StatusOr> CachedCompileAndInit( } cache->insert_or_assign(key, std::move(*compiled)); } - return cache->at(key).GetHostLaunch(); + return &cache->at(key); } void MosaicGPUCustomCall(void* stream, void** buffers, char* opaque, size_t opaque_len, XlaCustomCallStatus* status) { + // Forward-compatible version using the legacy FFI API if (reinterpret_cast(opaque) % alignof(KernelHash)) { fprintf(stderr, "Misaligned opaque pointer\n"); abort(); @@ -541,20 +599,92 @@ void MosaicGPUCustomCall(void* stream, void** buffers, char* opaque, abort(); } CacheKey key(hash, reinterpret_cast(ctx)); - auto ctx_and_kernel = CachedCompileAndInit(key, opaque + sizeof(KernelHash)); - if (!ctx_and_kernel.ok()) { + auto compiled_kernel = CachedCompileAndInit(key, opaque + sizeof(KernelHash)); + if (!compiled_kernel.ok()) { XlaCustomCallStatusSetFailure(status, - ctx_and_kernel.status().message().data(), - ctx_and_kernel.status().message().size()); + compiled_kernel.status().message().data(), + compiled_kernel.status().message().size()); return; } - void* args[4] = {&std::get<0>(*ctx_and_kernel), &stream, &buffers}; - std::get<1>(*ctx_and_kernel)(args); + auto ctx_kernel_comm = (*compiled_kernel)->GetHostLaunch(); + bool is_comm_used = std::get<2>(ctx_kernel_comm); + void* args[4] = {&std::get<0>(ctx_kernel_comm), &stream, &buffers}; + if (is_comm_used) { + mosaic::gpu::NvshmemApi::Default().barrier_all_on_stream( + reinterpret_cast(stream)); + } + std::get<1>(ctx_kernel_comm)(args); } XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("mosaic_gpu", &MosaicGPUCustomCall, "CUDA"); +absl::Status MosaicGpuExecute(gpuStream_t stream, ffi::RemainingArgs inputs, + ffi::RemainingRets results, + absl::string_view kernel_hash, + absl::string_view module, + bool use_custom_barrier, + xla::RunId run_id) { + // Updated version using the new FFI API supporting custom barrier + // for distributed kernels + if (use_custom_barrier) { + fprintf(stderr, "Custom barrier is not supported on GPUs.\n"); + abort(); + } + if (reinterpret_cast(kernel_hash.data()) % + alignof(KernelHash) || + kernel_hash.size() != sizeof(KernelHash)) { + fprintf(stderr, "Misaligned opaque pointer\n"); + abort(); + } + auto hash = *reinterpret_cast(kernel_hash.data()); + CUcontext ctx; + if (cuCtxGetCurrent(&ctx) != CUDA_SUCCESS) { + fprintf(stderr, "Failed to get current CUDA context\n"); + abort(); + } + CacheKey key(hash, reinterpret_cast(ctx)); + TF_ASSIGN_OR_RETURN(auto compiled_kernel, CachedCompileAndInit(key, module.data())); + auto ctx_kernel_comm = compiled_kernel->GetHostLaunch(); + bool is_comm_used = std::get<2>(ctx_kernel_comm); + + std::vector buffers; + buffers.reserve(inputs.size() + results.size()); + for (int i = 0; i < inputs.size(); ++i) { + buffers.push_back(inputs.get(i)->untyped_data()); + } + for (int i = 0; i < results.size(); ++i) { + buffers.push_back((*results.get(i))->untyped_data()); + } + void **buffers_ptr = buffers.data(); + void *args[4] = {&std::get<0>(ctx_kernel_comm), &stream, &buffers_ptr}; + + if (is_comm_used) { + mosaic::gpu::NvshmemApi::Default().barrier_all_on_stream( + reinterpret_cast(stream)); + } + std::get<1>(ctx_kernel_comm)(args); + return absl::OkStatus(); +} + +XLA_FFI_DEFINE_HANDLER(kMosaicGpuExecute, MosaicGpuExecute, + ffi::Ffi::Bind() + .Ctx>() + .RemainingArgs() + .RemainingRets() + .Attr("kernel_hash") + .Attr("module") + .Attr("use_custom_barrier") + .Ctx()); + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "mosaic_gpu_v2", "CUDA", + { + /*instantiate=*/nullptr, + /*prepare=*/nullptr, + /*initialize=*/nullptr, + /*execute=*/kMosaicGpuExecute, + }); + } // namespace extern "C" { @@ -565,7 +695,7 @@ void** MosaicGpuCompile(const char* module) { if (!compiled.ok()) { return nullptr; } - auto [ctx, launch] = compiled->GetHostLaunch(); + auto [ctx, launch, is_comm_used] = compiled->GetHostLaunch(); auto tuple_ptr = std::unique_ptr(new void*[3]); if (!tuple_ptr) { return nullptr; diff --git a/jaxlib/mosaic/gpu/launch_lowering.cc b/jaxlib/mosaic/gpu/launch_lowering.cc index 0331d800ec50..f3f982f07481 100644 --- a/jaxlib/mosaic/gpu/launch_lowering.cc +++ b/jaxlib/mosaic/gpu/launch_lowering.cc @@ -31,29 +31,29 @@ limitations under the License. #include #include -#include "llvm/include/llvm/ADT/STLExtras.h" -#include "llvm/include/llvm/ADT/StringRef.h" -#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h" -#include "mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h" -#include "mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h" -#include "mlir/include/mlir/IR/Builders.h" -#include "mlir/include/mlir/IR/BuiltinAttributes.h" -#include "mlir/include/mlir/IR/BuiltinOps.h" -#include "mlir/include/mlir/IR/DialectRegistry.h" -#include "mlir/include/mlir/IR/Location.h" -#include "mlir/include/mlir/IR/SymbolTable.h" -#include "mlir/include/mlir/IR/TypeRange.h" -#include "mlir/include/mlir/IR/Value.h" -#include "mlir/include/mlir/IR/ValueRange.h" -#include "mlir/include/mlir/IR/Visitors.h" -#include "mlir/include/mlir/Interfaces/DataLayoutInterfaces.h" -#include "mlir/include/mlir/Pass/Pass.h" -#include "mlir/include/mlir/Pass/PassRegistry.h" -#include "mlir/include/mlir/Support/LLVM.h" -#include "mlir/include/mlir/Support/LogicalResult.h" -#include "mlir/include/mlir/Support/TypeID.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMAttrs.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/IR/TypeRange.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Interfaces/DataLayoutInterfaces.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Support/TypeID.h" namespace mosaic { namespace gpu { diff --git a/jaxlib/mosaic/gpu/mosaic_gpu_comm.h b/jaxlib/mosaic/gpu/mosaic_gpu_comm.h new file mode 100644 index 000000000000..1aa15f9307c7 --- /dev/null +++ b/jaxlib/mosaic/gpu/mosaic_gpu_comm.h @@ -0,0 +1,86 @@ +/* Copyright 2025 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_MOSAIC_GPU_COMM_H_ +#define JAXLIB_MOSAIC_GPU_COMM_H_ + +#include +#include +#include + +#include "third_party/gpus/cuda/include/cuda.h" +#include "cuda_runtime_api.h" + +#define NVSHMEM_SUCCESS 0 +#define NVSHMEM_LIB_SONAME "libnvshmem_host.so.3" + +namespace mosaic { +namespace gpu { + +#define NVSHMEM_SET_FN(FnName) \ + FnName = reinterpret_cast(dlsym(library, #FnName)); \ + if (!FnName) { \ + fprintf(stderr, #FnName " not available in this library."); \ + abort(); \ + } + +class NvshmemApi { + public: + // Returns a default NvshmemApi for a current process. + // NvshmemApi follows the Singleton design pattern + static NvshmemApi& Default() { + static NvshmemApi instance; + return instance; + } + + int cumodule_int(CUmodule module) { + std::lock_guard lock(mutex_); + return nvshmemx_cumodule_init(module); + } + + void barrier_all_on_stream(cudaStream_t stream) { + nvshmemx_barrier_all_on_stream(stream); + } + + NvshmemApi(NvshmemApi const&) = delete; + void operator=(NvshmemApi const&) = delete; + + private: + NvshmemApi() { + const char* env_value = getenv("NVSHMEM_LIBRARY_PATH"); + const char* libnvshmem_path = + env_value && *env_value != 0 ? env_value : NVSHMEM_LIB_SONAME; + void* library = dlopen(libnvshmem_path, RTLD_LAZY); + if (library == nullptr) { + fprintf(stderr, "Failed to open %s library: %s", libnvshmem_path, dlerror()); + abort(); + } + + // Initialize supported NVSHMEM host API + NVSHMEM_SET_FN(nvshmemx_cumodule_init) + NVSHMEM_SET_FN(nvshmemx_barrier_all_on_stream) + } + + // Dlopened NVSHMEM API + int (*nvshmemx_cumodule_init)(CUmodule); + int (*nvshmemx_barrier_all_on_stream)(cudaStream_t); + + std::mutex mutex_; +}; + +} // namespace gpu +} // namespace mosaic + +#endif // JAXLIB_MOSAIC_GPU_COMM_H_ diff --git a/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc b/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc index 4f804c9e2116..decdbaef28e1 100644 --- a/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc +++ b/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc @@ -22,11 +22,11 @@ limitations under the License. #include #include -#include "nanobind/nanobind.h" -#include "nanobind/stl/tuple.h" -#include "nanobind/stl/vector.h" #include "absl/cleanup/cleanup.h" #include "absl/strings/str_cat.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/tuple.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep #include "jaxlib/gpu/vendor.h" #include "jaxlib/kernel_nanobind_helpers.h" #include "xla/ffi/api/c_api.h" @@ -98,19 +98,21 @@ static const auto* kEventElapsed = .Ret>() // elapsed_ms .To([](gpuStream_t stream, auto start, auto end, auto out) { gpuStreamSynchronize(stream); - auto start_event = std::make_unique(); - auto end_event = std::make_unique(); - absl::MakeCleanup([&]() { - gpuEventDestroy(*start_event); - gpuEventDestroy(*end_event); - }); - gpuMemcpy(start_event.get(), start.untyped_data(), sizeof(gpuEvent_t), + gpuEvent_t start_event = nullptr; + gpuEvent_t end_event = nullptr; + + absl::Cleanup cleanup = [&]() { + gpuEventDestroy(start_event); + gpuEventDestroy(end_event); + }; + + gpuMemcpy(&start_event, start.untyped_data(), sizeof(gpuEvent_t), gpuMemcpyDeviceToHost); - gpuMemcpy(end_event.get(), end.untyped_data(), sizeof(gpuEvent_t), + gpuMemcpy(&end_event, end.untyped_data(), sizeof(gpuEvent_t), gpuMemcpyDeviceToHost); + float elapsed; - if (auto res = - gpuEventElapsedTime(&elapsed, *start_event, *end_event); + if (auto res = gpuEventElapsedTime(&elapsed, start_event, end_event); res) { return ffi::Error::Internal(absl::StrCat( "Failed to get elapsed time between events: ", ToString(res))); @@ -193,6 +195,12 @@ void callback_complete(CUcontext context, uint32_t streamId, THROW_IF_CUPTI_ERROR(status); } } + + size_t num_dropped; + THROW_IF_CUPTI_ERROR( + cuptiActivityGetNumDroppedRecords(context, streamId, &num_dropped), + "failed to get number of dropped activity records"); + THROW_IF(num_dropped > 0, "activity records were dropped"); } NB_MODULE(_mosaic_gpu_ext, m) { @@ -237,15 +245,23 @@ NB_MODULE(_mosaic_gpu_ext, m) { cuptiActivityEnable(CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL), "failed to enable tracking of kernel activity by CUPTI"); }); - m.def("_cupti_get_timings", []() { - THROW_IF_CUPTI_ERROR( - cuptiActivityFlushAll(CUPTI_ACTIVITY_FLAG_FLUSH_FORCED), - "failed to flush CUPTI activity buffers"); - THROW_IF_CUPTI_ERROR(cuptiFinalize(), "failed to detach CUPTI"); - THROW_IF_CUPTI_ERROR(cuptiUnsubscribe(profiler_state.subscriber), - "failed to unsubscribe from CUPTI"); - return profiler_state.timings; - }); + m.def( + "_cupti_get_timings", + [](bool finalize) { + THROW_IF_CUPTI_ERROR( + cuptiActivityDisable(CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL), + "failed to disable tracking of kernel activity by CUPTI"); + THROW_IF_CUPTI_ERROR( + cuptiActivityFlushAll(CUPTI_ACTIVITY_FLAG_FLUSH_FORCED), + "failed to flush CUPTI activity buffers"); + if (finalize) { + THROW_IF_CUPTI_ERROR(cuptiFinalize(), "failed to detach CUPTI"); + } + THROW_IF_CUPTI_ERROR(cuptiUnsubscribe(profiler_state.subscriber), + "failed to unsubscribe from CUPTI"); + return profiler_state.timings; + }, + nb::arg("finalize") = true); } } // namespace diff --git a/jaxlib/mosaic/gpu/passes.cc b/jaxlib/mosaic/gpu/passes.cc index b8c3fbb74c81..9fa6f8df78a8 100644 --- a/jaxlib/mosaic/gpu/passes.cc +++ b/jaxlib/mosaic/gpu/passes.cc @@ -14,24 +14,28 @@ limitations under the License. ==============================================================================*/ #include "jaxlib/mosaic/gpu/passes.h" + #include #include #include #include -#include "llvm/include/llvm/ADT/StringRef.h" -#include "mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h" -#include "mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h" -#include "mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h" -#include "mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/include/mlir/IR/BuiltinAttributes.h" -#include "mlir/include/mlir/IR/BuiltinOps.h" -#include "mlir/include/mlir/IR/SymbolTable.h" -#include "mlir/include/mlir/Pass/PassRegistry.h" -#include "mlir/include/mlir/Support/LLVM.h" -#include "mlir/include/mlir/Transforms/DialectConversion.h" -#include "jaxlib/pass_boilerplate.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/Conversion/GPUCommon/GPUCommonPass.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/DialectConversion.h" +#include "jaxlib/mosaic/pass_boilerplate.h" namespace mosaic { namespace gpu { diff --git a/jaxlib/mosaic/gpu/runtime.cc b/jaxlib/mosaic/gpu/runtime.cc index ad3cd0e19644..d03aa51b3124 100644 --- a/jaxlib/mosaic/gpu/runtime.cc +++ b/jaxlib/mosaic/gpu/runtime.cc @@ -17,12 +17,14 @@ limitations under the License. #include #include +#include "jaxlib/mosaic/gpu/mosaic_gpu_comm.h" #include "third_party/gpus/cuda/include/cuda.h" + extern "C" { void mosaic_gpu_init_tma_desc(CUtensorMap *tma_desc, void *base_addr, - int64_t elem_bitwidth, int64_t rank, + int64_t elem_type, int64_t rank, int64_t *sizes, int64_t *strides, int64_t swizzle_bytes, int64_t *window_shape) { if (((uintptr_t)tma_desc) % 64 != 0) { @@ -32,6 +34,39 @@ void mosaic_gpu_init_tma_desc(CUtensorMap *tma_desc, void *base_addr, abort(); } + CUtensorMapDataType data_type; + int64_t elem_bitwidth; + // types are defined in: LaunchContext._get_tma_desc() + if (elem_type == 0){ + // this is for int4s + data_type = CU_TENSOR_MAP_DATA_TYPE_UINT8; + elem_bitwidth = 4; + } else if (elem_type == 1){ + data_type = CU_TENSOR_MAP_DATA_TYPE_UINT8; + elem_bitwidth = 8; + } else if (elem_type == 2){ + data_type = CU_TENSOR_MAP_DATA_TYPE_UINT16; + elem_bitwidth = 16; + } else if (elem_type == 3){ + data_type = CU_TENSOR_MAP_DATA_TYPE_UINT32; + elem_bitwidth = 32; + } else if (elem_type == 4){ + data_type = CU_TENSOR_MAP_DATA_TYPE_UINT64; + elem_bitwidth = 64; + } else if (elem_type == 5){ + data_type = CU_TENSOR_MAP_DATA_TYPE_FLOAT16; + elem_bitwidth = 16; + } else if (elem_type == 6){ + data_type = CU_TENSOR_MAP_DATA_TYPE_FLOAT32; + elem_bitwidth = 32; + } else if (elem_type == 7){ + data_type = CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; + elem_bitwidth = 16; + } else{ + fprintf(stderr, "Unsupported element type: %ld \n", elem_type); + abort(); + } + // Pack 4 bit types in 8 bit pairs. int64_t elem_bytewidth; if (elem_bitwidth < 8) { @@ -54,19 +89,6 @@ void mosaic_gpu_init_tma_desc(CUtensorMap *tma_desc, void *base_addr, elem_bytewidth = elem_bitwidth / 8; } - CUtensorMapDataType data_type; - if (elem_bytewidth == 1) { - data_type = CU_TENSOR_MAP_DATA_TYPE_UINT8; - } else if (elem_bytewidth == 2) { - data_type = CU_TENSOR_MAP_DATA_TYPE_UINT16; - } else if (elem_bytewidth == 4) { - data_type = CU_TENSOR_MAP_DATA_TYPE_UINT32; - } else if (elem_bytewidth == 8) { - data_type = CU_TENSOR_MAP_DATA_TYPE_UINT64; - } else { - fprintf(stderr, "Unsupported element size: %ld\n", elem_bytewidth); - abort(); - } if (rank < 1 || rank > 5) { fprintf(stderr, "Rank must be in [1, 5], but got %ld\n", rank); abort(); @@ -154,6 +176,17 @@ void* mosaic_gpu_module_load(void *data) { fprintf(stderr, "cuModuleLoadData failed: %s\n", ptr); abort(); } + + CUdeviceptr ptr = 0; + size_t size = 0; + // Check if module contains NVSHMEM globals implying NVSHMEM state needs to set + if (cuModuleGetGlobal(&ptr, &size, module, "nvshmemi_device_lib_version_d") == CUDA_SUCCESS) { + if (mosaic::gpu::NvshmemApi::Default().cumodule_int(module) != NVSHMEM_SUCCESS) { + fprintf(stderr, "nvshmemx_cumodule_init failed.\n"); + abort(); + } + } + return module; } diff --git a/jaxlib/mosaic/gpu/serde.cc b/jaxlib/mosaic/gpu/serde.cc index f4cf846acc11..5fca1d445774 100644 --- a/jaxlib/mosaic/gpu/serde.cc +++ b/jaxlib/mosaic/gpu/serde.cc @@ -15,10 +15,10 @@ limitations under the License. #include "jaxlib/mosaic/gpu/serde.h" -#include "llvm/include/llvm/ADT/StringMap.h" -#include "llvm/include/llvm/ADT/StringRef.h" -#include "mlir/include/mlir/IR/BuiltinOps.h" -#include "mlir/include/mlir/Support/LLVM.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Support/LLVM.h" #include "jaxlib/mosaic/serde.h" namespace mosaic::gpu { diff --git a/jaxlib/mosaic/gpu/serde.h b/jaxlib/mosaic/gpu/serde.h index 6187d72b4cd5..29dda33d0c5a 100644 --- a/jaxlib/mosaic/gpu/serde.h +++ b/jaxlib/mosaic/gpu/serde.h @@ -19,13 +19,13 @@ limitations under the License. #include #include -#include "llvm/include/llvm/ADT/StringRef.h" -#include "llvm/include/llvm/Support/CommandLine.h" -#include "mlir/include/mlir/IR/BuiltinOps.h" -#include "mlir/include/mlir/Interfaces/DataLayoutInterfaces.h" -#include "mlir/include/mlir/Pass/Pass.h" -#include "mlir/include/mlir/Pass/PassRegistry.h" -#include "jaxlib/pass_boilerplate.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/CommandLine.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Interfaces/DataLayoutInterfaces.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "jaxlib/mosaic/pass_boilerplate.h" namespace mosaic::gpu { diff --git a/jaxlib/mosaic/gpu/target.cc b/jaxlib/mosaic/gpu/target.cc index a1a66a709cbe..a259b3dead7b 100644 --- a/jaxlib/mosaic/gpu/target.cc +++ b/jaxlib/mosaic/gpu/target.cc @@ -23,8 +23,8 @@ limitations under the License. #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" -#include "llvm/include/llvm/MC/MCSubtargetInfo.h" -#include "llvm/include/llvm/MC/TargetRegistry.h" +#include "llvm/MC/MCSubtargetInfo.h" +#include "llvm/MC/TargetRegistry.h" namespace mosaic::gpu { diff --git a/jaxlib/pass_boilerplate.h b/jaxlib/mosaic/pass_boilerplate.h similarity index 87% rename from jaxlib/pass_boilerplate.h rename to jaxlib/mosaic/pass_boilerplate.h index b9754a8738ee..96d9e85a1d2d 100644 --- a/jaxlib/pass_boilerplate.h +++ b/jaxlib/mosaic/pass_boilerplate.h @@ -13,15 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef JAXLIB_PASS_BOILERPLATE_H_ -#define JAXLIB_PASS_BOILERPLATE_H_ +#ifndef JAXLIB_MOSAIC_PASS_BOILERPLATE_H_ +#define JAXLIB_MOSAIC_PASS_BOILERPLATE_H_ #include -#include "mlir/include/mlir/IR/DialectRegistry.h" -#include "mlir/include/mlir/Pass/Pass.h" -#include "mlir/include/mlir/Support/LLVM.h" -#include "mlir/include/mlir/Support/TypeID.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/TypeID.h" namespace jaxlib { namespace mlir { @@ -64,4 +64,4 @@ class Pass : public ::mlir::OperationPass { } // namespace mlir } // namespace jaxlib -#endif // JAXLIB_PASS_BOILERPLATE_H_ +#endif // JAXLIB_MOSAIC_PASS_BOILERPLATE_H_ diff --git a/jaxlib/mosaic/serde.cc b/jaxlib/mosaic/serde.cc index 88bca44bf181..307164d91dd9 100644 --- a/jaxlib/mosaic/serde.cc +++ b/jaxlib/mosaic/serde.cc @@ -18,15 +18,15 @@ limitations under the License. #include #include -#include "llvm/include/llvm/ADT/StringMap.h" -#include "llvm/include/llvm/ADT/StringRef.h" -#include "mlir/include/mlir/IR/BuiltinAttributes.h" -#include "mlir/include/mlir/IR/BuiltinOps.h" -#include "mlir/include/mlir/IR/Operation.h" -#include "mlir/include/mlir/IR/OperationSupport.h" -#include "mlir/include/mlir/IR/Visitors.h" -#include "mlir/include/mlir/Interfaces/DataLayoutInterfaces.h" -#include "mlir/include/mlir/Support/LLVM.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Interfaces/DataLayoutInterfaces.h" +#include "mlir/Support/LLVM.h" namespace jaxlib::mosaic { diff --git a/jaxlib/mosaic/serde.h b/jaxlib/mosaic/serde.h index 762d9e5dad73..fdcaf58d4a8e 100644 --- a/jaxlib/mosaic/serde.h +++ b/jaxlib/mosaic/serde.h @@ -18,11 +18,11 @@ limitations under the License. #include -#include "llvm/include/llvm/ADT/StringMap.h" -#include "llvm/include/llvm/ADT/StringRef.h" -#include "mlir/include/mlir/IR/BuiltinAttributes.h" -#include "mlir/include/mlir/IR/BuiltinOps.h" -#include "mlir/include/mlir/Support/LLVM.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Support/LLVM.h" namespace jaxlib::mosaic { diff --git a/jaxlib/pyinit_stub.c b/jaxlib/pyinit_stub.c new file mode 100644 index 000000000000..7fc873d9ae0e --- /dev/null +++ b/jaxlib/pyinit_stub.c @@ -0,0 +1,28 @@ +/* Copyright 2025 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Stub that reexports Wrapped_PyInit_module as PyInit_module. + +extern void* Wrapped_PyInit_@MODULE_NAME@(); + +#if defined(WIN32) || defined(_WIN32) +#define EXPORT_SYMBOL __declspec(dllexport) +#else +#define EXPORT_SYMBOL __attribute__ ((visibility("default"))) +#endif + +EXPORT_SYMBOL void* PyInit_@MODULE_NAME@() { + return Wrapped_PyInit_@MODULE_NAME@(); +} diff --git a/jaxlib/pywrap.bzl b/jaxlib/pywrap.bzl new file mode 100644 index 000000000000..e63bb0de9fd4 --- /dev/null +++ b/jaxlib/pywrap.bzl @@ -0,0 +1,83 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Wrappers around pywrap rules for JAX.""" + +load("@bazel_skylib//rules:expand_template.bzl", "expand_template") +load( + "@xla//third_party/py/rules_pywrap:pywrap.impl.bzl", + "pybind_extension", + _pywrap_binaries = "pywrap_binaries", + _pywrap_library = "pywrap_library", +) + +pywrap_library = _pywrap_library +pywrap_binaries = _pywrap_binaries + +def nanobind_pywrap_extension( + name, + srcs = [], + deps = [], + pytype_srcs = [], + pytype_deps = [], + copts = [], + linkopts = [], + visibility = None): + # buildifier: disable=function-docstring-args + "Python extension rule using nanobind and the pywrap rules." + module_name = name + lib_name = name + "_pywrap_library" + src_cc_name = name + "_pywrap_stub.c" + + # We put the entire contents of the extension in a single cc_library, which will become part of + # the common pywrap library. All the contents of all extensions will end up in the common + # library. + native.cc_library( + name = lib_name, + srcs = srcs, + copts = copts, + deps = deps, + local_defines = [ + "PyInit_{}=Wrapped_PyInit_{}".format(module_name, module_name), + ], + visibility = ["//visibility:private"], + ) + + # We build a small stub library as the extension that forwards to the PyInit_... symbol from the + # common pywrap library. + expand_template( + name = name + "_pywrap_stub", + testonly = True, + out = src_cc_name, + substitutions = { + "@MODULE_NAME@": module_name, + }, + template = "//jaxlib:pyinit_stub.c", + visibility = ["//visibility:private"], + ) + + # Despite its name "pybind_extension" has nothing to do with pybind. It is the Python extension + # rule from the pywrap rules. + pybind_extension( + name = name, + srcs = [src_cc_name], + deps = [":" + lib_name], + data = pytype_srcs, + linkopts = linkopts, + visibility = visibility, + default_deps = [], + common_lib_packages = [ + "jaxlib", + ], + ) diff --git a/jaxlib/rocm/BUILD b/jaxlib/rocm/BUILD index 9a25a795fd14..94d75d9c19ae 100644 --- a/jaxlib/rocm/BUILD +++ b/jaxlib/rocm/BUILD @@ -79,7 +79,7 @@ cc_library( deps = [ ":hip_gpu_kernel_helpers", ":hip_vendor", - "//jaxlib:handle_pool", + "//jaxlib/gpu:handle_pool", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/synchronization", "@local_config_rocm//rocm:hipblas", @@ -87,54 +87,6 @@ cc_library( ], ) -cc_library( - name = "hipblas_kernels", - srcs = ["//jaxlib/gpu:blas_kernels.cc"], - hdrs = ["//jaxlib/gpu:blas_kernels.h"], - deps = [ - ":hip_blas_handle_pool", - ":hip_gpu_kernel_helpers", - ":hip_make_batch_pointers", - ":hip_vendor", - "//jaxlib:kernel_helpers", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/base", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/hash", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/synchronization", - "@local_config_rocm//rocm:hipblas", - "@local_config_rocm//rocm:rocm_headers", - "@xla//xla/service:custom_call_status", - ], -) - -nanobind_extension( - name = "_blas", - srcs = ["//jaxlib/gpu:blas.cc"], - copts = [ - "-fexceptions", - "-fno-strict-aliasing", - ], - features = ["-use_header_modules"], - module_name = "_blas", - deps = [ - ":hip_vendor", - ":hipblas_kernels", - "//jaxlib:kernel_nanobind_helpers", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/strings:str_format", - "@local_config_rocm//rocm:hipblas", - "@local_config_rocm//rocm:rocm_headers", - "@nanobind", - "@xla//xla/tsl/python/lib/core:numpy", - ], -) - cc_library( name = "miopen_rnn_kernels", srcs = ["//jaxlib/gpu:rnn_kernels.cc"], @@ -143,11 +95,12 @@ cc_library( ":ffi_wrapper", ":hip_gpu_kernel_helpers", ":hip_vendor", - "//jaxlib:handle_pool", "//jaxlib:kernel_helpers", + "//jaxlib/gpu:handle_pool", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", "@local_config_rocm//rocm:miopen", "@local_config_rocm//rocm:rocm_headers", "@xla//xla/ffi/api:ffi", @@ -182,7 +135,7 @@ cc_library( deps = [ ":hip_gpu_kernel_helpers", ":hip_vendor", - "//jaxlib:handle_pool", + "//jaxlib/gpu:handle_pool", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/synchronization", "@local_config_rocm//rocm:hipsolver", @@ -291,8 +244,8 @@ cc_library( ":ffi_wrapper", ":hip_gpu_kernel_helpers", ":hip_vendor", - "//jaxlib:handle_pool", "//jaxlib:kernel_helpers", + "//jaxlib/gpu:handle_pool", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -318,6 +271,7 @@ nanobind_extension( ":hip_vendor", ":hipsparse_kernels", "//jaxlib:absl_status_casters", + "//jaxlib:kernel_helpers", "//jaxlib:kernel_nanobind_helpers", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", @@ -325,12 +279,14 @@ nanobind_extension( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/hash", "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@local_config_rocm//rocm:hipsparse", "@local_config_rocm//rocm:rocm_headers", "@nanobind", + "@xla//xla/service:custom_call_status", "@xla//xla/tsl/python/lib/core:numpy", ], ) @@ -496,9 +452,11 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/synchronization", - "@tsl//tsl/platform:env", "@xla//xla/service:custom_call_status", + "@xla//xla/tsl/platform:env", + "@xla//xla/tsl/platform:errors", "@xla//xla/tsl/util:env_var", ], ) @@ -536,7 +494,9 @@ nanobind_extension( "//jaxlib:absl_status_casters", "//jaxlib:kernel_nanobind_helpers", "//jaxlib/gpu:triton_cc_proto", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", "@nanobind", ], ) @@ -544,7 +504,6 @@ nanobind_extension( py_library( name = "rocm_gpu_support", deps = [ - ":_blas", ":_hybrid", ":_linalg", ":_prng", @@ -555,11 +514,49 @@ py_library( ], ) +cc_library( + name = "py_client_gpu", + srcs = ["//jaxlib/gpu:py_client_gpu.cc"], + hdrs = ["//jaxlib/gpu:py_client_gpu.h"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":hip_vendor", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@nanobind", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + "@xla//xla:comparison_util", + "@xla//xla:shape_util", + "@xla//xla:xla_data_proto_cc", + "@xla//xla/ffi:ffi_api", + "@xla//xla/ffi/api:ffi", + "@xla//xla/pjrt:host_callback", + "@xla//xla/pjrt:transpose", + "@xla//xla/python:nb_numpy", + "@xla//xla/python:types", + "@xla//xla/service:platform_util", + ], +) + nanobind_extension( name = "rocm_plugin_extension", srcs = ["rocm_plugin_extension.cc"], module_name = "rocm_plugin_extension", deps = [ + ":py_client_gpu", + "//jaxlib:kernel_nanobind_helpers", "//jaxlib/gpu:gpu_plugin_extension", "@com_google_absl//absl/log", "@com_google_absl//absl/strings", diff --git a/jaxlib/rocm/rocm_plugin_extension.cc b/jaxlib/rocm/rocm_plugin_extension.cc index 1dd1f1943fc8..d893c7fb7fe2 100644 --- a/jaxlib/rocm/rocm_plugin_extension.cc +++ b/jaxlib/rocm/rocm_plugin_extension.cc @@ -16,16 +16,19 @@ limitations under the License. #include #include -#include "nanobind/nanobind.h" #include "absl/log/log.h" #include "absl/strings/str_cat.h" #include "rocm/include/hip/hip_runtime.h" +#include "nanobind/nanobind.h" #include "jaxlib/gpu/gpu_plugin_extension.h" +#include "jaxlib/gpu/py_client_gpu.h" +#include "jaxlib/kernel_nanobind_helpers.h" namespace nb = nanobind; namespace xla { namespace { + std::string ToString(hipError_t result) { #define OSTREAM_ROCM_ERROR(__name) \ case hipError##__name: \ @@ -62,10 +65,25 @@ std::string ToString(hipError_t result) { return absl::StrCat("hipError_t(", static_cast(result), ")"); } } + +nb::dict FfiRegistrations() { + nb::dict dict; + nb::dict gpu_callback_dict; + gpu_callback_dict["instantiate"] = + jax::EncapsulateFfiHandler(jax::hip::kGpuTransposePlanCacheInstantiate); + gpu_callback_dict["execute"] = + jax::EncapsulateFfiHandler(jax::hip::kXlaFfiPythonGpuCallback); + dict["xla_ffi_python_gpu_callback"] = gpu_callback_dict; + dict["xla_ffi_partitioned_python_gpu_callback"] = gpu_callback_dict; + return dict; +} + } // namespace NB_MODULE(rocm_plugin_extension, m) { BuildGpuPluginExtension(m); + m.def("ffi_registrations", &FfiRegistrations); + m.def( "get_device_ordinal", [](std::intptr_t data_value) { diff --git a/jaxlib/setup.py b/jaxlib/setup.py index b3a37a25f1b2..5bd010525c96 100644 --- a/jaxlib/setup.py +++ b/jaxlib/setup.py @@ -76,6 +76,8 @@ def has_ext_modules(self): package_data={ 'jaxlib': [ '*.so', + '*.dylib', + '*.dll', '*.pyd*', 'py.typed', 'cpu/*', diff --git a/jaxlib/tools/BUILD.bazel b/jaxlib/tools/BUILD.bazel index afa5866e286d..3dffe556d821 100644 --- a/jaxlib/tools/BUILD.bazel +++ b/jaxlib/tools/BUILD.bazel @@ -29,16 +29,18 @@ load( load( "//jaxlib:jax.bzl", "PLATFORM_TAGS_DICT", - "if_windows", "jax_py_test", "jax_wheel", "pytype_strict_library", + "pytype_test", ) licenses(["notice"]) # Apache 2 package(default_visibility = ["//visibility:public"]) +exports_files(["wheel_size_test.py"]) + genrule( name = "platform_tags_py", srcs = [], @@ -61,15 +63,14 @@ py_binary( "LICENSE.txt", "//jaxlib", "//jaxlib:README.md", + "//jaxlib:jaxlib_binaries", "//jaxlib:setup.py", + "//jaxlib/xla:xla_client.py", + "//jaxlib/xla:xla_extension", "@xla//xla/ffi/api:api.h", "@xla//xla/ffi/api:c_api.h", "@xla//xla/ffi/api:ffi.h", - "@xla//xla/python:xla_client.py", - "@xla//xla/python:xla_extension", - ] + if_windows([ - "//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi.dll", - ]), + ], deps = [ ":build_utils", "@bazel_tools//tools/python/runfiles", @@ -389,3 +390,48 @@ verify_manylinux_compliance_test( wheel = ":jax_cuda_pjrt_wheel", x86_64_compliance_tag = X86_64_MANYLINUX_TAG, ) + +pytype_test( + name = "jaxlib_wheel_size_test", + srcs = [":wheel_size_test.py"], + args = [ + "--wheel-path=$(location :jaxlib_wheel)", + "--max-size-mib=110", + ], + data = [":jaxlib_wheel"], + main = "wheel_size_test.py", + tags = [ + "manual", + "notap", + ], +) + +pytype_test( + name = "jax_cuda_plugin_wheel_size_test", + srcs = [":wheel_size_test.py"], + args = [ + "--wheel-path=$(location :jax_cuda_plugin_wheel)", + "--max-size-mib=20", + ], + data = [":jax_cuda_plugin_wheel"], + main = "wheel_size_test.py", + tags = [ + "manual", + "notap", + ], +) + +pytype_test( + name = "jax_cuda_pjrt_wheel_size_test", + srcs = [":wheel_size_test.py"], + args = [ + "--wheel-path=$(location :jax_cuda_pjrt_wheel)", + "--max-size-mib=120", + ], + data = [":jax_cuda_pjrt_wheel"], + main = "wheel_size_test.py", + tags = [ + "manual", + "notap", + ], +) diff --git a/jaxlib/tools/build_gpu_kernels_wheel.py b/jaxlib/tools/build_gpu_kernels_wheel.py index 2f81eacbdde4..e9684108caf0 100644 --- a/jaxlib/tools/build_gpu_kernels_wheel.py +++ b/jaxlib/tools/build_gpu_kernels_wheel.py @@ -102,7 +102,6 @@ def prepare_wheel_cuda( dst_dir=plugin_dir, src_files=[ f"__main__/jaxlib/cuda/_solver.{pyext}", - f"__main__/jaxlib/cuda/_blas.{pyext}", f"__main__/jaxlib/cuda/_linalg.{pyext}", f"__main__/jaxlib/cuda/_prng.{pyext}", f"__main__/jaxlib/cuda/_rnn.{pyext}", @@ -140,7 +139,6 @@ def prepare_wheel_rocm( copy_runfiles( dst_dir=plugin_dir, src_files=[ - f"__main__/jaxlib/rocm/_blas.{pyext}", f"__main__/jaxlib/rocm/_linalg.{pyext}", f"__main__/jaxlib/rocm/_prng.{pyext}", f"__main__/jaxlib/rocm/_solver.{pyext}", diff --git a/jaxlib/tools/build_gpu_plugin_wheel.py b/jaxlib/tools/build_gpu_plugin_wheel.py index 667807b51197..d52cc7da36e8 100644 --- a/jaxlib/tools/build_gpu_plugin_wheel.py +++ b/jaxlib/tools/build_gpu_plugin_wheel.py @@ -81,7 +81,7 @@ def write_setup_cfg(sources_path, cpu): [bdist_wheel] plat_name={tag} -python-tag=py3 +python_tag=py3 """ ) diff --git a/jaxlib/tools/build_utils.py b/jaxlib/tools/build_utils.py index 4c50cff16743..582a0c9f1d6f 100644 --- a/jaxlib/tools/build_utils.py +++ b/jaxlib/tools/build_utils.py @@ -65,6 +65,7 @@ def build_wheel( package_name: str, git_hash: str = "", build_wheel_only: bool = True, + build_source_package_only: bool = False, ) -> None: """Builds a wheel in `output_path` using the source tree in `sources_path`.""" env = dict(os.environ) @@ -78,7 +79,8 @@ def build_wheel( env["USERPROFILE"] = env.get("SYSTEMDRIVE", "C:") subprocess.run( [sys.executable, "-m", "build", "-n"] - + (["-w"] if build_wheel_only else []), + + (["-w"] if build_wheel_only else []) + + (["-s"] if build_source_package_only else []), check=True, cwd=sources_path, env=env, @@ -97,10 +99,10 @@ def build_wheel( sys.stderr.write(" bazel run //build:requirements.update" + f" --repo_env=HERMETIC_PYTHON_VERSION={py_version}\n\n") shutil.copy(wheel, output_path) - if not build_wheel_only: + if build_source_package_only: for dist in glob.glob(os.path.join(sources_path, "dist", "*.tar.gz")): output_file = os.path.join(output_path, os.path.basename(dist)) - sys.stderr.write(f"Output source distribution: {output_file}\n\n") + sys.stderr.write(f"Output source package: {output_file}\n\n") shutil.copy(dist, output_path) diff --git a/jaxlib/tools/build_wheel.py b/jaxlib/tools/build_wheel.py index 8632468acb97..ba0eedfd393e 100644 --- a/jaxlib/tools/build_wheel.py +++ b/jaxlib/tools/build_wheel.py @@ -60,11 +60,11 @@ r = runfiles.Create() - def _is_mac(): return platform.system() == "Darwin" +soext = "dll" if build_utils.is_windows() else ("dylib" if _is_mac() else "so") pyext = "pyd" if build_utils.is_windows() else "so" @@ -103,17 +103,14 @@ def patch_copy_mlir_import(src_file, dst_dir): "pytree.pyi", "transfer_guard_lib.pyi", ] -_OPTIONAL_XLA_EXTENSION_STUBS = [] def patch_copy_xla_extension_stubs(dst_dir): xla_extension_dir = os.path.join(dst_dir, "xla_extension") os.makedirs(xla_extension_dir) for stub_name in _XLA_EXTENSION_STUBS: - stub_path = r.Rlocation("xla/xla/python/xla_extension/" + stub_name) + stub_path = r.Rlocation("__main__/jaxlib/xla/xla_extension/" + stub_name) stub_path = str(stub_path) # Make pytype accept os.path.exists(stub_path). - if stub_name in _OPTIONAL_XLA_EXTENSION_STUBS and not os.path.exists(stub_path): - continue with open(stub_path) as f: src = f.read() src = src.replace( @@ -135,7 +132,7 @@ def verify_mac_libraries_dont_reference_chkstack(): if not _is_mac(): return nm = subprocess.run( - ["nm", "-g", r.Rlocation("xla/xla/python/xla_extension.so")], + ["nm", "-g", r.Rlocation(f"__main__/jaxlib/xla_extension.{pyext}")], capture_output=True, text=True, check=False, @@ -186,6 +183,7 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu): src_files=[ f"__main__/jaxlib/cpu_feature_guard.{pyext}", f"__main__/jaxlib/utils.{pyext}", + "__main__/jaxlib/jax_common.dll" if build_utils.is_windows() else f"__main__/jaxlib/libjax_common.{soext}", "__main__/jaxlib/lapack.py", "__main__/jaxlib/hlo_helpers.py", "__main__/jaxlib/gpu_prng.py", @@ -197,8 +195,10 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu): "__main__/jaxlib/gpu_sparse.py", "__main__/jaxlib/plugin_support.py", "__main__/jaxlib/version.py", - "__main__/jaxlib/xla_client.py", - f"xla/xla/python/xla_extension.{pyext}", + "__main__/jaxlib/xla/xla_client.py", + f"__main__/jaxlib/weakref_lru_cache.{pyext}", + "__main__/jaxlib/weakref_lru_cache.pyi", + f"__main__/jaxlib/xla_extension.{pyext}", ], ) # This file is required by PEP-561. It marks jaxlib as package containing @@ -260,6 +260,7 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu): "__main__/jaxlib/mlir/dialects/_mhlo_ops_gen.py", "__main__/jaxlib/mlir/dialects/_ods_common.py", "__main__/jaxlib/mlir/dialects/_scf_ops_gen.py", + "__main__/jaxlib/mlir/dialects/_sdy_enums_gen.py", "__main__/jaxlib/mlir/dialects/_sdy_ops_gen.py", "__main__/jaxlib/mlir/dialects/_sparse_tensor_enum_gen.py", "__main__/jaxlib/mlir/dialects/_sparse_tensor_ops_gen.py", @@ -311,38 +312,31 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu): ) - if build_utils.is_windows(): - capi_so = "__main__/jaxlib/mlir/_mlir_libs/jaxlib_mlir_capi.dll" - else: - so_ext = "dylib" if _is_mac() else "so" - capi_so = f"__main__/jaxlib/mlir/_mlir_libs/libjaxlib_mlir_capi.{so_ext}" - mlir_libs_dir = jaxlib_dir / "mlir" / "_mlir_libs" copy_runfiles( dst_dir=mlir_libs_dir, src_files=[ - capi_so, "__main__/jaxlib/mlir/_mlir_libs/__init__.py", - f"__main__/jaxlib/mlir/_mlir_libs/_mlir.{pyext}", - f"__main__/jaxlib/mlir/_mlir_libs/_chlo.{pyext}", - f"__main__/jaxlib/mlir/_mlir_libs/_mlirHlo.{pyext}", - f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsSparseTensor.{pyext}", - f"__main__/jaxlib/mlir/_mlir_libs/_mlirSparseTensorPasses.{pyext}", - f"__main__/jaxlib/mlir/_mlir_libs/_mosaic_gpu_ext.{pyext}", - f"__main__/jaxlib/mlir/_mlir_libs/_tpu_ext.{pyext}", - f"__main__/jaxlib/mlir/_mlir_libs/_sdy.{pyext}", - f"__main__/jaxlib/mlir/_mlir_libs/_stablehlo.{pyext}", - f"__main__/jaxlib/mlir/_mlir_libs/register_jax_dialects.{pyext}", - f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsGPU.{pyext}", - f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsLLVM.{pyext}", - f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsNVGPU.{pyext}", - f"__main__/jaxlib/mlir/_mlir_libs/_mlirGPUPasses.{pyext}", + f"__main__/jaxlib/_mlir.{pyext}", + f"__main__/jaxlib/_chlo.{pyext}", + f"__main__/jaxlib/_mlirHlo.{pyext}", + f"__main__/jaxlib/_mlirDialectsSparseTensor.{pyext}", + f"__main__/jaxlib/_mlirSparseTensorPasses.{pyext}", + f"__main__/jaxlib/_mosaic_gpu_ext.{pyext}", + f"__main__/jaxlib/_tpu_ext.{pyext}", + f"__main__/jaxlib/_sdy.{pyext}", + f"__main__/jaxlib/_stablehlo.{pyext}", + f"__main__/jaxlib/register_jax_dialects.{pyext}", + f"__main__/jaxlib/_mlirDialectsGPU.{pyext}", + f"__main__/jaxlib/_mlirDialectsLLVM.{pyext}", + f"__main__/jaxlib/_mlirDialectsNVGPU.{pyext}", + f"__main__/jaxlib/_mlirGPUPasses.{pyext}", ] + ( [] if build_utils.is_windows() else [ - f"__main__/jaxlib/mlir/_mlir_libs/_triton_ext.{pyext}", + f"__main__/jaxlib/_triton_ext.{pyext}", "__main__/jaxlib/mlir/_mlir_libs/_triton_ext.pyi", ] ), diff --git a/jaxlib/tools/wheel_size_test.py b/jaxlib/tools/wheel_size_test.py new file mode 100644 index 000000000000..7e9c08ff9797 --- /dev/null +++ b/jaxlib/tools/wheel_size_test.py @@ -0,0 +1,56 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import os + + +def parse_args(): + """Arguments parser.""" + parser = argparse.ArgumentParser( + description="Helper for the wheel size verification", + fromfile_prefix_chars="@", + ) + parser.add_argument( + "--wheel-path", required=True, help="Path of the wheel, mandatory" + ) + parser.add_argument( + "--max-size-mib", + required=True, + help="Maximum size of the wheel in MiB", + ) + return parser.parse_args() + + +def verify_wheel_size(args): + wheel_size_mib = os.path.getsize(args.wheel_path) >> 20 + wheel_name = os.path.basename(args.wheel_path) + if wheel_size_mib > int(args.max_size_mib): + raise RuntimeError( + "The {name} size is {size} MiB, which is larger than the maximum size" + " {max_size} MiB".format( + name=wheel_name, + size=wheel_size_mib, + max_size=args.max_size_mb, + ) + ) + else: + logging.info( + "The %s size is %s MiB, which is less than the maximum size" + " %s MB", wheel_name, wheel_size_mib, args.max_size_mib) + + +if __name__ == "__main__": + verify_wheel_size(parse_args()) diff --git a/jaxlib/triton/BUILD b/jaxlib/triton/BUILD index 99cddd9e6381..478ce31140a6 100644 --- a/jaxlib/triton/BUILD +++ b/jaxlib/triton/BUILD @@ -35,7 +35,9 @@ pytype_strict_library( "//jaxlib/mlir:ir", ] + if_windows( [], - ["//jaxlib/mlir/_mlir_libs:_triton_ext"], + [ + "//jaxlib/mlir/_mlir_libs:_triton_ext", + ], ), ) diff --git a/jaxlib/triton/triton_dialect_capi.cc b/jaxlib/triton/triton_dialect_capi.cc index 6a46d2914f57..8781fd16d76a 100644 --- a/jaxlib/triton/triton_dialect_capi.cc +++ b/jaxlib/triton/triton_dialect_capi.cc @@ -15,12 +15,12 @@ limitations under the License. #include "jaxlib/triton/triton_dialect_capi.h" -#include "llvm/include/llvm/Support/Casting.h" -#include "mlir/include/mlir-c/IR.h" -#include "mlir/include/mlir/CAPI/IR.h" -#include "mlir/include/mlir/CAPI/Registration.h" -#include "mlir/include/mlir/IR/Attributes.h" -#include "mlir/include/mlir/IR/Dialect.h" +#include "llvm/Support/Casting.h" +#include "mlir-c/IR.h" +#include "mlir/CAPI/IR.h" +#include "mlir/CAPI/Registration.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Types.h" diff --git a/jaxlib/triton/triton_dialect_capi.h b/jaxlib/triton/triton_dialect_capi.h index 8c27b5b82500..7d2a2f10404a 100644 --- a/jaxlib/triton/triton_dialect_capi.h +++ b/jaxlib/triton/triton_dialect_capi.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef JAXLIB_TRITON_TRITON_DIALECT_CAPI_H_ #define JAXLIB_TRITON_TRITON_DIALECT_CAPI_H_ -#include "mlir/include/mlir-c/IR.h" -#include "mlir/include/mlir-c/Support.h" +#include "mlir-c/IR.h" +#include "mlir-c/Support.h" #ifdef __cplusplus extern "C" { diff --git a/jaxlib/utils.cc b/jaxlib/utils.cc index bf50b3a5254d..e5bb45e999da 100644 --- a/jaxlib/utils.cc +++ b/jaxlib/utils.cc @@ -19,12 +19,12 @@ limitations under the License. #include #include -#include "nanobind/nanobind.h" #include "absl/cleanup/cleanup.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/synchronization/mutex.h" +#include "nanobind/nanobind.h" namespace nb = nanobind; diff --git a/jaxlib/weakref_lru_cache.cc b/jaxlib/weakref_lru_cache.cc new file mode 100644 index 000000000000..0e3b9b831b82 --- /dev/null +++ b/jaxlib/weakref_lru_cache.cc @@ -0,0 +1,416 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include +#include +#include +#include +#include +#include +#include // NOLINT +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/cleanup/cleanup.h" +#include "absl/hash/hash.h" +#include "absl/strings/str_cat.h" +#include "absl/synchronization/mutex.h" +#include "absl/synchronization/notification.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "xla/pjrt/lru_cache.h" +#include "xla/tsl/platform/logging.h" + +namespace nb = nanobind; + +namespace jax { +namespace { + +// Minimal wrapper to expose a nb::dict_iterator's value as something +// hashable with Abseil. +class HashablePyDictEntry { + public: + explicit HashablePyDictEntry(std::pair entry) + : entry_(entry) {} + + template + friend H AbslHashValue(H h, const HashablePyDictEntry& v) { + return H::combine(std::move(h), nb::hash(v.entry_.first), + nb::hash(v.entry_.second)); + } + + std::pair entry_; +}; + +// Similarly, a minimalist adaptor around the nb::detail::dict_iterator +// itself. Note that the iterator "is" also a Value. Does not meet the full +// standard iterator requirements, only enough to support H::combine_unordered. +class HashablePyDictIter { + public: + using iterator_category = std::input_iterator_tag; + + explicit HashablePyDictIter(nb::detail::dict_iterator& iter) : iter_(iter) {} + + // Minimal set of iterator operations. + HashablePyDictEntry operator*() const { return HashablePyDictEntry(*iter_); } + bool operator!=(const HashablePyDictIter& rhs) const { + return iter_ != rhs.iter_; + } + void operator++() { ++iter_; } + + private: + nb::detail::dict_iterator& iter_; +}; + +struct HashableKey { + nb::object context; + nb::args args; + nb::kwargs kwargs; + + template + friend H AbslHashValue(H h, const HashableKey& key) { + // Note: Despite the fact this is an ABSL hash function, it's safe to call + // functions that may throw exceptions such as nb::hash(), because it is + // used by an LRUCache, which uses a std::unordered_map, which is + // exception-safe. + h = H::combine(std::move(h), nb::hash(key.context), nb::hash(key.args)); + nb::detail::dict_iterator begin = key.kwargs.begin(); + nb::detail::dict_iterator end = key.kwargs.end(); + h = H::combine_unordered(std::move(h), HashablePyDictIter(begin), + HashablePyDictIter(end)); + h = H::combine(std::move(h), key.kwargs.size()); + return h; + } +}; + +} // namespace + +class WeakrefLRUCache : public std::enable_shared_from_this { + public: + WeakrefLRUCache(nb::callable cache_context_fn, nb::callable fn, + int64_t maxsize) + : cache_context_fn_(cache_context_fn), fn_(fn), lru_list_(maxsize) {} + + nb::object Call(nb::object weakref_key, nb::args args, nb::kwargs kwargs); + + std::vector GetKeys(); + + struct CacheInfo { + int64_t hits; + int64_t misses; + int64_t maxsize; + int64_t currsize; + }; + CacheInfo GetCacheInfo() const; + + void Clear(); + + static PyType_Slot slots_[]; + + private: + class Key { + public: + Key(nb::object context, nb::args args, nb::kwargs kwargs) + : context_(std::move(context)), + args_(std::move(args)), + kwargs_(std::move(kwargs)), + cached_hash_(absl::HashOf(HashableKey{context_, args_, kwargs_})) {} + + bool operator==(const Key& other) const { + return context_.equal(other.context_) && args_.equal(other.args_) && + kwargs_.equal(other.kwargs_); + } + + template + friend H AbslHashValue(H h, const Key& key) { + return H::combine(std::move(h), key.cached_hash_); + } + + nb::object context() const { return context_; } + nb::args args() const { return args_; } + nb::kwargs kwargs() const { return kwargs_; } + + int tp_traverse(visitproc visit, void* arg) const { + Py_VISIT(context_.ptr()); + Py_VISIT(args_.ptr()); + Py_VISIT(kwargs_.ptr()); + return 0; + } + + private: + nb::object context_; + nb::args args_; + nb::kwargs kwargs_; + size_t cached_hash_; + }; + + struct CacheEntry { + bool has_result = false; + nb::object result; + absl::Notification completed; + std::thread::id thread_id = std::this_thread::get_id(); + + int tp_traverse(visitproc visit, void* arg) const { + Py_VISIT(result.ptr()); + return 0; + } + }; + + struct WeakrefCacheKey { + nb::weakref ref; + size_t cached_hash; + }; + + using Cache = xla::LRUCache>; + + struct WeakrefCacheValue { + std::shared_ptr cache; + }; + + struct WeakrefKeyHash { + size_t operator()(const WeakrefCacheKey& v) const { return v.cached_hash; } + }; + + struct WeakrefKeyEq { + bool operator()(const WeakrefCacheKey& lhs, + const WeakrefCacheKey& rhs) const { + return lhs.ref.equal(rhs.ref); + } + }; + + std::shared_ptr GetCache(WeakrefCacheKey key) { + WeakrefCacheValue& value = entries_[key]; + if (!value.cache) { + value.cache = std::make_shared(&lru_list_); + } + return value.cache; + } + + nb::callable cache_context_fn_; + nb::callable fn_; + Cache::LRUList lru_list_; + std::unordered_map + entries_; + int64_t misses_ = 0; + int64_t total_queries_ = 0; + absl::Mutex mu_; + + static int tp_traverse(PyObject* self, visitproc visit, void* arg); + static int tp_clear(PyObject* self); +}; + +nb::object WeakrefLRUCache::Call(nb::object weakref_key, nb::args args, + nb::kwargs kwargs) + ABSL_NO_THREAD_SAFETY_ANALYSIS { + nb::object context = cache_context_fn_(); + + // We precompute all of the hash values needed by the various maps rather + // than computing them during the std::unordered_map insertions. At the very + // least, MSVC's std::unordered_map has undefined behavior if the hash + // function throws an exception + // (https://learn.microsoft.com/en-us/cpp/standard-library/unordered-map-class?view=msvc-170#emplace). + Key key(context, args, kwargs); + size_t wrcache_hash = static_cast(nb::hash(weakref_key)); + + // No hash computations after this point. + + auto weakref_gc_callback = nb::cpp_function( + [this_weak = weak_from_this(), wrcache_hash](nb::handle weakref) { + auto cache = this_weak.lock(); + if (cache == nullptr) { + return; + } + // Set up PyCriticalSection for cache python associated object; + auto py_cache = nb::find(cache); + // This should never happen as python cache should always be found + CHECK(py_cache.ptr() != nullptr); + nb::ft_object_guard lock(py_cache); + + // The object the reference referred to is now in the process of being + // destroyed, so we cannot refer to its contents. Python weakref + // objects compare based on identity if the object they refer to is + // gone, so the hash lookup will work fine. + auto it = cache->entries_.find( + WeakrefCacheKey{nb::borrow(weakref), wrcache_hash}); + if (it == cache->entries_.end()) { + return; + } + // Create temp-var to avoid re-entrant erase. + auto tmp = std::move(it->second); + cache->entries_.erase(it); + }); + nb::weakref weakref = nb::weakref(weakref_key, weakref_gc_callback); + WeakrefCacheKey wrcache_key{weakref, wrcache_hash}; + std::shared_ptr cache_ptr = GetCache(wrcache_key); + Cache& cache = *cache_ptr; + ++total_queries_; + + bool inserted = false; + std::shared_ptr entry; + { + // Because the gil can be released during cache insertion, this forces + // the lock order to be mu_ then gil so we must release the gil first. + nb::gil_scoped_release release; + // Acquire a mutex to avoid problems where the gil is released during + // cache insertion and then a second thread invalidates the cache order. + mu_.Lock(); + } + { + // GetOrCreateIfAbsent calls into Python hash and equality functions, + // which may throw exceptions. The use of absl::Cleanup ensures mu_ is + // released if that happens. + absl::Cleanup unlock = [this]() ABSL_UNLOCK_FUNCTION(mu_) { mu_.Unlock(); }; + entry = cache.GetOrCreateIfAbsent(key, [&inserted](const Key& key) { + inserted = true; + return std::make_shared(); + }); + } + if (!entry->completed.HasBeenNotified()) { + if (inserted) { + ++misses_; + absl::Cleanup notify = [&] { entry->completed.Notify(); }; + entry->result = fn_(weakref_key, *args, **kwargs); + entry->has_result = true; + } else { + if (entry->thread_id == std::this_thread::get_id()) { + auto error_string = + absl::StrCat("Recursively calling ", + nb::cast(nb::repr(weakref_key)), + nb::cast(nb::repr(args))); + PyErr_SetString(PyExc_RecursionError, error_string.c_str()); + throw nb::python_error(); + } + nb::gil_scoped_release release; + entry->completed.WaitForNotification(); + } + } + + if (entry->has_result) { + return entry->result; + } else { + ++misses_; + return fn_(weakref_key, *args, **kwargs); + } +} + +std::vector WeakrefLRUCache::GetKeys() { + std::vector results; + mu_.Lock(); + for (const auto& wr_entry : entries_) { + for (const auto& rest : *wr_entry.second.cache) { + nb::tuple result = + nb::make_tuple(*wr_entry.first.ref, rest.first.context(), + rest.first.args(), rest.first.kwargs()); + results.push_back(std::move(result)); + } + } + mu_.Unlock(); + return results; +} + +WeakrefLRUCache::CacheInfo WeakrefLRUCache::GetCacheInfo() const { + CacheInfo result; + result.hits = total_queries_ - misses_; + result.misses = misses_; + result.maxsize = lru_list_.Capacity(); + result.currsize = lru_list_.Size(); + return result; +} + +void WeakrefLRUCache::Clear() { + total_queries_ = misses_ = 0; + std::vector> deferred_deletes; + deferred_deletes.reserve(entries_.size()); + for (auto& entry : entries_) { + deferred_deletes.emplace_back(entry.first, std::move(entry.second)); + } + entries_.clear(); + deferred_deletes.clear(); +} + +/*static*/ int WeakrefLRUCache::tp_traverse(PyObject* self, visitproc visit, + void* arg) { + Py_VISIT(Py_TYPE(self)); + if (!nb::inst_ready(self)) { + return 0; + } + WeakrefLRUCache* cache = nb::inst_ptr(self); + Py_VISIT(cache->cache_context_fn_.ptr()); + Py_VISIT(cache->fn_.ptr()); + for (const auto& [wr_key, wr_value] : cache->entries_) { + Py_VISIT(wr_key.ref.ptr()); + for (const auto& [key, cache_value] : *wr_value.cache) { + int rval = key.tp_traverse(visit, arg); + if (rval != 0) { + return rval; + } + if (cache_value.value.has_value()) { + cache_value.value->get()->tp_traverse(visit, arg); + } + } + } + return 0; +} + +/*static*/ int WeakrefLRUCache::tp_clear(PyObject* self) { + WeakrefLRUCache* cache = nb::inst_ptr(self); + cache->Clear(); + cache->cache_context_fn_.reset(); + cache->fn_.reset(); + return 0; +} + +/* static */ PyType_Slot WeakrefLRUCache::slots_[] = { + {Py_tp_traverse, (void*)WeakrefLRUCache::tp_traverse}, + {Py_tp_clear, (void*)WeakrefLRUCache::tp_clear}, + {0, nullptr}, +}; + +NB_MODULE(weakref_lru_cache, m) { + auto weakref_lru_cache = + nb::class_(m, "WeakrefLRUCache", + nb::is_weak_referenceable(), + nb::type_slots(WeakrefLRUCache::slots_)) + .def("__call__", &WeakrefLRUCache::Call, nb::lock_self()) + .def("cache_keys", &WeakrefLRUCache::GetKeys, nb::lock_self()) + .def("cache_info", &WeakrefLRUCache::GetCacheInfo, nb::lock_self()) + .def("cache_clear", &WeakrefLRUCache::Clear, nb::lock_self()); + nb::class_(weakref_lru_cache, + "WeakrefLRUCacheInfo") + .def_ro("hits", &WeakrefLRUCache::CacheInfo::hits) + .def_ro("misses", &WeakrefLRUCache::CacheInfo::misses) + .def_ro("maxsize", &WeakrefLRUCache::CacheInfo::maxsize) + .def_ro("currsize", &WeakrefLRUCache::CacheInfo::currsize) + .def("__repr__", [](WeakrefLRUCache::CacheInfo& info) { + return absl::StrCat( + "WeakrefLRUCache(hits=", info.hits, ", misses=", info.misses, + ", maxsize=", info.maxsize, ", currsize=", info.currsize, ")"); + }); + m.def( + "weakref_lru_cache", + [](nb::callable cache_context_fn, nb::callable fn, int64_t maxsize) { + return std::make_shared(cache_context_fn, fn, maxsize); + }, + nb::arg("cache_context_fn"), nb::arg("fn"), nb::arg("maxsize") = 2048); +} + +} // namespace jax diff --git a/jaxlib/weakref_lru_cache.pyi b/jaxlib/weakref_lru_cache.pyi new file mode 100644 index 000000000000..5b91ba1f5fc5 --- /dev/null +++ b/jaxlib/weakref_lru_cache.pyi @@ -0,0 +1,37 @@ +# Copyright 2025 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from collections.abc import Callable + +class WeakrefLRUCache: + def __call__(self, arg0: object, /, *args, **kwargs) -> object: ... + def cache_keys(self) -> list[object]: ... + def cache_info(self) -> WeakrefLRUCache.WeakrefLRUCacheInfo: ... + def cache_clear(self) -> None: ... + + class WeakrefLRUCacheInfo: + @property + def hits(self) -> int: ... + @property + def misses(self) -> int: ... + @property + def maxsize(self) -> int: ... + @property + def currsize(self) -> int: ... + def __repr__(self) -> str: ... + +def weakref_lru_cache( + cache_context_fn: Callable, fn: Callable, maxsize: int = 2048 +) -> WeakrefLRUCache: ... diff --git a/jaxlib/weakref_lru_cache_test.py b/jaxlib/weakref_lru_cache_test.py new file mode 100644 index 000000000000..a1016f397389 --- /dev/null +++ b/jaxlib/weakref_lru_cache_test.py @@ -0,0 +1,264 @@ +# Copyright 2023 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import gc +import threading +import time +import weakref + +from absl.testing import absltest +from jax.jaxlib import weakref_lru_cache + + +class WeakrefLRUCacheTest(absltest.TestCase): + + def testMultiThreaded(self): + insert_evs = [threading.Event() for _ in range(2)] + insert_evs_i = 0 + + class WRKey: + pass + + class ClashingKey: + + def __eq__(self, other): + return False + + def __hash__(self): + return 333 # induce maximal caching problems. + + class GilReleasingCacheKey: + + def __eq__(self, other): + nonlocal insert_evs_i + if isinstance(other, GilReleasingCacheKey) and insert_evs_i < len( + insert_evs + ): + insert_evs[insert_evs_i].set() + insert_evs_i += 1 + time.sleep(0.01) + return False + + def __hash__(self): + return 333 # induce maximal caching problems. + + def CacheFn(obj, gil_releasing_cache_key): + del obj + del gil_releasing_cache_key + return None + + cache = weakref_lru_cache.weakref_lru_cache(lambda: None, CacheFn, 2048) + + wrkey = WRKey() + + def Body(): + for insert_ev in insert_evs: + insert_ev.wait() + for _ in range(20): + cache(wrkey, ClashingKey()) + + t = threading.Thread(target=Body) + t.start() + for _ in range(3): + cache(wrkey, GilReleasingCacheKey()) + t.join() + + def testAnotherMultiThreaded(self): + num_workers = 5 + barrier = threading.Barrier(num_workers) + cache = weakref_lru_cache.weakref_lru_cache( + lambda: None, lambda x, y: y, 2048 + ) + + class WRKey: + pass + + def WorkerAddToCache(): + barrier.wait() + wrkey = WRKey() + for i in range(10): + cache(wrkey, i) + + def WorkerCleanCache(): + barrier.wait() + for _ in range(10): + cache.cache_clear() + + workers = [ + threading.Thread(target=WorkerAddToCache) + for _ in range(num_workers - 1) + ] + [threading.Thread(target=WorkerCleanCache)] + + for t in workers: + t.start() + + for t in workers: + t.join() + + def testKwargsDictOrder(self): + miss_id = 0 + + class WRKey: + pass + + def CacheFn(obj, kwkey1, kwkey2): + del obj, kwkey1, kwkey2 + nonlocal miss_id + miss_id += 1 + return miss_id + + cache = weakref_lru_cache.weakref_lru_cache(lambda: None, CacheFn, 4) + + wrkey = WRKey() + + self.assertEqual(cache(wrkey, kwkey1="a", kwkey2="b"), 1) + self.assertEqual(cache(wrkey, kwkey1="b", kwkey2="a"), 2) + self.assertEqual(cache(wrkey, kwkey2="b", kwkey1="a"), 1) + + def testGetKeys(self): + def CacheFn(obj, arg): + del obj + return arg + "extra" + + cache = weakref_lru_cache.weakref_lru_cache(lambda: None, CacheFn, 4) + + class WRKey: + pass + + wrkey = WRKey() + + self.assertEmpty(cache.cache_keys()) + cache(wrkey, "arg1") + cache(wrkey, "arg2") + self.assertLen(cache.cache_keys(), 2) + + def testNonWeakreferenceableKey(self): + class NonWRKey: + __slots__ = () + + non_wr_key = NonWRKey() + with self.assertRaises(TypeError): + weakref.ref(non_wr_key) + + cache = weakref_lru_cache.weakref_lru_cache(lambda: None, lambda x: 2048) + for _ in range(100): + with self.assertRaises(TypeError): + cache(non_wr_key) + + def testCrashingKey(self): + class WRKey: + pass + + class CrashingKey: + # A key that raises exceptions if eq or hash is called. + + def __eq__(self, other): + raise ValueError("eq") + + def __hash__(self): + raise ValueError("hash") + + cache = weakref_lru_cache.weakref_lru_cache( + lambda: None, lambda x, y: y, 2048 + ) + wrkey = WRKey() + with self.assertRaises(ValueError): + for _ in range(100): + cache(wrkey, CrashingKey()) + + def testPrintingStats(self): + class WRKey: + pass + + cache = weakref_lru_cache.weakref_lru_cache( + lambda: None, lambda x, y: y, 2048 + ) + wrkey = WRKey() + for i in range(10): + cache(wrkey, i) + for i in range(5): + cache(wrkey, i) + + self.assertEqual( + repr(cache.cache_info()), + "WeakrefLRUCache(hits=5, misses=10, maxsize=2048, currsize=10)", + ) + + def testGCKeys(self): + class WRKey: + + def __init__(self, x): + self.x = x + + def __eq__(self, other): + return self.x == other.x + + def __hash__(self): + return hash(self.x) + + cache = weakref_lru_cache.weakref_lru_cache( + lambda: None, lambda x, y: y, 2048 + ) + keys = [WRKey(i) for i in range(10)] + for i in range(10): + cache(keys[i], i) + + # Delete some keys, to exercise the weakref callback behavior. + del keys[::2] + + for key in keys: + cache(key, 7) + + def testTpTraverse(self): + class WRKey: + pass + + def CacheContextFn(): + return None + + def CallFn(x, y, *args, **kwargs): + del x, args, kwargs + return y + + cache = weakref_lru_cache.weakref_lru_cache(CacheContextFn, CallFn, 2048) + + keys = [WRKey() for _ in range(10)] + values = [str(i) for i in range(10)] + args = [str(i) for i in range(10)] + kwargs = {"a": "b"} + + for key, value in zip(keys, values): + cache(key, value, *args, **kwargs) + + expected_refs = ( + [ + CacheContextFn, + CallFn, + weakref_lru_cache.WeakrefLRUCache, + kwargs, + ] + + [weakref.getweakrefs(key)[0] for key in keys] + + values + + args + ) + + # Can't use assertContainsSubset because it doesn't support kwargs since + # dicts aren't hashable. + for ref in expected_refs: + self.assertIn(ref, gc.get_referents(cache)) + + +if __name__ == "__main__": + absltest.main() diff --git a/jaxlib/xla/BUILD b/jaxlib/xla/BUILD new file mode 100644 index 000000000000..a6a4cf660408 --- /dev/null +++ b/jaxlib/xla/BUILD @@ -0,0 +1,1062 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load( + "//jaxlib:jax.bzl", + "cc_proto_library", + "if_oss", + "jax_visibility", + "nanobind_extension", + "proto_library", + "py_deps", + "py_strict_library", + "py_strict_test", + "pytype_strict_library", +) +load("//jaxlib:pywrap.bzl", "nanobind_pywrap_extension") + +licenses(["notice"]) + +package( + default_applicable_licenses = [], + default_visibility = ["//jax:internal"], +) + +package_group( + name = "xla_python", + includes = [ + "//jax:internal", + ], +) + +nanobind_pywrap_extension( + name = "xla_extension", + srcs = ["xla.cc"], + pytype_deps = py_deps(["numpy"]), + pytype_srcs = glob(["xla_extension/*.pyi"]), + visibility = ["//visibility:public"], + deps = [ + ":config", + ":custom_call_sharding", + ":dlpack", + ":guard_lib", + ":ifrt_proxy", + ":jax_jit", + ":mlir", + ":nb_class_ptr", + ":pjit", + ":pmap_lib", + ":py_client", + ":python_ref_manager", + ":pytree", + ":sdy", + ":traceback", + ":util", + ":xla_compiler", + "@com_google_absl//absl/base", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/log:initialize", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@nanobind", + "@tsl//tsl/platform", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + "@xla//xla:literal", + "@xla//xla:shape_util", + "@xla//xla:types", + "@xla//xla:util", + "@xla//xla/backends/cpu/collectives:cpu_collectives", + "@xla//xla/ffi:ffi_api", + "@xla//xla/pjrt:exceptions", + "@xla//xla/pjrt:mlir_to_hlo", + "@xla//xla/pjrt:pjrt_api", + "@xla//xla/pjrt:pjrt_c_api_client", + "@xla//xla/pjrt:pjrt_client", + "@xla//xla/pjrt:pjrt_common", + "@xla//xla/pjrt:pjrt_compiler", + "@xla//xla/pjrt:pjrt_executable", + "@xla//xla/pjrt:pjrt_layout", + "@xla//xla/pjrt:status_casters", + "@xla//xla/pjrt/c:pjrt_c_api_hdrs", + "@xla//xla/pjrt/distributed", + "@xla//xla/pjrt/distributed:client", + "@xla//xla/pjrt/distributed:key_value_store_interface", + "@xla//xla/pjrt/distributed:protocol_proto_cc", + "@xla//xla/pjrt/distributed:service", + "@xla//xla/pjrt/plugin/xla_cpu:cpu_client_options", + "@xla//xla/pjrt/plugin/xla_cpu:xla_cpu_pjrt_client", + "@xla//xla/python:logging", + "@xla//xla/python:nb_absl_flat_hash_map", + "@xla//xla/python:nb_absl_span", + "@xla//xla/python:ops", + "@xla//xla/python:pprof_profile_builder", + "@xla//xla/python:profiler", + "@xla//xla/python:refine_polymorphic_shapes", + "@xla//xla/python:types", + "@xla//xla/python:version", + "@xla//xla/python/ifrt", + "@xla//xla/python/ifrt:plugin_program", + "@xla//xla/python/ifrt:plugin_program_serdes", + "@xla//xla/python/pjrt_ifrt", + "@xla//xla/python/pjrt_ifrt:pjrt_attribute_map_util", + "@xla//xla/python/pjrt_ifrt:xla_ifrt", + "@xla//xla/tsl/concurrency:ref_count", + "@xla//xla/tsl/distributed_runtime/preemption:preemption_sync_manager", + "@xla//xla/tsl/platform:logging", + "@xla//xla/tsl/platform:status", + "@xla//xla/tsl/platform:statusor", + "@xla//xla/tsl/platform/cloud:gcs_file_system", + "@xla//xla/tsl/python/lib/core:numpy", + ] + select({ + # gloo tcp transport only builds on linux + "@xla//xla/tsl:macos": [ + "@gloo//:transport_uv", + "@xla//xla/backends/cpu/collectives:gloo_collectives", + "@xla//xla/backends/cpu/collectives:gloo_kv_store", + ], + "@xla//xla/tsl:windows": [], + "//conditions:default": [ + ":py_socket_transfer", + "@gloo//:transport_tcp", + "@xla//xla/backends/cpu/collectives:gloo_collectives", + "@xla//xla/backends/cpu/collectives:gloo_kv_store", + ], + }) + select({ + # mpitrampoline does not build on windows + "@xla//xla/tsl:windows": [], + # we support MPI collectives only in OSS builds + "//conditions:default": if_oss(["@xla//xla/backends/cpu/collectives:mpi_collectives"]), + }), +) + +cc_library( + name = "callback", + srcs = [ + "callback.cc", + ], + hdrs = [ + "callback.h", + ], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":python_ref_manager", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@nanobind", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + "@xla//xla:comparison_util", + "@xla//xla:xla_data_proto_cc", + "@xla//xla/pjrt:host_callback", + "@xla//xla/pjrt:transpose", + "@xla//xla/python:nb_numpy", + "@xla//xla/tsl/platform:statusor", + "@xla//xla/tsl/python/lib/core:numpy", + ], +) + +cc_library( + name = "config", + srcs = ["config.cc"], + hdrs = ["config.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":python_ref_manager", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "@nanobind", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + "@xla//xla/tsl/platform:logging", + ], +) + +cc_library( + name = "custom_call_sharding", + srcs = ["custom_call_sharding.cc"], + hdrs = ["custom_call_sharding.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@nanobind", + "@xla//third_party/python_runtime:headers", + "@xla//xla:shape_util", + "@xla//xla:util", + "@xla//xla/hlo/ir:hlo", + "@xla//xla/hlo/utils:hlo_sharding_util", + "@xla//xla/pjrt:status_casters", + "@xla//xla/pjrt/c:pjrt_c_api_custom_partitioner_extension_hdrs", + "@xla//xla/pjrt/c:pjrt_c_api_hdrs", + "@xla//xla/pjrt/c:pjrt_c_api_helpers", + "@xla//xla/python:custom_call_batch_partitioner", + "@xla//xla/python:custom_partition_callback", + "@xla//xla/python:debug_callback_partitioner", + "@xla//xla/python:inspect_sharding", + "@xla//xla/tsl/platform:logging", + "@xla//xla/tsl/platform:statusor", + ], +) + +cc_library( + name = "dlpack", + srcs = ["dlpack.cc"], + hdrs = ["dlpack.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":nb_class_ptr", + ":py_client", + ":python_ref_manager", + ":traceback", + ":util", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@dlpack", + "@llvm-project//llvm:Support", + "@nanobind", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + "@xla//xla:shape_util", + "@xla//xla:status_macros", + "@xla//xla:util", + "@xla//xla:xla_data_proto_cc", + "@xla//xla/pjrt:exceptions", + "@xla//xla/pjrt:pjrt_client", + "@xla//xla/pjrt:pjrt_common", + "@xla//xla/pjrt:pjrt_compiler", + "@xla//xla/python:types", + "@xla//xla/python/ifrt", + "@xla//xla/python/pjrt_ifrt", + "@xla//xla/tsl/platform:errors", + "@xla//xla/tsl/platform:logging", + "@xla//xla/tsl/platform:statusor", + ], +) + +cc_library( + name = "guard_lib", + srcs = ["guard_lib.cc"], + hdrs = ["guard_lib.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@nanobind", + "@xla//xla:util", + ], +) + +cc_library( + name = "ifrt_proxy", + srcs = ["ifrt_proxy.cc"], + hdrs = ["ifrt_proxy.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":nb_class_ptr", + ":py_client", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/log:log_entry", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/time", + "@nanobind", + "@xla//xla/pjrt:status_casters", + "@xla//xla/python/ifrt", + "@xla//xla/python/ifrt:attribute_map", + "@xla//xla/python/ifrt_proxy/client:grpc_client", + "@xla//xla/python/ifrt_proxy/client:registry", + "@xla//xla/tsl/platform:env", + "@xla//xla/tsl/platform:statusor", + ], +) + +cc_library( + name = "jax_jit", + srcs = ["jax_jit.cc"], + hdrs = ["jax_jit.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":py_client", + ":python_ref_manager", + ":pytree", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@nanobind", + "@tsl//tsl/profiler/lib:traceme", + "@xla//third_party/python_runtime:headers", # build_cleaner: keep + "@xla//xla/pjrt:pjrt_client", + "@xla//xla/pjrt:pjrt_layout", + "@xla//xla/pjrt:status_casters", + "@xla//xla/python:nb_absl_inlined_vector", + "@xla//xla/python:nb_absl_span", + "@xla//xla/python:types", + "@xla//xla/tsl/platform:logging", + ], +) + +cc_library( + name = "mlir", + srcs = ["mlir.cc"], + hdrs = ["mlir.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:BytecodeWriter", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:ReconcileUnrealizedCasts", + "@llvm-project//mlir:Support", + "@nanobind", + "@stablehlo//:stablehlo_serialization", + "@xla//xla/hlo/builder:xla_computation", + "@xla//xla/hlo/translate:stablehlo", + "@xla//xla/mlir_hlo:mhlo_passes", + "@xla//xla/pjrt:mlir_to_hlo", + "@xla//xla/pjrt:status_casters", + "@xla//xla/python:refine_polymorphic_shapes", + "@xla//xla/tsl/platform:errors", + "@xla//xla/tsl/platform:logging", + "@xla//xla/tsl/platform:statusor", + ], +) + +cc_library( + name = "nb_class_ptr", + hdrs = ["nb_class_ptr.h"], + copts = ["-fexceptions"], + features = ["-use_header_modules"], + visibility = jax_visibility("jaxlib/xla/nb_class_ptr"), + deps = ["@nanobind"], +) + +cc_library( + name = "pjit", + srcs = ["pjit.cc"], + hdrs = ["pjit.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":config", + ":guard_lib", + ":jax_jit", + ":nb_class_ptr", + ":py_client", + ":python_ref_manager", + ":pytree", + ":traceback", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "@nanobind", + "@tsl//tsl/profiler/lib:traceme", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + "@xla//xla:shape_util", + "@xla//xla:util", + "@xla//xla/pjrt:exceptions", + "@xla//xla/pjrt:lru_cache", + "@xla//xla/python:nb_helpers", + "@xla//xla/python:nb_numpy", + "@xla//xla/python/ifrt", + "@xla//xla/tsl/concurrency:ref_count", + "@xla//xla/tsl/platform:env", + "@xla//xla/tsl/platform:errors", + "@xla//xla/tsl/platform:logging", + "@xla//xla/tsl/platform:statusor", + ], +) + +cc_library( + name = "pmap_lib", + srcs = ["pmap_lib.cc"], + hdrs = ["pmap_lib.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":config", + ":jax_jit", + ":nb_class_ptr", + ":py_client", + ":python_ref_manager", + ":pytree", + ":traceback", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "@nanobind", + "@tsl//tsl/profiler/lib:traceme", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + "@xla//xla:status_macros", + "@xla//xla:util", + "@xla//xla:xla_data_proto_cc", + "@xla//xla/pjrt:exceptions", + "@xla//xla/pjrt:status_casters", + "@xla//xla/python:nb_helpers", + "@xla//xla/python:nb_numpy", + "@xla//xla/python:types", + "@xla//xla/python/ifrt", + "@xla//xla/tsl/concurrency:ref_count", + "@xla//xla/tsl/platform:env", + "@xla//xla/tsl/platform:logging", + "@xla//xla/tsl/platform:statusor", + "@xla//xla/tsl/python/lib/core:numpy", + ], +) + +cc_library( + name = "py_client", + srcs = [ + "py_array.cc", + "py_client.cc", + "py_compile_only_client.cc", + "py_device.cc", + "py_device_list.cc", + "py_executable.cc", + "py_memory_space.cc", + "py_program.cc", + "py_values.cc", + "sharding.cc", + "to_ifrt_sharding.cc", + ], + hdrs = [ + "py_array.h", + "py_client.h", + "py_compile_only_client.h", + "py_device.h", + "py_device_list.h", + "py_executable.h", + "py_memory_space.h", + "py_program.h", + "py_values.h", + "sharded_device_array.h", + "sharding.h", + "to_ifrt_sharding.h", + ], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + visibility = jax_visibility("jaxlib/xla/py_client"), + deps = [ + ":guard_lib", + ":nb_class_ptr", + ":py_client_cpu", + ":py_host_callback", + ":python_ref_manager", + ":traceback", + ":util", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@nanobind", + "@tsl//tsl/platform:fingerprint", + "@tsl//tsl/platform:ml_dtypes", + "@tsl//tsl/profiler/lib:traceme", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + "@xla//xla:literal", + "@xla//xla:shape_util", + "@xla//xla:status_macros", + "@xla//xla:types", + "@xla//xla:util", + "@xla//xla:xla_data_proto_cc", + "@xla//xla/hlo/ir:hlo", + "@xla//xla/pjrt:exceptions", + "@xla//xla/pjrt:lru_cache", + "@xla//xla/pjrt:mlir_to_hlo", + "@xla//xla/pjrt:pjrt_client", + "@xla//xla/pjrt:pjrt_compiler", + "@xla//xla/pjrt:pjrt_executable", + "@xla//xla/pjrt:pjrt_future", + "@xla//xla/pjrt:pjrt_layout", + "@xla//xla/pjrt:status_casters", + "@xla//xla/python:nb_absl_span", + "@xla//xla/python:nb_helpers", + "@xla//xla/python:nb_numpy", + "@xla//xla/python:pprof_profile_builder", + "@xla//xla/python:types", + "@xla//xla/python/compile_only_ifrt:client", + "@xla//xla/python/ifrt", + "@xla//xla/python/ifrt:attribute_map", + "@xla//xla/python/ifrt:custom_call_program", + "@xla//xla/python/ifrt:plugin_program", + "@xla//xla/python/ifrt:plugin_program_serdes", + "@xla//xla/python/ifrt:user_context", + "@xla//xla/python/ifrt/hlo:hlo_program", + "@xla//xla/python/pjrt_ifrt", + "@xla//xla/python/pjrt_ifrt:pjrt_dtype", + "@xla//xla/python/pjrt_ifrt:xla_ifrt", + "@xla//xla/service:platform_util", + "@xla//xla/tsl/concurrency:ref_count", + "@xla//xla/tsl/framework:allocator", + "@xla//xla/tsl/platform:env", + "@xla//xla/tsl/platform:errors", + "@xla//xla/tsl/platform:logging", + "@xla//xla/tsl/platform:status", + "@xla//xla/tsl/platform:statusor", + "@xla//xla/tsl/python/lib/core:numpy", + ], +) + +cc_library( + name = "py_client_cpu", + srcs = ["py_client_cpu.cc"], + hdrs = ["py_client_cpu.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@nanobind", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + "@xla//xla:shape_util", + "@xla//xla:xla_data_proto_cc", + "@xla//xla/ffi:ffi_api", + "@xla//xla/ffi/api:ffi", + "@xla//xla/pjrt:host_callback", + "@xla//xla/pjrt:transpose", + "@xla//xla/python:nb_numpy", + "@xla//xla/python:types", + ], + alwayslink = 1, +) + +cc_library( + name = "py_host_callback", + srcs = ["py_host_callback.cc"], + hdrs = ["py_host_callback.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":callback", + ":py_host_callback_cc_proto", + ":python_ref_manager", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@nanobind", + "@xla//xla:shape_util", + "@xla//xla:status_macros", + "@xla//xla:util", + "@xla//xla:xla_data_proto_cc", + "@xla//xla/pjrt:host_callback", + "@xla//xla/python:types", + "@xla//xla/python/ifrt", + "@xla//xla/python/pjrt_ifrt", + "@xla//xla/python/pjrt_ifrt:xla_host_callback_proto_cc", + "@xla//xla/tsl/concurrency:ref_count", + "@xla//xla/tsl/platform:statusor", + ], +) + +proto_library( + name = "py_host_callback_proto", + srcs = ["py_host_callback.proto"], +) + +cc_proto_library( + name = "py_host_callback_cc_proto", + visibility = jax_visibility("jaxlib/xla/py_host_callback_cc_proto"), + deps = [":py_host_callback_proto"], +) + +cc_library( + name = "py_socket_transfer", + srcs = ["py_socket_transfer.cc"], + hdrs = ["py_socket_transfer.h"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":nb_class_ptr", + ":py_client", + ":traceback", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@llvm-project//llvm:Support", + "@nanobind", + "@tsl//tsl/platform:casts", + "@xla//xla:util", + "@xla//xla/pjrt:pjrt_client", + "@xla//xla/pjrt:status_casters", + "@xla//xla/python:nb_numpy", + "@xla//xla/python:types", + "@xla//xla/python/ifrt", + "@xla//xla/python/pjrt_ifrt", + "@xla//xla/python/pjrt_ifrt:pjrt_dtype", + "@xla//xla/python/transfer:event_loop", + "@xla//xla/python/transfer:socket-server", + "@xla//xla/python/transfer:socket_bulk_transport", + "@xla//xla/python/transfer:streaming", + "@xla//xla/python/transfer:streaming_ifrt", + "@xla//xla/python/transfer:transfer_socket_proto_cc", + "@xla//xla/tsl/concurrency:ref_count", + "@xla//xla/tsl/platform:statusor", + ], +) + +cc_library( + name = "python_ref_manager", + srcs = ["python_ref_manager.cc"], + hdrs = ["python_ref_manager.h"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + visibility = jax_visibility("jaxlib/xla/python_ref_manager"), + deps = [ + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "@nanobind", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + ], +) + +proto_library( + name = "pytree_proto", + srcs = ["pytree.proto"], +) + +cc_proto_library( + name = "pytree_cc_proto", + deps = [":pytree_proto"], +) + +cc_library( + name = "pytree", + srcs = ["pytree.cc"], + hdrs = ["pytree.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + visibility = jax_visibility("jaxlib/xla/pytree"), + deps = [ + ":nb_class_ptr", + ":pytree_cc_proto", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@nanobind", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + "@xla//xla/pjrt:exceptions", + "@xla//xla/tsl/platform:logging", + ], +) + +cc_library( + name = "sdy", + srcs = ["sdy.cc"], + hdrs = ["sdy.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:BytecodeWriter", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@nanobind", + "@shardy//shardy/dialect/sdy/ir:dialect", + "@xla//xla/hlo/translate/hlo_to_mhlo:hlo_to_mlir_hlo", + "@xla//xla/mlir_hlo:all_passes", + "@xla//xla/pjrt:mlir_to_hlo", + "@xla//xla/pjrt:status_casters", + "@xla//xla/service/spmd/shardy:constants", + "@xla//xla/service/spmd/shardy:utils", + "@xla//xla/service/spmd/shardy/sdy_round_trip:import_shardy_attrs", + "@xla//xla/service/spmd/shardy/sdy_round_trip:pipelines", + "@xla//xla/tsl/framework/mlir:status_scoped_diagnostic_handler", + ], +) + +cc_library( + name = "traceback", + srcs = ["traceback.cc"], + hdrs = ["traceback.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + visibility = jax_visibility("jaxlib/xla/traceback"), + deps = [ + ":nb_class_ptr", + "@com_google_absl//absl/base", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@nanobind", + "@tsl//tsl/platform", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + "@xla//xla/pjrt:exceptions", + "@xla//xla/tsl/platform:logging", + ], +) + +cc_library( + name = "util", + srcs = ["util.cc"], + hdrs = ["util.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + "@com_google_absl//absl/status", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + "@nanobind", + "@xla//xla:util", + "@xla//xla/pjrt:pjrt_future", + "@xla//xla/python:version", + "@xla//xla/python/ifrt", + "@xla//xla/tsl/concurrency:async_value", + "@xla//xla/tsl/concurrency:ref_count", + ], +) + +cc_library( + name = "xla_compiler", + srcs = ["xla_compiler.cc"], + hdrs = ["xla_compiler.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":dlpack", + ":py_client", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "@nanobind", + "@xla//xla:array", + "@xla//xla:debug_options_flags", + "@xla//xla:literal", + "@xla//xla:shape_util", + "@xla//xla:util", + "@xla//xla:xla_data_proto_cc", + "@xla//xla:xla_proto_cc", + "@xla//xla/client:executable_build_options", + "@xla//xla/ffi", + "@xla//xla/ffi:ffi_api", + "@xla//xla/ffi/api:c_api", + "@xla//xla/hlo/builder:xla_builder", + "@xla//xla/hlo/builder:xla_computation", + "@xla//xla/hlo/ir:hlo", + "@xla//xla/hlo/ir:hlo_module_group", + "@xla//xla/hlo/parser:hlo_parser", + "@xla//xla/hlo/pass:hlo_pass", + "@xla//xla/hlo/transforms/simplifiers:flatten_call_graph", + "@xla//xla/hlo/transforms/simplifiers:hlo_dce", + "@xla//xla/hlo/transforms/simplifiers:tuple_simplifier", + "@xla//xla/pjrt:compile_options_proto_cc", + "@xla//xla/pjrt:exceptions", + "@xla//xla/pjrt:pjrt_executable", + "@xla//xla/pjrt:status_casters", + "@xla//xla/python:nb_absl_span", + "@xla//xla/python:nb_helpers", + "@xla//xla/python:nb_numpy", + "@xla//xla/python:types", + "@xla//xla/service:call_inliner", + "@xla//xla/service:computation_placer", + "@xla//xla/service:custom_call_target_registry", + "@xla//xla/service:hlo_graph_dumper", + "@xla//xla/service:hlo_module_config", + "@xla//xla/service:hlo_proto_cc", + "@xla//xla/service:name_uniquer", + "@xla//xla/tsl/lib/strings:proto_serialization", + "@xla//xla/tsl/platform:env", + "@xla//xla/tsl/platform:errors", + "@xla//xla/tsl/platform:logging", + "@xla//xla/tsl/platform:statusor", + ], +) + +pytype_strict_library( + name = "xla_client", + srcs = ["xla_client.py"], + pytype_srcs = ["xla_client.pyi"], + visibility = [":xla_python"], + deps = py_deps([ + "numpy", + "ml_dtypes", + ]) + [":xla_extension"], +) + +py_strict_test( + name = "xla_client_backend_independent_test", + srcs = ["xla_client_backend_independent_test.py"], + deps = [ + ":xla_client", + ] + py_deps([ + "absl/testing", + "numpy", + "portpicker", + ]), +) + +py_strict_library( + name = "xla_client_test", + testonly = 1, + srcs = ["xla_client_test.py"], + visibility = [":xla_python"], + deps = [ + ":xla_client", + "//jax", + "//jax:test_util", + "//jaxlib", + ] + py_deps([ + "absl/flags", + "absl/logging", + "absl/testing", + "ml_dtypes", + "numpy", + ]), +) + +nanobind_extension( + name = "custom_calls_testlib", + testonly = 1, + srcs = ["custom_calls_testlib.cc"], + deps = [ + "@com_google_absl//absl/status", + "@nanobind", + "@xla//xla/ffi/api:c_api", + "@xla//xla/ffi/api:ffi", + ], +) + +py_strict_test( + name = "xla_client_test_cpu", + srcs = ["xla_client_test.py"], + args = ["--backend=cpu"], + env = { + "XLA_FLAGS": "--xla_force_host_platform_device_count=4", + }, + main = "xla_client_test.py", + deps = [ + ":custom_calls_testlib", + ":xla_client", + "//jax", + "//jax:test_util", + "//jaxlib", + ] + py_deps([ + "absl/flags", + "absl/logging", + "absl/testing", + "ml_dtypes", + "numpy", + ]), +) + +py_strict_test( + name = "pytree_test", + srcs = ["pytree_test.py"], + deps = [ + ":xla_client", + ] + py_deps([ + "absl/flags", + "absl/logging", + "absl/testing", + ]), +) + +py_strict_test( + name = "config_test", + srcs = ["config_test.py"], + deps = [ + ":xla_client", + ] + py_deps([ + "absl/flags", + "absl/logging", + "absl/testing", + ]), +) + +py_strict_test( + name = "jax_jit_test", + srcs = ["jax_jit_test.py"], + deps = [ + ":xla_client", + ] + py_deps([ + "absl/flags", + "absl/logging", + "absl/testing", + "numpy", + ]), +) diff --git a/jaxlib/xla/callback.cc b/jaxlib/xla/callback.cc new file mode 100644 index 000000000000..b5519ed3bee3 --- /dev/null +++ b/jaxlib/xla/callback.cc @@ -0,0 +1,173 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/callback.h" + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "jaxlib/xla/python_ref_manager.h" +#include "xla/pjrt/host_callback.h" +#include "xla/pjrt/transpose.h" +#include "xla/primitive_util.h" +#include "xla/python/nb_numpy.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/python/lib/core/numpy.h" +#include "xla/xla_data.pb.h" + +namespace nb = nanobind; + +namespace xla { + +CpuCallback::~CpuCallback() { + // The destructor may be called without GIL held. In that case, we defer it + // to GlobalPyRefManager. + std::vector objects; + objects.push_back(std::move(callable_)); + for (auto& arg : args_) { + objects.push_back(std::move(arg.dtype)); + } + + GlobalPyRefManager()->AddGarbage(absl::MakeSpan(objects)); +} + +absl::Status CpuCallback::PrepareAndCall(void* result, void** arg_ptrs) { + absl::Span inputs(arg_ptrs, args_.size()); + absl::Span outputs(reinterpret_cast(result), + results_.size()); + + nb::gil_scoped_acquire gil; + nb::tuple args = nb::steal(PyTuple_New(inputs.size())); + for (size_t i = 0; i < inputs.size(); ++i) { + if (args_[i].type == xla::TOKEN) { + PyTuple_SET_ITEM(args.ptr(), i, nb::none().release().ptr()); + } else { + nb_numpy_ndarray array = + nb_numpy_ndarray(args_[i].dtype, args_[i].dims, args_[i].strides, + const_cast(inputs[i])); + array.attr("flags").attr("writeable") = nb::bool_(false); + PyTuple_SET_ITEM(args.ptr(), i, array.release().ptr()); + } + } + + EnterHostCallback(); + absl::StatusOr maybe_result_tuple = Call(std::move(args)); + LeaveHostCallback(); + TF_ASSIGN_OR_RETURN(auto result_tuple, maybe_result_tuple); + + for (size_t i = 0; i < results_.size(); ++i) { + if (results_[i].type == xla::TOKEN) { + continue; + } + nb::object output = + nb::borrow(PyTuple_GetItem(result_tuple.ptr(), i)); + nb_numpy_ndarray array = nb_numpy_ndarray::ensure(std::move(output)); + absl::Span dims( + reinterpret_cast(array.shape()), array.ndim()); + absl::Span strides( + reinterpret_cast(array.strides()), array.ndim()); + if (strides == results_[i].expected_strides) { + std::memcpy(outputs[i], array.data(), results_[i].size_in_bytes); + } else { + xla::TransposePlan::Options options; + options.elem_size_in_bytes = + xla::primitive_util::ByteWidth(results_[i].type); + options.dims = dims; + options.permutation = results_[i].reversed_layout; + options.input_layout = xla::TransposePlan::Striding{strides}; + absl::StatusOr> plan = + transpose_cache_.GetOrCreate(options); + if (!plan.ok()) { + return std::move(plan).status(); + } + plan.value()->Execute(array.data(), outputs[i]); + } + } + + return absl::OkStatus(); +} + +absl::StatusOr CpuCallback::Call(nb::tuple args) { + auto py_error_to_status = [](nb::python_error& e) { + std::string error_message = e.what(); + return absl::InternalError( + absl::StrFormat("CpuCallback error: %s", error_message)); + }; + nb::object result_object; + try { + result_object = callable_(*nb::borrow(args)); + } catch (nb::python_error& e) { + return py_error_to_status(e); + } + if (!PyTuple_Check(result_object.ptr())) { + return absl::InternalError( + absl::StrFormat("CPU callback expected a tuple result, got %s", + nb::cast(nb::repr(result_object)))); + } + if (PyTuple_Size(result_object.ptr()) != results_.size()) { + return absl::InternalError( + absl::StrFormat("CPU callback expected a tuple with %d results, got %d", + results_.size(), PyTuple_Size(result_object.ptr()))); + } + nb::tuple result_tuple = nb::cast(result_object); + for (size_t i = 0; i < results_.size(); ++i) { + nb::object output = + nb::borrow(PyTuple_GetItem(result_tuple.ptr(), i)); + if (results_[i].type == xla::TOKEN) { + if (!output.is_none()) { + return absl::InternalError(absl::StrFormat( + "Token output from Python callback should be None, got %s", + nb::cast(nb::repr(output)))); + } + continue; + } + nb_numpy_ndarray array; + try { + array = nb_numpy_ndarray::from_any(output, NPY_ARRAY_ENSUREARRAY); + } catch (nb::python_error& e) { + return py_error_to_status(e); + } + static_assert(sizeof(ssize_t) == sizeof(int64_t), + "Expected ssize_t to be of equal size to int64_t"); + absl::Span dims( + reinterpret_cast(array.shape()), array.ndim()); + if (dims != results_[i].expected_dims) { + return absl::InternalError(absl::StrFormat( + "Mismatched result shape for %d-th return value from CPU callback; " + "expected array with dimensions %s, got %s", + i, absl::StrJoin(results_[i].expected_dims, ","), + absl::StrJoin(dims, ","))); + } + } + return result_tuple; +} + +} // namespace xla diff --git a/jaxlib/xla/callback.h b/jaxlib/xla/callback.h new file mode 100644 index 000000000000..ee1f35ce34a3 --- /dev/null +++ b/jaxlib/xla/callback.h @@ -0,0 +1,87 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_CALLBACK_H_ +#define JAXLIB_XLA_CALLBACK_H_ + +#include +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "nanobind/nanobind.h" +#include "xla/pjrt/transpose.h" +#include "xla/python/nb_numpy.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +class CpuCallback { + public: + struct Arg { + xla::PrimitiveType type; // XLA type + nb_dtype dtype; // NumPy type, for array types. + absl::InlinedVector dims; // Dimensions, for array types. + std::vector strides; // Byte strides, for array types. + size_t size_in_bytes; // Size of the array in bytes. + }; + struct Result { + xla::PrimitiveType type; // XLA type + // Expected output shape, for array types + absl::InlinedVector expected_dims; + // Expected output byte strides, for array types. If the strides do not + // match the output will be transposed into the expected layout. + std::vector expected_strides; + // The desired order of output dimensions in major-to-minor order. + absl::InlinedVector reversed_layout; + // Size of the array in bytes. + size_t size_in_bytes; + }; + + explicit CpuCallback(nanobind::callable callable, std::vector args, + std::vector results) + : callable_(std::move(callable)), + args_(std::move(args)), + results_(std::move(results)), + transpose_cache_(/*capacity=*/16) {} + + ~CpuCallback(); + + const std::vector& args() const { return args_; } + size_t num_args() const { return args_.size(); } + + const std::vector& results() const { return results_; } + size_t num_results() const { return results_.size(); } + void* callback() const { return callable_.ptr(); } + + xla::TransposePlanCache& transpose_cache() { return transpose_cache_; } + + absl::Status PrepareAndCall(void* result, void** arg_ptrs); + + absl::StatusOr Call(nanobind::tuple args); + + private: + nanobind::callable callable_; + std::vector args_; + std::vector results_; + xla::TransposePlanCache transpose_cache_; +}; + +} // namespace xla + +#endif // JAXLIB_XLA_CALLBACK_H_ diff --git a/jaxlib/xla/config.cc b/jaxlib/xla/config.cc new file mode 100644 index 000000000000..8804b783eb72 --- /dev/null +++ b/jaxlib/xla/config.cc @@ -0,0 +1,348 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/config.h" + +#include + +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_set.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "jaxlib/xla/python_ref_manager.h" +#include "xla/tsl/platform/logging.h" + +namespace jax { + +namespace nb = nanobind; + +// Singleton object used to represet "value not set" in thread-local configs. +nb::object UnsetObject() { + return nb::steal(PyObject_CallObject( + reinterpret_cast(&PyBaseObject_Type), nullptr)); +} + +// Each configuration object has: +// * a global value, and +// * a thread-local value. +// When querying the state of a config, the thread-local value is used if it is +// set. Otherwise, the global value is used. + +// This class represents all of the thread-local configuration state for a +// thread. +class ThreadLocalConfigState { + public: + ThreadLocalConfigState(); + ~ThreadLocalConfigState(); + + static ThreadLocalConfigState& Instance() { + thread_local auto state = std::make_unique(); + return *state; + } + + nb::object Get(int key) { + DCHECK_GE(key, 0); + return key >= entries_.size() ? nb::object() : entries_[key]; + } + + void Set(int key, nb::object value); + + private: + friend class GlobalConfigState; + + // These values are accessed in one of two ways: + // * The owning thread reads or writes them, while holding the GIL, or, under + // free-threading, while the owning thread is in ATTACHED gc state. + // * Other threads may read or clear values while performing a garbarge + // collection. + // No locking is needed because a GC thread cannot run concurrently with other + // Python threads; even under free-threading Python uses a stop-the-world GC. + std::vector entries_; +}; + +// This class represents all of the global configuration state. +// TODO(phawkins): to support free-threading, we will need to add locking to +// this class. +class GlobalConfigState { + public: + static GlobalConfigState& Instance() { + static auto state = new GlobalConfigState(); + return *state; + } + + nb::object Get(int key) const; + void Set(int key, nb::object value); + + // Adds or removes a thread-local state from the set of thread-local states. + void AddThreadLocalState(ThreadLocalConfigState* state) { + absl::MutexLock lock(&mu_); + thread_local_states_.insert(state); + } + void RemoveThreadLocalState(ThreadLocalConfigState* state) { + absl::MutexLock lock(&mu_); + thread_local_states_.erase(state); + } + + // Python GC helpers. These are called from the tp_traverse and tp_clear + // methods of the Config class. + int tp_traverse(int key, PyObject* self, visitproc visit, void* arg); + int tp_clear(int key, PyObject* self); + + // Returns the singleton object representing "value not set". + const nb::object& unset() const { return unset_; } + + // Returns the set of keys that should be included in the jit key. + absl::Span include_in_jit_key() const { + return include_in_jit_key_; + } + + private: + friend class Config; + + // The set of thread-local states. This is used during garbarge collection to + // visit thread-local values. + absl::Mutex mu_; + absl::flat_hash_set thread_local_states_ + ABSL_GUARDED_BY(mu_); + std::vector entries_; + std::vector include_in_jit_key_; + nb::object unset_ = UnsetObject(); +}; + +ThreadLocalConfigState::ThreadLocalConfigState() { + GlobalConfigState::Instance().AddThreadLocalState(this); +} + +ThreadLocalConfigState::~ThreadLocalConfigState() { + // It's important that we remove the thread-local state before we access + // entries_. This ensures that accesses to entries_ are ordered with respect + // any garbage collection. + GlobalConfigState::Instance().RemoveThreadLocalState(this); + // We do not hold the GIL, so we must use deferred destruction. + xla::GlobalPyRefManager()->AddGarbage(absl::MakeSpan(entries_)); +} + +void ThreadLocalConfigState::Set(int key, nb::object value) { + DCHECK_GE(key, 0); + if (key >= entries_.size()) { + entries_.resize(key + 1); + } + std::swap(entries_[key], value); +} + +nb::object GlobalConfigState::Get(int key) const { + DCHECK_GE(key, 0); + DCHECK_LT(key, entries_.size()); + return entries_[key]; +} + +void GlobalConfigState::Set(int key, nb::object value) { + DCHECK_GE(key, 0); + DCHECK_LT(key, entries_.size()); + std::swap(entries_[key], value); +} + +int GlobalConfigState::tp_traverse(int key, PyObject* self, visitproc visit, + void* arg) { + DCHECK_GE(key, 0); + if (key < entries_.size()) { + PyObject* value = entries_[key].ptr(); + Py_VISIT(value); + } + absl::MutexLock lock(&mu_); + for (const auto* state : thread_local_states_) { + if (key < state->entries_.size()) { + PyObject* value = state->entries_[key].ptr(); + Py_VISIT(value); + } + } + return 0; +} + +int GlobalConfigState::tp_clear(int key, PyObject* self) { + if (key < entries_.size()) { + nb::object tmp; + std::swap(entries_[key], tmp); + } + // We destroy the python objects outside of the lock out of an abundance of + // caution. + std::vector to_destroy; + absl::MutexLock lock(&mu_); + to_destroy.reserve(thread_local_states_.size()); + for (auto* state : thread_local_states_) { + if (key < state->entries_.size()) { + nb::object tmp; + std::swap(state->entries_[key], tmp); + to_destroy.push_back(std::move(tmp)); + } + } + return 0; +} + +// A Config object represents a configurable object with both global and +// thread-local state. This class is wrapped using nanobind and exposed to +// Python. +class Config { + public: + Config(nb::object value, bool include_in_jit_key); + + // Returns the thread-local value if it is set, otherwise the global value. + nb::object Get(); + + // Returns the global value. + nb::object GetGlobal(); + + // Sets the global value. + void SetGlobal(nb::object value); + + // Returns the thread-local value. + nb::object GetLocal(); + + // Sets the thread-local value. May be `unset`. + void SetLocal(nb::object value); + + // Swaps the thread-local value with `value`. Returns the previous value. + // Either may be `unset`. + nb::object SwapLocal(nb::object value); + + // This class doesn't actually hold any data, but it's the only type + // known to Python. We pretend that this object owns both the global and any + // thread-local values corresponding to this key. + static int tp_traverse(PyObject* self, visitproc visit, void* arg); + static int tp_clear(PyObject* self); + static PyType_Slot slots_[]; + + private: + int key_; +}; + +Config::Config(nb::object value, bool include_in_jit_key) { + auto& instance = GlobalConfigState::Instance(); + key_ = instance.entries_.size(); + instance.entries_.push_back(std::move(value)); + if (include_in_jit_key) { + instance.include_in_jit_key_.push_back(key_); + } +} + +nb::object Config::GetLocal() { + nb::object result = ThreadLocalConfigState::Instance().Get(key_); + if (!result.is_valid()) { + return GlobalConfigState::Instance().unset(); + } + return result; +} + +nb::object Config::GetGlobal() { + return GlobalConfigState::Instance().Get(key_); +} + +nb::object Config::Get() { + nb::object local = ThreadLocalConfigState::Instance().Get(key_); + if (local.is_valid()) { + return local; + } + return GetGlobal(); +} + +void Config::SetLocal(nb::object value) { + const auto& instance = GlobalConfigState::Instance(); + if (value.ptr() == instance.unset().ptr()) { + value = nb::object(); + } + ThreadLocalConfigState::Instance().Set(key_, std::move(value)); +} + +nb::object Config::SwapLocal(nb::object value) { + const auto& global_instance = GlobalConfigState::Instance(); + auto& instance = ThreadLocalConfigState::Instance(); + auto result = instance.Get(key_); + if (value.ptr() == global_instance.unset().ptr()) { + value = nb::object(); + } + instance.Set(key_, std::move(value)); + if (!result.is_valid()) { + return global_instance.unset(); + } + return result; +} + +void Config::SetGlobal(nb::object value) { + GlobalConfigState::Instance().Set(key_, value); +} + +/* static */ int Config::tp_traverse(PyObject* self, visitproc visit, + void* arg) { + Py_VISIT(Py_TYPE(self)); + if (!nb::inst_ready(self)) { + return 0; + } + Config* c = nb::inst_ptr(self); + // For the purposes of GC, we pretend that this object owns both the global + // and any thread-local values corresponding to this key. + return GlobalConfigState::Instance().tp_traverse(c->key_, self, visit, arg); +} + +/* static */ int Config::tp_clear(PyObject* self) { + Config* c = nb::inst_ptr(self); + return GlobalConfigState::Instance().tp_clear(c->key_, self); +} + +PyType_Slot Config::slots_[] = { + {Py_tp_traverse, reinterpret_cast(Config::tp_traverse)}, + {Py_tp_clear, reinterpret_cast(Config::tp_clear)}, + {0, nullptr}, +}; + +void BuildConfigSubmodule(nanobind::module_& m) { + nb::module_ config_module = m.def_submodule("config", "Config library"); + + config_module.attr("unset") = GlobalConfigState::Instance().unset(); + + nb::class_ config(config_module, "Config", + nb::type_slots(Config::slots_), nb::is_generic()); + config.def(nb::init(), nb::arg("value").none(), + nb::arg("include_in_jit_key") = false); + config.def_prop_ro("value", &Config::Get); + config.def("get_local", &Config::GetLocal); + config.def("get_global", &Config::GetGlobal); + config.def("set_local", &Config::SetLocal, nb::arg("value").none()); + config.def("swap_local", &Config::SwapLocal, nb::arg("value").none()); + config.def("set_global", &Config::SetGlobal, nb::arg("value").none()); +} + +std::vector JitConfigs() { + auto& instance = GlobalConfigState::Instance(); + auto& thread_local_instance = ThreadLocalConfigState::Instance(); + std::vector result; + result.reserve(instance.include_in_jit_key().size()); + for (int i : instance.include_in_jit_key()) { + nb::object local = thread_local_instance.Get(i); + if (local.is_valid()) { + result.push_back(std::move(local)); + } else { + result.push_back(instance.Get(i)); + } + } + return result; +} + +} // namespace jax diff --git a/jaxlib/xla/config.h b/jaxlib/xla/config.h new file mode 100644 index 000000000000..2a9281f498b4 --- /dev/null +++ b/jaxlib/xla/config.h @@ -0,0 +1,34 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_CONFIG_H_ +#define JAXLIB_XLA_CONFIG_H_ + +#include + +// placeholder for index annotation headers +#include "nanobind/nanobind.h" + +namespace jax { + +// Returns the set of configuration values that should be included in the JIT +// cache key. +std::vector JitConfigs(); + +void BuildConfigSubmodule(nanobind::module_& m); + +} // namespace jax + +#endif // JAXLIB_XLA_CONFIG_H_ diff --git a/jaxlib/xla/config_test.py b/jaxlib/xla/config_test.py new file mode 100644 index 000000000000..8701a37acd1d --- /dev/null +++ b/jaxlib/xla/config_test.py @@ -0,0 +1,71 @@ +# Copyright 2024 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import threading + +from absl.testing import absltest + +from jax.jaxlib.xla import xla_client + +config = xla_client._xla.config + + +class ConfigTest(absltest.TestCase): + + def testBasic(self): + c = config.Config(1) + self.assertEqual(c.value, 1) + self.assertEqual(c.get_global(), 1) + self.assertEqual(c.get_local(), config.unset) + + c.set_global(2) + self.assertEqual(c.value, 2) + self.assertEqual(c.get_global(), 2) + self.assertEqual(c.get_local(), config.unset) + + c.set_local(3) + self.assertEqual(c.value, 3) + self.assertEqual(c.get_global(), 2) + self.assertEqual(c.get_local(), 3) + + c.set_global(4) + self.assertEqual(c.value, 3) + self.assertEqual(c.get_global(), 4) + self.assertEqual(c.get_local(), 3) + + c.set_local(config.unset) + self.assertEqual(c.value, 4) + self.assertEqual(c.get_global(), 4) + self.assertEqual(c.get_local(), config.unset) + + def testThreading(self): + c = config.Config(1) + + def Body(): + for i in range(100): + c.set_local(i) + self.assertEqual(c.get_local(), i) + self.assertEqual(c.get_global(), 1) + self.assertEqual(c.value, i) + + threads = [threading.Thread(target=Body) for _ in range(4)] + for t in threads: + t.start() + for t in threads: + t.join() + + +if __name__ == "__main__": + absltest.main() diff --git a/jaxlib/xla/custom_call_sharding.cc b/jaxlib/xla/custom_call_sharding.cc new file mode 100644 index 000000000000..00accd85aefd --- /dev/null +++ b/jaxlib/xla/custom_call_sharding.cc @@ -0,0 +1,346 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/custom_call_sharding.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/tuple.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/hlo/utils/hlo_sharding_util.h" +#include "xla/pjrt/c/pjrt_c_api.h" +#include "xla/pjrt/c/pjrt_c_api_custom_partitioner_extension.h" +#include "xla/pjrt/c/pjrt_c_api_helpers.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/custom_call_batch_partitioner.h" +#include "xla/python/custom_partition_callback.h" +#include "xla/python/inspect_sharding.h" +#include "xla/shape.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" + +namespace xla { + +namespace nb = ::nanobind; + +class PyCustomCallPartitionerCallbacks { + public: + PyCustomCallPartitionerCallbacks(nb::object prop_user_sharding, + nb::object partition, + nb::object infer_sharding_from_operands) + : prop_user_sharding_(prop_user_sharding), + partition_(partition), + infer_sharding_from_operands_(infer_sharding_from_operands) { + callbacks_.version = 0; + callbacks_.private_data = this; + callbacks_.dtor = +[](JAX_CustomCallPartitioner_Callbacks* self) { + delete GetSelfPtr(self); + }; + callbacks_.partition = +[](JAX_CustomCallPartitioner_Callbacks* self, + JAX_CustomCallPartitioner_Partition_Args* args) { + jax::PopulateResults(GetSelfPtr(self)->CallPartition(args), args); + }; + callbacks_.infer_sharding = + +[](JAX_CustomCallPartitioner_Callbacks* self, + JAX_CustomCallPartitioner_InferShardingFromOperands_Args* args) { + jax::PopulateResults( + GetSelfPtr(self)->CallInferShardingFromOperands(args), args); + }; + callbacks_.propagate_user_sharding = + +[](JAX_CustomCallPartitioner_Callbacks* self, + JAX_CustomCallPartitioner_PropagateUserSharding_Args* args) { + jax::PopulateResults( + GetSelfPtr(self)->CallPropagateUserSharding(args), args); + }; + } + + absl::StatusOr< + std::tuple, xla::HloSharding>> + CallPartition(JAX_CustomCallPartitioner_Partition_Args* args) const { + if (args->header.api_version != 0) { + return absl::InternalError("API version mismatch."); + } + TF_ASSIGN_OR_RETURN(auto args_tuple, jax::ReadArgs(args)); + std::vector shapes = std::move(std::get<0>(args_tuple)); + std::vector> shardings = + std::move(std::get<1>(args_tuple)); + xla::Shape result_shape = std::move(std::get<2>(args_tuple)); + std::optional result_sharding = + std::move(std::get<3>(args_tuple)); + absl::string_view backend_config = std::move(std::get<4>(args_tuple)); + + { + nb::gil_scoped_acquire gil; + try { + auto py_result = + partition_(shapes, shardings, result_shape, result_sharding, + nb::bytes(backend_config.data(), backend_config.size())); + try { + auto [ir, arg_shardings, result_sharding] = nb::cast< + std::tuple, HloSharding>>( + py_result); + if (arg_shardings.size() != args->num_args) { + return xla::Internal( + "Shardings returned from partitioning: lengths must match: %d " + "vs %d", + arg_shardings.size(), args->num_args); + } + return std::make_tuple(std::string(ir.c_str(), ir.size()), + std::move(arg_shardings), + std::move(result_sharding)); + } catch (const nb::cast_error& e) { + return xla::Internal( + "Shardings returned from partitioning: expected " + "Tuple[bytes, List[HloSharding], HloSharding] got: %s", + nb::cast(nb::repr(py_result))); + } + } catch (const nb::python_error& e) { + return xla::Internal("custom_partitioner: %s", e.what()); + } + } + } + + absl::StatusOr> CallInferShardingFromOperands( + JAX_CustomCallPartitioner_InferShardingFromOperands_Args* args) const { + if (args->header.api_version != 0) { + return absl::InternalError("API version mismatch."); + } + TF_ASSIGN_OR_RETURN(auto args_tuple, jax::ReadArgs(args)); + std::vector arg_shapes = std::move(std::get<0>(args_tuple)); + std::vector> arg_shardings = + std::move(std::get<1>(args_tuple)); + xla::Shape result_shape = std::move(std::get<2>(args_tuple)); + absl::string_view backend_config = std::move(std::get<3>(args_tuple)); + + std::optional result; + nb::gil_scoped_acquire gil; + try { + auto py_result = infer_sharding_from_operands_( + arg_shapes, arg_shardings, result_shape, + nb::bytes(backend_config.data(), backend_config.size())); + if (py_result.is_none()) { + return std::nullopt; + } + return nb::cast(py_result); + } catch (const nb::python_error& e) { + return xla::Internal("custom_partitioner: %s", e.what()); + } + } + + absl::StatusOr CallPropagateUserSharding( + JAX_CustomCallPartitioner_PropagateUserSharding_Args* args) const { + if (args->header.api_version != 0) { + return absl::InternalError("API version mismatch."); + } + TF_ASSIGN_OR_RETURN(auto args_tuple, jax::ReadArgs(args)); + xla::HloSharding result_sharding = std::move(std::get<0>(args_tuple)); + xla::Shape result_shape = std::move(std::get<1>(args_tuple)); + absl::string_view backend_config = std::move(std::get<2>(args_tuple)); + + nb::gil_scoped_acquire gil; + try { + // TODO(parkers): expand this API to handle the `user` sharding. + // The user is used when the custom call returns a Tuple and + // the user is a get-tuple-element. In this case we must update only + // part of the sharding spec. + auto result = nb::cast(prop_user_sharding_( + result_sharding, result_shape, + nb::bytes(backend_config.data(), backend_config.size()))); + return result; + } catch (const nb::python_error& e) { + return xla::Internal("custom_partitioner: %s", e.what()); + } + } + + JAX_CustomCallPartitioner_Callbacks* callbacks() { return &callbacks_; } + + private: + static PyCustomCallPartitionerCallbacks* GetSelfPtr( + JAX_CustomCallPartitioner_Callbacks* callbacks) { + return reinterpret_cast( + callbacks->private_data); + } + + JAX_CustomCallPartitioner_Callbacks callbacks_; + nb::object prop_user_sharding_; + nb::object partition_; + nb::object infer_sharding_from_operands_; +}; + +namespace { + +void CallInspectSharding(void* obj, JAX_InspectSharding_Callback_Args* args) { + std::optional arg = jax::InspectShardingReadArgs(args); + if (!arg.has_value()) { + return; + } + try { + nb::gil_scoped_acquire gil; + nb::handle(reinterpret_cast(obj))(*std::move(arg)); + } catch (const nb::python_error& e) { + jax::InspectShardingSetError(args, std::string(e.what())); + } +} + +} // namespace + +void BuildCustomCallShardingPybindAPI(nb::module_& m) { + m.def( + "register_custom_call_partitioner", + [](std::string name, nb::object prop_user_sharding, nb::object partition, + nb::object infer_sharding_from_operands, + bool can_side_effecting_have_replicated_sharding, + std::optional c_api) { + auto* c_fns = + (new PyCustomCallPartitionerCallbacks(prop_user_sharding, partition, + infer_sharding_from_operands)) + ->callbacks(); + c_fns->can_side_effecting_have_replicated_sharding = + can_side_effecting_have_replicated_sharding; + if (!c_api.has_value()) { + RegisterCustomCallPartitioner( + name, jax::CreateCApiCustomCallPartitioner(c_fns)); + return; + } + + if (absl::string_view(c_api->name()) != "pjrt_c_api") { + throw absl::InvalidArgumentError( + "Argument to register_custom_call_partitioner was not a " + "pjrt_c_api capsule."); + } + auto* c_api_value = static_cast(c_api->data()); + PJRT_Custom_Partitioner_Extension* extension = + pjrt::FindExtension( + c_api_value, + PJRT_Extension_Type::PJRT_Extension_Type_Custom_Partitioner); + if (extension == nullptr) { + return; + } + PJRT_Register_Custom_Partitioner_Args args; + args.struct_size = PJRT_Register_Custom_Partitioner_Args_STRUCT_SIZE; + args.name = name.c_str(); + args.name_size = name.size(); + args.callbacks = c_fns; + PJRT_Error* error = + reinterpret_cast( + extension) + ->register_custom_partitioner(&args); + std::unique_ptr error_ptr( + error, pjrt::MakeErrorDeleter(c_api_value)); + ThrowIfError(pjrt::PjrtErrorToStatus(error_ptr.get(), c_api_value)); + }, + R"(Registers a partitioner for a custom-call operation. + +Args: + name: custom_call_target to match. + prop_user_sharding: Custom backwards sharding propagation rule. + Takes result sharding and returns the instruction sharding. + partition: Lowering rule. Takes operand and result shardings and returns + a generated HLO and sharding specs. The spmd lowerer first reshards + to match the returned sharding specs and then inserts the generated hlo. + infer_sharding_from_operands: Custom forwards sharding propagation rule. + Takes operand sharding and returns the instruction sharding. + can_side_effecting_have_replicated_sharding: Side effecting ops are not + allowed to have replicated sharding. Pass true to disable this check. + c_api: Optional `PJRT_Api*` if it is called with a plugin. This is safe to + call on plugins that do not implement the custom partitioner extension +)", + nb::arg("name"), nb::arg("prop_user_sharding"), nb::arg("partition"), + nb::arg("infer_sharding_from_operands"), + nb::arg("can_side_effecting_have_replicated_sharding") = false, + nb::arg("c_api").none() = std::nullopt); + m.def("encode_inspect_sharding_callback", + [](nb::object handler) -> nb::bytes { + JAX_InspectSharding_Callback cb; + cb.call = &CallInspectSharding; + cb.data = handler.ptr(); + char bytes[sizeof(JAX_InspectSharding_Callback)]; + std::memcpy(&bytes, &cb, sizeof(JAX_InspectSharding_Callback)); + return nb::bytes(bytes, sizeof(JAX_InspectSharding_Callback)); + }); + + nb::module_ hlo_sharding_util_m = m.def_submodule( + "hlo_sharding_util", "Utilities for manipulating HloSharding."); + hlo_sharding_util_m.def( + "PartiallyReplicateTiledShardingOnDims", + [](const HloSharding& sharding, std::vector dims) { + return hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( + sharding, dims); + }); + + m.def( + "register_custom_call_as_batch_partitionable", + [](std::string target_name, std::optional c_api) { + if (!c_api.has_value()) { + RegisterCustomCallPartitioner( + target_name, std::make_unique()); + return; + } + if (absl::string_view(c_api->name()) != "pjrt_c_api") { + throw absl::InvalidArgumentError( + "Argument to register_custom_call_partitioner was not a " + "pjrt_c_api capsule."); + } + auto* c_api_value = static_cast(c_api->data()); + PJRT_Custom_Partitioner_Extension* extension = + pjrt::FindExtension( + c_api_value, + PJRT_Extension_Type::PJRT_Extension_Type_Custom_Partitioner); + if (extension == nullptr) { + return; + } + PJRT_Register_Batch_Partitionable_Args args; + args.struct_size = PJRT_Register_Batch_Partitionable_Args_STRUCT_SIZE; + args.name = target_name.c_str(); + args.name_size = target_name.size(); + PJRT_Error* error = extension->register_batch_partitionable(&args); + std::unique_ptr error_ptr( + error, pjrt::MakeErrorDeleter(c_api_value)); + ThrowIfError(pjrt::PjrtErrorToStatus(error_ptr.get(), c_api_value)); + }, + R"(Registers a custom call as batch partitionable. + +If a custom call is "batch partitionable", it means that it can be trivially +partitioned on some number of (leading) dimensions, with the same call being +executed independently on each shard of data. If the data are sharded on +non-batch dimensions, partitioning will re-shard the data to be replicated on +the non-batch dimensions. + +Args: + target_name: the target name of the batch partitionable custom call. + c_api: optional `PJRT_Api*` to support registration via a PJRT plugin. +)", + nb::arg("target_name"), nb::arg("c_api").none() = std::nullopt); +} + +} // namespace xla diff --git a/jaxlib/xla/custom_call_sharding.h b/jaxlib/xla/custom_call_sharding.h new file mode 100644 index 000000000000..5a5f3776cc30 --- /dev/null +++ b/jaxlib/xla/custom_call_sharding.h @@ -0,0 +1,28 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_CUSTOM_CALL_SHARDING_H_ +#define JAXLIB_XLA_CUSTOM_CALL_SHARDING_H_ + +// placeholder for index annotation headers +#include "nanobind/nanobind.h" + +namespace xla { + +void BuildCustomCallShardingPybindAPI(nanobind::module_& m); + +} // namespace xla + +#endif // JAXLIB_XLA_CUSTOM_CALL_SHARDING_H_ diff --git a/jaxlib/xla/custom_calls_testlib.cc b/jaxlib/xla/custom_calls_testlib.cc new file mode 100644 index 000000000000..58f4818a431e --- /dev/null +++ b/jaxlib/xla/custom_calls_testlib.cc @@ -0,0 +1,128 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "nanobind/nanobind.h" +#include "xla/ffi/api/c_api.h" +#include "xla/ffi/api/ffi.h" + +namespace xla::ffi { +namespace nb = ::nanobind; + +// Implement custom calls as static functions with XLA FFI types in the function +// signature that gives access to the arguments and results buffers together +// with their types and dimensions. See `ffi/api/ffi_test.cc` for more XLA FFI +// examples and features (e.g. binding attributes, custom user-defined structs +// and arbitrary execution context). + +static Error AlwaysFail(Result) { + return Error(XLA_FFI_Error_Code_INTERNAL, "Failed intentionally"); +} + +static Error AlwaysSucceed(Result) { return Error::Success(); } + +static Error Subtract(BufferR0 a, BufferR0 b, + Result> out) { + *out->typed_data() = *a.typed_data() - *b.typed_data(); + return Error::Success(); +} + +static Error SubtractCst(BufferR0 a, + Result> out, float cst) { + *out->typed_data() = *a.typed_data() - cst; + return Error::Success(); +} + +// Define XLA FFI handlers from the implementations defined above using explicit +// XLA FFI binding API to describe type signatures of custom calls. + +XLA_FFI_DEFINE_HANDLER(kAlwaysFail, AlwaysFail, Ffi::Bind().Ret()); + +XLA_FFI_DEFINE_HANDLER(kAlwaysSucceed, AlwaysSucceed, + Ffi::Bind().Ret()); + +XLA_FFI_DEFINE_HANDLER(kSubtract, Subtract, + Ffi::Bind() + .Arg>() + .Arg>() + .Ret>()); + +XLA_FFI_DEFINE_HANDLER(kSubtractCst, SubtractCst, + Ffi::Bind() + .Arg>() + .Ret>() + .Attr("cst")); + +// XLA FFI calls can also be stateful. +struct TestFfiState { + static TypeId id; + explicit TestFfiState(int32_t value) : value(value) {} + int32_t value; +}; +TypeId TestFfiState::id = {}; + +static ErrorOr> StateInstantiate() { + return std::make_unique(42); +} + +static Error StateExecute(TestFfiState* state, + Result> out) { + *out->typed_data() = state->value; + return Error::Success(); +} + +XLA_FFI_DEFINE_HANDLER(kStateInstantiate, StateInstantiate, + Ffi::BindInstantiate()); +XLA_FFI_DEFINE_HANDLER( + kStateExecute, StateExecute, + Ffi::Bind().Ctx>().Ret>()); + +template +static auto BindFunction(T* fn) { + return nb::capsule(reinterpret_cast(fn)); +} + +template +static auto BindTypeId(T* typeId) { + return nb::capsule(reinterpret_cast(typeId)); +} + +// Custom calls registration library that exports function pointers to XLA FFI +// handlers to the python users. +NB_MODULE(custom_calls_testlib, m) { + m.def("registrations", []() { + nb::dict dict; + dict["always_fail"] = BindFunction(kAlwaysFail); + dict["always_succeed"] = BindFunction(kAlwaysSucceed); + dict["subtract_f32"] = BindFunction(kSubtract); + dict["subtract_f32_cst"] = BindFunction(kSubtractCst); + + nb::dict bundle; + bundle["instantiate"] = BindFunction(kStateInstantiate); + bundle["execute"] = BindFunction(kStateExecute); + dict["stateful"] = bundle; + + return dict; + }); + m.def("type_ids", []() { + nb::dict type_ids; + type_ids["test_ffi_state"] = BindTypeId(&TestFfiState::id); + return type_ids; + }); +} + +} // namespace xla::ffi diff --git a/jaxlib/xla/dlpack.cc b/jaxlib/xla/dlpack.cc new file mode 100644 index 000000000000..c8d02e679036 --- /dev/null +++ b/jaxlib/xla/dlpack.cc @@ -0,0 +1,700 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/dlpack.h" + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "include/dlpack/dlpack.h" +#include "llvm/Support/Casting.h" +#include "nanobind/nanobind.h" +#include "nanobind/ndarray.h" +#include "jaxlib/xla/nb_class_ptr.h" +#include "jaxlib/xla/py_array.h" +#include "jaxlib/xla/py_client.h" +#include "jaxlib/xla/python_ref_manager.h" +#include "jaxlib/xla/traceback.h" +#include "jaxlib/xla/util.h" +#include "xla/layout.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_common.h" +#include "xla/pjrt/pjrt_compiler.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/pjrt_ifrt/pjrt_array.h" +#include "xla/python/pjrt_ifrt/pjrt_client.h" +#include "xla/python/pjrt_ifrt/pjrt_device.h" +#include "xla/python/types.h" +#include "xla/shape_util.h" +#include "xla/status_macros.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" + +namespace nb = nanobind; + +namespace xla { +namespace { + +const char* const kDlTensorCapsuleName = "dltensor"; + +struct DLPackTensor { + ~DLPackTensor(); + + // `buffer_reference` is populated if we have shared (read-only) access. + nb::object buffer_reference; + + // `external_reference` is always populated. + std::unique_ptr external_reference; + + std::vector shape; + std::vector strides; + DLManagedTensor tensor; +}; + +DLPackTensor::~DLPackTensor() { + if (buffer_reference) { + GlobalPyRefManager()->AddGarbage( + absl::MakeSpan(&buffer_reference, /*size=*/1)); + } +} + +void DLPackTensorDeleter(DLManagedTensor* t) { + if (t) { + delete static_cast(t->manager_ctx); + } +} + +absl::StatusOr PrimitiveTypeToDLDataType(PrimitiveType type) { + switch (type) { + case S8: + return DLDataType{kDLInt, 8, 1}; + case S16: + return DLDataType{kDLInt, 16, 1}; + case S32: + return DLDataType{kDLInt, 32, 1}; + case S64: + return DLDataType{kDLInt, 64, 1}; + case U8: + return DLDataType{kDLUInt, 8, 1}; + case U16: + return DLDataType{kDLUInt, 16, 1}; + case U32: + return DLDataType{kDLUInt, 32, 1}; + case U64: + return DLDataType{kDLUInt, 64, 1}; + case F4E2M1FN: + return DLDataType{kDLFloat4_e2m1fn, 4, 1}; + case F8E3M4: + return DLDataType{kDLFloat8_e3m4, 8, 1}; + case F8E4M3: + return DLDataType{kDLFloat8_e4m3, 8, 1}; + case F8E4M3B11FNUZ: + return DLDataType{kDLFloat8_e4m3b11fnuz, 8, 1}; + case F8E4M3FN: + return DLDataType{kDLFloat8_e4m3fn, 8, 1}; + case F8E4M3FNUZ: + return DLDataType{kDLFloat8_e4m3fnuz, 8, 1}; + case F8E5M2: + return DLDataType{kDLFloat8_e5m2, 8, 1}; + case F8E5M2FNUZ: + return DLDataType{kDLFloat8_e5m2fnuz, 8, 1}; + case F8E8M0FNU: + return DLDataType{kDLFloat8_e8m0fnu, 8, 1}; + case BF16: + return DLDataType{kDLBfloat, 16, 1}; + case F16: + return DLDataType{kDLFloat, 16, 1}; + case F32: + return DLDataType{kDLFloat, 32, 1}; + case F64: + return DLDataType{kDLFloat, 64, 1}; + case PRED: + return DLDataType{kDLBool, 8, 1}; + case C64: + return DLDataType{kDLComplex, 64, 1}; + case C128: + return DLDataType{kDLComplex, 128, 1}; + default: + return Unimplemented("XLA type %s has no DLPack equivalent", + PrimitiveType_Name(type)); + } +} + +absl::StatusOr DLDataTypeToPrimitiveType(DLDataType type) { + if (type.lanes != 1) { + return Unimplemented("DLPack types with lanes != 1 not implemented, got %d", + type.lanes); + } + switch (type.code) { + case kDLBool: + switch (type.bits) { + case 8: + return PRED; + default: + return Unimplemented( + "Only 8-bit DLPack booleans are supported, got %d bits", + type.bits); + } + case kDLInt: + switch (type.bits) { + case 8: + return S8; + case 16: + return S16; + case 32: + return S32; + case 64: + return S64; + default: + return Unimplemented( + "Invalid or unsupported DLPack integer width: %d bits", + type.bits); + } + case kDLUInt: + switch (type.bits) { + case 8: + return U8; + case 16: + return U16; + case 32: + return U32; + case 64: + return U64; + default: + return Unimplemented( + "Invalid or unsupported DLPack unsigned integer width: %d bits", + type.bits); + } + case kDLFloat4_e2m1fn: + if (type.bits == 4) { + return F4E2M1FN; + } + return Unimplemented( + "Invalid or unsupported DLPack float4_e2m1fn width: %d bits", + type.bits); + case kDLFloat8_e3m4: + if (type.bits == 8) { + return F8E3M4; + } + return Unimplemented( + "Invalid or unsupported DLPack float8_e3m4 width: %d bits", + type.bits); + case kDLFloat8_e4m3: + if (type.bits == 8) { + return F8E4M3; + } + return Unimplemented( + "Invalid or unsupported DLPack float8_e4m3 width: %d bits", + type.bits); + case kDLFloat8_e4m3b11fnuz: + if (type.bits == 8) { + return F8E4M3B11FNUZ; + } + return Unimplemented( + "Invalid or unsupported DLPack float8_e4m3b11fnuz width: %d bits", + type.bits); + case kDLFloat8_e4m3fn: + if (type.bits == 8) { + return F8E4M3FN; + } + return Unimplemented( + "Invalid or unsupported DLPack float8_e4m3fn width: %d bits", + type.bits); + case kDLFloat8_e4m3fnuz: + if (type.bits == 8) { + return F8E4M3FNUZ; + } + return Unimplemented( + "Invalid or unsupported DLPack float8_e4m3fnuz width: %d bits", + type.bits); + case kDLFloat8_e5m2: + if (type.bits == 8) { + return F8E5M2; + } + return Unimplemented( + "Invalid or unsupported DLPack float8_e5m2 width: %d bits", + type.bits); + case kDLFloat8_e5m2fnuz: + if (type.bits == 8) { + return F8E5M2FNUZ; + } + return Unimplemented( + "Invalid or unsupported DLPack float8_e5m2fnuz width: %d bits", + type.bits); + case kDLFloat8_e8m0fnu: + if (type.bits == 8) { + return F8E8M0FNU; + } + return Unimplemented( + "Invalid or unsupported DLPack float8_e8m0fnu width: %d bits", + type.bits); + case kDLBfloat: + if (type.bits == 16) { + return BF16; + } + return Unimplemented( + "Invalid or unsupported DLPack bfloat width: %d bits", type.bits); + case kDLFloat: + switch (type.bits) { + case 16: + return F16; + case 32: + return F32; + case 64: + return F64; + default: + return Unimplemented( + "Invalid or unsupported DLPack float width: %d bits", type.bits); + } + case kDLComplex: + switch (type.bits) { + case 64: + return C64; + case 128: + return C128; + default: + return Unimplemented( + "Invalid or unsupported DLPack complex width: %d bits", + type.bits); + } + default: + return Unimplemented("Unknown or invalid DLPack type code %d", type.code); + } +} + +absl::StatusOr> StridesToLayout( + absl::Span dims, absl::Span strides) { + CHECK_EQ(dims.size(), strides.size()); + std::vector minor_to_major(dims.size()); + std::iota(minor_to_major.begin(), minor_to_major.end(), 0); + absl::c_sort(minor_to_major, [&](int a, int b) { + if (strides[a] < strides[b]) { + return true; + } + if (strides[a] > strides[b]) { + return false; + } + // If two dimensions have the same stride, prefer the major-to-minor + // interpretation of the ordering, since that's what JAX wants. + return b < a; + }); + int64_t stride = 1; + for (int64_t d : minor_to_major) { + if (dims[d] > 1 && strides[d] != stride) { + return Unimplemented( + "Only DLPack tensors with trivial (compact) striding are supported; " + "i.e., tensors whose striding represents a transposition of the " + "underlying buffer but not broadcasting. Dimensions were: [%s], " + "strides were [%s].", + absl::StrJoin(dims, ","), absl::StrJoin(strides, ",")); + } + stride *= dims[d]; + } + return minor_to_major; +} + +absl::StatusOr DLDeviceTypeForDevice(const PjRtDevice& device) { + if (device.client()->platform_id() == CpuId()) { + return kDLCPU; + } else if (device.client()->platform_id() == CudaId()) { + return kDLCUDA; + } else if (device.client()->platform_id() == RocmId()) { + return kDLROCM; + } + return InvalidArgument("Device %s cannot be used as a DLPack device.", + device.DebugString()); +} + +absl::StatusOr DLDeviceForDevice(const PjRtDevice& device) { + DLDevice context; + TF_ASSIGN_OR_RETURN(context.device_type, DLDeviceTypeForDevice(device)); + context.device_id = device.local_hardware_id().value(); + return context; +} + +absl::StatusOr DeviceForDLDevice(const PjRtClient* cpu_client, + const PjRtClient* gpu_client, + const DLDevice& context) { + switch (context.device_type) { + case kDLCPU: + if (cpu_client == nullptr) { + return InvalidArgument( + "DLPack tensor is on CPU, but no CPU backend was provided."); + } + TF_RET_CHECK(cpu_client->platform_id() == CpuId()); + return cpu_client->LookupAddressableDevice( + xla::PjRtLocalDeviceId(context.device_id)); + case kDLCUDA: + if (gpu_client == nullptr) { + return InvalidArgument( + "DLPack tensor is on GPU, but no GPU backend was provided."); + } + TF_RET_CHECK(gpu_client->platform_id() == CudaId()); + return gpu_client->LookupAddressableDevice( + xla::PjRtLocalDeviceId(context.device_id)); + case kDLROCM: + if (gpu_client == nullptr) { + return InvalidArgument( + "DLPack tensor is on GPU, but no GPU backend was provided."); + } + TF_RET_CHECK(gpu_client->platform_id() == RocmId()); + return gpu_client->LookupAddressableDevice( + xla::PjRtLocalDeviceId(context.device_id)); + default: + return InvalidArgument("Unknown/unsupported DLPack device type %d", + context.device_type); + } +} + +absl::Status VerifyDType(const DLTensor& dl_tensor) { + if (dl_tensor.dtype.bits % 8 != 0) { + return InvalidArgument( + "Unsupported DLPack tensor dtype: bits should be a multiple of 8, got " + "%d", + dl_tensor.dtype.bits); + } + + if (dl_tensor.dtype.lanes != 1) { + return InvalidArgument( + "Unsupported DLPack tensor dtype: lanes should be equal to 1, got %d", + dl_tensor.dtype.lanes); + } + + return absl::OkStatus(); +} + +absl::StatusOr> GetByteStrides(const DLTensor& dl_tensor) { + TF_RETURN_IF_ERROR(VerifyDType(dl_tensor)); + + // Convert element strides from the number of elements to the number of bytes. + std::vector strides; + strides.reserve(dl_tensor.ndim); + for (int i = 0; i < dl_tensor.ndim; ++i) { + strides.push_back(dl_tensor.strides[i] * dl_tensor.dtype.bits / 8); + } + return strides; +} + +absl::StatusOr> MakePjrtBuffer( + PjRtDevice& device, ::DLManagedTensor* dlmt, const Shape& shape, + PrimitiveType element_type, absl::Span dimensions, + std::optional stream = std::nullopt) { + std::function on_delete_callback; + if (dlmt->deleter) { + on_delete_callback = [dlmt]() { dlmt->deleter(dlmt); }; + } + + // First try to create a view. + void* data = + static_cast(dlmt->dl_tensor.data) + dlmt->dl_tensor.byte_offset; + auto result = device.client()->CreateViewOfDeviceBuffer( + data, shape, *device.default_memory_space(), on_delete_callback, stream); + + // If that fails with invalid argument, it's possibly because of the incorrect + // alignment. If we're on CPU, we can create a copy of buffer. + if (result.status().code() == absl::StatusCode::kInvalidArgument && + dlmt->dl_tensor.device.device_type == kDLCPU) { + LOG(WARNING) << "DLPack buffer is not aligned (data at: " << data + << "). Creating a copy."; + + // Convert tensor strides (expressed in number of elements) to byte strides. + std::optional> byte_strides; + if (dlmt->dl_tensor.strides) { + TF_ASSIGN_OR_RETURN(byte_strides, GetByteStrides(dlmt->dl_tensor)); + } + + TF_ASSIGN_OR_RETURN(auto* memory_space, device.default_memory_space()); + + // Create a copy. + result = device.client()->BufferFromHostBuffer( + data, element_type, dimensions, byte_strides, + PjRtClient::HostBufferSemantics::kMutableZeroCopy, on_delete_callback, + memory_space, /*device_layout=*/nullptr); + } + return result; +} + +} // namespace + +absl::StatusOr BufferToDLPackManagedTensor( + nb::handle py_buffer, std::optional stream) { + ifrt::Array* ifrt_array = nb::cast(py_buffer).ifrt_array(); + if (ifrt_array == nullptr) { + return Unimplemented( + "BufferToDLPackManagedTensor called on deleted array."); + } + auto* arr = llvm::dyn_cast_or_null(ifrt_array); + if (arr == nullptr) { + throw XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend only."); + } + PjRtBuffer* pjrt_buffer = arr->pjrt_buffers().front().get(); + + if (pjrt_buffer->IsTuple()) { + return Unimplemented( + "BufferToDLPackManagedTensor is not implemented for tuple " + "buffers."); + } + if (pjrt_buffer->has_dynamic_dimensions()) { + return Unimplemented("DynamicShape is not implemented in DLPack."); + } + + auto pack = std::make_unique(); + DLTensor& dt = pack->tensor.dl_tensor; + { + // AcquireExternalReference may block; there are no API guarantees. + GlobalPyRefManager()->CollectGarbage(); + nb::gil_scoped_release gil_release; + TF_ASSIGN_OR_RETURN(pack->external_reference, + pjrt_buffer->AcquireExternalReference()); + if (stream) { + TF_RETURN_IF_ERROR( + pack->external_reference->WaitUntilBufferReadyOnStream(*stream)); + } else { + TF_RETURN_IF_ERROR( + AwaitBuffersReady(absl::MakeConstSpan(&ifrt_array, 1))); + } + } + pack->buffer_reference = nb::borrow(py_buffer); + + dt.data = pack->external_reference->OpaqueDeviceMemoryDataPointer(); + pack->tensor.manager_ctx = pack.get(); + pack->tensor.deleter = DLPackTensorDeleter; + TF_ASSIGN_OR_RETURN(dt.device, DLDeviceForDevice(*pjrt_buffer->device())); + dt.device.device_id = pjrt_buffer->device()->local_hardware_id().value(); + dt.ndim = pjrt_buffer->dimensions().size(); + TF_ASSIGN_OR_RETURN(dt.dtype, + PrimitiveTypeToDLDataType(pjrt_buffer->element_type())); + + pack->shape = std::vector(pjrt_buffer->dimensions().begin(), + pjrt_buffer->dimensions().end()); + + // TODO(b/327524065): use PjRtLayout directly instead of xla::Layout + Layout xla_layout = pjrt_buffer->layout()->xla_layout(); + pack->strides = StridesForShape(pjrt_buffer->element_type(), + pjrt_buffer->dimensions(), xla_layout); + + dt.shape = reinterpret_cast(pack->shape.data()); + dt.strides = reinterpret_cast(pack->strides.data()); + dt.byte_offset = 0; + + // We cannot use nanobind's capsule object constructor because we need to + // detect if the capsule name has been changed in the deleter, but nanobind + // hides the underlying Python object from the deleter. + nb::capsule capsule = nb::steal( + PyCapsule_New(&pack.release()->tensor, kDlTensorCapsuleName, + [](PyObject* obj) noexcept { + DLManagedTensor* dlmt = static_cast( + PyCapsule_GetPointer(obj, kDlTensorCapsuleName)); + if (dlmt) { + DLPackTensorDeleter(dlmt); + } else { + // The tensor has been deleted. Clear any error from + // PyCapsule_GetPointer. + PyErr_Clear(); + } + })); + if (!capsule.ptr()) { + throw nb::python_error(); + } + return capsule; +} + +absl::StatusOr DLPackManagedTensorToBuffer( + const nb::capsule& tensor, std::optional> cpu_client, + std::optional> gpu_client) { + // TODO(hyeontaek): This is a potential target for an IFRT client to multiplex + // multiple PjRt clients. Devices from these PjRt clients could be expressed + // as a unified set of IFRT devices. + auto* cpu_pjrt_client = cpu_client ? (*cpu_client)->pjrt_client() : nullptr; + auto* gpu_pjrt_client = gpu_client ? (*gpu_client)->pjrt_client() : nullptr; + + if (absl::string_view(tensor.name()) != kDlTensorCapsuleName) { + return InvalidArgument( + "DLPack tensor must be a capsule with name \"dltensor\", got \"%s\". " + "Note that a DLPack tensor may be consumed at most once.", + absl::string_view(tensor.name())); + } + DLManagedTensor* dlmt = static_cast(tensor.data()); + if (dlmt->dl_tensor.ndim < 0) { + return InvalidArgument( + "Number of dimensions in DLManagedTensor must be nonnegative, got %d", + dlmt->dl_tensor.ndim); + } + TF_ASSIGN_OR_RETURN(PjRtDevice * device, + DeviceForDLDevice(cpu_client ? cpu_pjrt_client : nullptr, + gpu_client ? gpu_pjrt_client : nullptr, + dlmt->dl_tensor.device)); + absl::Span dimensions( + reinterpret_cast(dlmt->dl_tensor.shape), dlmt->dl_tensor.ndim); + TF_ASSIGN_OR_RETURN(PrimitiveType element_type, + DLDataTypeToPrimitiveType(dlmt->dl_tensor.dtype)); + + std::vector minor_to_major; + if (dlmt->dl_tensor.strides && + absl::c_find(dimensions, 0) == dimensions.end()) { + absl::Span strides( + reinterpret_cast(dlmt->dl_tensor.strides), + dlmt->dl_tensor.ndim); + TF_ASSIGN_OR_RETURN(minor_to_major, StridesToLayout(dimensions, strides)); + } else { + minor_to_major.resize(dlmt->dl_tensor.ndim); + std::iota(minor_to_major.rbegin(), minor_to_major.rend(), 0); + } + Shape shape = ShapeUtil::MakeShapeWithDenseLayout(element_type, dimensions, + minor_to_major); + + // Raise an error if the resulting PjRtBuffer would have a non-default layout. + // TODO(skyewm): we do this because JAX doesn't currently have good support + // for non-default layouts, and will return wrong results if a non-default + // layout is passed to a computation expecting default layouts. Remove this + // special case when non-default layouts are better supported by JAX. + TF_ASSIGN_OR_RETURN(Layout default_layout, device->client()->GetDefaultLayout( + element_type, dimensions)); + if (shape.layout() != default_layout) { + return Unimplemented( + "from_dlpack got array with non-default layout with minor-to-major " + "dimensions (%s), expected (%s)", + absl::StrJoin(shape.layout().minor_to_major(), ","), + absl::StrJoin(default_layout.minor_to_major(), ",")); + } + + std::function on_delete_callback; + if (dlmt->deleter) { + on_delete_callback = [dlmt]() { dlmt->deleter(dlmt); }; + } + + TF_ASSIGN_OR_RETURN( + auto pjrt_buffer, + MakePjrtBuffer(*device, dlmt, shape, element_type, dimensions)); + + // We have taken ownership of the array inside the capsule; make sure the + // capsule it cannot be used again. + PyCapsule_SetName(tensor.ptr(), "used_dltensor"); + PyCapsule_SetDestructor(tensor.ptr(), nullptr); + // TODO(phawkins): simplify the expression below once we know cpu_client is + // always non-null. + auto client = (cpu_client && device->client() == cpu_pjrt_client) + ? std::move(*cpu_client) + : std::move(*gpu_client); + auto* ifrt_client = + llvm::dyn_cast_or_null(client->ifrt_client()); + if (ifrt_client == nullptr) { + throw XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend only."); + } + TF_ASSIGN_OR_RETURN(auto ifrt_array, + ifrt_client->CreatePjRtArray(std::move(pjrt_buffer))); + return PyArray::MakeFromSingleDeviceArray(std::move(client), Traceback::Get(), + std::move(ifrt_array), false, true); +} + +absl::StatusOr DLPackManagedTensorToBuffer( + const nb::capsule& tensor, ifrt::Device* ifrt_device, + nb_class_ptr client, std::optional stream) { + ifrt::PjRtDevice* device = + llvm::dyn_cast_or_null(ifrt_device); + if (device == nullptr) { + throw XlaRuntimeError( + "DLPack is supported for PjRt-compatible backends only."); + } + if (!device->IsAddressable()) { + throw XlaRuntimeError( + "DLPack is only supported for devices addressable by the current " + "process."); + } + if (absl::string_view(tensor.name()) != kDlTensorCapsuleName) { + return InvalidArgument( + "DLPack tensor must be a capsule with name \"dltensor\", got \"%s\". " + "Note that a DLPack tensor may be consumed at most once.", + absl::string_view(tensor.name())); + } + DLManagedTensor* dlmt = static_cast(tensor.data()); + if (dlmt->dl_tensor.ndim < 0) { + return InvalidArgument( + "Number of dimensions in DLManagedTensor must be nonnegative, got %d", + dlmt->dl_tensor.ndim); + } + absl::Span dimensions( + reinterpret_cast(dlmt->dl_tensor.shape), dlmt->dl_tensor.ndim); + TF_ASSIGN_OR_RETURN(PrimitiveType element_type, + DLDataTypeToPrimitiveType(dlmt->dl_tensor.dtype)); + + std::vector minor_to_major; + if (dlmt->dl_tensor.strides && + absl::c_find(dimensions, 0) == dimensions.end()) { + absl::Span strides( + reinterpret_cast(dlmt->dl_tensor.strides), + dlmt->dl_tensor.ndim); + TF_ASSIGN_OR_RETURN(minor_to_major, StridesToLayout(dimensions, strides)); + } else { + minor_to_major.resize(dlmt->dl_tensor.ndim); + std::iota(minor_to_major.rbegin(), minor_to_major.rend(), 0); + } + Shape shape = ShapeUtil::MakeShapeWithDenseLayout(element_type, dimensions, + minor_to_major); + + TF_ASSIGN_OR_RETURN(auto pjrt_buffer, + MakePjrtBuffer(*device->pjrt_device(), dlmt, shape, + element_type, dimensions, stream)); + + // We have taken ownership of the array inside the capsule; make sure the + // capsule it cannot be used again. + PyCapsule_SetName(tensor.ptr(), "used_dltensor"); + PyCapsule_SetDestructor(tensor.ptr(), nullptr); + + auto* ifrt_client = + llvm::dyn_cast_or_null(client->ifrt_client()); + if (ifrt_client == nullptr) { + throw XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend only."); + } + TF_ASSIGN_OR_RETURN(auto ifrt_array, + ifrt_client->CreatePjRtArray(std::move(pjrt_buffer))); + return PyArray::MakeFromSingleDeviceArray(std::move(client), Traceback::Get(), + std::move(ifrt_array), false, true); +} + +absl::StatusOr PrimitiveTypeToNbDLDataType( + PrimitiveType type) { + TF_ASSIGN_OR_RETURN(DLDataType dl_type, PrimitiveTypeToDLDataType(type)); + + nanobind::dlpack::dtype nb_type; + nb_type.lanes = dl_type.lanes; + nb_type.bits = dl_type.bits; + nb_type.code = dl_type.code; + + return nb_type; +} + +} // namespace xla diff --git a/jaxlib/xla/dlpack.h b/jaxlib/xla/dlpack.h new file mode 100644 index 000000000000..7fffdc345d79 --- /dev/null +++ b/jaxlib/xla/dlpack.h @@ -0,0 +1,58 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_DLPACK_H_ +#define JAXLIB_XLA_DLPACK_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "nanobind/nanobind.h" +#include "nanobind/ndarray.h" +#include "jaxlib/xla/nb_class_ptr.h" +#include "jaxlib/xla/py_client.h" +#include "xla/python/ifrt/device.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +// If take_ownership is true, ownership of the buffer is handed to DLPack, and +// the receiver may mutate the buffer as they see fit. Otherwise PjRt retains +// ownership of the buffer and it should be immutable. +// +// stream, if set, is a GPU stream, e.g. cudaStream_t for CUDA GPUs, that should +// be synchronized to the buffer as per +// https://dmlc.github.io/dlpack/latest/python_spec.html#python-specification-for-dlpack. +absl::StatusOr BufferToDLPackManagedTensor( + nanobind::handle buffer, std::optional stream); + +absl::StatusOr DLPackManagedTensorToBuffer( + const nanobind::capsule& tensor, + std::optional> cpu_client, + std::optional> gpu_client); + +absl::StatusOr DLPackManagedTensorToBuffer( + const nanobind::capsule& tensor, ifrt::Device* device, + nb_class_ptr client, std::optional stream); + +// Converts a PrimitiveType to the nanobind specific implementation of +// DLDataType. +absl::StatusOr PrimitiveTypeToNbDLDataType( + PrimitiveType type); + +} // namespace xla + +#endif // JAXLIB_XLA_DLPACK_H_ diff --git a/jaxlib/xla/guard_lib.cc b/jaxlib/xla/guard_lib.cc new file mode 100644 index 000000000000..77866741819c --- /dev/null +++ b/jaxlib/xla/guard_lib.cc @@ -0,0 +1,197 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This files implements the configuration management for different types of +// guards. +// C++ backends are responsible for enforcing transfer guard levels. + +#include "jaxlib/xla/guard_lib.h" + +#include +#include + +#include "absl/base/attributes.h" +#include "absl/functional/function_ref.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "xla/util.h" + +namespace jax { + +namespace nb = ::nanobind; + +namespace { + +// Protected by the GIL. +GuardState& global_state = *new GuardState(); + +ABSL_CONST_INIT thread_local GuardState thread_local_state; + +// The default transfer guard level. +constexpr TransferGuardLevel kDefaultGuardLevel = TransferGuardLevel::kAllow; + +// The default garbage collection guard level. +constexpr GarbageCollectionGuardLevel kDefaultGarbageCollectionGuardLevel = + GarbageCollectionGuardLevel::kAllow; + +// Returns the transfer guard action for a transfer. +TransferGuardAction GetTransferGuardAction(TransferGuardLevel guard_level, + bool explicit_transfer) { + switch (guard_level) { + case TransferGuardLevel::kAllow: + return TransferGuardAction::kAllow; + case TransferGuardLevel::kLog: + if (explicit_transfer) { + return TransferGuardAction::kAllow; + } else { + return TransferGuardAction::kLog; + } + case TransferGuardLevel::kDisallow: + if (explicit_transfer) { + return TransferGuardAction::kAllow; + } else { + return TransferGuardAction::kDisallow; + } + case TransferGuardLevel::kLogExplicit: + return TransferGuardAction::kLog; + case TransferGuardLevel::kDisallowExplicit: + return TransferGuardAction::kDisallow; + default: + // Unreachable; gracefully handle the unexpected guard level and prevent a + // compiler warning. + return TransferGuardAction::kDisallow; + } +} + +// Returns the transfer guard action for a host-to-device transfer. +// REQUIRES: Python GIL. +TransferGuardAction GetTransferGuardActionForHostToDevice() { + return GetTransferGuardAction( + thread_local_state.host_to_device.value_or( + global_state.host_to_device.value_or(kDefaultGuardLevel)), + thread_local_state.explicit_device_put); +} + +// Returns the transfer guard action for a device-to-device transfer. +// REQUIRES: Python GIL. +TransferGuardAction GetTransferGuardActionForDeviceToDevice() { + return GetTransferGuardAction( + thread_local_state.device_to_device.value_or( + global_state.device_to_device.value_or(kDefaultGuardLevel)), + thread_local_state.explicit_device_put); +} + +// Returns the transfer guard action for a device-to-host transfer. +// REQUIRES: Python GIL. +TransferGuardAction GetTransferGuardActionForDeviceToHost() { + return GetTransferGuardAction( + thread_local_state.device_to_host.value_or( + global_state.device_to_host.value_or(kDefaultGuardLevel)), + thread_local_state.explicit_device_get); +} + +} // namespace + +absl::Status ApplyTransferGuardToHostToDevice( + absl::FunctionRef formatter) { + switch (GetTransferGuardActionForHostToDevice()) { + case TransferGuardAction::kAllow: + break; + case TransferGuardAction::kLog: + LOG(WARNING) << "host-to-device transfer: " << formatter(); + break; + case TransferGuardAction::kDisallow: + return xla::InvalidArgument("Disallowed host-to-device transfer: %s", + formatter()); + } + return absl::OkStatus(); +} + +absl::Status ApplyTransferGuardToDeviceToDevice( + absl::FunctionRef formatter) { + switch (GetTransferGuardActionForDeviceToDevice()) { + case TransferGuardAction::kAllow: + break; + case TransferGuardAction::kLog: + LOG(WARNING) << "device-to-device transfer: " << formatter(); + break; + case TransferGuardAction::kDisallow: + return xla::InvalidArgument("Disallowed device-to-device transfer: %s", + formatter()); + } + return absl::OkStatus(); +} + +absl::Status ApplyTransferGuardToDeviceToHost( + absl::FunctionRef formatter) { + switch (GetTransferGuardActionForDeviceToHost()) { + case TransferGuardAction::kAllow: + break; + case TransferGuardAction::kLog: + LOG(WARNING) << "device-to-host transfer: " << formatter(); + break; + case TransferGuardAction::kDisallow: + return xla::InvalidArgument("Disallowed device-to-host transfer: %s", + formatter()); + } + return absl::OkStatus(); +} + +GarbageCollectionGuardLevel GetGarbageCollectArrayGuard() { + return thread_local_state.garbage_collect_array.value_or( + global_state.garbage_collect_array.value_or( + kDefaultGarbageCollectionGuardLevel)); +} + +void BuildGuardSubmodule(nb::module_& m) { + nb::module_ glib = + m.def_submodule("guard_lib", "Jax support library for guards"); + + nb::enum_ tglevel(glib, "TransferGuardLevel"); + tglevel.value("ALLOW", TransferGuardLevel::kAllow); + tglevel.value("LOG", TransferGuardLevel::kLog); + tglevel.value("DISALLOW", TransferGuardLevel::kDisallow); + tglevel.value("LOG_EXPLICIT", TransferGuardLevel::kLogExplicit); + tglevel.value("DISALLOW_EXPLICIT", TransferGuardLevel::kDisallowExplicit); + + nb::enum_ gcglevel( + glib, "GarbageCollectionGuardLevel"); + gcglevel.value("ALLOW", GarbageCollectionGuardLevel::kAllow); + gcglevel.value("LOG", GarbageCollectionGuardLevel::kLog); + gcglevel.value("FATAL", GarbageCollectionGuardLevel::kFatal); + + nb::class_ tgstate(glib, "GuardState"); + tgstate.def_rw("host_to_device", &GuardState::host_to_device, + nb::arg().none()); + tgstate.def_rw("device_to_device", &GuardState::device_to_device, + nb::arg().none()); + tgstate.def_rw("device_to_host", &GuardState::device_to_host, + nb::arg().none()); + tgstate.def_rw("explicit_device_put", &GuardState::explicit_device_put); + tgstate.def_rw("explicit_device_get", &GuardState::explicit_device_get); + tgstate.def_rw("garbage_collect_array", &GuardState::garbage_collect_array, + nb::arg().none()); + + glib.def( + "global_state", [&]() { return &global_state; }, + nb::rv_policy::reference); + glib.def( + "thread_local_state", [&]() { return &thread_local_state; }, + nb::rv_policy::reference); +} + +} // namespace jax diff --git a/jaxlib/xla/guard_lib.h b/jaxlib/xla/guard_lib.h new file mode 100644 index 000000000000..8ddf6e8e892e --- /dev/null +++ b/jaxlib/xla/guard_lib.h @@ -0,0 +1,115 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_GUARD_LIB_H_ +#define JAXLIB_XLA_GUARD_LIB_H_ + +#include +#include + +// placeholder for index annotation headers +#include "absl/functional/function_ref.h" +#include "absl/status/status.h" +#include "nanobind/nanobind.h" + +namespace jax { + +// Transfer guard level chosen by the user code. +enum class TransferGuardLevel { + // Explicit transfers: allow + // Implicit transfers: allow + kAllow, + // Explicit transfers: allow + // Implicit transfers: log + kLog, + // Explicit transfers: allow + // Implicit transfers: disallow + kDisallow, + // Explicit transfers: log + // Implicit transfers: log + kLogExplicit, + // Explicit transfers: disallow + // Implicit transfers: disallow + kDisallowExplicit, +}; + +// Garbage collection guard level chose by the user code. +enum class GarbageCollectionGuardLevel { + // Silently allow the object to be garbage collected. + kAllow, + // Log and allow the object to be garbage collected. + kLog, + // Fatal crash on object garbage collection. + kFatal, +}; + +// Flags for guard levels are controlled by: +// - a global flag value, +// e.g., associated to --jax_transfer_guard_device_to_host +// which defaults to TransferGuardLevel::kAllow. +// - possibly a thread-local value, which initially is std::nullopt and +// overrides the global value if set. The thread-local state is used to +// implement context managers that locally override the global state. +// +// Explicit device_put/device_get contexts are tracked by context managers. +struct GuardState { + std::optional host_to_device; + std::optional device_to_device; + std::optional device_to_host; + bool explicit_device_put = false; + bool explicit_device_get = false; + + std::optional garbage_collect_array; +}; + +// Resulting action for a transfer given the transfer guard level and the +// transfer type. +enum class TransferGuardAction { + // Silently allow the transfer. + kAllow, + // Log and allow the transfer. + kLog, + // Disallow the transfer. + kDisallow, +}; + +// Guards a host-to-device transfer. formatter is called to describe the +// transfer in a log message or error status. +// REQUIRES: Python GIL. +absl::Status ApplyTransferGuardToHostToDevice( + absl::FunctionRef formatter); + +// Guards a device-to-device transfer. formatter is called to describe the +// transfer in a log message or error status. +// REQUIRES: Python GIL. +absl::Status ApplyTransferGuardToDeviceToDevice( + absl::FunctionRef formatter); + +// Guards a device-to-host transfer. formatter is called to describe the +// transfer in a log message or error status. +// REQUIRES: Python GIL. +absl::Status ApplyTransferGuardToDeviceToHost( + absl::FunctionRef formatter); + +// Returns the garbage collection guard level for "jax.Array" objects. +// REQUIRES: Python GIL. +GarbageCollectionGuardLevel GetGarbageCollectArrayGuard(); + +// The function to call in `xla.cc` to add the bindings for this module. +void BuildGuardSubmodule(nanobind::module_& m); + +} // namespace jax + +#endif // JAXLIB_XLA_GUARD_LIB_H_ diff --git a/jaxlib/xla/ifrt_proxy.cc b/jaxlib/xla/ifrt_proxy.cc new file mode 100644 index 000000000000..a89941f8581c --- /dev/null +++ b/jaxlib/xla/ifrt_proxy.cc @@ -0,0 +1,162 @@ +// Copyright 2023 The JAX Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "jaxlib/xla/ifrt_proxy.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/log/log_entry.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/function.h" // IWYU pragma: keep +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/unordered_map.h" // IWYU pragma: keep +#include "nanobind/stl/variant.h" // IWYU pragma: keep +#include "jaxlib/xla/nb_class_ptr.h" +#include "jaxlib/xla/py_client.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/ifrt/attribute_map.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt_proxy/client/registry.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/statusor.h" + +namespace nb = ::nanobind; + +namespace xla { +namespace ifrt { +namespace proxy { +namespace { + +struct PyClientConnectionOptions { + std::optional> on_disconnect; + std::optional> on_connection_update; + std::optional connection_timeout_in_seconds; + std::optional< + std::unordered_map>> + initialization_data; +}; + +absl::StatusOr> GetClient( + std::string proxy_server_address, + const PyClientConnectionOptions& py_options) { + DCHECK(PyGILState_Check()); + std::unique_ptr client; + + ClientConnectionOptions options; + if (py_options.on_disconnect) { + // While it is possible to pass around `py_options.on_disconnect` without + // wrapping it via a shared_ptr, copying the `py_options.on_disconnect` + // object can internally attempt to acquire the GIL [1], and can thus block + // or even deadlock. A unique_ptr or `absl::AnyInvocable` is not sufficient + // because downstream code can make copies. Reference: + // https://pybind11.readthedocs.io/en/stable/advanced/misc.html#common-sources-of-global-interpreter-lock-errors + auto py_on_disconnect = std::make_shared>( + std::move(*py_options.on_disconnect)); + + options.on_disconnect = + [on_disconnect = std::move(py_on_disconnect)](absl::Status s) mutable { + LOG(WARNING) << "Connection to server failed, calling supplied " + << "`on_disconnect` function: " << s; + tsl::Env::Default()->SchedClosure([s, on_disconnect]() mutable { + nb::gil_scoped_acquire gil_acquire; + (*on_disconnect)(s.ToString()); + on_disconnect = nullptr; + }); + }; + } + + if (py_options.on_connection_update) { + auto fn = std::make_shared>( + std::move(*py_options.on_connection_update)); + options.on_connection_update = [fn](absl::string_view log_line) -> void { + tsl::Env::Default()->SchedClosure([fn, str = std::string(log_line)] { + nb::gil_scoped_acquire gil_acquire; + (*fn)(std::string(str)); + }); + }; + } + + if (py_options.connection_timeout_in_seconds.has_value()) { + options.connection_timeout = + absl::Seconds(*py_options.connection_timeout_in_seconds); + } + + if (py_options.initialization_data.has_value()) { + AttributeMap::Map attribute_map; + for (const auto& [key, py_value] : *py_options.initialization_data) { + if (std::holds_alternative(py_value)) { + nb::bytes value = std::get(py_value); + attribute_map.insert({key, AttributeMap::StringValue(std::string( + value.c_str(), value.size()))}); + } else if (std::holds_alternative(py_value)) { + attribute_map.insert( + {key, AttributeMap::BoolValue(std::get(py_value))}); + } else { + CHECK(std::holds_alternative(py_value)); + attribute_map.insert( + {key, AttributeMap::Int64Value(std::get(py_value))}); + } + } + options.initialization_data = AttributeMap(std::move(attribute_map)); + } + + { + nb::gil_scoped_release gil_release; + TF_ASSIGN_OR_RETURN(client, CreateClient(proxy_server_address, options)); + } + + // Constructing `xla::PyClient` requires GIL as it may dec-ref Python objects. + return xla::PyClient::Make(std::move(client)); +} + +} // namespace + +void BuildIfrtProxySubmodule(nb::module_& m) { + nb::module_ sub_module = m.def_submodule("ifrt_proxy", "IFRT proxy"); + + nb::class_(sub_module, "ClientConnectionOptions") + .def(nb::init<>()) + .def_rw("on_disconnect", &PyClientConnectionOptions::on_disconnect, + nb::arg().none()) + .def_rw("on_connection_update", + &PyClientConnectionOptions::on_connection_update, + nb::arg().none()) + .def_rw("connection_timeout_in_seconds", + &PyClientConnectionOptions::connection_timeout_in_seconds, + nb::arg().none()) + .def_rw("initialization_data", + &PyClientConnectionOptions::initialization_data, + nb::arg().none()); + + sub_module.def("get_client", xla::ValueOrThrowWrapper(GetClient), + nb::arg("proxy_server_address"), nb::arg("options")); +} + +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/jaxlib/xla/ifrt_proxy.h b/jaxlib/xla/ifrt_proxy.h new file mode 100644 index 000000000000..a8fcb9e676ff --- /dev/null +++ b/jaxlib/xla/ifrt_proxy.h @@ -0,0 +1,31 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_IFRT_PROXY_CLIENT_PY_MODULE_H_ +#define JAXLIB_XLA_IFRT_PROXY_CLIENT_PY_MODULE_H_ + +#include "nanobind/nanobind.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +void BuildIfrtProxySubmodule(nanobind::module_& m); + +} // namespace proxy +} // namespace ifrt +} // namespace xla + +#endif // JAXLIB_XLA_IFRT_PROXY_CLIENT_PY_MODULE_H_ diff --git a/jaxlib/xla/jax_jit.cc b/jaxlib/xla/jax_jit.cc new file mode 100644 index 000000000000..4645c59c7147 --- /dev/null +++ b/jaxlib/xla/jax_jit.cc @@ -0,0 +1,495 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This files implements the `jax.jit` dispatch and just-in-time feature. +// +// In a nutshell, `Jit(f)` returns a callable that will dispatch (i.e. forward +// based on passed arguments dtypes/shapes/identity) the execution to a +// just-in-time compiled XLA Executable. All of that is done in C++ for +// performance reasons. +// +// This file contains the utilities to: +// (a) inspect arguments and describe their structure, dtype/shapes, etc. +// (b) keep a mapping from function signatures to compiled XLA Executables. + +#include "jaxlib/xla/jax_jit.h" + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/base/attributes.h" +#include "absl/container/inlined_vector.h" +#include "absl/hash/hash.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/pair.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/xla/py_values.h" +#include "jaxlib/xla/pytree.h" +#include "jaxlib/xla/sharding.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_layout.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/nb_absl_inlined_vector.h" // IWYU pragma: keep +#include "xla/python/nb_absl_span.h" // IWYU pragma: keep +#include "xla/python/types.h" +#include "xla/tsl/platform/logging.h" +#include "tsl/profiler/lib/traceme.h" + +namespace jax { + +namespace nb = nanobind; + +// TODO(phawkins): Add support for Tracers. +// TODO(jblespiau): Use absl absl::Status. + +namespace { + +// `thread_local_state.extra_jit_context` is set from Python. It's done when +// loading the Python jax modules on the main-thread. For other threads, we +// need to initialize the field the first time we access `thread_local_state`. +nb::object& initialize_local_state = *new nb::object(); + +} // namespace + +JitState& GlobalJitState() { + // Protected by the GIL. + static JitState& global_state = *new JitState(); + return global_state; +} + +JitState& ThreadLocalJitState() { + // TODO(phawkins): Google style guide forbids thread-local values with + // non-trivial destructors. + ABSL_CONST_INIT thread_local JitState thread_local_state; // NOLINT + DCHECK(PyGILState_Check()); + if (thread_local_state.extra_jit_context == std::nullopt) { + CHECK(initialize_local_state.ptr() != nullptr); + // Avoids reentrant calls to the initialization function. + thread_local_state.extra_jit_context = nb::none(); + initialize_local_state(); + } + return thread_local_state; +} + +bool GetDisableJit() { + auto& global_state = GlobalJitState(); + auto& thread_local_state = ThreadLocalJitState(); + CHECK(global_state.disable_jit.has_value()); + return thread_local_state.disable_jit.value_or(*global_state.disable_jit); +} + +bool GetEnableX64() { + auto& global_state = GlobalJitState(); + auto& thread_local_state = ThreadLocalJitState(); + CHECK(global_state.enable_x64.has_value()); + return thread_local_state.enable_x64.value_or(*global_state.enable_x64); +} + +std::optional GetDefaultDevice() { + auto& global_state = GlobalJitState(); + auto& thread_local_state = ThreadLocalJitState(); + return thread_local_state.default_device.has_value() + ? thread_local_state.default_device + : global_state.default_device; +} + +std::optional GetPostHook() { + auto& global_state = GlobalJitState(); + auto& thread_local_state = ThreadLocalJitState(); + return thread_local_state.post_hook.has_value() ? thread_local_state.post_hook + : global_state.post_hook; +} + +static std::string OptionalDebugString( + const std::optional optional) { + if (optional.has_value()) { + return nb::cast(nb::str(optional.value())); + } else { + return "None"; + } +} + +std::string ArgumentSignature::DebugString() const { + auto py_object_formatter = [](std::string* out, const nb::object& o) { + out->append(nb::cast(nb::str(o))); + }; + auto treedef_formatter = [](std::string* out, const xla::PyTreeDef& d) { + out->append(d.ToString()); + }; + return absl::StrFormat( + "static args (positional + keyword): [%s], " + "static arg keyword names: [%s], " + "dynamic arg signatures (positional + keyword): [%s]" + "dynamic arg shardings: [%s]", + absl::StrJoin(static_args, ",", py_object_formatter), + absl::StrJoin(static_arg_names, ",", py_object_formatter), + absl::StrJoin(dynamic_arg_names, ",", py_object_formatter), + absl::StrJoin(dynamic_arg_treedefs, "| ", treedef_formatter)); +} + +bool ArgumentSignature::operator==(const ArgumentSignature& other) const { + if (dynamic_arg_treedefs != other.dynamic_arg_treedefs) { + return false; + } + auto object_ptr_equality = [](nb::handle a, nb::handle b) { + return a.ptr() == b.ptr(); + }; + if (!absl::c_equal(dynamic_arg_names, other.dynamic_arg_names, + object_ptr_equality)) { + return false; + } + if (!absl::c_equal(static_arg_names, other.static_arg_names, + object_ptr_equality)) { + return false; + } + return absl::c_equal( + static_args, other.static_args, + [](const nb::object& a, const nb::object& b) { + try { + return a.type().ptr() == b.type().ptr() && a.equal(b); + } catch (const nb::python_error& e) { + throw std::invalid_argument(absl::StrCat( + "static arguments should be comparable using __eq__." + "The following error was raised when comparing two objects of " + "types ", + nb::cast(nb::str(a.type())), " and ", + nb::cast(nb::str(b.type())), + ". The error was:\n", e.what())); + } + }); +} + +std::string CallSignature::DebugString() const { + auto py_object_formatter = [](std::string* out, const nb::object& o) { + out->append(nb::cast(nb::str(o))); + }; + auto signature_formatter = [](std::string* out, + const xla::PyArgSignature& s) { + out->append(s.DebugString()); + }; + auto layout_formatter = [](std::string* out, + const std::shared_ptr& l) { + if (l != nullptr) { + out->append(l->ToString()); + } else { + out->append("None"); + } + }; + auto bool_formatter = [](std::string* out, bool o) { + out->append(o ? "true" : "false"); + }; + return absl::StrFormat( + "arg signature: %s\n" + "dynamic arg signatures (positional + keyword): %s\n" + "dynamic arg shardings: %s\n" + "dynamic arg layouts: %s\n" + "committed args: %s\n" + "device: %s\n" + "default_device: %s\n" + "jax_enable_x64: %d\n" + "global_extra_jit_context: %s\n" + "thread_local_extra_jit_context: %s\n" + "configs: %s\n", + arg_signature.DebugString(), + absl::StrJoin(dynamic_arg_signatures, ", ", signature_formatter), + absl::StrJoin(dynamic_arg_shardings, ", ", py_object_formatter), + absl::StrJoin(dynamic_arg_layouts, ", ", layout_formatter), + absl::StrJoin(committed_args, ",", bool_formatter), + device != nullptr ? device->DebugString() : "nullptr", + OptionalDebugString(default_device), jax_enable_x64, + OptionalDebugString(global_extra_jit_context), + OptionalDebugString(thread_local_extra_jit_context), + absl::StrJoin(configs, ", ", py_object_formatter)); +} + +bool CallSignature::operator==(const CallSignature& other) const { + if (arg_signature != other.arg_signature) { + return false; + } + if (dynamic_arg_signatures != other.dynamic_arg_signatures) { + return false; + } + if (device != other.device) { + return false; + } + if (jax_enable_x64 != other.jax_enable_x64) { + return false; + } + if (committed_args != other.committed_args) { + return false; + } + return + // `==` on py:objects is the Python `is`. We need equal. + absl::c_equal(dynamic_arg_shardings, other.dynamic_arg_shardings, + ShardingEqual) && + absl::c_equal(dynamic_arg_layouts, other.dynamic_arg_layouts, + [](const std::shared_ptr& a, + const std::shared_ptr& b) { + return (a && b) ? *a == *b : a == b; + }) && + (global_extra_jit_context.has_value() == + other.global_extra_jit_context.has_value()) && + (!global_extra_jit_context.has_value() || + global_extra_jit_context->equal(*other.global_extra_jit_context)) && + (default_device.has_value() == other.default_device.has_value()) && + (!default_device.has_value() || + default_device->equal(*other.default_device)) && + (thread_local_extra_jit_context.has_value() == + other.thread_local_extra_jit_context.has_value()) && + (!thread_local_extra_jit_context.has_value() || + thread_local_extra_jit_context->equal( + *other.thread_local_extra_jit_context)) && + configs.size() == other.configs.size() && + absl::c_equal( + configs, other.configs, + [](const nb::object& a, const nb::object& b) { return a.equal(b); }); +} + +// Filter out static arguments, flatten and concatenate other arguments (i.e. +// dynamic positional and keyword arguments), filling `arguments` in place. +absl::Status ParseArguments( + absl::Span positional_args, + absl::Span keyword_args, nb::handle kwnames, + absl::Span static_argnums, + absl::Span static_argnames, + xla::PyTreeRegistry* pytree_registry, ArgumentSignature& signature, + absl::InlinedVector& flat_dynamic_args) { + tsl::profiler::TraceMe traceme("ParseArguments"); + + DCHECK(absl::c_all_of(static_argnames, [](const nb::str& name) { + return PyUnicode_CHECK_INTERNED(name.ptr()); + })); + + flat_dynamic_args.reserve(positional_args.size() + keyword_args.size()); + if (static_argnums.empty()) { + signature.dynamic_arg_treedefs.reserve(positional_args.size()); + + // Positional arguments. + for (int i = 0; i < positional_args.size(); ++i) { + signature.dynamic_arg_treedefs.emplace_back(pytree_registry); + xla::PyTreeDef& pytree_def = signature.dynamic_arg_treedefs.back(); + pytree_def.Flatten(nb::handle(positional_args[i]), flat_dynamic_args); + } + } else { + signature.dynamic_arg_treedefs.reserve(positional_args.size()); + + // Positional arguments. + int num_positional_args = positional_args.size(); + for (int i = 0; i < positional_args.size(); ++i) { + if (std::find_if(static_argnums.begin(), static_argnums.end(), + [i, num_positional_args](int t) { + return t >= 0 ? i == t : i == t + num_positional_args; + }) == static_argnums.end()) { + signature.dynamic_arg_treedefs.emplace_back(pytree_registry); + xla::PyTreeDef& pytree_def = signature.dynamic_arg_treedefs.back(); + pytree_def.Flatten(positional_args[i], flat_dynamic_args); + } else { + signature.static_args.emplace_back( + nb::borrow(positional_args[i])); + } + } + } + + // Keyword arguments. + if (!keyword_args.empty()) { + std::vector> kwargs(keyword_args.size()); + // We first intern the keys, then sort them (by name, as in the Python path) + // (see also xla::PyTreeDef::Flatten) and then create the signatures. + // TODO(jblespiau): We should be able to sort the keys by interned-key + // pointers, but this requires the Python compilation to do the same. + for (int i = 0; i < keyword_args.size(); ++i) { + // Intern the key if not already interned. + PyObject* key = PyTuple_GET_ITEM(kwnames.ptr(), i); + Py_INCREF(key); + if (!PyUnicode_CHECK_INTERNED(key)) { + PyUnicode_InternInPlace(&key); + } + kwargs[i].first = key; + kwargs[i].second = keyword_args[i]; + } + + std::sort(kwargs.begin(), kwargs.end(), + [](const std::pair& a, + const std::pair& b) { + return a.first < b.first; + }); + auto kwarg_is_static = [&](nb::handle name) { + for (const auto& kw : static_argnames) { + if (kw.ptr() == name.ptr()) return true; + } + return false; + }; + + signature.dynamic_arg_names.reserve(keyword_args.size()); + for (int i = 0; i < keyword_args.size(); ++i) { + if (kwarg_is_static(kwargs[i].first)) { + signature.static_arg_names.push_back( + nb::steal(kwargs[i].first)); + signature.static_args.push_back( + nb::borrow(kwargs[i].second)); + } else { + signature.dynamic_arg_names.push_back( + nb::steal(kwargs[i].first)); + signature.dynamic_arg_treedefs.emplace_back(pytree_registry); + xla::PyTreeDef& pytree_def = signature.dynamic_arg_treedefs.back(); + pytree_def.Flatten(nb::handle(kwargs[i].second.ptr()), + flat_dynamic_args); + } + } + } + return absl::OkStatus(); +} + +void BuildJaxjitSubmodule(nb::module_& m) { + nb::module_ jitlib = m.def_submodule("jax_jit", "Jax C++ jit library"); + + nb::class_ jit_state_(jitlib, "JitState"); + jit_state_.def_rw("disable_jit", &JitState::disable_jit, nb::arg().none()); + jit_state_.def_rw("enable_x64", &JitState::enable_x64, nb::arg().none()); + jit_state_.def_rw("default_device", &JitState::default_device, + nb::arg().none()); + jit_state_.def_rw("extra_jit_context", &JitState::extra_jit_context, + nb::arg().none()); + jit_state_.def_rw("post_hook", &JitState::post_hook, nb::arg().none()); + + jitlib.def( + "global_state", [&]() { return &GlobalJitState(); }, + nb::rv_policy::reference); + jitlib.def( + "thread_local_state", [&]() { return &ThreadLocalJitState(); }, + nb::rv_policy::reference); + + jitlib.def( + "swap_thread_local_state_disable_jit", + [&](std::optional value) -> std::optional { + auto tls = &ThreadLocalJitState(); + auto result = tls->disable_jit; + tls->disable_jit = value; + return result; + }, + nb::arg("value").none(), nb::rv_policy::reference); + + jitlib.def("get_enable_x64", &GetEnableX64); + jitlib.def("set_thread_local_state_initialization_callback", + [](nb::object f) { initialize_local_state = f; }); + + nb::class_ arg_signature(jitlib, "PyArgSignature"); + arg_signature + .def_prop_ro( + "dtype", + [](const xla::PyArgSignature& sig) { + return xla::ValueOrThrow(xla::PrimitiveTypeToNbDtype(sig.dtype)); + }) + .def_prop_ro("shape", + [](const xla::PyArgSignature& sig) { + return xla::SpanToNbTuple(absl::MakeConstSpan(sig.shape)); + }) + .def_ro("weak_type", &xla::PyArgSignature::weak_type); + jitlib.def("_ArgSignatureOfValue", + xla::ValueOrThrowWrapper(xla::PyArgSignatureOfValue)); + + jitlib.def("_is_float0", &xla::IsFloat0); + + nb::class_ argument_signature(jitlib, "ArgumentSignature"); + argument_signature.def_ro("static_args", &ArgumentSignature::static_args) + .def_ro("static_arg_names", &ArgumentSignature::static_arg_names) + .def_ro("dynamic_arg_names", &ArgumentSignature::dynamic_arg_names) + .def_ro("dynamic_arg_treedefs", &ArgumentSignature::dynamic_arg_treedefs) + .def("__repr__", &ArgumentSignature::DebugString) + .def("__str__", &ArgumentSignature::DebugString) + .def("__hash__", + [](const ArgumentSignature& s) { return absl::HashOf(s); }) + .def("__eq__", [](const ArgumentSignature& a, + const ArgumentSignature& b) { return a == b; }) + .def("__ne__", [](const ArgumentSignature& a, + const ArgumentSignature& b) { return a != b; }); + + jitlib.def( + "parse_arguments", + [](nb::sequence positional_args, nb::sequence keyword_args, + nb::tuple kwnames, absl::Span static_argnums, + absl::Span static_argnames, + xla::PyTreeRegistry* pytree_registry) { + ArgumentSignature signature; + absl::InlinedVector flat_dynamic_args; + nb::object positional_args_seq = nb::steal(PySequence_Fast( + positional_args.ptr(), "positional_args must be a list or tuple")); + if (!positional_args_seq.ptr()) { + throw nb::python_error(); + } + nb::object keyword_args_seq = nb::steal(PySequence_Fast( + keyword_args.ptr(), "keyword_args must be a list or tuple")); + if (!keyword_args_seq.ptr()) { + throw nb::python_error(); + } + absl::Span positional_args_span = + absl::MakeSpan(PySequence_Fast_ITEMS(positional_args_seq.ptr()), + PySequence_Fast_GET_SIZE(positional_args_seq.ptr())); + absl::Span keyword_args_span = + absl::MakeSpan(PySequence_Fast_ITEMS(keyword_args_seq.ptr()), + PySequence_Fast_GET_SIZE(keyword_args_seq.ptr())); + + // Intern the static argument names. + std::vector static_argnames_interned; + static_argnames_interned.reserve(static_argnames.size()); + for (const nb::str& name : static_argnames) { + PyObject* s = name.inc_ref().ptr(); + PyUnicode_InternInPlace(&s); + static_argnames_interned.push_back(nb::steal(s)); + } + + xla::ThrowIfError( + ParseArguments(positional_args_span, keyword_args_span, kwnames, + static_argnums, static_argnames_interned, + pytree_registry, signature, flat_dynamic_args)); + return std::make_pair(std::move(signature), + std::move(flat_dynamic_args)); + }, + nb::arg("positional_args"), nb::arg("keyword_args"), nb::arg("kwnames"), + nb::arg("static_argnums"), nb::arg("static_argnames"), + nb::arg("pytree_registry"), + R"doc(Parses the arguments to a function as jax.jit would. + +Returns a ArgumentSignature and the flattened dynamic arguments. + +Args: + positional_args: The positional arguments. + keyword_args: The keyword arguments. + kwnames: The keyword names. + static_argnums: The static argument numbers. + static_argnames: The static argument names. + pytree_registry: The pytree registry. +)doc"); +} + +} // namespace jax diff --git a/jaxlib/xla/jax_jit.h b/jaxlib/xla/jax_jit.h new file mode 100644 index 000000000000..9eba2e9d3228 --- /dev/null +++ b/jaxlib/xla/jax_jit.h @@ -0,0 +1,266 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_JAX_JIT_H_ +#define JAXLIB_XLA_JAX_JIT_H_ + +#include + +#include +#include +#include +#include +#include +#include +#include + +// placeholder for index annotation headers +#include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "jaxlib/xla/py_values.h" +#include "jaxlib/xla/python_ref_manager.h" +#include "jaxlib/xla/pytree.h" +#include "jaxlib/xla/sharding.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_layout.h" +#include "xla/tsl/platform/logging.h" + +namespace jax { + +// Flags, such as JIT disable and the x64 mode, are controlled by: +// - a global flag value, e.g., associated to --jax_enable_x64 +// - possibly a thread-local value, which initially is std::nullopt and +// overrides the global value if set. The thread-local state is +// used to implement context managers that locally override the global state. +struct JitState { + ~JitState() { + if (extra_jit_context) { + // We likely do not hold the GIL if this JitState is thread-local, so we + // hand the Python object to the global reference manager to destroy. + nanobind::object o = std::move(*extra_jit_context); + xla::GlobalPyRefManager()->AddGarbage(absl::MakeSpan(&o, 1)); + extra_jit_context = std::nullopt; + } + } + + std::optional disable_jit; + std::optional enable_x64; + + // Used to manually set the default device jax should use. May be unset even + // in global state, indicating there is no manual override. + // TODO(skyewm): make this a C++ type when all JAX backends support a single + // C++ device interface + std::optional default_device; + + // Extra context that should be included in the JIT cache key. Must be + // hashable and have an equality defined. + std::optional extra_jit_context; + + // A callback that, if present, is called when a JITted function is executed + // from cache. May be unset even in global state. + std::optional post_hook; +}; + +JitState& GlobalJitState(); + +// Requires the GIL. +JitState& ThreadLocalJitState(); + +// Getters for JitState fields that first look in thread-local state, then +// fallback to global state. +bool GetDisableJit(); +bool GetEnableX64(); + +// TODO(skyewm): return a C++ type when all JAX backends support a single C++ +// device interface +std::optional GetDefaultDevice(); +std::optional GetPostHook(); + +// An ArgumentSignature describes the static arguments to a function call, and +// how the dynamic arguments are related to the arguments. Together with the +// values of the dynamic arguments, this fully describes the arguments. +struct ArgumentSignature { + // A PyTreeDef for each dynamic argument, positional arguments first + // followed by keyword arguments. Keyword arguments are in the order given + // by dynamic_arg_names. + absl::InlinedVector dynamic_arg_treedefs; + + // Dynamic keyword argument names. Interned, and sorted by the keyword + // name. Interned values are safe to compare by pointer. + std::vector dynamic_arg_names; + + // Static arguments. Contains the positional arguments sorted in argument + // order, followed by static keyword arguments in the order given by + // `static_arg_names`. + std::vector static_args; + + // Static keyword argument names. Interned, and sorted by keyword name. + std::vector static_arg_names; + + bool operator==(const ArgumentSignature& other) const; + bool operator!=(const ArgumentSignature& other) const { + return !(*this == other); + } + + std::string DebugString() const; +}; + +template +H AbslHashValue(H h, const ArgumentSignature& s) { + h = H::combine(std::move(h), s.dynamic_arg_treedefs, + s.dynamic_arg_names.size(), s.static_args.size(), + s.static_arg_names.size()); + + for (const auto& name : s.dynamic_arg_names) { + h = H::combine(std::move(h), name.ptr()); + } + for (size_t i = 0; i < s.static_args.size(); ++i) { + const auto& static_arg = s.static_args[i]; + Py_hash_t hash; + try { + hash = nanobind::hash(static_arg); + } catch (const nanobind::python_error& e) { + if (!e.matches(PyExc_TypeError)) throw; + throw std::invalid_argument(absl::StrCat( + "Non-hashable static arguments are not supported. An error occurred " + "while trying to hash an object of type ", + nanobind::cast(nanobind::str(static_arg.type())), + ", ", nanobind::cast(nanobind::str(static_arg)), + ". The error was:\n", e.what(), "\n")); + } + h = H::combine(std::move(h), hash); + } + for (const auto& name : s.static_arg_names) { + h = H::combine(std::move(h), name.ptr()); + } + return h; +} + +// Filter out static arguments, flatten and concatenate other arguments (i.e. +// dynamic positional and keyword arguments), filling `arguments` in place. +// Args: +// positional_args: positional arguments +// keyword_args: the values of the keyword arguments +// kwnames: either None or a tuple containing the keyword argument names +// static_argnums: the indices of the static arguments in the positional +// arguments +// static_argnames: the names of the static arguments, which must be interned. +// pytree_registry: the registry to use to convert the arguments to pytrees +// signature: output; describes the static arguments and the identities of the +// dynamic arguments. +// flat_dynamic_args: output; the concatenation of the dynamic positional +// arguments and sorted keyword arguments. +absl::Status ParseArguments( + absl::Span positional_args, + absl::Span keyword_args, nanobind::handle kwnames, + absl::Span static_argnums, + absl::Span static_argnames, + xla::PyTreeRegistry* pytree_registry, ArgumentSignature& signature, + absl::InlinedVector& flat_dynamic_args); + +// The signature of Python jitted function call, partitioned into: +// - dynamic positional arguments (i.e. positional args which are not static) +// - static positional arguments (i.e. the args associated to static_argnums) +// - keyword arguments +// The CallSignature should unambiguously identify a function call, thus, +// equality is based on: +// (a) Same PyTree for all dynamic positional arguments and keyword arguments +// (a) equality of the arguments and keyword arguments ArgSignature +// (a) equality (delegated to Python) of the static arguments. +struct CallSignature { + // Not part of the signature, but we need it for error messages. + absl::string_view function_name; + + ArgumentSignature arg_signature; + + // Shape and dtype for both the dynamic positional arguments and the keyword + // arguments (sorted by keyword name). + absl::InlinedVector dynamic_arg_signatures; + + // The sharding of the jax.Array arguments. + std::vector dynamic_arg_shardings; + + // The layout of the jax.Array arguments. + std::vector> dynamic_arg_layouts; + + absl::InlinedVector committed_args; + + // For JIT, we need this in the key because computation follows the data, so + // we may have multiple executables depending on the devices the data is on. + // This is not the case for PMAP, and is set to `nullptr`. + xla::PjRtDevice* device = nullptr; + bool jax_enable_x64; + + // For JIT on PJIT, we need to fallback to python whenever default_device + // changes. + std::optional default_device; + + // Opaque additional context that should be included as part of the cache key. + std::optional global_extra_jit_context; + std::optional thread_local_extra_jit_context; + + std::vector configs; + + bool operator==(const CallSignature& other) const; + bool operator!=(const CallSignature& other) const { + return !(*this == other); + } + + std::string DebugString() const; +}; + +template +H AbslHashValue(H h, const CallSignature& s) { + h = H::combine(std::move(h), s.arg_signature, s.dynamic_arg_signatures); + + DCHECK(s.dynamic_arg_shardings.empty() || + s.dynamic_arg_shardings.size() == s.dynamic_arg_signatures.size()); + + DCHECK(s.dynamic_arg_layouts.empty() || + s.dynamic_arg_layouts.size() == s.dynamic_arg_signatures.size()); + + // TODO(chky): For now, we are only hashing the pointer of shardings to avoid + // slow python hashing function. Consider implementing hashing function and + // equality checks in C++ in jax::Sharding and use those here. + for (const auto& sharding : s.dynamic_arg_shardings) { + h = H::combine(std::move(h), ShardingHash(sharding)); + } + + for (const auto& layout : s.dynamic_arg_layouts) { + if (layout != nullptr) { + h = H::combine(std::move(h), *layout); + } + } + + h = H::combine(std::move(h), s.committed_args, s.device, s.jax_enable_x64); + + // We do not hash the extra_jit_context fields since calling Python hash + // functions is expensive (~300ns) and we don't expect a large number of + // different contexts. + return h; +} + +// The function to call in `xla.cc` to add the bindings for this module. +void BuildJaxjitSubmodule(nanobind::module_& m); + +} // namespace jax + +#endif // JAXLIB_XLA_JAX_JIT_H_ diff --git a/jaxlib/xla/jax_jit_test.py b/jaxlib/xla/jax_jit_test.py new file mode 100644 index 000000000000..a090bc8dfadd --- /dev/null +++ b/jaxlib/xla/jax_jit_test.py @@ -0,0 +1,47 @@ +# Copyright 2024 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for jax_jit helper functions.""" + +from absl.testing import absltest + +from jax.jaxlib.xla import xla_client + +jax_jit = xla_client._xla.jax_jit +pytree = xla_client._xla.pytree + +pytree_registry = pytree.default_registry() + + +class JaxJitTest(absltest.TestCase): + + def testParseArguments(self): + sig, args = jax_jit.parse_arguments( + positional_args=[1, 2, 3], + keyword_args=[4, 5], + kwnames=("a", "b"), + static_argnums=[0, 2], + static_argnames=["a"], + pytree_registry=pytree_registry, + ) + self.assertEqual(args, [2, 5]) + self.assertEqual(sig.static_args, [1, 3, 4]) + self.assertEqual(sig.static_arg_names, ["a"]) + _, leaf = pytree_registry.flatten(0) + self.assertEqual(sig.dynamic_arg_names, ["b"]) + self.assertEqual(sig.dynamic_arg_treedefs, [leaf, leaf]) + + +if __name__ == "__main__": + absltest.main() diff --git a/jaxlib/xla/mlir.cc b/jaxlib/xla/mlir.cc new file mode 100644 index 000000000000..29ef86d50df6 --- /dev/null +++ b/jaxlib/xla/mlir.cc @@ -0,0 +1,235 @@ +/* Copyright 2021 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/mlir.h" + +#include + +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Bytecode/BytecodeWriter.h" +#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/OwningOpRef.h" +#include "mlir/Parser/Parser.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LogicalResult.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "stablehlo/dialect/Serialization.h" +#include "xla/hlo/builder/xla_computation.h" +#include "xla/hlo/translate/stablehlo.h" +#include "xla/mlir_hlo/mhlo/transforms/passes.h" +#include "xla/pjrt/mlir_to_hlo.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/refine_polymorphic_shapes.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" + +namespace nb = nanobind; + +namespace xla { +namespace { + +std::string PrintModule(mlir::ModuleOp module) { + std::string s; + llvm::raw_string_ostream os(s); + mlir::OpPrintingFlags flags; + flags.enableDebugInfo(); + module->print(os, flags); + return s; +} + +absl::StatusOr SerializeUsingBytecode(mlir::ModuleOp module) { + std::string bytecode; + llvm::raw_string_ostream os(bytecode); + mlir::BytecodeWriterConfig config; + if (mlir::failed(mlir::writeBytecodeToFile(module, os, config))) { + return absl::InvalidArgumentError("mlir::writeBytecodeToFile failed"); + } + return bytecode; +} + +void EnablePrintBeforeAndAfter(mlir::PassManager& pm) { + auto print_before = [](mlir::Pass*, mlir::Operation*) { return true; }; + auto print_after = [](mlir::Pass*, mlir::Operation*) { return true; }; + pm.enableIRPrinting(print_before, print_after); +} + +absl::StatusOr HloToStableHlo(const nb::bytes& hlo_module_proto) { + mlir::MLIRContext context; + if (VLOG_IS_ON(3)) context.disableMultithreading(); + HloModuleProto proto; + proto.ParseFromArray(hlo_module_proto.c_str(), hlo_module_proto.size()); + TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, + ConvertHloToStablehlo(context, &proto)); + TF_ASSIGN_OR_RETURN(std::string bytecode, SerializeUsingBytecode(*module)); + return nb::bytes(bytecode.data(), bytecode.size()); +} + +// Converts an XlaComputation to a StableHLO mlir::Module string. +// Exists for backwards compatibility. +// TODO(phawkins): port remaining users of XlaComputations to use mlir::Modules +// instead and delete this function. +absl::StatusOr PyXlaComputationToMlirModule( + const XlaComputation& computation) { + mlir::MLIRContext context; + if (VLOG_IS_ON(3)) context.disableMultithreading(); + TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, + ConvertHloToStablehlo(context, &computation.proto())); + return PrintModule(*module); +} + +absl::StatusOr PyMlirModuleToXlaComputation( + absl::string_view mlir_module, bool use_tuple_args, bool return_tuple) { + mlir::MLIRContext context; + TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, + ParseMlirModuleString(mlir_module, context)); + XlaComputation computation; + // SDY dialect may be part of the module which XLA doesn't know about. + TF_RETURN_IF_ERROR(ExportShardyForHloRoundTrip(*module)); + TF_RETURN_IF_ERROR(MlirToXlaComputation(*module, computation, use_tuple_args, + return_tuple, + /*use_shardy=*/false)); + return computation; +} + +absl::StatusOr PyMhloToStablehlo(absl::string_view mlir_module) { + mlir::MLIRContext context; + if (VLOG_IS_ON(3)) context.disableMultithreading(); + // JAX can be customized in a way that involves operations from custom + // dialects showing up in JAX IR. + // `ParseMlirModuleString` won't know about these dialects, but that's fine + // since we just want to convert MHLO ops to StableHLO ops here and leave + // everything else unchanged. + // In order to achieve that, we're allowing unregistered dialects here. + context.allowUnregisteredDialects(true); + TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, + ParseMlirModuleString(mlir_module, context)); + mlir::PassManager pm(&context); + if (VLOG_IS_ON(3)) EnablePrintBeforeAndAfter(pm); + pm.addPass(mlir::mhlo::createHloLegalizeToStablehloPass()); + if (!mlir::succeeded(pm.run(*module))) { + return tsl::errors::InvalidArgument("MHLO => StableHLO failed"); + } + // Use bytecode, passing unregistered dialects with properties causes issues + // when using textual assembly. + TF_ASSIGN_OR_RETURN(std::string bytecode, SerializeUsingBytecode(*module)); + return nb::bytes(bytecode.data(), bytecode.size()); +} + +absl::StatusOr PySerializePortableArtifact( + absl::string_view mlir_module, absl::string_view target) { + mlir::MLIRContext context; + if (VLOG_IS_ON(3)) context.disableMultithreading(); + TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, + ParseMlirModuleString(mlir_module, context)); + + // Serialize portable artifact + TF_ASSIGN_OR_RETURN( + std::string bytecode, + SerializeUsingVersionedStablehlo(*module, target, /*inplace=*/true)); + return nb::bytes(bytecode.data(), bytecode.size()); +} + +absl::StatusOr PyDeserializePortableArtifact( + const nb::bytes& bytecode_str) { + mlir::MLIRContext context; + mlir::OwningOpRef module = + mlir::stablehlo::deserializePortableArtifact( + absl::string_view(bytecode_str.c_str(), bytecode_str.size()), + &context); + if (!module) + return tsl::errors::InvalidArgument("Failed to deserialize StableHLO"); + return PrintModule(*module); +} + +} // namespace + +void BuildMlirSubmodule(nb::module_& m) { + nb::module_ mlir_module = m.def_submodule("mlir", "MLIR/XLA integration"); + + mlir_module.def("hlo_to_stablehlo", xla::ValueOrThrowWrapper(HloToStableHlo), + nb::arg("computation")); + + mlir_module.def("xla_computation_to_mlir_module", + xla::ValueOrThrowWrapper(PyXlaComputationToMlirModule), + nb::arg("computation")); + mlir_module.def( + "mlir_module_to_xla_computation", + [](const nb::bytes& bytecode, bool use_tuple_args, bool return_tuple) { + return xla::ValueOrThrow(PyMlirModuleToXlaComputation( + absl::string_view(bytecode.c_str(), bytecode.size()), + use_tuple_args, return_tuple)); + }, + nb::arg("mlir_module"), nb::arg("use_tuple_args") = false, + nb::arg("return_tuple") = false); + mlir_module.def("mlir_module_to_xla_computation", + xla::ValueOrThrowWrapper(PyMlirModuleToXlaComputation), + nb::arg("mlir_module"), nb::arg("use_tuple_args") = false, + nb::arg("return_tuple") = false); + mlir_module.def( + "mhlo_to_stablehlo", + [](const nb::bytes& bytecode) { + return xla::ValueOrThrow(PyMhloToStablehlo( + absl::string_view(bytecode.c_str(), bytecode.size()))); + }, + nb::arg("mlir_module")); + mlir_module.def("mhlo_to_stablehlo", + xla::ValueOrThrowWrapper(PyMhloToStablehlo), + nb::arg("mlir_module")); + mlir_module.def( + "serialize_portable_artifact", + [](const nb::bytes& bytecode, absl::string_view target) { + return xla::ValueOrThrow(PySerializePortableArtifact( + absl::string_view(bytecode.c_str(), bytecode.size()), target)); + }, + nb::arg("mlir_module"), nb::arg("target")); + mlir_module.def("serialize_portable_artifact", + xla::ValueOrThrowWrapper(PySerializePortableArtifact), + nb::arg("mlir_module"), nb::arg("target")); + mlir_module.def("deserialize_portable_artifact", + xla::ValueOrThrowWrapper(PyDeserializePortableArtifact), + nb::arg("mlir_module")); + mlir_module.def( + "refine_polymorphic_shapes", + [](nb::bytes bytecode, bool enable_shape_assertions, + bool validate_static_shapes, bool enable_shardy) -> nb::bytes { + std::string buffer; + llvm::raw_string_ostream os(buffer); + xla::ThrowIfError(RefinePolymorphicShapes( + absl::string_view(bytecode.c_str(), bytecode.size()), os, + enable_shape_assertions, validate_static_shapes, enable_shardy)); + return nb::bytes(buffer.data(), buffer.size()); + }, + nb::arg("mlir_module"), nb::arg("enable_shape_assertions") = true, + nb::arg("validate_static_shapes") = true, + nb::arg("enable_shardy") = false, + R"(Refines the dynamic shapes for a module. + The "main" function must have static shapes and all the + intermediate dynamic shapes depend only on the input static + shapes. Optionally, also validates that the resulting module has + only static shapes. + )"); +} + +} // namespace xla diff --git a/jaxlib/xla/mlir.h b/jaxlib/xla/mlir.h new file mode 100644 index 000000000000..ee95f5f95921 --- /dev/null +++ b/jaxlib/xla/mlir.h @@ -0,0 +1,28 @@ +/* Copyright 2021 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_MLIR_H_ +#define JAXLIB_XLA_MLIR_H_ + +// placeholder for index annotation headers +#include "nanobind/nanobind.h" + +namespace xla { + +void BuildMlirSubmodule(nanobind::module_& m); + +} // namespace xla + +#endif // JAXLIB_XLA_MLIR_H_ diff --git a/jaxlib/xla/nb_class_ptr.h b/jaxlib/xla/nb_class_ptr.h new file mode 100644 index 000000000000..e468860dc661 --- /dev/null +++ b/jaxlib/xla/nb_class_ptr.h @@ -0,0 +1,59 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_NB_CLASS_PTR_H_ +#define JAXLIB_XLA_NB_CLASS_PTR_H_ + +#include "nanobind/nanobind.h" + +namespace xla { + +// A reference-counting smart pointer to a nanobind-wrapped class on the Python +// heap. Type T must be a class known to nanobind via a nanobind::class_ +// declaration. nb_class_ptr is useful for managing C++ classes that may be +// allocated inline in Python objects on the Python heap. +template +class nb_class_ptr : public nanobind::object { + public: + inline nb_class_ptr() : nanobind::object() {} + inline nb_class_ptr(nanobind::handle h, ::nanobind::detail::borrow_t) + : nanobind::object(h, ::nanobind::detail::borrow_t{}) {} + inline nb_class_ptr(nanobind::handle h, ::nanobind::detail::steal_t) + : nanobind::object(h, ::nanobind::detail::steal_t{}) {} + inline static bool check_(nanobind::handle h) { + nanobind::handle type = nanobind::type(); + return h.type().is(type); + }; + + T* operator->() const { return nanobind::inst_ptr(ptr()); } + T& operator*() const { return *nanobind::inst_ptr(ptr()); } + T* get() const { return ptr() ? nanobind::inst_ptr(ptr()) : nullptr; } +}; + +// This function is analogous to std::make_unique(...), but instead it +// allocates the object on the Python heap +template +nb_class_ptr make_nb_class(Args&&... args) { + nanobind::handle type = nanobind::type(); + nanobind::object instance = nanobind::inst_alloc(type); + T* ptr = nanobind::inst_ptr(instance); + new (ptr) T(std::forward(args)...); + nanobind::inst_mark_ready(instance); + return nb_class_ptr(instance.release(), ::nanobind::detail::steal_t{}); +} + +} // namespace xla + +#endif // JAXLIB_XLA_NB_CLASS_PTR_H_ diff --git a/jaxlib/xla/pjit.cc b/jaxlib/xla/pjit.cc new file mode 100644 index 000000000000..50bdc750d3a4 --- /dev/null +++ b/jaxlib/xla/pjit.cc @@ -0,0 +1,1408 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/pjit.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include // NOLINT +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/cleanup/cleanup.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" +#include "absl/hash/hash.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "absl/synchronization/notification.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/xla/config.h" +#include "jaxlib/xla/guard_lib.h" +#include "jaxlib/xla/jax_jit.h" +#include "jaxlib/xla/nb_class_ptr.h" +#include "jaxlib/xla/py_array.h" +#include "jaxlib/xla/py_executable.h" +#include "jaxlib/xla/py_values.h" +#include "jaxlib/xla/python_ref_manager.h" +#include "jaxlib/xla/pytree.h" +#include "jaxlib/xla/sharding.h" +#include "jaxlib/xla/traceback.h" +#include "xla/layout.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/lru_cache.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/nb_helpers.h" +#include "xla/python/nb_numpy.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" +#include "tsl/profiler/lib/traceme.h" + +namespace jax { +namespace { + +namespace nb = nanobind; + +struct PjitCacheEntry { + explicit PjitCacheEntry(xla::PyTreeRegistry* registry) + : out_pytree_def(registry) {} + std::shared_ptr executable; + std::vector in_shardings; + std::vector out_avals; + std::vector out_dtypes; + std::vector> out_shapes; + std::vector out_weak_types; + std::vector out_shardings; + std::vector out_committed; + xla::PyTreeDef out_pytree_def; + // Bitvector of kept arguments from Jaxpr DCE pass. Used to drop some `args` + // in PjitFunction::Call before calling into compiled computation. + std::vector kept_var_bitvec; + std::vector in_device_local_layouts; + + // Ensures a single thread performs the compilation for a given executable. + // + // The first thread (holding the GIL) will create the CacheEntry associated to + // a signature and if the object has been inserted already, other threads + // will wait for the notification. + absl::Notification compilation_complete; + + std::thread::id thread_id = std::this_thread::get_id(); + + bool fall_back_to_python = false; +}; + +// A PjitFunctionCache represents a cache of compiled functions that can be +// shared between one or more PjitFunction objects. It serves two goals: +// - reduce the number of lru caches (hash map) across multiple JITs. +// - make the cache global to increase cache hits (e.g. calling jit(f)(3) twice) +// keeping entries alive as long as the underlying function f is alive. +// Assume the cache is protected by the GIL. +class PjitFunctionCache { + public: + static constexpr int kDefaultCapacity = 4096; + explicit PjitFunctionCache(int capacity); + + // Cache entries are shared_ptr<>s because it's possible the cache entry + // might be evicted before we finish tracing/compiling. + typedef xla::LRUCache> Cache; + + // We include as part of the cache key `global_cache_key` (and any other + // fields that aren't subsumed by the CallSignature we compute for each call). + static std::shared_ptr Lookup( + xla::nb_class_ptr self, nb::handle function, + nb::object global_cache_key); + std::shared_ptr DefaultCache(); + + // These methods require the GIL or the object's lock in no-GIL mode. + int Size() const { return lru_list_.Size(); } + int Capacity() const { return lru_list_.Capacity(); } + void Clear() { + lru_list_.Clear(); + functions_.clear(); + } + + private: + struct Key { + nb::handle function; // Does not hold a reference. + + // Other fields that are part of the arguments to `jit`, but are not + // otherwise part of CallSignature. + nb::object global_cache_key; + + size_t cached_hash; + + bool operator==(const Key& other) const { + bool global_cache_eq; + try { + global_cache_eq = global_cache_key.equal(other.global_cache_key); + } catch (const nanobind::python_error& e) { + throw std::invalid_argument( + absl::StrCat("Equality of global cache key lead to an exception. " + "The error was:\n", + e.what(), "\n")); + } + return function.ptr() == other.function.ptr() && global_cache_eq; + } + + struct Hash { + size_t operator()(const Key& key) const { return key.cached_hash; } + }; + }; + + template + friend H AbslHashValue(H h, const Key& key) { + h = H::combine(std::move(h), key.function.ptr()); + Py_hash_t hash; + try { + hash = nb::hash(key.global_cache_key); + } catch (const nanobind::python_error& e) { + if (!e.matches(PyExc_TypeError)) throw; + throw std::invalid_argument(absl::StrCat( + "Hashing global cache key lead to an exception. The error was:\n", + e.what(), "\n")); + } + h = H::combine(std::move(h), hash); + return h; + } + + struct Value { + explicit Value(std::shared_ptr cache) : cache(std::move(cache)) {} + std::shared_ptr cache; + + // A weak reference to the key function. We use the weak reference to + // register a callback that is triggered when the key function is destroyed. + // We use a weak pointer because we want to allow caching across multiple + // calls to `pjit(f)` if `f` remains alive, but we do not want the cache + // to keep `f` alive if all other references are dropped. + std::optional weakref; + }; + + // lru_list_ and functions_ are protected by the GIL in GIL mode, and by the + // self object lock in freethreading mode. + Cache::LRUList lru_list_; + // We use std::unordered_map because ABSL containers are not exception safe: + std::unordered_map, Key::Hash> functions_; + // mu_ prevents concurrent insertions into functions_ if the gil or critical + // section lock is released during insertion. + absl::Mutex mu_; +}; + +PjitFunctionCache::PjitFunctionCache(int capacity) : lru_list_(capacity) {} + +std::shared_ptr PjitFunctionCache::DefaultCache() { + return std::make_shared(&lru_list_); +} + +/*static*/ std::shared_ptr PjitFunctionCache::Lookup( + xla::nb_class_ptr self, nb::handle function, + nb::object global_cache_key) ABSL_NO_THREAD_SAFETY_ANALYSIS { + // In no-GIL mode, a critical section on self plays the same role that + // the GIL plays in GIL mode. + nb::ft_object_guard lock(self); + { + // Because the gil (or the critical section lock) can be released during + // cache insertion, this forces the lock order to be mu_ then gil so we + // must release the gil first. + nb::gil_scoped_release release; + // Acquire a mutex to avoid problems where the gil is released during + // cache insertion and then a second thread invalidates the cache order. + self->mu_.Lock(); + } + absl::Cleanup unlock = [&self]() ABSL_UNLOCK_FUNCTION(self->mu_) { + self->mu_.Unlock(); + }; + Key key; + key.function = function; + key.global_cache_key = global_cache_key; + key.cached_hash = absl::HashOf(key); + auto insert = self->functions_.emplace(key, nullptr); + if (!insert.second) { + return insert.first->second->cache; + } + std::shared_ptr cache = std::make_shared(&self->lru_list_); + auto callback = + nb::cpp_function([self, key{std::move(key)}](nb::handle weakref) { + nb::ft_object_guard lock(self); + auto it = self->functions_.find(key); + if (it == self->functions_.end()) { + return; + } + // Remove the value from the map before destroying it. Destroying + // the value may release `lock` since it may call arbitrary Python + // code. + std::unique_ptr value = std::move(it->second); + self->functions_.erase(it); + value.reset(); + }); + PyObject* weakref = PyWeakref_NewRef(function.ptr(), callback.ptr()); + if (weakref) { + std::unique_ptr& entry = insert.first->second; + entry = std::make_unique(cache); + entry->weakref = nb::steal(weakref); + } else { + PyErr_Clear(); + // `function` is not weak-referenceable. Don't bother adding it to the + // shared cache in that case; the `jit` object will hold the only shared + // reference to the cache entry. + self->functions_.erase(insert.first); + } + return cache; +} + +class PjitFunction { + public: + PjitFunction(std::string function_name, std::optional fun, + nb::callable cache_miss, std::vector static_argnums, + std::vector static_argnames, + nb::object global_cache_key, + xla::nb_class_ptr pytree_registry, + nb::callable shard_arg_fallback, + xla::nb_class_ptr cache); + ~PjitFunction(); + + PjitFunction(const PjitFunction&) = delete; + PjitFunction& operator=(const PjitFunction&) = delete; + PjitFunction(PjitFunction&&) = default; + PjitFunction& operator=(PjitFunction&&) = default; + + // nb::object typed subclass for PjitFunction objects. + class pyobject : public nb::object { + public: + NB_OBJECT(pyobject, nb::object, "PjitFunction", + PjitFunction::IsPjitFunction); + pyobject() = default; + PjitFunction* func() const { + return PjitFunction::AsPjitFunctionUnchecked(*this); + } + }; + // Alias as ::object; outside the scope above we won't confuse nanobind's + // macros. + using object = pyobject; + + // Returns true if `h` is a PjitFunction. + static bool IsPjitFunction(nb::handle handle); + // Converts `handle` to a PjitFunction*. Does not do any checking. + static PjitFunction* AsPjitFunctionUnchecked(nb::handle handle); + + absl::StatusOr Call(nb::handle callable, PyObject* const* args, + size_t nargs, PyObject* kwnames); + + void InitExecutables(); + + void ClearPythonReferences(); + + const std::string& function_name() const { return function_name_; } + const std::optional& fun() const { return fun_; } + const nb::callable& cache_miss() const { return cache_miss_; } + const xla::nb_class_ptr& pytree_registry() const { + return pytree_registry_; + } + const nb::callable& shard_arg_fallback() const { return shard_arg_fallback_; } + + const std::vector& static_argnums() const { return static_argnums_; } + const std::vector& static_argnames() const { + return static_argnames_; + } + const nb::object& global_cache_key() const { return global_cache_key_; } + const xla::nb_class_ptr& cache() const { return cache_; } + + int cache_capacity() const { + nb::ft_object_guard lock(cache_); + return executables_->Size(); + } + + void ClearCache() { + nb::ft_object_guard lock(cache_); + executables_->Clear(); + } + + std::shared_ptr executables() { + nb::ft_object_guard lock(cache_); + return executables_; + } + + nb::object PythonSignature() { + if (!fun_.has_value()) { + throw nb::value_error( + absl::StrFormat( + "Calling __signature__ on PjitFunction(%s) not supported.", + function_name_) + .c_str()); + } + static const auto* inspect = + new nb::module_(nb::module_::import_("inspect")); + return inspect->attr("signature")(*fun_); + } + + private: + absl::Status ComputeCallSignature( + absl::Span flat_dynamic_args, + CallSignature& call_signature); + + void PopulateCacheEntry(PjitCacheEntry& cache_entry, + const nb::tuple& out_and_fastpath_data); + + std::string function_name_; + std::optional fun_; + nb::callable cache_miss_; + std::vector static_argnums_; + std::vector static_argnames_; + nb::object global_cache_key_; + + xla::nb_class_ptr pytree_registry_; + nb::callable shard_arg_fallback_; + xla::nb_class_ptr cache_; + + // In no-GIL mode executables_ is protected by the object lock on cache_, + // because it shared an LRU list with cache_. + std::shared_ptr executables_; +}; + +PjitFunction::PjitFunction( + std::string function_name, std::optional fun, + nb::callable cache_miss, std::vector static_argnums, + std::vector static_argnames, nb::object global_cache_key, + xla::nb_class_ptr pytree_registry, + nb::callable shard_arg_fallback, xla::nb_class_ptr cache) + : function_name_(std::move(function_name)), + fun_(std::move(fun)), + cache_miss_(std::move(cache_miss)), + static_argnums_(std::move(static_argnums)), + global_cache_key_(std::move(global_cache_key)), + pytree_registry_(std::move(pytree_registry)), + shard_arg_fallback_(std::move(shard_arg_fallback)), + cache_(std::move(cache)) { + std::sort(static_argnums_.begin(), static_argnums_.end()); + static_argnames_.reserve(static_argnames.size()); + for (nb::str& name : static_argnames) { + PyObject* s = name.inc_ref().ptr(); + PyUnicode_InternInPlace(&s); + static_argnames_.push_back(nb::steal(s)); + } +} + +void PjitFunction::InitExecutables() { + // Construction of the object hasn't completed yet, so we don't need to hold + // the cache lock to mutate executables_. + if (!fun_.has_value()) { + executables_ = cache_->DefaultCache(); + } else { + executables_ = cache_->Lookup(cache_, fun_.value(), global_cache_key_); + } +} + +PjitFunction::~PjitFunction() { + nb::ft_object_guard lock(cache_); + executables_ = nullptr; +} + +void CallShardArgFallback( + nb::handle arg, nb::handle sharding, nb::handle layout, + const nb::callable& fallback, + std::vector>& num_args_arrays, + std::vector& keep_alive_objects) { + tsl::profiler::TraceMe traceme("cpp_pjit_shard_arg_fallback"); + auto py_array_or_bufs = fallback(arg, sharding, layout); + auto py_array = nb::cast(py_array_or_bufs); + num_args_arrays.push_back(tsl::FormRef(py_array.ifrt_array())); + keep_alive_objects.push_back(std::move(py_array_or_bufs)); +} + +// Prepares the input PjRtBuffers from the python arguments. This is equivalent +// to shard_args() in pxla.py but for only a few supported cases. +absl::StatusOr>> +PrepareIfrtInputs(const xla::PyLoadedExecutable& executable, + absl::Span flat_dynamic_args, + bool enable_x64, const std::vector& kept_args, + const std::vector& in_shardings, + const std::vector& in_device_local_layouts, + const nb::callable& shard_arg_fallback, + std::vector& keep_alive_objects) { + const auto& addressable_devices = + executable.ifrt_loaded_executable()->addressable_devices(); + const auto& num_global_devices = + executable.ifrt_loaded_executable()->num_devices(); + int num_args = flat_dynamic_args.size(); + + std::vector> num_args_arrays; + num_args_arrays.reserve(num_args); + + struct CopyGroup { + std::vector indices; + std::vector> arrays; + }; + absl::flat_hash_map, + CopyGroup> + copy_groups; + + xla::DevicePutOptions options; + options.squash_64bit_types = !enable_x64; + options.allow_zero_copy = true; + xla::ifrt::Device* data_device = nullptr; + if (executable.ifrt_loaded_executable()->num_devices() == 1) { + data_device = executable.ifrt_loaded_executable()->addressable_devices()[0]; + } + int dce_i = 0; + for (int i = 0; i < num_args; ++i) { + if (!kept_args[i]) { + continue; + } + int dce_index = dce_i; + ++dce_i; + + const nb::object& arg = flat_dynamic_args[i]; + const nb::object& in_device_local_layout = + in_device_local_layouts[dce_index]; + + auto transfer_guard_formatter = [] { return std::string(""); }; + + if (arg.type().ptr() != xla::PyArray::type().ptr()) { + if (data_device != nullptr && in_device_local_layout.is_none()) { + TF_RETURN_IF_ERROR( + jax::ApplyTransferGuardToHostToDevice(transfer_guard_formatter)); + TF_ASSIGN_OR_RETURN( + auto on_device_fn, + DevicePut(arg, executable.ifrt_loaded_executable()->client(), + data_device, options, xla::ifrt::MemoryKind())); + TF_ASSIGN_OR_RETURN(xla::DevicePutResult on_device, [&]() { + // Must release the GIL before calling IFRT because backends may + // decide to block/sleep for device buffer allocation. + nb::gil_scoped_release gil_release; + return std::move(on_device_fn)(); + }()); + + num_args_arrays.push_back(std::move(on_device.ifrt_array)); + if (on_device.owning_pybuffer) { + keep_alive_objects.push_back(std::move(on_device.owning_pybuffer)); + } + continue; + } else { + CallShardArgFallback(arg, in_shardings[dce_index], + in_device_local_layout, shard_arg_fallback, + num_args_arrays, keep_alive_objects); + continue; + } + } + + xla::PyArray py_array = nb::borrow(arg); + const auto& sharding = py_array.sharding(); + int sharding_num_devices = jax::Sharding::SafeNumDevices(sharding); + + // Currently only committed PyArray inputs or uncommitted PyArray on a + // single device inputs are allowed. This is checked previously in the entry + // point of PjitFunction::Call(). + DCHECK(py_array.committed() || + (!py_array.committed() && sharding_num_devices == 1)); + + if (!in_device_local_layout.is_none()) { + TF_ASSIGN_OR_RETURN(auto arr_layout, py_array.ifrt_array()->layout()); + xla::Layout in_xc_layout = nb::cast( + in_device_local_layout.attr("_to_xla_layout")(py_array.dtype())); + if (in_xc_layout != arr_layout->xla_layout()) { + CallShardArgFallback(arg, in_shardings[dce_index], + in_device_local_layout, shard_arg_fallback, + num_args_arrays, keep_alive_objects); + continue; + } + } + + if (sharding.type().ptr() == jax::PmapSharding::type().ptr()) { + CallShardArgFallback(arg, in_shardings[dce_index], in_device_local_layout, + shard_arg_fallback, num_args_arrays, + keep_alive_objects); + continue; + } + + if (sharding_num_devices != num_global_devices) { + CallShardArgFallback(arg, in_shardings[dce_index], in_device_local_layout, + shard_arg_fallback, num_args_arrays, + keep_alive_objects); + continue; + } + + xla::ifrt::Array* ifrt_array = py_array.ifrt_array(); + // PyArray inputs should have already been checked in + // `xla::PyArgSignatureOfValue()` called by + // `PjitFunction::ComputeCallSignature()`. + DCHECK(ifrt_array != nullptr) << "PyArray has been unexpectedly deleted."; + + const auto& ifrt_sharding = ifrt_array->sharding(); + if (sharding_num_devices == 1 && + ifrt_sharding.devices()->devices().front() != addressable_devices[0]) { + auto& copy_group = + copy_groups[std::make_pair(ifrt_sharding.devices()->devices().front(), + ifrt_sharding.memory_kind())]; + copy_group.indices.push_back(num_args_arrays.size()); + copy_group.arrays.push_back(tsl::FormRef(ifrt_array)); + num_args_arrays.push_back({}); + } else { + num_args_arrays.push_back(tsl::FormRef(ifrt_array)); + } + + keep_alive_objects.push_back(arg); + } + + if (!copy_groups.empty()) { + xla::ifrt::Client* const ifrt_client = + executable.ifrt_loaded_executable()->client(); + xla::ifrt::DeviceListRef ifrt_devices = + ifrt_client->MakeDeviceList({addressable_devices[0]}); + for (auto& [key, group] : copy_groups) { + TF_ASSIGN_OR_RETURN( + auto copied_ifrt_arrays, + ifrt_client->CopyArrays(absl::MakeSpan(group.arrays), ifrt_devices, + /*memory_kind=*/std::nullopt, + xla::ifrt::ArrayCopySemantics::kReuseInput)); + for (int i = 0; i < copied_ifrt_arrays.size(); ++i) { + num_args_arrays[group.indices[i]] = std::move(copied_ifrt_arrays[i]); + } + } + } + + return num_args_arrays; +} + +absl::StatusOr PjitFunction::Call(nb::handle callable, + PyObject* const* args, + size_t nargs, PyObject* kwnames) { + tsl::profiler::TraceMe traceme( + [&] { return absl::StrCat("PjitFunction(", function_name_, ")"); }); + + // Make sure we trigger a garbage collection on JIT function calls. Otherwise + // code like + // f = jit(...) + // while True: + // f(x) + // may never free temporary buffers for copies of arguments. + xla::GlobalPyRefManager()->MaybeCollectGarbage(); + + if (GetDisableJit()) { + if (!fun_.has_value()) { + throw nb::value_error( + absl::StrFormat("Disable jit is not supported in the AOT path since " + "the function is not available for (%s)", + function_name_) + .c_str()); + } + return nb::steal( + PyObject_Vectorcall(fun_.value().ptr(), args, nargs, kwnames)); + } + + // Calls the cache_miss_ function. This just calls the Python function; it may + // return nullptr value if a Python exception is thrown. + auto cache_miss = [&]() -> nb::tuple { + return nb::steal( + PyObject_Vectorcall(cache_miss_.ptr(), args, nargs, kwnames)); + }; + + // Call the cache_miss() function, extracting the output data and ignoring + // the fastpath data. If the cache miss returns a Python error, returns + // nullptr and leaves the Python error set. + auto fallback_to_cache_miss = [&]() { + nb::tuple cache_miss_output = cache_miss(); + if (!cache_miss_output.ptr()) { + return nb::object(); + } + return nb::object(cache_miss_output[0]); + }; + + size_t num_positional_args = PyVectorcall_NARGS(nargs); + size_t num_keyword_args = kwnames ? PyTuple_GET_SIZE(kwnames) : 0; + absl::Span positional_args(args, num_positional_args); + absl::Span keyword_args(args + num_positional_args, + num_keyword_args); + + CallSignature call_signature; + std::vector keep_alive_objects; + absl::InlinedVector flat_dynamic_args; + auto status = ParseArguments( + positional_args, keyword_args, kwnames, static_argnums_, static_argnames_, + pytree_registry_.get(), call_signature.arg_signature, flat_dynamic_args); + if (!status.ok()) { + VLOG(2) << "ParseArguments failed: " << status; + return fallback_to_cache_miss(); + } + + // Perform a few checks for the arguments. Currently we are only allowing + // committed PyArray inputs. For other cases, e.g. Tracers or ShapedArray, it + // will fallback to python. For jit, numpy arrays and scalars are also + // allowed, which we will check later. + for (const auto& arg : flat_dynamic_args) { + if (arg.type().ptr() != xla::PyArray::type().ptr()) { + continue; + } + + xla::PyArray py_array = nb::borrow(arg); + + // Only allow committed PyArray in cpp pjit for now as the logic on handling + // sharding for uncommitted PyArray is complicated and still under + // development. + // + // TODO(chky): Consider support uncommitted PyArray in cpp when the python + // side stablizes. + if (!py_array.committed() && + jax::Sharding::SafeNumDevices(py_array.sharding()) > 1) { + VLOG(2) << "PyArray argument is not committed and number of global " + "devices is more than 1; fallback to python."; + return fallback_to_cache_miss(); + } + } + + status = ComputeCallSignature(flat_dynamic_args, call_signature); + if (!status.ok()) { + VLOG(2) << "ComputeCallSignature failed: " << status; + return fallback_to_cache_miss(); + } + + VLOG(2) << "CallSignature:\n" << call_signature.DebugString(); + bool inserted = false; + std::shared_ptr cache_entry; + { + nb::ft_object_guard lock(cache_); + cache_entry = executables_->GetOrCreateIfAbsent( + call_signature, [this, &inserted](const CallSignature& unused) { + inserted = true; + return std::make_shared(pytree_registry_.get()); + }); + } + + if (!cache_entry->compilation_complete.HasBeenNotified()) { + // In case of several threads attempting to compile the executable, only + // the one that inserted the item will perform the compilation. + if (inserted) { + nb::object out_and_fastpath_data; + nb::tuple out_tuple; + VLOG(2) << "Cache miss for " << call_signature.DebugString(); + bool remove_cache = false; + try { + // Calls Python and may release the GIL. May also throw if + // compilation/tracing fails. + out_and_fastpath_data = cache_miss(); + if (!out_and_fastpath_data.ptr()) { + throw nb::python_error(); + } + out_tuple = nb::cast(out_and_fastpath_data); + + PopulateCacheEntry(*cache_entry, out_tuple); + + if (out_tuple.size() > 2 && out_tuple[2].is_valid()) { + remove_cache = nb::cast(out_tuple[2]); + } + } catch (const std::exception& e) { + VLOG(2) << "cache miss fail: " << e.what(); + cache_entry->fall_back_to_python = true; + cache_entry->compilation_complete.Notify(); + throw; + } + cache_entry->compilation_complete.Notify(); + + if (remove_cache) { + nb::ft_object_guard lock(cache_); + executables_->Remove(call_signature); + } + + // We have already computed the result in the miss path so we can return + // it. We are even *required* to do so if there are donated arguments, + // because any donated buffers will now be invalid. + return nb::object(out_tuple[0]); + } else { + if (cache_entry->thread_id == std::this_thread::get_id()) { + auto error_string = absl::StrCat("Recursively calling jit: ", + call_signature.DebugString()); + PyErr_SetString(PyExc_RecursionError, error_string.c_str()); + throw nb::python_error(); + } + // Release the GIL while we wait, making sure the compile thread can + // lock it. + nb::gil_scoped_release release; + cache_entry->compilation_complete.WaitForNotification(); + } + } + + if (cache_entry->fall_back_to_python) { + VLOG(2) << "cpp pjit fallback to python."; + return fallback_to_cache_miss(); + } + + // A vector of [num_inputs]. + auto num_args_arrays = PrepareIfrtInputs( + *cache_entry->executable, flat_dynamic_args, + call_signature.jax_enable_x64, cache_entry->kept_var_bitvec, + cache_entry->in_shardings, cache_entry->in_device_local_layouts, + shard_arg_fallback_, keep_alive_objects); + + if (!num_args_arrays.ok()) { + VLOG(2) << "Failed to prepare IFRT inputs: " << num_args_arrays.status(); + return fallback_to_cache_miss(); + } + + xla::ifrt::ExecuteOptions execute_options = + cache_entry->executable->options(); + execute_options.launch_id = cache_entry->executable->GetNextLaunchId(); + execute_options.execution_stream_id = + tsl::Env::Default()->GetCurrentThreadId(); + + // A vector of [num_outputs]. + std::vector> output_arrays; + { + nb::gil_scoped_release gil_release; + TF_ASSIGN_OR_RETURN(auto result, + cache_entry->executable->ifrt_executable()->Execute( + absl::MakeSpan(*num_args_arrays), execute_options, + /*devices=*/std::nullopt)); + output_arrays = std::move(result.outputs); + } + + auto traceback = xla::Traceback::Get(); + + // Convert the ifrt::Array objects to PyArray. + int num_outputs = output_arrays.size(); + absl::InlinedVector outputs; + outputs.reserve(num_outputs); + for (int i = 0; i < num_outputs; ++i) { + // Creating the PyArray result. In addition to the IFRT arrays, the metadata + // like `aval` and `sharding` are retrieved from the cache for this + // function, which are produced by the python path in `cache_miss`. + xla::PyArray py_array( + cache_entry->out_avals[i], cache_entry->out_weak_types[i], + cache_entry->out_dtypes[i], cache_entry->out_shapes[i], + cache_entry->out_shardings[i], cache_entry->executable->client(), + traceback, std::move(output_arrays[i]), + /*committed=*/cache_entry->out_committed.at(i), /*skip_checks=*/true); + + outputs.push_back(std::move(py_array)); + } + + nb::object out = nb::steal( + cache_entry->out_pytree_def.Unflatten(outputs).release().ptr()); + + // If there is a post-hook function, call it with the inputs and the outputs. + std::optional post_hook = GetPostHook(); + if (post_hook) { + nb::tuple args_tuple = + nb::steal(PyTuple_New(num_positional_args)); + for (size_t i = 0; i < num_positional_args; ++i) { + Py_INCREF(args[i]); + PyTuple_SET_ITEM(args_tuple.ptr(), i, args[i]); + } + nb::dict kwargs; + if (kwnames) { + for (size_t i = 0; i < num_keyword_args; ++i) { + kwargs[nb::handle(PyTuple_GET_ITEM(kwnames, i))] = + nb::borrow(args[num_positional_args + i]); + } + } + (*post_hook)(nb::handle(callable.ptr()), args_tuple, kwargs, + nb::handle(out.ptr())); + } + + return out; +} + +absl::Status PjitFunction::ComputeCallSignature( + absl::Span flat_dynamic_args, CallSignature& signature) { + signature.function_name = function_name_; + + // Get dynamic argument signatures. + JitState& global_state = jax::GlobalJitState(); + JitState& tls = jax::ThreadLocalJitState(); + bool jax_enable_x64 = GetEnableX64(); + + signature.default_device = GetDefaultDevice(); + signature.jax_enable_x64 = jax_enable_x64; + + auto& dynamic_arg_signatures = signature.dynamic_arg_signatures; + dynamic_arg_signatures.reserve(flat_dynamic_args.size()); + auto& dynamic_arg_shardings = signature.dynamic_arg_shardings; + dynamic_arg_shardings.reserve(flat_dynamic_args.size()); + auto& dynamic_arg_layouts = signature.dynamic_arg_layouts; + dynamic_arg_layouts.reserve(flat_dynamic_args.size()); + + for (nb::handle arg : flat_dynamic_args) { + TF_ASSIGN_OR_RETURN(auto arg_signature, + xla::PyArgSignatureOfValue(arg, jax_enable_x64)); + signature.dynamic_arg_signatures.push_back(std::move(arg_signature)); + + // It should be already checked previously in the entry point of + // PjitFunction::Call(). + if (arg.type().ptr() == xla::PyArray::type().ptr()) { + auto py_array = nb::borrow(arg); + signature.dynamic_arg_shardings.push_back(py_array.sharding()); + auto layout = py_array.layout(); + if (absl::IsUnimplemented(layout.status())) { + signature.dynamic_arg_layouts.push_back(nullptr); + } else { + signature.dynamic_arg_layouts.push_back(*std::move(layout)); + } + signature.committed_args.push_back(py_array.committed()); + } else { + signature.dynamic_arg_shardings.push_back(nb::none()); + signature.dynamic_arg_layouts.push_back(nullptr); + signature.committed_args.push_back(false); + } + } + + signature.thread_local_extra_jit_context = tls.extra_jit_context; + signature.global_extra_jit_context = global_state.extra_jit_context; + signature.configs = JitConfigs(); + + return absl::OkStatus(); +} + +void PjitFunction::PopulateCacheEntry(PjitCacheEntry& cache_entry, + const nb::tuple& out_and_fastpath_data) { + DCHECK_GE(out_and_fastpath_data.size(), 2); + + if (out_and_fastpath_data[1].is_none()) { + VLOG(2) << "fastpath_data is none"; + cache_entry.fall_back_to_python = true; + return; + } + + nb::tuple fastpath_data = nb::cast(out_and_fastpath_data[1]); + + cache_entry.executable = nb::cast>( + fastpath_data.attr("xla_executable")); + + nb::sequence in_shardings = fastpath_data.attr("in_shardings"); + cache_entry.in_shardings.reserve(nb::len(in_shardings)); + for (nb::handle sharding : in_shardings) { + cache_entry.in_shardings.push_back(nb::borrow(sharding)); + } + + nb::sequence out_shardings = fastpath_data.attr("out_shardings"); + cache_entry.out_shardings.reserve(nb::len(out_shardings)); + for (nb::handle sharding : out_shardings) { + cache_entry.out_shardings.push_back(nb::borrow(sharding)); + } + + nb::sequence out_committed = fastpath_data.attr("out_committed"); + cache_entry.out_committed.reserve(nb::len(out_committed)); + for (nb::handle c : out_committed) { + cache_entry.out_committed.push_back(nb::cast(c)); + } + + nb::sequence out_avals = fastpath_data.attr("out_avals"); + cache_entry.out_avals.reserve(nb::len(out_avals)); + cache_entry.out_dtypes.reserve(nb::len(out_avals)); + cache_entry.out_shapes.reserve(nb::len(out_avals)); + cache_entry.out_weak_types.reserve(nb::len(out_avals)); + for (nb::handle aval : out_avals) { + cache_entry.out_avals.push_back(nb::borrow(aval)); + cache_entry.out_dtypes.push_back(aval.attr("dtype")); + cache_entry.out_shapes.push_back( + nb::cast>(aval.attr("shape"))); + cache_entry.out_weak_types.push_back( + nb::cast(aval.attr("weak_type"))); + } + + cache_entry.out_pytree_def = nb::cast( + nb::handle(fastpath_data.attr("out_pytree_def").ptr())); + + nb::sequence kept_var_bitvec = fastpath_data.attr("kept_var_bitvec"); + cache_entry.kept_var_bitvec.reserve(nb::len(kept_var_bitvec)); + for (nb::handle k : kept_var_bitvec) { + cache_entry.kept_var_bitvec.push_back(nb::cast(k)); + } + + nb::sequence in_device_local_layouts = + fastpath_data.attr("in_device_local_layouts"); + cache_entry.in_device_local_layouts.reserve(nb::len(in_device_local_layouts)); + for (nb::handle dll : in_device_local_layouts) { + cache_entry.in_device_local_layouts.push_back(nb::borrow(dll)); + } +} + +// Helper function used by the tp_clear GC method. +void PjitFunction::ClearPythonReferences() { + // TODO(mattjj): phawkins@ observed that the xla::PyTreeRegistry + // pytree_registry_ attribute of PjitFunction could in principle also have + // python references to clear + nb::callable cache_miss; + std::optional fun; + nb::callable shard_arg_fallback; + // Swap values for nulls before they are destroyed. See the Python + // Py_CLEAR() documentation for a discussion of this topic. + std::swap(cache_miss_, cache_miss); + std::swap(fun_, fun); + std::swap(shard_arg_fallback_, shard_arg_fallback); +} + +struct PjitFunctionObject { + PyObject_HEAD; +#if PY_VERSION_HEX < 0x030C0000 + PyObject* dict; // Dictionary for __dict__ + PyObject* weakrefs; // Weak references; for use by the Python interpreter. +#endif // PY_VERSION_HEX < 0x030C0000 + vectorcallfunc vectorcall; + PjitFunction fun; + + // Doubly-linked list of PjitFunctionObjects, protected by + // PjitFunctionStore::mu_ or the GIL in GIL mode. + PjitFunctionObject* next; + PjitFunctionObject* prev; +}; + +// Contains a list of all PjitFunctionObjects. +// Thread-safe. +class PjitFunctionStore { + public: + void Insert(PjitFunctionObject* o) { + nb::ft_lock_guard lock(mu_); + o->next = compiled_functions_; + o->prev = nullptr; + if (o->next) { + o->next->prev = o; + } + compiled_functions_ = o; + } + + void Remove(PjitFunctionObject* o) { + nb::ft_lock_guard lock(mu_); + if (o->next) { + o->next->prev = o->prev; + } + if (o->prev) { + o->prev->next = o->next; + } else { + compiled_functions_ = o->next; + } + } + + void ClearCaches() { + std::vector< + std::pair>> + caches; + { + nb::ft_lock_guard lock(mu_); + for (PjitFunctionObject* fn = compiled_functions_; fn != nullptr; + fn = fn->next) { + caches.emplace_back(fn->fun.cache(), fn->fun.executables()); + } + } + for (auto& [cache, executables] : caches) { + nb::ft_object_guard lock(cache); + executables->Clear(); + } + }; + + private: + // Protected by the GIL in GIL mode, and by mu_ in freethreading mode. + nb::ft_mutex mu_; + PjitFunctionObject* compiled_functions_; +}; + +PjitFunctionStore pjit_function_store; + +PyObject* PjitFunction_Type = nullptr; + +bool PjitFunction::IsPjitFunction(nb::handle handle) { + return handle.type().ptr() == PjitFunction_Type; +} + +PjitFunction* PjitFunction::AsPjitFunctionUnchecked(nb::handle handle) { + return &(reinterpret_cast(handle.ptr())->fun); +} + +PjitFunction* AsPjitFunction(nb::handle handle) { + if (!PjitFunction::IsPjitFunction(handle)) { + throw xla::XlaRuntimeError(xla::InvalidArgument("Expected a PjitFunction")); + } + return PjitFunction::AsPjitFunctionUnchecked(handle); +} + +extern "C" { + +PyObject* PjitFunction_tp_vectorcall(PyObject* callable, PyObject* const* args, + size_t nargs, PyObject* kwnames) { + PjitFunctionObject* o = reinterpret_cast(callable); + tsl::profiler::TraceMe traceme([&] { + return absl::StrCat("PjitFunction(", o->fun.function_name(), ")"); + }); + try { + absl::StatusOr out = + o->fun.Call(callable, args, nargs, kwnames); + if (!out.ok()) { + PyErr_SetString(PyExc_ValueError, out.status().ToString().c_str()); + return nullptr; + } + return out.value().release().ptr(); + } catch (nb::python_error& e) { + e.restore(); + return nullptr; + } catch (nb::cast_error& e) { + PyErr_SetString(PyExc_ValueError, e.what()); + return nullptr; + } catch (std::invalid_argument& e) { + PyErr_SetString(PyExc_ValueError, e.what()); + return nullptr; + } catch (std::runtime_error& e) { + PyErr_SetString(PyExc_ValueError, e.what()); + return nullptr; + } +} + +PyObject* PjitFunction_tp_new(PyTypeObject* subtype, PyObject* args, + PyObject* kwds) { + PjitFunctionObject* self = + reinterpret_cast(subtype->tp_alloc(subtype, 0)); + if (!self) return nullptr; +#if PY_VERSION_HEX < 0x030C0000 + self->dict = nullptr; + self->weakrefs = nullptr; +#endif // PY_VERSION_HEX < 0x030C0000 + self->vectorcall = PjitFunction_tp_vectorcall; + return reinterpret_cast(self); +} + +void PjitFunction_tp_dealloc(PyObject* self) { + PyObject_GC_UnTrack(self); + PyTypeObject* tp = Py_TYPE(self); + PjitFunctionObject* o = reinterpret_cast(self); + pjit_function_store.Remove(o); + PyObject_ClearWeakRefs(self); +#if PY_VERSION_HEX < 0x030C0000 + Py_CLEAR(o->dict); +#elif PY_VERSION_HEX < 0x030D0000 + _PyObject_ClearManagedDict(self); +#else + PyObject_ClearManagedDict(self); +#endif // PY_VERSION_HEX < 0x030C0000 + o->fun.~PjitFunction(); + tp->tp_free(self); + Py_DECREF(tp); +} + +int PjitFunction_tp_traverse(PyObject* self, visitproc visit, void* arg) { + // TODO(mattjj): phawkins@ observed that the xla::PyTreeRegistry + // pytree_registry_ attribute of PjitFunction could in principle also have + // python references to visit + PjitFunctionObject* o = reinterpret_cast(self); + // https://docs.python.org/3/c-api/typeobj.html#c.PyTypeObject.tp_traverse + Py_VISIT(Py_TYPE(self)); +#if PY_VERSION_HEX < 0x030C0000 + Py_VISIT(o->dict); +#elif PY_VERSION_HEX < 0x030D0000 + _PyObject_VisitManagedDict(self, visit, arg); +#else + PyObject_VisitManagedDict(self, visit, arg); +#endif // PY_VERSION_HEX < 0x030C0000 + Py_VISIT(o->fun.cache_miss().ptr()); + Py_VISIT(o->fun.shard_arg_fallback().ptr()); + if (o->fun.fun()) { + Py_VISIT(o->fun.fun()->ptr()); + } + return 0; +} + +int PjitFunction_tp_clear(PyObject* self) { + PjitFunctionObject* o = reinterpret_cast(self); +#if PY_VERSION_HEX < 0x030C0000 + Py_CLEAR(o->dict); +#elif PY_VERSION_HEX < 0x030D0000 + _PyObject_ClearManagedDict(self); +#else + PyObject_ClearManagedDict(self); +#endif // PY_VERSION_HEX < 0x030C0000 + o->fun.ClearPythonReferences(); + return 0; +} + +// Implements the Python descriptor protocol so JIT-compiled functions can be +// used as bound methods. See: +// https://docs.python.org/3/howto/descriptor.html#functions-and-methods +PyObject* PjitFunction_tp_descr_get(PyObject* self, PyObject* obj, + PyObject* type) { + if (obj == nullptr || obj == Py_None) { + Py_INCREF(self); + return self; + } + return PyMethod_New(self, obj); +} + +static PyGetSetDef PjitFunction_tp_getset[] = { + // Having a __dict__ seems necessary to allow !functool.wraps to override + // __doc__. + {const_cast("__dict__"), PyObject_GenericGetDict, + PyObject_GenericSetDict, nullptr, nullptr}, + {nullptr, nullptr, nullptr, nullptr, nullptr}}; + +PyObject* PjitFunction_tp_repr(PyObject* self) { + try { + const std::string& repr = absl::StrFormat( + "", nb::cast(nb::repr( + nb::getattr(self, "__wrapped__")))); + return PyUnicode_FromString(repr.c_str()); + } catch (...) { + // Ignore all errors when accessing a repr. + return PyUnicode_FromString(""); + } +} + +} // extern "C" + +void InitializePjitFunction( + PjitFunctionObject* fn_obj, std::string function_name, + std::optional fun, nb::callable cache_miss, + std::vector static_argnums, std::vector static_argnames, + nb::object global_cache_key, + xla::nb_class_ptr pytree_registry, + nb::callable shard_arg_fallback, + xla::nb_class_ptr cache) { + fn_obj->next = fn_obj->prev = nullptr; + if (nb::isinstance(global_cache_key)) { + global_cache_key = nb::tuple(global_cache_key); + } + new (&fn_obj->fun) PjitFunction( + std::move(function_name), std::move(fun), std::move(cache_miss), + std::move(static_argnums), std::move(static_argnames), + std::move(global_cache_key), std::move(pytree_registry), + std::move(shard_arg_fallback), std::move(cache)); + // Handled separately because it is not exception safe to call this + // in the constructor because it leaves the object improperly constructed. + fn_obj->fun.InitExecutables(); + + // Only add the executable to the store after executables_ has been + // initialized. We want only fully constructed executables in the store. + pjit_function_store.Insert(fn_obj); +} + +nb::object MakePjitFunction( + std::string function_name, std::optional fun, + nb::callable cache_miss, std::vector static_argnums, + std::vector static_argnames, nb::object global_cache_key, + xla::nb_class_ptr pytree_registry, + nb::callable shard_arg_fallback, + std::optional> cache) { + nb::object obj = nb::steal(PjitFunction_tp_new( + reinterpret_cast(PjitFunction_Type), nullptr, nullptr)); + PjitFunctionObject* fn_obj = reinterpret_cast(obj.ptr()); + if (!cache) { + cache = xla::make_nb_class( + PjitFunctionCache::kDefaultCapacity); + } + InitializePjitFunction( + fn_obj, std::move(function_name), std::move(fun), std::move(cache_miss), + std::move(static_argnums), std::move(static_argnames), + std::move(global_cache_key), std::move(pytree_registry), + std::move(shard_arg_fallback), std::move(*cache)); + return obj; +} + +// Version numbers for the pickled representations of +// PjitFunction. Increment these if changing them. +const int kPjitFunctionPickleVersion = 1; + +PyMemberDef PjitFunction_members[] = { + {"__vectorcalloffset__", T_PYSSIZET, + static_cast(offsetof(PjitFunctionObject, vectorcall)), + READONLY, nullptr}, +#if PY_VERSION_HEX < 0x030C0000 + {"__dictoffset__", T_PYSSIZET, + static_cast(offsetof(PjitFunctionObject, dict)), READONLY, + nullptr}, + {"__weaklistoffset__", T_PYSSIZET, + static_cast(offsetof(PjitFunctionObject, weakrefs)), READONLY, + nullptr}, +#endif // PY_VERSION_HEX < 0x030C0000 + {nullptr, 0, 0, 0, nullptr}, +}; + +PyType_Slot PjitFunction_slots[] = { + {Py_tp_new, reinterpret_cast(PjitFunction_tp_new)}, + {Py_tp_dealloc, reinterpret_cast(PjitFunction_tp_dealloc)}, + {Py_tp_traverse, reinterpret_cast(PjitFunction_tp_traverse)}, + {Py_tp_clear, reinterpret_cast(PjitFunction_tp_clear)}, + {Py_tp_getset, reinterpret_cast(PjitFunction_tp_getset)}, + {Py_tp_descr_get, reinterpret_cast(PjitFunction_tp_descr_get)}, + {Py_tp_call, reinterpret_cast(PyVectorcall_Call)}, + {Py_tp_repr, reinterpret_cast(PjitFunction_tp_repr)}, + {Py_tp_members, reinterpret_cast(PjitFunction_members)}, + {0, nullptr}, +}; + +} // namespace + +void BuildPjitSubmodule(nb::module_& m) { + nb::class_ cache(m, "PjitFunctionCache"); + cache.def(nb::init(), + nb::arg("capacity") = PjitFunctionCache::kDefaultCapacity); + cache.def("size", &PjitFunctionCache::Size, nb::lock_self()); + cache.def("capacity", &PjitFunctionCache::Capacity, nb::lock_self()); + cache.def("clear", &PjitFunctionCache::Clear, nb::lock_self()); + cache.def_static("clear_all", []() { pjit_function_store.ClearCaches(); }); + cache.def( + "__getstate__", + // Pickles as an empty cache; the client can repopulate as needed. + [](const PjitFunctionCache& cache) { + nb::dict pickle; + pickle["version"] = kPjitFunctionPickleVersion; + pickle["capacity"] = cache.Capacity(); + return pickle; + }, + nb::lock_self()); + cache.def("__setstate__", + [](PjitFunctionCache* cache, const nb::dict& pickle) { + int version = nb::cast(pickle["version"]); + if (version != kPjitFunctionPickleVersion) { + throw std::invalid_argument(absl::StrFormat( + "Invalid PjitFunction pickle version, got %d, expected %d", + version, kPjitFunctionPickleVersion)); + } + int capacity = nb::cast(pickle["capacity"]); + new (cache) PjitFunctionCache(capacity); + }); + + // We need to use heap-allocated type objects because we want to add + // additional methods dynamically. + std::string name = + absl::StrCat(nb::cast(m.attr("__name__")), ".PjitFunction"); + PyType_Spec PjitFunction_spec = { +#if PY_VERSION_HEX < 0x030B0000 + // Work around for https://github.com/python/cpython/issues/89478 + // CPython 3.10 and earlier assume that the .name value remains alive + // forever. + /*.name=*/strdup(name.c_str()), +#else + /*.name=*/name.c_str(), +#endif // PY_VERSION_HEX < 0x030B0000 + /*.basicsize=*/static_cast(sizeof(PjitFunctionObject)), + /*.itemsize=*/0, +#if PY_VERSION_HEX < 0x030C0000 + /*.flags=*/Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | + Py_TPFLAGS_HAVE_VECTORCALL, +#else // PY_VERSION_HEX < 0x030C0000 + /*.flags=*/Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | + Py_TPFLAGS_HAVE_VECTORCALL | Py_TPFLAGS_MANAGED_DICT | + Py_TPFLAGS_MANAGED_WEAKREF, +#endif // PY_VERSION_HEX < 0x030C0000 + /*.slots=*/PjitFunction_slots, + }; + PjitFunction_Type = PyType_FromSpec(&PjitFunction_spec); + if (!PjitFunction_Type) { + throw nb::python_error(); + } + nb::object cfun = nb::borrow(PjitFunction_Type); + + // Add PjitFunction to the xla_extension module so it can be pickled. + m.attr("PjitFunction") = cfun; + cfun.attr("__getstate__") = nb::cpp_function( + [](const PjitFunction::object& self) { + PjitFunction* fn = self.func(); + nb::dict pickle; + pickle["version"] = kPjitFunctionPickleVersion; + pickle["function_name"] = fn->function_name(); + if (fn->fun().has_value()) { + pickle["fun"] = *fn->fun(); + } + pickle["cache_miss"] = fn->cache_miss(); + pickle["static_argnums"] = fn->static_argnums(); + pickle["static_argnames"] = nb::cast(fn->static_argnames()); + pickle["global_cache_key"] = fn->global_cache_key(); + pickle["pytree_registry"] = nb::cast(fn->pytree_registry()); + pickle["shard_arg_fallback"] = fn->shard_arg_fallback(); + pickle["cache"] = fn->cache(); + return pickle; + }, + nb::is_method()); + cfun.attr("__setstate__") = nb::cpp_function( + [](nb::object& self, const nb::dict& pickle) { + int version = nb::cast(pickle["version"]); + if (version != kPjitFunctionPickleVersion) { + throw std::invalid_argument(absl::StrFormat( + "Invalid PjitFunction pickle version, got %d, expected %d. " + "Pickling/Unpickling jitted functions using different JAX " + "versions is not supported.", + version, kPjitFunctionPickleVersion)); + } + std::string function_name = + nb::cast(pickle["function_name"]); + std::optional fun; + if (pickle.contains("fun")) { + fun = nb::cast(pickle["fun"]); + } + nb::callable cache_miss = nb::cast(pickle["cache_miss"]); + std::vector static_argnums = + nb::cast>(pickle["static_argnums"]); + std::vector static_argnames = + nb::cast>(pickle["static_argnames"]); + nb::object global_cache_key = pickle["global_cache_key"]; + xla::nb_class_ptr pytree_registry = + nb::cast>( + nb::handle(pickle["pytree_registry"].ptr())); + nb::callable shard_arg_fallback = + nb::cast(pickle["shard_arg_fallback"]); + xla::nb_class_ptr cache = + nb::cast>(pickle["cache"]); + InitializePjitFunction( + reinterpret_cast(self.ptr()), + std::move(function_name), std::move(fun), std::move(cache_miss), + std::move(static_argnums), std::move(static_argnames), + std::move(global_cache_key), std::move(pytree_registry), + std::move(shard_arg_fallback), std::move(cache)); + }, + nb::is_method()); + cfun.attr("__signature__") = + xla::nb_property_readonly([](nb::handle self) -> nb::object { + return AsPjitFunction(self)->PythonSignature(); + }); + cfun.attr("_cache_miss") = + xla::nb_property_readonly([](nb::handle self) -> nb::object { + return AsPjitFunction(self)->cache_miss(); + }); + // All private members are only for testing/debugging purposes + cfun.attr("_cache_size") = nb::cpp_function( + [](nb::handle self) -> int { + return AsPjitFunction(self)->cache_capacity(); + }, + nb::is_method()); + cfun.attr("_clear_cache") = nb::cpp_function( + [](nb::handle self) { AsPjitFunction(self)->ClearCache(); }, + nb::is_method()); + + m.def( + "pjit", + [](std::string function_name, std::optional fun, + nb::callable cache_miss, std::vector static_argnums, + std::vector static_argnames, nb::object global_cache_key, + nb::object pytree_registry, nb::callable shard_arg_fallback, + std::optional> cache) { + xla::nb_class_ptr registry = + nb::cast>( + nb::handle(pytree_registry.ptr())); + return MakePjitFunction( + std::move(function_name), std::move(fun), std::move(cache_miss), + std::move(static_argnums), std::move(static_argnames), + std::move(global_cache_key), std::move(registry), + std::move(shard_arg_fallback), std::move(cache)); + }, + nb::arg("function_name"), nb::arg("fun").none(), nb::arg("cache_miss"), + nb::arg("static_argnums"), nb::arg("static_argnames"), + nb::arg("global_cache_key"), nb::arg("pytree_registry"), + nb::arg("shard_arg_fallback"), nb::arg("cache").none() = nb::none()); +} + +} // namespace jax diff --git a/jaxlib/xla/pjit.h b/jaxlib/xla/pjit.h new file mode 100644 index 000000000000..8d47347ab9a2 --- /dev/null +++ b/jaxlib/xla/pjit.h @@ -0,0 +1,27 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_PJIT_H_ +#define JAXLIB_XLA_PJIT_H_ + +// placeholder for index annotation headers +#include "nanobind/nanobind.h" + +namespace jax { + +void BuildPjitSubmodule(nanobind::module_& m); +} + +#endif // JAXLIB_XLA_PJIT_H_ diff --git a/jaxlib/xla/pmap_lib.cc b/jaxlib/xla/pmap_lib.cc new file mode 100644 index 000000000000..295ac8bfccfb --- /dev/null +++ b/jaxlib/xla/pmap_lib.cc @@ -0,0 +1,1180 @@ +/* Copyright 2021 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/pmap_lib.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" +#include "absl/hash/hash.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/synchronization/notification.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/variant.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/xla/config.h" +#include "jaxlib/xla/jax_jit.h" +#include "jaxlib/xla/nb_class_ptr.h" +#include "jaxlib/xla/py_array.h" +#include "jaxlib/xla/py_client.h" +#include "jaxlib/xla/py_device.h" +#include "jaxlib/xla/py_executable.h" +#include "jaxlib/xla/py_values.h" +#include "jaxlib/xla/python_ref_manager.h" +#include "jaxlib/xla/pytree.h" +#include "jaxlib/xla/sharded_device_array.h" +#include "jaxlib/xla/sharding.h" +#include "jaxlib/xla/to_ifrt_sharding.h" +#include "jaxlib/xla/traceback.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/nb_helpers.h" +#include "xla/python/nb_numpy.h" +#include "xla/python/types.h" +#include "xla/status_macros.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/python/lib/core/numpy.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" +#include "tsl/profiler/lib/traceme.h" + +namespace jax { + +namespace nb = nanobind; + +namespace { + +// Specifies how to shard the inputs. Even though everything could be computed +// from `sharding_specs` and the argument shape, we cache derived computations +// for performance. +struct InputSpec { + InputSpec(nb::object indices, nb::object array_sharding) + : indices(std::move(indices)), + array_sharding(std::move(array_sharding)) {} + nb::object indices; + nb::object array_sharding; +}; + +// An object containing the arguments to create Array from the +// output buffers. +struct ResultSpec { + public: + explicit ResultSpec(nb::object aval) + : out_aval(std::move(aval)), + weak_type(nb::cast(out_aval.attr("weak_type"))) {} + nb::object out_aval; + bool weak_type; +}; + +// The result of `ShardArg`. +struct ShardArgResult { + // Points to the on-device array. + // ifrt_array->sharding().num_shards() == `num_devices`. + tsl::RCReference ifrt_array; + // The Python argument will be always be copied to `owning_sda`. + nb::object owning_sda; +}; + +// Shards a single argument over devices. +// +// We currently only support fully in C++, C++ Array. For all +// other usages, we call a Python function returning C++ Array +// that will be casted back to the C++ objects. +// +// This function is not usable for JAX extensions that do not comply with the +// PjRt interfaces. +// +// Arguments: +// `arg`: The object to shard across `devices`. If a `Array`, +// a fast-path will be executed if it's already correctly sharded. +// +// Returns a failure absl::Status when an unrecoverable error occurred, so we +// don't need to fallback to Python. +// +// Both `devices` and `sharding_spec` has the same length. +absl::StatusOr ShardArg( + nb::handle arg, absl::Span devices, + const InputSpec& input_spec, nb::handle py_devices, + const nb::callable& python_fallback) { + if (arg.type().ptr() == xla::PyArray::type().ptr()) { + auto py_array = nb::borrow(arg); + if (py_array.sharding().type().ptr() == + input_spec.array_sharding.type().ptr()) { + auto* pmap_sharding = nb::cast(py_array.sharding()); + auto* cached_pmap_sharding = + nb::cast(input_spec.array_sharding); + + if (pmap_sharding->sharding_spec() == + cached_pmap_sharding->sharding_spec()) { + ShardArgResult result; + result.owning_sda = nb::borrow(arg); + result.ifrt_array = tsl::FormRef(py_array.ifrt_array()); + if (result.ifrt_array == nullptr) { + return xla::InvalidArgument("Array has been deleted."); + } + if (result.ifrt_array->sharding().devices()->devices() != devices) { + absl::InlinedVector ifrt_devices; + ifrt_devices.reserve(devices.size()); + ifrt_devices.insert(ifrt_devices.end(), devices.begin(), + devices.end()); + // pmap does not support memory_kind for now. + auto* ifrt_client = result.ifrt_array->client(); + TF_ASSIGN_OR_RETURN(auto copied_ifrt_arrays, + ifrt_client->CopyArrays( + absl::MakeSpan(&result.ifrt_array, 1), + ifrt_client->MakeDeviceList(ifrt_devices), + xla::ifrt::MemoryKind(), + xla::ifrt::ArrayCopySemantics::kReuseInput)); + result.ifrt_array = std::move(copied_ifrt_arrays.front()); + } + return result; + } + } + } + + auto ndarray = xla::nb_numpy_ndarray::ensure(arg); + if (ndarray && PyArray_CheckExact(arg.ptr()) && + xla::DtypeToPrimitiveType(ndarray.dtype()).status().ok()) { + tsl::profiler::TraceMe traceme("ndarray pmap ShardArg"); + nb::list indices = nb::list(input_spec.indices); + nb::list py_devices_list = nb::cast(py_devices); + auto n_devices = py_devices_list.size(); + if (indices.size() != n_devices) { + return xla::InvalidArgument("indices vs devices mismatch: %d vs %d", + indices.size(), n_devices); + } + + std::vector> per_device_arrays; + per_device_arrays.reserve(n_devices); + absl::InlinedVector devices; + devices.reserve(n_devices); + // TODO(hyeontaek): The created array will never be disassembled. We should + // omit collecting shapes and make the OpaqueSharding non-disassemblable? + std::vector shapes; + shapes.reserve(n_devices); + + nb::list owning_pylist; + ShardArgResult result; + result.owning_sda = owning_pylist; + const bool jax_enable_x64 = GetEnableX64(); + + std::vector device_put_fns; + device_put_fns.reserve(n_devices); + xla::DevicePutOptions options; + options.squash_64bit_types = !jax_enable_x64; + options.allow_zero_copy = true; + for (size_t i = 0; i < n_devices; ++i) { + auto to_device = nb::cast(py_devices_list[i]); + if (to_device->client().get() == nullptr) { + return xla::InvalidArgument("Cannot copy to unattached devices."); + } + + TF_ASSIGN_OR_RETURN( + device_put_fns.emplace_back(), + DevicePut(arg[indices[i]], to_device->client()->ifrt_client(), + to_device->device(), options, xla::ifrt::MemoryKind())); + } + std::vector device_puts; + device_puts.reserve(n_devices); + { + nb::gil_scoped_release gil_release; + for (auto& device_put_fn : device_put_fns) { + TF_ASSIGN_OR_RETURN(auto device_put, std::move(device_put_fn)()); + device_puts.push_back(std::move(device_put)); + } + } + for (auto& device_put : device_puts) { + per_device_arrays.push_back(std::move(device_put.ifrt_array)); + devices.push_back( + per_device_arrays.back()->sharding().devices()->devices().front()); + shapes.push_back(per_device_arrays.back()->shape()); + if (device_put.owning_pybuffer) { + owning_pylist.append(device_put.owning_pybuffer); + } + } + + if (per_device_arrays.empty()) { + return xla::InvalidArgument("Per-device arrays must not be empty."); + } + // TODO(hyeontaek): The logical shape here is inaccurate. We + // may want to avoid creating a new Array or specialize Array + // to disallow access to the logical shape. + xla::ifrt::Shape shape = per_device_arrays.front()->shape(); + TF_ASSIGN_OR_RETURN( + auto ifrt_sharding, + xla::GetIfrtConcreteSharding(input_spec.array_sharding, shape, shapes)); + TF_ASSIGN_OR_RETURN( + result.ifrt_array, + per_device_arrays.front() + ->client() + ->AssembleArrayFromSingleDeviceArrays( + std::move(shape), std::move(ifrt_sharding), + absl::MakeSpan(per_device_arrays), + xla::ifrt::ArrayCopySemantics::kReuseInput, + xla::ifrt::SingleDeviceShardSemantics::kAddressableShards)); + return result; + } + tsl::profiler::TraceMe traceme("pmap_lib_shard_arg_python_fallback"); + auto py_array_or_bufs = python_fallback(arg, input_spec.array_sharding); + + auto py_array = nb::cast(py_array_or_bufs); + ShardArgResult result; + result.owning_sda = nb::borrow(py_array_or_bufs); + result.ifrt_array = tsl::FormRef(py_array.ifrt_array()); + return result; +} + +struct PmapCacheEntry { + explicit PmapCacheEntry(xla::PyTreeRegistry* registry) + : out_pytree_def(registry) {} + std::shared_ptr executable; + // The value `backend.local_devices()`. + nb::object py_devices; // To pass back to Python. + std::vector devices; + std::vector input_specs; + xla::PyTreeDef out_pytree_def; + // Objects necessary to build the out Array objects. + std::vector out_result_specs; + + std::vector out_array_shardings; + std::vector out_dtypes; + std::vector> out_shapes; + std::vector out_committed; + + // Ensures a single thread performs the compilation for a given executable. + // + // The first thread (holding the GIL) will create the CacheEntry associated to + // a signature and if the object has been inserted already, other threads + // will wait for the notification. + absl::Notification compilation_complete; + + bool fall_back_to_python = false; +}; + +} // namespace + +// A `PmapFunction` is associated to a `jax.pmap(f)` and takes care of the +// bookkeeping of the different signatures used and the dispatch of calls to +// the correct underlying `PyLoadedExecutable`. This class is thread-safe. +class PmapFunction { + public: + PmapFunction(nb::callable fun, nb::callable cache_miss, + std::vector static_argnums, + nb::callable python_shard_arg_fallback, + xla::nb_class_ptr pytree_registry) + : fun_(std::move(fun)), + cache_miss_(std::move(cache_miss)), + static_argnums_(std::move(static_argnums)), + pytree_registry_(std::move(pytree_registry)), + python_shard_arg_fallback_(std::move(python_shard_arg_fallback)) { + std::sort(static_argnums_.begin(), static_argnums_.end()); + + function_name_ = + nb::cast(nb::str(nb::getattr(fun_, "__name__", fun_))); + } + PmapFunction(const PmapFunction&) = delete; + PmapFunction& operator=(const PmapFunction& other) = delete; + PmapFunction(PmapFunction&&) = default; + PmapFunction& operator=(PmapFunction&&) = default; + + // This function will: + // (a) flatten the inputs using pytree + // (b) get buffer objects from the arguments + // (c) call the executable + // (d) construct `Array` objects from the outputs + // (e) reconstruct the `PyTree`. + absl::StatusOr Call(nb::handle callable, PyObject* const* args, + size_t nargs, PyObject* kwnames); + + nb::object PythonSignature() { + static const auto* inspect = + new nb::module_(nb::module_::import_("inspect")); + return inspect->attr("signature")(fun_); + } + + int cache_size() { + nb::ft_lock_guard lock(mu_); + return executables_.size(); + } + void cache_clear() { + nb::ft_lock_guard lock(mu_); + return executables_.clear(); + } + const nb::callable& fun() const { return fun_; } + const nb::callable& cache_miss() const { return cache_miss_; } + const std::string& function_name() const { return function_name_; } + const xla::nb_class_ptr& pytree_registry() const { + return pytree_registry_; + } + const nb::callable& python_shard_arg_fallback() const { + return python_shard_arg_fallback_; + } + const std::vector& static_argnums() const { return static_argnums_; } + + // nb::object typed subclass for PmapFunction objects. + class pyobject : public nb::object { + public: + NB_OBJECT(pyobject, nb::object, "PmapFunction", + PmapFunction::IsPmapFunction); + pyobject() = default; + PmapFunction* func() const { + return PmapFunction::AsPmapFunctionUnchecked(*this); + } + }; + // Alias as ::object; outside the scope above we won't confuse nanobind's + // macros. + using object = pyobject; + + // Returns true if `h` is a PmapFunction. + static bool IsPmapFunction(nb::handle handle); + // Converts `handle` to a PmapFunction*. Does not do any checking. + static PmapFunction* AsPmapFunctionUnchecked(nb::handle handle); + + // Helper function used by the tp_clear GC method. + void ClearPythonReferences() { + nb::callable fun, cache_miss, python_shard_arg_fallback; + // Swap values for nulls before they are destroyed. See the Python + // Py_CLEAR() documentation for a discussion of this topic. + std::swap(fun_, fun); + std::swap(cache_miss_, cache_miss); + std::swap(python_shard_arg_fallback_, python_shard_arg_fallback); + } + + // Updates the signature of arguments for a pmapped function. + // + // It deals with the arguments signatures and also of the global and + // thread-local jit context. + absl::Status ComputeCallSignature( + absl::Span flat_dynamic_args, + CallSignature& signature) { + signature.function_name = function_name_; + + // Get dynamic argument signatures. + JitState& global_state = jax::GlobalJitState(); + JitState& tls = jax::ThreadLocalJitState(); + const bool jax_enable_x64 = GetEnableX64(); + signature.jax_enable_x64 = jax_enable_x64; + for (nb::handle arg : flat_dynamic_args) { + auto signature_or_error = xla::PyArgSignatureOfValue(arg, jax_enable_x64); + if (!signature_or_error.ok()) { + VLOG(2) << "PyArgSignatureOfValue failed: " + << signature_or_error.status(); + return signature_or_error.status(); + } + signature.dynamic_arg_signatures.push_back( + std::move(signature_or_error).value()); + } + signature.thread_local_extra_jit_context = tls.extra_jit_context; + signature.global_extra_jit_context = global_state.extra_jit_context; + signature.configs = JitConfigs(); + return absl::Status(); + } + + // Returns, for debugging purposes (e.g. finding why some call misses the + // cache and recompiles), the list of the string representations of the keys. + // + // The format can change at any time. + std::string DebugCacheKeys() { + nb::ft_lock_guard lock(mu_); + std::vector key_strings = { + absl::StrCat("The cache contains ", executables_.size(), " elements:")}; + // We will be able to use auto& [key, _] when TF uses C++ 17. + for (auto& pair : executables_) { + key_strings.push_back(pair.first.DebugString()); + } + return absl::StrJoin(key_strings, "\n\n"); + } + + private: + // Mutates `cache_entry` in place. + void PopulateCacheEntry(PmapCacheEntry& cache_entry, + const nb::tuple& out_and_fastpath_data); + + bool always_fallback_to_python_ = false; + + nb::callable fun_; // The Python function to pmap. + std::string function_name_; + // See JAX _cpp_pmap in api.py for documentation. + nb::callable cache_miss_; + + // We need to know the static arguments to remove them from the arguments + // passed to the underlying PyLoadedExecutable. In sorted order. + std::vector static_argnums_; + xla::nb_class_ptr pytree_registry_; + // We need a `shared_ptr` here to ensure value pointer stability, and to + // ensure that the cache entry remains alive in the presence of concurrent + // removals. + absl::flat_hash_map> + executables_; + + // The fallback function to use with `ShardArgs`. + // TODO(jblespiau): Add support for more types from C++. + nb::callable python_shard_arg_fallback_; + + // Protect methods in FT: + nb::ft_mutex mu_; +}; + +void PmapFunction::PopulateCacheEntry(PmapCacheEntry& cache_entry, + const nb::tuple& out_and_fastpath_data) { + CHECK_EQ(out_and_fastpath_data.size(), 2); + if (out_and_fastpath_data[1].is_none()) { + cache_entry.fall_back_to_python = true; + return; + } + + nb::tuple pmap_data = nb::cast(out_and_fastpath_data[1]); + if (nb::cast(pmap_data.attr("version")) != 1) { + throw xla::XlaRuntimeError(absl::StrCat( + "The versions of jaxlib and Jax are incompatible (pmap cpp version 1 " + "expected, but got ", + nb::cast(pmap_data.attr("version")), + "Upgrade jaxlib and jax. Provided data was:", + nb::cast(nb::str(nb::repr(pmap_data))))); + } + // See api.nb::_PmapFastpathData in the JAX code base for the expected + // namedtuple. + std::shared_ptr executable; + try { + executable = nb::cast>( + pmap_data.attr("xla_executable")); + } catch (const nb::cast_error& e) { + // Backends that don't implement the C++ PjRt APIs + cache_entry.fall_back_to_python = true; + always_fallback_to_python_ = true; + return; + } + cache_entry.executable = std::move(executable); + const std::vector>& devices = + cache_entry.executable->AddressableDevices(); + cache_entry.devices.reserve(devices.size()); + for (auto& device : devices) { + cache_entry.devices.push_back(device->device()); + } + + // Inputs shard args details. + nb::list input_indices = pmap_data.attr("input_indices"); + + cache_entry.py_devices = pmap_data.attr("input_devices"); + auto input_devices = nb::cast>>( + pmap_data.attr("input_devices")); + + nb::list input_array_shardings = pmap_data.attr("input_array_shardings"); + + cache_entry.input_specs.reserve(input_array_shardings.size()); + + for (int i = 0; i < input_array_shardings.size(); ++i) { + cache_entry.input_specs.emplace_back(input_indices[i], + input_array_shardings[i]); + } + + // Outputs specs. + auto out_tree = nb::cast(pmap_data.attr("out_pytree_def")); + cache_entry.out_pytree_def = std::move(out_tree); + nb::list out_avals = pmap_data.attr("out_avals"); + + cache_entry.out_result_specs.reserve(out_avals.size()); + cache_entry.out_dtypes.reserve(out_avals.size()); + cache_entry.out_shapes.reserve(out_avals.size()); + + for (int i = 0; i < out_avals.size(); ++i) { + cache_entry.out_dtypes.push_back(out_avals[i].attr("dtype")); + cache_entry.out_shapes.push_back( + nb::cast>(out_avals[i].attr("shape"))); + cache_entry.out_result_specs.emplace_back(out_avals[i]); + } + + nb::list out_array_shardings = pmap_data.attr("out_array_shardings"); + + DCHECK(out_array_shardings.size() == 0 || + out_avals.size() == out_array_shardings.size()); + + cache_entry.out_array_shardings.reserve(out_array_shardings.size()); + for (nb::handle out_array_sharding : out_array_shardings) { + cache_entry.out_array_shardings.push_back( + nb::borrow(out_array_sharding)); + } + + nb::list out_committed = pmap_data.attr("out_committed"); + + DCHECK(out_committed.size() == 0 || out_avals.size() == out_committed.size()); + + cache_entry.out_committed.reserve(out_committed.size()); + for (nb::handle c : out_committed) { + cache_entry.out_committed.push_back(nb::cast(c)); + } +} + +absl::StatusOr PmapFunction::Call(nb::handle callable, + PyObject* const* args, + size_t nargs, PyObject* kwnames) { + xla::GlobalPyRefManager()->MaybeCollectGarbage(); + + // Calls the cache_miss_ function. This just calls the Python function; it may + // return nullptr value if a Python exception is thrown. + auto cache_miss = [&]() -> nb::tuple { + return nb::steal( + PyObject_Vectorcall(cache_miss_.ptr(), args, nargs, kwnames)); + }; + + // Call the cache_miss() function, extracting the output data and ignoring + // the fastpath data. If the cache miss returns a Python error, returns + // nullptr and leaves the Python error set. + auto fallback_to_cache_miss = [&]() { + nb::tuple cache_miss_output = cache_miss(); + if (!cache_miss_output.ptr()) { + return nb::object(); + } + return nb::object(cache_miss_output[0]); + }; + + if (always_fallback_to_python_) { + return fallback_to_cache_miss(); + } + + size_t num_positional_args = PyVectorcall_NARGS(nargs); + size_t num_keyword_args = kwnames ? PyTuple_GET_SIZE(kwnames) : 0; + absl::Span positional_args(args, num_positional_args); + absl::Span keyword_args(args + num_positional_args, + num_keyword_args); + CallSignature call_signature; + absl::InlinedVector flat_dynamic_args; + std::vector keep_alive_objects; + absl::Status status = + ParseArguments(positional_args, keyword_args, kwnames, static_argnums_, + /*static_argnames=*/{}, pytree_registry_.get(), + call_signature.arg_signature, flat_dynamic_args); + if (!status.ok()) { + VLOG(2) << "ParseArguments failed: " << status; + return fallback_to_cache_miss(); + } + + status = ComputeCallSignature(flat_dynamic_args, call_signature); + if (!status.ok()) { + return fallback_to_cache_miss(); + } + + // Retrieve/Maybe add the executable to the cache. + bool inserted = false; + std::shared_ptr cache_entry_ptr; + { + nb::ft_lock_guard lock(mu_); + std::shared_ptr& entry_ref = executables_[call_signature]; + if (!entry_ref) { + inserted = true; + entry_ref = std::make_shared(pytree_registry_.get()); + } + cache_entry_ptr = entry_ref; + } + PmapCacheEntry& cache_entry = *cache_entry_ptr; + + if (!cache_entry.compilation_complete.HasBeenNotified()) { + // In case of several threads attempting to compile the executable, only + // the one that inserted the item will perform the compilation. + if (inserted) { + nb::object out_and_fastpath_data; + nb::tuple out_tuple; + VLOG(2) << "Cache miss for " << call_signature.DebugString(); + try { + // Calls Python and may release the GIL. May also throw if + // compilation/tracing fails. + out_and_fastpath_data = cache_miss(); + if (!out_and_fastpath_data.ptr()) { + throw nb::python_error(); + } + out_tuple = nb::cast(out_and_fastpath_data); + + PopulateCacheEntry(cache_entry, out_tuple); + } catch (const std::exception& e) { + cache_entry.fall_back_to_python = true; + cache_entry.compilation_complete.Notify(); + throw; + } + cache_entry.compilation_complete.Notify(); + + // We have already computed the result in the miss path so we can return + // it. We are even *required* to do so if there are donated arguments, + // because any donated buffers will now be invalid. + return nb::object(out_tuple[0]); + } else { + // Release the GIL while we wait, making sure the compile thread can + // lock it. + nb::gil_scoped_release release; + cache_entry.compilation_complete.WaitForNotification(); + } + } + if (cache_entry.fall_back_to_python) { + return fallback_to_cache_miss(); + } + + // 1. Parse arguments. + std::vector& input_devices = cache_entry.devices; + std::vector& input_specs = cache_entry.input_specs; + const int num_args = flat_dynamic_args.size(); + + // We need [num_args] for the `Execute` call below. + std::vector> num_args_arrays(num_args); + for (int i = 0; i < num_args; ++i) { + TF_ASSIGN_OR_RETURN( + ShardArgResult sharded_arg, + ShardArg(flat_dynamic_args[i], input_devices, input_specs[i], + cache_entry.py_devices, python_shard_arg_fallback_)); + + num_args_arrays[i] = std::move(sharded_arg.ifrt_array); + if (sharded_arg.owning_sda) { + keep_alive_objects.push_back(std::move(sharded_arg.owning_sda)); + } + } + + xla::ifrt::ExecuteOptions execute_options = cache_entry.executable->options(); + execute_options.launch_id = cache_entry.executable->GetNextLaunchId(); + execute_options.execution_stream_id = + tsl::Env::Default()->GetCurrentThreadId(); + + // A vector of [num_outputs]. + std::vector> output_arrays; + { + nb::gil_scoped_release gil_release; + auto ifrt_executable = cache_entry.executable->ifrt_executable(); + TF_ASSIGN_OR_RETURN( + auto result, ifrt_executable->Execute(absl::MakeSpan(num_args_arrays), + execute_options, + /*devices=*/std::nullopt)); + output_arrays = std::move(result.outputs); + } + + // TODO(jblespiau): We don't need to create the PyBuffer objects. + // Having a C++ `Array`, keeping internally the PjRtBuffer + // objects is sufficient, and we can lazily create the `PyBuffer` only if + // we access them from Python. + auto traceback = xla::Traceback::Get(); + // TODO(jblespiau): Change the `client` function to return a reference. + xla::nb_class_ptr client = cache_entry.executable->client(); + + // Convert the PjRtBuffer objects to PyBuffer, and invert the order from + // [num_devices, num_args] to [num_args, num_devices]. + const int num_outputs = output_arrays.size(); + std::vector flat_sharded_device_arrays; + flat_sharded_device_arrays.reserve(num_outputs); + + const auto& output_specs = cache_entry.out_result_specs; + + TF_RET_CHECK(cache_entry.out_array_shardings.size() == num_outputs); + for (int i = 0; i < num_outputs; ++i) { + const ResultSpec& result_spec = output_specs[i]; + xla::PyArray py_array( + result_spec.out_aval, result_spec.weak_type, cache_entry.out_dtypes[i], + cache_entry.out_shapes[i], cache_entry.out_array_shardings[i], client, + traceback, std::move(output_arrays[i]), cache_entry.out_committed[i], + /*skip_checks=*/true); + + flat_sharded_device_arrays.push_back(std::move(py_array)); + } + + nb::object out = + cache_entry.out_pytree_def.Unflatten(flat_sharded_device_arrays); + + // If there is a post-hook function, call it with the inputs and the outputs. + std::optional post_hook = GetPostHook(); + if (post_hook) { + nb::tuple args_tuple = + nb::steal(PyTuple_New(num_positional_args)); + for (size_t i = 0; i < num_positional_args; ++i) { + Py_INCREF(args[i]); + PyTuple_SET_ITEM(args_tuple.ptr(), i, args[i]); + } + nb::dict kwargs; + if (kwnames) { + for (size_t i = 0; i < num_keyword_args; ++i) { + kwargs[nb::handle(PyTuple_GET_ITEM(kwnames, i))] = + nb::borrow(args[num_positional_args + i]); + } + } + + (*post_hook)(callable, args_tuple, kwargs, out); + } + + return out; +} + +struct JaxPmapFunctionObject { + PyObject_HEAD; +#if PY_VERSION_HEX < 0x030C0000 + PyObject* dict; // Dictionary for __dict__ + PyObject* weakrefs; // Weak references; for use by the Python interpreter. +#endif // PY_VERSION_HEX < 0x030C0000 + vectorcallfunc vectorcall; + PmapFunction fun; +}; + +PyObject* JaxPmapFunction_Type = nullptr; + +bool PmapFunction::IsPmapFunction(nb::handle handle) { + return handle.type().ptr() == JaxPmapFunction_Type; +} + +PmapFunction* PmapFunction::AsPmapFunctionUnchecked(nb::handle handle) { + return &(reinterpret_cast(handle.ptr())->fun); +} + +absl::StatusOr AsPmapFunction(nb::handle handle) { + if (!PmapFunction::IsPmapFunction(handle)) { + return xla::InvalidArgument("Expected a PmapFunction"); + } + return PmapFunction::AsPmapFunctionUnchecked(handle); +} + +namespace { + +extern "C" { + +PyObject* JaxPmapFunction_tp_vectorcall(PyObject* callable, + PyObject* const* args, size_t nargs, + PyObject* kwnames) { + JaxPmapFunctionObject* o = reinterpret_cast(callable); + tsl::profiler::TraceMe traceme([&] { + return absl::StrCat("JaxPmapFunction(", o->fun.function_name(), ")"); + }); + try { + absl::StatusOr out = + o->fun.Call(callable, args, nargs, kwnames); + if (!out.ok()) { + PyErr_SetString(PyExc_ValueError, out.status().ToString().c_str()); + return nullptr; + } + return out.value().release().ptr(); + } catch (nb::python_error& e) { + e.restore(); + return nullptr; + } catch (nb::cast_error& e) { + PyErr_SetString(PyExc_ValueError, e.what()); + return nullptr; + } catch (std::invalid_argument& e) { + PyErr_SetString(PyExc_ValueError, e.what()); + return nullptr; + } +} + +PyObject* JaxPmapFunction_tp_new(PyTypeObject* subtype, PyObject* args, + PyObject* kwds) { + JaxPmapFunctionObject* self = + reinterpret_cast(subtype->tp_alloc(subtype, 0)); + if (!self) return nullptr; +#if PY_VERSION_HEX < 0x030C0000 + self->dict = nullptr; + self->weakrefs = nullptr; +#endif // PY_VERSION_HEX < 0x030C0000 + self->vectorcall = JaxPmapFunction_tp_vectorcall; + return reinterpret_cast(self); +} + +void JaxPmapFunction_tp_dealloc(PyObject* self) { + PyObject_GC_UnTrack(self); + PyTypeObject* tp = Py_TYPE(self); + JaxPmapFunctionObject* o = reinterpret_cast(self); + PyObject_ClearWeakRefs(self); +#if PY_VERSION_HEX < 0x030C0000 + Py_CLEAR(o->dict); +#elif PY_VERSION_HEX < 0x030D0000 + _PyObject_ClearManagedDict(self); +#else + PyObject_ClearManagedDict(self); +#endif // PY_VERSION_HEX < 0x030C0000 + o->fun.~PmapFunction(); + tp->tp_free(self); + Py_DECREF(tp); +} + +int JaxPmapFunction_tp_traverse(PyObject* self, visitproc visit, void* arg) { + JaxPmapFunctionObject* o = reinterpret_cast(self); + // https://docs.python.org/3/c-api/typeobj.html#c.PyTypeObject.tp_traverse + Py_VISIT(Py_TYPE(self)); +#if PY_VERSION_HEX < 0x030C0000 + Py_VISIT(o->dict); +#elif PY_VERSION_HEX < 0x030D0000 + _PyObject_VisitManagedDict(self, visit, arg); +#else + PyObject_VisitManagedDict(self, visit, arg); +#endif // PY_VERSION_HEX < 0x030C0000 + Py_VISIT(o->fun.fun().ptr()); + Py_VISIT(o->fun.cache_miss().ptr()); + return 0; +} + +int JaxPmapFunction_tp_clear(PyObject* self) { + JaxPmapFunctionObject* o = reinterpret_cast(self); +#if PY_VERSION_HEX < 0x030C0000 + Py_CLEAR(o->dict); +#elif PY_VERSION_HEX < 0x030D0000 + _PyObject_ClearManagedDict(self); +#else + PyObject_ClearManagedDict(self); +#endif // PY_VERSION_HEX < 0x030C0000 + o->fun.ClearPythonReferences(); + return 0; +} + +// Implements the Python descriptor protocol so PMAP-compiled functions can be +// used as bound methods. See: +// https://docs.python.org/3/howto/descriptor.html#functions-and-methods +PyObject* JaxPmapFunction_tp_descr_get(PyObject* self, PyObject* obj, + PyObject* type) { + if (obj == nullptr || obj == Py_None) { + Py_INCREF(self); + return self; + } + return PyMethod_New(self, obj); +} + +static PyGetSetDef JaxPmapFunction_tp_getset[] = { + // Having a __dict__ seems necessary to allow !functool.wraps to override + // __doc__. + {const_cast("__dict__"), PyObject_GenericGetDict, + PyObject_GenericSetDict, nullptr, nullptr}, + {nullptr, nullptr, nullptr, nullptr, nullptr}}; + +PyMemberDef JaxPmapFunction_members[] = { + {"__vectorcalloffset__", T_PYSSIZET, + static_cast(offsetof(JaxPmapFunctionObject, vectorcall)), + READONLY, nullptr}, +#if PY_VERSION_HEX < 0x030C0000 + {"__dictoffset__", T_PYSSIZET, + static_cast(offsetof(JaxPmapFunctionObject, dict)), READONLY, + nullptr}, + {"__weaklistoffset__", T_PYSSIZET, + static_cast(offsetof(JaxPmapFunctionObject, weakrefs)), + READONLY, nullptr}, +#endif // PY_VERSION_HEX < 0x030C0000 + {nullptr, 0, 0, 0, nullptr}, +}; + +PyType_Slot JaxPmapFunction_slots[] = { + {Py_tp_new, reinterpret_cast(JaxPmapFunction_tp_new)}, + {Py_tp_dealloc, reinterpret_cast(JaxPmapFunction_tp_dealloc)}, + {Py_tp_traverse, reinterpret_cast(JaxPmapFunction_tp_traverse)}, + {Py_tp_clear, reinterpret_cast(JaxPmapFunction_tp_clear)}, + {Py_tp_getset, reinterpret_cast(JaxPmapFunction_tp_getset)}, + {Py_tp_descr_get, reinterpret_cast(JaxPmapFunction_tp_descr_get)}, + {Py_tp_call, reinterpret_cast(PyVectorcall_Call)}, + {Py_tp_members, reinterpret_cast(JaxPmapFunction_members)}, + {0, nullptr}, +}; + +} // extern "C" + +nb::object MakePmapFunction( + nb::callable fun, nb::callable cache_miss, std::vector static_argnums, + nb::callable python_shard_arg_fallback, + xla::nb_class_ptr pytree_registry) { + nb::object obj = nb::steal(JaxPmapFunction_tp_new( + reinterpret_cast(JaxPmapFunction_Type), nullptr, nullptr)); + JaxPmapFunctionObject* buf = + reinterpret_cast(obj.ptr()); + new (&buf->fun) PmapFunction( + std::move(fun), std::move(cache_miss), std::move(static_argnums), + std::move(python_shard_arg_fallback), std::move(pytree_registry)); + return obj; +} + +// Version numbers for the pickled representations. +// Increment these if changing them. +const int kPmapFunctionPickleVersion = 1; + +} // namespace + +void BuildPmapSubmodule(nb::module_& m) { + nb::module_ pmap_lib = m.def_submodule("pmap_lib", "Jax C++ pmap library"); + + nb::class_ no_sharding(pmap_lib, "NoSharding"); + no_sharding.def(nb::init<>()) + .def("__getstate__", + [](const NoSharding& self) { return nb::make_tuple(); }) + .def("__setstate__", + [](NoSharding& self, nb::tuple t) { new (&self) NoSharding(); }) + .def("__repr__", + [](const NoSharding& chuncked) { return "NoSharding()"; }) + .def("__eq__", + [](const NoSharding& self, nb::object obj) { + return nb::isinstance(obj); + }) + .def("__hash__", [](const NoSharding& self) { + const size_t hash = absl::HashOf(self); + return nb::int_(hash); + }); + + nb::class_ chunked(pmap_lib, "Chunked"); + chunked.def(nb::init>()) + .def("__getstate__", + [](const Chunked& self) { return nb::make_tuple(self.chunks); }) + .def("__setstate__", + [](Chunked& self, nb::tuple t) { + new (&self) Chunked{nb::cast>(t[0])}; + }) + .def_ro("chunks", &Chunked::chunks) + .def("__repr__", + [](const Chunked& chuncked) { + return absl::StrCat("Chunked(", + absl::StrJoin(chuncked.chunks, ","), ")"); + }) + .def("__eq__", [](const Chunked& self, nb::object other) { + if (!nb::isinstance(other)) { + return false; + } + return self == nb::cast(other); + }); + + nb::class_ unstacked(pmap_lib, "Unstacked"); + unstacked.def(nb::init()) + .def("__getstate__", + [](const Unstacked& self) { return nb::make_tuple(self.size); }) + .def("__setstate__", + [](Unstacked& self, nb::tuple t) { + new (&self) Unstacked{nb::cast(t[0])}; + }) + .def_ro("size", &Unstacked::size) + .def("__repr__", + [](const Unstacked& x) { + return absl::StrCat("Unstacked(", x.size, ")"); + }) + .def("__eq__", [](const Unstacked& self, nb::object other) { + if (!nb::isinstance(other)) { + return false; + } + return self == nb::cast(other); + }); + + nb::class_ sharded_axis(pmap_lib, "ShardedAxis"); + sharded_axis.def(nb::init()) + .def("__getstate__", + [](const ShardedAxis& self) { return nb::make_tuple(self.axis); }) + .def("__setstate__", + [](ShardedAxis& self, nb::tuple t) { + new (&self) ShardedAxis{nb::cast(t[0])}; + }) + .def_ro("axis", &ShardedAxis::axis) + .def("__repr__", + [](const ShardedAxis& x) { + return absl::StrCat("ShardedAxis(axis=", x.axis, ")"); + }) + .def("__eq__", [](const ShardedAxis& self, const ShardedAxis& other) { + return self == other; + }); + + nb::class_ replicated(pmap_lib, "Replicated"); + replicated.def(nb::init()) + .def("__getstate__", + [](const Replicated& self) { return nb::make_tuple(self.replicas); }) + .def("__setstate__", + [](Replicated& self, nb::tuple t) { + new (&self) Replicated{nb::cast(t[0])}; + }) + .def_ro("replicas", &Replicated::replicas) + .def("__repr__", + [](const Replicated& x) { + return absl::StrCat("Replicated(replicas=", x.replicas, ")"); + }) + .def("__eq__", [](const Replicated& self, const Replicated& other) { + return self == other; + }); + + nb::class_ sharding_spec(pmap_lib, "ShardingSpec"); + sharding_spec + .def(nb::init(), nb::arg("sharding"), + nb::arg("mesh_mapping")) + .def("__getstate__", + [](const ShardingSpec& self) { + auto sharding = + xla::SpanToNbTuple(absl::MakeConstSpan(self.GetSharding())); + auto mesh_mapping = + xla::SpanToNbTuple(absl::MakeConstSpan(self.GetMeshMapping())); + return nb::make_tuple(sharding, mesh_mapping); + }) + .def("__setstate__", + [](ShardingSpec& self, nb::tuple t) { + new (&self) + ShardingSpec{nb::cast>(t[0]), + nb::cast>(t[1])}; + }) + .def_prop_ro( + "sharding", + [](const ShardingSpec& self) { + return xla::SpanToNbTuple(absl::MakeConstSpan(self.GetSharding())); + }) + .def_prop_ro("mesh_mapping", + [](const ShardingSpec& self) { + return xla::SpanToNbTuple( + absl::MakeConstSpan(self.GetMeshMapping())); + }) + .def("__eq__", [](const ShardingSpec& self, + const ShardingSpec& other) { return self == other; }) + .def("__hash__", [](const ShardingSpec& self) { + const size_t hash = absl::HashOf(self); + return nb::int_(hash); + }); + + // We need to use heap-allocated type objects because we want to add + // additional methods dynamically. + + std::string name = + absl::StrCat(nb::cast(m.attr("__name__")), ".PmapFunction"); + PyType_Spec pmap_function_spec = { +#if PY_VERSION_HEX < 0x030B0000 + // Work around for https://github.com/python/cpython/issues/89478 + // CPython 3.10 and earlier assume that the .name value remains alive + // forever. + /*.name=*/strdup(name.c_str()), +#else + /*.name=*/name.c_str(), +#endif // PY_VERSION_HEX < 0x030B0000 + /*.basicsize=*/static_cast(sizeof(JaxPmapFunctionObject)), + /*.itemsize=*/0, +#if PY_VERSION_HEX < 0x030C0000 + /*.flags=*/Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | + Py_TPFLAGS_HAVE_VECTORCALL, +#else // PY_VERSION_HEX >= 0x030C0000 + /*.flags=*/Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | + Py_TPFLAGS_HAVE_VECTORCALL | Py_TPFLAGS_MANAGED_DICT | + Py_TPFLAGS_MANAGED_WEAKREF, +#endif // PY_VERSION_HEX >= 0x030C0000 + /*.slots=*/JaxPmapFunction_slots, + }; + + JaxPmapFunction_Type = PyType_FromSpec(&pmap_function_spec); + if (!JaxPmapFunction_Type) { + throw nb::python_error(); + } + nb::object cfun = nb::borrow(JaxPmapFunction_Type); + + // Add PmapFunction to the xla_extension module so it can be pickled. + m.attr("PmapFunction") = cfun; + + cfun.attr("__signature__") = + xla::nb_property_readonly([](nb::handle self) -> nb::object { + PmapFunction* fun = xla::ValueOrThrow(AsPmapFunction(self)); + return fun->PythonSignature(); + }); + // Required by `post_hook`. + cfun.attr("_cache_miss") = + xla::nb_property_readonly([](nb::handle self) -> nb::object { + PmapFunction* fun = xla::ValueOrThrow(AsPmapFunction(self)); + return fun->cache_miss(); + }); + cfun.attr("__getstate__") = nb::cpp_function( + [](const PmapFunction::object& self) { + PmapFunction* fn = self.func(); + nb::dict pickle; + pickle["version"] = kPmapFunctionPickleVersion; + pickle["fun"] = fn->fun(); + pickle["cache_miss"] = fn->cache_miss(); + pickle["static_argnums"] = fn->static_argnums(); + pickle["python_shard_arg_fallback"] = fn->python_shard_arg_fallback(); + pickle["pytree_registry"] = nb::cast(fn->pytree_registry()); + return pickle; + }, + nb::is_method()); + cfun.attr("__setstate__") = nb::cpp_function( + [](PmapFunction::object& self, const nb::dict& pickle) { + int version = nb::cast(pickle["version"]); + if (version != kPmapFunctionPickleVersion) { + throw std::invalid_argument(absl::StrFormat( + "Invalid PmapFunction pickle version, got %d, expected %d. " + "Pickling/Unpickling jitted functions using different JAX " + "versions is not supported.", + version, kPmapFunctionPickleVersion)); + } + nb::callable fun = nb::cast(pickle["fun"]); + nb::callable cache_miss = nb::cast(pickle["cache_miss"]); + std::vector static_argnums = + nb::cast>(pickle["static_argnums"]); + nb::callable python_shard_arg_fallback = + nb::cast(pickle["python_shard_arg_fallback"]); + xla::nb_class_ptr pytree_registry = + nb::cast>( + pickle["pytree_registry"]); + new (&(reinterpret_cast(self.ptr())->fun)) + PmapFunction(std::move(fun), std::move(cache_miss), + std::move(static_argnums), + std::move(python_shard_arg_fallback), + std::move(pytree_registry)); + }, + nb::is_method()); + + // This is only for testing/debugging purposes. + cfun.attr("_cache_size") = + xla::nb_property_readonly([](nb::handle self) -> nb::object { + PmapFunction* fun = xla::ValueOrThrow(AsPmapFunction(self)); + return nb::cast(fun->cache_size()); + }); + + cfun.attr("_cache_clear") = nb::cpp_function( + [](nb::handle self) { + PmapFunction* fun = xla::ValueOrThrow(AsPmapFunction(self)); + fun->cache_clear(); + }, + nb::is_method()); + + cfun.attr("_debug_cache_keys") = nb::cpp_function( + [](nb::handle self) -> std::string { + PmapFunction* fun = xla::ValueOrThrow(AsPmapFunction(self)); + return fun->DebugCacheKeys(); + }, + nb::is_method()); + + pmap_lib.def( + "pmap", + [](nb::callable fun, nb::callable cache_miss, + std::vector static_argnums, nb::callable shard_arg_fallback, + nb::object pytree_registry) -> nb::object { + xla::nb_class_ptr registry = + nb::cast>(pytree_registry); + return MakePmapFunction( + std::move(fun), std::move(cache_miss), std::move(static_argnums), + std::move(shard_arg_fallback), std::move(registry)); + }, + nb::arg("fun"), nb::arg("cache_miss"), nb::arg("static_argnums"), + nb::arg("shard_arg_fallback"), nb::arg("pytree_registry")); +} + +} // namespace jax diff --git a/jaxlib/xla/pmap_lib.h b/jaxlib/xla/pmap_lib.h new file mode 100644 index 000000000000..2bad85e59671 --- /dev/null +++ b/jaxlib/xla/pmap_lib.h @@ -0,0 +1,34 @@ +/* Copyright 2021 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_PMAP_LIB_H_ +#define JAXLIB_XLA_PMAP_LIB_H_ + + +// placeholder for index annotation headers +#include "nanobind/nanobind.h" + +// TODO(jblespiau): The current implementation moves the Python logic to C++, +// as a preliminary step to executing the `pmap` execution path from C++. +// It implements the current Python behavior (thus, it may not be optimal, and +// we will be able to modify it later). + +namespace jax { + +void BuildPmapSubmodule(nanobind::module_& m); + +} // namespace jax + +#endif // JAXLIB_XLA_PMAP_LIB_H_ diff --git a/jaxlib/xla/py_array.cc b/jaxlib/xla/py_array.cc new file mode 100644 index 000000000000..e3321c4c88ce --- /dev/null +++ b/jaxlib/xla/py_array.cc @@ -0,0 +1,2168 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/py_array.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include // NOLINT +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/base/casts.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "llvm/Support/Casting.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/pair.h" // IWYU pragma: keep +#include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/xla/guard_lib.h" +#include "jaxlib/xla/nb_class_ptr.h" +#include "jaxlib/xla/py_client.h" +#include "jaxlib/xla/py_device.h" +#include "jaxlib/xla/py_device_list.h" +#include "jaxlib/xla/py_values.h" +#include "jaxlib/xla/python_ref_manager.h" +#include "jaxlib/xla/sharding.h" +#include "jaxlib/xla/to_ifrt_sharding.h" +#include "jaxlib/xla/traceback.h" +#include "jaxlib/xla/util.h" +#include "xla/layout.h" +#include "xla/layout_util.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/lru_cache.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_compiler.h" +#include "xla/pjrt/pjrt_future.h" +#include "xla/pjrt/status_casters.h" +#include "xla/primitive_util.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/array_spec.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/remap_plan.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/nb_absl_span.h" // IWYU pragma: keep +#include "xla/python/nb_helpers.h" +#include "xla/python/nb_numpy.h" +#include "xla/python/pjrt_ifrt/pjrt_array.h" +#include "xla/python/pjrt_ifrt/pjrt_client.h" +#include "xla/python/pjrt_ifrt/pjrt_device.h" +#include "xla/python/pjrt_ifrt/pjrt_dtype.h" +#include "xla/python/types.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/status_macros.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/python/lib/core/numpy.h" // IWYU pragma: keep +#include "xla/util.h" +#include "xla/xla_data.pb.h" + +namespace xla { +namespace { + +namespace nb = nanobind; + +PjRtBuffer* GetPjrtBuffer(ifrt::Array* ifrt_array) { + auto* arr = llvm::dyn_cast_or_null(ifrt_array); + if (arr == nullptr) { + throw XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend only."); + } + return arr->pjrt_buffers().front().get(); +} + +absl::StatusOr XlaDynamicShape(ifrt::Array* ifrt_array, + std::optional& scratch) { + auto* pjrt_buffer = GetPjrtBuffer(ifrt_array); + + if (!scratch) { + absl::Span dims; + std::optional> logical_dims_storage; + if (pjrt_buffer->has_dynamic_dimensions()) { + { + nb::gil_scoped_release gil_release; + TF_ASSIGN_OR_RETURN(std::vector logical_dims, + pjrt_buffer->logical_dimensions()); + logical_dims_storage.emplace(std::move(logical_dims)); + } + dims = *logical_dims_storage; + } else { + dims = pjrt_buffer->dimensions(); + } + Shape shape = ShapeUtil::MakeShape(pjrt_buffer->element_type(), dims); + // TODO(b/327524065): fix this + *shape.mutable_layout() = pjrt_buffer->layout()->xla_layout(); + scratch = std::move(shape); + } + return &scratch.value(); +} + +tsl::RCReference CreateIfRtArrayFromSingleDeviceShardedPyArrays( + nb_dtype dtype, absl::Span shape, + absl::Span py_arrays, const nb::object& sharding) { + const ifrt::MemoryKind dst_memory_kind = xla::GetMemoryKind(sharding); + + std::vector> ifrt_arrays; + ifrt_arrays.reserve(py_arrays.size()); + absl::InlinedVector devices; + devices.reserve(py_arrays.size()); + absl::flat_hash_set device_set; + device_set.reserve(py_arrays.size()); + std::vector shapes; + shapes.reserve(py_arrays.size()); + + auto sharding_device_list = xla::GetIfrtDeviceList(sharding); + if (!sharding_device_list.ok()) { + // TODO(hyeontaek): Return a absl::Status. + throw nb::value_error(sharding_device_list.status().ToString().c_str()); + } + ifrt::Device* device = sharding_device_list.value()->devices().front(); + + // TODO(hyeontaek): Canonicalize every `ifrt::MemoryKind` at creation time to + // skip canonicalization here once JAX begins to do it for JAX shardings. + const ifrt::MemoryKind canonical_dst_memory_kind = + ifrt::CanonicalizeMemoryKind(dst_memory_kind, device); + for (const auto& py_array : py_arrays) { + if (py_array.num_shards() != 1) { + throw nb::value_error( + absl::StrFormat( + "When making an array from single-device arrays the input arrays " + "must have one shard each. An argument array had %d shard(s).", + py_array.num_shards()) + .c_str()); + } + ifrt_arrays.push_back(tsl::FormRef(py_array.ifrt_array())); + ifrt::Device* const device = + ifrt_arrays.back()->sharding().devices()->devices().front(); + devices.push_back(device); + device_set.insert(device); + shapes.push_back(ifrt_arrays.back()->shape()); + if (canonical_dst_memory_kind != + ifrt::CanonicalizeMemoryKind( + ifrt_arrays.back()->sharding().memory_kind(), device)) { + throw nb::value_error( + absl::StrFormat( + "Memory kind mismatch with PjRtBuffers. Got sharding with " + "memory kind '%v' and a buffer with memory_kind '%v'", + dst_memory_kind, ifrt_arrays.back()->sharding().memory_kind()) + .c_str()); + } + } + ifrt::DeviceListRef device_list = device->client()->MakeDeviceList(devices); + if (device_set.size() != device_list->size()) { + throw nb::value_error( + absl::StrFormat( + "When making an array from single-device arrays, the input arrays " + "must be from distinct devices, but got %v", + *device_list) + .c_str()); + } + + auto ifrt_dtype = DtypeToIfRtDType(dtype); + if (!ifrt_dtype.ok()) { + // TODO(hyeontaek): Return a absl::Status. + throw nb::value_error(ifrt_dtype.status().ToString().c_str()); + } + + absl::StatusOr> ifrt_sharding = + sharding.type().is(jax::PmapSharding::type()) + ? xla::GetIfrtConcreteSharding(sharding, ifrt::Shape(shape), + std::move(shapes)) + : xla::GetIfrtHloSharding(sharding, ifrt::Shape(shape)); + if (!ifrt_sharding.ok()) { + // TODO(hyeontaek): Return a absl::Status. + throw nb::value_error(ifrt_sharding.status().ToString().c_str()); + } + // TODO(emilyaf): Always use `ifrt_dtype` once tokens are handled correctly. + ifrt::DType array_dtype = + ifrt_arrays.empty() ? ifrt_dtype.value() : ifrt_arrays[0]->dtype(); + absl::StatusOr> ifrt_array = + device->client()->AssembleArrayFromSingleDeviceArrays( + array_dtype, ifrt::Shape(shape), *std::move(ifrt_sharding), + absl::MakeSpan(ifrt_arrays), ifrt::ArrayCopySemantics::kReuseInput, + ifrt::SingleDeviceShardSemantics::kAddressableShards); + if (!ifrt_array.ok()) { + // TODO(hyeontaek): Return a absl::Status. + throw nb::value_error(ifrt_array.status().ToString().c_str()); + } + return *std::move(ifrt_array); +} + +struct PyBaseArrayObject { + PyObject_HEAD; +#if PY_VERSION_HEX < 0x030C0000 + PyObject* weakrefs; +#endif // PY_VERSION_HEX < 0x030C0000 +}; + +extern "C" void PyBaseArray_tp_dealloc(PyBaseArrayObject* self) { + PyObject_GC_UnTrack(self); + PyObject_ClearWeakRefs((PyObject*)self); + PyTypeObject* tp = Py_TYPE(self); + tp->tp_free((PyObject*)self); + Py_DECREF(tp); +} + +extern "C" int PyBaseArray_tp_traverse(PyObject* self, visitproc visit, + void* arg) { + Py_VISIT(Py_TYPE(self)); + return 0; +} + +struct PyArrayObject { + PyObject_HEAD; +#if PY_VERSION_HEX < 0x030C0000 + PyObject* weakrefs; + PyObject* dict; +#endif // PY_VERSION_HEX < 0x030C0000 + bool initialized; + alignas(PyArray::Storage) char array_storage[sizeof(PyArray::Storage)]; +}; +static_assert(std::is_standard_layout::value); + +PyArray::Storage* GetPyArrayStorageFromObject(PyArrayObject* py_array_object) { + return std::launder( + reinterpret_cast(py_array_object->array_storage)); +} + +extern "C" PyObject* PyArray_tp_new(PyTypeObject* type, PyObject*, PyObject*) { + PyObject* self = type->tp_alloc(type, 0); + auto* obj = reinterpret_cast(self); + obj->initialized = false; + return self; +} + +extern "C" void PyArray_tp_dealloc(PyObject* self) { + PyObject_GC_UnTrack(self); + PyTypeObject* tp = Py_TYPE(self); + auto* obj = reinterpret_cast(self); + + if (obj->initialized) { + GetPyArrayStorageFromObject(obj)->~PyArray_Storage(); + } + + PyObject_ClearWeakRefs(self); +#if PY_VERSION_HEX < 0x030C0000 + PyObject*& dict = *_PyObject_GetDictPtr(self); + Py_CLEAR(dict); +#elif PY_VERSION_HEX < 0x030D0000 + _PyObject_ClearManagedDict(self); +#else + PyObject_ClearManagedDict(self); +#endif // PY_VERSION_HEX < 0x030C0000 + + tp->tp_free(self); + Py_DECREF(tp); +} + +// dynamic_attr: Allow the garbage collector to traverse the internal instance +// `__dict__`. +extern "C" int PyArray_tp_traverse(PyObject* self, visitproc visit, void* arg) { +#if PY_VERSION_HEX < 0x030C0000 + PyObject*& dict = *_PyObject_GetDictPtr(self); + Py_VISIT(dict); +#elif PY_VERSION_HEX < 0x030D0000 + _PyObject_VisitManagedDict(self, visit, arg); +#else + PyObject_VisitManagedDict(self, visit, arg); +#endif // PY_VERSION_HEX < 0x030C0000 + // https://docs.python.org/3/c-api/typeobj.html#c.PyTypeObject.tp_traverse + Py_VISIT(Py_TYPE(self)); + return 0; +} + +// dynamic_attr: Allow the GC to clear the dictionary. +extern "C" int PyArray_tp_clear(PyObject* self) { + switch (auto guard_level = jax::GetGarbageCollectArrayGuard(); guard_level) { + case jax::GarbageCollectionGuardLevel::kAllow: + break; + case jax::GarbageCollectionGuardLevel::kLog: + case jax::GarbageCollectionGuardLevel::kFatal: { + auto* obj = reinterpret_cast(self); + std::string traceback_str; + if (obj->initialized) { + auto traceback = GetPyArrayStorageFromObject(obj)->traceback; + if (traceback.has_value()) { + traceback_str = traceback.value()->ToString(); + } + } + auto error_msg = absl::StrCat( + "`jax.Array` was deleted by the Python garbage collector " + "instead of reference counting. Break the reference cycle " + "that delays the deletion of this `jax.Array` to avoid hogging " + "memory. Traceback: \n", + traceback_str.empty() ? "not available" : traceback_str); + if (guard_level == jax::GarbageCollectionGuardLevel::kFatal) { + Py_FatalError(error_msg.c_str()); + } else { + PyErr_SetString(PyExc_RuntimeError, error_msg.c_str()); + PyErr_Print(); + PyErr_Clear(); + } + break; + } + } +#if PY_VERSION_HEX < 0x030C0000 + PyObject*& dict = *_PyObject_GetDictPtr(self); + Py_CLEAR(dict); +#elif PY_VERSION_HEX < 0x030D0000 + _PyObject_ClearManagedDict(self); +#else + PyObject_ClearManagedDict(self); +#endif // PY_VERSION_HEX < 0x030C0000 + return 0; +} + +template +PyArray::Storage* Construct(PyArrayObject* self, Args&&... args) { + PyArray::Storage* out = + new (self->array_storage) PyArray::Storage(std::forward(args)...); + self->initialized = true; + return out; +} + +struct ShapedArrayCacheKey { + std::vector dims; + ifrt::DType dtype{ifrt::DType::kInvalid}; + bool weak_type; + + template + friend H AbslHashValue(H h, const ShapedArrayCacheKey& value) { + return H::combine(std::move(h), value.dims, value.dtype, value.weak_type); + } + bool operator==(const ShapedArrayCacheKey& other) const { + return dims == other.dims && dtype == other.dtype && + weak_type == other.weak_type; + } +}; + +// Constructing ShapedArrays has gotten slow. Cache it. +nb::object MakeShapedArrayCached(const ShapedArrayCacheKey& key) { + using CacheT = + LRUCache>>; + static nb::ft_mutex mu; + static auto* lru_list = new CacheT::LRUList(4096); + static auto* cache = new CacheT(lru_list); + + static const nb::object* shaped_array = []() -> nb::object* { + nb::object jax_core; + try { + jax_core = nb::module_::import_("jax.core"); + } catch (nb::python_error& e) { + return nullptr; + } + return new nb::object(jax_core.attr("ShapedArray")); + }(); + if (!shaped_array) { + return nb::none(); + } + + nb::ft_lock_guard lock(mu); + auto value = + cache->GetOrCreateIfAbsent(key, [](const ShapedArrayCacheKey& key) { + return std::make_shared>(); + }); + + if (!value->has_value()) { + nb_dtype dtype = + IfrtDtypeToDtypeWithTokenCanonicalization(key.dtype).value(); + nb::object aval = (*shaped_array)( + SpanToNbTuple(absl::Span( + key.dtype.kind() == ifrt::DType::kToken ? std::vector{0} + : key.dims)), + dtype, key.weak_type); + *value = aval; + return aval; + } + return **value; +} + +// Grouping key used by BatchedCopyToDeviceWithSharding. +// Defined outside of the function as required by templatized function +// `AbslHashValue`. +struct BatchedCopyToDeviceWithShardingKey { + ifrt::DeviceListRef src_devices; + ifrt::MemoryKind src_memory_kind; + ifrt::DeviceListRef dst_devices; + ifrt::MemoryKind dst_memory_kind; + ifrt::ArrayCopySemantics array_copy_semantics; + + bool operator==(const BatchedCopyToDeviceWithShardingKey& other) const { + return *src_devices == *other.src_devices && + src_memory_kind == other.src_memory_kind && + *dst_devices == *other.dst_devices && + dst_memory_kind == other.dst_memory_kind && + array_copy_semantics == other.array_copy_semantics; + } + + template + friend H AbslHashValue(H h, const BatchedCopyToDeviceWithShardingKey& key) { + return H::combine(std::move(h), key.src_devices, key.src_memory_kind, + key.dst_devices, key.dst_memory_kind, + key.array_copy_semantics); + } +}; + +} // namespace + +PyArray_Storage::PyArray_Storage( + nb::object aval, bool weak_type, xla::nb_dtype dtype, + std::vector shape, nb::object sharding, bool committed, + nb_class_ptr py_client, std::optional traceback, + tsl::RCReference ifrt_array, xla::PjRtFuture<> result_status) + : aval(std::move(aval)), + weak_type(weak_type), + dtype(std::move(dtype)), + shape(std::move(shape)), + sharding(std::move(sharding)), + committed(committed), + py_client(std::move(py_client)), + traceback(std::move(traceback)), + ifrt_array(std::move(ifrt_array)), + result_status(std::move(result_status)) { + static_assert(PyClient::kNumArraysShards < + std::numeric_limits::max()); + thread_id_bucket = std::hash()(std::this_thread::get_id()) % + PyClient::kNumArraysShards; + + PyClient::ArraysShard& shard = this->py_client->arrays_[thread_id_bucket]; + nanobind::ft_lock_guard lock(shard.mutex); + next = shard.arrays; + shard.arrays = this; + if (next) { + next->prev = this; + } + prev = nullptr; +} + +void PyInit_helper(PyArray self, nb::object aval, nb::object sharding, + absl::Span py_arrays, bool committed) { + auto dtype = nb::cast(aval.attr("dtype")); + auto shape = nb::cast>(aval.attr("shape")); + auto py_device_list = nb::cast( + sharding.attr("_internal_device_list")); + nb_class_ptr py_client = py_device_list->py_client(); + auto ifrt_array = CreateIfRtArrayFromSingleDeviceShardedPyArrays( + dtype, shape, py_arrays, sharding); + Construct(reinterpret_cast(self.ptr()), aval, + nb::cast(aval.attr("weak_type")), std::move(dtype), + std::move(shape), std::move(sharding), committed, py_client, + Traceback::Get(), std::move(ifrt_array), xla::PjRtFuture<>()); +} + +void PyArray::PyInit(PyArray self, nb::object aval, nb::object sharding, + absl::Span py_arrays, bool committed, + bool skip_checks) { + if (skip_checks) { + PyInit_helper(self, aval, sharding, py_arrays, committed); + } else { + nb::object rearranged_arrays = + self.CheckAndRearrange(py_arrays, sharding, aval); + auto rearranged_py_arrays = + nb::cast>(rearranged_arrays); + PyInit_helper(self, aval, sharding, rearranged_py_arrays, committed); + } +} + +PyArray PyArray::MakeFromSingleDeviceArray( + nb_class_ptr py_client, std::optional traceback, + tsl::RCReference ifrt_array, bool weak_type, bool committed, + xla::PjRtFuture<> result_status) { + if (!llvm::isa(ifrt_array->sharding())) { + throw XlaRuntimeError( + InvalidArgument("Constructing single device jax.Array from non-single " + "device ifrt array.")); + } + auto shape_span = ifrt_array->shape().dims(); + ShapedArrayCacheKey key; + key.dtype = ifrt_array->dtype(); + key.dims = key.dtype.kind() == ifrt::DType::kToken + ? std::vector{0} + : std::vector(shape_span.begin(), shape_span.end()); + key.weak_type = weak_type; + auto aval = MakeShapedArrayCached(key); + auto dtype = IfrtDtypeToDtypeWithTokenCanonicalization(key.dtype).value(); + const ifrt::MemoryKind memory_kind = ifrt_array->sharding().memory_kind(); + nb::object py_memory_kind = + (memory_kind.memory_kind().has_value()) + ? nb::object(nb::str(memory_kind.memory_kind()->data(), + memory_kind.memory_kind()->size())) + : nb::none(); + nb::object sharding = make_nb_class( + py_client, ifrt_array->sharding().devices(), std::move(py_memory_kind)); + return PyArray(std::move(aval), weak_type, dtype, std::move(key.dims), + std::move(sharding), std::move(py_client), + std::move(traceback), std::move(ifrt_array), committed, + /*skip_checks=*/true, std::move(result_status)); +} + +PyArray PyArray::MakeFromIfrtArrayAndSharding( + nb_class_ptr py_client, std::optional traceback, + tsl::RCReference ifrt_array, nb::object sharding, + bool weak_type, bool committed, bool skip_checks) { + auto shape_span = ifrt_array->shape().dims(); + ShapedArrayCacheKey key; + key.dtype = ifrt_array->dtype(); + key.dims = key.dtype.kind() == ifrt::DType::kToken + ? std::vector{0} + : std::vector(shape_span.begin(), shape_span.end()); + key.weak_type = weak_type; + auto aval = MakeShapedArrayCached(key); + auto dtype = IfrtDtypeToDtypeWithTokenCanonicalization(key.dtype).value(); + return PyArray(std::move(aval), weak_type, dtype, std::move(key.dims), + std::move(sharding), std::move(py_client), + std::move(traceback), std::move(ifrt_array), committed, + skip_checks); +} + +PyArrayResultHandler::PyArrayResultHandler(nb::object aval, nb::object sharding, + bool committed, bool skip_checks) + : aval_(std::move(aval)), + sharding_(std::move(sharding)), + committed_(committed), + skip_checks_(skip_checks) { + weak_type_ = nb::cast(aval_.attr("weak_type")); + dtype_ = nb::cast(aval_.attr("dtype")); + shape_ = nb::cast>(aval_.attr("shape")); +} + +PyArray PyArrayResultHandler::Call(absl::Span py_arrays) const { + auto py_device_list = jax::GetPyDeviceList(sharding_); + if (!py_device_list.ok()) { + throw nb::value_error( + absl::StrCat("Failed to get py device list from sharding: ", + py_device_list.status().ToString()) + .c_str()); + } + return Call(py_device_list.value()->py_client(), + CreateIfRtArrayFromSingleDeviceShardedPyArrays( + dtype_, shape_, py_arrays, sharding_), + xla::PjRtFuture<>()); +} + +PyArray PyArrayResultHandler::Call(nb_class_ptr py_client, + tsl::RCReference ifrt_array, + xla::PjRtFuture<> result_status) const { + return PyArray(aval_, weak_type_, dtype_, shape_, sharding_, + std::move(py_client), Traceback::Get(), std::move(ifrt_array), + committed_, skip_checks_, std::move(result_status)); +} + +PyArray PyArrayResultHandler::Call(PyArray py_array) const { + return Call(py_array.py_client(), tsl::FormRef(py_array.ifrt_array()), + xla::PjRtFuture<>()); +} + +PyArray::PyArray(nb::object aval, bool weak_type, nb_dtype dtype, + std::vector shape, nb::object sharding, + nb_class_ptr py_client, + std::optional traceback, + tsl::RCReference ifrt_array, bool committed, + bool skip_checks, xla::PjRtFuture<> result_status) { + auto* self = + PyArray_tp_new(reinterpret_cast(type_), nullptr, nullptr); + m_ptr = self; + Construct(reinterpret_cast(self), std::move(aval), weak_type, + std::move(dtype), std::move(shape), std::move(sharding), committed, + std::move(py_client), std::move(traceback), std::move(ifrt_array), + std::move(result_status)); + + if (!skip_checks) { + this->attr("_arrays") = this->attr("_check_and_rearrange")( + this->attr("_arrays"), this->attr("_sharding"), this->attr("aval")); + } +} + +PyArray::Storage& PyArray::GetStorage() { + return *GetPyArrayStorageFromObject(reinterpret_cast(ptr())); +} + +const PyArray::Storage& PyArray::GetStorage() const { + return *GetPyArrayStorageFromObject(reinterpret_cast(ptr())); +} + +nb::object PyArray::CheckAndRearrange(const absl::Span py_arrays, + const nb::object sharding, + const nb::object aval) { + return this->attr("_check_and_rearrange")(py_arrays, sharding, aval); +} + +void PyArray::SetIfrtArray(tsl::RCReference ifrt_array) { + GetStorage().ifrt_array = std::move(ifrt_array); +} + +const std::vector& PyArray::py_arrays_cached() { + auto& py_arrays = this->py_arrays(); + + if (py_arrays.empty()) { + auto ifrt_arrays = ifrt_array()->DisassembleIntoSingleDeviceArrays( + ifrt::ArrayCopySemantics::kReuseInput, + ifrt::SingleDeviceShardSemantics::kAddressableShards); + if (!ifrt_arrays.ok()) { + throw nb::value_error( + absl::StrCat("Failed to disassemble into single-device arrays: ", + ifrt_arrays.status().ToString()) + .c_str()); + } + py_arrays.reserve(ifrt_arrays->size()); + for (auto& ifrt_array : *ifrt_arrays) { + py_arrays.push_back(PyArray::MakeFromSingleDeviceArray( + py_client(), traceback(), std::move(ifrt_array), weak_type(), + committed(), result_status())); + } + } + + return py_arrays; +} + +nb::object PyArray::arrays() { + // For performance, we only keep pjrt buffers by default. But on python side + // "_arrays" returns PyArrays instead, and subsequent calls to "_arrays" + // should return the same PyArrays (to avoid duplicate device to host + // transfers). So we create PyArrays the first time it is called and reuse + // them later. + if (ifrt_array() == nullptr || ifrt_array()->IsDeleted()) return nb::none(); + + if (llvm::isa(&ifrt_array()->sharding())) { + std::vector py_arrays; + py_arrays.push_back(*this); + return nb::cast(py_arrays); + } + + return nb::cast(py_arrays_cached()); +} + +absl::Status PyArray::set_arrays(nb::object obj) { + if (obj.is_none()) { + SetIfrtArray(tsl::RCReference()); + py_arrays().clear(); + return absl::OkStatus(); + } + + if (!nb::isinstance(obj)) { + return InvalidArgument("Unsupported arg when setting Array._arrays: %s", + nb::cast(nb::str(obj.type()))); + } + + nb::list list(obj); + + if (list.size() == 0) return absl::OkStatus(); + + SetIfrtArray(tsl::RCReference()); + py_arrays().clear(); + std::vector> ifrt_arrays; + ifrt_arrays.reserve(list.size()); + absl::InlinedVector devices; + devices.reserve(list.size()); + std::vector shapes; + shapes.reserve(list.size()); + for (nb::handle obj : list) { + if (obj.type().is(PyArray::type())) { + auto py_array = nb::borrow(obj); + if (py_array.py_client().get() != py_client().get()) { + return InvalidArgument("Client mismatch when assigning to _arrays."); + } + if (py_array.num_shards() != 1) { + return InvalidArgument("Wrong number of shards: %d", + py_array.num_shards()); + } + ifrt_arrays.push_back(tsl::FormRef(py_array.ifrt_array())); + devices.push_back( + ifrt_arrays.back()->sharding().devices()->devices().front()); + shapes.push_back(ifrt_arrays.back()->shape()); + } else { + return InvalidArgument("Unsupported arg when setting Array._arrays: %s", + nb::cast(nb::str(obj.type()))); + } + } + const ifrt::MemoryKind first_memory_kind = + ifrt_arrays.front()->sharding().memory_kind(); + // TODO(hyeontaek): Canonicalize every `ifrt::MemoryKind` at creation time to + // skip canonicalization here once JAX begins to do it for JAX shardings. + const ifrt::MemoryKind canonical_first_memory_kind = + ifrt::CanonicalizeMemoryKind( + first_memory_kind, + ifrt_arrays.front()->sharding().devices()->devices().front()); + for (const auto& ifrt_array : ifrt_arrays) { + if (canonical_first_memory_kind != + ifrt::CanonicalizeMemoryKind( + ifrt_array->sharding().memory_kind(), + ifrt_array->sharding().devices()->devices().front())) { + throw nb::value_error( + absl::StrFormat( + "Memory kind mismatch between single-device arrays. Got one " + "array with memory kind '%v' and another with memory_kind '%v'", + first_memory_kind, ifrt_array->sharding().memory_kind()) + .c_str()); + } + } + + TF_ASSIGN_OR_RETURN( + auto ifrt_sharding, + sharding().type().is(jax::PmapSharding::type()) + ? xla::GetIfrtConcreteSharding(sharding(), ifrt::Shape(shape()), + std::move(shapes)) + : xla::GetIfrtHloSharding(sharding(), ifrt::Shape(shape()))); + TF_ASSIGN_OR_RETURN( + auto array, + py_client()->ifrt_client()->AssembleArrayFromSingleDeviceArrays( + ifrt::Shape(shape()), std::move(ifrt_sharding), + absl::MakeSpan(ifrt_arrays), ifrt::ArrayCopySemantics::kReuseInput, + ifrt::SingleDeviceShardSemantics::kAddressableShards)); + SetIfrtArray(std::move(array)); + return absl::OkStatus(); +} + +absl::StatusOr PyArray::FullyReplicatedShard() { + auto& cached = GetStorage().fully_replicated_array; + if (!cached.is_none()) { + return nb::cast(cached); + } + + if (ifrt_array() == nullptr) { + return InvalidArgument( + "FullyReplicatedShard() called on deleted or donated buffer"); + } + + TF_ASSIGN_OR_RETURN(auto fully_replicated_ifrt_shard, + ifrt_array()->FullyReplicatedShard( + ifrt::ArrayCopySemantics::kReuseInput)); + auto array = MakeFromSingleDeviceArray( + py_client(), traceback(), std::move(fully_replicated_ifrt_shard), + weak_type(), committed(), result_status()); + cached = array; + return nb::cast(cached); +} + +absl::Status PyArray::BlockUntilReady() const { + nb::gil_scoped_release gil_release; + if (ifrt_array() == nullptr) { + return InvalidArgument( + "BlockHostUntilReady() called on deleted or donated buffer"); + } + ifrt::Array* ifrt_array = this->ifrt_array(); + return AwaitBuffersReady(absl::MakeConstSpan(&ifrt_array, 1)); +} + +absl::StatusOr PyArray::GetOnDeviceSizeInBytes() { + if (ifrt_array() == nullptr) { + return InvalidArgument( + "GetOnDeviceSizeInBytes() called on deleted or donated buffer"); + } + + TF_ASSIGN_OR_RETURN(size_t shard_size, + GetPjrtBuffer(ifrt_array())->GetOnDeviceSizeInBytes()); + return shard_size * nb::len(nb::object(sharding().attr("device_set"))); +} + +absl::Status PyArray::BlockUntilResultStatusIsReady() { + auto& result_status = GetStorage().result_status; + // If the result_status future is not valid, this result did not come directly + // from a computation that returns tokens, so we don't wait for the status. + if (!result_status.IsValid()) { + return absl::OkStatus(); + } + if (!result_status.IsReady()) { + // Only release the gil if we need to Await(). + nb::gil_scoped_release release_gil; + BlockUntilReadyWithCancel(result_status); + return result_status.Await(); + } + return result_status.Await(); +} + +absl::StatusOr> +PyArray::SingleDeviceArrayToNumpyArrayDidCopy() { + TF_ASSIGN_OR_RETURN(auto arr, FullyReplicatedShard()); + auto result = arr.GetStorage().host_value.AsNumPyArray( + arr.GetStorage().dynamic_shape, arr.ifrt_array()); + TF_RETURN_IF_ERROR(arr.BlockUntilResultStatusIsReady()); + return result; +} + +absl::StatusOr PyArray::SingleDeviceArrayToNumpyArray() { + TF_ASSIGN_OR_RETURN(auto result, SingleDeviceArrayToNumpyArrayDidCopy()); + return result.first; +} + +absl::Status PyArray::CopySingleDeviceArrayToHostAsync() { + TF_ASSIGN_OR_RETURN(auto arr, FullyReplicatedShard()); + return arr.GetStorage().host_value.CopyToHostAsync( + arr.GetStorage().dynamic_shape, arr.ifrt_array()); +} + +absl::StatusOr PyArray::AssertUnsharded(absl::string_view api) { + if (ifrt_array() == nullptr) { + return InvalidArgument("%s( called on deleted or donated buffer", api); + } + + if (llvm::isa(&ifrt_array()->sharding())) { + return *this; + } + + auto& py_arrays = py_arrays_cached(); + if (py_arrays.size() != 1) { + return InvalidArgument("%s() is supported only for unsharded arrays.", api); + } + return py_arrays[0]; +} + +absl::StatusOr PyArray::UnsafeBufferPointer() { + TF_ASSIGN_OR_RETURN(auto arr, AssertUnsharded("UnsafeBufferPointer")); + + return py_client()->pjrt_client()->UnsafeBufferPointer( + GetPjrtBuffer(arr.ifrt_array())); +} + +nb::dict PyArray::CudaArrayInterface() { + auto arr_or_error = AssertUnsharded("UnsafeBufferPointer"); + if (!arr_or_error.ok()) { + throw nb::attribute_error( + "__cuda_array_interface__ is only supported for unsharded arrays."); + } + auto arr = *arr_or_error; + + ifrt::Array* ifrt_array = arr.ifrt_array(); + std::optional& scratch = arr.GetStorage().dynamic_shape; + auto* pjrt_buffer = GetPjrtBuffer(ifrt_array); + if (pjrt_buffer->client()->platform_id() != CudaId()) { + throw nb::attribute_error( + "__cuda_array_interface__ is only defined for NVidia GPU buffers."); + } + if (pjrt_buffer->IsTuple()) { + throw nb::attribute_error( + "__cuda_array_interface__ is only defined for array buffers."); + } + + switch (pjrt_buffer->element_type()) { + case PrimitiveType::PRED: + case PrimitiveType::S8: + case PrimitiveType::S16: + case PrimitiveType::S32: + case PrimitiveType::S64: + case PrimitiveType::U8: + case PrimitiveType::U16: + case PrimitiveType::U32: + case PrimitiveType::U64: + case PrimitiveType::F16: + case PrimitiveType::F32: + case PrimitiveType::F64: + case PrimitiveType::C64: + case PrimitiveType::C128: + break; + + default: + throw nb::attribute_error( + absl::StrFormat( + "__cuda_array_interface__ is not supported for %s buffers.", + PrimitiveType_Name(pjrt_buffer->element_type())) + .c_str()); + } + + nb::str typestr = + ValueOrThrow(TypeDescriptorForPrimitiveType(pjrt_buffer->element_type())); + + // TODO(b/327524065): use PjRtLayout directly instead of xla::Layout + Layout xla_layout = pjrt_buffer->layout()->xla_layout(); + if (!LayoutUtil::IsMonotonicWithDim0Major(xla_layout)) { + throw nb::attribute_error( + "__cuda_array_interface__ is only currently supported for " + "buffers in row-major order."); + } + + nb::dict result; + const auto* dynamic_shape = + ValueOrThrow(XlaDynamicShape(ifrt_array, scratch)); + result["shape"] = SpanToNbTuple(dynamic_shape->dimensions()); + result["typestr"] = std::move(typestr); + std::unique_ptr external_reference_hold = + ValueOrThrow(pjrt_buffer->AcquireExternalReference()); + const void* root_ptr = + external_reference_hold->OpaqueDeviceMemoryDataPointer(); + nb::tuple data = + nb::make_tuple(nb::int_(absl::bit_cast(root_ptr)), + nb::bool_(true) /* read-only */ + ); + result["data"] = std::move(data); + result["version"] = nb::int_(2); + return result; +} + +absl::StatusOr CudaArrayInterfaceToBuffer( + const nb::dict& cai, nb_class_ptr client, + std::optional device_id) { + if (!cai.contains("data")) { + return absl::InvalidArgumentError( + "CUDA Array Interface does not define `data`"); + } + if (!cai.contains("shape")) { + return absl::InvalidArgumentError( + "CUDA Array Interface does not define `shape`"); + } + if (!cai.contains("typestr")) { + return absl::InvalidArgumentError( + "CUDA Array Interface does not define `typestr`"); + } + if (!cai.contains("version")) { + return absl::InvalidArgumentError( + "CUDA Array Interface does not define `version`"); + } + auto version = nb::cast(cai["version"]); + if (version < 2 || version > 3) { + LOG(WARNING) << "CUDA Array Interface version " << version + << " support is undefined"; + } + auto data = nb::cast(cai["data"]); + auto data_value = nb::cast(data[0]); + void* data_ptr = reinterpret_cast(data_value); + auto dimensions = nb::cast>(cai["shape"]); + if (data_value == 0 && absl::c_find(dimensions, 0) == dimensions.end()) { + return absl::InvalidArgumentError( + "CUDA Array Interface `data`(=NULL) and `shape`(no zero-valued " + "dimensions) are inconsistent"); + } + auto ndim = dimensions.size(); + TF_ASSIGN_OR_RETURN( + PrimitiveType element_type, + DtypeToPrimitiveType(nb_dtype::from_args(cai["typestr"]))); + + if (!device_id.has_value()) { + throw XlaRuntimeError( + "This operation requires CUDA support from jaxlib or jax cuda plugin."); + } + TF_ASSIGN_OR_RETURN(auto device, + client->DeviceFromLocalHardwareId(*device_id)); + bool is_default_stream = + data_value == 0 || version == 2 || + (version == 3 && (!cai.contains("stream") || cai["stream"].is_none())); + TF_ASSIGN_OR_RETURN( + std::intptr_t stream, + ([is_default_stream, cai, device]() -> absl::StatusOr { + if (is_default_stream) { + return device->GetStreamForExternalReadyEvents(); + } else { + auto stream_ = nb::cast(cai["stream"]); + if (stream_ == 0) { + return absl::InvalidArgumentError( + "CUDA Array Interface does not allow zero stream value"); + } + return stream_; + } + }())); + + std::vector minor_to_major(ndim); + if (cai.contains("strides") && !cai["strides"].is_none() && data_value != 0) { + std::iota(minor_to_major.begin(), minor_to_major.end(), 0); + auto strides = nb::cast>(cai["strides"]); + if (strides.size() != ndim) { + return absl::InvalidArgumentError( + "CUDA Array Interface `shape` and `strides` dimensionalities are " + "inconsistent"); + } + absl::c_sort(minor_to_major, [&](int a, int b) { + // If two dimensions have the same stride, prefer the major-to-minor + // interpretation of the ordering, since that's what JAX wants. + return (strides[a] == strides[b] ? b < a : strides[a] < strides[b]); + }); + int64_t stride = ShapeUtil::ByteSizeOfPrimitiveType(element_type); + for (int64_t d : minor_to_major) { + if (dimensions[d] > 1 && strides[d] != stride) { + return absl::UnimplementedError(absl::StrCat( + "Only arrays with trivial (compact) striding are supported; " + "i.e., arrays whose striding represents a transposition of the " + "underlying buffer but not broadcasting. Dimensions were: [%s], " + "strides were [%s].", + absl::StrJoin(dimensions, ","), absl::StrJoin(strides, ","))); + } + stride *= dimensions[d]; + } + } else { + std::iota(minor_to_major.rbegin(), minor_to_major.rend(), 0); + } + Shape shape = ShapeUtil::MakeShapeWithDenseLayout(element_type, dimensions, + minor_to_major); + std::function on_delete_callback = []() {}; + auto* pjrt_device = + llvm::dyn_cast_or_null(device->device()); + if (pjrt_device == nullptr) { + return InvalidArgument( + "This operation is implemented for a PjRt-compatible backend only."); + } + TF_RET_CHECK(pjrt_device->IsAddressable()); + TF_ASSIGN_OR_RETURN( + auto pjrt_buffer, + device->client()->pjrt_client()->CreateViewOfDeviceBuffer( + static_cast(data_ptr), shape, + *pjrt_device->pjrt_device()->default_memory_space(), + on_delete_callback, + stream <= 2 ? std::nullopt : std::make_optional(stream))); + auto* ifrt_client = + llvm::dyn_cast_or_null(client->ifrt_client()); + if (ifrt_client == nullptr) { + throw XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend only."); + } + TF_ASSIGN_OR_RETURN(auto ifrt_array, + ifrt_client->CreatePjRtArray(std::move(pjrt_buffer))); + return PyArray::MakeFromSingleDeviceArray(std::move(client), Traceback::Get(), + std::move(ifrt_array), false, true); +} + +absl::Status PyArray::Delete() { + for (auto& arr : py_arrays()) { + TF_RETURN_IF_ERROR(arr.Delete()); + } + py_arrays().clear(); + if (ifrt_array() != nullptr) { + // We do not wait for the deletion to complete here. + // + // (1) Skipping blocking does not affect the correctness of deletion as long + // as the runtime preserves dispatch ordering of deletion w.r.t. other + // operations. + // + // (2) Synchronously waiting for the deletion to complete is very expensive + // when the deletion can return a status only after the underlying physical + // buffer has been deleted or a request must be processed via RPC, + // especially as this deletion is done per array. + ifrt_array()->Delete(); + SetIfrtArray(tsl::RCReference()); + } + return absl::OkStatus(); +} + +bool PyArray::IsDeleted() const { + if (ifrt_array() == nullptr) { + return true; + } + + return ifrt_array()->IsDeleted(); +} + +PyArray PyArray::Clone() const { + auto array = tsl::FormRef(ifrt_array()); + auto* ifrt_client = py_client()->ifrt_client(); + tsl::RCReference out = + ifrt_client + ->CopyArrays(absl::MakeSpan(&array, 1), /*devices=*/std::nullopt, + /*memory_kind=*/std::nullopt, + ifrt::ArrayCopySemantics::kReuseInput) + .value() + .front(); + return PyArray(aval(), weak_type(), dtype(), + std::vector(shape().begin(), shape().end()), + sharding(), py_client(), traceback(), std::move(out), + committed(), /*skip_checks=*/true, result_status()); +} + +nb::handle PyArray::Storage::AsHandle() { + return reinterpret_cast(reinterpret_cast(this) - + offsetof(PyArrayObject, array_storage)); +} + +PyArray::Storage::~PyArray_Storage() { + CHECK(PyGILState_Check()); + if (py_client) { + PyClient::ArraysShard& shard = py_client->arrays_[thread_id_bucket]; + nanobind::ft_lock_guard lock(shard.mutex); + if (shard.arrays == this) { + shard.arrays = next; + } + if (prev) { + prev->next = next; + } + if (next) { + next->prev = prev; + } + } + // Release GIL and then explicitly destroy `ifrt_array` to prevent deadlock on + // CPU backend caused by interactions between argument donations and host + // callbacks. + nb::gil_scoped_release gil_release; + ifrt_array.reset(); +} + +absl::StatusOr> PyArray::BatchedCopyToDeviceWithSharding( + absl::Span py_arrays, + absl::Span dst_device_lists, + absl::Span dst_shardings, + absl::Span array_copy_semantics) { + if (py_arrays.empty()) { + return std::vector(); + } + + TF_RET_CHECK(py_arrays.size() == dst_device_lists.size()); + TF_RET_CHECK(py_arrays.size() == dst_shardings.size()); + + ifrt::Client* const client = py_arrays.front().ifrt_array()->client(); + std::vector results(py_arrays.size()); + + // Arrays to be copied, grouped by source/destination devices and memory + // kinds. The grouping is enforced by `ifrt::Client::CopyArrays()`. + struct Batch { + std::vector indexes; + std::vector> ifrt_arrays; + }; + absl::flat_hash_map batches; + + for (int i = 0; i < py_arrays.size(); ++i) { + const auto& py_array = py_arrays[i]; + const auto& dst_sharding = dst_shardings[i]; + const auto& array_cs = array_copy_semantics[i]; + + auto* ifrt_array_ptr = py_array.ifrt_array(); + const ifrt::DeviceListRef& src_devices = + ifrt_array_ptr->sharding().devices(); + const ifrt::DeviceListRef& dst_devices = dst_device_lists[i]; + + ifrt::MemoryKind src_memory_kind = + ifrt::CanonicalizeMemoryKind(ifrt_array_ptr->sharding().memory_kind(), + src_devices->devices().front()); + ifrt::MemoryKind dst_memory_kind = ifrt::CanonicalizeMemoryKind( + xla::GetMemoryKind(dst_sharding), dst_devices->devices().front()); + + if (*src_devices == *dst_devices && src_memory_kind == dst_memory_kind && + array_cs == ifrt::ArrayCopySemantics::kReuseInput) { + results[i] = py_arrays[i]; + continue; + } + + auto transfer_guard_formatter = [&py_array, &dst_sharding] { + return absl::StrCat( + "aval=", nb::cast(nb::repr(py_array.aval())), + ", sharding=", + nb::cast(nb::repr(py_array.sharding())), + ", dst_sharding=", + nb::cast(nb::repr(dst_sharding))); + }; + TF_RETURN_IF_ERROR( + jax::ApplyTransferGuardToDeviceToDevice(transfer_guard_formatter)); + + Batch& batch = batches[BatchedCopyToDeviceWithShardingKey{ + src_devices, src_memory_kind, dst_devices, dst_memory_kind, array_cs}]; + batch.indexes.push_back(i); + batch.ifrt_arrays.push_back(tsl::FormRef(ifrt_array_ptr)); + } + + std::vector>> ifrt_arrays; + { + GlobalPyRefManager()->CollectGarbage(); + nb::gil_scoped_release gil_release; + + for (auto& [key, batch] : batches) { + TF_ASSIGN_OR_RETURN( + auto copied, + client->CopyArrays( + absl::MakeSpan(batch.ifrt_arrays), + // All arrays in `batch` have the same `key.dst_devices` and + // `key.dst_memory_kind` due to the grouping above. + key.dst_devices, key.dst_memory_kind, key.array_copy_semantics)); + for (int i = 0; i < batch.indexes.size(); ++i) { + ifrt_arrays.push_back( + std::make_pair(batch.indexes[i], std::move(copied[i]))); + } + } + } + + auto traceback = Traceback::Get(); + for (auto& [i, ifrt_array] : ifrt_arrays) { + const auto& py_array = py_arrays[i]; + absl::Span shape_span = py_array.shape(); + results[i] = + PyArray(py_array.aval(), py_array.weak_type(), py_array.dtype(), + std::vector(shape_span.begin(), shape_span.end()), + dst_shardings[i], py_array.py_client(), traceback, + std::move(ifrt_array), py_array.committed(), + /*skip_checks=*/true, py_array.result_status()); + } + return results; +} + +absl::StatusOr PyArray::BatchedDevicePut( + nb::object aval, nb::object sharding, std::vector xs, + absl::Span dst_devices, bool committed, + bool force_copy, PjRtClient::HostBufferSemantics host_buffer_semantics, + bool jax_enable_x64) { + if (dst_devices.size() != xs.size()) { + throw nb::value_error( + absl::StrCat("Argument sizes (xs and devices) must match %zu vs %zu", + dst_devices.size(), xs.size()) + .c_str()); + } + for (const PyDevice* device : dst_devices) { + if (device->client().get() == nullptr) { + return InvalidArgument("Cannot copy to unattached devices."); + } + } + auto transfer_guard_formatter = [&aval, &sharding] { + return absl::StrCat( + "aval=", nb::cast(nb::repr(aval)), + ", dst_sharding=", nb::cast(nb::repr(sharding))); + }; + + GlobalPyRefManager()->CollectGarbage(); + + auto n_devices = dst_devices.size(); + + DevicePutOptions options; + options.squash_64bit_types = !jax_enable_x64; + options.allow_zero_copy = + (!force_copy && (host_buffer_semantics == + ifrt::Client::HostBufferSemantics::kImmutableZeroCopy)); + if (!dst_devices.empty()) { + options.ifrt_user_context = + dst_devices.front()->client()->ifrt_client()->CreateUserContext(); + } + + nb::list owning_pylist; + std::vector> ifrt_arrays; + + absl::InlinedVector devices; + devices.reserve(n_devices); + std::vector shapes; + shapes.reserve(n_devices); + + ifrt::MemoryKind dst_memory_kind = xla::GetMemoryKind(sharding); + + std::vector device_put_fns; + device_put_fns.reserve(xs.size()); + size_t i = 0; + for (auto& x : xs) { + if (PyArray::IsPyArray(x)) { + TF_RETURN_IF_ERROR( + jax::ApplyTransferGuardToDeviceToDevice(transfer_guard_formatter)); + } else { + TF_RETURN_IF_ERROR( + jax::ApplyTransferGuardToHostToDevice(transfer_guard_formatter)); + } + TF_ASSIGN_OR_RETURN( + device_put_fns.emplace_back(), + DevicePut(x, dst_devices[i]->client()->ifrt_client(), + dst_devices[i]->device(), options, dst_memory_kind)); + ++i; + } + std::vector device_puts; + device_puts.reserve(device_put_fns.size()); + { + nb::gil_scoped_release gil_release; + for (auto& device_put_fn : device_put_fns) { + TF_ASSIGN_OR_RETURN(auto device_put, std::move(device_put_fn)()); + device_puts.push_back(std::move(device_put)); + } + } + for (auto& device_put : device_puts) { + ifrt_arrays.push_back(std::move(device_put.ifrt_array)); + devices.push_back( + ifrt_arrays.back()->sharding().devices()->devices().front()); + shapes.push_back(ifrt_arrays.back()->shape()); + if (device_put.owning_pybuffer) { + owning_pylist.append(device_put.owning_pybuffer); + } + } + + // TODO(phawkins): it's highly suspicious to me that owning_pylist isn't + // consumed here. Look into this. + + auto weak_type = nb::cast(aval.attr("weak_type")); + auto dtype = aval.attr("dtype"); + auto shape = nb::cast>(aval.attr("shape")); + + TF_ASSIGN_OR_RETURN( + auto ifrt_sharding, + sharding.type().is(jax::PmapSharding::type()) + ? xla::GetIfrtConcreteSharding(sharding, ifrt::Shape(shape), + std::move(shapes)) + : xla::GetIfrtHloSharding(sharding, ifrt::Shape(shape))); + TF_ASSIGN_OR_RETURN(auto ifrt_dtype, DtypeToIfRtDType(dtype)); + // TODO(emilyaf): Remove the following and just use ifrt_dtype when tokens are + // supported. + ifrt::DType array_dtype = + ifrt_arrays.empty() ? ifrt_dtype : ifrt_arrays.front()->dtype(); + TF_ASSIGN_OR_RETURN(auto py_device_list, jax::GetPyDeviceList(sharding)); + TF_ASSIGN_OR_RETURN( + auto ifrt_array, + py_device_list->py_client() + ->ifrt_client() + ->AssembleArrayFromSingleDeviceArrays( + array_dtype, ifrt::Shape(shape), std::move(ifrt_sharding), + absl::MakeSpan(ifrt_arrays), + xla::ifrt::ArrayCopySemantics::kReuseInput, + xla::ifrt::SingleDeviceShardSemantics::kAddressableShards)); + + return PyArray(aval, weak_type, dtype, std::move(shape), sharding, + py_device_list->py_client(), Traceback::Get(), + std::move(ifrt_array), committed, /*skip_checks=*/true); +} + +absl::StatusOr PyArray::ReorderShards( + PyArray x, nanobind::object dst_sharding, + ifrt::ArrayCopySemantics array_copy_semantics) { + xla::ifrt::Array* ifrt_array_ptr = x.ifrt_array(); + if (ifrt_array_ptr == nullptr) { + return absl::InvalidArgumentError( + "Reorder() called on deleted or donated buffer"); + } + + ifrt::Client* const client = ifrt_array_ptr->client(); + + const auto& device_list = ifrt_array_ptr->sharding().devices(); + TF_ASSIGN_OR_RETURN(auto dst_device_list, GetIfrtDeviceList(dst_sharding)); + if (device_list->AddressableDeviceList()->size() != + dst_device_list->AddressableDeviceList()->size()) { + return absl::InvalidArgumentError(absl::StrCat( + "Array is expected to have ", + dst_device_list->AddressableDeviceList()->size(), + " addressable shards, but has ", + device_list->AddressableDeviceList()->size(), " addressable shards")); + } + + TF_ASSIGN_OR_RETURN( + std::shared_ptr dst_ifrt_sharding, + GetIfrtConcreteEvenSharding(dst_sharding, ifrt_array_ptr->dtype(), + ifrt_array_ptr->shape())); + + tsl::RCReference new_ifrt_array; + { + nb::gil_scoped_release gil_release; + + const absl::Span addressable_devices = + device_list->AddressableDeviceList()->devices(); + const absl::Span dst_addressable_devices = + dst_device_list->AddressableDeviceList()->devices(); + + absl::flat_hash_map device_id_to_array_shard_index; + device_id_to_array_shard_index.reserve(dst_addressable_devices.size()); + for (int i = 0; i < dst_addressable_devices.size(); ++i) { + const int device_id = dst_addressable_devices[i]->Id().value(); + const bool inserted = + device_id_to_array_shard_index.insert({device_id, i}).second; + if (!inserted) { + return absl::InvalidArgumentError( + absl::StrCat("Sharding contains duplicate device id=", device_id)); + } + } + + std::vector from_shard_indices; + from_shard_indices.reserve(addressable_devices.size()); + std::vector to_shard_indices; + to_shard_indices.reserve(dst_addressable_devices.size()); + for (int i = 0; i < dst_addressable_devices.size(); ++i) { + from_shard_indices.push_back(i); + const int shard_device_id = addressable_devices[i]->Id().value(); + const auto it = device_id_to_array_shard_index.find(shard_device_id); + if (it == device_id_to_array_shard_index.end()) { + return absl::InvalidArgumentError(absl::StrCat( + "Array shard ", i, " is on device id=", shard_device_id, + ", but sharding does not have a shard on that device.")); + } + to_shard_indices.push_back(it->second); + } + + auto mappings = + std::make_shared>(); + { + auto& mapping = mappings->emplace_back(); + mapping.in_array = 0; + mapping.out_array = 0; + mapping.from.reserve(dst_addressable_devices.size()); + mapping.to.reserve(dst_addressable_devices.size()); + for (int64_t i = 0; i < dst_addressable_devices.size(); ++i) { + mapping.from.push_back(xla::ifrt::RemapPlan::Interval{ + from_shard_indices[i], from_shard_indices[i] + 1, 1}); + mapping.to.push_back(xla::ifrt::RemapPlan::Interval{ + to_shard_indices[i], to_shard_indices[i] + 1, 1}); + } + } + + xla::ifrt::RemapPlan plan = { + /*input_specs=*/{xla::ifrt::ArraySpec{ + /*dtype=*/ifrt_array_ptr->dtype(), + /*shape=*/ifrt_array_ptr->shape(), + /*sharding=*/ifrt_array_ptr->shared_ptr_sharding()}}, + /*output_specs=*/ + {xla::ifrt::ArraySpec{/*dtype=*/ifrt_array_ptr->dtype(), + /*shape=*/ifrt_array_ptr->shape(), + /*sharding=*/std::move(dst_ifrt_sharding)}}, + /*mappings=*/std::move(mappings), + }; + DCHECK_OK(plan.Validate()); + std::vector> input; + input.push_back(tsl::FormRef(ifrt_array_ptr)); + TF_ASSIGN_OR_RETURN( + auto remapped, + client->RemapArrays(plan, absl::MakeSpan(input), array_copy_semantics)); + + TF_RET_CHECK(remapped.size() == 1); + new_ifrt_array = std::move(remapped.front()); + } + + return xla::PyArray(nb::borrow(x.aval().ptr()), x.weak_type(), + nb::borrow(x.dtype().ptr()), + std::vector(x.shape().begin(), x.shape().end()), + std::move(dst_sharding), x.py_client(), x.traceback(), + std::move(new_ifrt_array), + /*committed=*/true, + /*skip_checks=*/true); +} + +absl::Status PyArray::BatchedBlockUntilReady(std::vector objs) { + // Create ready futures for all arrays before blocking on their readiness. + // This helps reduce the latency in some backend implementations where + // querying readiness of an array is not free. + + std::vector ifrt_arrays; + ifrt_arrays.reserve(objs.size()); + for (nb::handle obj : objs) { + if (obj.type().is(PyArray::type())) { + auto py_array = nb::borrow(obj); + ifrt::Array* const ifrt_array = py_array.ifrt_array(); + if (ifrt_array == nullptr) { + return absl::InvalidArgumentError( + "BlockHostUntilReady() called on deleted or donated buffer"); + } + ifrt_arrays.push_back(ifrt_array); + } else { + return absl::InvalidArgumentError( + "PyArray::BatchedBlockUntilReady can take PyArray only"); + } + } + + GlobalPyRefManager()->CollectGarbage(); + nb::gil_scoped_release gil_release; + return AwaitBuffersReady(absl::MakeConstSpan(ifrt_arrays)); +} + +absl::Status PyArray::ReplaceWithAlias(PyArray o) { + auto& storage = GetStorage(); + auto& o_storage = o.GetStorage(); + if (storage.py_client.get() != o_storage.py_client.get()) { + return absl::InvalidArgumentError( + "Unable to replace a PyArray with a PyArray from a different client."); + } + storage.aval = o_storage.aval; + storage.weak_type = o_storage.weak_type; + storage.dtype = o_storage.dtype; + storage.shape = o_storage.shape; + storage.sharding = o_storage.sharding; + storage.npy_value = o_storage.npy_value; + storage.committed = o_storage.committed; + storage.traceback = o_storage.traceback; + storage.ifrt_array = o_storage.ifrt_array; + storage.fully_replicated_array = o_storage.fully_replicated_array; + storage.py_arrays = o_storage.py_arrays; + storage.host_value.Clear(); + storage.dynamic_shape = o_storage.dynamic_shape; + storage.result_status = o_storage.result_status; + + return absl::OkStatus(); +} + +std::vector PyClient::LiveArrays() const { + std::vector result; + for (auto& shard : arrays_) { + nb::ft_lock_guard lock(shard.mutex); + for (PyArray::Storage* array = shard.arrays; array; array = array->next) { + bool all_deleted = + (array->ifrt_array == nullptr || array->ifrt_array->IsDeleted()); + if (!all_deleted) { + result.push_back(nb::borrow(array->AsHandle())); + } + } + } + return result; +} + +// PEP 3118 buffer protocol implementation. + +namespace { + +// Extra data to be kept alive by the consumer of the buffer protocol. +struct ExtraBufferInfo { + explicit ExtraBufferInfo( + std::shared_ptr buffer, + std::unique_ptr external_reference_hold) + : buffer(std::move(buffer)), + external_reference_hold(std::move(external_reference_hold)) {} + + std::vector strides; + // We keep an external reference hold to the PjRtBuffer. This prevents a + // use-after-free in the event that Delete() is called on a buffer with an + // live buffer protocol view. It does however mean that Delete() sometimes + // won't actually delete immediately. + std::shared_ptr buffer; + std::unique_ptr external_reference_hold; +}; + +// The default layout of a non-tuple array should have major-to-minor layout +// and no tiles. +bool HasDefaultLayout(const Layout& layout) { + return LayoutUtil::IsMonotonicWithDim0Major(layout) && layout.tiles().empty(); +} + +int PyArray_bf_getbuffer(PyObject* exporter, Py_buffer* view, int flags) { + absl::Status status = [&]() -> absl::Status { + PyArray py_array = nb::borrow(exporter); + if (py_array.ifrt_array() == nullptr) { + // TODO(phawkins): why is this happening? + return InvalidArgument("Array is null"); + } + if (!llvm::isa(py_array.ifrt_array())) { + return InvalidArgument("Only local arrays are supported, got %s", + py_array.ifrt_array()->DebugString()); + } + auto* array = + static_cast(py_array.ifrt_array()); + absl::Span> buffers = + array->pjrt_buffers(); + + PjRtBuffer& buffer = *buffers.front(); + if (!buffer.IsOnCpu()) { + return InvalidArgument( + "Python buffer protocol is only defined for CPU buffers."); + } + + if (buffers.size() != 1) { + return InvalidArgument( + "Python buffer protocol is only defined for buffers with a single " + "shard."); + } + if (!py_array.sharding().type().is(jax::SingleDeviceSharding::type())) { + return InvalidArgument( + "Python buffer protocol is only defined for single-device sharded " + "buffers."); + } + + const char* format = + PEP3118FormatDescriptorForPrimitiveType(buffer.element_type()); + // It isn't an option for us to export unknown types as, say, bytes. When + // converting an object to an ndarray, NumPy tries the buffer protocol + // first. We very much want NumPy to fail and fall back to using + // __array__, which allows us to handle custom dtypes correctly. + if (!format) { + return InvalidArgument( + "Buffers of type %s are not supported by the Python buffer protocol.", + PrimitiveType_Name(buffer.element_type())); + } + + std::unique_ptr external_reference_hold; + { + // We call BlockHostUntilReady() below, which may block. + nb::gil_scoped_release gil_release; + + if (buffer.IsTuple()) { + return InvalidArgument( + "Python buffer protocol is only defined for array buffers."); + } + if ((flags & PyBUF_WRITEABLE) == PyBUF_WRITEABLE) { + return InvalidArgument("XLA buffers are read-only."); + } + TF_ASSIGN_OR_RETURN(external_reference_hold, + buffer.AcquireExternalReference()); + if (buffer.IsDeleted()) { + return InvalidArgument("Deleted buffer used in buffer protocol."); + } + + // TODO(b/327524065): use PjRtLayout directly instead of xla::Layout + Layout xla_layout = buffer.layout()->xla_layout(); + + if (((flags & PyBUF_C_CONTIGUOUS) == PyBUF_C_CONTIGUOUS || + (flags & PyBUF_STRIDES) == PyBUF_ND) && + !LayoutUtil::IsMonotonicWithDim0Major(xla_layout)) { + return InvalidArgument("Buffer is not in C-contiguous layout."); + } else if ((flags & PyBUF_F_CONTIGUOUS) == PyBUF_F_CONTIGUOUS && + !LayoutUtil::IsMonotonicWithDim0Minor(xla_layout)) { + return InvalidArgument("Buffer is not in F-contiguous layout."); + } else if ((flags & PyBUF_ANY_CONTIGUOUS) == PyBUF_ANY_CONTIGUOUS && + !LayoutUtil::IsMonotonicWithDim0Major(xla_layout) && + !LayoutUtil::IsMonotonicWithDim0Minor(xla_layout)) { + return InvalidArgument("Buffer is not in contiguous layout."); + } else if (!HasDefaultLayout(xla_layout)) { + // Fail and fall back to using __array__ if the CPU buffer has a device + // specific layout. For instance, this happens for host buffers in + // pinned memories of the TPU device. + return InvalidArgument( + "Buffer is potentially a device buffer with non default layout."); + } + TF_RETURN_IF_ERROR(buffer.GetReadyFuture().Await()); + } + + // We must hold the GIL (or at least prevent Python GC) while writing to the + // view object, see https://github.com/python/cpython/issues/130409. + std::memset(view, 0, sizeof(Py_buffer)); + const void* root_ptr = + external_reference_hold->OpaqueDeviceMemoryDataPointer(); + view->buf = const_cast(root_ptr); + auto extra = std::make_unique( + buffers.front(), std::move(external_reference_hold)); + view->itemsize = ShapeUtil::ByteSizeOfPrimitiveType(buffer.element_type()); + TF_ASSIGN_OR_RETURN(view->len, buffer.GetOnDeviceSizeInBytes()); + view->readonly = 1; + if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT) { + view->format = const_cast(format); + } + if ((flags & PyBUF_ND) == PyBUF_ND) { + view->ndim = buffer.dimensions().size(); + static_assert(sizeof(int64_t) == sizeof(Py_ssize_t), + "Py_ssize_t must be 64 bits"); + if (view->ndim != 0) { + view->shape = reinterpret_cast( + const_cast(buffer.dimensions().data())); + if ((flags & PyBUF_STRIDES) == PyBUF_STRIDES) { + extra->strides = + ByteStridesForShape(buffer.element_type(), buffer.dimensions(), + buffer.layout()->xla_layout()); + view->strides = reinterpret_cast( + const_cast(extra->strides.data())); + } + } + } + view->internal = extra.release(); + return absl::OkStatus(); + }(); + if (!status.ok()) { + // numpy.asarray(...) eats the PyExc_BufferError. Adding a log here helps + // debugging when the error really occurs. + VLOG(1) << "Buffer Protocol Error: " << status; + PyErr_SetString(PyExc_BufferError, status.ToString().c_str()); + return -1; + } + view->obj = exporter; + Py_INCREF(view->obj); + return 0; +} + +void PyArray_bf_releasebuffer(PyObject*, Py_buffer* buffer) { + auto extra = static_cast(buffer->internal); + delete extra; +} + +// Returns if shape has a major-to-minor layout. +bool HasMajorToMinorLayout(const xla::Shape& shape) { + if (shape.has_layout()) { + for (int i = 0; i < shape.layout().minor_to_major_size(); ++i) { + if (shape.layout().minor_to_major(i) != + shape.layout().minor_to_major_size() - 1 - i) { + return false; + } + } + } + return true; +} + +// Returns byte_strides if shape has a non-major-to-minor layout. +std::optional> ByteStridesOrDefaultForShapeInt64( + const Shape& shape) { + if (!shape.has_layout() || HasMajorToMinorLayout(shape)) { + return std::nullopt; + } + return ByteStridesForShape(shape); +} + +bool IsZeroCopyableCpuBuffer(const PjRtBuffer* buf) { + // For CPU buffers with device-specific layouts, we must delinearize + // to unpack the array. This could happen for the host buffer + // pre-mapped to the TPU device, a.k.a., pinned host buffers for the + // device. + bool has_default_layout = + buf->layout() == nullptr || HasDefaultLayout(buf->layout()->xla_layout()); + // On CPU for values >= 8 bits, we can return the value in a zero-copy way. + // For sub-byte values, we must copy in order to unpack the array. + return buf->IsOnCpu() && + !primitive_util::IsSubByteNonPredType(buf->element_type()) && + has_default_layout; +} +} // namespace + +PyHostValue::PyHostValue() = default; +PyHostValue::~PyHostValue() = default; + +absl::StatusOr> PyHostValue::AsNumPyArray( + std::optional& dynamic_shape_holder, ifrt::Array* ifrt_array) { + if (ifrt_array->IsDeleted()) { + return InvalidArgument("DeviceArray has been deleted."); + } + // The only `jax.Array` with token-shape buffer is the one wrapped by + // `jax.core.Token`. Since it is an internal implementation detail, we + // don't support converting it to a numpy array. + if (ifrt_array->dtype().kind() == ifrt::DType::kToken) { + return InvalidArgument( + "Cannot convert a token-shape buffer to a numpy array."); + } + auto* arr = llvm::dyn_cast_or_null(ifrt_array); + if (arr != nullptr) { + auto* pjrt_buffer = arr->pjrt_buffers().front().get(); + TF_RET_CHECK(!pjrt_buffer->IsTuple()); + // On CPU for values >= 8 bits, we can return the value in a zero-copy way. + // For sub-byte values, we must copy in order to unpack the array. + if (IsZeroCopyableCpuBuffer(pjrt_buffer)) { + TF_ASSIGN_OR_RETURN(const auto* shape, + XlaDynamicShape(ifrt_array, dynamic_shape_holder)); + TF_ASSIGN_OR_RETURN(nb_dtype dtype, + PrimitiveTypeToNbDtype(shape->element_type())); + // Objects that must be kept alive while the array is alive. + struct Hold { + tsl::RCReference buffer; + std::unique_ptr external_reference_hold; + }; + auto hold = std::make_unique(); + hold->buffer = tsl::FormRef(ifrt_array); + auto* hold_ptr = hold.release(); + nb::capsule hold_capsule( + hold_ptr, [](void* h) noexcept { delete static_cast(h); }); + { + // Release the GIL as `AcquireExternalReference` may block. + nb::gil_scoped_release gil; + TF_ASSIGN_OR_RETURN(hold_ptr->external_reference_hold, + pjrt_buffer->AcquireExternalReference()); + auto fut = ifrt_array->GetReadyFuture(); + BlockUntilReadyWithCancel(fut); + TF_RETURN_IF_ERROR(fut.Await()); + } + void* data = + hold_ptr->external_reference_hold->OpaqueDeviceMemoryDataPointer(); + nb_numpy_ndarray array(dtype, shape->dimensions(), + ByteStridesForShape(*shape), data, hold_capsule); + array.attr("flags").attr("writeable") = nb::bool_(false); + return std::make_pair(array, false); + } + } + + TF_RETURN_IF_ERROR(CopyToHostAsync(dynamic_shape_holder, ifrt_array)); + if (!ready_.IsReady()) { + nb::gil_scoped_release gil; + BlockUntilReadyWithCancel(ready_); + TF_RETURN_IF_ERROR(ready_.Await()); + } else { + TF_RETURN_IF_ERROR(ready_.Await()); + } + if (string_array_contents_ != nullptr) { + TF_RETURN_IF_ERROR(ConvertStringArrayContentsToNumpyArray(ifrt_array)); + } + return std::make_pair(value_, true); +} + +absl::Status PyHostValue::ConvertStringArrayContentsToNumpyArray( + ifrt::Array* ifrt_array) { +#ifdef NPY_2_0_API_VERSION + if (PyArray_RUNTIME_VERSION < NPY_2_0_API_VERSION) { + return absl::FailedPreconditionError( + absl::StrCat("String arrays are not supported in NumPy version: ", + PyArray_RUNTIME_VERSION)); + } + auto numpy_dtype = nb::steal( + reinterpret_cast(PyArray_DescrFromType(NPY_VSTRING))); + value_ = nb_numpy_ndarray(numpy_dtype, ifrt_array->shape().dims(), + /*strides=*/std::nullopt); + + auto dst_py_array_obj = reinterpret_cast<::PyArrayObject*>(value_.ptr()); + auto iter = + nb::steal(PyArray_IterNew(reinterpret_cast(dst_py_array_obj))); + for (auto& cord : *string_array_contents_) { + absl::string_view input_str_view = cord.Flatten(); + auto py_unicode = nb::steal(PyUnicode_FromStringAndSize( + input_str_view.data(), input_str_view.size())); + if (py_unicode.ptr() == nullptr) { + return absl::InternalError("PyUnicode_FromStringAndSize failed"); + } + if (PyArray_SETITEM(dst_py_array_obj, + static_cast(PyArray_ITER_DATA(iter.ptr())), + py_unicode.ptr()) != 0) { + return absl::InternalError("PyArray_SETITEM failed"); + } + PyArray_ITER_NEXT(iter.ptr()); + } + + value_.attr("flags").attr("writeable") = nb::bool_(false); + + string_array_contents_.reset(); + + return absl::OkStatus(); +#else + return absl::FailedPreconditionError( + "String arrays are not supported in this NumPy version."); +#endif +} + +absl::Status PyHostValue::CopyStringArrayToHostAsync( + std::optional& dynamic_shape_holder, ifrt::Array* ifrt_array) { + auto transfer_guard_formatter = [ifrt_array] { + return absl::StrCat( + "shape=(", absl::StrJoin(ifrt_array->shape().dims(), ","), + "), dtype=", ifrt_array->dtype().DebugString(), ", device=", + ifrt_array->sharding().devices()->devices().front()->DebugString()); + }; + TF_RETURN_IF_ERROR( + jax::ApplyTransferGuardToDeviceToHost(transfer_guard_formatter)); + + TF_ASSIGN_OR_RETURN(nb_dtype dtype, IfrtDtypeToNbDtype(ifrt_array->dtype())); + auto shape = ifrt_array->shape(); + + // Allocate a vector of cords to hold the contents of the array until + // they are until they are ultimately converted to a numpy array as part + // of the `AsNumPyArray` call. + string_array_contents_ = + std::make_shared>(shape.num_elements()); + ready_ = ifrt_array->CopyToHostBuffer(string_array_contents_->data(), + /*byte_strides=*/std::nullopt, + ifrt::ArrayCopySemantics::kAlwaysCopy); + + ready_.OnReady( + [string_array_contents = string_array_contents_](absl::Status) { + }); // Keeps the cords alive until the copy is done. + + return absl::OkStatus(); +} + +absl::Status PyHostValue::CopyToHostAsync( + std::optional& dynamic_shape_holder, ifrt::Array* ifrt_array) { + if (ready_.IsValid()) { + // The array value has been populated, so CopyToHostAsync has been called. + return absl::OkStatus(); + } + + // Copying in Arrays of type kString requires some special handling + if (ifrt_array->dtype().kind() == ifrt::DType::kString) { + return CopyStringArrayToHostAsync(dynamic_shape_holder, ifrt_array); + } + + auto* arr = llvm::dyn_cast_or_null(ifrt_array); + if (arr != nullptr && !arr->pjrt_buffers().front()->IsTuple() && + IsZeroCopyableCpuBuffer(arr->pjrt_buffers().front().get())) { + return absl::OkStatus(); + } + auto transfer_guard_formatter = [ifrt_array] { + return absl::StrCat( + "shape=(", absl::StrJoin(ifrt_array->shape().dims(), ","), + "), dtype=", ifrt_array->dtype().DebugString(), ", device=", + ifrt_array->sharding().devices()->devices().front()->DebugString()); + }; + TF_RETURN_IF_ERROR( + jax::ApplyTransferGuardToDeviceToHost(transfer_guard_formatter)); + + // TODO(b/182461453): This is a blocking call. If we further implemented + // populating dynamic shape metadata while fetching the literal, we wouldn't + // need this static approach. + const xla::Shape* dynamic_shape; + std::optional shape_holder; + if (llvm::isa(ifrt_array)) { + TF_ASSIGN_OR_RETURN(dynamic_shape, + XlaDynamicShape(ifrt_array, dynamic_shape_holder)); + } else { + // Skip querying the dynamic shape for a non-PjRt Array. + TF_ASSIGN_OR_RETURN(xla::PrimitiveType type, + ifrt::ToPrimitiveType(ifrt_array->dtype())); + shape_holder = ShapeUtil::MakeShapeWithDescendingLayout( + type, ifrt_array->shape().dims()); + dynamic_shape = &*shape_holder; + } + + xla::Shape host_shape = ShapeUtil::DeviceShapeToHostShape(*dynamic_shape); + + auto strides = ByteStridesOrDefaultForShapeInt64(host_shape); + TF_ASSIGN_OR_RETURN(nb_dtype dtype, + PrimitiveTypeToNbDtype(host_shape.element_type())); + value_ = nb_numpy_ndarray(dtype, host_shape.dimensions(), strides); + // TODO(hyeontaek): Several PjRt runtimes assume that the host buffer uses + // the same transposition as the device buffer. This is different from + // PjRtBuffer::ToLiteral()'s semantics that the runtime respects the layout + // of the host buffer literal. On the other hand, the runtime often knows + // better about an efficient layout for the host buffer. It will be useful + // to revisit the semantics of PjRtBuffer::ToLiteral() to see if it is + // desirable for the runtime to choose the layout. + ready_ = ifrt_array->CopyToHostBuffer(value_.mutable_data(), strides, + ifrt::ArrayCopySemantics::kReuseInput); + // Make sure the destination of the copy remains alive until the copy is done. + value_.inc_ref(); + ready_.OnReady([array{value_.ptr()}](absl::Status status) { + GlobalPyRefManager()->AddGarbage(nb::steal(array)); + }); + value_.attr("flags").attr("writeable") = nb::bool_(false); + return absl::OkStatus(); +} + +void PyHostValue::Clear() { + ready_ = {}; + value_ = {}; + string_array_contents_ = {}; +} + +namespace { +PyMemberDef PyBaseArray_members[] = { +#if PY_VERSION_HEX < 0x030C0000 + {"__weaklistoffset__", T_PYSSIZET, + static_cast(offsetof(PyBaseArrayObject, weakrefs)), READONLY, + nullptr}, +#endif // PY_VERSION_HEX < 0x030C0000 + {nullptr, 0, 0, 0, nullptr}, +}; + +PyType_Slot PyBaseArray_slots[] = { + {Py_tp_dealloc, reinterpret_cast(PyBaseArray_tp_dealloc)}, + {Py_tp_members, reinterpret_cast(PyBaseArray_members)}, + {Py_tp_traverse, reinterpret_cast(PyBaseArray_tp_traverse)}, + {Py_tp_hash, reinterpret_cast(PyObject_HashNotImplemented)}, + {0, nullptr}, +}; + +PyGetSetDef PyArray_tp_getset[] = { + {"__dict__", PyObject_GenericGetDict, PyObject_GenericSetDict, nullptr, + nullptr}, + {nullptr, nullptr, nullptr, nullptr, nullptr}, +}; + +PyMemberDef PyArray_members[] = { +#if PY_VERSION_HEX < 0x030C0000 + {"__weaklistoffset__", T_PYSSIZET, + static_cast(offsetof(PyArrayObject, weakrefs)), READONLY, + nullptr}, + {"__dictoffset__", T_PYSSIZET, + static_cast(offsetof(PyArrayObject, dict)), READONLY, nullptr}, +#endif // PY_VERSION_HEX < 0x030C0000 + {nullptr, 0, 0, 0, nullptr}, +}; // namespace xla + +PyType_Slot PyArray_slots[] = { + {Py_tp_new, reinterpret_cast(PyArray_tp_new)}, + {Py_tp_dealloc, reinterpret_cast(PyArray_tp_dealloc)}, + {Py_tp_members, reinterpret_cast(PyArray_members)}, + {Py_tp_traverse, reinterpret_cast(PyArray_tp_traverse)}, + {Py_tp_clear, reinterpret_cast(PyArray_tp_clear)}, + {Py_tp_getset, reinterpret_cast(PyArray_tp_getset)}, + {Py_bf_getbuffer, reinterpret_cast(PyArray_bf_getbuffer)}, + {Py_bf_releasebuffer, reinterpret_cast(PyArray_bf_releasebuffer)}, + {0, nullptr}, +}; + +} // namespace + +absl::Status PyArray::RegisterTypes(nb::module_& m) { + // We are not using nanobind to avoid having a non-standard metaclass, which + // would make Array incompatible with abc.ABCMeta. + std::string base_name = + absl::StrCat(nb::cast(m.attr("__name__")), ".Array"); + PyType_Spec PyBaseArray_spec = { +#if PY_VERSION_HEX < 0x030B0000 + // Work around for https://github.com/python/cpython/issues/89478 + // CPython 3.10 and earlier assume that the .name value remains alive + // forever. + /*.name=*/strdup(base_name.c_str()), +#else + /*.name=*/base_name.c_str(), +#endif // PY_VERSION_HEX < 0x030B0000 + /*.basicsize=*/static_cast(sizeof(PyBaseArrayObject)), + /*.itemsize=*/0, +#if PY_VERSION_HEX < 0x030C0000 + /*.flags=*/Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC, +#else // PY_VERSION_HEX >= 0x030C0000 + /*.flags=*/Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC | + Py_TPFLAGS_MANAGED_WEAKREF, +#endif // PY_VERSION_HEX >= 0x030C0000 + /*.slots=*/PyBaseArray_slots}; + auto* base_type = PyType_FromSpec(&PyBaseArray_spec); + if (!base_type) { + throw nb::python_error(); + } + m.attr("Array") = nb::borrow(base_type); + + std::string name = + absl::StrCat(nb::cast(m.attr("__name__")), ".ArrayImpl"); + + PyType_Spec PyArray_spec = { +#if PY_VERSION_HEX < 0x030B0000 + // Work around for https://github.com/python/cpython/issues/89478 + // CPython 3.10 and earlier assume that the .name value remains alive + // forever. + /*.name=*/strdup(name.c_str()), +#else + /*.name=*/name.c_str(), +#endif // PY_VERSION_HEX < 0x030B0000 + /*.basicsize=*/static_cast(sizeof(PyArrayObject)), + /*.itemsize=*/0, +#if PY_VERSION_HEX < 0x030C0000 + /*.flags=*/Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC, +#else // PY_VERSION_HEX >= 0x030C0000 + /*.flags=*/Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | + Py_TPFLAGS_MANAGED_DICT | Py_TPFLAGS_MANAGED_WEAKREF, +#endif // PY_VERSION_HEX >= 0x030C0000 + /*.slots=*/PyArray_slots, + }; + + type_ = PyType_FromSpecWithBases(&PyArray_spec, base_type); + if (!type_) { + throw nb::python_error(); + } + auto type = nb::borrow(type_); + m.attr("ArrayImpl") = type; + + type.attr("__init__") = nb::cpp_function( + [](PyArray self, nb::object aval, nb::object sharding, nb::list arrays, + bool committed, bool skip_checks) { + if (!(arrays.size() == 0 || arrays[0].type().is(PyArray::type()))) { + throw nb::type_error( + absl::StrCat( + "Unsupported type for elements in `arrays`: ", + nb::cast(nb::str(arrays[0].type()))) + .c_str()); + } + auto py_arrays = nb::cast>(arrays); + PyArray::PyInit(self, std::move(aval), std::move(sharding), py_arrays, + committed, skip_checks); + }, + nb::is_method(), nb::arg("aval"), nb::arg("sharding"), nb::arg("arrays"), + nb::arg("committed"), nb::arg("_skip_checks") = false); + type.attr("delete") = nb::cpp_function( + [](PyArray& self) { xla::ThrowIfError(self.Delete()); }, nb::is_method()); + type.attr("_sharding") = nb_property_readonly(&PyArray::sharding); + type.attr("aval") = nb_property(&PyArray::aval, &PyArray::set_aval); + type.attr("_arrays") = + nb_property(&PyArray::arrays, [](PyArray& self, nb::object obj) { + xla::ThrowIfError(self.set_arrays(obj)); + }); + type.attr("_fully_replicated_shard") = nb::cpp_function( + [](PyArray self) { + return xla::ValueOrThrow(self.FullyReplicatedShard()); + }, + nb::is_method()); + type.attr("_npy_value") = + nb_property(&PyArray::npy_value, &PyArray::set_npy_value); + type.attr("_committed") = nb_property_readonly(&PyArray::committed); + type.attr("unsafe_buffer_pointer") = nb::cpp_function( + [](PyArray self) { + return xla::ValueOrThrow(self.UnsafeBufferPointer()); + }, + nb::is_method()); + type.attr("__cuda_array_interface__") = nb_property_readonly( + [](PyArray self) { return self.CudaArrayInterface(); }); + type.attr("_pjrt_layout") = + nb_property_readonly(xla::ValueOrThrowWrapper(&PyArray::layout)); + type.attr("on_device_size_in_bytes") = nb::cpp_function( + xla::ValueOrThrowWrapper(&PyArray::GetOnDeviceSizeInBytes), + nb::is_method()); + type.attr("_single_device_array_to_np_array_did_copy") = nb::cpp_function( + xla::ValueOrThrowWrapper(&PyArray::SingleDeviceArrayToNumpyArrayDidCopy), + nb::is_method()); + type.attr("_copy_single_device_array_to_host_async") = nb::cpp_function( + [](PyArray& self) { + xla::ThrowIfError(self.CopySingleDeviceArrayToHostAsync()); + }, + nb::is_method()); + type.attr("_replace_with") = nb::cpp_function( + [](PyArray& self, PyArray& o) { + xla::ThrowIfError(self.ReplaceWithAlias(o)); + }, + nb::is_method()); + type.attr("block_until_ready") = nb::cpp_function( + [](PyArray self) -> nb::object { + xla::ThrowIfError(self.BlockUntilReady()); + return self; + }, + nb::is_method()); + type.attr("platform") = nb::cpp_function( + [](PyArray self) { + if (self.ifrt_array()->client()->platform_name() == "cuda" || + self.ifrt_array()->client()->platform_name() == "rocm") { + return absl::string_view("gpu"); + } else { + return self.ifrt_array()->client()->platform_name(); + } + }, + nb::is_method()); + type.attr("is_ready") = nb::cpp_function( + [](PyArray self) { return xla::ValueOrThrow(self.IsReady()); }, + nb::is_method()); + type.attr("is_deleted") = + nb::cpp_function(&PyArray::IsDeleted, nb::is_method()); + type.attr("traceback") = nb_property_readonly(&PyArray::traceback); + type.attr("clone") = nb::cpp_function(&PyArray::Clone, nb::is_method()); + type.attr("__module__") = m.attr("__name__"); + + m.attr("batched_copy_array_to_devices_with_sharding") = nb::cpp_function( + [](absl::Span arrays, + absl::Span> dst_device_lists, + absl::Span shardings, + absl::Span array_copy_semantics) { + if (arrays.empty()) { + return std::vector(); + } + auto* client = arrays[0].ifrt_array()->client(); + std::vector device_lists; + device_lists.reserve(dst_device_lists.size()); + for (const auto& dst_devices : dst_device_lists) { + absl::InlinedVector devices; + devices.reserve(dst_devices.size()); + for (auto& d : dst_devices) { + devices.push_back(d->device()); + } + device_lists.push_back(client->MakeDeviceList(devices)); + } + return xla::ValueOrThrow(PyArray::BatchedCopyToDeviceWithSharding( + arrays, device_lists, shardings, array_copy_semantics)); + }); + m.attr("array_result_handler") = nb::cpp_function( + [](nb::object aval, nb::object sharding, bool committed, + bool skip_checks) -> nb_class_ptr { + return make_nb_class( + std::move(aval), std::move(sharding), committed, skip_checks); + }, + nb::arg("aval"), nb::arg("sharding"), nb::arg("committed"), + nb::arg("_skip_checks") = false); + + nb::class_(m, "ResultHandler") + .def("__call__", [](const PyArrayResultHandler& self, + PyArray arg) { return self.Call(arg); }) + .def("__call__", + [](const PyArrayResultHandler& self, + std::vector py_arrays) { return self.Call(py_arrays); }); + + return absl::OkStatus(); +} + +} // namespace xla diff --git a/jaxlib/xla/py_array.h b/jaxlib/xla/py_array.h new file mode 100644 index 000000000000..7c7a6fefe3a2 --- /dev/null +++ b/jaxlib/xla/py_array.h @@ -0,0 +1,365 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_PY_ARRAY_H_ +#define JAXLIB_XLA_PY_ARRAY_H_ + +#include + +#include +#include +#include +#include +#include +#include + +// placeholder for index annotation headers +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "llvm/Support/Casting.h" +#include "nanobind/nanobind.h" +#include "jaxlib/xla/nb_class_ptr.h" +#include "jaxlib/xla/py_client.h" +#include "jaxlib/xla/traceback.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_future.h" +#include "xla/pjrt/pjrt_layout.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/nb_numpy.h" +#include "xla/python/pjrt_ifrt/pjrt_array.h" +#include "xla/shape.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/util.h" + +namespace xla { + +// Private to PyArray, but you cannot forward declare member classes. +// Not thread safe; assumes the GIL is held. +class PyHostValue { + public: + PyHostValue(); + ~PyHostValue(); + + PyHostValue(const PyHostValue&) = delete; + PyHostValue(PyHostValue&&) = delete; + PyHostValue& operator=(const PyHostValue&) = delete; + PyHostValue& operator=(PyHostValue&&) = delete; + + absl::Status CopyToHostAsync(std::optional& dynamic_shape_holder, + ifrt::Array* ifrt_array); + + absl::StatusOr> AsNumPyArray( + std::optional& dynamic_shape_holder, ifrt::Array* ifrt_array); + + void Clear(); + + private: + absl::Status CopyStringArrayToHostAsync( + std::optional& dynamic_shape_holder, ifrt::Array* ifrt_array); + + absl::Status ConvertStringArrayContentsToNumpyArray(ifrt::Array* ifrt_array); + + ifrt::Future<> ready_; + nb_numpy_ndarray value_; + + // Optional field, only used for arrays of type kString. This vector of cords + // serves as input buffer for the CopyToHostBuffer call. It holds these + // contents until it is lazily converted it to a numpy array when the user + // calls `AsNumPyArray`. + std::shared_ptr> string_array_contents_; +}; + +// Private to PyArray, but you cannot forward declare member classes. +struct PyArray_Storage { + PyArray_Storage(nanobind::object aval, bool weak_type, nb_dtype dtype, + std::vector shape, nanobind::object sharding, + bool committed, nb_class_ptr py_client, + std::optional traceback, + tsl::RCReference ifrt_array, + xla::PjRtFuture<> result_status); + + ~PyArray_Storage(); + nanobind::handle AsHandle(); + + nanobind::object aval; + bool weak_type = false; + nb_dtype dtype; + std::vector shape; + + nanobind::object sharding; + nanobind::object npy_value = nanobind::none(); + bool committed = false; + + nb_class_ptr py_client; + std::optional traceback; + tsl::RCReference ifrt_array; + nanobind::object fully_replicated_array = nanobind::none(); + + // optional field, used only in python + std::vector py_arrays; + PyHostValue host_value; // Protected by the GIL. + std::optional dynamic_shape = std::nullopt; + // Only set if this Array was generated by a computation that has effects. + // This is the result status of the XLA computation that generated this + // array. + xla::PjRtFuture<> result_status; + + // Doubly-linked list of all PyArrays known to the client. Protected by the + // GIL. Since multiple PyArrays may share the same PjRtBuffer, there may be + // duplicate PjRtBuffers in this list. + PyArray_Storage* next; + PyArray_Storage* prev; + + uint8_t thread_id_bucket; +}; + +// The C++ implementation of jax.Array. A few key methods and data members are +// implemented in C++ for performance, while most of the functionalities are +// still implemented in python. +class PyArray : public nanobind::object { + public: + NB_OBJECT(PyArray, nanobind::object, "Array", PyArray::IsPyArray); + PyArray() = default; + + // "__init__" methods. Only used in python + static void PyInit(PyArray self, nanobind::object aval, + nanobind::object sharding, + absl::Span py_arrays, bool committed, + bool skip_checks); + + // Only used in C++. `skip_checks` should only be set for Arrays created by + // jax that cannot possibly have consistency issues (e.g. `sharding` devices + // different than `ifrt_array` devices). Arrays created by users should be + // checked. + PyArray(nanobind::object aval, bool weak_type, nb_dtype dtype, + std::vector shape, nanobind::object sharding, + nb_class_ptr py_client, + std::optional traceback, + tsl::RCReference ifrt_array, bool committed, + bool skip_checks, + xla::PjRtFuture<> result_status = xla::PjRtFuture<>()); + + static PyArray MakeFromSingleDeviceArray( + nb_class_ptr py_client, std::optional traceback, + tsl::RCReference ifrt_array, bool weak_type, bool committed, + xla::PjRtFuture<> result_status = xla::PjRtFuture<>()); + + static PyArray MakeFromIfrtArrayAndSharding( + nb_class_ptr py_client, std::optional traceback, + tsl::RCReference ifrt_array, nanobind::object sharding, + bool weak_type, bool committed, bool skip_checks); + + static absl::Status RegisterTypes(nanobind::module_& m); + + static PyArray borrow(PyObject* ptr) { + return nanobind::borrow(ptr); + } + + using Storage = PyArray_Storage; + + const nanobind::object& aval() const { return GetStorage().aval; } + void set_aval(nanobind::object aval) { GetStorage().aval = std::move(aval); } + + bool weak_type() const { return GetStorage().weak_type; } + + const nb_dtype& dtype() const { return GetStorage().dtype; } + absl::Span shape() const { return GetStorage().shape; } + + const nanobind::object& sharding() const { return GetStorage().sharding; } + + absl::StatusOr> layout() { + return ifrt_array()->layout(); + } + + bool committed() const { return GetStorage().committed; } + + const nanobind::object& npy_value() const { return GetStorage().npy_value; } + void set_npy_value(nanobind::object v) { + GetStorage().npy_value = std::move(v); + } + + const nb_class_ptr& py_client() const { + return GetStorage().py_client; + } + + const std::optional& traceback() const { + return GetStorage().traceback; + } + + // Returns xla::InvalidArgument if the buffer has been deleted. + // See `PjRtFuture` for the semantics of `IsReady` and `IsKnownReady`. + absl::StatusOr IsReady() { + ifrt::Array* ifrt_array_ptr = ifrt_array(); + if (ifrt_array_ptr->IsDeleted()) { + return InvalidArgument("Array has been deleted."); + } + return ifrt_array_ptr->GetReadyFuture().IsReady(); + } + + const xla::PjRtFuture<>& result_status() const { + return GetStorage().result_status; + } + + ifrt::Array* ifrt_array() const { return GetStorage().ifrt_array.get(); } + + // Short-term escape hatch to get PjRtBuffers from PyArray. + // TODO(hyeontaek): Migrate all users of this method to be agnostic of PjRt. + absl::Span> pjrt_buffers() const { + ifrt::Array* ifrt_array_ptr = ifrt_array(); + if (ifrt_array_ptr == nullptr) { + return {}; + } + auto* arr = + llvm::dyn_cast_or_null(ifrt_array_ptr); + if (arr == nullptr) { + throw XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend only."); + } + return arr->pjrt_buffers(); + } + + int num_addressable_shards() const { + ifrt::Array* ifrt_array_ptr = ifrt_array(); + if (ifrt_array_ptr == nullptr) { + return 0; + } + auto* arr = + llvm::dyn_cast_or_null(ifrt_array_ptr); + if (arr == nullptr) { + // TODO(hyeontaek): Add num_addressable_shards to ifrt. + return num_shards(); + } + return arr->pjrt_buffers().size(); + } + + std::vector& py_arrays() { return GetStorage().py_arrays; } + const std::vector& py_arrays() const { + return GetStorage().py_arrays; + } + const std::vector& py_arrays_cached(); + + nanobind::object arrays(); + absl::Status set_arrays(nanobind::object obj); + absl::StatusOr FullyReplicatedShard(); + + int num_shards() const { + ifrt::Array* ifrt_array_ptr = ifrt_array(); + if (ifrt_array_ptr == nullptr) { + return 0; + } + return ifrt_array_ptr->sharding().devices()->size(); + } + + static nanobind::handle type() { + DCHECK(type_); + return nanobind::handle(type_); + } + + static bool IsPyArray(nanobind::handle arg) { + return arg.type().is(PyArray::type()); + } + + absl::Status BlockUntilReady() const; + + absl::Status BlockUntilResultStatusIsReady(); + + absl::StatusOr GetOnDeviceSizeInBytes(); + absl::StatusOr> + SingleDeviceArrayToNumpyArrayDidCopy(); + absl::StatusOr SingleDeviceArrayToNumpyArray(); + absl::Status CopySingleDeviceArrayToHostAsync(); + nanobind::dict CudaArrayInterface(); + absl::StatusOr UnsafeBufferPointer(); + + absl::Status Delete(); + + bool IsDeleted() const; + + PyArray Clone() const; + + static absl::StatusOr> BatchedCopyToDeviceWithSharding( + absl::Span py_arrays, + absl::Span dst_device_lists, + absl::Span dst_shardings, + absl::Span array_copy_semantics); + + static absl::StatusOr BatchedDevicePut( + nanobind::object aval, nanobind::object sharding, + std::vector xs, + absl::Span dst_devices, bool committed, + bool force_copy, PjRtClient::HostBufferSemantics host_buffer_semantics, + bool jax_enable_x64); + + static absl::StatusOr ReorderShards( + PyArray x, nanobind::object dst_sharding, + ifrt::ArrayCopySemantics array_copy_semantics); + + static absl::Status BatchedBlockUntilReady( + std::vector objs); + + absl::Status ReplaceWithAlias(PyArray o); + + private: + absl::StatusOr AssertUnsharded(absl::string_view api); + + nanobind::object CheckAndRearrange(absl::Span py_arrays, + nanobind::object sharding, + nanobind::object aval); + + void SetIfrtArray(tsl::RCReference ifrt_array); + + Storage& GetStorage(); + const Storage& GetStorage() const; + + inline static PyObject* type_ = nullptr; +}; + +class PyArrayResultHandler { + public: + PyArrayResultHandler(nanobind::object aval, nanobind::object sharding, + bool committed, bool skip_checks); + + PyArray Call(absl::Span py_arrays) const; + PyArray Call(PyArray py_array) const; + + PyArray Call(nb_class_ptr py_client, + tsl::RCReference ifrt_array, + xla::PjRtFuture<> result_status = xla::PjRtFuture<>()) const; + + private: + nanobind::object aval_; + nanobind::object sharding_; + bool weak_type_; + bool committed_; + bool skip_checks_; + + nb_dtype dtype_; + std::vector shape_; +}; + +absl::StatusOr CudaArrayInterfaceToBuffer( + const nanobind::dict& cai, nb_class_ptr cuda_client, + std::optional device_id); + +} // namespace xla + +#endif // JAXLIB_XLA_PY_ARRAY_H_ diff --git a/jaxlib/xla/py_client.cc b/jaxlib/xla/py_client.cc new file mode 100644 index 000000000000..c4e8449b85c3 --- /dev/null +++ b/jaxlib/xla/py_client.cc @@ -0,0 +1,829 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/py_client.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "llvm/Support/Casting.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OwningOpRef.h" +#include "mlir/Pass/PassManager.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/pair.h" // IWYU pragma: keep +#include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/variant.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/xla/guard_lib.h" +#include "jaxlib/xla/nb_class_ptr.h" +#include "jaxlib/xla/py_array.h" +#include "jaxlib/xla/py_device.h" +#include "jaxlib/xla/py_executable.h" +#include "jaxlib/xla/py_host_callback.h" +#include "jaxlib/xla/py_memory_space.h" +#include "jaxlib/xla/py_values.h" +#include "jaxlib/xla/python_ref_manager.h" +#include "jaxlib/xla/traceback.h" +#include "xla/literal.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/mlir_to_hlo.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_compiler.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/pjrt/pjrt_layout.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/compiler.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/hlo/hlo_program.h" +#include "xla/python/ifrt/host_callback.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/program.h" +#include "xla/python/nb_absl_span.h" // IWYU pragma: keep +#include "xla/python/nb_numpy.h" +#include "xla/python/pjrt_ifrt/pjrt_array.h" +#include "xla/python/pjrt_ifrt/pjrt_client.h" +#include "xla/python/pjrt_ifrt/xla_compiler.h" +#include "xla/python/pprof_profile_builder.h" +#include "xla/python/types.h" +#include "xla/service/platform_util.h" // IWYU pragma: keep +#include "xla/shape.h" +#include "xla/status_macros.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/status.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" + +namespace xla { + +namespace nb = nanobind; + +/*static*/ nb_class_ptr PyClient::Make( + std::shared_ptr ifrt_client) { + auto client = make_nb_class(std::move(ifrt_client)); + Initialize(client); + return client; +} + +PyClient::PyClient(std::shared_ptr ifrt_client) + : ifrt_client_(std::move(ifrt_client)), + client_attributes_(ifrt_client_->Attributes()) { + CHECK(ifrt_client_); +} + +/* static */ void PyClient::Initialize(nb_class_ptr client) { + for (ifrt::Device* device : client->ifrt_client()->devices()) { + client->devices_[device] = make_nb_class(client, device); + + for (ifrt::Memory* memory : device->Memories()) { + auto& py_memory = client->memory_spaces_[memory]; + if (py_memory.get() == nullptr) { + py_memory = make_nb_class(client, memory); + } + } + } +} + +PyClient::~PyClient() { + nb::gil_scoped_release gil; + ifrt_client_ = nullptr; +} + +nb_class_ptr PyClient::GetPyDevice(ifrt::Device* device) { + auto& py_device = devices_[device]; + if (py_device.get() == nullptr) { + py_device = make_nb_class( + nb::borrow>(nb::find(this)), device); + } + return py_device; +} + +nb_class_ptr PyClient::GetPyMemorySpace( + ifrt::Memory* memory_space) { + auto& py_memory = memory_spaces_[memory_space]; + if (py_memory.get() == nullptr) { + py_memory = make_nb_class( + nb::borrow>(nb::find(this)), memory_space); + } + return py_memory; +} + +std::vector> PyClient::Devices() { + std::vector> devices; + auto span = ifrt_client_->devices(); + devices.reserve(span.size()); + for (ifrt::Device* device : span) { + devices.push_back(GetPyDevice(device)); + } + return devices; +} + +std::vector> PyClient::LocalDevices() { + std::vector> devices; + devices.reserve(ifrt_client_->addressable_devices().size()); + for (ifrt::Device* device : ifrt_client_->addressable_devices()) { + devices.push_back(GetPyDevice(device)); + } + return devices; +} + +std::vector> PyClient::GetAllDevices() { + std::vector> devices; + devices.reserve(ifrt_client_->GetAllDevices().size()); + for (ifrt::Device* device : ifrt_client_->GetAllDevices()) { + devices.push_back(GetPyDevice(device)); + } + return devices; +} + +absl::StatusOr> PyClient::DeviceFromLocalHardwareId( + int local_hardware_id) { + TF_ASSIGN_OR_RETURN(ifrt::Device * device, + ifrt_client_->LookupAddressableDevice(local_hardware_id)); + return GetPyDevice(device); +} + +nb::list PyClient::LiveExecutables() { + CHECK(PyGILState_Check()); + nb::ft_lock_guard lock(executables_mutex_); + nb::list executables; + for (PyLoadedExecutable* exec = executables_; exec; exec = exec->next_) { + if (!exec->is_deleted()) { + executables.append(nb::find(exec)); + } + } + return executables; +} + +absl::Status PyClient::Defragment() { + CHECK(PyGILState_Check()); + if (!llvm::isa(ifrt_client_.get())) { + return absl::UnimplementedError( + "Defragmentation is not supported on this runtime."); + } + ifrt::PlatformId platform_id = ifrt_client_->platform_id(); + bool is_gpu_client = platform_id == CudaId() || platform_id == RocmId() || + platform_id == SyclId(); + + if (!is_gpu_client) { + return absl::UnimplementedError( + "Defragmentation is not supported on this runtime."); + } + + // TODO(b/399879011): This is a GPU-specific implementation of `Defragment`. + // Ideally, this would be replaced with some kind of auto-defrag-on-OOM, or at + // least would not live in this file. + + struct TmpBuffer { + // Non-empty for buffers found in a PyArray_Storage. Multiple Arrays + // can reference the same PjRtBuffer. + std::vector*> pjrt_buffer_ptrs; + // TODO(skyewm): maybe use py_buffer's HostValue + std::shared_ptr host_copy; + }; + + // Synchronously copy all buffers to host + absl::flat_hash_map pjrt_buf_to_tmp_buffer; + + std::vector arrays = LiveArrays(); + for (const PyArray& array : arrays) { + // TODO(hyeontaek): Support non-PjRt Arrays. + // TODO(hyeontaek): Re-construct ifrt::Array with new PjRtBuffer so that + // std::shared_ptr does not need to be updated in-place. + if (array.ifrt_array() == nullptr) { + continue; + } + auto* arr = + llvm::dyn_cast_or_null(array.ifrt_array()); + if (arr == nullptr) { + throw XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend " + "only."); + } + TF_ASSIGN_OR_RETURN(absl::Span> pjrt_buffers, + arr->mutable_pjrt_buffers()); + for (int i = 0; i < pjrt_buffers.size(); ++i) { + std::shared_ptr& pjrt_buf_ptr = pjrt_buffers[i]; + if (pjrt_buf_ptr->IsDeleted()) { + continue; + } + auto [iter, inserted] = + pjrt_buf_to_tmp_buffer.insert({pjrt_buf_ptr.get(), TmpBuffer()}); + if (inserted) { + TF_ASSIGN_OR_RETURN(iter->second.host_copy, + pjrt_buf_ptr->ToLiteralSync()); + } + iter->second.pjrt_buffer_ptrs.push_back(&pjrt_buf_ptr); + } + } + + // All buffers successfully copied to host, delete on-device copies. + // + // Use blocking delete operation to ensure all memory is actually cleared + // before we start rewriting buffers. + // + // Die instead of returning a bad status because program presumably can't + // continue if we fail to reconstitute device buffers. + for (const auto& it : pjrt_buf_to_tmp_buffer) { + PjRtBuffer* pjrt_buf = it.first; + TF_CHECK_OK(pjrt_buf + ->ReleaseDeviceMemoryOwnership( + /*wait_for_operations_to_complete=*/true) + .status()); + } + + // Copy host copies back to device and update PyArrays in-place. + for (auto& it : pjrt_buf_to_tmp_buffer) { + PjRtBuffer* pjrt_buf = it.first; + TmpBuffer& tmp_buffer = it.second; + std::unique_ptr new_copy = + pjrt_client() + ->BufferFromHostLiteral(*tmp_buffer.host_copy, + pjrt_buf->memory_space()) + .value(); + TF_CHECK_OK(new_copy->GetReadyFuture().Await()); + + std::shared_ptr new_pjrt_buf_ptr(new_copy.release()); + for (std::shared_ptr* pjrt_buffer_ptr : + tmp_buffer.pjrt_buffer_ptrs) { + *pjrt_buffer_ptr = new_pjrt_buf_ptr; + } + } + + // TODO(skyewm): delete executables? + return absl::OkStatus(); +} + +/* static */ absl::StatusOr PyClient::BufferFromPyval( + nb_class_ptr client, nb::handle argument, ifrt::Device* device, + bool force_copy, ifrt::Client::HostBufferSemantics host_buffer_semantics) { + if (device == nullptr) { + TF_RET_CHECK(!client->ifrt_client_->addressable_devices().empty()); + device = client->ifrt_client_->addressable_devices().front(); + } + CHECK(device != nullptr); + + auto transfer_guard_formatter = [&argument, dst_device = device] { + auto type = nb::cast(nb::str(argument.type())); + // Catch exceptions because shape and dtype properties convertible to str + // are not guaranteed to present in an arbitrary argument. + std::string shape; + std::string dtype; + try { + shape = + nb::cast(nb::str(nb::object(argument.attr("shape")))); + } catch (const std::exception& e) { + shape = ""; + } + try { + dtype = + nb::cast(nb::str(nb::object(argument.attr("dtype")))); + } catch (const std::exception& e) { + dtype = ""; + } + return absl::StrCat("type=", type, ", shape=", shape, ", dtype=", dtype, + ", dst_device=", dst_device->DebugString()); + }; + TF_RETURN_IF_ERROR( + jax::ApplyTransferGuardToHostToDevice(transfer_guard_formatter)); + + TF_ASSIGN_OR_RETURN(ifrt::Device * found_device, + client->ifrt_client_->LookupDevice(device->Id())); + if (found_device != device) { + return InvalidArgument("Cannot copy value to device '%s' with '%s' backend", + device->DebugString(), + client->ifrt_client_->platform_name()); + } + GlobalPyRefManager()->CollectGarbage(); + + DevicePutOptions options; + options.squash_64bit_types = false; + options.allow_zero_copy = + (!force_copy && (host_buffer_semantics == + ifrt::Client::HostBufferSemantics::kImmutableZeroCopy)); + TF_ASSIGN_OR_RETURN(auto put_fn, + DevicePut(argument, client->ifrt_client_.get(), device, + options, ifrt::MemoryKind())); + TF_ASSIGN_OR_RETURN(auto put, [&]() { + // Must release the GIL before calling IFRT because backends may + // decide to block/sleep for device buffer allocation. + nb::gil_scoped_release gil_release; + return std::move(put_fn)(); + }()); + + if (put.ifrt_array) { + auto traceback = Traceback::Get(); + return PyArray::MakeFromSingleDeviceArray( + std::move(client), std::move(traceback), std::move(put.ifrt_array), + /*weak_type=*/false, + /*committed=*/false); + } else { + return put.owning_pybuffer; + } +} + +namespace { + +// Makes IFRT `CompileOptions` from XLA `CompileOptions` and optional host +// callbacks. +std::unique_ptr MakeIfrtCompileOptions( + CompileOptions options, std::vector host_callbacks) { + std::vector> + ifrt_loaded_host_callbacks; + ifrt_loaded_host_callbacks.reserve(host_callbacks.size()); + // Extract `ifrt::LoadedHostCallback`s from host callback capsules that were + // created by `PyClient::MakePythonCallbackUsingHostSendAndRecv()`. + for (auto& host_callback : host_callbacks) { + ifrt_loaded_host_callbacks.push_back(tsl::FormRef( + static_cast(host_callback.data()))); + } + return std::make_unique( + std::move(options), std::move(ifrt_loaded_host_callbacks)); +} + +// Makes IFRT `DeserializeExecutableOptions` from XLA `CompileOptions` and +// optional host callbacks. +std::unique_ptr +MakeIfrtDeserializeExecutableOptions(std::optional options, + std::vector host_callbacks) { + std::vector> + ifrt_loaded_host_callbacks; + ifrt_loaded_host_callbacks.reserve(host_callbacks.size()); + // Extract `ifrt::LoadedHostCallback`s from host callback capsules that were + // created by `PyClient::MakePythonCallbackUsingHostSendAndRecv()`. + for (auto& host_callback : host_callbacks) { + ifrt_loaded_host_callbacks.push_back(tsl::FormRef( + static_cast(host_callback.data()))); + } + return std::make_unique( + std::move(options), std::move(ifrt_loaded_host_callbacks)); +} + +} // namespace + +/* static */ absl::StatusOr> +PyClient::CompileIfrtProgram( + nb_class_ptr client, std::unique_ptr ifrt_program, + std::unique_ptr ifrt_options) { + auto* pjrt_compatible_client = + llvm::dyn_cast_or_null( + client->ifrt_client_.get()); + auto* ifrt_xla_options = + llvm::dyn_cast_or_null(ifrt_options.get()); + // For XLA programs, pass allocated device memory size to compile options for + // pjrt compatible backends. + if (pjrt_compatible_client != nullptr && ifrt_xla_options != nullptr) { + xla::CompileOptions& options = ifrt_xla_options->compile_options; + auto addressable_devices = + pjrt_compatible_client->pjrt_client()->addressable_devices(); + if (!addressable_devices.empty()) { + int device_ordinal = options.executable_build_options.device_ordinal(); + if (device_ordinal < 0) { + device_ordinal = 0; + } + CHECK_LT(device_ordinal, addressable_devices.size()); + auto stats = addressable_devices[device_ordinal]->GetAllocatorStats(); + if (stats.ok() && stats->bytes_limit) { + options.executable_build_options.set_device_memory_size( + *stats->bytes_limit); + } + } + + if (pjrt_compatible_client->pjrt_client()->key_value_store().has_value()) { + options.executable_build_options.set_key_value_store( + *pjrt_compatible_client->pjrt_client()->key_value_store()); + } + } + + std::unique_ptr ifrt_loaded_executable; + std::optional fingerprint; + { + nb::gil_scoped_release gil_release; + TF_ASSIGN_OR_RETURN(ifrt_loaded_executable, + client->ifrt_client_->GetDefaultCompiler()->Compile( + std::move(ifrt_program), std::move(ifrt_options))); + TF_RETURN_IF_ERROR(ifrt_loaded_executable->GetReadyFuture().Await()); + TF_ASSIGN_OR_RETURN(fingerprint, ifrt_loaded_executable->Fingerprint()); + } + auto traceback = Traceback::Get(); + return make_nb_class( + std::move(client), std::move(ifrt_loaded_executable), + std::move(traceback), std::move(fingerprint)); +} + +/* static */ absl::StatusOr> PyClient::Compile( + nb_class_ptr client, std::string mlir_module, + CompileOptions options, std::vector host_callbacks) { + mlir::MLIRContext context; + TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, + ParseMlirModuleString(mlir_module, context)); + if (options.executable_build_options.use_shardy_partitioner()) { + // Since Shardy is located in the middle of the XLA pipeline, we need to + // export it before going to HLO while preserving Shardy ops and attrs. + TF_RETURN_IF_ERROR(ExportShardyForHloRoundTrip(*module)); + } + return CompileIfrtProgram( + client, std::make_unique(module.get()), + MakeIfrtCompileOptions(std::move(options), std::move(host_callbacks))); +} + +/* static */ absl::StatusOr> PyClient::Compile( + nb_class_ptr client, std::string mlir_module, + CompileOptions options, std::vector host_callbacks) { + mlir::MLIRContext context; + TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, + ParseMlirModuleString(mlir_module, context)); + if (options.executable_build_options.use_shardy_partitioner()) { + // Since Shardy is located in the middle of the XLA pipeline, we need to + // export it before going to HLO while preserving Shardy ops and attrs. + TF_RETURN_IF_ERROR(ExportShardyForHloRoundTrip(*module)); + } + + std::vector> + ifrt_loaded_host_callbacks; + ifrt_loaded_host_callbacks.reserve(host_callbacks.size()); + // Extract `ifrt::LoadedHostCallback`s from host callback capsules that were + // created by `PyClient::MakePythonCallbackUsingHostSendAndRecv()`. + for (auto& host_callback : host_callbacks) { + auto callback = tsl::MakeRef( + client->ifrt_client(), std::move(host_callback)); + ifrt_loaded_host_callbacks.push_back(callback); + } + auto compile_options = std::make_unique( + std::move(options), std::move(ifrt_loaded_host_callbacks)); + return CompileIfrtProgram( + client, std::make_unique(module.get()), + std::move(compile_options)); +} + +absl::StatusOr PyClient::SerializeExecutable( + const PyLoadedExecutable& executable) const { + TF_ASSIGN_OR_RETURN(auto serialized, + executable.ifrt_loaded_executable()->Serialize()); + return nb::bytes(serialized.data(), serialized.size()); +} + +/* static */ absl::StatusOr> +PyClient::DeserializeExecutable(nb_class_ptr client, + nb::bytes serialized, + std::optional options, + std::vector host_callbacks) { + std::unique_ptr ifrt_loaded_executable; + std::optional fingerprint; + auto ifrt_deserialize_options = MakeIfrtDeserializeExecutableOptions( + std::move(options), std::move(host_callbacks)); + { + nb::gil_scoped_release gil_release; + TF_ASSIGN_OR_RETURN( + ifrt_loaded_executable, + client->ifrt_client_->GetDefaultCompiler()->DeserializeLoadedExecutable( + absl::string_view(serialized.c_str(), serialized.size()), + std::move(ifrt_deserialize_options))); + } + TF_ASSIGN_OR_RETURN(fingerprint, ifrt_loaded_executable->Fingerprint()); + auto traceback = Traceback::Get(); + return make_nb_class( + std::move(client), std::move(ifrt_loaded_executable), + std::move(traceback), std::move(fingerprint)); +} + +namespace { + +struct HeapProfileKey { + Traceback* traceback; + int64_t size; + xla::PjRtDevice* device; + bool operator==(const HeapProfileKey& other) const; +}; + +bool HeapProfileKey::operator==(const HeapProfileKey& other) const { + if (size != other.size || device != other.device) { + return false; + } + if ((traceback == nullptr) != (other.traceback == nullptr)) { + return false; + } + if (traceback && traceback->raw_frames() != other.traceback->raw_frames()) { + return false; + } + return true; +} + +template +H AbslHashValue(H h, const HeapProfileKey& key) { + if (key.traceback) { + h = H::combine(std::move(h), key.traceback->raw_frames()); + } + h = H::combine(std::move(h), key.size, key.device); + return h; +} + +} // namespace + +absl::StatusOr PyClient::HeapProfile() { + CHECK(PyGILState_Check()); + absl::flat_hash_set buffer_set; + absl::flat_hash_map entries; + + auto add_buffer_to_profile = [&](PjRtBuffer* buffer, Traceback* traceback) { + // We only wish to count each PjRtBuffer once, even though they may be + // shared by multiple PyArrays. + if (!buffer->IsDeleted() && buffer_set.insert(buffer).second) { + TF_ASSIGN_OR_RETURN(size_t size, buffer->GetOnDeviceSizeInBytes()); + HeapProfileKey key{traceback, static_cast(size), + buffer->device()}; + ++entries[key]; + } + return absl::OkStatus(); + }; + + std::vector arrays = LiveArrays(); + for (const PyArray& array : arrays) { + if (array.ifrt_array() == nullptr) { + continue; + } + auto* arr = + llvm::dyn_cast_or_null(array.ifrt_array()); + // TODO(hyeontaek): Support non-PjRt Arrays. + if (arr == nullptr) { + throw XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend " + "only."); + } + for (const auto& buffer : arr->pjrt_buffers()) { + TF_RETURN_IF_ERROR(add_buffer_to_profile( + buffer.get(), + array.traceback() ? array.traceback()->get() : nullptr)); + } + } + + for (PyLoadedExecutable* executable = executables_; executable; + executable = executable->next_) { + if (!executable->is_deleted()) { + HeapProfileKey key{ + executable->traceback() ? executable->traceback()->get() : nullptr, + executable->SizeOfGeneratedCodeInBytes(), nullptr}; + ++entries[key]; + } + } + + PprofProfileBuilder builder; + auto* allocations = builder.profile().add_sample_type(); + allocations->set_type(builder.StringId("allocations")); + allocations->set_unit(builder.StringId("count")); + auto* space = builder.profile().add_sample_type(); + space->set_type(builder.StringId("space")); + space->set_unit(builder.StringId("bytes")); + + const int kind_string_id = builder.StringId("kind"); + const int buffer_string_id = builder.StringId("buffer"); + const int executable_string_id = builder.StringId("executable"); + const int device_string_id = builder.StringId("device"); + for (const auto& entry : entries) { + auto* sample = builder.profile().add_sample(); + if (entry.first.traceback) { + for (const auto& frame : entry.first.traceback->raw_frames()) { + sample->add_location_id(builder.LocationId(frame.first, frame.second)); + } + } + sample->add_value(entry.second); + sample->add_value(entry.first.size * entry.second); + + auto* kind_label = sample->add_label(); + kind_label->set_key(kind_string_id); + if (entry.first.device) { + kind_label->set_str(buffer_string_id); + auto* device_label = sample->add_label(); + device_label->set_key(device_string_id); + std::string device_label_str(entry.first.device->DebugString()); + device_label->set_str(builder.StringId(device_label_str)); + } else { + kind_label->set_str(executable_string_id); + } + } + std::string serialized = builder.profile().SerializeAsString(); + return nb::bytes(serialized.data(), serialized.size()); +} + +absl::StatusOr PyClient::MakePythonCallbackUsingHostSendAndRecv( + nb::callable callable, absl::Span operand_shapes, + absl::Span result_shapes, + absl::Span send_channel_ids, + absl::Span recv_channel_ids, nb::callable serializer) { + TF_ASSIGN_OR_RETURN( + auto loaded_host_callback, + PyHostSendAndRecvLoadedHostCallback::Create( + ifrt_client(), std::move(callable), operand_shapes, result_shapes, + send_channel_ids, recv_channel_ids, std::move(serializer))); + nb::capsule callback_capsule( + loaded_host_callback.release(), [](void* ptr) noexcept { + static_cast(ptr)->DropRef(); + }); + return callback_capsule; +} + +/* static */ int PyClient::tp_traverse(PyObject* self, visitproc visit, + void* arg) { + Py_VISIT(Py_TYPE(self)); + if (!nb::inst_ready(self)) { + return 0; + } + PyClient* c = nb::inst_ptr(self); + for (const auto& [ifrt_device, py_device] : c->devices_) { + Py_VISIT(py_device.ptr()); + } + for (const auto& [ifrt_memory, py_memory] : c->memory_spaces_) { + Py_VISIT(py_memory.ptr()); + } + return 0; +} + +/* static */ int PyClient::tp_clear(PyObject* self) { + PyClient* c = nb::inst_ptr(self); + absl::flat_hash_map> devices; + std::swap(devices, c->devices_); + absl::flat_hash_map> memory_spaces; + std::swap(memory_spaces, c->memory_spaces_); + return 0; +} + +PyType_Slot PyClient::slots_[] = { + {Py_tp_traverse, (void*)PyClient::tp_traverse}, + {Py_tp_clear, (void*)PyClient::tp_clear}, + {0, nullptr}, +}; + +/* static */ void PyClient::RegisterPythonTypes(nb::module_& m) { + nb::enum_(m, "HostBufferSemantics") + .value("IMMUTABLE_ONLY_DURING_CALL", + PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall) + .value("IMMUTABLE_UNTIL_TRANSFER_COMPLETES", + PjRtClient::HostBufferSemantics::kImmutableUntilTransferCompletes) + .value("ZERO_COPY", PjRtClient::HostBufferSemantics::kImmutableZeroCopy); + + nb::class_ py_local_client(m, "Client", nb::is_weak_referenceable(), + nb::type_slots(PyClient::slots_)); + py_local_client.def_prop_ro("platform", &PyClient::platform_name) + .def_prop_ro("_raw_platform", &PyClient::raw_platform_name) + .def_prop_ro("platform_version", &PyClient::platform_version) + .def_prop_ro("runtime_type", &PyClient::runtime_type) + .def("device_count", &PyClient::device_count) + .def("local_device_count", &PyClient::addressable_device_count) + .def("devices", &PyClient::Devices) + .def("local_devices", &PyClient::LocalDevices) + // TODO(hyeontaek): Remove this method once we have a unified API for + // enumerating devices with different criteria. + .def("_get_all_devices", &PyClient::GetAllDevices) + .def("device_from_local_hardware_id", + xla::ValueOrThrowWrapper(&PyClient::DeviceFromLocalHardwareId)) + .def("live_executables", &PyClient::LiveExecutables) + .def("live_arrays", &PyClient::LiveArrays) + .def("live_buffers", &PyClient::LiveArrays) + .def("process_index", &PyClient::process_index) + .def("host_id", &PyClient::process_index) + .def("task_id", &PyClient::process_index) + .def( + "buffer_from_pyval", + [](nb_class_ptr client, nb::handle argument, + PyDevice* device, bool force_copy, + PjRtClient::HostBufferSemantics host_buffer_semantics) { + return ValueOrThrow( + PyClient::BufferFromPyval(std::move(client), argument, + device ? device->device() : nullptr, + force_copy, host_buffer_semantics)); + }, + nb::arg("argument"), nb::arg("device").none() = nullptr, + nb::arg("force_copy") = false, + nb::arg("host_buffer_semantics") = + PjRtClient::HostBufferSemantics::kImmutableZeroCopy) + .def( + "compile", + [](nb_class_ptr client, nb::bytes mlir_module, + CompileOptions options, std::vector host_callbacks) { + return ValueOrThrow(PyClient::Compile( + std::move(client), + std::string(mlir_module.c_str(), mlir_module.size()), + std::move(options), std::move(host_callbacks))); + }, + nb::arg("computation"), nb::arg("compile_options") = CompileOptions(), + nb::arg("host_callbacks") = std::vector()) + .def( + "compile", + [](nb_class_ptr client, nb::bytes mlir_module, + CompileOptions options, std::vector host_callbacks) { + return ValueOrThrow(PyClient::Compile( + std::move(client), + std::string(mlir_module.c_str(), mlir_module.size()), + std::move(options), std::move(host_callbacks))); + }, + nb::arg("computation"), nb::arg("compile_options") = CompileOptions(), + nb::arg("host_callbacks") = std::vector()) + .def( + "compile", + [](nb_class_ptr client, std::string mlir_module, + CompileOptions options, std::vector host_callbacks) { + return ValueOrThrow(PyClient::Compile( + std::move(client), std::move(mlir_module), std::move(options), + std::move(host_callbacks))); + }, + nb::arg("computation"), nb::arg("compile_options") = CompileOptions(), + nb::arg("host_callbacks") = std::vector()) + .def( + "compile", + [](nb_class_ptr client, std::string mlir_module, + CompileOptions options, std::vector host_callbacks) { + return ValueOrThrow(PyClient::Compile( + std::move(client), std::move(mlir_module), std::move(options), + std::move(host_callbacks))); + }, + nb::arg("computation"), nb::arg("compile_options") = CompileOptions(), + nb::arg("host_callbacks") = std::vector()) + .def("compile_ifrt_program", + xla::ValueOrThrowWrapper(PyClient::CompileIfrtProgram)) + .def("serialize_executable", + xla::ValueOrThrowWrapper(&PyClient::SerializeExecutable)) + .def( + "deserialize_executable", + [](nb_class_ptr client, nb::bytes serialized, + std::optional options, + std::vector host_callbacks) { + return ValueOrThrow(PyClient::DeserializeExecutable( + std::move(client), std::move(serialized), std::move(options), + std::move(host_callbacks))); + }, + nb::arg("serialized"), nb::arg("compile_options").none() = nb::none(), + nb::arg("host_callbacks") = std::vector()) + .def("heap_profile", xla::ValueOrThrowWrapper(&PyClient::HeapProfile)) + // TODO(zhangqiaorjc): Experimental. + .def("defragment", + [](PyClient& self) { xla::ThrowIfError(self.Defragment()); }) + .def("make_python_callback_from_host_send_and_recv", + xla::ValueOrThrowWrapper( + &PyClient::MakePythonCallbackUsingHostSendAndRecv), + nb::arg("callable"), nb::arg("operand_shapes"), + nb::arg("result_shapes"), nb::arg("send_channel_ids"), + nb::arg("recv_channel_ids"), + nb::arg("serializer").none() = nb::none()) + .def( + "get_default_layout", + [](PyClient& self, nb_dtype dtype, nb::sequence shard_shape, + nb_class_ptr device) + -> std::shared_ptr { + ifrt::DType ifrt_type = xla::ValueOrThrow(DtypeToIfRtDType(dtype)); + std::vector dims = SequenceToVector(shard_shape); + return xla::ValueOrThrow(self.ifrt_client()->GetDefaultLayout( + ifrt_type, dims, device->device(), xla::ifrt::MemoryKind())); + }, + nb::arg("dtype"), nb::arg("shard_shape"), nb::arg("device")) + .def("__getattr__", + [](PyClient& client, absl::string_view name) -> nb::object { + const auto& attrs = client.Attributes().map(); + auto it = attrs.find(name); + if (it != attrs.end()) { + return std::visit([](auto&& v) { return nb::cast(v.value); }, + it->second); + } + throw nb::attribute_error( + absl::StrCat("Unknown attribute ", name).c_str()); + }); +} + +} // namespace xla diff --git a/jaxlib/xla/py_client.h b/jaxlib/xla/py_client.h new file mode 100644 index 000000000000..29a506d48864 --- /dev/null +++ b/jaxlib/xla/py_client.h @@ -0,0 +1,252 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_PY_CLIENT_H_ +#define JAXLIB_XLA_PY_CLIENT_H_ + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "llvm/Support/Casting.h" +#include "nanobind/nanobind.h" +#include "jaxlib/xla/nb_class_ptr.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/python/ifrt/attribute_map.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/compiler.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/program.h" +#include "xla/python/pjrt_ifrt/pjrt_client.h" +#include "xla/shape.h" + +namespace xla { + +class PyClient; +class PyLoadedExecutable; +class PyArray; +class PyDevice; +class PyMemorySpace; +struct PyArray_Storage; + +// Python wrapper around PjRtClient. +// We use a wrapper class to add Python-specific functionality. +class PyClient { + public: + static nb_class_ptr Make(std::shared_ptr ifrt_client); + + // Do not call the constructor directly. Use `PyClient::Make` instead. + explicit PyClient(std::shared_ptr ifrt_client); + virtual ~PyClient(); + + ifrt::Client* ifrt_client() const { return ifrt_client_.get(); } + const std::shared_ptr& shared_ptr_ifrt_client() const { + return ifrt_client_; + } + + // Short-term escape hatch to get PjRtClient from PyClient. + // TODO(hyeontaek): Migrate all users of this method to be agnostic of PjRt. + xla::PjRtClient* pjrt_client() const { + auto* pjrt_client = + llvm::dyn_cast_or_null(ifrt_client_.get()); + if (pjrt_client == nullptr) { + throw XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend only."); + } + return pjrt_client->pjrt_client(); + } + std::shared_ptr shared_ptr_pjrt_client() { + auto* pjrt_client = + llvm::dyn_cast_or_null(ifrt_client_.get()); + if (pjrt_client == nullptr) { + throw XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend only."); + } + return pjrt_client->shared_ptr_pjrt_client(); + } + + // Legacy alises. + std::shared_ptr shared_pjrt_client() { + return shared_ptr_pjrt_client(); + } + + absl::string_view platform_name() const { + // TODO(phawkins): this is a temporary backwards compatibility shim. We + // changed the name PJRT reports for GPU platforms to "cuda" or "rocm", but + // we haven't yet updated JAX clients that expect "gpu". Migrate users and + // remove this code. + if (ifrt_client_->platform_name() == "cuda" || + ifrt_client_->platform_name() == "rocm") { + return "gpu"; + } else { + return ifrt_client_->platform_name(); + } + } + absl::string_view raw_platform_name() const { + // TODO(parkers): Once platform_name() is the same, remove this. + return ifrt_client_->platform_name(); + } + absl::string_view platform_version() const { + return ifrt_client_->platform_version(); + } + absl::string_view runtime_type() const { + return ifrt_client_->runtime_type(); + } + + // Returns implementation-specific attributes about this client, e.g. the PJRT + // C API version if applicable. + const xla::ifrt::AttributeMap& Attributes() const { + return client_attributes_; + } + + int addressable_device_count() const { + return ifrt_client_->addressable_device_count(); + } + int device_count() const { return ifrt_client_->device_count(); } + int process_index() const { return ifrt_client_->process_index(); } + + std::vector> Devices(); + std::vector> LocalDevices(); + // Returns all devices in the client. Private API; only use this method for + // implementing backend._get_all_devices(). + // TODO(hyeontaek): Remove this method once we have a unified API for + // enumerating devices with different criteria. + std::vector> GetAllDevices(); + absl::StatusOr> DeviceFromLocalHardwareId( + int local_hardware_id); + + // Returns the PyDevice associated with the given ifrt::Device. + nb_class_ptr GetPyDevice(ifrt::Device* device); + + // Returns the PyMemorySpace associated with the given ifrt::Memory. + nb_class_ptr GetPyMemorySpace(ifrt::Memory* memory_space); + + // Returns a vector of live PyArray objects. PyArray objects may share + // PjRtBuffers, so there may be duplicates of the same underlying device + // buffer. + std::vector LiveBuffersOnDevice(ifrt::Device* device); + + nanobind::list LiveExecutables(); + + // TODO(zhangqiaorjc): Remove when we have transparent defragmentation. + absl::Status Defragment(); + + static absl::StatusOr BufferFromPyval( + nb_class_ptr client, nanobind::handle argument, + ifrt::Device* device, bool force_copy, + ifrt::Client::HostBufferSemantics host_buffer_semantics); + + static absl::StatusOr> CompileIfrtProgram( + nb_class_ptr client, + std::unique_ptr ifrt_program, + std::unique_ptr ifrt_options); + + static absl::StatusOr> Compile( + nb_class_ptr client, std::string mlir_module, + CompileOptions options, std::vector host_callbacks); + + static absl::StatusOr> Compile( + nb_class_ptr client, std::string mlir_module, + CompileOptions options, std::vector host_callbacks); + + absl::StatusOr SerializeExecutable( + const PyLoadedExecutable& executable) const; + static absl::StatusOr> DeserializeExecutable( + nb_class_ptr client, nanobind::bytes serialized, + std::optional options, + std::vector host_callbacks); + + absl::StatusOr HeapProfile(); + + // `MakePythonCallbackUsingHostSendAndRecv` takes in an input Python callable + // that takes in arguments of shapes `operand_shapes` and returns results of + // shapes `result_shapes`. The arguments correspond to Send ops in the HLO + // program through `send_channel_ids` and the results correspond to Recv ops + // through `recv_channel_ids`. It returns the host callback as an opaque + // object whose reference will keep the Python callback alive. The host + // callback can be passed to `PyClient::Compile` or + // `PyClient::DeserializeExecutable`. The corresponding Send/Recv ops in the + // XLA computation can trigger the execution of this host callback. + // `serializer` is a function that takes `callable` as an argument and returns + // a serialized callable as a string. + // + // The callable receives as arguments NumPy arrays for arguments with array + // types, and None for Token argument. The callable must return a tuple of + // either arrays or None values. + absl::StatusOr MakePythonCallbackUsingHostSendAndRecv( + nanobind::callable callable, absl::Span operand_shapes, + absl::Span result_shapes, + absl::Span send_channel_ids, + absl::Span recv_channel_ids, + nanobind::callable serializer); + + std::vector LiveArrays() const; + + static void RegisterPythonTypes(nanobind::module_& m); + + protected: + static void Initialize(nb_class_ptr client); + + private: + friend class PyLoadedExecutable; + friend class PyArray; + friend struct PyArray_Storage; + + static int tp_traverse(PyObject* self, visitproc visit, void* arg); + static int tp_clear(PyObject* self); + static PyType_Slot slots_[]; + + std::shared_ptr ifrt_client_; + xla::ifrt::AttributeMap client_attributes_; + // Pointers to intrusive doubly-linked lists of arrays and executables, used + // to iterate over all known objects when heap profiling. The list structure + // is protected by the GIL. + + nanobind::ft_mutex executables_mutex_; + // List guarded by executables_mutex_. + PyLoadedExecutable* executables_ = nullptr; + +#ifdef NB_FREE_THREADING + static constexpr size_t kNumArraysShards = 16; +#else + static constexpr size_t kNumArraysShards = 1; +#endif + struct ArraysShard { + mutable nanobind::ft_mutex mutex; + PyArray_Storage* arrays; + }; + std::array arrays_; + + absl::flat_hash_map> devices_; + absl::flat_hash_map> + memory_spaces_; +}; + +} // namespace xla + +#endif // JAXLIB_XLA_PY_CLIENT_H_ diff --git a/jaxlib/xla/py_client_cpu.cc b/jaxlib/xla/py_client_cpu.cc new file mode 100644 index 000000000000..91f4e4ee42b9 --- /dev/null +++ b/jaxlib/xla/py_client_cpu.cc @@ -0,0 +1,185 @@ +/* Copyright 2025 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/py_client_cpu.h" + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/inlined_vector.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "xla/ffi/api/ffi.h" +#include "xla/ffi/ffi_api.h" +#include "xla/pjrt/host_callback.h" +#include "xla/pjrt/transpose.h" +#include "xla/primitive_util.h" +#include "xla/python/nb_numpy.h" +#include "xla/python/types.h" +#include "xla/shape_util.h" +#include "xla/xla_data.pb.h" + +namespace nb = nanobind; + +namespace xla { + +struct CpuTransposePlanCache { + static ffi::TypeId id; + explicit CpuTransposePlanCache(int capacity) : cache(capacity) {} + xla::TransposePlanCache cache; +}; + +ffi::TypeId CpuTransposePlanCache::id = {}; + +XLA_FFI_REGISTER_TYPE(ffi::GetXlaFfiApi(), "CpuTransposePlanCache", + &CpuTransposePlanCache::id); + +static ffi::ErrorOr> +CpuTransposePlanCacheInstantiate(uint64_t index) { + return std::make_unique(/*capacity=*/16); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + kCpuTransposePlanCacheInstantiate, CpuTransposePlanCacheInstantiate, + ffi::Ffi::BindInstantiate().Attr("index")); + +ffi::Error XlaFfiPythonCpuCallback(FfiLoadedHostCallbacks* callbacks, + CpuTransposePlanCache* transpose_cache, + uint64_t index, ffi::RemainingArgs args, + ffi::RemainingRets rets) { + nb::gil_scoped_acquire gil; + auto callback = nb::borrow( + static_cast(callbacks->callbacks[index])); + auto nb_args = nb::steal(PyTuple_New(args.size())); + for (size_t i = 0; i < args.size(); ++i) { + auto arg = args.get(i); + auto ptype = static_cast(arg->element_type()); + // TODO(b/395428868): Remove this check once we support subbyte types. + if (ptype == S1 || ptype == S2 || ptype == S4 || ptype == U1 || + ptype == U2 || ptype == U4) { + return ffi::Error(ffi::ErrorCode::kUnimplemented, + absl::StrFormat("Unsupported primitive type: %s", + PrimitiveType_Name(ptype))); + } + if (ptype == TOKEN) { + PyTuple_SET_ITEM(nb_args.ptr(), i, nb::none().release().ptr()); + continue; + } + auto maybe_dtype = PrimitiveTypeToNbDtype(ptype); + if (!maybe_dtype.ok()) { + return ffi::Error::Internal(maybe_dtype.status().ToString()); + } + auto dtype = maybe_dtype.value(); + auto dims = absl::Span(arg->dimensions().begin(), + arg->dimensions().size()); + // We pass in data using default numpy layout i.e., std::nullopt. + auto array = + nb_numpy_ndarray(dtype, dims, std::nullopt, arg.value().untyped_data()); + array.attr("flags").attr("writeable") = nb::bool_(false); + PyTuple_SET_ITEM(nb_args.ptr(), i, array.release().ptr()); + } + + EnterHostCallback(); + // TODO(dsuo): Change this to use the Python vectorcall protocol, which allows + // you to avoid constructing a tuple for the arguments. + nb::tuple result_tuple; + try { + auto result_object = callback(*nb::borrow(nb_args)); + result_tuple = nb::cast(result_object); + } catch (nb::python_error& e) { + return ffi::Error::Internal( + absl::StrFormat("CpuCallback error calling callback: %s", e.what())); + } + LeaveHostCallback(); + + for (size_t i = 0; i < rets.size(); ++i) { + auto ret = rets.get(i).value(); + auto ptype = static_cast(ret->element_type()); + // TODO(b/395428868): Remove this check once we support subbyte types. + if (ptype == S1 || ptype == S2 || ptype == S4 || ptype == U1 || + ptype == U2 || ptype == U4) { + return ffi::Error(ffi::ErrorCode::kUnimplemented, + absl::StrFormat("Unsupported primitive type: %s", + PrimitiveType_Name(ptype))); + } + if (ptype == TOKEN) continue; + nb::object output = + nb::borrow(PyTuple_GetItem(result_tuple.ptr(), i)); + nb_numpy_ndarray array = nb_numpy_ndarray::ensure(std::move(output)); + absl::Span strides( + reinterpret_cast(array.strides()), array.ndim()); + // We expect the output to be in default numpy layout. + auto dims = absl::Span(ret->dimensions().begin(), + ret->dimensions().size()); + auto maybe_expected_shape = ShapeUtil::MakeValidatedShape(ptype, dims); + if (!maybe_expected_shape.ok()) { + return ffi::Error::Internal(maybe_expected_shape.status().ToString()); + } + auto expected_shape = maybe_expected_shape.value(); + auto expected_strides = ByteStridesForShape(expected_shape); + if (strides == expected_strides) { + std::memcpy(ret->untyped_data(), array.data(), ret->size_bytes()); + continue; + } + xla::TransposePlan::Options options; + options.elem_size_in_bytes = xla::primitive_util::ByteWidth(ptype); + options.dims = absl::Span( + reinterpret_cast(array.shape()), array.ndim()); + absl::InlinedVector reversed_layout; + reversed_layout.resize(expected_shape.dimensions_size()); + absl::c_reverse_copy(expected_shape.layout().minor_to_major(), + reversed_layout.begin()); + options.permutation = reversed_layout; + options.input_layout = xla::TransposePlan::Striding{strides}; + auto maybe_plan = transpose_cache->cache.GetOrCreate(options); + if (!maybe_plan.ok()) { + return ffi::Error::Internal(maybe_plan.status().ToString()); + } + auto plan = maybe_plan.value(); + plan->Execute(array.data(), ret->untyped_data()); + } + + return ffi::Error::Success(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(kXlaFfiPythonCpuCallback, XlaFfiPythonCpuCallback, + ffi::Ffi::Bind() + .Ctx>() + .Ctx>() + .Attr("index") + .RemainingArgs() + .RemainingRets()); + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "xla_ffi_python_cpu_callback", + "HOST", + {kCpuTransposePlanCacheInstantiate, nullptr, nullptr, + kXlaFfiPythonCpuCallback}); +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), + "xla_ffi_partitioned_python_cpu_callback", "HOST", + {kCpuTransposePlanCacheInstantiate, nullptr, nullptr, + kXlaFfiPythonCpuCallback}); +} // namespace xla diff --git a/jaxlib/xla/py_client_cpu.h b/jaxlib/xla/py_client_cpu.h new file mode 100644 index 000000000000..0035b0a361fa --- /dev/null +++ b/jaxlib/xla/py_client_cpu.h @@ -0,0 +1,28 @@ +/* Copyright 2025 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_PY_CLIENT_CPU_H_ +#define JAXLIB_XLA_PY_CLIENT_CPU_H_ + +#include "xla/ffi/api/ffi.h" + +namespace xla { + +XLA_FFI_DECLARE_HANDLER_SYMBOL(kCpuTransposePlanCacheInstantiate); +XLA_FFI_DECLARE_HANDLER_SYMBOL(kXlaFfiPythonCpuCallback); + +} // namespace xla + +#endif // JAXLIB_XLA_PY_CLIENT_CPU_H_ diff --git a/jaxlib/xla/py_compile_only_client.cc b/jaxlib/xla/py_compile_only_client.cc new file mode 100644 index 000000000000..673dfc214346 --- /dev/null +++ b/jaxlib/xla/py_compile_only_client.cc @@ -0,0 +1,131 @@ +/* Copyright 2023 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/py_compile_only_client.h" + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "llvm/Support/Casting.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OwningOpRef.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/xla/nb_class_ptr.h" +#include "jaxlib/xla/py_client.h" +#include "xla/pjrt/mlir_to_hlo.h" +#include "xla/pjrt/pjrt_compiler.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/compile_only_ifrt/client.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/pjrt_ifrt/pjrt_executable.h" +#include "xla/python/pjrt_ifrt/pjrt_topology.h" +#include "xla/python/pjrt_ifrt/xla_compiler.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/python/lib/core/numpy.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" + +namespace nb = nanobind; + +namespace xla { + +namespace { + +class CompileOnlyPyClient : public PyClient { + public: + using PyClient::PyClient; + + static nb_class_ptr Make( + std::shared_ptr topology) { + auto client = + nb::borrow>(make_nb_class( + std::make_unique(std::move(topology)))); + CompileOnlyPyClient::Initialize(client); + return client; + } + + absl::StatusOr> CompileUnloaded( + absl::string_view mlir_module, CompileOptions options, + std::vector host_callbacks) { + if (!host_callbacks.empty()) { + return Unimplemented( + "Compiling with host_callbacks not available with compile-only " + "client."); + } + nb::gil_scoped_release gil_release; + mlir::MLIRContext context; + TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, + ParseMlirModuleString(mlir_module, context)); + if (options.executable_build_options.use_shardy_partitioner()) { + // Since Shardy is located in the middle of the XLA pipeline, we need to + // export it before going to HLO while preserving Shardy ops and attrs. + TF_RETURN_IF_ERROR(ExportShardyForHloRoundTrip(*module)); + } + auto* ifrt_client = + llvm::dyn_cast_or_null(this->ifrt_client()); + CHECK(ifrt_client) << "CompileOnlyPyClient requires ifrt_client be a " + "CompileOnlyIfRtClient"; + auto xla_options = std::make_unique(options); + TF_ASSIGN_OR_RETURN(auto executable, + PjRtCompile(std::move(options), module.get(), + *ifrt_client->topology().description())); + TF_ASSIGN_OR_RETURN(auto ifrt_executable, + ifrt::PjRtExecutable::Create(std::move(executable))); + return std::shared_ptr(std::move(ifrt_executable)); + } + + private: + static void Initialize(nb_class_ptr client) { + PyClient::Initialize(client); + } +}; + +} // namespace + +nb_class_ptr MakeCompileOnlyClient( + std::shared_ptr topology) { + return CompileOnlyPyClient::Make(std::move(topology)); +} + +void RegisterCompileOnlyClient(nb::module_& m) { + nb::class_(m, "CompileOnlyPyClient") + .def( + "compile", + [](CompileOnlyPyClient& self, nb::bytes mlir_module, + CompileOptions options, std::vector host_callbacks) { + return ValueOrThrow(self.CompileUnloaded( + absl::string_view(mlir_module.c_str(), mlir_module.size()), + std::move(options), std::move(host_callbacks))); + }, + nb::arg("computation"), nb::arg("compile_options") = CompileOptions(), + nb::arg("host_callbacks") = std::vector()) + .def( + "compile", ValueOrThrowWrapper(&CompileOnlyPyClient::CompileUnloaded), + nb::arg("computation"), nb::arg("compile_options") = CompileOptions(), + nb::arg("host_callbacks") = std::vector()); +} + +} // namespace xla diff --git a/jaxlib/xla/py_compile_only_client.h b/jaxlib/xla/py_compile_only_client.h new file mode 100644 index 000000000000..6cc700e1d3a9 --- /dev/null +++ b/jaxlib/xla/py_compile_only_client.h @@ -0,0 +1,45 @@ +/* Copyright 2023 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_PY_COMPILE_ONLY_CLIENT_H_ +#define JAXLIB_XLA_PY_COMPILE_ONLY_CLIENT_H_ + +#include + +// placeholder for index annotation headers +#include "nanobind/nanobind.h" +#include "jaxlib/xla/nb_class_ptr.h" +#include "jaxlib/xla/py_client.h" +#include "xla/python/pjrt_ifrt/pjrt_topology.h" + +namespace xla { + +// This is a workaround for AOT compilation until topologies and device +// descriptions are better integrated into jax's Python code. It returns a +// PyClient that will return errors for all non-AOT methods. It also exposes a +// different compile method that returns an unloaded executable (vs. PyClient +// usually returns a loaded executable). RegisterCompileOnlyClient() overloads +// the Python "compile" method to return the unloaded executable, and we rely on +// Python duck typing to treat the unloaded executable like a loaded executable +// (except it will raise errors if you try to run it, which is what we want for +// AOT environments). +nb_class_ptr MakeCompileOnlyClient( + std::shared_ptr); + +void RegisterCompileOnlyClient(nanobind::module_& m); + +} // namespace xla + +#endif // JAXLIB_XLA_PY_COMPILE_ONLY_CLIENT_H_ diff --git a/jaxlib/xla/py_device.cc b/jaxlib/xla/py_device.cc new file mode 100644 index 000000000000..253bfd439a9b --- /dev/null +++ b/jaxlib/xla/py_device.cc @@ -0,0 +1,350 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/py_device.h" + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "llvm/Support/Casting.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/variant.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/xla/nb_class_ptr.h" +#include "jaxlib/xla/py_client.h" +#include "jaxlib/xla/py_memory_space.h" +#include "jaxlib/xla/python_ref_manager.h" +#include "xla/layout_util.h" +#include "xla/literal.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/nb_helpers.h" +#include "xla/python/pjrt_ifrt/pjrt_client.h" +#include "xla/python/pjrt_ifrt/pjrt_device.h" +#include "xla/python/types.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/tsl/framework/allocator.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" + +namespace nb = ::nanobind; + +namespace xla { + +PyDevice::PyDevice(nb_class_ptr client, ifrt::Device* device) + : client_(std::move(client)), device_(device) {} + +int PyDevice::id() const { return device_->Id().value(); } + +int PyDevice::process_index() const { return device_->ProcessIndex(); } + +absl::string_view PyDevice::platform() const { + // TODO(phawkins): this is a temporary backwards + // compatibility shim. We changed the name PJRT + // reports for GPU platforms to "cuda" or "rocm", + // but we haven't yet updated JAX clients that + // expect "gpu". Migrate users and remove this + // code. + if (client_->platform_name() == "cuda" || + client_->platform_name() == "rocm") { + return absl::string_view("gpu"); + } else { + return client_->platform_name(); + } +} + +absl::string_view PyDevice::device_kind() const { return device_->Kind(); } + +std::optional PyDevice::local_hardware_id() const { + // TODO(phawkins): consider supporting this for non-PJRT devices. + ifrt::PjRtDevice* device = llvm::dyn_cast(device_); + if (device == nullptr || !device->IsAddressable()) { + return std::nullopt; + } + int local_hardware_id = device->pjrt_device()->local_hardware_id().value(); + if (local_hardware_id == -1) { + return std::nullopt; + } + return local_hardware_id; +} + +absl::string_view PyDevice::Str() const { return device_->DebugString(); } + +absl::string_view PyDevice::Repr() const { return device_->ToString(); } + +absl::Status PyDevice::TransferToInfeed(LiteralSlice literal) { + GlobalPyRefManager()->CollectGarbage(); + nb::gil_scoped_release gil_release; + auto client = llvm::dyn_cast(client_->ifrt_client()); + auto device = llvm::dyn_cast(device_); + if (client == nullptr || device == nullptr) { + return xla::InvalidArgument( + "TransferToInfeed is only supported for PjRt devices."); + } + return client->TransferToInfeed(device, literal); +} + +absl::StatusOr PyDevice::TransferFromOutfeed(Shape shape) { + GlobalPyRefManager()->CollectGarbage(); + std::shared_ptr literal; + { + nb::gil_scoped_release gil_release; + auto client = llvm::dyn_cast(client_->ifrt_client()); + auto device = llvm::dyn_cast(device_); + if (client == nullptr || device == nullptr) { + return xla::InvalidArgument( + "TransferFromOutfeed is only supported for PjRt devices."); + } + ShapeUtil::ForEachMutableSubshape( + &shape, [](Shape* subshape, const ShapeIndex&) { + if (!subshape->has_layout()) { + LayoutUtil::SetToDefaultLayout(subshape); + } + }); + literal = std::make_shared(shape); + TF_RETURN_IF_ERROR(client->TransferFromOutfeed(device, literal.get())); + } + return LiteralToPython(std::move(literal)); +} + +absl::StatusOr> PyDevice::Memory( + absl::string_view kind) const { + ifrt::Memory* result_memory_space = nullptr; + for (auto* memory_space : device_->Memories()) { + if (memory_space->Kind().memory_kind() == kind) { + if (result_memory_space != nullptr) { + std::string memories = absl::StrJoin( + device_->Memories(), ", ", + [](std::string* out, const auto& memory_space) { + absl::StrAppend(out, *memory_space->Kind().memory_kind()); + }); + auto device_kind = device_->Kind(); + return xla::InvalidArgument( + "Found more than one addressable memory for " + "kind %s which is not allowed. There can only " + "be one memory for each " + "kind. Device %s can address the following " + "memory kinds: %s", + kind, device_kind, memories); + } + result_memory_space = memory_space; + } + } + if (result_memory_space == nullptr) { + std::string memories = absl::StrJoin( + device_->Memories(), ", ", + [](std::string* out, const auto& memory_space) { + absl::StrAppend(out, *memory_space->Kind().memory_kind()); + }); + auto device_kind = device_->Kind(); + return xla::InvalidArgument( + "Could not find memory addressable by device %s. Device %s " + "can address the following memory kinds: %s. " + "Got memory kind: %s", + device_kind, device_kind, memories, kind); + } + return client_->GetPyMemorySpace(result_memory_space); +} + +absl::StatusOr> PyDevice::DefaultMemory() const { + TF_ASSIGN_OR_RETURN(auto* memory_space, device_->DefaultMemory()); + return client_->GetPyMemorySpace(memory_space); +} + +nb::list PyDevice::AddressableMemories() const { + nb::list memory_spaces; + for (auto* memory_space : device_->Memories()) { + memory_spaces.append(client_->GetPyMemorySpace(memory_space)); + } + return memory_spaces; +} + +absl::StatusOr> PyDevice::MemoryStats() const { + GlobalPyRefManager()->CollectGarbage(); + ifrt::PjRtDevice* device = llvm::dyn_cast(device_); + if (device == nullptr || !device->IsAddressable()) { + return xla::InvalidArgument( + "MemoryStats is only supported for addressable PjRt devices."); + } + absl::StatusOr maybe_stats = + device->pjrt_device()->GetAllocatorStats(); + if (absl::IsUnimplemented(maybe_stats.status())) { + return std::nullopt; + } + // Raise error if any status other than Unimplemented is returned. + ThrowIfError(maybe_stats.status()); + + nb::dict result; + result["num_allocs"] = maybe_stats->num_allocs; + result["bytes_in_use"] = maybe_stats->bytes_in_use; + result["peak_bytes_in_use"] = maybe_stats->peak_bytes_in_use; + result["largest_alloc_size"] = maybe_stats->largest_alloc_size; + if (maybe_stats->bytes_limit) { + result["bytes_limit"] = *maybe_stats->bytes_limit; + } + result["bytes_reserved"] = maybe_stats->bytes_reserved; + result["peak_bytes_reserved"] = maybe_stats->peak_bytes_reserved; + if (maybe_stats->bytes_reservable_limit) { + result["bytes_reservable_limit"] = *maybe_stats->bytes_reservable_limit; + } + result["largest_free_block_bytes"] = maybe_stats->largest_free_block_bytes; + if (maybe_stats->pool_bytes) { + result["pool_bytes"] = *maybe_stats->pool_bytes; + } + if (maybe_stats->peak_pool_bytes) { + result["peak_pool_bytes"] = *maybe_stats->peak_pool_bytes; + } + return result; +} + +absl::StatusOr PyDevice::GetStreamForExternalReadyEvents() + const { + ifrt::PjRtDevice* device = llvm::dyn_cast(device_); + if (device == nullptr || !device->IsAddressable()) { + return xla::InvalidArgument( + "GetStreamForExternalReadyEvents is only supported for addressable " + "PjRt devices."); + } + return device->pjrt_device()->GetStreamForExternalReadyEvents(); +} + +/* static */ int PyDevice::tp_traverse(PyObject* self, visitproc visit, + void* arg) { + PyDevice* d = nb::inst_ptr(self); + Py_VISIT(d->client().ptr()); + return 0; +} + +/* static */ int PyDevice::tp_clear(PyObject* self) { + PyDevice* d = nb::inst_ptr(self); + nb_class_ptr client; + std::swap(client, d->client_); + return 0; +} + +PyType_Slot PyDevice::slots_[] = { + {Py_tp_traverse, (void*)PyDevice::tp_traverse}, + {Py_tp_clear, (void*)PyDevice::tp_clear}, + {0, nullptr}, +}; + +/* static */ void PyDevice::RegisterPythonType(nb::module_& m) { + nb::class_ device( + m, "Device", nb::type_slots(PyDevice::slots_), + "A descriptor of an available device.\n\nSubclasses are used to " + "represent specific types of devices, e.g. CPUs, GPUs. Subclasses may " + "have additional properties specific to that device type."); + device + .def_prop_ro( + "id", &PyDevice::id, + "Integer ID of this device.\n\nUnique across all available devices " + "of this type, including remote devices on multi-host platforms.") + .def_prop_ro("process_index", &PyDevice::process_index, + "Integer index of this device's process.\n\n" + "This is always 0 except on multi-process platforms.") + .def_prop_ro("host_id", &PyDevice::process_index, + "Deprecated; please use process_index") + .def_prop_ro("task_id", &PyDevice::process_index, + "Deprecated; please use process_index") + .def_prop_ro("platform", &PyDevice::platform) + .def_prop_ro("device_kind", &PyDevice::device_kind) + .def_prop_ro("client", &PyDevice::client) + .def_prop_ro( + "local_hardware_id", &PyDevice::local_hardware_id, + "Opaque hardware ID, e.g., the CUDA device number. In general, not " + "guaranteed to be dense, and not guaranteed to be defined on all " + "platforms.") + .def("__str__", &PyDevice::Str) + .def("__repr__", &PyDevice::Repr) + .def("transfer_to_infeed", + ThrowIfErrorWrapper(&PyDevice::TransferToInfeed)) + .def("transfer_from_outfeed", + ValueOrThrowWrapper(&PyDevice::TransferFromOutfeed)) + .def("memory", ValueOrThrowWrapper(&PyDevice::Memory), nb::arg("kind")) + .def("default_memory", ValueOrThrowWrapper(&PyDevice::DefaultMemory), + "Returns the default memory of a device.") + .def("addressable_memories", &PyDevice::AddressableMemories, + "Returns all the memories that a device can address.") + + .def("live_buffers", + [](nb::handle device) { + PythonDeprecationWarning( + /*stacklevel=*/1, + "Per device live_buffers() is deprecated. Please " + "use the jax.live_arrays() for jax.Arrays instead."); + return nb::list(); + }) + .def( + "memory_stats", ValueOrThrowWrapper(&PyDevice::MemoryStats), + "Returns memory statistics for this device keyed by name. May not " + "be implemented on all platforms, and different platforms may return " + "different stats, or -1 for unavailable stats. 'bytes_in_use' is " + "usually available. Intended for diagnostic use.") + .def( + "get_stream_for_external_ready_events", + xla::ValueOrThrowWrapper(&PyDevice::GetStreamForExternalReadyEvents)); + static PyMethodDef get_attr_method = { + "__getattr__", + +[](PyObject* self, PyObject* args) -> PyObject* { + PyObject* key; + if (!PyArg_ParseTuple(args, "O", &key)) { + PyErr_SetString(PyExc_TypeError, "__getattr__ must take 1 argument."); + return nullptr; + } + try { + auto device = nb::cast(nb::handle(self)); + auto name = nb::cast(nb::handle(key)); + const auto& attrs = device->device_->Attributes().map(); + auto it = attrs.find(name); + if (it != attrs.end()) { + auto result = std::visit([](auto&& v) { return nb::cast(v.value); }, + it->second); + return result.release().ptr(); + } + PyErr_SetNone(PyExc_AttributeError); + return nullptr; + } catch (std::exception& e) { + PyErr_Format(PyExc_SystemError, "Unhandled nanobind exception: %s", + e.what()); + return nullptr; + } catch (...) { + PyErr_SetString(PyExc_SystemError, "Unhandled nanobind exception."); + return nullptr; + } + }, + METH_VARARGS, + nullptr, + }; + device.attr("__getattr__") = nb::steal(PyDescr_NewMethod( + reinterpret_cast(device.ptr()), &get_attr_method)); +} + +} // namespace xla diff --git a/jaxlib/xla/py_device.h b/jaxlib/xla/py_device.h new file mode 100644 index 000000000000..4e74992fb2ee --- /dev/null +++ b/jaxlib/xla/py_device.h @@ -0,0 +1,83 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_PY_DEVICE_H_ +#define JAXLIB_XLA_PY_DEVICE_H_ + +#include + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "nanobind/nanobind.h" +#include "jaxlib/xla/nb_class_ptr.h" +#include "jaxlib/xla/py_client.h" +#include "xla/literal.h" +#include "xla/python/ifrt/device.h" +#include "xla/shape.h" + +namespace xla { + +class PyDevice { + public: + PyDevice(nb_class_ptr client, ifrt::Device* device); + + // Devices are compared using Python object identity, so we don't allow them + // to be copied or moved. + PyDevice(const PyDevice&) = delete; + PyDevice(PyDevice&&) = delete; + PyDevice& operator=(const PyDevice&) = delete; + PyDevice& operator=(PyDevice&&) = delete; + + const nb_class_ptr& client() const { return client_; } + ifrt::Device* device() const { return device_; } + + int id() const; + int process_index() const; + absl::string_view platform() const; + absl::string_view device_kind() const; + std::optional local_hardware_id() const; + + absl::string_view Str() const; + absl::string_view Repr() const; + + absl::Status TransferToInfeed(LiteralSlice literal); + absl::StatusOr TransferFromOutfeed(Shape shape); + + absl::StatusOr> Memory( + absl::string_view kind) const; + absl::StatusOr> DefaultMemory() const; + nanobind::list AddressableMemories() const; + absl::StatusOr> MemoryStats() const; + + absl::StatusOr GetStreamForExternalReadyEvents() const; + + static void RegisterPythonType(nanobind::module_& m); + + private: + static int tp_traverse(PyObject* self, visitproc visit, void* arg); + static int tp_clear(PyObject* self); + static PyType_Slot slots_[]; + + nb_class_ptr client_; + ifrt::Device* device_; +}; + +} // namespace xla + +#endif // JAXLIB_XLA_PY_DEVICE_H_ diff --git a/jaxlib/xla/py_device_list.cc b/jaxlib/xla/py_device_list.cc new file mode 100644 index 000000000000..205c971b9317 --- /dev/null +++ b/jaxlib/xla/py_device_list.cc @@ -0,0 +1,470 @@ +/* Copyright 2023 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/py_device_list.h" + +#include + +#include +#include +#include +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/hash/hash.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "nanobind/make_iterator.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "jaxlib/xla/nb_class_ptr.h" +#include "jaxlib/xla/py_client.h" +#include "jaxlib/xla/py_device.h" +#include "jaxlib/xla/python_ref_manager.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/types.h" +#include "xla/util.h" + +namespace jax { + +namespace nb = ::nanobind; + +PyDeviceList::PyDeviceList(xla::nb_class_ptr py_client, + xla::ifrt::DeviceListRef device_list) + : py_client_(std::move(py_client)), device_list_(std::move(device_list)) {} + +PyDeviceList::PyDeviceList(nb::tuple py_device_assignment) + : device_list_(py_device_assignment) { + // Attempt to convert to Python devices into `ifrt::DeviceList`. + if (py_device_assignment.size() == 0) { + return; + } + absl::InlinedVector devices; + devices.reserve(py_device_assignment.size()); + for (nb::handle obj : py_device_assignment) { + if (!nb::isinstance(obj.ptr())) { + // Non-`xla::PyDevice` is used on an alternative JAX backend with device + // duck typing. Use Python device objects already set in `device_list_`. + return; + } + auto py_device = nb::cast(obj); + if (py_client_.get() == nullptr) { + py_client_ = py_device->client(); + } else if (py_device->client().get() != py_client_.get()) { + // If the list contains multiple clients, fall back to device duck typing. + return; + } + devices.push_back(py_device->device()); + } + device_list_ = py_client_->ifrt_client()->MakeDeviceList(devices); +} + +PyDeviceList::~PyDeviceList() { + if (device_list_.index() == 1) { + xla::GlobalPyRefManager()->AddGarbage( + std::move(std::get<1>(std::move(device_list_)))); + } +} + +absl::StatusOr PyDeviceList::ifrt_device_list() + const { + switch (device_list_.index()) { + case 0: + return std::get<0>(device_list_); + case 1: + return xla::InvalidArgument("DeviceList contains non-IFRT devices"); + default: + return xla::InvalidArgument("Unrecognized DeviceList type"); + } +} + +int64_t PyDeviceList::Hash() { + if (!hash_.has_value()) { + switch (device_list_.index()) { + case 0: + hash_ = absl::HashOf(std::get<0>(device_list_)); + break; + case 1: + hash_ = nb::hash(std::get<1>(device_list_)); + break; + default: + throw nb::value_error("Unrecognized DeviceList type"); + } + } + return *hash_; +} + +/*static*/ bool PyDeviceList::Equal(xla::nb_class_ptr self, + nb::handle other) { + if (!nb::isinstance(other)) { + return false; + } + auto o = nb::cast(other); + // Fast-path using a pointer equality check. + if (self.get() == o) { + return true; + } + int64_t h1, h2; + { + nb::ft_object_guard lock(self); + h1 = self->Hash(); + } + { + nb::ft_object_guard lock(other); + h2 = o->Hash(); + } + if (h1 != h2) { + return false; + } + if (self->device_list_.index() == 0 && o->device_list_.index() == 0) { + nb::gil_scoped_release gil_release; + return *std::get<0>(self->device_list_) == *std::get<0>(o->device_list_); + } else { + return self->AsTuple().equal(o->AsTuple()); + } +} + +/*static*/ bool PyDeviceList::NotEqual(xla::nb_class_ptr self, + nb::handle other) { + return !Equal(std::move(self), other); +} + +int PyDeviceList::Len() const { + switch (device_list_.index()) { + case 0: + return std::get<0>(device_list_)->size(); + case 1: + return nb::len(std::get<1>(device_list_)); + default: + throw nb::value_error("Unrecognized DeviceList type"); + } +} + +nb::object PyDeviceList::GetItem(int index) { + switch (device_list_.index()) { + case 0: { + const xla::ifrt::DeviceListRef& device_list = std::get<0>(device_list_); + if (index < -device_list->size() || index >= device_list->size()) { + throw nb::index_error(); + } else if (index < 0) { + index += device_list->size(); + } + return py_client_->GetPyDevice(device_list->devices()[index]); + } + case 1: + return std::get<1>(device_list_).attr("__getitem__")(index); + default: + throw nb::value_error("Unrecognized DeviceList type"); + } +} + +nb::object PyDeviceList::GetSlice(nb::slice slice) { + switch (device_list_.index()) { + case 0: { + const xla::ifrt::DeviceListRef& device_list = std::get<0>(device_list_); + const absl::Span devices = + device_list->devices(); + Py_ssize_t start, stop, step, slicelength; + if (PySlice_GetIndicesEx(slice.ptr(), devices.size(), &start, &stop, + &step, &slicelength) != 0) { + throw nb::python_error(); + } + nb::tuple out = nb::steal(PyTuple_New(slicelength)); + for (size_t i = 0; i < slicelength; ++i) { + nb::object d = py_client_->GetPyDevice(devices[start]); + PyTuple_SET_ITEM(out.ptr(), i, d.release().ptr()); + start += step; + } + return std::move(out); + } + case 1: + return std::get<1>(device_list_).attr("__getitem__")(slice); + default: + throw nb::value_error("Unrecognized DeviceList type"); + } +} + +nb::tuple PyDeviceList::AsTuple() const { + switch (device_list_.index()) { + case 0: { + const xla::ifrt::DeviceListRef& device_list = std::get<0>(device_list_); + nb::tuple out = nb::steal(PyTuple_New(device_list->size())); + int i = 0; + for (xla::ifrt::Device* device : device_list->devices()) { + nb::object d = py_client_->GetPyDevice(device); + PyTuple_SET_ITEM(out.ptr(), i, d.release().ptr()); + ++i; + } + return out; + } + case 1: + return std::get<1>(device_list_); + default: + throw nb::value_error("Unrecognized DeviceList type"); + } +} + +nb::iterator PyDeviceList::Iter() { + switch (device_list_.index()) { + case 0: { + // Iterator whose deference converts `xla::ifrt::Device*` into JAX + // `PjRtDevice`. + struct Iterator { + void operator++() { ++it; } + bool operator==(const Iterator& other) const { return it == other.it; } + xla::nb_class_ptr operator*() const { + return py_client->GetPyDevice(*it); + } + xla::nb_class_ptr py_client; + absl::Span::const_iterator it; + }; + return nb::make_iterator( + nb::type(), "ifrt_device_iterator", + Iterator{py_client_, std::get<0>(device_list_)->devices().cbegin()}, + Iterator{py_client_, std::get<0>(device_list_)->devices().cend()}); + } + case 1: + return nb::make_iterator( + nb::type(), "python_device_iterator", + std::get<1>(device_list_).begin(), std::get<1>(device_list_).end()); + default: + throw nb::value_error("Unrecognized DeviceList type"); + } +} + +std::string PyDeviceList::Str() { + return nb::cast(nb::str(AsTuple())); +} + +nb::tuple PyDeviceList::Dump() const { return AsTuple(); } + +bool PyDeviceList::IsFullyAddressable() { + if (!is_fully_addressable_.has_value()) { + is_fully_addressable_ = true; + switch (device_list_.index()) { + case 0: { + const int process_index = py_client_ ? py_client_->process_index() : 0; + for (const xla::ifrt::Device* device : + std::get<0>(device_list_)->devices()) { + if (device->ProcessIndex() != process_index) { + is_fully_addressable_ = false; + break; + } + } + break; + } + case 1: { + for (nb::handle device : std::get<1>(device_list_)) { + if (nb::cast(device.attr("process_index")) != + nb::cast(device.attr("client").attr("process_index")())) { + is_fully_addressable_ = false; + break; + } + } + break; + } + default: + throw nb::value_error("Unrecognized DeviceList type"); + } + } + return *is_fully_addressable_; +} + +/*static*/ xla::nb_class_ptr PyDeviceList::AddressableDeviceList( + xla::nb_class_ptr self) { + nb::ft_object_guard lock(self); + if (self->IsFullyAddressable()) { + // Do not cache this result in `addressable_device_list_`. Otherwise, it + // will create a cycle that prevents deletion of this object. + return self; + } + if (!self->addressable_device_list_.has_value()) { + switch (self->device_list_.index()) { + case 0: { + absl::InlinedVector addressable_devices; + const int process_index = + self->py_client_ ? self->py_client_->process_index() : 0; + for (xla::ifrt::Device* device : + std::get<0>(self->device_list_)->devices()) { + if (device->ProcessIndex() == process_index) { + addressable_devices.push_back(device); + } + } + self->addressable_device_list_ = xla::make_nb_class( + self->py_client_, self->py_client_->ifrt_client()->MakeDeviceList( + addressable_devices)); + break; + } + case 1: { + auto device_list = std::get<1>(self->device_list_); + std::vector addressable_devices; + for (size_t i = 0; i < device_list.size(); ++i) { + nb::object device = device_list[i]; + if (nb::cast(device.attr("process_index")) == + nb::cast(device.attr("client").attr("process_index")())) { + addressable_devices.push_back(std::move(device)); + } + } + self->addressable_device_list_ = xla::make_nb_class( + xla::MutableSpanToNbTuple(absl::MakeSpan(addressable_devices))); + break; + } + default: + throw nb::value_error("Unrecognized DeviceList type"); + } + } + return *self->addressable_device_list_; +} + +void PyDeviceList::PopulateMemoryKindInfo() { + if (device_list_.index() == 1) { + // Handle Python duck-type devices in a separate function for readability. + PopulateMemoryKindInfoForDuckTypedDevices(); + return; + } + if (device_list_.index() != 0) { + throw nb::value_error("Unrecognized DeviceList type"); + } + MemoryKindInfo info; + xla::ifrt::Device* addressable_device = nullptr; + const int process_index = py_client_ ? py_client_->process_index() : 0; + for (xla::ifrt::Device* device : std::get<0>(device_list_)->devices()) { + if (device->ProcessIndex() == process_index) { + addressable_device = device; + break; + } + } + if (addressable_device == nullptr) { + info.default_memory_kind = nb::none(); + memory_kind_info_ = std::move(info); + return; + } + + auto default_memory = addressable_device->DefaultMemory(); + if (!default_memory.ok()) { + // Cache the error. + memory_kind_info_ = default_memory.status(); + return; + } + info.default_memory_kind = nb::cast(*(*default_memory)->Kind().memory_kind()); + nb::tuple memory_kinds = + nb::steal(PyTuple_New(addressable_device->Memories().size())); + for (size_t i = 0; i < addressable_device->Memories().size(); ++i) { + auto* memory = addressable_device->Memories()[i]; + nb::str s = nb::str(memory->Kind().memory_kind()->data(), + memory->Kind().memory_kind()->size()); + PyTuple_SET_ITEM(memory_kinds.ptr(), i, s.release().ptr()); + } + info.memory_kinds = std::move(memory_kinds); + memory_kind_info_ = std::move(info); +} + +void PyDeviceList::PopulateMemoryKindInfoForDuckTypedDevices() { + MemoryKindInfo info; + try { + nb::handle addressable_device; + for (nb::handle device : std::get<1>(device_list_)) { + if (nb::cast(device.attr("process_index")) == + nb::cast(device.attr("client").attr("process_index")())) { + addressable_device = device; + break; + } + } + if (!addressable_device) { + info.default_memory_kind = nb::none(); + // info.memory_kinds is default-initialized to an empty tuple. + memory_kind_info_ = std::move(info); + return; + } + auto default_memory = addressable_device.attr("default_memory")(); + info.default_memory_kind = default_memory.attr("kind"); + info.memory_kinds = nb::tuple( + nb::object(addressable_device.attr("addressable_memories")())); + memory_kind_info_ = std::move(info); + } catch (nb::python_error& e) { + // Cache the error. + memory_kind_info_ = xla::InvalidArgument("%s", e.what()); + } +} + +/*static*/ absl::StatusOr PyDeviceList::MemoryKinds( + xla::nb_class_ptr self) { + nb::ft_object_guard lock(self); + if (!self->memory_kind_info_.has_value()) { + self->PopulateMemoryKindInfo(); + } + if (!self->memory_kind_info_->ok()) { + return self->memory_kind_info_->status(); + } + return (*self->memory_kind_info_)->memory_kinds; +} + +/*static*/ absl::StatusOr PyDeviceList::DefaultMemoryKind( + xla::nb_class_ptr self) { + nb::ft_object_guard lock(self); + if (!self->memory_kind_info_.has_value()) { + self->PopulateMemoryKindInfo(); + } + if (!self->memory_kind_info_->ok()) { + return self->memory_kind_info_->status(); + } + return (*self->memory_kind_info_)->default_memory_kind; +} + +/*static*/ void PyDeviceList::Register(nb::module_& m) { + nb::class_(m, "DeviceList") + .def(nb::init()) + .def("__hash__", &PyDeviceList::Hash, nb::lock_self()) + .def("__eq__", &PyDeviceList::Equal) + .def("__ne__", &PyDeviceList::NotEqual) + .def("__len__", &PyDeviceList::Len) + .def("__getitem__", &PyDeviceList::GetItem) + .def("__getitem__", &PyDeviceList::GetSlice) + .def("__iter__", &PyDeviceList::Iter, nb::keep_alive<0, 1>()) + .def("__str__", &PyDeviceList::Str) + .def("__repr__", &PyDeviceList::Str) + .def("__getstate__", [](const PyDeviceList& l) { return l.Dump(); }) + .def("__setstate__", + [](PyDeviceList& self, nb::tuple t) { + new (&self) PyDeviceList(std::move(t)); + }) + .def_prop_ro("is_fully_addressable", &PyDeviceList::IsFullyAddressable, + nb::lock_self()) + .def_prop_ro("addressable_device_list", + &PyDeviceList::AddressableDeviceList) + // `xla::ValueOrThrowWrapper` does not work with + // `def_prop_ro()`. Manually convert an error into an exception. + .def_prop_ro("default_memory_kind", + [](xla::nb_class_ptr l) { + auto kind = DefaultMemoryKind(l); + if (!kind.ok()) { + throw nb::value_error(kind.status().ToString().c_str()); + } + return *kind; + }) + .def_prop_ro("memory_kinds", [](xla::nb_class_ptr l) { + auto kinds = MemoryKinds(l); + if (!kinds.ok()) { + throw nb::value_error(kinds.status().ToString().c_str()); + } + return *kinds; + }); +} + +} // namespace jax diff --git a/jaxlib/xla/py_device_list.h b/jaxlib/xla/py_device_list.h new file mode 100644 index 000000000000..0fa9b3965dfe --- /dev/null +++ b/jaxlib/xla/py_device_list.h @@ -0,0 +1,136 @@ +/* Copyright 2023 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_PY_DEVICE_LIST_H_ +#define JAXLIB_XLA_PY_DEVICE_LIST_H_ + +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "nanobind/nanobind.h" +#include "jaxlib/xla/nb_class_ptr.h" +#include "jaxlib/xla/py_client.h" +#include "xla/python/ifrt/device_list.h" + +namespace jax { + +// Device list with various caching and direct access to IFRT DeviceList. +class PyDeviceList { + public: + PyDeviceList(xla::nb_class_ptr py_client, + xla::ifrt::DeviceListRef device_list); + explicit PyDeviceList(nanobind::tuple py_device_assignment); + ~PyDeviceList(); + + PyDeviceList(const PyDeviceList&) = delete; + PyDeviceList(PyDeviceList&&) = delete; + PyDeviceList& operator=(const PyDeviceList&) = delete; + PyDeviceList& operator=(PyDeviceList&&) = delete; + + static nanobind::handle type() { + static auto type = nanobind::type(); + return type; + } + + // These two methods are safe to call from C++ without GIL. + xla::nb_class_ptr py_client() const { return py_client_; } + absl::StatusOr ifrt_device_list() const; + + int Len() const; // Requires the GIL in GIL mode. + nanobind::object GetItem(int index); // Requires the GIL in GIL mode. + + // Requires the GIL in GIL mode. Acquires the self lock in non-GIL mode. + static xla::nb_class_ptr AddressableDeviceList( + xla::nb_class_ptr self); + + // Requires the GIL in GIL mode. Acquires the self lock in non-GIL mode. + static absl::StatusOr DefaultMemoryKind( + xla::nb_class_ptr self); + + // Requires the GIL in GIL mode. Acquires the self lock in non-GIL mode. + static absl::StatusOr MemoryKinds( + xla::nb_class_ptr self); + + // go/pywald-pybind-annotation BEGIN + // refs { + // module_path: "third_party/py/jax/jaxlib/xla/xla.cc" + // module_arg {} + // } + // go/pywald-pybind-annotation END + static void Register(nanobind::module_& m); + + private: + nanobind::tuple AsTuple() const; + + // Methods below require GIL. + nanobind::object GetSlice(nanobind::slice slice); + nanobind::iterator Iter(); + + std::string Str(); + + nanobind::tuple Dump() const; + + int64_t Hash(); // Mutates hash_, needs self lock. + + static bool Equal(xla::nb_class_ptr self, + nanobind::handle other); + static bool NotEqual(xla::nb_class_ptr self, + nanobind::handle other); + + // Finds the memory kind info from an addressable device. Requires the GIL + // or self lock. + void PopulateMemoryKindInfo(); + // Same as `PopulateMemoryKindInfo()`, but uses `py_device_assignment_` + // instead of `ifrt_device_list_` to support duck-typed device objects. + // Requires the GIL or self lock. + void PopulateMemoryKindInfoForDuckTypedDevices(); + + // Requires the self lock or GIL is held. + bool IsFullyAddressable(); + + // Valid only if `device_list_` contains `xla::ifrt::DeviceList` and + // non-empty. + xla::nb_class_ptr py_client_; + + // Either C++ `ifrt::DeviceList` or Python duck-type devices. + // TODO(hyeontaek): Remove support for Python duck-type devices once all + // JAX backends and tests are migrated to use an `xla::ifrt::Device` type + // for JAX devices. + // Immutable after constructor; no locking needed. + std::variant device_list_; + + // Populated on demand. Guarded by the object's self lock. + std::optional hash_; + // TODO(hyeontaek): Make the following property cached within + // `xla::ifrt::DeviceList`. + // Populated on demand. Guarded by the object's self lock. + std::optional is_fully_addressable_; + // Populated on demand. Guarded by the object's self lock. + std::optional> addressable_device_list_; + + struct MemoryKindInfo { + nanobind::object default_memory_kind; + nanobind::tuple memory_kinds; + }; + // Populated on demand. Guarded by the object's self lock. + std::optional> memory_kind_info_; +}; + +} // namespace jax + +#endif // JAXLIB_XLA_PY_DEVICE_LIST_H_ diff --git a/jaxlib/xla/py_executable.cc b/jaxlib/xla/py_executable.cc new file mode 100644 index 000000000000..eaf5af34f883 --- /dev/null +++ b/jaxlib/xla/py_executable.cc @@ -0,0 +1,427 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/py_executable.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/inlined_vector.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "jaxlib/xla/nb_class_ptr.h" +#include "jaxlib/xla/py_array.h" +#include "jaxlib/xla/py_client.h" +#include "jaxlib/xla/py_device.h" +#include "jaxlib/xla/traceback.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/pjrt/pjrt_future.h" +#include "xla/pjrt/pjrt_layout.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/status.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/fingerprint.h" +#include "tsl/profiler/lib/traceme.h" + +namespace xla { + +namespace nb = nanobind; + +absl::Status PyToken::Await() { + CHECK(future_.IsValid()); + nb::gil_scoped_release gil_release; + return future_.Await(); +} + +absl::Status PyShardedToken::Await() { + nb::gil_scoped_release gil_release; + absl::Status status = absl::OkStatus(); + for (auto& future : futures_) { + auto s = future.Await(); + if (!s.ok()) status = std::move(s); + } + return status; +} + +PyLoadedExecutable::PyLoadedExecutable( + nb_class_ptr client, + std::shared_ptr ifrt_loaded_executable, + std::optional traceback, + std::optional fingerprint) + : client_(std::move(client)), + ifrt_loaded_executable_(std::move(ifrt_loaded_executable)), + traceback_(std::move(traceback)), + fingerprint_(std::move(fingerprint)), + next_launch_id_( + fingerprint_.has_value() ? tsl::Fingerprint32(*fingerprint_) : 1) { + CHECK(PyGILState_Check()); + if (fingerprint_) { + VLOG(1) << "Fingerprint for executable " << ifrt_loaded_executable_->name() + << ": " << *fingerprint_; + } + nb::ft_lock_guard lock(client_->executables_mutex_); + next_ = client_->executables_; + client_->executables_ = this; + prev_ = nullptr; + if (next_) { + next_->prev_ = this; + } +} + +PyLoadedExecutable::~PyLoadedExecutable() { + CHECK(PyGILState_Check()); + nb::ft_lock_guard lock(client_->executables_mutex_); + if (client_->executables_ == this) { + client_->executables_ = next_; + } + if (prev_) { + prev_->next_ = next_; + } + if (next_) { + next_->prev_ = prev_; + } +} + +std::vector> PyLoadedExecutable::AddressableDevices() + const { + std::vector> devices; + devices.reserve(ifrt_loaded_executable_->addressable_devices().size()); + for (ifrt::Device* device : ifrt_loaded_executable_->addressable_devices()) { + devices.push_back(client_->GetPyDevice(device)); + } + return devices; +} + +namespace { + +static int GetNumDevices(const ExecuteShardedArg& arg) { + if (std::holds_alternative(arg)) { + return std::get(arg).num_addressable_shards(); + } else { + return std::get>(arg).size(); + } +} +static tsl::RCReference GetIfRtArray( + const ExecuteShardedArg& arg) { + if (std::holds_alternative(arg)) { + return tsl::FormRef(std::get(arg).ifrt_array()); + } + auto& arg_vector = std::get>(arg); + + // TODO(hyeontaek): This on-demand Array creation is not efficient and has + // insufficient information about the shape (a dummy shape is used). This + // should be removed if possible and only be used in the context where the + // shape information is unused. + std::vector> ifrt_arrays; + ifrt_arrays.reserve(arg_vector.size()); + absl::InlinedVector devices; + devices.reserve(arg_vector.size()); + for (auto& arr : arg_vector) { + CHECK_EQ(arr.ifrt_array()->sharding().devices()->size(), 1) + << arr.ifrt_array()->sharding().DebugString(); + ifrt_arrays.push_back(tsl::FormRef(arr.ifrt_array())); + devices.push_back( + arr.ifrt_array()->sharding().devices()->devices().front()); + } + CHECK(!ifrt_arrays.empty()); + // Use a dummy shape. + // TODO(hyeontaek): Find a way to compute a correct shape. + // TODO(yashkatariya): Plumb sharding or memory_kind here. + ifrt::Client* client = ifrt_arrays.front()->client(); + auto ifrt_array = client->AssembleArrayFromSingleDeviceArrays( + ifrt_arrays.front()->shape(), + ifrt::OpaqueSharding::Create(client->MakeDeviceList(devices), + ifrt::MemoryKind()), + absl::MakeSpan(ifrt_arrays), ifrt::ArrayCopySemantics::kReuseInput, + ifrt::SingleDeviceShardSemantics::kAddressableShards); + TF_CHECK_OK(ifrt_array.status()); + return *ifrt_array; +} + +void PopulateExecuteShardedResults( + const nb_class_ptr& client, + std::vector> ifrt_arrays, + const PjRtFuture<>& result_status, int num_computations, + std::vector>& outputs) { + auto traceback = Traceback::Get(); + DCHECK_GT(num_computations, 0); + int num_output_buffers = ifrt_arrays.size(); + outputs.resize(num_output_buffers); + for (int buffer_id = 0; buffer_id < num_output_buffers; ++buffer_id) { + outputs[buffer_id].reserve(num_computations); + auto exploded_arrays = + ifrt_arrays[buffer_id]->DisassembleIntoSingleDeviceArrays( + ifrt::ArrayCopySemantics::kReuseInput, + ifrt::SingleDeviceShardSemantics::kAddressableShards); + TF_CHECK_OK(exploded_arrays.status()); + for (auto& exploded_array : *exploded_arrays) { + outputs[buffer_id].push_back(PyArray::MakeFromSingleDeviceArray( + client, traceback, std::move(exploded_array), false, true, + result_status)); + } + } +} + +absl::StatusOr ExecuteShardedOnLocalDevicesInternal( + const ifrt::ExecuteOptions& options, const nb_class_ptr& client, + ifrt::LoadedExecutable* ifrt_loaded_executable, + absl::Span args, + std::optional>>& returned_futures) { + std::vector> output_arrays; + std::unique_ptr> returned_future; + int num_computations = ifrt_loaded_executable->addressable_devices().size(); + PjRtFuture<> result_status; + { + nb::gil_scoped_release gil_release; + for (const auto& arg : args) { + if (GetNumDevices(arg) != num_computations) { + return InvalidArgument( + "Expected args to execute_sharded_on_local_devices to have %d " + "shards, got: [%s]", + num_computations, + absl::StrJoin(args, ", ", + [](std::string* out, const ExecuteShardedArg& arg) { + out->append(std::to_string(GetNumDevices(arg))); + })); + } + } + std::vector> arg_arrays(args.size()); + absl::c_transform(args, arg_arrays.begin(), + [&](const ExecuteShardedArg& arg) mutable { + return GetIfRtArray(arg); + }); + TF_ASSIGN_OR_RETURN(auto result, ifrt_loaded_executable->Execute( + absl::MakeSpan(arg_arrays), options, + /*devices=*/std::nullopt)); + output_arrays = std::move(result.outputs); + // options.fill_status is only supposed to be true when the computation has + // tokens. + if (options.fill_status) { + result_status = result.status; + if (returned_futures.has_value()) { + returned_futures->resize(num_computations, std::move(result.status)); + } + } + } + + // TODO(b/240696624): Although the PjRt interface require `returned_futures` + // to be resized correctly if it is not nullopt, some implementation does not + // implement this. So we have to check whether returned_futures is empty. + // Remove this check once the implementation is fixed. + auto py_sharded_token = returned_futures.has_value() + ? PyShardedToken(std::move(*returned_futures)) + : PyShardedToken(); + + return PyExecuteResults(client, std::move(output_arrays), num_computations, + std::move(py_sharded_token), result_status); +} + +} // namespace + +PyExecuteResults::PyExecuteResults( + const nb_class_ptr& client, + std::vector> ifrt_arrays, + int num_computations, PyShardedToken token, PjRtFuture<> result_status) + : client_(client), + ifrt_arrays_(std::move(ifrt_arrays)), + num_computations_(num_computations), + token_(std::move(token)), + result_status_(std::move(result_status)) {} + +void PyExecuteResults::CheckNotDisassembled() const { + if (is_exploded_) { + throw nb::value_error("ExecuteResults already exploded."); + } +} + +std::vector> PyExecuteResults::Consume() { + CheckNotDisassembled(); + is_exploded_ = true; + return std::move(ifrt_arrays_); +} + +PyShardedToken PyExecuteResults::ConsumeToken() { + if (token_consumed_) { + throw nb::value_error("ExecuteResults token already consumed."); + } + token_consumed_ = true; + return std::move(token_); +} + +std::vector> +PyExecuteResults::DisassembleIntoSingleDeviceArrays() { + std::vector> outputs; + PopulateExecuteShardedResults( + client_, Consume(), + result_status_.IsValid() ? result_status_ : PjRtFuture<>(), + num_computations_, outputs); + return outputs; +} + +std::vector> +PyExecuteResults::DisassemblePrefixIntoSingleDeviceArrays(size_t n) { + CheckNotDisassembled(); + if (n > ifrt_arrays_.size()) { + throw nb::value_error( + absl::StrCat("In DisassemblePrefixIntoSingleDeviceArrays: ", n, " > ", + ifrt_arrays_.size()) + .c_str()); + } + std::vector> ifrt_arrays; + ifrt_arrays.reserve(ifrt_arrays_.size() - n); + for (size_t i = n; i < ifrt_arrays_.size(); ++i) { + ifrt_arrays.push_back(std::move(ifrt_arrays_[i])); + } + ifrt_arrays_.erase(ifrt_arrays_.begin() + n, ifrt_arrays_.end()); + std::swap(ifrt_arrays_, ifrt_arrays); + std::vector> outputs; + PopulateExecuteShardedResults( + client_, std::move(ifrt_arrays), + result_status_.IsValid() ? result_status_ : PjRtFuture<>(), + num_computations_, outputs); + return outputs; +} + +std::vector PyExecuteResults::ConsumeWithHandlers( + std::vector> + out_handlers) { + std::vector outputs; + auto ifrt_arrays = Consume(); + auto traceback = Traceback::Get(); + DCHECK_GT(num_computations_, 0); + int num_output_buffers = ifrt_arrays.size(); + outputs.reserve(num_output_buffers); + if (out_handlers.size() != num_output_buffers) { + throw nb::value_error( + absl::StrCat("Mismatch between out_handlers and num_results: ", + out_handlers.size(), " vs ", num_output_buffers) + .c_str()); + } + for (int buffer_id = 0; buffer_id < num_output_buffers; ++buffer_id) { + auto& handler = out_handlers[buffer_id]; + if (std::holds_alternative(handler)) { + outputs.push_back(std::get(handler)->Call( + client_, std::move(ifrt_arrays[buffer_id]), + result_status_.IsValid() ? result_status_ : PjRtFuture<>())); + } else { + tsl::profiler::TraceMe traceme("ConsumeWithHandlers fallback."); + auto disassembled_arrays = + ifrt_arrays[buffer_id]->DisassembleIntoSingleDeviceArrays( + ifrt::ArrayCopySemantics::kReuseInput, + ifrt::SingleDeviceShardSemantics::kAddressableShards); + TF_CHECK_OK(disassembled_arrays.status()); + nb::list bufs = + nb::steal(PyList_New(disassembled_arrays->size())); + int i = 0; + for (auto& disassembled_array : *disassembled_arrays) { + nb::object array = PyArray::MakeFromSingleDeviceArray( + client_, traceback, std::move(disassembled_array), false, true, + result_status_.IsValid() ? result_status_ : PjRtFuture<>()); + PyList_SET_ITEM(bufs.ptr(), i, array.release().ptr()); + ++i; + } + outputs.push_back(std::get(handler)(std::move(bufs))); + } + } + return outputs; +} + +absl::StatusOr PyLoadedExecutable::ExecuteSharded( + std::vector args, bool with_tokens) { + xla::ifrt::ExecuteOptions options = options_; + options.launch_id = GetNextLaunchId(); + options.fill_status = with_tokens; + options.execution_stream_id = tsl::Env::Default()->GetCurrentThreadId(); + std::optional>> returned_futures; + if (with_tokens) { + returned_futures.emplace(); + } + absl::Span span_args = args; + return ExecuteShardedOnLocalDevicesInternal(options, client_, + ifrt_loaded_executable_.get(), + span_args, returned_futures); +} + +absl::StatusOr>> +PyLoadedExecutable::HloModules() const { + nb::gil_scoped_release gil_release; + return ifrt_loaded_executable_->GetHloModules(); +} + +absl::StatusOr>> +PyLoadedExecutable::GetOutputMemoryKinds() const { + nb::gil_scoped_release gil_release; + return ifrt_loaded_executable_->GetOutputMemoryKinds(); +} + +absl::StatusOr>> +PyLoadedExecutable::GetParameterLayouts() const { + nb::gil_scoped_release gil_release; + return ifrt_loaded_executable_->GetParameterLayouts(); +} + +absl::StatusOr>> +PyLoadedExecutable::GetOutputLayouts() const { + nb::gil_scoped_release gil_release; + return ifrt_loaded_executable_->GetOutputLayouts(); +} + +std::optional> +PyLoadedExecutable::GetParameterShardings() const { + nb::gil_scoped_release gil_release; + return ifrt_loaded_executable_->GetParameterShardings(); +} + +std::optional> PyLoadedExecutable::GetOutputShardings() + const { + nb::gil_scoped_release gil_release; + return ifrt_loaded_executable_->GetOutputShardings(); +} + +int64_t PyLoadedExecutable::GetNextLaunchId() { + return next_launch_id_.fetch_add(1, std::memory_order_relaxed); +} + +void PyLoadedExecutable::KeepAlive(nb::object obj) { + keepalives_.push_back(std::move(obj)); +} + +} // namespace xla diff --git a/jaxlib/xla/py_executable.h b/jaxlib/xla/py_executable.h new file mode 100644 index 000000000000..9c8ce8010c90 --- /dev/null +++ b/jaxlib/xla/py_executable.h @@ -0,0 +1,254 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_PY_EXECUTABLE_H_ +#define JAXLIB_XLA_PY_EXECUTABLE_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "llvm/Support/Casting.h" +#include "nanobind/nanobind.h" +#include "jaxlib/xla/nb_class_ptr.h" +#include "jaxlib/xla/py_array.h" +#include "jaxlib/xla/py_client.h" +#include "jaxlib/xla/traceback.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/pjrt/pjrt_future.h" +#include "xla/pjrt/pjrt_layout.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/attribute_map.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/pjrt_ifrt/pjrt_executable.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/status.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +class PyToken { + public: + PyToken() = default; + explicit PyToken(PjRtFuture<> future) : future_(std::move(future)) {} + + static PyToken ReadyPyToken() { + return PyToken(PjRtFuture<>(absl::OkStatus())); + } + + absl::Status Await(); + + private: + PjRtFuture<> future_; +}; + +// PyShardedToken contains a PyToken for each device's execution. +class PyShardedToken { + public: + // Default construction creates a always-ready token. + PyShardedToken() = default; + explicit PyShardedToken(std::vector> futures) + : futures_(std::move(futures)) {} + + PyToken GetPyToken(int device_id) const { + if (futures_.empty()) return PyToken::ReadyPyToken(); + return PyToken(futures_.at(device_id)); + } + + absl::Status Await(); + + private: + std::vector> futures_; +}; + +class PyExecuteResults { + public: + PyExecuteResults(const nb_class_ptr& client, + std::vector> ifrt_arrays, + int num_computations, PyShardedToken token, + PjRtFuture<> result_status = PjRtFuture<>()); + + std::vector> DisassembleIntoSingleDeviceArrays(); + + std::vector> DisassemblePrefixIntoSingleDeviceArrays( + size_t n); + + std::vector ConsumeWithHandlers( + std::vector> + out_handlers); + + std::vector> Consume(); + + PyShardedToken ConsumeToken(); + + size_t Size() const { + CheckNotDisassembled(); + return ifrt_arrays_.size(); + } + + void CheckNotDisassembled() const; + + private: + bool is_exploded_ = false; + bool token_consumed_ = false; + nb_class_ptr client_; + std::vector> ifrt_arrays_; + int num_computations_; + PyShardedToken token_; + // Only set if the computation has tokens. + PjRtFuture<> result_status_; +}; + +using ExecuteShardedArg = std::variant>; + +// Python wrapper around PjRtExecutable. We use a wrapper class: +// a) to keep the PyClient alive via a std::shared_ptr<> +// b) to add Python-specific functionality. +class PyLoadedExecutable { + public: + PyLoadedExecutable( + nb_class_ptr client, + std::shared_ptr ifrt_loaded_executable, + std::optional traceback, + std::optional fingerprint); + ~PyLoadedExecutable(); + + nb_class_ptr client() const { return client_; } + ifrt::LoadedExecutable* ifrt_loaded_executable() const { + return ifrt_loaded_executable_.get(); + } + + std::shared_ptr shared_ifrt_loaded_executable() { + return ifrt_loaded_executable_; + } + + std::vector> AddressableDevices() const; + + int64_t SizeOfGeneratedCodeInBytes() const { + return ifrt_loaded_executable_->SizeOfGeneratedCodeInBytes(); + } + + absl::StatusOr GetCompiledMemoryStats() const { + nanobind::gil_scoped_release scope; + return ifrt_loaded_executable_->GetCompiledMemoryStats(); + } + + absl::StatusOr GetCostAnalysis() const { + return ifrt_loaded_executable_->GetCostAnalysis(); + } + + void Delete() { + // TODO(hyeontaek): Return absl::Status. + TF_CHECK_OK(ifrt_loaded_executable_->Delete().Await()); + } + + bool is_deleted() { return ifrt_loaded_executable_->IsDeleted(); } + + // Takes args indexed by argid then deviceid, transposes them, and passes to + // PjRtExecutable::Execute. The result is similarly transposed back into the + // argid,deviceid format. + // args is [num_args x num_devices]. + absl::StatusOr ExecuteSharded( + std::vector args, bool with_tokens); + + absl::StatusOr>> HloModules() const; + + absl::StatusOr>> + GetOutputMemoryKinds() const; + + absl::StatusOr>> + GetParameterLayouts() const; + + absl::StatusOr>> + GetOutputLayouts() const; + + std::optional> GetParameterShardings() const; + + std::optional> GetOutputShardings() const; + + const std::optional& traceback() { return traceback_; } + + ifrt::LoadedExecutable* ifrt_executable() const { + return ifrt_loaded_executable_.get(); + } + + // Short-term escape hatch to get PjRtLoadedExecutable from PyExecutable. + // TODO(hyeontaek): Migrate all users of this method to be agnostic of PjRt. + std::shared_ptr shared_ptr_pjrt_executable() { + auto* exec = llvm::dyn_cast_or_null( + ifrt_loaded_executable_.get()); + if (exec == nullptr) { + throw XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend only."); + } + return exec->shared_ptr_pjrt_loaded_executable(); + } + + // Returns a template of execute options to pass to + // `ifrt_executable()->Execute()`. Note that the caller may need to override + // some options such as `launch_id` that change at each execution. + const ifrt::ExecuteOptions& options() const { return options_; } + + // Returns a unique launch ID to use for the next execution. + int64_t GetNextLaunchId(); + + const std::optional& fingerprint() const { return fingerprint_; } + + // Keep `obj` alive as long as PyLoadedExecutable. + void KeepAlive(nanobind::object obj); + + private: + friend class PyClient; + + nb_class_ptr client_; + std::shared_ptr ifrt_loaded_executable_; + std::optional traceback_; + + // Identical executables (i.e. representing the same program) will have the + // same fingerprint. nullopt on platforms or executables where fingerprints + // aren't implemented. + std::optional fingerprint_; + + // Launch ID to use for the next execution. + std::atomic next_launch_id_; + + // The options to pass to `executable_.Execute`. + ifrt::ExecuteOptions options_; + + // Python objects to keep alive as requested by user. + std::vector keepalives_; + + // Doubly-linked list of all executables known to the client. Protected by the + // GIL. + PyLoadedExecutable* next_; + PyLoadedExecutable* prev_; +}; + +} // namespace xla + +#endif // JAXLIB_XLA_PY_EXECUTABLE_H_ diff --git a/jaxlib/xla/py_host_callback.cc b/jaxlib/xla/py_host_callback.cc new file mode 100644 index 000000000000..fdb40c04b517 --- /dev/null +++ b/jaxlib/xla/py_host_callback.cc @@ -0,0 +1,259 @@ +/* Copyright 2023 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/py_host_callback.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "llvm/Support/ExtensibleRTTI.h" +#include "nanobind/nanobind.h" +#include "jaxlib/xla/callback.h" +#include "jaxlib/xla/py_host_callback.pb.h" +#include "jaxlib/xla/python_ref_manager.h" +#include "xla/layout_util.h" +#include "xla/pjrt/host_callback.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/host_callback.h" +#include "xla/python/pjrt_ifrt/pjrt_host_callback.h" +#include "xla/python/pjrt_ifrt/xla_host_callback.pb.h" +#include "xla/python/types.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/status_macros.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" + +namespace nb = nanobind; + +namespace xla { + +char PyFfiLoadedHostCallback::ID = 0; +char PyHostSendAndRecvLoadedHostCallback::ID = 0; + +namespace { + +absl::StatusOr> CreateCallbackArgs( + absl::Span operand_shapes) { + std::vector callback_args(operand_shapes.size()); + for (int i = 0; i < operand_shapes.size(); ++i) { + Shape shape = operand_shapes[i]; + + if (shape.IsArray()) { + Shape layout = + (shape.has_layout() ? shape + : LayoutUtil::GetWithDefaultLayout(shape)); + callback_args[i].dims.resize(shape.dimensions_size()); + absl::c_copy(shape.dimensions(), callback_args[i].dims.begin()); + callback_args[i].strides = ByteStridesForShape(layout); + callback_args[i].type = shape.element_type(); + callback_args[i].size_in_bytes = ShapeUtil::ByteSizeOf(layout); + TF_ASSIGN_OR_RETURN(callback_args[i].dtype, + PrimitiveTypeToNbDtype(shape.element_type())); + } else if (shape.IsToken()) { + callback_args[i].type = TOKEN; + } else { + return InvalidArgument( + "Only array and token arguments to Python callbacks are supported, " + "got %s", + shape.ToString()); + } + } + return callback_args; +} + +absl::StatusOr> CreateCallbackResults( + absl::Span result_shapes) { + std::vector callback_results(result_shapes.size()); + for (int i = 0; i < result_shapes.size(); ++i) { + if (result_shapes[i].IsArray()) { + const Shape& shape = + result_shapes[i].has_layout() + ? result_shapes[i] + : LayoutUtil::GetWithDefaultLayout(result_shapes[i]); + callback_results[i].expected_dims.resize(shape.dimensions_size()); + absl::c_copy(shape.dimensions(), + callback_results[i].expected_dims.begin()); + callback_results[i].expected_strides = ByteStridesForShape(shape); + callback_results[i].type = shape.element_type(); + callback_results[i].size_in_bytes = ShapeUtil::ByteSizeOf(shape); + callback_results[i].reversed_layout.resize(shape.dimensions_size()); + absl::c_reverse_copy(shape.layout().minor_to_major(), + callback_results[i].reversed_layout.begin()); + } else if (result_shapes[i].IsToken()) { + callback_results[i].type = TOKEN; + } else { + return InvalidArgument( + "Only array and token return values from Python callbacks are " + "supported, got %s", + result_shapes[i].ToString()); + } + } + return callback_results; +} + +} // namespace + +PyFfiLoadedHostCallback::~PyFfiLoadedHostCallback() { + // The destructor may be called without GIL held. In that case, we defer it + // to GlobalPyRefManager. + std::vector objects; + objects.push_back(std::move(callable_)); + GlobalPyRefManager()->AddGarbage(absl::MakeSpan(objects)); +} + +absl::StatusOr> +PyHostSendAndRecvLoadedHostCallback::Create( + ifrt::Client* ifrt_client, nb::callable callable, + absl::Span operand_shapes, + absl::Span result_shapes, + absl::Span send_channel_ids, + absl::Span recv_channel_ids, nb::callable serializer) { + TF_ASSIGN_OR_RETURN(auto callback_args, CreateCallbackArgs(operand_shapes)); + TF_ASSIGN_OR_RETURN(auto callback_results, + CreateCallbackResults(result_shapes)); + + // `callable` will be destroyed safely with `PythonRefManager` when + // `CpuCallback` is destroyed. + auto cpu_callback = + std::make_shared(callable, callback_args, callback_results); + + auto host_callback = std::make_unique(); + + auto assign_arg_info = [](absl::Span shapes, + absl::Span channel_ids, + std::vector& arg_infos) { + DCHECK_EQ(shapes.size(), channel_ids.size()); + arg_infos.reserve(shapes.size()); + for (int i = 0; i < shapes.size(); ++i) { + HostCallbackArgInfo host_callback_arg_info; + host_callback_arg_info.channel_id = channel_ids[i]; + const auto& shape = shapes[i]; + Shape layout = + (shape.has_layout() ? shape + : LayoutUtil::GetWithDefaultLayout(shape)); + host_callback_arg_info.shape = layout; + arg_infos.push_back(std::move(host_callback_arg_info)); + } + }; + + assign_arg_info(operand_shapes, send_channel_ids, host_callback->operands); + assign_arg_info(result_shapes, recv_channel_ids, host_callback->results); + + host_callback->callback = [cpu_callback = std::move(cpu_callback)]( + void** outputs, void** inputs) { + return cpu_callback->PrepareAndCall(outputs, inputs); + }; + return tsl::RCReference( + tsl::MakeRef( + ifrt_client, std::move(host_callback), callable, operand_shapes, + result_shapes, send_channel_ids, recv_channel_ids, + std::move(serializer))); +} + +PyHostSendAndRecvLoadedHostCallback::PyHostSendAndRecvLoadedHostCallback( + ifrt::Client* ifrt_client, + std::unique_ptr xla_host_callback, nb::callable callable, + absl::Span operand_shapes, + absl::Span result_shapes, + absl::Span send_channel_ids, + absl::Span recv_channel_ids, nb::callable serializer) + : llvm::RTTIExtends( + ifrt_client, std::move(xla_host_callback)), + callable_(std::move(callable)), + operand_shapes_(operand_shapes.begin(), operand_shapes.end()), + result_shapes_(result_shapes.begin(), result_shapes.end()), + send_channel_ids_(send_channel_ids.begin(), send_channel_ids.end()), + recv_channel_ids_(recv_channel_ids.begin(), recv_channel_ids.end()), + serializer_(serializer) {} + +PyHostSendAndRecvLoadedHostCallback::~PyHostSendAndRecvLoadedHostCallback() { + GlobalPyRefManager()->AddGarbage( + absl::MakeSpan(static_cast(&callable_), 1)); + GlobalPyRefManager()->AddGarbage( + absl::MakeSpan(static_cast(&serializer_), 1)); +} + +absl::StatusOr PyHostSendAndRecvLoadedHostCallback::Serialize() + const { + if (serializer_.is_none()) { + return InvalidArgument( + "Host callback cannot be serialized because serializer was not " + "provided by JAX"); + } + ifrt::XlaHostCallbackProto xla_host_callback_proto; + + TF_RET_CHECK(operand_shapes_.size() == send_channel_ids_.size()); + for (int i = 0; i < operand_shapes_.size(); ++i) { + ifrt::XlaHostCallbackProto::ArgInfo* const operand = + xla_host_callback_proto.add_operands(); + operand->set_channel_id(send_channel_ids_[i]); + *operand->mutable_shape() = operand_shapes_[i].ToProto(); + } + + TF_RET_CHECK(result_shapes_.size() == recv_channel_ids_.size()); + for (int i = 0; i < result_shapes_.size(); ++i) { + ifrt::XlaHostCallbackProto::ArgInfo* const result = + xla_host_callback_proto.add_results(); + result->set_channel_id(recv_channel_ids_[i]); + *result->mutable_shape() = result_shapes_[i].ToProto(); + } + + std::string callable; + { + nb::gil_scoped_acquire gil_acquire; + try { + nb::bytes bytes = nb::cast(serializer_(callable_)); + callable = std::string(bytes.c_str(), bytes.size()); + } catch (const nb::python_error& e) { + return absl::InternalError(absl::StrCat( + "Unable to pickle the host_callback callable: ", e.what())); + } catch (const std::exception& e) { + std::exception_ptr p = std::current_exception(); + return absl::InternalError(absl::StrCat( + "Exception while pickling the host_callback callable: ", e.what())); + } catch (...) { + // Ensure to avoid leaking any exception because this method could have + // been called outside of a Python context where C++ exceptions are not + // necessarily enabled. + return absl::InternalError( + "Unknown exception while pickling the host_callback callable."); + } + } + PyHostCallbackProto py_host_callback_proto; + py_host_callback_proto.set_callable(std::move(callable)); + if (!xla_host_callback_proto.mutable_serialized_callback()->PackFrom( + py_host_callback_proto)) { + return absl::InternalError("Could not serialize a Python host callback"); + } + xla_host_callback_proto.set_use_major_to_minor_data_layout_for_callbacks( + true); + return xla_host_callback_proto.SerializeAsString(); +} + +} // namespace xla diff --git a/jaxlib/xla/py_host_callback.h b/jaxlib/xla/py_host_callback.h new file mode 100644 index 000000000000..1a1402a4eee2 --- /dev/null +++ b/jaxlib/xla/py_host_callback.h @@ -0,0 +1,119 @@ +/* Copyright 2023 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_PY_HOST_CALLBACK_H_ +#define JAXLIB_XLA_PY_HOST_CALLBACK_H_ + +#include +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "llvm/Support/ExtensibleRTTI.h" +#include "nanobind/nanobind.h" +#include "xla/pjrt/host_callback.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/host_callback.h" +#include "xla/python/pjrt_ifrt/pjrt_host_callback.h" +#include "xla/shape.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/util.h" + +namespace xla { + +using PyLoadedHostCallback = ::xla::ifrt::LoadedHostCallback; + +class PyFfiLoadedHostCallback final + : public llvm::RTTIExtends { + public: + PyFfiLoadedHostCallback(ifrt::Client* ifrt_client, + nanobind::callable callable) + : llvm::RTTIExtends(ifrt_client, + callable.ptr()), + callable_(std::move(callable)) {} + ~PyFfiLoadedHostCallback() override; + + ifrt::Client* client() const override { return ifrt_client_; } + absl::StatusOr Serialize() const override { + return Unimplemented( + "PyFfiLoadedHostCallback::Serialize() is not supported"); + }; + + static char ID; // NOLINT + + private: + ifrt::Client* ifrt_client_; + nanobind::callable callable_; +}; + +// `PyHostSendAndRecvLoadedHostCallback` implements a Python host callback that +// uses XLA host send and recv. This object should be passed to the compiler +// when creating `xla::ifrt::LoadedExecutable`. +// +// Serialization is supported if the Python host callback using the +// `cloudpickle` third-party library. +// +// TODO(hyeontaek): Update the comment ("compiler" to "client") after splitting +// compilation and loading. +class PyHostSendAndRecvLoadedHostCallback final + : public llvm::RTTIExtends { + public: + static absl::StatusOr> + Create(ifrt::Client* ifrt_client, nanobind::callable callable, + absl::Span operand_shapes, + absl::Span result_shapes, + absl::Span send_channel_ids, + absl::Span recv_channel_ids, + nanobind::callable serializer); + + // PjRtLoadedHostCallback implementation. + + ~PyHostSendAndRecvLoadedHostCallback() override; + + absl::StatusOr Serialize() const override; + + static char ID; // NOLINT + + private: + PyHostSendAndRecvLoadedHostCallback( + ifrt::Client* ifrt_client, + std::unique_ptr xla_host_callback, + nanobind::callable callable, absl::Span operand_shapes, + absl::Span result_shapes, + absl::Span send_channel_ids, + absl::Span recv_channel_ids, + nanobind::callable serializer); + + template + friend tsl::RCReference tsl::MakeRef(Args&&... args); + + // Retained arguments for host callback serialization. + nanobind::callable callable_; + std::vector operand_shapes_; + std::vector result_shapes_; + std::vector send_channel_ids_; + std::vector recv_channel_ids_; + nanobind::callable serializer_; +}; + +} // namespace xla + +#endif // JAXLIB_XLA_PY_HOST_CALLBACK_H_ diff --git a/jaxlib/xla/py_host_callback.proto b/jaxlib/xla/py_host_callback.proto new file mode 100644 index 000000000000..997fc7fe450c --- /dev/null +++ b/jaxlib/xla/py_host_callback.proto @@ -0,0 +1,25 @@ +/* Copyright 2023 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package xla; + +// Represents a JAX host callback that is serialized using the 'cloudpickle' +// Python library. Typically used for +// `xla.ifrt.XlaHostCallbackProto.serialized_callback`. +message PyHostCallbackProto { + bytes callable = 1; +} diff --git a/jaxlib/xla/py_memory_space.cc b/jaxlib/xla/py_memory_space.cc new file mode 100644 index 000000000000..0409861dd3b9 --- /dev/null +++ b/jaxlib/xla/py_memory_space.cc @@ -0,0 +1,102 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/py_memory_space.h" + +#include + +#include + +#include "absl/strings/string_view.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "jaxlib/xla/nb_class_ptr.h" +#include "jaxlib/xla/py_client.h" +#include "xla/python/ifrt/device.h" + +namespace nb = ::nanobind; + +namespace xla { + +PyMemorySpace::PyMemorySpace(nb_class_ptr client, + ifrt::Memory* memory) + : client_(std::move(client)), memory_(memory) {} + +int PyMemorySpace::process_index() const { return client_->process_index(); } + +absl::string_view PyMemorySpace::platform() const { + // TODO(phawkins): this is a temporary backwards + // compatibility shim. We changed the name PJRT + // reports for GPU platforms to "cuda" or "rocm", + // but we haven't yet updated JAX clients that + // expect "gpu". Migrate users and remove this + // code. + if (client_->platform_name() == "cuda" || + client_->platform_name() == "rocm") { + return absl::string_view("gpu"); + } else { + return client_->platform_name(); + } +} + +absl::string_view PyMemorySpace::kind() const { + return *memory_->Kind().memory_kind(); +} + +absl::string_view PyMemorySpace::Str() const { return memory_->DebugString(); } + +absl::string_view PyMemorySpace::Repr() const { return memory_->ToString(); } + +nb::list PyMemorySpace::AddressableByDevices() const { + nb::list devices; + for (ifrt::Device* device : memory_->Devices()) { + devices.append(client_->GetPyDevice(device)); + } + return devices; +} + +/* static */ int PyMemorySpace::tp_traverse(PyObject* self, visitproc visit, + void* arg) { + PyMemorySpace* d = nb::inst_ptr(self); + Py_VISIT(d->client().ptr()); + return 0; +} + +/* static */ int PyMemorySpace::tp_clear(PyObject* self) { + PyMemorySpace* d = nb::inst_ptr(self); + nb_class_ptr client; + std::swap(client, d->client_); + return 0; +} + +PyType_Slot PyMemorySpace::slots_[] = { + {Py_tp_traverse, (void*)PyMemorySpace::tp_traverse}, + {Py_tp_clear, (void*)PyMemorySpace::tp_clear}, + {0, nullptr}, +}; + +/* static */ void PyMemorySpace::RegisterPythonType(nb::module_& m) { + nb::class_ device(m, "Memory", + nb::type_slots(PyMemorySpace::slots_)); + device.def_prop_ro("process_index", &PyMemorySpace::process_index) + .def_prop_ro("platform", &PyMemorySpace::platform) + .def_prop_ro("kind", &PyMemorySpace::kind) + .def("__str__", &PyMemorySpace::Str) + .def("__repr__", &PyMemorySpace::Repr) + .def("addressable_by_devices", &PyMemorySpace::AddressableByDevices, + "Returns devices that can address this memory."); +} + +} // namespace xla diff --git a/jaxlib/xla/py_memory_space.h b/jaxlib/xla/py_memory_space.h new file mode 100644 index 000000000000..f38038af4870 --- /dev/null +++ b/jaxlib/xla/py_memory_space.h @@ -0,0 +1,65 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_PY_MEMORY_SPACE_H_ +#define JAXLIB_XLA_PY_MEMORY_SPACE_H_ + +#include + +#include "absl/strings/string_view.h" +#include "nanobind/nanobind.h" +#include "jaxlib/xla/nb_class_ptr.h" +#include "jaxlib/xla/py_client.h" +#include "xla/python/ifrt/memory.h" + +namespace xla { + +class PyMemorySpace { + public: + PyMemorySpace(nb_class_ptr client, ifrt::Memory* memory_space); + + // Memory spaces are compared using Python object identity, so we don't allow + // them to be copied or moved. + PyMemorySpace(const PyMemorySpace&) = delete; + PyMemorySpace(PyMemorySpace&&) = delete; + PyMemorySpace& operator=(const PyMemorySpace&) = delete; + PyMemorySpace& operator=(PyMemorySpace&&) = delete; + + const nb_class_ptr& client() const { return client_; } + ifrt::Memory* memory_space() const { return memory_; } + + int process_index() const; + absl::string_view platform() const; + absl::string_view kind() const; + + absl::string_view Str() const; + absl::string_view Repr() const; + + nanobind::list AddressableByDevices() const; + + static void RegisterPythonType(nanobind::module_& m); + + private: + static int tp_traverse(PyObject* self, visitproc visit, void* arg); + static int tp_clear(PyObject* self); + static PyType_Slot slots_[]; + + nb_class_ptr client_; + ifrt::Memory* memory_; +}; + +} // namespace xla + +#endif // JAXLIB_XLA_PY_MEMORY_SPACE_H_ diff --git a/jaxlib/xla/py_program.cc b/jaxlib/xla/py_program.cc new file mode 100644 index 000000000000..b3828f5372d9 --- /dev/null +++ b/jaxlib/xla/py_program.cc @@ -0,0 +1,291 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/py_program.h" + +#include +#include +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" +#include "absl/strings/string_view.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OwningOpRef.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/xla/nb_class_ptr.h" +#include "jaxlib/xla/py_device.h" +#include "jaxlib/xla/py_device_list.h" +#include "jaxlib/xla/python_ref_manager.h" +#include "jaxlib/xla/sharding.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/pjrt/mlir_to_hlo.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/ifrt/array_spec.h" +#include "xla/python/ifrt/compiler.h" +#include "xla/python/ifrt/custom_call_program.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/hlo/hlo_program.h" +#include "xla/python/ifrt/host_callback.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/plugin_program.h" +#include "xla/python/ifrt/program.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/nb_numpy.h" +#include "xla/python/pjrt_ifrt/xla_compiler.h" +#include "xla/python/pjrt_ifrt/xla_sharding.h" +#include "xla/python/types.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/statusor.h" + +namespace xla { + +namespace nb = ::nanobind; + +namespace { + +// Gets `ifrt::DeviceList` from a sequence of JAX devices. +absl::StatusOr GetDeviceList(nb::sequence devices) { + ifrt::DeviceListRef ifrt_device_list; + if (devices.type().is(jax::PyDeviceList::type())) { + return nb::cast(devices)->ifrt_device_list(); + } else { + auto py_devices = nb::cast>>(devices); + if (py_devices.empty()) { + return absl::InvalidArgumentError( + "Colocated Python program requires at least one device"); + } + absl::InlinedVector ifrt_devices; + ifrt_devices.reserve(py_devices.size()); + for (const nb_class_ptr& py_device : py_devices) { + ifrt_devices.push_back(py_device->device()); + } + return py_devices.front()->client()->ifrt_client()->MakeDeviceList( + ifrt_devices); + } +} + +// Gets `xla::HloSharding` from a JAX Sharding. +xla::HloSharding GetXlaHloSharding(nb::handle sharding, + int64_t num_dimensions) { + if (sharding.type().is(jax::GSPMDSharding::type())) { + return nb::cast(sharding)->hlo_sharding(); + } else { + return nb::cast( + sharding.attr("_to_xla_hlo_sharding")(num_dimensions)); + } +} + +// Gets `ifrt::DeviceList` from a JAX Sharding. +absl::StatusOr GetIfrtDeviceList(nb::handle sharding) { + if (sharding.type().is(jax::NamedSharding::type())) { + TF_ASSIGN_OR_RETURN( + auto ns_device_list, + nb::cast(sharding)->internal_device_list()); + return ns_device_list->ifrt_device_list(); + } else if (sharding.type().is(jax::SingleDeviceSharding::type())) { + return nb::cast(sharding) + ->internal_device_list() + ->ifrt_device_list(); + } else if (sharding.type().is(jax::PmapSharding::type())) { + return nb::cast(sharding) + ->internal_device_list() + ->ifrt_device_list(); + } else if (sharding.type().is(jax::GSPMDSharding::type())) { + return nb::cast(sharding) + ->internal_device_list() + ->ifrt_device_list(); + } else { + return nb::cast( + sharding.attr("_internal_device_list")) + ->ifrt_device_list(); + } +} + +// Gets `ifrt::MemoryKind` from a JAX Sharding. +ifrt::MemoryKind GetIfrtMemoryKind(nb::handle sharding) { + auto memory_kind = sharding.attr("memory_kind"); + if (memory_kind.is_none()) { + return ifrt::MemoryKind(); + } else { + return ifrt::MemoryKind(nb::cast(memory_kind)); + } +} + +// Makes `ifrt::Sharding` from a JAX Sharding. It requires the number of shape +// dimensions, which may become necessary when building an HLO sharding. +absl::StatusOr> GetIfrtSharding( + nb::handle sharding, int64_t num_dimensions) { + auto ifrt_memory_kind = GetIfrtMemoryKind(sharding); + std::shared_ptr ifrt_sharding; + if (sharding.type().is(jax::SingleDeviceSharding::type())) { + TF_ASSIGN_OR_RETURN(auto ifrt_device_list, + nb::cast(sharding) + ->internal_device_list() + ->ifrt_device_list()); + return ifrt::SingleDeviceSharding::Create( + ifrt_device_list->devices().front(), ifrt_memory_kind); + } else { + TF_ASSIGN_OR_RETURN(auto ifrt_device_list, GetIfrtDeviceList(sharding)); + auto xla_hlo_sharding = GetXlaHloSharding(sharding, num_dimensions); + return ifrt::HloSharding::Create(std::move(ifrt_device_list), + ifrt_memory_kind, + std::move(xla_hlo_sharding)); + } +} + +// Gets `ifrt::ArraySpec`s from a sequence of JAX avals (e.g., +// `jax.ShapeDtypeStruct`). +absl::StatusOr> GetIfrtArraySpecs( + nb::sequence avals) { + std::vector ifrt_array_specs; + ifrt_array_specs.reserve(nb::len(avals)); + for (nb::handle aval : avals) { + ifrt::Shape ifrt_shape(nb::cast>(aval.attr("shape"))); + TF_ASSIGN_OR_RETURN( + auto ifrt_dtype, + DtypeToIfRtDType(nb::cast(aval.attr("dtype")))); + TF_ASSIGN_OR_RETURN( + auto ifrt_sharding, + GetIfrtSharding(aval.attr("sharding"), ifrt_shape.dims().size())); + ifrt_array_specs.push_back(ifrt::ArraySpec{ + ifrt_dtype, std::move(ifrt_shape), std::move(ifrt_sharding)}); + } + return ifrt_array_specs; +} + +absl::StatusOr> MakePluginProgramFromString( + std::string data) { + auto plugin_program = std::make_unique(); + plugin_program->data = std::move(data); + return plugin_program; +} + +absl::StatusOr> MakePluginProgramFromBytes( + nb::bytes data) { + auto plugin_program = std::make_unique(); + plugin_program->data = std::string(data.c_str(), data.size()); + return plugin_program; +} + +absl::StatusOr> +MakeColocatedPythonCompileOptions() { + return std::make_unique(); +} + +absl::StatusOr> +MakePluginCompileOptions() { + return std::make_unique(); +} + +absl::StatusOr> MakeHloProgram( + absl::string_view mlir_module) { + auto context = std::make_unique(); + TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, + ParseMlirModuleString(mlir_module, *context)); + return std::make_unique(std::move(context), + std::move(module)); +} + +absl::StatusOr> MakeHloProgramFromString( + std::string mlir_module) { + return MakeHloProgram(mlir_module); +} + +absl::StatusOr> MakeHloProgramFromBytes( + nb::bytes mlir_module) { + return MakeHloProgram( + absl::string_view(mlir_module.c_str(), mlir_module.size())); +} + +absl::StatusOr> MakeXlaCompileOptions( + CompileOptions options, std::vector host_callbacks) { + std::vector> + ifrt_loaded_host_callbacks; + ifrt_loaded_host_callbacks.reserve(host_callbacks.size()); + // Extract `ifrt::LoadedHostCallback`s from host callback capsules that were + // created by `PyClient::MakePythonCallbackUsingHostSendAndRecv()` or + // `PyClient::GetEmitPythonCallbackDescriptor()`. + for (auto& host_callback : host_callbacks) { + ifrt_loaded_host_callbacks.push_back(tsl::FormRef( + static_cast(host_callback.data()))); + } + return std::make_unique( + std::move(options), std::move(ifrt_loaded_host_callbacks)); +} + +constexpr absl::string_view kColocatedPythonProgramType = + "jax_colocated_python_v0.0.1"; + +absl::StatusOr> MakeColocatedPythonProgram( + std::string name, nb::bytes picked_function, nb::sequence devices, + nb::sequence input_avals, nb::sequence output_avals) { + auto ifrt_serialized_program_text = absl::MakeCordFromExternal( + absl::string_view(reinterpret_cast(picked_function.data()), + picked_function.size()), + /*releaser=*/[picked_function](absl::string_view) mutable { + GlobalPyRefManager()->AddGarbage(std::move(picked_function)); + }); + TF_ASSIGN_OR_RETURN(auto ifrt_device_list, GetDeviceList(devices)); + TF_ASSIGN_OR_RETURN(auto ifrt_input_specs, GetIfrtArraySpecs(input_avals)); + TF_ASSIGN_OR_RETURN(auto ifrt_output_specs, GetIfrtArraySpecs(output_avals)); + return std::make_unique( + std::string(kColocatedPythonProgramType), std::move(name), + std::move(ifrt_serialized_program_text), std::move(ifrt_device_list), + std::move(ifrt_input_specs), std::move(ifrt_output_specs)); +} + +} // namespace + +void BuildIfrtProgramsSubmodule(nanobind::module_& m) { + auto sub_module = m.def_submodule("ifrt_programs"); + nb::class_ ifrt_program_base_class(sub_module, "Program"); + nb::class_ ifrt_compile_options_base_class( + sub_module, "CompileOptions"); + sub_module + .def("make_hlo_program", ValueOrThrowWrapper(MakeHloProgramFromString), + nb::arg("mlir_module")) + .def("make_hlo_program", ValueOrThrowWrapper(MakeHloProgramFromBytes), + nb::arg("mlir_module")) + .def("make_colocated_python_program", + ValueOrThrowWrapper(MakeColocatedPythonProgram), nb::arg("name"), + nb::arg("pickled_function"), nb::arg("devices"), + nb::arg("input_avals"), nb::arg("output_avals")) + .def("make_plugin_program", + ValueOrThrowWrapper(MakePluginProgramFromString), nb::arg("data")) + .def("make_plugin_program", + ValueOrThrowWrapper(MakePluginProgramFromBytes), nb::arg("data")) + .def("make_xla_compile_options", + ValueOrThrowWrapper(MakeXlaCompileOptions), nb::arg("options"), + nb::arg("host_callbacks")) + .def("make_colocated_python_compile_options", + ValueOrThrowWrapper(MakeColocatedPythonCompileOptions)) + .def("make_plugin_compile_options", + ValueOrThrowWrapper(MakePluginCompileOptions)); +} + +} // namespace xla diff --git a/jaxlib/xla/py_program.h b/jaxlib/xla/py_program.h new file mode 100644 index 000000000000..9fd30eeeed2f --- /dev/null +++ b/jaxlib/xla/py_program.h @@ -0,0 +1,27 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_PY_PROGRAM_H_ +#define JAXLIB_XLA_PY_PROGRAM_H_ + +#include "nanobind/nanobind.h" + +namespace xla { + +void BuildIfrtProgramsSubmodule(nanobind::module_& m); + +} // namespace xla + +#endif // JAXLIB_XLA_PY_PROGRAM_H_ diff --git a/jaxlib/xla/py_socket_transfer.cc b/jaxlib/xla/py_socket_transfer.cc new file mode 100644 index 000000000000..4aa40cf66087 --- /dev/null +++ b/jaxlib/xla/py_socket_transfer.cc @@ -0,0 +1,412 @@ +/* Copyright 2025 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "jaxlib/xla/py_socket_transfer.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/synchronization/mutex.h" +#include "llvm/Support/Casting.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/array.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/xla/nb_class_ptr.h" +#include "jaxlib/xla/py_array.h" +#include "jaxlib/xla/py_client.h" +#include "jaxlib/xla/to_ifrt_sharding.h" +#include "jaxlib/xla/traceback.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/array_spec.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/nb_numpy.h" +#include "xla/python/pjrt_ifrt/pjrt_array.h" +#include "xla/python/pjrt_ifrt/pjrt_device.h" +#include "xla/python/pjrt_ifrt/pjrt_dtype.h" +#include "xla/python/pjrt_ifrt/pjrt_memory.h" +#include "xla/python/transfer/event_loop.h" +#include "xla/python/transfer/socket-server.h" +#include "xla/python/transfer/socket_bulk_transport.h" +#include "xla/python/transfer/streaming.h" +#include "xla/python/transfer/streaming_ifrt.h" +#include "xla/python/transfer/transfer_socket.pb.h" +#include "xla/python/types.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" +#include "tsl/platform/casts.h" + +namespace aux { + +namespace nb = nanobind; + +absl::StatusOr MemorySpaceFromSharding( + const xla::ifrt::Sharding& sharding) { + if (sharding.devices()->devices().size() != 1) { + return xla::InvalidArgument( + "Can only convert SingleDeviceSharding to MemorySpace not %s", + sharding.DebugString()); + } + auto* device = sharding.devices()->devices()[0]; + if (sharding.memory_kind().memory_kind().has_value()) { + // Find `PjRtMemorySpace` that is associated with the sharding's device + // and matches the sharding's memory_kind. + xla::ifrt::Memory* memory = nullptr; + for (xla::ifrt::Memory* ms : device->Memories()) { + if (ms->Kind() == sharding.memory_kind()) { + memory = ms; + break; + } + } + if (memory == nullptr) { + return xla::InvalidArgument( + "Invalid memory kind: %s; available memory kinds: %s", + *sharding.memory_kind().memory_kind(), + absl::StrJoin(sharding.devices()->devices().front()->Memories(), ", ", + [](std::string* out, xla::ifrt::Memory* ms) { + absl::StrAppend(out, *ms->Kind().memory_kind()); + })); + } + return tensorflow::down_cast(memory)->pjrt_memory(); + } else { + if (!device->IsAddressable()) { + return xla::InvalidArgument( + "Cannot copy array to non-addressable device %s", + device->DebugString()); + } + return tensorflow::down_cast(device) + ->pjrt_device() + ->default_memory_space(); + } +} + +class IfrtArrayEntry : public PullTable::Entry { + public: + struct BufferRef { + tsl::RCReference arr; + xla::PjRtBuffer* buffer; + size_t buf_size; + }; + explicit IfrtArrayEntry(std::vector arrs, + std::shared_ptr state, + size_t xfer_size) + : arrs_(std::move(arrs)), state_(state), xfer_size_(xfer_size) {} + bool Handle(tsl::RCReference state, + const SocketTransferPullRequest& req, + size_t base_req_id) override { + for (uint64_t bid : req.buffer_ids()) { + auto req_id = base_req_id; + ++base_req_id; + for (size_t i = 0; i * xfer_size_ < arrs_[bid].buf_size; ++i) { + DmaCopyChunk blob; + blob.arr = std::move(arrs_[bid].arr); + blob.buffer = arrs_[bid].buffer; + blob.buffer_id = bid; + blob.offset = i * xfer_size_; + blob.size = std::min(xfer_size_, arrs_[bid].buf_size - blob.offset); + bool is_largest = blob.size + blob.offset == arrs_[bid].buf_size; + state_->ScheduleCopy( + blob, [req_id, state, copier_state = state_, is_largest]( + PremappedCopierState* copier_state_ptr, void* buf, + const DmaCopyChunk& chunk) { + state->Send( + req_id, buf, chunk.offset, chunk.size, is_largest, + [copier_state, buf]() { copier_state->ReturnBuffer(buf); }); + }); + } + } + + num_consumed_bufs_ += req.buffer_ids().size(); + return num_consumed_bufs_ == arrs_.size(); + } + + private: + absl::Mutex mu_; + size_t num_consumed_bufs_ = 0; + std::vector arrs_; + std::shared_ptr state_; + size_t xfer_size_; +}; + +absl::StatusOr> CreatePullEntry( + const std::vector>& arrs, + std::shared_ptr state, size_t xfer_size) { + std::vector refs; + for (auto& arr : arrs) { + auto* pjrt_arr = llvm::dyn_cast_or_null(arr.get()); + if (pjrt_arr == nullptr) { + return absl::InvalidArgumentError( + "Cannot remote transfer non-pjrt arrays."); + } + for (auto& pjrt_buf : pjrt_arr->pjrt_buffers()) { + TF_ASSIGN_OR_RETURN(size_t buf_size, pjrt_buf->GetOnDeviceSizeInBytes()); + refs.push_back({arr, pjrt_buf.get(), buf_size}); + } + } + return tsl::MakeRef(std::move(refs), state, xfer_size); +} + +class PyTransferServerConnection { + public: + explicit PyTransferServerConnection( + tsl::RCReference conn) + : conn_(std::move(conn)) {} + + void Pull(uint64_t uuid, std::vector buffer_ids, + std::vector> pull_dests) { + for (size_t i = 0; i < buffer_ids.size(); ++i) { + conn_->Pull(uuid, buffer_ids[i], std::move(pull_dests[i])); + } + } + + private: + tsl::RCReference conn_; +}; + +class PyTransferServer { + public: + PyTransferServer() = default; + absl::Status Start(xla::ifrt::Client* client, size_t max_num_parallel_copies, + size_t xfer_size, const SocketAddress& addr, + const std::vector& transport_addresses) { + std::shared_ptr factory; + if (transport_addresses.empty()) { + factory = BulkTransportFactory::CreateLocal(); + } else { + auto tmp = xla::ValueOrThrow( + AllocateAlignedMemory(xfer_size * max_num_parallel_copies)); + SlabAllocator uallocator(xla::ValueOrThrow(MapPjrtMemory( + client, tmp->data(), tmp->size(), tmp)), + xfer_size); + factory = xla::ValueOrThrow(CreateSocketBulkTransportFactory( + transport_addresses, std::nullopt, uallocator)); + } + + server_ = std::make_shared(); + + TF_ASSIGN_OR_RETURN(auto mem, + AllocateAndMapPjrtMemory( + client, max_num_parallel_copies * xfer_size * 2)); + premapped_copier_ = std::make_shared( + mem, max_num_parallel_copies, xfer_size); + xfer_size_ = xfer_size; + return server_->Start(addr, factory); + } + std::string address() { return server_->addr().ToString(); } + + PyTransferServerConnection Connect(const std::string& saddr) { + return PyTransferServerConnection( + server_->Connect(xla::ValueOrThrow(SocketAddress::Parse(saddr)))); + } + + void AwaitPull(uint64_t uuid, + const std::vector>& arrs) { + server_->AwaitPull(uuid, xla::ValueOrThrow(CreatePullEntry( + arrs, premapped_copier_, xfer_size_))); + } + + size_t xfer_size() { return xfer_size_; } + + std::shared_ptr premapped_copier() { + return premapped_copier_; + } + + private: + std::shared_ptr server_; + std::shared_ptr premapped_copier_; + size_t xfer_size_; +}; + +absl::StatusOr ArraySpecFromShapeDtypeStruct( + nb::handle aval) { + TF_ASSIGN_OR_RETURN(xla::ifrt::DType dtype, + xla::DtypeToIfRtDType( + nb::borrow(aval.attr("dtype").ptr()))); + auto shape_dims = nb::cast>(aval.attr("shape")); + auto shape = xla::ifrt::Shape( + xla::ifrt::Shape::Dimensions(shape_dims.begin(), shape_dims.end())); + TF_ASSIGN_OR_RETURN(auto sharding, + xla::GetIfrtHloSharding(aval.attr("sharding"), shape)); + return xla::ifrt::ArraySpec{dtype, std::move(shape), std::move(sharding)}; +} + +struct BufferSource { + tsl::RCReference arr; + xla::PjRtBuffer* buffer; +}; + +struct CopyDests { + std::vector shape_specs; + xla::PjRtMemorySpace* memory_space; +}; + +void RegisterTransferServerTypes(nanobind::module_& m) { + nb::class_(m, "TransferConnection") + .def("_pull_flat", [](PyTransferServerConnection& self, uint64_t uuid, + xla::nb_class_ptr py_client, + std::vector py_avals) { + auto* ifrt_client = llvm::dyn_cast_or_null( + py_client->ifrt_client()); + if (ifrt_client == nullptr) { + xla::ThrowIfError(absl::InvalidArgumentError( + "_pull_flat only supported on pjrt-ifrt clients.")); + } + + std::vector avals; + std::vector shardings; + shardings.reserve(py_avals.size()); + avals.reserve(py_avals.size()); + for (const auto& py_aval : py_avals) { + avals.push_back( + xla::ValueOrThrow(ArraySpecFromShapeDtypeStruct(py_aval))); + shardings.push_back(py_aval.attr("sharding")); + } + + std::vector dests; + std::vector> fetch_idxs; + absl::flat_hash_map mapping; + std::vector>> buffer_list; + + for (auto& aval : avals) { + std::vector> buf_list; + auto prim_type = + xla::ValueOrThrow(xla::ifrt::ToPrimitiveType(aval.dtype)); + auto shards = xla::ValueOrThrow(aval.sharding->Disassemble( + aval.shape, + xla::ifrt::SingleDeviceShardSemantics::kAddressableShards)); + buf_list.reserve(shards.size()); + for (auto& shard : shards) { + auto* mem_space = + xla::ValueOrThrow(MemorySpaceFromSharding(*shard.second)); + int dest_idx = + mapping.emplace(mem_space, static_cast(dests.size())) + .first->second; + if (dest_idx == dests.size()) { + dests.emplace_back(); + dests.back().memory_space = mem_space; + } + fetch_idxs.push_back( + {dest_idx, + static_cast(dests[dest_idx].shape_specs.size())}); + buf_list.push_back(fetch_idxs.back()); + dests[dest_idx].shape_specs.push_back( + {prim_type, xla::DimensionVector(shard.first.dims().begin(), + shard.first.dims().end())}); + } + buffer_list.push_back(std::move(buf_list)); + } + + std::vector< + std::shared_ptr> + atms; + atms.reserve(dests.size()); + + for (auto& dest : dests) { + atms.push_back(xla::ValueOrThrow( + py_client->pjrt_client()->CreateBuffersForAsyncHostToDevice( + dest.shape_specs, std::nullopt, dest.memory_space))); + } + + std::vector> pull_dests; + std::vector buffer_ids; + pull_dests.reserve(fetch_idxs.size()); + buffer_ids.reserve(fetch_idxs.size()); + for (auto& fetch_idx : fetch_idxs) { + auto& atm = atms[fetch_idx.first]; + pull_dests.push_back(MakeDmaDestination( + atm, fetch_idx.second, atm->buffer_size(fetch_idx.second))); + buffer_ids.push_back(static_cast(buffer_ids.size())); + } + + self.Pull(uuid, buffer_ids, std::move(pull_dests)); + + std::vector out; + auto traceback = xla::Traceback::Get(); + for (size_t i = 0; i < buffer_list.size(); ++i) { + xla::ifrt::PjRtArray::PjRtBuffers buffers; + buffers.reserve(buffer_list[i].size()); + for (auto& v : buffer_list[i]) { + buffers.push_back(atms[v.first]->RetrieveBuffer(v.second)); + } + auto arr = xla::ValueOrThrow(xla::ifrt::PjRtArray::Create( + ifrt_client, avals[i].dtype, avals[i].shape, avals[i].sharding, + std::move(buffers), avals[i].layout)); + out.push_back(xla::PyArray::MakeFromIfrtArrayAndSharding( + py_client, traceback, std::move(arr), shardings[i], false, true, + /*skip_checks=*/false)); + } + + return out; + }); + + nb::class_(m, "TransferServer") + .def("address", [](PyTransferServer& self) { return self.address(); }) + .def("_await_pull_flat", + [](PyTransferServer& self, uint64_t uuid, + std::vector inputs) { + std::vector> arrs; + arrs.reserve(inputs.size()); + for (const xla::PyArray& input : inputs) { + arrs.push_back(tsl::FormRef(input.ifrt_array())); + } + self.AwaitPull(uuid, arrs); + }) + .def("connect", [](PyTransferServer& self, const std::string& address) { + return self.Connect(address); + }); + + m.def( + "start_transfer_server", + [](xla::nb_class_ptr py_client, std::string address, + std::vector transport_addresses_str, + size_t max_num_parallel_copies, + size_t transfer_size) -> PyTransferServer { + PyTransferServer result; + std::vector transport_addresses; + transport_addresses.reserve(transport_addresses_str.size()); + for (const std::string& addr : transport_addresses_str) { + transport_addresses.push_back( + xla::ValueOrThrow(SocketAddress::Parse(addr))); + } + xla::ThrowIfError(result.Start( + py_client->ifrt_client(), max_num_parallel_copies, transfer_size, + xla::ValueOrThrow(SocketAddress::Parse(address)), + transport_addresses)); + return result; + }, + nb::arg("client"), nb::arg("address") = SocketAddress().ToString(), + nb::arg("transport_addresses") = std::vector(), + nb::arg("max_num_parallel_copies") = 8, + nb::arg("transfer_size") = 256 * 1024 * 1024); +} + +} // namespace aux diff --git a/jaxlib/xla/py_socket_transfer.h b/jaxlib/xla/py_socket_transfer.h new file mode 100644 index 000000000000..fa477f24e3e5 --- /dev/null +++ b/jaxlib/xla/py_socket_transfer.h @@ -0,0 +1,26 @@ +/* Copyright 2025 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef JAXLIB_XLA_TRANSFER_PY_SOCKET_TRANSFER_H_ +#define JAXLIB_XLA_TRANSFER_PY_SOCKET_TRANSFER_H_ + +#include "nanobind/nanobind.h" + +namespace aux { + +void RegisterTransferServerTypes(nanobind::module_& m); + +} // namespace aux + +#endif // JAXLIB_XLA_TRANSFER_PY_SOCKET_TRANSFER_H_ diff --git a/jaxlib/xla/py_values.cc b/jaxlib/xla/py_values.cc new file mode 100644 index 000000000000..90dd77209694 --- /dev/null +++ b/jaxlib/xla/py_values.cc @@ -0,0 +1,759 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/py_values.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/complex.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "jaxlib/xla/py_array.h" +#include "jaxlib/xla/python_ref_manager.h" +#include "jaxlib/xla/sharding.h" +#include "xla/primitive_util.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/nb_numpy.h" +#include "xla/python/pjrt_ifrt/pjrt_dtype.h" +#include "xla/python/types.h" +#include "xla/shape.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/python/lib/core/numpy.h" +#include "xla/types.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/ml_dtypes.h" +#include "tsl/profiler/lib/traceme.h" + +namespace nb = nanobind; + +namespace xla { + +namespace { + +absl::StatusOr> StringDTypeArrayToCords( + PyArrayObject* py_array_obj) { + if (PyArray_SIZE(py_array_obj) == 0) { + return absl::InvalidArgumentError("empty numpy array"); + } + + std::vector cords; + cords.reserve(PyArray_SIZE(py_array_obj)); + + auto iter = + nb::steal(PyArray_IterNew(reinterpret_cast(py_array_obj))); + while (PyArray_ITER_NOTDONE(iter.ptr())) { + auto* iter_data = PyArray_ITER_DATA(iter.ptr()); + auto* item = PyArray_GETITEM(py_array_obj, static_cast(iter_data)); + if (!item) { + return absl::InternalError( + "Failed to get elements out of the ndarray iter."); + } + Py_ssize_t len; + auto str = PyUnicode_AsUTF8AndSize(item, &len); + cords.push_back(absl::Cord(absl::string_view(str, len))); + PyArray_ITER_NEXT(iter.ptr()); + } + return cords; +} + +using DevicePutFunc = std::function( + nb::handle, ifrt::Client*, ifrt::Device*, const DevicePutOptions& options, + ifrt::MemoryKind to_memory_kind)>; + +template +absl::StatusOr HandlePythonScalar( + nb::handle obj, ifrt::Client* client, ifrt::Device* to_device, + const DevicePutOptions& options, ifrt::MemoryKind to_memory_kind) { + T value; + try { + value = nb::cast(obj); + } catch (const std::exception& e) { + return InvalidArgument( + "Unable to convert Python scalar to %s. This most likely means the " + "value (%s) overflows the range of the type.", + PrimitiveType_Name(primitive_util::NativeToPrimitiveType()), + nb::cast(nb::repr(obj))); + } + + std::variant data; + Shape shape; + PrimitiveType type; + if (std::is_same() || !options.squash_64bit_types) { + data.template emplace<0>(value); + type = primitive_util::NativeToPrimitiveType(); + } else { + // TODO(phawkins): we should check for overflow here, e.g., because of bugs + // like https://github.com/google/jax/issues/2006 + data.template emplace<1>(static_cast(value)); + type = primitive_util::NativeToPrimitiveType(); + } + + return [client, data, type, to_device, to_memory_kind, + options]() -> absl::StatusOr { + const void* ptr = std::visit( + [](const auto& v) { return static_cast(&v); }, data); + TF_ASSIGN_OR_RETURN(auto ifrt_dtype, xla::ifrt::ToDType(type)); + // TODO(yashkatariya): Plumb sharding or memory_kind here. + TF_ASSIGN_OR_RETURN( + auto ifrt_array, + client->MakeArrayFromHostBuffer( + ptr, ifrt_dtype, /*shape=*/ifrt::Shape({}), /*byte_strides=*/{}, + ifrt::SingleDeviceSharding::Create(to_device, to_memory_kind), + ifrt::Client::HostBufferSemantics::kImmutableOnlyDuringCall, + /*on_done_with_host_buffer=*/{}, options.ifrt_user_context)); + return DevicePutResult(std::move(ifrt_array), /*weak_type=*/true); + }; +} + +absl::StatusOr HandlePythonInt( + nb::handle obj, ifrt::Client* client, ifrt::Device* to_device, + const DevicePutOptions& options, ifrt::MemoryKind to_memory_kind) { + PrimitiveType type; + std::variant data; + + if (options.squash_64bit_types) { + try { + data.emplace<1>(nb::cast(obj)); + } catch (const std::exception& e) { + return InvalidArgument( + "Unable to convert Python scalar to %s. This most likely means the " + "value (%s) overflows the range of the type.", + PrimitiveType_Name(primitive_util::NativeToPrimitiveType()), + nb::cast(nb::repr(obj))); + } + type = S32; + } else { + try { + data.emplace<0>(nb::cast(obj)); + } catch (const std::exception& e) { + return InvalidArgument( + "Unable to convert Python scalar to %s. This most likely means the " + "value (%s) overflows the range of the type.", + PrimitiveType_Name(primitive_util::NativeToPrimitiveType()), + nb::cast(nb::repr(obj))); + } + type = S64; + } + return [client, data, type, to_device, to_memory_kind, + options]() -> absl::StatusOr { + const void* ptr = std::visit( + [](const auto& v) { return static_cast(&v); }, data); + TF_ASSIGN_OR_RETURN(auto ifrt_dtype, xla::ifrt::ToDType(type)); + // TODO(yashkatariya): Plumb sharding or memory_kind here. + TF_ASSIGN_OR_RETURN( + auto ifrt_array, + client->MakeArrayFromHostBuffer( + ptr, ifrt_dtype, /*shape=*/xla::ifrt::Shape({}), + /*byte_strides=*/{}, + ifrt::SingleDeviceSharding::Create(to_device, to_memory_kind), + ifrt::Client::HostBufferSemantics::kImmutableOnlyDuringCall, + /*on_done_with_host_buffer=*/nullptr, options.ifrt_user_context)); + return DevicePutResult(std::move(ifrt_array), /*weak_type=*/true); + }; +} + +template +absl::StatusOr HandleNumpyScalar( + nb::handle h, ifrt::Client* client, ifrt::Device* to_device, + const DevicePutOptions& options, ifrt::MemoryKind to_memory_kind) { + std::variant data; + PrimitiveType type; + // For extension types, ScalarAsCtype returns a pointer to the data. + if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = S2; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = S4; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = U2; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = U4; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = BF16; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = F4E2M1FN; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = F8E3M4; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = F8E4M3; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = F8E4M3FN; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = F8E4M3B11FNUZ; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = F8E5M2; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = F8E4M3FNUZ; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = F8E5M2FNUZ; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = F8E8M0FNU; + } else if (std::is_same() || !options.squash_64bit_types) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<0>()); + type = primitive_util::NativeToPrimitiveType(); + } else { + T value; + PyArray_ScalarAsCtype(h.ptr(), &value); + data.template emplace<1>(static_cast(value)); + type = primitive_util::NativeToPrimitiveType(); + } + std::shared_ptr py_buffer_ref; + if (data.index() == 2) { + py_buffer_ref = + GlobalPyRefManager()->ManageReference(nb::cast(h)); + } + return [client, data, py_buffer_ref, type, to_device, options, + to_memory_kind]() mutable -> absl::StatusOr { + const void* ptr = std::visit( + [](const auto& v) -> const void* { + if constexpr (std::is_same_v, void*>) { + return v; + } else { + return static_cast(&v); + } + }, + data); + TF_ASSIGN_OR_RETURN(auto ifrt_dtype, xla::ifrt::ToDType(type)); + // TODO(yashkatariya): Plumb sharding or memory_kind here. + TF_ASSIGN_OR_RETURN( + auto ifrt_array, + client->MakeArrayFromHostBuffer( + ptr, ifrt_dtype, /*shape=*/xla::ifrt::Shape({}), + /*byte_strides=*/{}, + ifrt::SingleDeviceSharding::Create(to_device, to_memory_kind), + ifrt::Client::HostBufferSemantics::kImmutableOnlyDuringCall, + /*on_done_with_host_buffer=*/ + [py_buffer_ref = std::move( + py_buffer_ref)]() { /* keeps py_buffer_ref alive */ }, + options.ifrt_user_context)); + return DevicePutResult(std::move(ifrt_array), /*weak_type=*/false); + }; +} + +absl::StatusOr HandleStringNumpyArray( + nb::handle h, ifrt::Client* client, ifrt::Device* to_device, + const DevicePutOptions& options, ifrt::MemoryKind to_memory_kind) { + xla::nb_numpy_ndarray array = nb::cast(h); + auto py_array_obj = reinterpret_cast(array.ptr()); + TF_ASSIGN_OR_RETURN(auto cords, StringDTypeArrayToCords(py_array_obj)); + + // Assemble all the parameters of MakeArrayFromHostBuffer + void* data = cords.data(); + + // Make an explicit copy of the shape elements so we won't run into complex + // endianness and precision issues that might arise if we reinterpret-casted + // from npy_intp, that can be just 32 bits-wide in some environments + // such as macos_arm64 to const int64_t* that must be 64 bits-wide. + ifrt::Shape::Dimensions dims; + dims.reserve(array.ndim()); + for (int i = 0; i < array.ndim(); ++i) { + dims.push_back(array.shape(i)); + } + ifrt::Shape shape(std::move(dims)); + + std::shared_ptr sharding = + xla::ifrt::SingleDeviceSharding::Create(to_device, to_memory_kind); + + auto on_done_with_host_buffer = [cords = std::move(cords)] {}; + + return [client, data = data, shape = std::move(shape), + sharding = std::move(sharding), + on_done_with_host_buffer = std::move(on_done_with_host_buffer), + options]() mutable -> absl::StatusOr { + TF_ASSIGN_OR_RETURN( + auto ifrt_array, + client->MakeArrayFromHostBuffer( + data, ifrt::DType(ifrt::DType::kString), std::move(shape), + /*byte_strides=*/std::nullopt, std::move(sharding), + ifrt::Client::HostBufferSemantics::kImmutableUntilTransferCompletes, + std::move(on_done_with_host_buffer), options.ifrt_user_context)); + + return DevicePutResult(std::move(ifrt_array), /*weak_type=*/false); + }; +} + +absl::StatusOr HandleNumpyArray( + nb::handle h, ifrt::Client* client, ifrt::Device* to_device, + const DevicePutOptions& options, ifrt::MemoryKind to_memory_kind) { + xla::nb_numpy_ndarray array = nb::cast(h); + + // String numpy arrays require substantially different processing. + if (array.dtype().char_() == (int)'T' || array.dtype().kind() == 'T') { + return HandleStringNumpyArray(h, client, to_device, options, + to_memory_kind); + } + + TF_ASSIGN_OR_RETURN(PrimitiveType type, DtypeToPrimitiveType(array.dtype())); + + PrimitiveType squashed_type; + if (options.squash_64bit_types) { + squashed_type = Squash64BitTypes(type); + if (squashed_type != type) { + TF_ASSIGN_OR_RETURN(xla::nb_dtype squashed_dtype, + PrimitiveTypeToNbDtype(squashed_type)); + array = nb::steal(PyArray_CastToType( + reinterpret_cast(array.ptr()), + reinterpret_cast(squashed_dtype.release().ptr()), + /*fortran=*/0)); + } + } else { + squashed_type = type; + } + + absl::InlinedVector dims(array.ndim()); + absl::InlinedVector byte_strides(array.ndim()); + for (int i = 0; i < array.ndim(); ++i) { + dims[i] = array.shape(i); + byte_strides[i] = array.strides(i); + } + const void* data = array.data(); + std::shared_ptr py_buffer_ref = + GlobalPyRefManager()->ManageReference(std::move(array)); + return [client, data, squashed_type, dims = std::move(dims), + byte_strides = std::move(byte_strides), + py_buffer_ref = std::move(py_buffer_ref), options, to_device, + to_memory_kind]() mutable -> absl::StatusOr { + TF_ASSIGN_OR_RETURN(auto ifrt_dtype, xla::ifrt::ToDType(squashed_type)); + + ifrt::Client::HostBufferSemantics host_buffer_semantics = + ifrt::Client::HostBufferSemantics::kImmutableOnlyDuringCall; + std::function on_done_with_host_buffer; + if (options.allow_zero_copy) { + on_done_with_host_buffer = + [py_buffer_ref{ + std::move(py_buffer_ref)}]() { /* keeps py_buffer_ref alive */ }; + host_buffer_semantics = + ifrt::Client::HostBufferSemantics::kImmutableZeroCopy; + } + + TF_ASSIGN_OR_RETURN( + auto ifrt_array, + client->MakeArrayFromHostBuffer( + data, ifrt_dtype, ifrt::Shape(dims), byte_strides, + xla::ifrt::SingleDeviceSharding::Create(to_device, to_memory_kind), + host_buffer_semantics, std::move(on_done_with_host_buffer), + options.ifrt_user_context)); + return DevicePutResult(std::move(ifrt_array), /*weak_type=*/false); + }; +} + +absl::StatusOr HandlePyArray( + nb::handle obj, ifrt::Client* client, ifrt::Device* to_device, + const DevicePutOptions& options, ifrt::MemoryKind to_memory_kind) { + auto py_array = nb::borrow(obj); + + // We only allow single device case for PyArray in device put. + if (py_array.num_shards() != 1) { + return InvalidArgument( + "device_put expects an array with exactly one shard, got an array with " + "with %d shards.", + py_array.num_shards()); + } + + ifrt::Array* ifrt_array = py_array.ifrt_array(); + if (ifrt_array == nullptr) { + return InvalidArgument("Array has been deleted."); + } + + // Fallback to python for non-matching clients or pmap sharding. + if (py_array.sharding().type().ptr() == jax::PmapSharding::type().ptr() || + ifrt_array->sharding().devices()->devices().front()->client() != + to_device->client()) { + return HandleNumpyArray(obj.attr("_value"), client, to_device, options, + to_memory_kind); + } + + if (ifrt_array->sharding().devices()->devices().front() == to_device && + options.allow_zero_copy && + (!to_memory_kind.memory_kind().has_value() || + !ifrt_array->sharding().memory_kind().memory_kind().has_value() || + ifrt_array->sharding().memory_kind() == to_memory_kind)) { + DevicePutResult result(tsl::FormRef(ifrt_array), py_array.weak_type(), + /*owning_pybuffer=*/nb::borrow(obj)); + return [result = std::move(result)]() mutable { return std::move(result); }; + } else { + return [ifrt_array = tsl::FormRef(ifrt_array), to_device, to_memory_kind, + owning_pybuffer = py_array.weak_type(), + allow_zero_copy = options.allow_zero_copy]() mutable + -> absl::StatusOr { + auto* ifrt_client = ifrt_array->client(); + TF_ASSIGN_OR_RETURN( + auto copied_ifrt_arrays, + ifrt_client->CopyArrays( + absl::MakeSpan(&ifrt_array, 1), + ifrt_client->MakeDeviceList({to_device}), to_memory_kind, + allow_zero_copy ? ifrt::ArrayCopySemantics::kReuseInput + : ifrt::ArrayCopySemantics::kAlwaysCopy)); + return DevicePutResult(std::move(copied_ifrt_arrays[0]), + std::move(owning_pybuffer)); + }; + } +} + +} // namespace + +absl::StatusOr DevicePut(nb::handle arg, + ifrt::Client* client, + ifrt::Device* to_device, + const DevicePutOptions& options, + ifrt::MemoryKind to_memory_kind) { + tsl::profiler::TraceMe traceme("DevicePut"); + static const absl::flat_hash_map* const handlers = + [] { + auto p = new absl::flat_hash_map(); + const NumpyScalarTypes& dtypes = GetNumpyScalarTypes(); + // Python scalar types. + static_assert(sizeof(bool) == 1, + "Conversion code assumes bool is 1 byte"); + (*p)[reinterpret_cast(&PyBool_Type)] = + HandlePythonScalar; + (*p)[reinterpret_cast(&PyLong_Type)] = HandlePythonInt; + (*p)[reinterpret_cast(&PyFloat_Type)] = + HandlePythonScalar; + (*p)[reinterpret_cast(&PyComplex_Type)] = + HandlePythonScalar; + + (*p)[reinterpret_cast(&PyArray_Type)] = HandleNumpyArray; + + // Numpy scalar types. For some of them, we share the handler with + // Python types (np_int64, np_float64, np_complex128). + (*p)[dtypes.np_bool.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_int4.ptr()] = HandleNumpyScalar; + if (dtypes.np_int2.has_value()) { + (*p)[dtypes.np_int2->ptr()] = HandleNumpyScalar; + } + (*p)[dtypes.np_int8.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_int16.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_int32.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_int64.ptr()] = HandleNumpyScalar; + if (dtypes.np_uint2.has_value()) { + (*p)[dtypes.np_uint2->ptr()] = HandleNumpyScalar; + } + (*p)[dtypes.np_uint4.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_uint8.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_uint16.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_uint32.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_uint64.ptr()] = HandleNumpyScalar; + if (dtypes.np_float4_e2m1fn.has_value()) { + (*p)[dtypes.np_float4_e2m1fn->ptr()] = + HandleNumpyScalar; + } + if (dtypes.np_float8_e3m4.has_value()) { + (*p)[dtypes.np_float8_e3m4->ptr()] = + HandleNumpyScalar; + } + if (dtypes.np_float8_e4m3.has_value()) { + (*p)[dtypes.np_float8_e4m3->ptr()] = + HandleNumpyScalar; + } + (*p)[dtypes.np_float8_e4m3fn.ptr()] = + HandleNumpyScalar; + (*p)[dtypes.np_float8_e4m3b11fnuz.ptr()] = + HandleNumpyScalar; + (*p)[dtypes.np_float8_e5m2.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_float8_e4m3fnuz.ptr()] = + HandleNumpyScalar; + (*p)[dtypes.np_float8_e5m2fnuz.ptr()] = + HandleNumpyScalar; + if (dtypes.np_float8_e8m0fnu.has_value()) { + (*p)[dtypes.np_float8_e8m0fnu->ptr()] = + HandleNumpyScalar; + } + (*p)[dtypes.np_bfloat16.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_float16.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_float32.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_float64.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_complex64.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_complex128.ptr()] = + HandleNumpyScalar; + static_assert(sizeof(long long) == sizeof(int64_t), // NOLINT + "long long must be the same size as int64_t"); + (*p)[dtypes.np_longlong.ptr()] = HandleNumpyScalar; + static_assert(sizeof(int) == sizeof(int32_t), + "int must be the same size as int32_t"); + (*p)[dtypes.np_intc.ptr()] = HandleNumpyScalar; + + return p; + }(); + + if (arg.type().ptr() == PyArray::type().ptr()) { + auto array = nb::borrow(arg); + return HandlePyArray(arg, client, to_device, options, to_memory_kind); + } + + auto res = handlers->find(arg.type().ptr()); + if (res == handlers->end()) { + for (auto base_class : arg.type().attr("__mro__")) { + res = handlers->find(base_class.ptr()); + if (res != handlers->end()) { + return res->second(arg, client, to_device, options, to_memory_kind); + } + } + return InvalidArgument( + "%s", absl::StrCat( + "Not supported: The C++ jax jit execution path, only accepts " + "DeviceArray, Numpy arrays scalars of supported types " + "(see implementation), or Python scalars. Got type ", + nb::cast(nb::str(arg.type())))); + } + return res->second(arg, client, to_device, options, to_memory_kind); +} + +bool IsFloat0(xla::nb_numpy_ndarray arg) { + static const auto* dtypes_module = + new nb::module_(nb::module_::import_("jax.dtypes")); + static const auto* float0_dtype = + new nb::handle(dtypes_module->attr("float0")); + return float0_dtype->is(arg.attr("dtype")); +} + +std::string PyArgSignature::DebugString() const { + std::string result = ""; + if (weak_type) { + absl::StrAppend(&result, "weak_"); + } + absl::StrAppend(&result, xla::PrimitiveType_Name(dtype)); + absl::StrAppend(&result, "[", absl::StrJoin(shape, ","), "]"); + return result; +} + +using ToPyArgSignatureHandler = + std::function(nb::handle, bool)>; + +absl::StatusOr PyArgSignatureOfValue(nb::handle arg, + bool jax_enable_x64) { + static const absl::flat_hash_map* const + handlers = [] { + auto p = new absl::flat_hash_map(); + + const NumpyScalarTypes& dtypes = GetNumpyScalarTypes(); + + // The 4 Python native types. + ToPyArgSignatureHandler bool_handler = + [](nb::handle, bool) -> absl::StatusOr { + return PyArgSignature(PrimitiveType::PRED, {}, true); + }; + ToPyArgSignatureHandler int_handler = + [](nb::handle h, + bool jax_enable_x64) -> absl::StatusOr { + // TODO(phawkins): we should consider checking for integer overflow. + if (jax_enable_x64) { + return PyArgSignature(PrimitiveType::S64, {}, true); + } else { + return PyArgSignature(PrimitiveType::S32, {}, true); + } + }; + ToPyArgSignatureHandler float_handler = + [&dtypes](nb::handle h, + bool jax_enable_x64) -> absl::StatusOr { + // Only Python native types has a True weak_type. + bool weak_type = !nb::isinstance(h, dtypes.np_float64); + if (jax_enable_x64) { + return PyArgSignature(PrimitiveType::F64, {}, weak_type); + } else { + return PyArgSignature(PrimitiveType::F32, {}, weak_type); + } + }; + ToPyArgSignatureHandler complex_handler = + [&dtypes](nb::handle h, + bool jax_enable_x64) -> absl::StatusOr { + // Note that this branch is also taken for np.complex128: + // isinstance(np.complex128(3), complex) returns True + // isinstance(np.complex64(3), complex) returns False + bool weak_type = !nb::isinstance(h, dtypes.np_complex128); + if (jax_enable_x64) { + return PyArgSignature(PrimitiveType::C128, {}, weak_type); + } else { + return PyArgSignature(PrimitiveType::C64, {}, weak_type); + } + }; + + (*p)[reinterpret_cast(&PyBool_Type)] = bool_handler; + (*p)[reinterpret_cast(&PyLong_Type)] = int_handler; + (*p)[reinterpret_cast(&PyFloat_Type)] = float_handler; + (*p)[reinterpret_cast(&PyComplex_Type)] = complex_handler; + + ToPyArgSignatureHandler numpy_handler = + [](nb::handle h, + bool jax_enable_x64) -> absl::StatusOr { + xla::nb_numpy_ndarray numpy_array = + nb::cast(h); + TF_ASSIGN_OR_RETURN(PrimitiveType dtype, + DtypeToPrimitiveType(numpy_array.dtype())); + if (!jax_enable_x64) { + dtype = Squash64BitTypes(dtype); + } + // We use reinterpret_cast<> to defend against environments where + // ssize_t may not be precisely the same type as int64_t, even if it + // is the same size (long vs long long). + static_assert(sizeof(int64_t) == sizeof(ssize_t), + "Code assumes ssize_t is the same as int64_t"); + return PyArgSignature( + dtype, + absl::MakeConstSpan( + reinterpret_cast(numpy_array.shape()), + numpy_array.ndim()), + /*weak_type=*/false); + }; + (*p)[reinterpret_cast(&PyArray_Type)] = numpy_handler; + + ToPyArgSignatureHandler np_uint64_handler = + [](nb::handle h, + bool jax_enable_x64) -> absl::StatusOr { + if (jax_enable_x64) { + return PyArgSignature(PrimitiveType::U64, {}, /*weak_type=*/false); + } else { + return PyArgSignature(PrimitiveType::U32, {}, /*weak_type=*/false); + } + }; + ToPyArgSignatureHandler np_int_handler = + [](nb::handle h, + bool jax_enable_x64) -> absl::StatusOr { + if (jax_enable_x64) { + return PyArgSignature(PrimitiveType::S64, {}, /*weak_type=*/false); + } else { + return PyArgSignature(PrimitiveType::S32, {}, /*weak_type=*/false); + } + }; + ToPyArgSignatureHandler numpy_array_handler = + [](nb::handle h, + bool jax_enable_x64) -> absl::StatusOr { + // This block deals with all numpy scalar types, except for int64_dt, + // float64_dt and complex128_dt which are taken care of in previous if + // blocks. + TF_ASSIGN_OR_RETURN(auto dtype, + DtypeToPrimitiveType(h.attr("dtype"))); + return PyArgSignature(dtype, {}, /*weak_type=*/false); + }; + + // This block deals with all numpy scalar types, except for int64_dt, + // float64_dt and complex128_dt which are taken care of in previous if + // blocks. + (*p)[dtypes.np_bool.ptr()] = numpy_array_handler; + (*p)[dtypes.np_int4.ptr()] = numpy_array_handler; + (*p)[dtypes.np_int8.ptr()] = numpy_array_handler; + (*p)[dtypes.np_int16.ptr()] = numpy_array_handler; + (*p)[dtypes.np_int32.ptr()] = numpy_array_handler; + (*p)[dtypes.np_int64.ptr()] = np_int_handler; + (*p)[dtypes.np_uint4.ptr()] = numpy_array_handler; + (*p)[dtypes.np_uint8.ptr()] = numpy_array_handler; + (*p)[dtypes.np_uint16.ptr()] = numpy_array_handler; + (*p)[dtypes.np_uint32.ptr()] = numpy_array_handler; + (*p)[dtypes.np_uint64.ptr()] = np_uint64_handler; + // TODO(upwind): Explore if we can remove std::optional for these types + // in xla/python/types.h and xla/python/types.cc + if (dtypes.np_float4_e2m1fn.has_value()) { + (*p)[dtypes.np_float4_e2m1fn->ptr()] = numpy_array_handler; + } + if (dtypes.np_float8_e3m4.has_value()) { + (*p)[dtypes.np_float8_e3m4->ptr()] = numpy_array_handler; + } + if (dtypes.np_float8_e4m3.has_value()) { + (*p)[dtypes.np_float8_e4m3->ptr()] = numpy_array_handler; + } + (*p)[dtypes.np_float8_e4m3fn.ptr()] = numpy_array_handler; + (*p)[dtypes.np_float8_e4m3b11fnuz.ptr()] = numpy_array_handler; + (*p)[dtypes.np_float8_e4m3fnuz.ptr()] = numpy_array_handler; + (*p)[dtypes.np_float8_e5m2.ptr()] = numpy_array_handler; + (*p)[dtypes.np_float8_e5m2fnuz.ptr()] = numpy_array_handler; + if (dtypes.np_float8_e8m0fnu.has_value()) { + (*p)[dtypes.np_float8_e8m0fnu->ptr()] = numpy_array_handler; + } + (*p)[dtypes.np_float16.ptr()] = numpy_array_handler; + (*p)[dtypes.np_bfloat16.ptr()] = numpy_array_handler; + (*p)[dtypes.np_float32.ptr()] = numpy_array_handler; + (*p)[dtypes.np_float64.ptr()] = float_handler; + (*p)[dtypes.np_complex64.ptr()] = numpy_array_handler; + (*p)[dtypes.np_complex128.ptr()] = complex_handler; + (*p)[dtypes.np_longlong.ptr()] = np_int_handler; + (*p)[dtypes.np_intc.ptr()] = numpy_array_handler; + + return p; + }(); + + if (arg.type().ptr() == PyArray::type().ptr()) { + auto array = nb::borrow(arg); + ifrt::Array* ifrt_array = array.ifrt_array(); + if (ifrt_array == nullptr) { + return xla::InvalidArgument("Array has been deleted."); + } + TF_ASSIGN_OR_RETURN(auto primitive_type, + ifrt::ToPrimitiveType(ifrt_array->dtype())); + return PyArgSignature(primitive_type, array.shape(), array.weak_type()); + } + + auto res = handlers->find(arg.type().ptr()); + if (res == handlers->end()) { + // We attempt to look at the MRO classes + for (auto base_class : arg.type().attr("__mro__")) { + res = handlers->find(base_class.ptr()); + if (res != handlers->end()) { + return res->second(arg, jax_enable_x64); + } + } + return InvalidArgument( + "%s", + absl::StrCat("Not supported: The C++ ToPyArgSignature only accepts " + "Buffer/DeviceArray, Numpy " + "arrays scalars of supported types " + "(see implementation), or Python scalars. Got type ", + nb::cast(nb::str(arg.type())))); + } + return res->second(arg, jax_enable_x64); +} + +} // namespace xla diff --git a/jaxlib/xla/py_values.h b/jaxlib/xla/py_values.h new file mode 100644 index 000000000000..b64895100d8c --- /dev/null +++ b/jaxlib/xla/py_values.h @@ -0,0 +1,127 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Helpers for converting Python values into buffers. + +#ifndef JAXLIB_XLA_PY_VALUES_H_ +#define JAXLIB_XLA_PY_VALUES_H_ + +#include +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/functional/any_invocable.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/user_context.h" +#include "xla/python/nb_numpy.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +struct DevicePutResult { + explicit DevicePutResult( + tsl::RCReference ifrt_array, bool weak_type, + nanobind::object owning_pybuffer = nanobind::object()) + : ifrt_array(std::move(ifrt_array)), + weak_type(weak_type), + owning_pybuffer(owning_pybuffer) {} + + // Disallow copy since copying `DevicePutResult` without holding GIL may be + // dangerous due to `owning_pybuffer`. + DevicePutResult(const DevicePutResult&) = delete; + DevicePutResult& operator=(const DevicePutResult&) = delete; + DevicePutResult(DevicePutResult&&) noexcept = default; + DevicePutResult& operator=(DevicePutResult&&) noexcept = default; + + // Points to the on-device array. Not owned. + tsl::RCReference ifrt_array; + bool weak_type; + + nanobind::object owning_pybuffer; +}; + +// Copies a buffer-like object to be on device. +// +// If `arg` is not convertible to a `PjRtBuffer` from C++, an error will be +// returned; float0s are not supported yet. +// If the value is known to be a PyBuffer object, py_buffer can be passed as +// an optimization to avoid a Python->C++ cast. +// +// This function performs Python work inline but postpones C++ work until the +// returned function is called. The returned function must be called after +// releasing GIL. Useful for batching GIL release when there are many device_put +// to execute. +// +// May throw exceptions from nanobind in addition to failing via an error +// absl::Status. (We could catch these if needed, but there seems little point.) +struct DevicePutOptions { + bool squash_64bit_types = false; + bool allow_zero_copy = true; + tsl::RCReference ifrt_user_context; +}; +using DevicePutResultFn = + absl::AnyInvocable() &&>; +absl::StatusOr DevicePut(nanobind::handle arg, + ifrt::Client* client, + ifrt::Device* to_device, + const DevicePutOptions& options, + ifrt::MemoryKind to_memory_kind); + +// Returns `true` if `arg` is a JAX float0 array. +bool IsFloat0(xla::nb_numpy_ndarray arg); + +// Describes the abstract shape and dtype of an argument. +struct PyArgSignature { + PyArgSignature(PrimitiveType dtype, absl::Span shape, + bool weak_type) + : dtype(dtype), shape(shape.begin(), shape.end()), weak_type(weak_type) {} + // This is the XLA dtype of the object. + const PrimitiveType dtype; + const absl::InlinedVector shape; + // JAX arguments can be of weak type, if and only if they are Python scalars + // or `DeviceArray` values such that `aval.weak_type` is true. + const bool weak_type; + bool operator==(const PyArgSignature& other) const { + return std::tie(dtype, weak_type, shape) == + std::tie(other.dtype, other.weak_type, other.shape); + } + bool operator!=(const PyArgSignature& other) const { + return !(*this == other); + } + std::string DebugString() const; +}; + +// Returns the PyArgSignature associated with an argument. Returns an error if +// the argument is not supported. +absl::StatusOr PyArgSignatureOfValue(nanobind::handle arg, + bool jax_enable_x64); + +template +H AbslHashValue(H h, const xla::PyArgSignature& s) { + h = H::combine(std::move(h), s.dtype); + h = H::combine_contiguous(std::move(h), s.shape.data(), s.shape.size()); + return h; +} +} // namespace xla + +#endif // JAXLIB_XLA_PY_VALUES_H_ diff --git a/jaxlib/xla/python_ref_manager.cc b/jaxlib/xla/python_ref_manager.cc new file mode 100644 index 000000000000..5b85d2ab84cb --- /dev/null +++ b/jaxlib/xla/python_ref_manager.cc @@ -0,0 +1,106 @@ +/* Copyright 2019 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/python_ref_manager.h" + +#include + +#include +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" + +namespace xla { + +namespace nb = nanobind; + +PythonRefManager::ManagedPyObjects::ManagedPyObjects( + PythonRefManager* manager, absl::Span objects) + : manager_(manager) { + objects_.reserve(objects.size()); + for (nb::object& object : objects) { + objects_.push_back(std::move(object)); + } +} + +PythonRefManager::ManagedPyObjects::~ManagedPyObjects() { + if (manager_ && !objects_.empty()) { + manager_->AddGarbage(absl::MakeSpan(objects_)); + } +} + +std::shared_ptr +PythonRefManager::ManageReference(nb::object object) { + return std::make_shared(this, + absl::Span(&object, 1)); +} + +std::shared_ptr +PythonRefManager::ManageReferences(absl::Span objects) { + return std::make_shared(this, objects); +} + +void PythonRefManager::AddGarbage(nb::object garbage) { + absl::MutexLock lock(&mu_); + // We want to collect arbitrary python garbage (e.g., buffers) aggressively. + garbage_count_.fetch_add(100, std::memory_order_relaxed); + python_garbage_.push_back(std::move(garbage)); +} + +void PythonRefManager::AddGarbage(absl::Span garbage) { + absl::MutexLock lock(&mu_); + // We want to collect arbitrary python garbage (e.g., buffers) aggressively. + garbage_count_.fetch_add(100, std::memory_order_relaxed); + for (nb::object& o : garbage) { + python_garbage_.push_back(std::move(o)); + } +} + +void PythonRefManager::AddGarbage( + absl::Span const> garbage) { + absl::MutexLock lock(&mu_); + // We don't care about collecting stack frame objects often. We grab a lot of + // tracebacks and the code objects are most likely live for the entire + // process. + garbage_count_.fetch_add(1, std::memory_order_relaxed); + for (const auto& o : garbage) { + python_garbage_.push_back(nb::steal(reinterpret_cast(o.first))); + } +} + +void PythonRefManager::CollectGarbage() { + // TODO(phawkins): we should CHECK(PyGILState_Check()); + std::deque garbage; + { + absl::MutexLock lock(&mu_); + garbage_count_ = 0; + garbage.swap(python_garbage_); + } + // We defer deleting garbage until the lock is released. It's possible that + // deleting garbage will lead to more Python garbage being added; if we held + // the lock we would deadlock because absl::Mutex is not reentrant. +} + +PythonRefManager* GlobalPyRefManager() { + static PythonRefManager* static_ref_manager = new PythonRefManager(); + return static_ref_manager; +} + +} // namespace xla diff --git a/jaxlib/xla/python_ref_manager.h b/jaxlib/xla/python_ref_manager.h new file mode 100644 index 000000000000..c0630da2ebd5 --- /dev/null +++ b/jaxlib/xla/python_ref_manager.h @@ -0,0 +1,108 @@ +/* Copyright 2019 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_PYTHON_REF_MANAGER_H_ +#define JAXLIB_XLA_PYTHON_REF_MANAGER_H_ + +#include + +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/container/inlined_vector.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" + +namespace xla { + +// Class that manages destruction of Python objects. +// +// We must not destroy Python objects without holding the GIL. However, we +// frequently want to hold references to Python objects for the duration of +// an asynchronous transfer on a Stream, and release our reference when the +// transfer completes. +// +// This class holds references to Python objects outside a GIL scope, that can +// be collected later when the GIL is held by calling CollectGarbage(). +class PythonRefManager { + public: + PythonRefManager() = default; + + // Holds references to a set of nanobind::objects, adding the references to + // the PythonRefManager on destruction. + class ManagedPyObjects { + public: + ManagedPyObjects() = default; + ManagedPyObjects(PythonRefManager* manager, + absl::Span objects); + + ~ManagedPyObjects(); + + ManagedPyObjects(const ManagedPyObjects& other) = delete; + ManagedPyObjects(ManagedPyObjects&& other) = default; + ManagedPyObjects& operator=(const ManagedPyObjects& other) = delete; + ManagedPyObjects& operator=(ManagedPyObjects&& other) noexcept = default; + + private: + PythonRefManager* manager_ = nullptr; + absl::InlinedVector objects_; + }; + + // Creates a managed std::shared_ptr to an object. When the shared_ptr is + // destroyed, the reference to 'object' will be added to python_garbage_, + // and collected next time CollectGarbage() is called. + std::shared_ptr ManageReference(nanobind::object object); + std::shared_ptr ManageReferences( + absl::Span objects); + + // Adds garbage objects to the manager. + void AddGarbage(nanobind::object garbage); + void AddGarbage(absl::Span garbage); + void AddGarbage(absl::Span const> garbage); + + // Releases the contents of python_garbage_. Requires that the GIL is held. + // The client calls this method during API entry points where the GIL is held + // to free any garbage that has accumulated. + void CollectGarbage(); + + // Cheaper version of CollectGarbage() with relaxed consistency and frequency. + // The purpose of this function is to amortize lock acquisition costs over + // a larger number of API calls. + void MaybeCollectGarbage() { + if (garbage_count_.load(std::memory_order_relaxed) >= 100) { + CollectGarbage(); + } + } + + private: + absl::Mutex mu_; + std::deque python_garbage_ ABSL_GUARDED_BY(mu_); + + // Writes to garbage_count_ are protected by mu_, reads are not protected. + std::atomic garbage_count_{0}; +}; + +// A global PythonRefManager. Unless `CollectGarbage()` is called before +// shutdown, this container will hold on to Python objects and thus cause a +// leak. This behavior is similar to `tensorflow::ClearDecRefCache()`. +PythonRefManager* GlobalPyRefManager(); + +} // namespace xla + +#endif // JAXLIB_XLA_PYTHON_REF_MANAGER_H_ diff --git a/jaxlib/xla/pytree.cc b/jaxlib/xla/pytree.cc new file mode 100644 index 000000000000..9359165b19dd --- /dev/null +++ b/jaxlib/xla/pytree.cc @@ -0,0 +1,1831 @@ +/* Copyright 2019 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Caution: this code uses exceptions. The exception use is local to the +// binding code and the idiomatic way to emit Python exceptions. + +#include "jaxlib/xla/pytree.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" +#include "absl/hash/hash.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/pair.h" // IWYU pragma: keep +#include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/tuple.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/xla/nb_class_ptr.h" +#include "jaxlib/xla/pytree.pb.h" +#include "xla/pjrt/exceptions.h" +#include "xla/tsl/platform/logging.h" + +namespace xla { + +namespace nb = nanobind; + +constexpr int kSequenceKeyHashSalt = 1; +constexpr int kFlattenedIndexKeyHashSalt = 42; + +PyTreeRegistry::PyTreeRegistry(bool enable_none, bool enable_tuple, + bool enable_namedtuple, bool enable_list, + bool enable_dict) { + auto add_builtin_type = [&](PyTypeObject* type_obj, PyTreeKind kind) { + nb::object type = + nb::borrow(reinterpret_cast(type_obj)); + auto registration = std::make_unique(); + registration->kind = kind; + registration->type = type; + CHECK(registrations_.emplace(type, std::move(registration)).second); + }; + if (enable_none) { + add_builtin_type(Py_TYPE(Py_None), PyTreeKind::kNone); + } + if (enable_tuple) { + add_builtin_type(&PyTuple_Type, PyTreeKind::kTuple); + } + enable_namedtuple_ = enable_namedtuple; + if (enable_list) { + add_builtin_type(&PyList_Type, PyTreeKind::kList); + } + if (enable_dict) { + add_builtin_type(&PyDict_Type, PyTreeKind::kDict); + } +} + +void PyTreeRegistry::Register( + nb::object type, nb::callable to_iterable, nb::callable from_iterable, + std::optional to_iterable_with_keys) { + auto registration = std::make_unique(); + registration->kind = PyTreeKind::kCustom; + registration->type = type; + registration->to_iterable = std::move(to_iterable); + registration->from_iterable = std::move(from_iterable); + registration->to_iterable_with_keys = std::move(to_iterable_with_keys); + nb::ft_lock_guard lock(mu_); + auto it = registrations_.emplace(type, std::move(registration)); + if (!it.second) { + throw std::invalid_argument( + absl::StrFormat("Duplicate custom PyTreeDef type registration for %s.", + nb::cast(nb::repr(type)))); + } +} + +void PyTreeRegistry::RegisterDataclass(nb::object type, + std::vector data_fields, + std::vector meta_fields) { + auto registration = std::make_unique(); + registration->kind = PyTreeKind::kDataclass; + registration->type = type; + registration->data_fields = std::move(data_fields); + registration->meta_fields = std::move(meta_fields); + nb::ft_lock_guard lock(mu_); + auto it = registrations_.emplace(type, std::move(registration)); + if (!it.second) { + throw std::invalid_argument(absl::StrFormat( + "Duplicate custom dataclass PyTreeDef type registration for %s.", + nb::cast(nb::repr(std::move(type))))); + } +} + +std::pair +PyTreeRegistry::Registration::ToIterable(nanobind::handle o) const { + nb::object out = to_iterable(o); + nb::tuple leaves_and_aux_data; + if (!nb::try_cast(out, leaves_and_aux_data) || + leaves_and_aux_data.size() != 2) { + throw std::invalid_argument(absl::StrCat( + "The to_iterable function for a custom PyTree node should return " + "a (children, aux_data) tuple, got ", + nb::cast(nb::repr(out)))); + } + nb::iterable leaves; + if (!nb::try_cast(leaves_and_aux_data[0], leaves)) { + throw std::invalid_argument(absl::StrCat( + "The to_iterable function for a custom PyTree node should return " + "a (children, aux_data) tuple where 'children' is iterable, " + "got ", + nb::cast(nb::repr(out)))); + } + return std::make_pair(std::move(leaves), nb::object(leaves_and_aux_data[1])); +} + +std::pair>, nb::object> +PyTreeRegistry::Registration::ToIterableWithKeys(nb::handle o) const { + // Backwards compatibility case: return dummy FlattenedIndexKey for each leaf. + std::vector> result; + if (!to_iterable_with_keys.has_value()) { + auto [leaves, aux_data] = ToIterable(o); + for (nb::handle leaf : leaves) { + result.push_back(std::make_pair( + make_nb_class(result.size()), nb::borrow(leaf))); + } + return std::make_pair(std::move(result), std::move(aux_data)); + } + nb::object out = to_iterable_with_keys.value()(o); + nb::tuple leaves_and_aux_data; + if (!nb::try_cast(out, leaves_and_aux_data) || + leaves_and_aux_data.size() != 2) { + throw std::invalid_argument(absl::StrCat( + "The to_iterable_with_keys function for a custom PyTree " + "node should return a (key_leaf_pairs, aux_data) tuple, got ", + nb::cast(nb::repr(out)))); + } + nb::iterable key_leaf_pairs; + if (!nb::try_cast(leaves_and_aux_data[0], key_leaf_pairs)) { + throw std::invalid_argument(absl::StrCat( + "The to_iterable_with_keys function for a custom PyTree node should " + "return a (key_leaf_pairs, aux_data) tuple where 'key_leaf_pairs' is " + "iterable, got ", + nb::cast(nb::repr(leaves_and_aux_data)))); + } + for (nb::handle key_leaf_pair : key_leaf_pairs) { + nb::tuple key_leaf_pair_tuple; + if (!nb::try_cast(key_leaf_pair, key_leaf_pair_tuple) || + key_leaf_pair_tuple.size() != 2) { + throw std::invalid_argument(absl::StrCat( + "The to_iterable_with_keys function for a custom PyTree node should " + "return a (key_leaf_pairs, aux_data) tuple where 'child", + nb::cast(nb::repr(key_leaf_pair)))); + } + result.push_back(std::make_pair(nb::borrow(key_leaf_pair_tuple[0]), + nb::borrow(key_leaf_pair_tuple[1]))); + } + return std::make_pair(std::move(result), nb::object(leaves_and_aux_data[1])); +} + +int PyTreeRegistry::Registration::tp_traverse(visitproc visit, void* arg) { + Py_VISIT(type.ptr()); + Py_VISIT(to_iterable.ptr()); + Py_VISIT(from_iterable.ptr()); + for (const auto& field : data_fields) { + Py_VISIT(field.ptr()); + } + for (const auto& field : meta_fields) { + Py_VISIT(field.ptr()); + } + return 0; +} + +// Computes the node kind of a given Python object. +PyTreeKind PyTreeRegistry::KindOfObject( + nb::handle obj, PyTreeRegistry::Registration const** custom) const { + const PyTreeRegistry::Registration* registration = Lookup(obj.type()); + if (registration) { + if (registration->kind == PyTreeKind::kCustom || + registration->kind == PyTreeKind::kDataclass) { + *custom = registration; + } else { + *custom = nullptr; + } + return registration->kind; + } else if (nb::isinstance(obj) && nb::hasattr(obj, "_fields")) { + // We can only identify namedtuples heuristically, here by the presence of + // a _fields attribute. + return PyTreeKind::kNamedTuple; + } else { + return PyTreeKind::kLeaf; + } +} + +/*static*/ const PyTreeRegistry::Registration* PyTreeRegistry::Lookup( + nb::handle type) const { + nb::ft_lock_guard lock(mu_); + auto it = registrations_.find(type); + return it == registrations_.end() ? nullptr : it->second.get(); +} + +/*static*/ std::vector GetSortedPyDictKeys(PyObject* py_dict) { + std::vector keys; + keys.reserve(PyDict_Size(py_dict)); + PyObject* key; + Py_ssize_t pos = 0; + while (PyDict_Next(py_dict, &pos, &key, /*value=*/nullptr)) { + keys.push_back(nb::borrow(key)); + } + + try { + std::stable_sort( + keys.begin(), keys.end(), [](const nb::object& a, const nb::object& b) { + int cmp = PyObject_RichCompareBool(a.ptr(), b.ptr(), Py_LT); + if (cmp == -1) { + throw nb::python_error(); + } + return cmp; + }); + } catch (nb::python_error& e) { + nb::raise_from(e, PyExc_ValueError, + "Comparator raised exception while sorting pytree " + "dictionary keys."); + } + return keys; +} + +/*static*/ bool IsSortedPyDictKeysEqual(absl::Span lhs, + absl::Span rhs) { + if (lhs.size() != rhs.size()) { + return false; + } + for (int i = 0; i < lhs.size(); ++i) { + if (lhs[i].not_equal(rhs[i])) { + return false; + } + } + return true; +} + +bool PyTreeDef::operator==(const PyTreeDef& other) const { + if (traversal_.size() != other.traversal_.size()) { + return false; + } + for (size_t i = 0; i < traversal_.size(); ++i) { + const Node& a = traversal_[i]; + const Node& b = other.traversal_[i]; + if (a.kind != b.kind || a.arity != b.arity || + (a.node_data.ptr() == nullptr) != (b.node_data.ptr() == nullptr) || + (a.sorted_dict_keys.size() != b.sorted_dict_keys.size()) || + a.custom != b.custom) { + return false; + } + if (a.node_data && a.node_data.not_equal(b.node_data)) { + return false; + } + if (!IsSortedPyDictKeysEqual(a.sorted_dict_keys, b.sorted_dict_keys)) { + return false; + } + // We don't need to test equality of num_leaves and num_nodes since they + // are derivable from the other node data. + } + return true; +} + +nb::object PyTreeRegistry::FlattenOneLevel(nb::handle x) const { + return FlattenOneLevelImpl(x, /*with_keys=*/false); +} + +nb::object PyTreeRegistry::FlattenOneLevelWithKeys(nb::handle x) const { + return FlattenOneLevelImpl(x, /*with_keys=*/true); +} + +nb::object PyTreeRegistry::FlattenOneLevelImpl(nb::handle x, + bool with_keys) const { + PyTreeRegistry::Registration const* custom; + PyTreeKind kind = KindOfObject(x, &custom); + switch (kind) { + case PyTreeKind::kNone: + return nb::make_tuple(nb::make_tuple(), nb::none()); + case PyTreeKind::kTuple: { + if (with_keys) { + auto size = PyTuple_GET_SIZE(x.ptr()); + nb::object key_leaves = nb::steal(PyTuple_New(size)); + for (int i = 0; i < size; ++i) { + nb::object key = make_nb_class(i); + nb::object value = + nb::borrow(PyTuple_GET_ITEM(x.ptr(), i)); + PyTuple_SET_ITEM(key_leaves.ptr(), i, + nb::make_tuple(key, value).release().ptr()); + } + return nb::make_tuple(std::move(key_leaves), nb::none()); + } + return nb::make_tuple(nb::borrow(x), nb::none()); + } + case PyTreeKind::kList: { + if (with_keys) { + auto size = PyList_GET_SIZE(x.ptr()); + nb::object key_leaves = nb::steal(PyTuple_New(size)); + for (int i = 0; i < size; ++i) { + nb::object key = make_nb_class(i); + nb::object value = + nb::borrow(PyList_GET_ITEM(x.ptr(), i)); + PyTuple_SET_ITEM(key_leaves.ptr(), i, + nb::make_tuple(key, value).release().ptr()); + } + return nb::make_tuple(std::move(key_leaves), nb::none()); + } + return nb::make_tuple(nb::borrow(x), nb::none()); + } + case PyTreeKind::kDict: { + nb::dict dict = nb::borrow(x); + std::vector sorted_keys = GetSortedPyDictKeys(dict.ptr()); + nb::tuple keys = nb::steal(PyTuple_New(sorted_keys.size())); + nb::tuple values = nb::steal(PyTuple_New(sorted_keys.size())); + for (size_t i = 0; i < sorted_keys.size(); ++i) { + nb::object& key = sorted_keys[i]; + nb::object value = nb::object(dict[key]); + if (with_keys) { + value = nb::make_tuple(make_nb_class(key), value); + } + PyTuple_SET_ITEM(values.ptr(), i, value.release().ptr()); + PyTuple_SET_ITEM(keys.ptr(), i, sorted_keys[i].release().ptr()); + } + return nb::make_tuple(std::move(values), std::move(keys)); + } + case PyTreeKind::kNamedTuple: { + nb::tuple in = nb::borrow(x); + nb::list out; + if (with_keys) { + // Get key names from NamedTuple fields. + nb::tuple fields; + if (!nb::try_cast(nb::getattr(in, "_fields"), fields) || + in.size() != fields.size()) { + throw std::invalid_argument( + "A namedtuple's _fields attribute should have the same size as " + "the tuple."); + } + auto field_iter = fields.begin(); + for (nb::handle entry : in) { + out.append(nb::make_tuple( + make_nb_class(nb::str(*field_iter)), entry)); + } + return nb::make_tuple(std::move(out), x.type()); + } + for (size_t i = 0; i < in.size(); ++i) { + out.append(in[i]); + } + return nb::make_tuple(std::move(out), x.type()); + } + case PyTreeKind::kCustom: { + if (with_keys) { + auto [leaves, aux_data] = custom->ToIterableWithKeys(x); + return nb::make_tuple(std::move(leaves), std::move(aux_data)); + } + auto [leaves, aux_data] = custom->ToIterable(x); + return nb::make_tuple(std::move(leaves), std::move(aux_data)); + } + case PyTreeKind::kDataclass: { + auto data_size = custom->data_fields.size(); + nb::list leaves = nb::steal(PyList_New(data_size)); + for (int leaf = 0; leaf < data_size; ++leaf) { + nb::object value = nb::getattr(x, custom->data_fields[leaf]); + if (with_keys) { + value = nb::make_tuple( + make_nb_class(custom->data_fields[leaf]), value); + } + PyList_SET_ITEM(leaves.ptr(), leaf, value.release().ptr()); + } + auto meta_size = custom->meta_fields.size(); + nb::object aux_data = nb::steal(PyTuple_New(meta_size)); + for (int meta_leaf = 0; meta_leaf < meta_size; ++meta_leaf) { + PyTuple_SET_ITEM( + aux_data.ptr(), meta_leaf, + nb::getattr(x, custom->meta_fields[meta_leaf]).release().ptr()); + } + return nb::make_tuple(std::move(leaves), std::move(aux_data)); + } + default: + DCHECK(kind == PyTreeKind::kLeaf); + return nb::none(); + } +} + +/* static */ PyType_Slot PyTreeRegistry::slots_[] = { + {Py_tp_traverse, (void*)PyTreeRegistry::tp_traverse}, + {Py_tp_clear, (void*)PyTreeRegistry::tp_clear}, + {0, nullptr}, +}; + +/* static */ int PyTreeRegistry::tp_traverse(PyObject* self, visitproc visit, + void* arg) { + Py_VISIT(Py_TYPE(self)); + if (!nb::inst_ready(self)) { + return 0; + } + PyTreeRegistry* registry = nb::inst_ptr(self); + nb::ft_lock_guard lock(registry->mu_); + for (const auto& [key, value] : registry->registrations_) { + Py_VISIT(key.ptr()); + int rval = value->tp_traverse(visit, arg); + if (rval != 0) { + return rval; + } + } + return 0; +} + +/* static */ int PyTreeRegistry::tp_clear(PyObject* self) { + PyTreeRegistry* registry = nb::inst_ptr(self); + nb::ft_lock_guard lock(registry->mu_); + registry->registrations_.clear(); + return 0; +} + +/* static */ PyType_Slot DictKey::slots_[] = { + {Py_tp_traverse, (void*)DictKey::tp_traverse}, + {Py_tp_clear, (void*)DictKey::tp_clear}, + {0, nullptr}, +}; + +/* static */ int DictKey::tp_traverse(PyObject* self, visitproc visit, + void* arg) { + DictKey* key = nb::inst_ptr(self); + Py_VISIT(key->key_.ptr()); + return 0; +} + +/* static */ int DictKey::tp_clear(PyObject* self) { + DictKey* dictkey = nb::inst_ptr(self); + nb::object tmp; + std::swap(tmp, dictkey->key_); + return 0; +} + +std::string SequenceKey::ToString() const { + return absl::StrFormat("[%d]", idx_); +} + +std::string SequenceKey::ToReprString() const { + return absl::StrFormat("SequenceKey(idx=%d)", idx_); +} + +std::string DictKey::ToString() const { + return absl::StrFormat("[%s]", nb::cast(nb::repr(key_))); +} + +std::string DictKey::ToReprString() const { + return absl::StrFormat("DictKey(key=%s)", + nb::cast(nb::repr(key_))); +} + +std::string GetAttrKey::ToString() const { + return absl::StrFormat(".%s", nb::cast(name_)); +} + +std::string GetAttrKey::ToReprString() const { + return absl::StrFormat("GetAttrKey(name='%s')", + nb::cast(name_)); +} + +std::string FlattenedIndexKey::ToString() const { + return absl::StrFormat("[]", key_); +} + +std::string FlattenedIndexKey::ToReprString() const { + return absl::StrFormat("FlattenedIndexKey(key=%d)", key_); +} + +bool SequenceKey::Equals(const nb::object& other) { + SequenceKey other_key(0); + if (!nb::try_cast(other, other_key)) return false; + return idx_ == other_key.idx(); +} + +bool DictKey::Equals(const nb::object& other) { + DictKey other_key(nb::none()); + if (!nb::try_cast(other, other_key)) return false; + return key_.equal(other_key.key()); +} + +bool GetAttrKey::Equals(const nb::object& other) { + GetAttrKey other_key(nb::str("")); + if (!nb::try_cast(other, other_key)) return false; + return name_.equal(other_key.name()); +} + +bool FlattenedIndexKey::Equals(const nb::object& other) { + FlattenedIndexKey other_key(0); + if (!nb::try_cast(other, other_key)) return false; + return key_ == other_key.key(); +} + +nanobind::tuple SequenceKey::MatchArgs(nanobind::handle unused) { + return nanobind::make_tuple("idx"); +}; + +nanobind::tuple DictKey::MatchArgs(nanobind::handle unused) { + return nanobind::make_tuple("key"); +}; + +nanobind::tuple GetAttrKey::MatchArgs(nanobind::handle unused) { + return nanobind::make_tuple("name"); +}; + +nanobind::tuple FlattenedIndexKey::MatchArgs(nanobind::handle unused) { + return nanobind::make_tuple("key"); +}; + +template +void PyTreeDef::FlattenImpl(nb::handle handle, T& leaves, + const std::optional& leaf_predicate, + std::optional>& keypath) { + Node node; + const int start_num_nodes = traversal_.size(); + const int start_num_leaves = leaves.size(); + bool is_known_leaf = false; + if (leaf_predicate) { + nb::object o = (*leaf_predicate)(handle); + // Historically we accepted "truthy" values from leaf predicates. Accept + // None here to keep existing clients happy. + if (o.is_none()) { + is_known_leaf = false; + } else if (!nb::try_cast(o, is_known_leaf)) { + throw std::invalid_argument(absl::StrCat( + "is_leaf predicate returned a non-boolean value ", + nb::cast(nb::repr(o)), "; expected a boolean")); + } + } + if (is_known_leaf) { + nb::object value = nb::borrow(handle); + if (keypath.has_value()) { + const std::vector& frozen_keypath = keypath.value(); + nb::object kp_tuple = nb::steal(PyTuple_New(frozen_keypath.size())); + for (int i = 0; i < frozen_keypath.size(); ++i) { + PyTuple_SET_ITEM(kp_tuple.ptr(), i, + nb::object(frozen_keypath[i]).release().ptr()); + } + value = nb::make_tuple(std::move(kp_tuple), std::move(value)); + } + if constexpr (std::is_same_v) { + leaves.append(std::move(value)); + } else { + leaves.push_back(std::move(value)); + } + } else { + node.kind = registry_->KindOfObject(handle, &node.custom); + auto recurse = [this, &leaf_predicate, &leaves]( + nb::handle child, + std::optional>& keypath) { + if (Py_EnterRecursiveCall( + " in flatten; PyTree may have cyclical node references.")) { + return; + } + FlattenImpl(child, leaves, leaf_predicate, keypath); + Py_LeaveRecursiveCall(); + }; + switch (node.kind) { + case PyTreeKind::kNone: + // Nothing to do. + break; + case PyTreeKind::kTuple: { + node.arity = PyTuple_GET_SIZE(handle.ptr()); + for (int i = 0; i < node.arity; ++i) { + if (keypath.has_value()) { + keypath->push_back(make_nb_class(i)); + } + recurse(PyTuple_GET_ITEM(handle.ptr(), i), keypath); + if (keypath.has_value()) { + keypath->pop_back(); + } + } + break; + } + case PyTreeKind::kList: { + node.arity = PyList_GET_SIZE(handle.ptr()); + for (int i = 0; i < node.arity; ++i) { + if (keypath.has_value()) { + keypath->push_back(make_nb_class(i)); + } + recurse(PyList_GET_ITEM(handle.ptr(), i), keypath); + if (keypath.has_value()) { + keypath->pop_back(); + } + } + break; + } + case PyTreeKind::kDict: { + nb::dict dict = nb::borrow(handle); + + std::vector keys = GetSortedPyDictKeys(dict.ptr()); + for (nb::object& key : keys) { + if (keypath.has_value()) { + keypath->push_back(make_nb_class(key)); + } + recurse(dict[key], keypath); + if (keypath.has_value()) { + keypath->pop_back(); + } + } + node.arity = dict.size(); + node.sorted_dict_keys = std::move(keys); + break; + } + case PyTreeKind::kCustom: { + if (keypath.has_value()) { + auto [leaves, aux_data] = node.custom->ToIterableWithKeys(handle); + node.node_data = std::move(aux_data); + node.arity = 0; + for (auto& [key, leaf] : leaves) { + keypath->push_back(key); + ++node.arity; + recurse(leaf, keypath); + keypath->pop_back(); + } + } else { + auto [leaves, aux_data] = node.custom->ToIterable(handle); + node.node_data = std::move(aux_data); + node.arity = 0; + for (nb::handle entry : leaves) { + ++node.arity; + recurse(entry, keypath); + } + } + break; + } + case PyTreeKind::kDataclass: { + auto meta_size = node.custom->meta_fields.size(); + nb::object aux_data = nb::steal(PyTuple_New(meta_size)); + for (int meta_leaf = 0; meta_leaf < meta_size; ++meta_leaf) { + PyTuple_SET_ITEM( + aux_data.ptr(), meta_leaf, + nb::getattr(handle, node.custom->meta_fields[meta_leaf]) + .release() + .ptr()); + } + node.node_data = std::move(aux_data); + auto data_size = node.custom->data_fields.size(); + node.arity = data_size; + for (int leaf = 0; leaf < data_size; ++leaf) { + if (keypath.has_value()) { + keypath->push_back( + make_nb_class(node.custom->data_fields[leaf])); + } + recurse(nb::getattr(handle, node.custom->data_fields[leaf]), keypath); + if (keypath.has_value()) { + keypath->pop_back(); + } + } + break; + } + case PyTreeKind::kNamedTuple: { + nb::tuple tuple = nb::borrow(handle); + node.arity = tuple.size(); + node.node_data = nb::borrow(tuple.type()); + if (keypath.has_value()) { + // Get key names from NamedTuple fields. + nb::tuple fields; + if (!nb::try_cast(nb::getattr(tuple, "_fields"), fields) || + tuple.size() != fields.size()) { + throw std::invalid_argument( + "A namedtuple's _fields attribute should have the same size as " + "the tuple."); + } + auto field_iter = fields.begin(); + for (nb::handle entry : tuple) { + keypath->push_back(make_nb_class(nb::str(*field_iter))); + field_iter++; + recurse(entry, keypath); + keypath->pop_back(); + } + } else { + for (nb::handle entry : tuple) { + recurse(entry, keypath); + } + } + break; + } + default: + DCHECK(node.kind == PyTreeKind::kLeaf); + auto value = nb::borrow(handle); + if (keypath.has_value()) { + const std::vector& frozen_keypath = keypath.value(); + nb::object kp_tuple = nb::steal(PyTuple_New(frozen_keypath.size())); + for (int i = 0; i < frozen_keypath.size(); ++i) { + PyTuple_SET_ITEM(kp_tuple.ptr(), i, + nb::object(frozen_keypath[i]).release().ptr()); + } + value = nb::make_tuple(std::move(kp_tuple), std::move(value)); + } + if constexpr (std::is_same_v) { + leaves.append(std::move(value)); + } else { + leaves.push_back(std::move(value)); + } + } + } + node.num_nodes = traversal_.size() - start_num_nodes + 1; + node.num_leaves = leaves.size() - start_num_leaves; + traversal_.push_back(std::move(node)); +} + +void PyTreeDef::Flatten(nb::handle handle, + absl::InlinedVector& leaves, + std::optional leaf_predicate) { + std::optional> keypath = std::nullopt; + FlattenImpl(handle, leaves, leaf_predicate, keypath); +} + +void PyTreeDef::Flatten(nb::handle handle, std::vector& leaves, + std::optional leaf_predicate) { + std::optional> keypath = std::nullopt; + FlattenImpl(handle, leaves, leaf_predicate, keypath); +} + +void PyTreeDef::Flatten(nb::handle handle, nb::list& leaves, + std::optional leaf_predicate) { + std::optional> keypath = std::nullopt; + FlattenImpl(handle, leaves, leaf_predicate, keypath); +} + +/*static*/ std::pair, nb_class_ptr> +PyTreeDef::Flatten(nb::handle x, nb_class_ptr registry, + std::optional leaf_predicate) { + auto def = make_nb_class(registry); + std::vector leaves; + def->Flatten(x, leaves, leaf_predicate); + return std::make_pair(std::move(leaves), std::move(def)); +} + +void PyTreeDef::FlattenWithPath(nb::handle handle, nanobind::list& leaves, + std::optional leaf_predicate) { + std::optional> keypath = std::vector(); + FlattenImpl(handle, leaves, leaf_predicate, keypath); +} + +/*static*/ bool PyTreeDef::AllLeaves(PyTreeRegistry* registry, + const nb::iterable& x) { + const PyTreeRegistry::Registration* custom; + for (const nb::handle& h : x) { + if (registry->KindOfObject(h, &custom) != PyTreeKind::kLeaf) return false; + } + return true; +} + +template +nb::object PyTreeDef::UnflattenImpl(T leaves) const { + absl::InlinedVector agenda; + auto it = leaves.begin(); + int leaf_count = 0; + for (const Node& node : traversal_) { + if (agenda.size() < node.arity) { + throw std::logic_error("Too few elements for TreeDef node."); + } + switch (node.kind) { + case PyTreeKind::kLeaf: + if (it == leaves.end()) { + throw std::invalid_argument(absl::StrFormat( + "Too few leaves for PyTreeDef; expected %d, got %d", num_leaves(), + leaf_count)); + } + agenda.push_back(nb::borrow(*it)); + ++it; + ++leaf_count; + break; + + case PyTreeKind::kNone: + case PyTreeKind::kTuple: + case PyTreeKind::kNamedTuple: + case PyTreeKind::kList: + case PyTreeKind::kDict: + case PyTreeKind::kCustom: + case PyTreeKind::kDataclass: { + const int size = agenda.size(); + absl::Span span; + if (node.arity > 0) { + span = absl::Span(&agenda[size - node.arity], node.arity); + } + nb::object o = MakeNode(node, span); + agenda.resize(size - node.arity); + agenda.push_back(o); + break; + } + } + } + if (it != leaves.end()) { + throw std::invalid_argument(absl::StrFormat( + "Too many leaves for PyTreeDef; expected %d.", num_leaves())); + } + if (agenda.size() != 1) { + throw std::logic_error("PyTreeDef traversal did not yield a singleton."); + } + return std::move(agenda.back()); +} + +nb::object PyTreeDef::Unflatten(nb::iterable leaves) const { + return UnflattenImpl(leaves); +} + +nb::object PyTreeDef::Unflatten(absl::Span leaves) const { + return UnflattenImpl(leaves); +} + +/*static*/ nb::object PyTreeDef::MakeNode(const PyTreeDef::Node& node, + absl::Span children) { + if (children.size() != node.arity) { + throw std::logic_error("Node arity mismatch."); + } + switch (node.kind) { + case PyTreeKind::kLeaf: + throw std::logic_error("MakeNode not implemented for leaves."); + + case PyTreeKind::kNone: + return nb::none(); + + case PyTreeKind::kTuple: + case PyTreeKind::kNamedTuple: { + nb::object tuple = nb::steal(PyTuple_New(node.arity)); + for (int i = 0; i < node.arity; ++i) { + PyTuple_SET_ITEM(tuple.ptr(), i, children[i].release().ptr()); + } + if (node.kind == PyTreeKind::kNamedTuple) { + return node.node_data(*tuple); + } else { + return tuple; + } + } + + case PyTreeKind::kList: { + nb::object list = nb::steal(PyList_New(node.arity)); + for (int i = 0; i < node.arity; ++i) { + PyList_SET_ITEM(list.ptr(), i, children[i].release().ptr()); + } + return list; + } + + case PyTreeKind::kDict: { + nb::dict dict; + for (int i = 0; i < node.arity; ++i) { + dict[node.sorted_dict_keys[i]] = std::move(children[i]); + } + return std::move(dict); + break; + } + case PyTreeKind::kCustom: { + nb::object tuple = nb::steal(PyTuple_New(node.arity)); + for (int i = 0; i < node.arity; ++i) { + PyTuple_SET_ITEM(tuple.ptr(), i, children[i].release().ptr()); + } + return node.custom->from_iterable(node.node_data, tuple); + } + + case PyTreeKind::kDataclass: { + nb::kwargs kwargs; + auto meta_size = node.custom->meta_fields.size(); + for (int i = 0; i < meta_size; ++i) { + kwargs[node.custom->meta_fields[i]] = + nb::borrow(nb::tuple(node.node_data)[i]); + } + auto data_size = node.custom->data_fields.size(); + for (int i = 0; i < data_size; ++i) { + kwargs[node.custom->data_fields[i]] = std::move(children[i]); + } + return node.custom->type(**kwargs); + } + } + throw std::logic_error("Unreachable code."); +} + +nb::list PyTreeDef::FlattenUpTo(nb::handle xs) const { + nb::list leaves = nb::steal(PyList_New(num_leaves())); + std::vector agenda; + agenda.push_back(nb::borrow(xs)); + auto it = traversal_.rbegin(); + int leaf = num_leaves() - 1; + while (!agenda.empty()) { + if (it == traversal_.rend()) { + throw std::invalid_argument(absl::StrFormat( + "Tree structures did not match: %s vs %s", + nb::cast(nb::repr(xs)), ToString())); + } + const Node& node = *it; + nb::object object = agenda.back(); + agenda.pop_back(); + ++it; + + switch (node.kind) { + case PyTreeKind::kLeaf: + if (leaf < 0) { + throw std::logic_error("Leaf count mismatch."); + } + PyList_SET_ITEM(leaves.ptr(), leaf, object.release().ptr()); + --leaf; + break; + + case PyTreeKind::kNone: + if (!object.is_none()) { + throw std::invalid_argument(absl::StrFormat( + "Expected None, got %s.\n\n" + "In previous releases of JAX, flatten-up-to used to " + "consider None to be a tree-prefix of non-None values. To obtain " + "the previous behavior, you can usually write:\n" + " jax.tree.map(lambda x, y: None if x is None else f(x, y), a, " + "b, is_leaf=lambda x: x is None)", + nb::cast(nb::repr(object)))); + } + break; + + case PyTreeKind::kTuple: { + if (!PyTuple_CheckExact(object.ptr())) { + throw std::invalid_argument( + absl::StrFormat("Expected tuple, got %s.", + nb::cast(nb::repr(object)))); + } + nb::tuple tuple = nb::borrow(object); + if (tuple.size() != node.arity) { + throw std::invalid_argument(absl::StrFormat( + "Tuple arity mismatch: %d != %d; tuple: %s.", tuple.size(), + node.arity, nb::cast(nb::repr(object)))); + } + for (nb::handle entry : tuple) { + agenda.push_back(nb::borrow(entry)); + } + break; + } + + case PyTreeKind::kList: { + if (!PyList_CheckExact(object.ptr())) { + throw std::invalid_argument( + absl::StrFormat("Expected list, got %s.", + nb::cast(nb::repr(object)))); + } + nb::list list = nb::borrow(object); + if (list.size() != node.arity) { + throw std::invalid_argument(absl::StrFormat( + "List arity mismatch: %d != %d; list: %s.", list.size(), + node.arity, nb::cast(nb::repr(object)))); + } + for (nb::handle entry : list) { + agenda.push_back(nb::borrow(entry)); + } + break; + } + + case PyTreeKind::kDict: { + if (!PyDict_CheckExact(object.ptr())) { + throw std::invalid_argument( + absl::StrFormat("Expected dict, got %s.", + nb::cast(nb::repr(object)))); + } + nb::dict dict = nb::borrow(object); + std::vector keys = GetSortedPyDictKeys(dict.ptr()); + if (!IsSortedPyDictKeysEqual(keys, node.sorted_dict_keys)) { + // Convert to a nb::list for nb::repr to avoid having to stringify a + // vector. This is error path so it is fine to pay conversion cost. + throw std::invalid_argument( + absl::StrFormat("Dict key mismatch; expected keys: %s; dict: %s.", + nb::cast( + nb::repr(nb::cast(node.sorted_dict_keys))), + nb::cast(nb::repr(object)))); + } + for (nb::handle key : keys) { + agenda.push_back(dict[key]); + } + break; + } + + case PyTreeKind::kNamedTuple: { + if (!nb::isinstance(object) || + !nb::hasattr(object, "_fields")) { + throw std::invalid_argument( + absl::StrFormat("Expected named tuple, got %s.", + nb::cast(nb::repr(object)))); + } + nb::tuple tuple = nb::borrow(object); + if (tuple.size() != node.arity) { + throw std::invalid_argument(absl::StrFormat( + "Named tuple arity mismatch: %d != %d; tuple: %s.", tuple.size(), + node.arity, nb::cast(nb::repr(object)))); + } + if (tuple.type().not_equal(node.node_data)) { + throw std::invalid_argument(absl::StrFormat( + "Named tuple type mismatch: expected type: %s, tuple: %s.", + nb::cast(nb::repr(node.node_data)), + nb::cast(nb::repr(object)))); + } + for (nb::handle entry : tuple) { + agenda.push_back(nb::borrow(entry)); + } + break; + } + + case PyTreeKind::kCustom: { + auto* registration = registry_->Lookup(object.type()); + if (registration != node.custom) { + throw std::invalid_argument(absl::StrFormat( + "Custom node type mismatch: expected type: %s, value: %s.", + nb::cast(nb::repr(node.custom->type)), + nb::cast(nb::repr(object)))); + } + auto [leaves, aux_data] = node.custom->ToIterable(object); + if (node.node_data.not_equal(aux_data)) { + throw std::invalid_argument(absl::StrFormat( + "Mismatch custom node data: %s != %s; value: %s.", + nb::cast(nb::repr(node.node_data)), + nb::cast(nb::repr(aux_data)), + nb::cast(nb::repr(object)))); + } + int arity = 0; + for (nb::handle entry : leaves) { + ++arity; + agenda.push_back(nb::borrow(entry)); + } + if (arity != node.arity) { + throw std::invalid_argument(absl::StrFormat( + "Custom type arity mismatch: %d != %d; value: %s.", arity, + node.arity, nb::cast(nb::repr(object)))); + } + break; + } + + case PyTreeKind::kDataclass: { + auto* registration = registry_->Lookup(object.type()); + if (registration != node.custom) { + throw std::invalid_argument(absl::StrFormat( + "Custom dataclasss node type mismatch: expected type: %s, value: " + "%s.", + nb::cast(nb::repr(node.custom->type)), + nb::cast(nb::repr(std::move(object))))); + } + auto meta_size = node.custom->meta_fields.size(); + nb::object aux_data = nb::steal(PyTuple_New(meta_size)); + for (int meta_leaf = 0; meta_leaf < meta_size; ++meta_leaf) { + PyTuple_SET_ITEM( + aux_data.ptr(), meta_leaf, + nb::getattr(object, node.custom->meta_fields[meta_leaf]) + .release() + .ptr()); + } + if (node.node_data.not_equal(aux_data)) { + throw std::invalid_argument(absl::StrFormat( + "Mismatch custom dataclass node data: %s != %s; value: %s.", + nb::cast(nb::repr(node.node_data)), + nb::cast(nb::repr(aux_data)), + nb::cast(nb::repr(object)))); + } + auto data_size = node.custom->data_fields.size(); + if (data_size != node.arity) { + throw std::invalid_argument(absl::StrFormat( + "Custom type arity mismatch: %d != %d; value: %s.", data_size, + node.arity, nb::cast(nb::repr(object)))); + } + for (int leaf = 0; leaf < data_size; ++leaf) { + agenda.push_back(nb::borrow( + nb::getattr(object, node.custom->data_fields[leaf]))); + } + break; + } + } + } + if (it != traversal_.rend() || leaf != -1) { + throw std::invalid_argument( + absl::StrFormat("Tree structures did not match: %s vs %s", + nb::cast(nb::repr(xs)), ToString())); + } + return leaves; +} + +nb::object PyTreeDef::Walk(const nb::callable& f_node, nb::handle f_leaf, + nb::iterable leaves) const { + std::vector agenda; + auto it = leaves.begin(); + for (const Node& node : traversal_) { + switch (node.kind) { + case PyTreeKind::kLeaf: { + if (it == leaves.end()) { + throw std::invalid_argument("Too few leaves for PyTreeDef"); + } + + nb::object leaf = nb::borrow(*it); + agenda.push_back(f_leaf.is_none() ? std::move(leaf) + : f_leaf(std::move(leaf))); + ++it; + break; + } + + case PyTreeKind::kNone: + case PyTreeKind::kTuple: + case PyTreeKind::kNamedTuple: + case PyTreeKind::kList: + case PyTreeKind::kDict: + case PyTreeKind::kCustom: + case PyTreeKind::kDataclass: { + if (agenda.size() < node.arity) { + throw std::logic_error("Too few elements for custom type."); + } + nb::object tuple = nb::steal(PyTuple_New(node.arity)); + for (int i = node.arity - 1; i >= 0; --i) { + PyTuple_SET_ITEM(tuple.ptr(), i, agenda.back().release().ptr()); + agenda.pop_back(); + } + nb::object node_data = node.node_data; + if (node.kind == PyTreeKind::kDict) { + // Convert to a nb::list for f_node invocation. + node_data = nb::cast(node.sorted_dict_keys); + } + agenda.push_back(f_node(tuple, node_data ? node_data : nb::none())); + } + } + } + if (it != leaves.end()) { + throw std::invalid_argument("Too many leaves for PyTreeDef"); + } + if (agenda.size() != 1) { + throw std::logic_error("PyTreeDef traversal did not yield a singleton."); + } + return std::move(agenda.back()); +} + +nb::object PyTreeDef::FromIterableTreeHelper( + nb::handle xs, + absl::InlinedVector::const_reverse_iterator* it) const { + if (*it == traversal_.rend()) { + throw std::invalid_argument("Tree structures did not match."); + } + const Node& node = **it; + ++*it; + if (node.kind == PyTreeKind::kLeaf) { + return nb::borrow(xs); + } + nb::iterable iterable = nb::borrow(xs); + std::vector ys; + ys.reserve(node.arity); + for (nb::handle x : iterable) { + ys.push_back(nb::borrow(x)); + } + if (ys.size() != node.arity) { + throw std::invalid_argument("Arity mismatch between trees"); + } + for (int j = node.arity - 1; j >= 0; --j) { + ys[j] = FromIterableTreeHelper(ys[j], it); + } + + return MakeNode(node, absl::MakeSpan(ys)); +} + +nb::object PyTreeDef::FromIterableTree(nb::handle xs) const { + auto it = traversal_.rbegin(); + nb::object out = FromIterableTreeHelper(xs, &it); + if (it != traversal_.rend()) { + throw std::invalid_argument("Tree structures did not match."); + } + return out; +} + +nb_class_ptr PyTreeDef::Compose(const PyTreeDef& inner) const { + if (inner.registry_ != registry_) { + throw std::invalid_argument( + "PyTree registries of PyTreeDefs passed to Compose() must match."); + } + auto out = make_nb_class(registry_ref_); + out->traversal_.reserve(static_cast(num_leaves()) * + inner.num_nodes() + + num_nodes() - num_leaves()); + for (const Node& n : traversal_) { + if (n.kind == PyTreeKind::kLeaf) { + absl::c_copy(inner.traversal_, std::back_inserter(out->traversal_)); + } else { + out->traversal_.push_back(n); + } + } + out->SetNumLeavesAndNumNodes(); + return out; +} + +/*static*/ nb_class_ptr PyTreeDef::Tuple( + nb_class_ptr registry, nb::list defs) { + auto out = make_nb_class(std::move(registry)); + int num_leaves = 0; + for (nb::handle def_handle : defs) { + const PyTreeDef* def = nb::cast(def_handle); + if (def->registry() != out->registry()) { + throw std::invalid_argument( + "PyTree registries of PyTreeDefs passed to Tuple() must match."); + } + absl::c_copy(def->traversal_, std::back_inserter(out->traversal_)); + num_leaves += def->num_leaves(); + } + Node node; + node.kind = PyTreeKind::kTuple; + node.arity = defs.size(); + node.num_leaves = num_leaves; + node.num_nodes = out->traversal_.size() + 1; + out->traversal_.push_back(node); + return out; +} + +std::vector> PyTreeDef::Children() const { + std::vector> children; + if (traversal_.empty()) { + return children; + } + Node const& root = traversal_.back(); + children.resize(root.arity); + int pos = traversal_.size() - 1; + for (int i = root.arity - 1; i >= 0; --i) { + children[i] = make_nb_class(registry_ref_); + const Node& node = traversal_.at(pos - 1); + if (pos < node.num_nodes) { + throw std::logic_error("children() walked off start of array"); + } + std::copy(traversal_.begin() + pos - node.num_nodes, + traversal_.begin() + pos, + std::back_inserter(children[i]->traversal_)); + pos -= node.num_nodes; + } + if (pos != 0) { + throw std::logic_error("pos != 0 at end of PyTreeDef::Children"); + } + return children; +} + +std::string PyTreeDef::ToString() const { + std::vector agenda; + for (const Node& node : traversal_) { + if (agenda.size() < node.arity) { + throw std::logic_error("Too few elements for container."); + } + + std::string children = + absl::StrJoin(agenda.end() - node.arity, agenda.end(), ", "); + std::string representation; + switch (node.kind) { + case PyTreeKind::kLeaf: + agenda.push_back("*"); + continue; + case PyTreeKind::kNone: + representation = "None"; + break; + case PyTreeKind::kTuple: + // Tuples with only one element must have a trailing comma. + if (node.arity == 1) children += ","; + representation = absl::StrCat("(", children, ")"); + break; + case PyTreeKind::kList: + representation = absl::StrCat("[", children, "]"); + break; + case PyTreeKind::kDict: { + if (node.sorted_dict_keys.size() != node.arity) { + throw std::logic_error("Number of keys and entries does not match."); + } + representation = "{"; + std::string separator; + auto child_iter = agenda.end() - node.arity; + for (const nb::handle& key : node.sorted_dict_keys) { + absl::StrAppendFormat(&representation, "%s%s: %s", separator, + nb::cast(nb::repr(key)), + *child_iter); + child_iter++; + separator = ", "; + } + representation += "}"; + break; + } + + case PyTreeKind::kNamedTuple: + case PyTreeKind::kCustom: + case PyTreeKind::kDataclass: { + std::string kind; + std::string data; + if (node.kind == PyTreeKind::kNamedTuple) { + kind = "namedtuple"; + if (node.node_data) { + // Node data for named tuples is the type. + data = absl::StrFormat( + "[%s]", nb::cast( + nb::str(nb::getattr(node.node_data, "__name__")))); + } + } else { + kind = nb::cast( + nb::str(nb::getattr(node.custom->type, "__name__"))); + if (node.node_data) { + data = absl::StrFormat( + "[%s]", nb::cast(nb::str(node.node_data))); + } + } + + representation = + absl::StrFormat("CustomNode(%s%s, [%s])", kind, data, children); + break; + } + } + agenda.erase(agenda.end() - node.arity, agenda.end()); + agenda.push_back(std::move(representation)); + } + if (agenda.size() != 1) { + throw std::logic_error("PyTreeDef traversal did not yield a singleton."); + } + return absl::StrCat("PyTreeDef(", agenda.back(), ")"); +} + +nb::object PyTreeDef::ToPickle() const { + nb::list traversal; + for (const auto& node : traversal_) { + nb::object node_data = node.node_data; + if (node.kind == PyTreeKind::kDict) { + // Convert to a nb::list for pickling to avoid having to pickle a vector. + // Pickle should be a rare operation so this conversion cost is hopefully + // on non-critical path. + node_data = nb::cast(node.sorted_dict_keys); + } + traversal.append( + nb::make_tuple(static_cast(node.kind), node.arity, + node_data ? node_data : nb::none(), + node.custom != nullptr ? node.custom->type : nb::none(), + node.num_leaves, node.num_nodes)); + } + return nb::make_tuple(nb::cast(registry_ref_), traversal); +} + +void PyTreeDef::FromPickle(nb::object pickle) { + for (const auto& item : nb::cast(pickle)) { + auto t = nb::cast(item); + if (t.size() != 6) { + throw xla::XlaRuntimeError("Malformed pickled PyTreeDef"); + } + Node& node = traversal_.emplace_back(); + node.kind = static_cast(nb::cast(t[0])); + node.arity = nb::cast(t[1]); + switch (node.kind) { + case PyTreeKind::kNamedTuple: + node.node_data = t[2]; + break; + case PyTreeKind::kDict: + node.sorted_dict_keys = nb::cast>(t[2]); + break; + case PyTreeKind::kCustom: + case PyTreeKind::kDataclass: + node.node_data = t[2]; + break; + default: + if (!t[2].is_none()) { + throw xla::XlaRuntimeError("Malformed pickled PyTreeDef"); + } + break; + } + if (node.kind == PyTreeKind::kCustom || + node.kind == PyTreeKind::kDataclass) { + node.custom = t[3].is_none() ? nullptr : registry()->Lookup(t[3]); + if (node.custom == nullptr) { + throw xla::XlaRuntimeError( + absl::StrCat("Unknown custom type in pickled PyTreeDef: ", + nb::cast(nb::repr(t[3])))); + } + } else { + if (!t[3].is_none()) { + throw xla::XlaRuntimeError("Malformed pickled PyTreeDef"); + } + } + node.num_leaves = nb::cast(t[4]); + node.num_nodes = nb::cast(t[5]); + } +} + +void PyTreeDef::SetNumLeavesAndNumNodes() { + // num_leaves and num_nodes are fully determined by arity. + std::vector> starts; + int num_leaves = 0; + for (int i = 0; i < traversal_.size(); ++i) { + std::pair start = {num_leaves, i}; + if (traversal_[i].kind == PyTreeKind::kLeaf) { + num_leaves += 1; + } + if (traversal_[i].arity == 0) { + starts.push_back(start); + } else { + starts.resize(starts.size() - (traversal_[i].arity - 1)); + } + traversal_[i].num_leaves = num_leaves - starts.back().first; + traversal_[i].num_nodes = i + 1 - starts.back().second; + } +} + +void PyTreeDef::SerializeTo(jax::PyTreeDefProto& result) const { + absl::flat_hash_map interned_strings; + auto intern_str = [&](const std::string& key) { + auto [it, added] = + interned_strings.emplace(key, result.interned_strings_size()); + if (added) { + result.add_interned_strings(key); + } + return it->second; + }; + for (const auto& node : traversal_) { + auto* node_data = result.add_nodes(); + node_data->set_arity(node.arity); + switch (node.kind) { + case PyTreeKind::kLeaf: + node_data->set_type(jax::PyTreeNodeType::PY_TREE_KIND_LEAF); + break; + case PyTreeKind::kList: + node_data->set_type(jax::PyTreeNodeType::PY_TREE_KIND_LIST); + break; + case PyTreeKind::kNone: + node_data->set_type(jax::PyTreeNodeType::PY_TREE_KIND_NONE); + break; + case PyTreeKind::kTuple: + node_data->set_type(jax::PyTreeNodeType::PY_TREE_KIND_TUPLE); + break; + case PyTreeKind::kDict: + node_data->set_type(jax::PyTreeNodeType::PY_TREE_KIND_DICT); + for (auto& key : node.sorted_dict_keys) { + if (!nb::isinstance(key)) { + throw std::invalid_argument( + "Only string keys are supported in proto pytree " + "serialization."); + } + node_data->mutable_dict_keys()->add_str_id( + intern_str(nb::cast(key))); + } + break; + default: + throw std::invalid_argument( + "User-defined nodes are not supported when serializing pytrees as " + "protocol buffers. You should either convert the user-defined " + "nodes to another type or use pickle instead."); + break; + } + } +} + +nb_class_ptr PyTreeDef::DeserializeFrom( + nb_class_ptr registry, const jax::PyTreeDefProto& input) { + std::vector interned_strings; + interned_strings.reserve(input.interned_strings().size()); + for (auto& s : input.interned_strings()) { + interned_strings.push_back(nb::cast(s)); + } + nb_class_ptr result = + make_nb_class(std::move(registry)); + for (auto& node_proto : input.nodes()) { + result->traversal_.emplace_back(); + auto& node = result->traversal_.back(); + node.arity = node_proto.arity(); + node.custom = nullptr; + switch (node_proto.type()) { + case jax::PyTreeNodeType::PY_TREE_KIND_LEAF: + node.kind = PyTreeKind::kLeaf; + break; + case jax::PyTreeNodeType::PY_TREE_KIND_LIST: + node.kind = PyTreeKind::kList; + break; + case jax::PyTreeNodeType::PY_TREE_KIND_NONE: + node.kind = PyTreeKind::kNone; + break; + case jax::PyTreeNodeType::PY_TREE_KIND_TUPLE: + node.kind = PyTreeKind::kTuple; + break; + case jax::PyTreeNodeType::PY_TREE_KIND_DICT: + node.kind = PyTreeKind::kDict; + for (uint32_t str_id : node_proto.dict_keys().str_id()) { + if (str_id >= interned_strings.size()) { + throw std::invalid_argument( + "Malformed pytree proto (dict_key out of range)."); + } + node.sorted_dict_keys.push_back(interned_strings.at(str_id)); + } + break; + default: + throw std::invalid_argument( + "Malformed pytree proto (invalid node type)"); + break; + } + } + result->SetNumLeavesAndNumNodes(); + return result; +} + +std::optional> PyTreeDef::GetNodeData() + const { + if (traversal_.empty()) { + throw std::logic_error("empty PyTreeDef traversal."); + } + auto builtin_type = [](PyTypeObject* type_obj) { + return nb::borrow(reinterpret_cast(type_obj)); + }; + const auto& node = traversal_.back(); + switch (node.kind) { + case PyTreeKind::kLeaf: + return std::nullopt; + case PyTreeKind::kNone: + return std::make_pair(builtin_type(Py_TYPE(Py_None)), nb::none()); + case PyTreeKind::kTuple: + return std::make_pair(builtin_type(&PyTuple_Type), nb::none()); + case PyTreeKind::kList: + return std::make_pair(builtin_type(&PyList_Type), nb::none()); + case PyTreeKind::kDict: + return std::make_pair(builtin_type(&PyDict_Type), + nb::cast(node.sorted_dict_keys)); + case PyTreeKind::kNamedTuple: + return std::make_pair(node.node_data, nb::none()); + case PyTreeKind::kCustom: + case PyTreeKind::kDataclass: + return std::make_pair(node.custom->type, node.node_data); + } +} + +nb_class_ptr PyTreeDef::MakeFromNodeDataAndChildren( + nb_class_ptr registry, + std::optional> node_data, + nb::iterable children) { + nb_class_ptr result = + make_nb_class(std::move(registry)); + int num_leaves = 0; + int arity = 0; + for (nb::handle pchild : children) { + const PyTreeDef& child = nb::cast(pchild); + absl::c_copy(child.traversal_, std::back_inserter(result->traversal_)); + num_leaves += child.num_leaves(); + ++arity; + } + result->traversal_.emplace_back(); + auto& node = result->traversal_.back(); + node.arity = arity; + node.custom = nullptr; + node.num_leaves = num_leaves; + node.num_nodes = result->traversal_.size(); + if (node_data == std::nullopt) { + node.kind = PyTreeKind::kLeaf; + ++node.num_leaves; + return result; + } + int is_nt = PyObject_IsSubclass(node_data->first.ptr(), + reinterpret_cast(&PyTuple_Type)); + if (is_nt == -1) { + throw nb::python_error(); + } + if (is_nt != 0 && nb::hasattr(node_data->first, "_fields")) { + node.kind = PyTreeKind::kNamedTuple; + node.node_data = node_data->first; + return result; + } + auto* registration = result->registry()->Lookup(node_data->first); + if (registration == nullptr) { + throw std::logic_error(absl::StrFormat( + "Could not find type: %s.", + nb::cast(nb::repr(node_data->first)))); + } + node.kind = registration->kind; + if (node.kind == PyTreeKind::kCustom || node.kind == PyTreeKind::kDataclass) { + node.custom = registration; + node.node_data = node_data->second; + } else if (node.kind == PyTreeKind::kNamedTuple) { + node.node_data = node_data->first; + } else if (node.kind == PyTreeKind::kDict) { + node.sorted_dict_keys = + nb::cast>(node_data->second); + } + return result; +} + +int PyTreeDef::Node::tp_traverse(visitproc visit, void* arg) const { + Py_VISIT(node_data.ptr()); + for (const auto& key : sorted_dict_keys) { + Py_VISIT(key.ptr()); + } + return 0; +} + +/* static */ int PyTreeDef::tp_traverse(PyObject* self, visitproc visit, + void* arg) { + Py_VISIT(Py_TYPE(self)); + if (!nb::inst_ready(self)) { + return 0; + } + PyTreeDef* treedef = nb::inst_ptr(self); + Py_VISIT(treedef->registry_ref_.ptr()); + for (const auto& node : treedef->traversal_) { + node.tp_traverse(visit, arg); + } + return 0; +} + +/* static */ int PyTreeDef::tp_clear(PyObject* self) { + PyTreeDef* treedef = nb::inst_ptr(self); + treedef->registry_ref_.reset(); + treedef->traversal_.clear(); + return 0; +} + +/* static */ PyType_Slot PyTreeDef::slots_[] = { + {Py_tp_traverse, (void*)PyTreeDef::tp_traverse}, + {Py_tp_clear, (void*)PyTreeDef::tp_clear}, + {0, nullptr}, +}; + +void BuildPytreeSubmodule(nb::module_& m) { + nb::module_ pytree = m.def_submodule("pytree", "Python tree library"); + pytree.attr("version") = nb::int_(3); + + nb::class_ treedef(pytree, "PyTreeDef", + nb::type_slots(PyTreeDef::slots_)); + + nb::class_ registry(m, "PyTreeRegistry", nb::dynamic_attr(), + nb::type_slots(PyTreeRegistry::slots_)); + + registry.def(nb::init(), + nb::arg("enable_none") = true, nb::arg("enable_tuple") = true, + nb::arg("enable_namedtuple") = true, + nb::arg("enable_list") = true, nb::arg("enable_dict") = true); + registry.def( + "flatten", + [](nb_class_ptr registry, nb::object x, + std::optional leaf_predicate) { + nb::list leaves; + nb_class_ptr def = + make_nb_class(std::move(registry)); + def->Flatten(x, leaves, leaf_predicate); + return nb::make_tuple(std::move(leaves), std::move(def)); + }, + nb::arg("tree").none(), nb::arg("leaf_predicate").none() = std::nullopt); + registry.def("flatten_one_level", &PyTreeRegistry::FlattenOneLevel, + nb::arg("tree").none()); + registry.def("flatten_one_level_with_keys", + &PyTreeRegistry::FlattenOneLevelWithKeys, + nb::arg("tree").none()); + registry.def( + "flatten_with_path", + [](nb_class_ptr registry, nb::object x, + std::optional leaf_predicate) { + nb::list leaves; + nb_class_ptr def = + make_nb_class(std::move(registry)); + def->FlattenWithPath(x, leaves, leaf_predicate); + return nb::make_tuple(std::move(leaves), std::move(def)); + }, + nb::arg("tree").none(), nb::arg("leaf_predicate").none() = std::nullopt); + registry.def("register_node", &PyTreeRegistry::Register, + nb::arg("type").none(), nb::arg("to_iterable").none(), + nb::arg("from_iterable").none(), + nb::arg("to_iterable_with_keys").none() = std::nullopt); + registry.def("register_dataclass_node", &PyTreeRegistry::RegisterDataclass); + registry.def("__reduce__", + [](nb::object self) { return self.attr("__name__"); }); + + pytree.attr("_default_registry") = make_nb_class( + /*enable_none=*/true, /*enable_tuple=*/true, /*enable_namedtuple=*/true, + /*enable_list=*/true, /*enable_dict*/ true); + pytree.def("default_registry", + [registry = nb::cast>( + pytree.attr("_default_registry"))]() { return registry; }); + + pytree.attr("PyTreeRegistry") = m.attr("PyTreeRegistry"); + pytree.def("tuple", &PyTreeDef::Tuple); + pytree.def("all_leaves", &PyTreeDef::AllLeaves); + + treedef.def("unflatten", + static_cast( + &PyTreeDef::Unflatten)); + treedef.def("flatten_up_to", &PyTreeDef::FlattenUpTo, nb::arg("tree").none()); + treedef.def("compose", &PyTreeDef::Compose); + treedef.def( + "walk", &PyTreeDef::Walk, + "Walk pytree, calling f_node(node, node_data) at nodes, and f_leaf " + "at leaves", + nb::arg("f_node"), nb::arg("f_leaf"), nb::arg("leaves")); + treedef.def("from_iterable_tree", &PyTreeDef::FromIterableTree); + treedef.def("children", &PyTreeDef::Children); + treedef.def_prop_ro("num_leaves", &PyTreeDef::num_leaves); + treedef.def_prop_ro("num_nodes", &PyTreeDef::num_nodes); + treedef.def("__repr__", &PyTreeDef::ToString); + treedef.def("__eq__", + [](const PyTreeDef& a, const PyTreeDef& b) { return a == b; }); + treedef.def("__ne__", + [](const PyTreeDef& a, const PyTreeDef& b) { return a != b; }); + treedef.def("__hash__", [](const PyTreeDef& t) { return absl::HashOf(t); }); + treedef.def("serialize_using_proto", [](const PyTreeDef& a) { + jax::PyTreeDefProto result; + a.SerializeTo(result); + std::string serialized = result.SerializeAsString(); + return nb::bytes(serialized.data(), serialized.size()); + }); + treedef.def_static( + "deserialize_using_proto", + [](nb_class_ptr registry, nb::bytes data) { + jax::PyTreeDefProto input; + absl::string_view serialized(data.c_str(), data.size()); + if (serialized.size() > std::numeric_limits::max()) { + throw xla::XlaRuntimeError( + "Pytree serialization too large to deserialize."); + } + if (!input.ParseFromArray(serialized.data(), serialized.size())) { + throw xla::XlaRuntimeError("Could not deserialize PyTreeDefProto."); + } + return PyTreeDef::DeserializeFrom(std::move(registry), input); + }, + nb::arg("registry"), nb::arg("data")); + treedef.def("node_data", &PyTreeDef::GetNodeData, + "Returns None if a leaf-pytree, else (type, node_data)"); + treedef.def_static( + "make_from_node_data_and_children", + &PyTreeDef::MakeFromNodeDataAndChildren, nb::arg("registry"), + nb::arg("node_data").none(), nb::arg("children"), + "Reconstructs a pytree from `node_data()` and `children()`."); + treedef.def("__getstate__", &PyTreeDef::ToPickle); + treedef.def("__setstate__", [](PyTreeDef& t, nb::object o) { + nb::tuple pickle = nb::cast(o); + if (pickle.size() != 2) { + throw xla::XlaRuntimeError( + "Malformed pickled PyTreeDef, expected 2-tuple"); + } + auto registry = nb::cast>(pickle[0]); + new (&t) PyTreeDef(registry); + t.FromPickle(pickle[1]); + }); + + nb::class_ sequence_key(pytree, "SequenceKey"); + sequence_key.def(nb::init(), nb::arg("idx")); + sequence_key.def("__str__", &SequenceKey::ToString); + sequence_key.def("__repr__", &SequenceKey::ToReprString); + sequence_key.def("__eq__", &SequenceKey::Equals); + sequence_key.def("__hash__", [](const SequenceKey& key) { + return key.idx() + kSequenceKeyHashSalt; + }); + sequence_key.def_prop_ro("idx", &SequenceKey::idx); + sequence_key.def_prop_ro_static("__match_args__", &SequenceKey::MatchArgs); + sequence_key.def("__getstate__", + [](SequenceKey& key) { return nb::make_tuple(key.idx()); }); + sequence_key.def("__setstate__", + [](SequenceKey& key, const nb::tuple& state) { + if (state.size() != 1) { + throw xla::XlaRuntimeError( + "Malformed pickled SequenceKey, expected 1-tuple"); + } + new (&key) SequenceKey(nb::cast(state[0])); + }); + + nb::class_ dict_key(pytree, "DictKey", + nb::type_slots(DictKey::slots_)); + dict_key.def(nb::init(), nb::arg("key")); + dict_key.def("__str__", &DictKey::ToString); + dict_key.def("__repr__", &DictKey::ToReprString); + dict_key.def("__eq__", &DictKey::Equals); + dict_key.def("__hash__", + [](const DictKey& key) { return nanobind::hash(key.key()); }); + dict_key.def_prop_ro("key", &DictKey::key); + dict_key.def_prop_ro_static("__match_args__", &DictKey::MatchArgs); + dict_key.def("__getstate__", + [](DictKey& key) { return nb::make_tuple(key.key()); }); + dict_key.def("__setstate__", [](DictKey& key, const nb::tuple& state) { + if (state.size() != 1) { + throw xla::XlaRuntimeError("Malformed pickled DictKey, expected 1-tuple"); + } + new (&key) DictKey(nb::cast(state[0])); + }); + + nb::class_ get_attr_key(pytree, "GetAttrKey"); + get_attr_key.def(nb::init(), nb::arg("name")); + get_attr_key.def("__str__", &GetAttrKey::ToString); + get_attr_key.def("__repr__", &GetAttrKey::ToReprString); + get_attr_key.def("__eq__", &GetAttrKey::Equals); + get_attr_key.def("__hash__", + [](const GetAttrKey& key) { return nb::hash(key.name()); }); + get_attr_key.def_prop_ro("name", &GetAttrKey::name); + get_attr_key.def_prop_ro_static("__match_args__", &GetAttrKey::MatchArgs); + get_attr_key.def("__getstate__", + [](GetAttrKey& key) { return nb::make_tuple(key.name()); }); + get_attr_key.def("__setstate__", [](GetAttrKey& key, const nb::tuple& state) { + if (state.size() != 1) { + throw xla::XlaRuntimeError( + "Malformed pickled GetAttrKey, expected 1-tuple"); + } + new (&key) GetAttrKey(nb::str(state[0])); + }); + + nb::class_ flattened_index_key(pytree, + "FlattenedIndexKey"); + flattened_index_key.def(nb::init(), nb::arg("key")); + flattened_index_key.def("__str__", &FlattenedIndexKey::ToString); + flattened_index_key.def("__repr__", &FlattenedIndexKey::ToReprString); + flattened_index_key.def("__eq__", &FlattenedIndexKey::Equals); + flattened_index_key.def("__hash__", [](const FlattenedIndexKey& key) { + return key.key() + kFlattenedIndexKeyHashSalt; + }); + flattened_index_key.def_prop_ro("key", &FlattenedIndexKey::key); + flattened_index_key.def_prop_ro_static("__match_args__", + &FlattenedIndexKey::MatchArgs); + flattened_index_key.def("__getstate__", [](FlattenedIndexKey& key) { + return nb::make_tuple(key.key()); + }); + flattened_index_key.def( + "__setstate__", [](FlattenedIndexKey& key, const nb::tuple& state) { + if (state.size() != 1) { + throw xla::XlaRuntimeError( + "Malformed pickled FlattenedIndexKey, expected 1-tuple"); + } + new (&key) FlattenedIndexKey(nb::cast(state[0])); + }); +} + +} // namespace xla diff --git a/jaxlib/xla/pytree.h b/jaxlib/xla/pytree.h new file mode 100644 index 000000000000..c0cf284c6dbd --- /dev/null +++ b/jaxlib/xla/pytree.h @@ -0,0 +1,408 @@ +/* Copyright 2019 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_PYTREE_H_ +#define JAXLIB_XLA_PYTREE_H_ + +// See https://docs.jax.dev/en/latest/pytrees.html for the documentation +// about pytree. + +#include + +#include +#include +#include +#include +#include +#include + +// placeholder for index annotation headers +#include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" +#include "absl/hash/hash.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "jaxlib/xla/nb_class_ptr.h" +#include "jaxlib/xla/pytree.pb.h" + +namespace xla { + +enum class PyTreeKind { + kLeaf, // An opaque leaf node + kNone, // None. + kTuple, // A tuple + kNamedTuple, // A collections.namedtuple + kList, // A list + kDict, // A dict + kCustom, // A custom type. + kDataclass, // A dataclass. +}; + +// Registry of custom node types. +class PyTreeRegistry { + public: + PyTreeRegistry(bool enable_none, bool enable_tuple, bool enable_namedtuple, + bool enable_list, bool enable_dict); + + PyTreeRegistry(const PyTreeRegistry&) = delete; + PyTreeRegistry(PyTreeRegistry&&) = delete; + PyTreeRegistry& operator=(const PyTreeRegistry&) = delete; + PyTreeRegistry& operator=(PyTreeRegistry&&) = delete; + + struct Registration { + PyTreeKind kind; + + // The following values are populated for custom types. + // The Python type object, used to identify the type. + nanobind::object type; + // A function with signature: object -> (iterable, aux_data) + nanobind::callable to_iterable; + // A function with signature: (aux_data, iterable) -> object + nanobind::callable from_iterable; + // A function with signature: (aux_data, iterable(keypath, leaf)) -> object + std::optional to_iterable_with_keys; + + // Helper that calls to_iterable and validates that it returns a pair + // of an iterable and an aux_data object + std::pair ToIterable( + nanobind::handle o) const; + // Helper that calls to_iterable_with_keys and validates that it returns a + // pair of an iterable of key-leaf pairs and an aux_data object. If + // to_iterable_with_keys is not available, return a dummy key for each leaf, + // similar to the current jax.tree_util.FlattenedIndexKey. + std::pair>, + nanobind::object> + ToIterableWithKeys(nanobind::handle o) const; + + // For dataclasses. + std::vector data_fields; + std::vector meta_fields; + + int tp_traverse(visitproc visit, void* arg); + }; + + // Registers a new custom type. Objects of `type` will be treated as container + // node types in PyTrees. + void Register( + nanobind::object type, nanobind::callable to_iterable, + nanobind::callable from_iterable, + std::optional to_iterable_with_keys = std::nullopt); + // Same, but for dataclasses. + void RegisterDataclass(nanobind::object type, + std::vector data_fields, + std::vector meta_fields); + + // Finds the custom type registration for `type`. Returns nullptr if none + // exists. + const Registration* Lookup(nanobind::handle type) const; + + PyTreeKind KindOfObject(nanobind::handle obj, + PyTreeRegistry::Registration const** custom) const; + + // Flattens a pytree one level, returning either a tuple of the leaves and + // the node data, or None, if the entry is a leaf. + nanobind::object FlattenOneLevel(nanobind::handle x) const; + // Similar to above but returns a key-leaf pair for each leaf. + nanobind::object FlattenOneLevelWithKeys(nanobind::handle x) const; + // Underlying implementation of FlattenOneLevel and FlattenOneLevelWithKeys. + nanobind::object FlattenOneLevelImpl(nanobind::handle x, + bool with_keys) const; + + static PyType_Slot slots_[]; + + private: + struct TypeHash { + using is_transparent = void; + size_t operator()(const nanobind::object& t) const { + return absl::HashOf(t.ptr()); + } + size_t operator()(const nanobind::handle& t) const { + return absl::HashOf(t.ptr()); + } + }; + struct TypeEq { + using is_transparent = void; + bool operator()(const nanobind::object& a, + const nanobind::object& b) const { + return a.ptr() == b.ptr(); + } + bool operator()(const nanobind::object& a, + const nanobind::handle& b) const { + return a.ptr() == b.ptr(); + } + }; + mutable nanobind::ft_mutex mu_; + absl::flat_hash_map, TypeHash, + TypeEq> + registrations_; // Guarded by mu_ + bool enable_namedtuple_; + + static int tp_traverse(PyObject* self, visitproc visit, void* arg); + static int tp_clear(PyObject* self); +}; + +class SequenceKey { + public: + explicit SequenceKey(int idx) : idx_(idx) {}; + std::string ToReprString() const; + std::string ToString() const; + bool Equals(const nanobind::object& other); + int idx() const { return idx_; } + static nanobind::tuple MatchArgs(nanobind::handle unused); + + private: + int idx_; +}; + +class DictKey { + public: + explicit DictKey(nanobind::object key) : key_(key) {}; + std::string ToReprString() const; + std::string ToString() const; + bool Equals(const nanobind::object& other); + nanobind::object key() const { return key_; } + static nanobind::tuple MatchArgs(nanobind::handle unused); + static PyType_Slot slots_[]; + + private: + nanobind::object key_; + static int tp_traverse(PyObject* self, visitproc visit, void* arg); + static int tp_clear(PyObject* self); +}; + +class GetAttrKey { + public: + explicit GetAttrKey(nanobind::str name) : name_(name) {}; + std::string ToReprString() const; + std::string ToString() const; + bool Equals(const nanobind::object& other); + nanobind::str name() const { return name_; } + static nanobind::tuple MatchArgs(nanobind::handle unused); + + private: + nanobind::str name_; +}; + +class FlattenedIndexKey { + public: + explicit FlattenedIndexKey(int key) : key_(key) {}; + std::string ToReprString() const; + std::string ToString() const; + bool Equals(const nanobind::object& other); + int key() const { return key_; } + static nanobind::tuple MatchArgs(nanobind::handle unused); + + private: + int key_; +}; + +// A PyTreeDef describes the tree structure of a PyTree. A PyTree is a tree of +// Python values, where the interior nodes are tuples, lists, dictionaries, or +// user-defined containers, and the leaves are other objects. +class PyTreeDef { + public: + // Unowned registry: the registry must remain live at least as long as the + // PyTreeDef. It is the caller's responsibility to enforce this. + explicit PyTreeDef(PyTreeRegistry* registry) : registry_(registry) {} + + explicit PyTreeDef(nb_class_ptr registry) + : registry_(registry.get()), registry_ref_(std::move(registry)) {} + + // Flattens a Pytree into a list of leaves and a PyTreeDef. + // Returns references to the flattened objects, which might be temporary + // objects in the case of custom pytype handlers. + static std::pair, nb_class_ptr> + Flatten(nanobind::handle x, nb_class_ptr registry, + std::optional leaf_predicate = std::nullopt); + + // Flattens a Pytree into a list of `leaves` and a PyTreeDef (this). + // `leaves` owns references to the flattened objects, which might be + // temporary objects in the case of custom pytype handlers. + void Flatten(nanobind::handle handle, std::vector& leaves, + std::optional leaf_predicate = std::nullopt); + void Flatten(nanobind::handle handle, + absl::InlinedVector& leaves, + std::optional leaf_predicate = std::nullopt); + void Flatten(nanobind::handle handle, nanobind::list& leaves, + std::optional leaf_predicate = std::nullopt); + + void FlattenWithPath( + nanobind::handle handle, nanobind::list& leaves, + std::optional leaf_predicate = std::nullopt); + + // Tests whether the given list is a flat list of leaves. + static bool AllLeaves(PyTreeRegistry* registry, const nanobind::iterable& x); + + // Flattens a Pytree up to this PyTreeDef. 'this' must be a tree prefix of + // the tree-structure of 'x'. For example, if we flatten a value + // [(1, (2, 3)), {"foo": 4}] with a treedef [(*, *), *], the result is the + // list of leaves [1, (2, 3), {"foo": 4}]. + nanobind::list FlattenUpTo(nanobind::handle x) const; + + // Returns an unflattened PyTree given an iterable of leaves and a PyTreeDef. + nanobind::object Unflatten(nanobind::iterable leaves) const; + nanobind::object Unflatten(absl::Span leaves) const; + + // Composes two PyTreeDefs, replacing the leaves of this tree with copies of + // `inner`. The returned PyTreeDef holds a reference to its registry. + nb_class_ptr Compose(const PyTreeDef& inner) const; + + // Makes a Tuple PyTreeDef out of a vector of PyTreeDefs. + static nb_class_ptr Tuple(nb_class_ptr registry, + nanobind::list defs); + + // The returned PyTreeDefs hold a reference to the registry. + std::vector> Children() const; + + // Maps a function over a PyTree structure, applying f_leaf to each leaf, and + // f_node(node, node_data) to each container node. + nanobind::object Walk(const nanobind::callable& f_node, + nanobind::handle f_leaf, + nanobind::iterable leaves) const; + + // Given a tree of iterables with the same node/leaf structure as this PyTree, + // build the corresponding PyTree. + // TODO(phawkins): use flattening everywhere instead and delete this method. + nanobind::object FromIterableTree(nanobind::handle xs) const; + + int num_leaves() const { + if (traversal_.empty()) { + return 0; + } + return traversal_.back().num_leaves; + } + + int num_nodes() const { return traversal_.size(); } + + PyTreeRegistry* registry() const { return registry_; } + + size_t Hash() const; + + bool operator==(const PyTreeDef& other) const; + bool operator!=(const PyTreeDef& other) const { return !(*this == other); } + + std::string ToString() const; + + // Transforms the PyTreeDef into a pickleable object. Used to implement + // `PyTreeDef.__getstate__`. + nanobind::object ToPickle() const; + + // Transforms the object returned by `ToPickleable()` back to PyTreeDef. Used + // to implement `PyTreeDef.__setstate__`. + void FromPickle(nanobind::object pickleable); + + void SerializeTo(jax::PyTreeDefProto& result) const; + + static nb_class_ptr DeserializeFrom( + nb_class_ptr registry, const jax::PyTreeDefProto& input); + + std::optional> GetNodeData() + const; + + static nb_class_ptr MakeFromNodeDataAndChildren( + nb_class_ptr registry, + std::optional> node_data, + nanobind::iterable children); + + static PyType_Slot slots_[]; + + private: + void SetNumLeavesAndNumNodes(); + + struct Node { + PyTreeKind kind = PyTreeKind::kLeaf; + + // Arity for non-kLeaf types. + int arity = 0; + + // Kind-specific auxiliary data. For a kNamedTuple, contains the tuple type + // object. For a kDict, use `sorted_dict_keys` field below. For a kCustom + // type, contains the auxiliary data returned by the `to_iterable` function. + nanobind::object node_data; + + // Kind-specific auxiliary data specialized for kDict. Use a c++ vector + // to hold the sorted dict keys instead of a py::list to avoid creating + // a new python list object when flattening kDict. For deeply nested dict, + // using c++ vector instead of py::list avoids creating too many python + // objects that make python gc sweep slow. + std::vector sorted_dict_keys; + + // Custom type registration. Must be null for non-custom types. + const PyTreeRegistry::Registration* custom = nullptr; + + // Number of leaf nodes in the subtree rooted at this node. + int num_leaves = 0; + + // Number of leaf and interior nodes in the subtree rooted at this node. + int num_nodes = 0; + + int tp_traverse(visitproc visit, void* arg) const; + }; + template + friend H AbslHashValue(H h, const Node& n); + + template + friend H AbslHashValue(H h, const PyTreeDef& t); + + // Helper that manufactures an instance of a node given its children. + static nanobind::object MakeNode(const Node& node, + absl::Span children); + + // Recursive helper used to implement FromIterableTree() + nanobind::object FromIterableTreeHelper( + nanobind::handle xs, + absl::InlinedVector::const_reverse_iterator* it) + const; + + template + void FlattenImpl(nanobind::handle handle, T& leaves, + const std::optional& leaf_predicate, + std::optional>& keypath); + + template + nanobind::object UnflattenImpl(T leaves) const; + + static int tp_traverse(PyObject* self, visitproc visit, void* arg); + static int tp_clear(PyObject* self); + + // Pytree registry. Not owned. + PyTreeRegistry* registry_; + // If this class holds a reference to `registry`, it is held by + // `registry_ref_`. + nb_class_ptr registry_ref_; + + // Nodes, in a post-order traversal. We use an ordered traversal to minimize + // allocations, and post-order corresponds to the order we need to rebuild the + // tree structure. + absl::InlinedVector traversal_; +}; + +template +H AbslHashValue(H h, const PyTreeDef::Node& n) { + h = H::combine(std::move(h), n.kind, n.arity, n.custom); + return h; +} + +template +H AbslHashValue(H h, const PyTreeDef& t) { + h = H::combine(std::move(h), t.traversal_); + return h; +} + +void BuildPytreeSubmodule(nanobind::module_& m); + +} // namespace xla + +#endif // JAXLIB_XLA_PYTREE_H_ diff --git a/jaxlib/xla/pytree.proto b/jaxlib/xla/pytree.proto new file mode 100644 index 000000000000..73c087ef55ab --- /dev/null +++ b/jaxlib/xla/pytree.proto @@ -0,0 +1,32 @@ +syntax = "proto3"; + +package jax; + +enum PyTreeNodeType { + PY_TREE_KIND_INVALID = 0; + PY_TREE_KIND_LEAF = 1; + PY_TREE_KIND_LIST = 2; + PY_TREE_KIND_NONE = 3; + PY_TREE_KIND_TUPLE = 4; + PY_TREE_KIND_DICT = 5; +} + +message DictKeysProto { + repeated uint32 str_id = 1; +} + +message PyTreeNodeDefProto { + // Recovers the tree structure. + uint32 arity = 1; + // Node type. + PyTreeNodeType type = 2; + // Only set when type == DICT. + DictKeysProto dict_keys = 3; +} + +// A Pytree. +message PyTreeDefProto { + repeated PyTreeNodeDefProto nodes = 1; + // Extra strings. + repeated string interned_strings = 2; +} diff --git a/jaxlib/xla/pytree_test.py b/jaxlib/xla/pytree_test.py new file mode 100644 index 000000000000..b5ac7dd5b4d2 --- /dev/null +++ b/jaxlib/xla/pytree_test.py @@ -0,0 +1,144 @@ +# Copyright 2023 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import collections +import dataclasses +import gc + +from absl.testing import absltest + +from jax.jaxlib.xla import xla_client + +pytree = xla_client._xla.pytree + + +ExampleType = collections.namedtuple("ExampleType", "field0 field1") + +registry = pytree.PyTreeRegistry() + + +class ExampleType2: + + def __init__(self, field0, field1): + self.field0 = field0 + self.field1 = field1 + + def to_iterable(self): + return [self.field0, self.field1], (None,) + + +def from_iterable(state, values): + del state + return ExampleType2(field0=values[0], field1=values[1]) + + +registry.register_node(ExampleType2, ExampleType2.to_iterable, from_iterable) + + +@dataclasses.dataclass +class Custom: + a: int + b: str + + +registry.register_dataclass_node(Custom, ["a"], ["b"]) + + +class PyTreeTest(absltest.TestCase): + + def roundtrip(self, example): + original = registry.flatten(example)[1] + self.assertEqual( + pytree.PyTreeDef.deserialize_using_proto( + registry, original.serialize_using_proto() + ), + original, + ) + + def testSerializeDeserializeNoPickle(self): + o = object() + self.roundtrip(({"a": o, "b": o}, [o, (o, o), None])) + + def testSerializeWithFallback(self): + o = object() + with self.assertRaises(ValueError): + self.roundtrip({"a": ExampleType(field0=o, field1=o)}) + + def testRegisteredType(self): + o = object() + with self.assertRaises(ValueError): + self.roundtrip({"a": ExampleType2(field0=o, field1=o)}) + + def roundtrip_node_data(self, example): + original = registry.flatten(example)[1] + restored = pytree.PyTreeDef.make_from_node_data_and_children( + registry, original.node_data(), original.children() + ) + self.assertEqual(restored, original) + + def testRoundtripNodeData(self): + o = object() + self.roundtrip_node_data([o, o, o]) + self.roundtrip_node_data((o, o, o)) + self.roundtrip_node_data({"a": o, "b": o}) + self.roundtrip_node_data({22: o, 88: o}) + self.roundtrip_node_data(None) + self.roundtrip_node_data(o) + self.roundtrip_node_data(ExampleType(field0=o, field1=o)) + self.roundtrip_node_data(ExampleType2(field0=o, field1=o)) + + def testCompose(self): + x = registry.flatten(0)[1] + y = registry.flatten((0, 0))[1] + self.assertEqual((x.compose(y)).num_leaves, 2) + + def testDataclassMakeFromNodeData(self): + c = Custom(1, "a") + c_leafs, c_tree = registry.flatten(c) + c_tree2 = c_tree.make_from_node_data_and_children( + registry, c_tree.node_data(), c_tree.children() + ) + self.assertEqual(c_tree2.unflatten(c_leafs), c) + self.assertEqual(str(c_tree2), str(c_tree)) + + def testTpTraverse(self): + self.assertContainsSubset( + [ + pytree.PyTreeRegistry, + ExampleType2, + ExampleType2.to_iterable, + from_iterable, + ], + gc.get_referents(registry), + ) + k1 = "k1" + k2 = "k2" + + t = ExampleType("a", "b") + _, treedef = registry.flatten([1, {k1: 2, k2: t}, 5, t]) + + self.assertContainsSubset( + [ + pytree.PyTreeDef, + registry, + k1, + k2, + ExampleType, + ], + gc.get_referents(treedef), + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/jaxlib/xla/sdy.cc b/jaxlib/xla/sdy.cc new file mode 100644 index 000000000000..c6d1145517d8 --- /dev/null +++ b/jaxlib/xla/sdy.cc @@ -0,0 +1,143 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/sdy.h" + +#include +#include + +#include "mhlo/transforms/passes.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Bytecode/BytecodeWriter.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OwningOpRef.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LLVM.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/tuple.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "shardy/dialect/sdy/ir/dialect.h" +#include "shardy/dialect/sdy/ir/utils.h" +#include "xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" +#include "xla/mlir_hlo/mhlo/transforms/passes.h" +#include "xla/pjrt/mlir_to_hlo.h" +#include "xla/pjrt/status_casters.h" +#include "xla/service/spmd/shardy/constants.h" +#include "xla/service/spmd/shardy/sdy_round_trip/import_shardy_attrs.h" +#include "xla/service/spmd/shardy/sdy_round_trip/pipelines.h" +#include "xla/service/spmd/shardy/utils.h" +#include "xla/tsl/framework/mlir/status_scoped_diagnostic_handler.h" + +namespace nb = nanobind; + +namespace xla { + +namespace { + +absl::StatusOr SerializeUsingBytecode(mlir::ModuleOp module) { + std::string bytecode; + llvm::raw_string_ostream os(bytecode); + mlir::BytecodeWriterConfig config; + if (mlir::failed(mlir::writeBytecodeToFile(module, os, config))) { + return absl::InvalidArgumentError("mlir::writeBytecodeToFile failed"); + } + return bytecode; +} + +} // namespace + +void BuildSdySubmodule(nb::module_& m) { + nb::module_ mlir_module = m.def_submodule("sdy", "Shardy/XLA integration"); + + mlir_module + // TODO(b/707574930): define a C API for the XLA pipelines. + .def( + "sdy_round_trip_export_pipeline", + [](const nb::bytes& bytecode) -> nb::bytes { + mlir::MLIRContext context; + mlir::OwningOpRef module = + xla::ValueOrThrow(ParseMlirModuleString( + absl::string_view(bytecode.c_str(), bytecode.size()), + context)); + mlir::PassManager pm(&context); + sdy::addSdyRoundTripExportPipeline(pm); + tsl::StatusScopedDiagnosticHandler diagnosticHandler(&context); + ThrowIfError(diagnosticHandler.consumeStatus(pm.run(module.get()))); + std::string module_str = + xla::ValueOrThrow(SerializeUsingBytecode(module.get())); + return nb::bytes(module_str.data(), module_str.size()); + }, + nb::arg("module")) + .def( + "sdy_round_trip_import_shardings", + [](const nb::bytes& bytecode) -> nb::bytes { + mlir::MLIRContext context; + mlir::OwningOpRef module = + xla::ValueOrThrow(ParseMlirModuleString( + absl::string_view(bytecode.c_str(), bytecode.size()), + context)); + mlir::PassManager pm(&context); + pm.addPass(xla::sdy::createSdyRoundTripImportShardyAttrsPass()); + tsl::StatusScopedDiagnosticHandler diagnosticHandler(&context); + ThrowIfError(diagnosticHandler.consumeStatus(pm.run(module.get()))); + std::string module_str = + xla::ValueOrThrow(SerializeUsingBytecode(module.get())); + return nb::bytes(module_str.data(), module_str.size()); + }, + nb::arg("module")) + .def("lowered_with_shardy", + [](const nb::bytes& bytecode) -> bool { + mlir::MLIRContext context; + mlir::OwningOpRef module = + xla::ValueOrThrow(ParseMlirModuleString( + absl::string_view(bytecode.c_str(), bytecode.size()), + context)); + return mlir::sdy::getMeshAttr(module.get(), "mesh") || + sdy::tryGetFrontendAttr( + module.get(), sdy::kMeshesRoundTripAttr) + .has_value(); + }) + // TODO(bartchr): delete this and all uses of it once I have JAX export + // support multiple meshes. + .def("get_mesh", [](const nb::bytes& bytecode) -> nb::list { + mlir::MLIRContext context; + mlir::OwningOpRef module = + xla::ValueOrThrow(ParseMlirModuleString( + absl::string_view(bytecode.c_str(), bytecode.size()), context)); + auto mesh_op = + mlir::SymbolTable::lookupNearestSymbolFrom( + module.get(), mlir::StringAttr::get(&context, "mesh")); + if (!mesh_op) { + return {}; + } + nb::list mesh_shape; + for (auto axis : mesh_op.getMeshAttr().getAxes()) { + mesh_shape.append( + nb::make_tuple(axis.getName().str(), axis.getSize())); + } + return mesh_shape; + }); +} + +} // namespace xla diff --git a/jaxlib/xla/sdy.h b/jaxlib/xla/sdy.h new file mode 100644 index 000000000000..ef075855decd --- /dev/null +++ b/jaxlib/xla/sdy.h @@ -0,0 +1,28 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_SDY_H_ +#define JAXLIB_XLA_SDY_H_ + +// placeholder for index annotation headers +#include "nanobind/nanobind.h" + +namespace xla { + +void BuildSdySubmodule(nanobind::module_& m); + +} // namespace xla + +#endif // JAXLIB_XLA_SDY_H_ diff --git a/jaxlib/xla/sharded_device_array.h b/jaxlib/xla/sharded_device_array.h new file mode 100644 index 000000000000..6e014789a289 --- /dev/null +++ b/jaxlib/xla/sharded_device_array.h @@ -0,0 +1,216 @@ +/* Copyright 2021 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_SHARDED_DEVICE_ARRAY_H_ +#define JAXLIB_XLA_SHARDED_DEVICE_ARRAY_H_ + +#include +#include +#include + +#include "nanobind/nanobind.h" +#include "nanobind/stl/variant.h" // IWYU pragma: keep +#include "xla/python/types.h" + +// TODO(jblespiau): The current implementation moves the Python logic to C++, +// as a preliminary step to executing the `pmap` execution path from C++. +// It implements the current Python behavior (thus, it may not be optimal, and +// we will be able to modify it later). + +namespace jax { + +// High level introduction. +// +// pmap and other parallel computation functions distribute some computation on +// several devices. On December 2020, the devices mesh (i.e. N-dimentional array +// of devices on which we map the computation) is defined by the user. +// +// We describe how to shard the inputs, and how to map it to the mesh of devices +// using `ShardingSpec`. It's mainly based on 2 components: +// - `sharding`, which specifies how to shard the inputs. +// - `mesh_mapping`, which specifies how to map shards to devices. +// +// The 3 following structs define how to shard one dimension of an ndarry. +// +// `NoSharding` (`None` in Python) means no sharding. +struct NoSharding { + bool operator==(const NoSharding& other) const { return true; } + bool operator!=(const NoSharding& other) const { return false; } +}; + +template +H AbslHashValue(H h, const NoSharding& key) { + return h; +} + +// `Chunked` means that the dimension is split into np.prod(chunks) chunks +// and the split dimension itself is preserved inside the map. +// Those chunks are distributed over `len(chunks)` ShardedAxes axes +// (major-to-minor). +// For example, for a tensor `t` of shape [N] sharded using [Chunked([p])] (with +// p dividing N, let S = N // p) the tensor will be split into p chunks of +// shape [S], such sharded_t[k] = t[k * S: (k+1)*S] (left included, right +// excluded) for k in {0, ... p-1}. +struct Chunked { + public: + explicit Chunked(std::vector chunks_) : chunks(std::move(chunks_)) {} + // The number of chunks per axis. + std::vector chunks; + + bool operator==(const Chunked& other) const { return chunks == other.chunks; } + bool operator!=(const Chunked& other) const { return chunks != other.chunks; } +}; + +template +H AbslHashValue(H h, const Chunked& key) { + h = H::combine(std::move(h), key.chunks); + return h; +} + +// `Unstacked` means that the dimension is split into chunks of size 1, and +// doesn't appear inside the map. `size` is always the dimension size. +// For example, a Tensor t of shape [N] will be sharded into N tensors of shape +// [], when using `Unstacked(N)`. +struct Unstacked { + public: + explicit Unstacked(int sz) : size(sz) {} + int size; + + bool operator==(const Unstacked& other) const { return size == other.size; } + bool operator!=(const Unstacked& other) const { return size != other.size; } +}; + +template +H AbslHashValue(H h, const Unstacked& key) { + h = H::combine(std::move(h), key.size); + return h; +} + +using AvalDimSharding = std::variant; + +// Assigns sharded axes to mesh dimensions. +// +// The devices will be for each dimension which has a sharded `AvalDimSharding` +// When no axis is assigned, the data is replicated. +// As indices are 0-indexed, `ShardedAxis(1)` refers to the second actually +// sharded axis (i.e. counting as if the None dimensions of sharding were +// filtered out). +// For example, given the sharding `[Unstacked(n), None, Chunked(m)]`, an entry +// of `ShardedAxis(1)` refers to the `Chunked(m)` axis, not the `None`. + +struct ShardedAxis { + int axis; + bool operator==(const ShardedAxis& other) const { return axis == other.axis; } + bool operator!=(const ShardedAxis& other) const { return axis != other.axis; } +}; + +template +H AbslHashValue(H h, const ShardedAxis& key) { + h = H::combine(std::move(h), key.axis); + return h; +} + +struct Replicated { + int replicas; + bool operator==(const Replicated& other) const { + return replicas == other.replicas; + } + bool operator!=(const Replicated& other) const { + return replicas != other.replicas; + } +}; + +template +H AbslHashValue(H h, const Replicated& key) { + h = H::combine(std::move(h), key.replicas); + return h; +} + +using MeshDimAssignment = std::variant; + +// Describes how each axis is sharded (if it is), and how it's mapped to the +// devices mesh. See Jax pxla.py for the documentation. +// +// ShardingSpec is shared across pmap, pjit and xpmap. For pmap, an input +// `sharding` is composed of `NoSharding` and at most one `Unstacked`. +// If `axis_size=None`, at least one the inputs has a dimension associated to +// `Unstacked`. +// +// Examples: +// +// 1. For pmap, with a tensor of shape [8, 2, 2], to unstack along the first +// dimension into [8] devices: +// +// sharding = [Unstacked(8), NoSharding, NoSharding] +// mesh_mapping = [ShardedAxis(0)] +// +// 2. With an input array of shape [6], that we want to chunk into [2, 3] +// Assuming an device mesh [3, 4, 2] of devices, we will have: +// +// sharding = [Chunked([2, 3])] +// mesh_mapping = [ShardedAxis(1), Replicated, ShardedAxis(0)] +// +// In particular, in the above example, the ShardedAxis refers to indices +// of the sharded shape [2, 3]. (only the `Chunked` sharding can produce more +// than one dimension). +class ShardingSpec { + public: + ShardingSpec(std::vector sharding, + std::vector mesh_mapping) + : sharding_(std::move(sharding)), + mesh_mapping_(std::move(mesh_mapping)) {} + ShardingSpec(nanobind::iterable py_sharding, + nanobind::iterable py_mesh_mapping) + : sharding_(xla::IterableToVector(py_sharding)), + mesh_mapping_( + xla::IterableToVector(py_mesh_mapping)) {} + + const std::vector& GetSharding() const { return sharding_; } + const std::vector& GetMeshMapping() const { + return mesh_mapping_; + } + + bool operator==(const ShardingSpec& other) const { + return sharding_ == other.sharding_ && mesh_mapping_ == other.mesh_mapping_; + } + + bool operator!=(const ShardingSpec& other) const { return !(*this == other); } + + template + friend H AbslHashValue(H h, const ShardingSpec& key); + + private: + // `sharding` specifies how the array is supposed to get partitioned into + // chunks. Its length matchs the rank of the array. See the docstring + // of `AvalDimSharding` for the supported partitioning schemes. + std::vector sharding_; + // `mesh_mapping` describes an assignments of the array chunks created by + // `sharding` to a logical device mesh. The length of the tuple is equal to + // the rank of the mesh. Each mesh dimension can either get partitions of + // data varying along one of the sharded dimensions, or the data can be + // replicated. + std::vector mesh_mapping_; +}; + +template +H AbslHashValue(H h, const ShardingSpec& key) { + h = H::combine(std::move(h), key.sharding_); + h = H::combine(std::move(h), key.mesh_mapping_); + return h; +} + +} // namespace jax + +#endif // JAXLIB_XLA_SHARDED_DEVICE_ARRAY_H_ diff --git a/jaxlib/xla/sharding.cc b/jaxlib/xla/sharding.cc new file mode 100644 index 000000000000..b7c7e0a7de72 --- /dev/null +++ b/jaxlib/xla/sharding.cc @@ -0,0 +1,407 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/sharding.h" + +#include + +#include +#include +#include +#include +#include +#include + +#include "absl/hash/hash.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "jaxlib/xla/nb_class_ptr.h" +#include "jaxlib/xla/py_client.h" +#include "jaxlib/xla/py_device_list.h" +#include "jaxlib/xla/sharded_device_array.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/nb_numpy.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/xla_data.pb.h" + +namespace jax { + +namespace nb = nanobind; + +// Gets `jax::PyDeviceList` from a JAX Sharding. +absl::StatusOr> GetPyDeviceList( + nb::handle sharding_py) { + nb::handle sharding(sharding_py.ptr()); + if (sharding.type().is(jax::NamedSharding::type())) { + TF_ASSIGN_OR_RETURN( + auto ns_device_list, + nb::cast(sharding)->internal_device_list()); + return ns_device_list; + } else if (sharding.type().is(jax::SingleDeviceSharding::type())) { + return nb::cast(sharding) + ->internal_device_list(); + } else if (sharding.type().is(jax::PmapSharding::type())) { + return nb::cast(sharding)->internal_device_list(); + } else if (sharding.type().is(jax::GSPMDSharding::type())) { + return nb::cast(sharding) + ->internal_device_list(); + } else { + return nb::cast>( + sharding.attr("_internal_device_list")); + } +} + +nb::object CheckAndCanonicalizeMemoryKind( + nb::object memory_kind, + const xla::nb_class_ptr& device_list) { + if (!memory_kind.is_none()) { + // If memory kind is not None, check if it's supported by the devices + // mentioned in the Sharding. + auto supported_memory_kinds = PyDeviceList::MemoryKinds(device_list); + if (!supported_memory_kinds.ok()) { + supported_memory_kinds = nb::tuple(); + } + for (nb::handle supported_memory_kind : *supported_memory_kinds) { + if (supported_memory_kind.equal(memory_kind)) { + return memory_kind; + } + } + auto addressable_device_list = + PyDeviceList::AddressableDeviceList(device_list); + if (addressable_device_list->Len() == 0) { + // If the device list is not addressable, we can't check if the memory + // kind is supported, so we assume it is. + return memory_kind; + } + nb::object device_kind = + addressable_device_list->GetItem(0).attr("device_kind"); + absl::string_view device_kind_str = + nb::cast(device_kind); + auto py_str_formatter = [](std::string* out, nb::handle h) { + *out += nb::cast(nb::str(h)); + }; + throw nb::value_error( + absl::StrCat( + "Could not find memory addressable by device ", device_kind_str, + ". Device ", device_kind_str, + " can address the following memory kinds: ", + absl::StrJoin(*supported_memory_kinds, ", ", py_str_formatter), + ". Got memory kind: ", nb::cast(memory_kind)) + .c_str()); + } + // If memory kind is None, canonicalize to default memory. + absl::StatusOr default_memory_kind = + PyDeviceList::DefaultMemoryKind(device_list); + if (!default_memory_kind.ok()) { + return nb::none(); + } + return *std::move(default_memory_kind); +} + +int Sharding::SafeNumDevices(nb::handle sharding) { + const jax::Sharding* cpp_sharding; + if (nb::try_cast(sharding, cpp_sharding)) { + if (cpp_sharding->num_devices_.has_value()) { + return (*cpp_sharding->num_devices_); + } + } + nb::set device_set = sharding.attr("device_set"); + return device_set.size(); +} + +size_t ShardingHash(nb::handle sharding) { + auto type = sharding.type(); + + if (type.is(NamedSharding::type())) { + const auto* named_sharding = nb::inst_ptr(sharding); + return absl::Hash()(named_sharding->mesh().ptr()); + } + + if (type.is(GSPMDSharding::type())) { + auto* gspmd_sharding = nb::inst_ptr(sharding); + return gspmd_sharding->Hash(); + } + + if (type.is(SingleDeviceSharding::type())) { + auto* single_device_sharding = nb::inst_ptr(sharding); + return absl::Hash()(single_device_sharding->device().ptr()); + } + + return nb::hash(sharding); +} + +bool ShardingEqual(nb::handle a, nb::handle b) { + if (a.ptr() == b.ptr()) return true; + + auto a_type = a.type(); + auto b_type = b.type(); + + if (!a_type.is(b_type)) return false; + + if (a_type.is(NamedSharding::type())) { + auto* a_named_sharding = nb::inst_ptr(a); + auto* b_named_sharding = nb::inst_ptr(b); + + return a_named_sharding->mesh().ptr() == b_named_sharding->mesh().ptr() && + a_named_sharding->spec().equal(b_named_sharding->spec()) && + a_named_sharding->memory_kind().equal( + b_named_sharding->memory_kind()) && + a_named_sharding->logical_device_ids().equal( + b_named_sharding->logical_device_ids()); + } + + if (a_type.is(GSPMDSharding::type())) { + auto* a_gspmd_sharding = nb::inst_ptr(a); + auto* b_gspmd_sharding = nb::inst_ptr(b); + + return a_gspmd_sharding == b_gspmd_sharding; + } + + if (a_type.is(SingleDeviceSharding::type())) { + auto* a_single_device_sharding = + nb::inst_ptr(a); + auto* b_single_device_sharding = + nb::inst_ptr(b); + + return a_single_device_sharding->device().ptr() == + b_single_device_sharding->device().ptr() && + a_single_device_sharding->memory_kind().equal( + b_single_device_sharding->memory_kind()); + } + + return a.equal(b); +} + +// This list is to check for valid memory kinds when an AbstractMesh is passed +// to NamedSharding. +static const std::array valid_memory_kinds = { + "device", + "pinned_host", + "unpinned_host", +}; + +NamedSharding::NamedSharding(nb::object mesh, nb::object spec, + nb::object memory_kind, + nb::object logical_device_ids) + : Sharding(/*num_devices=*/[&mesh]() { + return nb::cast(mesh.attr("size")); + }()), + mesh_(std::move(mesh)), + spec_(std::move(spec)), + memory_kind_(std::move(memory_kind)), + logical_device_ids_(std::move(logical_device_ids)) { + if (spec_.is_none()) { + throw nb::type_error( + "Unexpected None passed as spec for NamedSharding. Did you mean P()?"); + } + nb::object idl = nb::object(mesh_.attr("_internal_device_list")); + if (idl.is_none()) { + internal_device_list_ = std::nullopt; + } else { + internal_device_list_ = nb::cast>(idl); + } + if (internal_device_list_) { + memory_kind_ = + CheckAndCanonicalizeMemoryKind(memory_kind_, *internal_device_list_); + } else { + if (!memory_kind_.is_none() && + (std::find(valid_memory_kinds.begin(), valid_memory_kinds.end(), + nb::cast(memory_kind_)) == + valid_memory_kinds.end())) { + throw nb::value_error( + absl::StrCat("Got invalid memory kind: ", + nb::cast(memory_kind_), + ". Valid memory kinds are: ", + absl::StrJoin(valid_memory_kinds, ", ")) + .c_str()); + } + } + + // TODO(phawkins): this leaks a reference to the check_pspec function. + // A better way to fix this would be to move PartitionSpec and this check into + // C++. + nb::object* check_pspec = []() { + static absl::Mutex mu; + static nb::object* output = nullptr; + { + absl::MutexLock lock(&mu); + if (output) { + return output; + } + } + nb::module_ si = nb::module_::import_("jax._src.named_sharding"); + nb::object attr = si.attr("check_pspec"); + absl::MutexLock lock(&mu); + if (!output) { + output = new nb::object(attr); + } + return output; + }(); + (*check_pspec)(mesh_, spec_); +} + +/*static*/ PyObject* NamedSharding::type_ = nullptr; + +/*static*/ void NamedSharding::InitializeType() { + // Intentionally leaks a reference. + type_ = nanobind::type().inc_ref().ptr(); +} + +SingleDeviceSharding::SingleDeviceSharding(nb::object device, + nb::object memory_kind) + : Sharding(/*num_devices=*/1), + device_(device), + memory_kind_(std::move(memory_kind)), + internal_device_list_( + xla::make_nb_class(nb::make_tuple(std::move(device)))) { + memory_kind_ = + CheckAndCanonicalizeMemoryKind(memory_kind_, internal_device_list_); +} + +/*static*/ PyObject* SingleDeviceSharding::type_ = nullptr; + +/*static*/ void SingleDeviceSharding::InitializeType() { + // Intentionally leaks a reference. + type_ = nanobind::type().inc_ref().ptr(); +} + +SingleDeviceSharding::SingleDeviceSharding( + xla::nb_class_ptr client, + xla::ifrt::DeviceListRef device_list, nb::object memory_kind) + : Sharding(/*num_devices=*/1), + device_(client->GetPyDevice(device_list->devices().front())), + memory_kind_(std::move(memory_kind)), + internal_device_list_(xla::make_nb_class( + std::move(client), std::move(device_list))) { + memory_kind_ = + CheckAndCanonicalizeMemoryKind(memory_kind_, internal_device_list_); +} + +PmapSharding::PmapSharding(xla::nb_numpy_ndarray devices, + ShardingSpec sharding_spec) + : Sharding(/*num_devices=*/devices.size()), + devices_(std::move(devices)), + sharding_spec_(std::move(sharding_spec)) { + nb::object flat_devices = devices_.attr("flat"); + internal_device_list_ = + xla::make_nb_class(nb::tuple(flat_devices)); +} + +/*static*/ PyObject* PmapSharding::type_ = nullptr; + +// /*static*/ nanobind::handle PmapSharding::type() { return type_; } + +/*static*/ void PmapSharding::InitializeType() { + // Intentionally leaks a reference. + type_ = nanobind::type().inc_ref().ptr(); +} + +GSPMDSharding::GSPMDSharding(nb::sequence devices, xla::HloSharding op_sharding, + nb::object memory_kind, nb::object device_list) + : Sharding(/*num_devices=*/nb::len(devices.ptr())), + devices_(nb::tuple(devices)), + hlo_sharding_(std::move(op_sharding)), + memory_kind_(std::move(memory_kind)) { + if (device_list.is_none()) { + internal_device_list_ = xla::make_nb_class(devices_); + } else { + internal_device_list_ = + nb::cast>(std::move(device_list)); + } + // This checks in python if the memory kind is correct for the given + // devices. Currently in python this check is optimized but we want to + // move that check to C++ after which we can remove this call. + CHECK(devices_.size() != 0) + << "Devices given to GSPMDSharding must not be empty"; + memory_kind_ = + CheckAndCanonicalizeMemoryKind(memory_kind_, internal_device_list_); +} + +/*static*/ PyObject* GSPMDSharding::type_ = nullptr; + +/*static*/ void GSPMDSharding::InitializeType() { + // Intentionally leaks a reference. + type_ = nanobind::type().inc_ref().ptr(); +} + +void RegisterSharding(nb::module_& m) { + nb::class_(m, "Sharding").def(nb::init<>()); + + nb::class_(m, "NamedSharding", nb::dynamic_attr()) + .def(nb::init(), + nb::arg("mesh"), nb::arg("spec").none(), + nb::arg("memory_kind").none() = nb::none(), + nb::arg("_logical_device_ids").none() = nb::none()) + .def_prop_ro("mesh", &NamedSharding::mesh) + .def_prop_ro("spec", &NamedSharding::spec) + .def_prop_ro("_memory_kind", &NamedSharding::memory_kind) + .def_prop_ro("_logical_device_ids", &NamedSharding::logical_device_ids) + .def_prop_ro("_internal_device_list", [](const NamedSharding& s) { + return xla::ValueOrThrow(s.internal_device_list()); + }); + NamedSharding::InitializeType(); + + nb::class_(m, "SingleDeviceSharding", + nb::dynamic_attr()) + .def(nb::init(), nb::arg("device"), + nb::arg("memory_kind").none() = nb::none()) + .def_prop_ro("_device", &SingleDeviceSharding::device) + .def_prop_ro("_memory_kind", &SingleDeviceSharding::memory_kind) + .def_prop_ro("_internal_device_list", + &SingleDeviceSharding::internal_device_list); + SingleDeviceSharding::InitializeType(); + + nb::class_(m, "PmapSharding", nb::dynamic_attr()) + .def( + "__init__", + [](PmapSharding* self, nb::object devices, + ShardingSpec sharding_spec) { + new (self) PmapSharding(xla::nb_numpy_ndarray::ensure(devices), + std::move(sharding_spec)); + }, + nb::arg("devices"), nb::arg("sharding_spec")) + .def_prop_ro("devices", &PmapSharding::devices) + .def_prop_ro("sharding_spec", &PmapSharding::sharding_spec) + .def_prop_ro("_internal_device_list", + &PmapSharding::internal_device_list); + PmapSharding::InitializeType(); + + nb::class_(m, "GSPMDSharding", nb::dynamic_attr()) + .def(nb::init(), + nb::arg("devices"), nb::arg("op_sharding"), + nb::arg("memory_kind").none() = nb::none(), + nb::arg("_device_list").none() = nb::none()) + .def(nb::init(), + nb::arg("devices"), nb::arg("op_sharding"), + nb::arg("memory_kind").none() = nb::none(), + nb::arg("_device_list").none() = nb::none()) + .def_prop_ro("_devices", &GSPMDSharding::devices) + .def_prop_ro("_hlo_sharding", &GSPMDSharding::hlo_sharding) + .def_prop_ro("_memory_kind", &GSPMDSharding::memory_kind) + .def_prop_ro("_internal_device_list", + &GSPMDSharding::internal_device_list); + GSPMDSharding::InitializeType(); +} + +} // namespace jax diff --git a/jaxlib/xla/sharding.h b/jaxlib/xla/sharding.h new file mode 100644 index 000000000000..e0c54592259b --- /dev/null +++ b/jaxlib/xla/sharding.h @@ -0,0 +1,241 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_SHARDING_H_ +#define JAXLIB_XLA_SHARDING_H_ + +#include + +#include +#include +#include + +// placeholder for index annotation headers +#include "absl/hash/hash.h" +#include "absl/status/statusor.h" +#include "nanobind/nanobind.h" +#include "jaxlib/xla/nb_class_ptr.h" +#include "jaxlib/xla/py_client.h" +#include "jaxlib/xla/py_device_list.h" +#include "jaxlib/xla/sharded_device_array.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/nb_numpy.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" + +namespace jax { + +class Sharding { + public: + Sharding() = default; + + // This constructor is used in the fast path to retrieve the number of devices + // without falling back to python. This is only used in the cpp path. + explicit Sharding(int num_devices) : num_devices_(num_devices) {} + + virtual ~Sharding() = default; + + static int SafeNumDevices(nanobind::handle sharding); + + private: + std::optional num_devices_; +}; + +// Gets `jax::PyDeviceList` from a JAX Sharding. +absl::StatusOr> GetPyDeviceList( + nanobind::handle sharding_py); + +// Checks if the memory kind is valid, and canonicalizes the +// memory kind to default memory on backends that support memories. +nanobind::object CheckAndCanonicalizeMemoryKind( + nanobind::object memory_kind, + const xla::nb_class_ptr& device_list); + +// Returns a hash that may sometimes return different hashes for equal values. +// It is not a correct implementation of `__hash__` in python, but it's fine +// for jit/pjit dispatch since it only causes spurious cache misses. +size_t ShardingHash(nanobind::handle sharding); + +bool ShardingEqual(nanobind::handle a, nanobind::handle b); + +class NamedSharding : public Sharding { + public: + NamedSharding(nanobind::object mesh, nanobind::object spec, + nanobind::object memory_kind, + nanobind::object logical_device_ids); + + const nanobind::object& mesh() const { return mesh_; } + const nanobind::object& spec() const { return spec_; } + const nanobind::object& memory_kind() const { return memory_kind_; } + const nanobind::object& logical_device_ids() const { + return logical_device_ids_; + } + + static nanobind::handle type() { return type_; } + static void InitializeType(); + + absl::StatusOr> internal_device_list() const { + if (internal_device_list_) { + return *internal_device_list_; + } + return xla::InvalidArgument( + "internal_device_list is not implemented for " + "`jax.sharding.AbstractMesh`"); + } + + private: + nanobind::object mesh_; + nanobind::object spec_; + nanobind::object memory_kind_; + nanobind::object logical_device_ids_; + std::optional> internal_device_list_; + static PyObject* type_; +}; + +class SingleDeviceSharding : public Sharding { + public: + explicit SingleDeviceSharding( + nanobind::object device, nanobind::object memory_kind = nanobind::none()); + + // Used only in C++ to accelerate `PyArray::MakeFromSingleDeviceArray()`. + SingleDeviceSharding(xla::nb_class_ptr client, + xla::ifrt::DeviceListRef device_list, + nanobind::object memory_kind); + + const nanobind::object& device() const { return device_; } + const nanobind::object& memory_kind() const { return memory_kind_; } + + static nanobind::handle type() { return type_; } + static void InitializeType(); + + xla::nb_class_ptr internal_device_list() const { + return internal_device_list_; + } + + private: + nanobind::object device_; + nanobind::object memory_kind_; + xla::nb_class_ptr internal_device_list_; + + static PyObject* type_; +}; + +// The C++ implementation of jax.PmapSharding in python. It contains a few key +// data members and methods that are performance-critical. +class PmapSharding : public Sharding { + public: + PmapSharding(xla::nb_numpy_ndarray devices, ShardingSpec sharding_spec); + + ~PmapSharding() override = default; + + xla::nb_numpy_ndarray devices() const { return devices_; } + + const ShardingSpec& sharding_spec() const { return sharding_spec_; } + + static nanobind::handle type() { return type_; } + static void InitializeType(); + + xla::nb_class_ptr internal_device_list() const { + return internal_device_list_; + } + + private: + xla::nb_numpy_ndarray devices_; + ShardingSpec sharding_spec_; + xla::nb_class_ptr internal_device_list_; + static PyObject* type_; +}; + +class GSPMDSharding : public Sharding { + public: + GSPMDSharding(nanobind::sequence devices, xla::OpSharding op_sharding, + nanobind::object memory_kind, nanobind::object device_list) + : GSPMDSharding( + std::move(devices), + xla::ValueOrThrow(xla::HloSharding::FromProto(op_sharding)), + std::move(memory_kind), std::move(device_list)) {} + + GSPMDSharding(nanobind::sequence devices, xla::HloSharding op_sharding, + nanobind::object memory_kind, nanobind::object device_list); + + const nanobind::tuple& devices() const { return devices_; } + const nanobind::object& memory_kind() const { return memory_kind_; } + + size_t Hash() { + if (!hash_.has_value()) { + hash_ = CalculateHash(); + } + return *hash_; + } + + static nanobind::handle type() { return type_; } + static void InitializeType(); + + const xla::HloSharding& hlo_sharding() const { return hlo_sharding_; } + + bool operator==(const GSPMDSharding& other) const { + return AreOpShardingsEqual(*this, other) && + this->devices().equal(other.devices()) && + this->memory_kind().equal(other.memory_kind()); + } + + xla::nb_class_ptr internal_device_list() const { + return internal_device_list_; + } + + private: + size_t CalculateHash() const { + // We only hash `hlo_sharding_` here for performance. + return absl::Hash()(hlo_sharding_); + } + + static bool AreOpShardingsEqual(const GSPMDSharding& a, + const GSPMDSharding& b) { + // If the OpSharding object is the same, return true + if (&a.hlo_sharding() == &b.hlo_sharding()) { + return true; + } + // If both OpShardings are replicated, return true + if (a.IsOpShardingReplicated() && b.IsOpShardingReplicated()) { + return true; + } + return a.hlo_sharding() == b.hlo_sharding(); + } + + bool IsOpShardingReplicated() const { + // For JAX, shardings with 1 device are considered as replicated in its + // semantics so that downstream things continue to work. + if (hlo_sharding_.tile_assignment().num_elements() == 1) { + return true; + } + return hlo_sharding().IsReplicated(); + } + + nanobind::tuple devices_; + xla::HloSharding hlo_sharding_; + nanobind::object memory_kind_; + std::optional hash_; + xla::nb_class_ptr internal_device_list_; + + static PyObject* type_; +}; + +void RegisterSharding(nanobind::module_& m); + +} // namespace jax + +#endif // JAXLIB_XLA_SHARDING_H_ diff --git a/jaxlib/xla/to_ifrt_sharding.cc b/jaxlib/xla/to_ifrt_sharding.cc new file mode 100644 index 000000000000..52879cfa9fbe --- /dev/null +++ b/jaxlib/xla/to_ifrt_sharding.cc @@ -0,0 +1,141 @@ +/* Copyright 2025 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/to_ifrt_sharding.h" + +#include +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "jaxlib/xla/nb_class_ptr.h" +#include "jaxlib/xla/py_device_list.h" +#include "jaxlib/xla/sharding.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/pjrt_ifrt/pjrt_dtype.h" +#include "xla/python/pjrt_ifrt/xla_sharding.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +namespace nb = ::nanobind; + +// Gets `xla::HloSharding` from a JAX Sharding. +xla::HloSharding GetXlaHloSharding(nb::handle sharding, + int64_t num_dimensions) { + if (sharding.type().is(nb::handle(jax::GSPMDSharding::type().ptr()))) { + return nb::cast(nb::handle(sharding.ptr())) + ->hlo_sharding(); + } else { + return nb::cast( + sharding.attr("_to_xla_hlo_sharding")(num_dimensions)); + } +} + +// Gets `xla::ifrt::DeviceList` from a JAX Sharding. +absl::StatusOr GetIfrtDeviceList( + nb::handle sharding_py) { + TF_ASSIGN_OR_RETURN(auto py_device_list, jax::GetPyDeviceList(sharding_py)); + return py_device_list->ifrt_device_list(); +} + +// Gets `xla::ifrt::MemoryKind` from a JAX Sharding. +xla::ifrt::MemoryKind GetMemoryKind(nb::handle sharding) { + nb::object py_memory_kind = nb::none(); + + // sharding.attr("memory_kind") can crash if sharding was originally created + // from C++ and casted into a Python Sharding object. Thus, we cast sharding + // to a C++ type and use C++ `memory_kind()` method, which bypasses any Python + // attribute access. + nb::handle type = sharding.type(); + if (type.is(jax::NamedSharding::type())) { + py_memory_kind = + nb::cast(sharding)->memory_kind(); + } else if (type.is(jax::SingleDeviceSharding::type())) { + py_memory_kind = + nb::cast(sharding)->memory_kind(); + } else if (type.is(jax::GSPMDSharding::type())) { + py_memory_kind = + nb::cast(sharding)->memory_kind(); + } else { + py_memory_kind = sharding.attr("memory_kind"); + } + + if (py_memory_kind.is_none()) { + return xla::ifrt::MemoryKind(); + } + return xla::ifrt::MemoryKind(nb::cast(py_memory_kind)); +} + +// Converts a JAX Sharding into `xla::ifrt::HloSharding`. +absl::StatusOr> GetIfrtHloSharding( + nb::handle sharding, const xla::ifrt::Shape& shape) { + TF_ASSIGN_OR_RETURN(xla::ifrt::DeviceListRef device_list, + GetIfrtDeviceList(sharding)); + xla::ifrt::MemoryKind memory_kind = GetMemoryKind(sharding.ptr()); + xla::HloSharding hlo_sharding = + GetXlaHloSharding(sharding, shape.dims().size()); + return xla::ifrt::HloSharding::Create( + std::move(device_list), std::move(memory_kind), std::move(hlo_sharding)); +} + +// Converts a JAX Sharding into `xla::ifrt::ConcreteEvenSharding`. +absl::StatusOr> +GetIfrtConcreteEvenSharding(nb::handle sharding, xla::ifrt::DType dtype, + const xla::ifrt::Shape& shape) { + TF_ASSIGN_OR_RETURN(xla::ifrt::DeviceListRef device_list, + GetIfrtDeviceList(sharding)); + xla::ifrt::MemoryKind memory_kind = GetMemoryKind(sharding.ptr()); + TF_ASSIGN_OR_RETURN(xla::PrimitiveType xla_primitive_type, + xla::ifrt::ToPrimitiveType(dtype)); + // The XLA shape's layout is irrelevant because we only need to know the + // tile shape, which is independent from the layout. + xla::Shape xla_shape = xla::ShapeUtil::MakeShapeWithDescendingLayout( + xla_primitive_type, shape.dims()); + xla::HloSharding hlo_sharding = + GetXlaHloSharding(sharding, shape.dims().size()); + xla::Shape tile_shape = hlo_sharding.TileShape(xla_shape); + xla::ifrt::Shape shard_shape(xla::ifrt::Shape::Dimensions( + tile_shape.dimensions().begin(), tile_shape.dimensions().end())); + return xla::ifrt::ConcreteEvenSharding::Create( + std::move(device_list), std::move(memory_kind), shape, + /*shard_shape=*/std::move(shard_shape)); +} + +// Converts a JAX Sharding into `xla::ifrt::ConcreteSharding`. +absl::StatusOr> +GetIfrtConcreteSharding(nb::handle sharding, const xla::ifrt::Shape& shape, + std::vector shard_shapes) { + TF_ASSIGN_OR_RETURN(xla::ifrt::DeviceListRef device_list, + GetIfrtDeviceList(sharding)); + xla::ifrt::MemoryKind memory_kind = GetMemoryKind(sharding.ptr()); + return xla::ifrt::ConcreteSharding::Create( + std::move(device_list), std::move(memory_kind), shape, + /*shard_shapes=*/std::move(shard_shapes)); +} + +} // namespace xla diff --git a/jaxlib/xla/to_ifrt_sharding.h b/jaxlib/xla/to_ifrt_sharding.h new file mode 100644 index 000000000000..ebc999888297 --- /dev/null +++ b/jaxlib/xla/to_ifrt_sharding.h @@ -0,0 +1,62 @@ +/* Copyright 2025 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_TO_IFRT_SHARDING_H_ +#define JAXLIB_XLA_TO_IFRT_SHARDING_H_ + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "nanobind/nanobind.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" + +namespace xla { + +// Gets `xla::HloSharding` from a JAX Sharding. +xla::HloSharding GetXlaHloSharding(nanobind::handle sharding, + int64_t num_dimensions); + +// Gets `xla::ifrt::DeviceList` from a JAX Sharding. +absl::StatusOr GetIfrtDeviceList( + nanobind::handle sharding_py); + +// Gets `xla::ifrt::MemoryKind` from a JAX Sharding. +xla::ifrt::MemoryKind GetMemoryKind(nanobind::handle sharding); + +// Converts a JAX Sharding into `xla::ifrt::HloSharding`. +absl::StatusOr> GetIfrtHloSharding( + nanobind::handle sharding, const xla::ifrt::Shape& shape); + +// Converts a JAX Sharding into `xla::ifrt::ConcreteEvenSharding`. +absl::StatusOr> +GetIfrtConcreteEvenSharding(nanobind::handle sharding, xla::ifrt::DType dtype, + const xla::ifrt::Shape& shape); + +// Converts a JAX Sharding into `xla::ifrt::ConcreteSharding`. +absl::StatusOr> +GetIfrtConcreteSharding(nanobind::handle sharding, + const xla::ifrt::Shape& shape, + std::vector shard_shapes); + +} // namespace xla + +#endif // JAXLIB_XLA_TO_IFRT_SHARDING_H_ diff --git a/jaxlib/xla/traceback.cc b/jaxlib/xla/traceback.cc new file mode 100644 index 000000000000..35085b3e32fa --- /dev/null +++ b/jaxlib/xla/traceback.cc @@ -0,0 +1,357 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/traceback.h" + +#include + +#include +#include +#include +#include +#include + +#include "absl/base/casts.h" +#include "absl/hash/hash.h" +#include "absl/log/check.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/xla/nb_class_ptr.h" +#include "xla/pjrt/exceptions.h" +#include "tsl/platform/platform.h" + +#ifdef PLATFORM_GOOGLE +#define Py_BUILD_CORE +#include "internal/pycore_frame.h" +#undef Py_BUILD_CORE +#endif // PLATFORM_GOOGLE + +namespace xla { + +namespace nb = nanobind; + +bool Traceback::enabled_ = true; + +Traceback::Traceback() { + DCHECK(PyGILState_Check()); + PyThreadState* thread_state = PyThreadState_GET(); + +#if PY_VERSION_HEX < 0x030b0000 + // The representation of frame->f_lasti changed from bytes to words in Python + // 3.10, see https://docs.python.org/3/whatsnew/3.10.html#changes-in-the-c-api + // This should match sizeof(_Py_CODEUNIT) which is unfortunately private. + constexpr int kLastiWordBytes = 2; + + for (PyFrameObject* py_frame = thread_state->frame; py_frame != nullptr; + py_frame = py_frame->f_back) { + Py_INCREF(py_frame->f_code); + frames_.emplace_back(py_frame->f_code, py_frame->f_lasti * kLastiWordBytes); + } +#else // PY_VERSION_HEX < 0x030b0000 + +#ifdef PLATFORM_GOOGLE + // This code is equivalent to the version using public APIs, but it saves us + // an allocation of one object per stack frame. However, this is definitely + // violating the API contract of CPython, so we only use this where we can be + // confident we know exactly which CPython we are using (internal to Google). + // Feel free to turn this on if you like, but it might break at any time! + for (_PyInterpreterFrame* f = thread_state->cframe->current_frame; + f != nullptr; f = f->previous) { + if (_PyFrame_IsIncomplete(f)) continue; + Py_INCREF(f->f_code); + frames_.emplace_back(f->f_code, + _PyInterpreterFrame_LASTI(f) * sizeof(_Py_CODEUNIT)); + } +#else // PLATFORM_GOOGLE + PyFrameObject* next; + for (PyFrameObject* py_frame = PyThreadState_GetFrame(thread_state); + py_frame != nullptr; py_frame = next) { + frames_.emplace_back(PyFrame_GetCode(py_frame), PyFrame_GetLasti(py_frame)); + next = PyFrame_GetBack(py_frame); + Py_XDECREF(py_frame); + } +#endif // PLATFORM_GOOGLE + +#endif // PY_VERSION_HEX < 0x030b0000 +} + +Traceback::~Traceback() { + for (auto& frame : frames_) { + DCHECK(PyGILState_Check()); + Py_DECREF(frame.first); + } +} + +Traceback::Traceback(Traceback&& other) noexcept + : frames_(std::move(other.frames_)) { + // absl::InlinedVector does not always clear itself if moved. Since we rely on + // its empty() method to destroy Traceback differently, we explicitly clear + // here. + other.frames_.clear(); +} + +std::string Traceback::Frame::ToString() const { + return absl::StrFormat("%s:%d (%s)", nb::cast(file_name), + line_num, nb::cast(function_name)); +} + +std::string Traceback::ToString() const { + std::vector frame_strs; + frame_strs.reserve(frames_.size()); + for (const Frame& frame : Frames()) { + frame_strs.push_back(frame.ToString()); + } + return absl::StrJoin(frame_strs, "\n"); +} + +std::vector Traceback::Frames() const { + // We require the GIL because we manipulate Python strings. + CHECK(PyGILState_Check()); + std::vector frames; + frames.reserve(frames_.size()); + for (const auto& frame : frames_) { + frames.push_back(Frame{nb::borrow(frame.first->co_filename), + nb::borrow(frame.first->co_name), + frame.first->co_firstlineno, + PyCode_Addr2Line(frame.first, frame.second)}); + } + return frames; +} + +std::optional> Traceback::Get() { + DCHECK(PyGILState_Check()); + if (!enabled_) { + return std::nullopt; + } + return make_nb_class(); +} + +void Traceback::SetEnabled(bool enabled) { enabled_ = enabled; } + +nb::object Traceback::AsPythonTraceback() const { + nb::object traceback = nb::none(); + nb::dict globals; + nb::handle traceback_type(reinterpret_cast(&PyTraceBack_Type)); + for (const std::pair& frame : frames_) { + int lineno = PyCode_Addr2Line(frame.first, frame.second); + // Under Python 3.11 we observed crashes when using a fake PyFrameObject + // with a real PyCodeObject (https://github.com/google/jax/issues/16027). + // because the frame does not have fields necessary to compute the locals, + // notably the closure object, leading to crashes in CPython in + // _PyFrame_FastToLocalsWithError + // https://github.com/python/cpython/blob/deaf509e8fc6e0363bd6f26d52ad42f976ec42f2/Objects/frameobject.c#LL1116C2-L1116C2 + // We therefore always build a fake code object to go along with our fake + // frame. + PyCodeObject* py_code = + PyCode_NewEmpty(PyUnicode_AsUTF8(frame.first->co_filename), + PyUnicode_AsUTF8(frame.first->co_name), lineno); + PyFrameObject* py_frame = PyFrame_New(PyThreadState_Get(), py_code, + globals.ptr(), /*locals=*/nullptr); + Py_DECREF(py_code); + + traceback = traceback_type( + /*tb_next=*/std::move(traceback), + /*tb_frame=*/ + nb::steal(reinterpret_cast(py_frame)), + /*tb_lasti=*/0, + /*tb_lineno=*/ + PyCode_Addr2Line(frame.first, frame.second)); + } + return traceback; +} + +namespace { + +Py_hash_t traceback_tp_hash(PyObject* o) { + Traceback* tb; + if (!nb::try_cast(nb::handle(o), tb)) { + PyErr_SetString(PyExc_TypeError, "Expected a Traceback object"); + return -1; + } + size_t h = absl::HashOf(*tb); + Py_hash_t s = absl::bit_cast(h); // Python hashes are signed. + return s == -1 ? -2 : s; // -1 must not be used as a Python hash value. +} + +PyObject* traceback_tp_richcompare(PyObject* self, PyObject* other, int op) { + if (op != Py_EQ && op != Py_NE) { + return Py_NewRef(Py_NotImplemented); + } + + Traceback* x; + if (!nb::try_cast(nb::handle(self), x)) { + PyErr_SetString(PyExc_TypeError, "Expected a Traceback object"); + return nullptr; + } + + bool result; + Traceback* y; + if (nb::try_cast(nb::handle(other), y)) { + result = ((*x == *y) == (op == Py_EQ)); + } else { + result = (op == Py_NE); + } + return Py_NewRef(result ? Py_True : Py_False); +} + +// It turns out to be slightly faster to define a tp_hash slot rather than +// defining __hash__ and __eq__ on the class. +PyType_Slot traceback_slots_[] = { + {Py_tp_hash, (void*)traceback_tp_hash}, + {Py_tp_richcompare, (void*)traceback_tp_richcompare}, + {0, nullptr}, +}; + +} // namespace + +void BuildTracebackSubmodule(nb::module_& m) { + nb::class_(m, "Frame") + .def(nb::init()) + .def_ro("file_name", &Traceback::Frame::file_name) + .def_ro("function_name", &Traceback::Frame::function_name) + .def_ro("function_start_line", &Traceback::Frame::function_start_line) + .def_ro("line_num", &Traceback::Frame::line_num) + .def("__repr__", [](const Traceback::Frame& frame) { + return absl::StrFormat( + "%s;%s:%d", nb::cast(frame.function_name), + nb::cast(frame.file_name), frame.line_num); + }); + + nb::class_ traceback(m, "Traceback", + nb::type_slots(traceback_slots_), + "Represents a Python stack trace."); + traceback.def_prop_rw_static( + "enabled", [](nb::object /* cls */) { return Traceback::enabled(); }, + [](nb::object /* cls */, bool enabled) { + return Traceback::SetEnabled(enabled); + }); + traceback.def_static( + "get_traceback", []() { return Traceback::Get(); }, + R"doc( + Returns a :class:`Traceback` for the current thread. + + If ``Traceback.enabled`` is ``True``, returns a :class:`Traceback` object + that describes the Python stack of the calling thread. Stack trace + collection has a small overhead, so it is disabled by default. If traceback + collection is disabled, returns ``None``. + )doc"); + traceback.def_prop_ro("frames", &Traceback::Frames); + traceback.def("raw_frames", [](const Traceback& tb) -> nb::tuple { + // We return a tuple of lists, rather than a list of tuples, because it + // is cheaper to allocate only three Python objects for everything rather + // than one per frame. + nb::list out_code = nb::steal(PyList_New(tb.raw_frames().size())); + nb::list out_lasti = + nb::steal(PyList_New(tb.raw_frames().size())); + for (size_t i = 0; i < tb.raw_frames().size(); ++i) { + const auto& frame = tb.raw_frames()[i]; + PyObject* code = reinterpret_cast(frame.first); + Py_INCREF(code); + PyList_SET_ITEM(out_code.ptr(), i, code); + PyList_SET_ITEM(out_lasti.ptr(), i, + nb::int_(frame.second).release().ptr()); + } + return nb::make_tuple(out_code, out_lasti); + }); + traceback.def("__str__", &Traceback::ToString); + traceback.def("as_python_traceback", &Traceback::AsPythonTraceback); + + traceback.def_static( + "traceback_from_frames", + [](std::vector frames) { + nb::object traceback = nb::none(); + nb::dict globals; + nb::handle traceback_type( + reinterpret_cast(&PyTraceBack_Type)); + for (const Traceback::Frame& frame : frames) { + PyCodeObject* py_code = + PyCode_NewEmpty(frame.file_name.c_str(), + frame.function_name.c_str(), frame.line_num); + PyFrameObject* py_frame = PyFrame_New(PyThreadState_Get(), py_code, + globals.ptr(), /*locals=*/ + nullptr); + Py_DECREF(py_code); + traceback = traceback_type( + /*tb_next=*/std::move(traceback), + /*tb_frame=*/ + nb::steal(reinterpret_cast(py_frame)), + /*tb_lasti=*/0, + /*tb_lineno=*/ + frame.line_num); + } + return traceback; + }, + "Creates a traceback from a list of frames."); + + traceback.def_static( + "code_addr2line", + [](nb::handle code, int lasti) { + if (!PyCode_Check(code.ptr())) { + throw xla::XlaRuntimeError("code argument must be a code object"); + } + return PyCode_Addr2Line(reinterpret_cast(code.ptr()), + lasti); + }, + "Python wrapper around the Python C API function PyCode_Addr2Line"); + +#if PY_VERSION_HEX >= 0x030b0000 + traceback.def_static( + "code_addr2location", + [](nb::handle code, int lasti) { + if (!PyCode_Check(code.ptr())) { + throw xla::XlaRuntimeError("code argument must be a code object"); + } + int start_line, start_column, end_line, end_column; + if (!PyCode_Addr2Location(reinterpret_cast(code.ptr()), + lasti, &start_line, &start_column, &end_line, + &end_column)) { + throw nb::python_error(); + } + return nb::make_tuple(start_line, start_column, end_line, end_column); + }, + "Python wrapper around the Python C API function PyCode_Addr2Location"); +#endif // PY_VERSION_HEX >= 0x030b0000 + +#if PY_VERSION_HEX < 0x030b0000 + // This function replaces the exception traceback associated with the current + // Python thread. + m.def( + "replace_thread_exc_traceback", + [](nb::object tb) { + if (!tb.is_none() && !PyTraceBack_Check(tb.ptr())) { + throw xla::XlaRuntimeError( + "argument must be a traceback object or None"); + } + PyThreadState* thread_state = PyThreadState_Get(); + if (!thread_state->exc_info->exc_traceback) { + throw xla::XlaRuntimeError( + "Current thread does not have an active " + "exception traceback"); + } + PyObject* old_exc_traceback = thread_state->exc_info->exc_traceback; + PyObject* new_tb = tb.is_none() ? nullptr : tb.release().ptr(); + thread_state->exc_info->exc_traceback = new_tb; + Py_XDECREF(old_exc_traceback); + }, + nb::arg("traceback").none()); +#endif // PY_VERSION_HEX < 0x30b0000 +} +} // namespace xla diff --git a/jaxlib/xla/traceback.h b/jaxlib/xla/traceback.h new file mode 100644 index 000000000000..685ecc5f8793 --- /dev/null +++ b/jaxlib/xla/traceback.h @@ -0,0 +1,109 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_TRACEBACK_H_ +#define JAXLIB_XLA_TRACEBACK_H_ + +#include + +#include +#include +#include +#include + +// placeholder for index annotation headers +#include "absl/container/inlined_vector.h" +#include "nanobind/nanobind.h" +#include "jaxlib/xla/nb_class_ptr.h" + +namespace xla { + +// Represents a Python traceback. This object is designed to be allocated on +// the Python heap; creating or destroying a traceback requires the GIL. +class Traceback { + public: + // Requires GIL. Creates a Traceback object that requires destructor to be + // invoked with GIL held as well. + static std::optional> Get(); + + // Requires GIL. + static bool enabled() { return enabled_; } + // Requires GIL. + static void SetEnabled(bool enabled); + + // Requires GIL. Don't call this directly, you're looking for Get(). + Traceback(); + // Requires GIL. + ~Traceback(); + + Traceback(const Traceback&) = delete; + Traceback(Traceback&& other) noexcept; + Traceback& operator=(const Traceback&) = delete; + Traceback& operator=(Traceback&&) = delete; + + // Requires the GIL be held. + std::string ToString() const; + + struct Frame { + nanobind::str file_name; + nanobind::str function_name; + int function_start_line; + int line_num; + + std::string ToString() const; + }; + std::vector Frames() const; + + const absl::InlinedVector, 32>& raw_frames() + const { + return frames_; + } + + // Returns the traceback as a fake Python Traceback object, suitable for + // using as an exception traceback. + nanobind::object AsPythonTraceback() const; + + bool operator==(const Traceback& other) const { + return frames_ == other.frames_; + } + bool operator!=(const Traceback& other) const { + return frames_ != other.frames_; + } + + private: + // Each frame is a pair of a code object and a "lasti" instruction location + // in bytes. The size of _Py_CODEUNIT has changed across different Python + // versions; the lasti value here has already been multiplied by + // sizeof(_Py_CODEUNIT) if needed and is suitable for passing to functions + // like PyCode_Addr2Line(). + absl::InlinedVector, 32> frames_; + + // Protected by GIL. + static bool enabled_; +}; + +using nb_traceback = nb_class_ptr; + +template +H AbslHashValue(H h, const Traceback& traceback) { + h = H::combine(std::move(h), traceback.raw_frames()); + return h; +} + +void BuildTracebackSubmodule(nanobind::module_& m); + +} // namespace xla + +#endif // JAXLIB_XLA_TRACEBACK_H_ diff --git a/jaxlib/xla/util.cc b/jaxlib/xla/util.cc new file mode 100644 index 000000000000..5fb3f352ba2c --- /dev/null +++ b/jaxlib/xla/util.cc @@ -0,0 +1,85 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/util.h" + +#include +#include + +#include "absl/status/status.h" +#include "absl/synchronization/notification.h" +#include "absl/time/time.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "xla/pjrt/pjrt_future.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt/value.h" +#include "xla/python/version.h" +#include "xla/tsl/concurrency/async_value.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/util.h" + +namespace xla { + +void BlockUntilReadyWithCancel(xla::PjRtFuture<>& future) { +#if JAX_IFRT_VERSION_NUMBER >= 5 + future.BlockUntilReady([](tsl::AsyncValue* value) { + auto state = std::make_shared(); + value->AndThen([state]() { state->Notify(); }); + while (true) { + if (state->WaitForNotificationWithTimeout(absl::Milliseconds(200))) { + break; + } + nanobind::gil_scoped_acquire gil_acquire; + if (PyErr_CheckSignals() != 0) { + throw nanobind::python_error(); + } + } + }); +#endif +} + +absl::Status AwaitBuffersReady(absl::Span ifrt_arrays) { + if (ifrt_arrays.empty()) { + return absl::OkStatus(); + } + + ifrt::Future<> future; + if (ifrt_arrays.size() == 1) { + future = ifrt_arrays[0]->GetReadyFuture(); + } else { + std::vector> values; + values.reserve(ifrt_arrays.size()); + for (ifrt::Array* const ifrt_array : ifrt_arrays) { + values.push_back(tsl::FormRef(ifrt_array)); + } + ifrt::Client* const client = ifrt_arrays.front()->client(); + future = client->GetReadyFuture(values); + } + BlockUntilReadyWithCancel(future); + absl::Status s = future.Await(); + if (!s.ok()) { + // Fix up error string because some clients rely on it. + if (s.message() == "GetReadyFuture() called on deleted or donated buffer") { + s = InvalidArgument( + "BlockHostUntilReady() called on deleted or donated buffer"); + } + } + return s; +} + +} // namespace xla diff --git a/jaxlib/xla/util.h b/jaxlib/xla/util.h new file mode 100644 index 000000000000..ed3b03d733dd --- /dev/null +++ b/jaxlib/xla/util.h @@ -0,0 +1,34 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_UTIL_H_ +#define JAXLIB_XLA_UTIL_H_ + +#include "absl/status/status.h" +#include "absl/types/span.h" +#include "xla/python/ifrt/array.h" + +namespace xla { + +// Waits until future is ready but will cancel if ctrl-c is pressed. +void BlockUntilReadyWithCancel(xla::PjRtFuture<>& future); + +// Requests if given buffers are ready, awaits for results and returns OK if +// all of the buffers are ready or the last non-ok status. +absl::Status AwaitBuffersReady(absl::Span ifrt_arrays); + +} // namespace xla + +#endif // JAXLIB_XLA_UTIL_H_ diff --git a/jaxlib/xla/xla.cc b/jaxlib/xla/xla.cc new file mode 100644 index 000000000000..225d45f53b4b --- /dev/null +++ b/jaxlib/xla/xla.cc @@ -0,0 +1,955 @@ +/* Copyright 2019 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/casts.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" +#include "absl/hash/hash.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "absl/types/span.h" +#include "llvm/Support/Casting.h" +#include "nanobind/nanobind.h" +#include "nanobind/nb_defs.h" +#include "nanobind/stl/function.h" // IWYU pragma: keep +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/pair.h" // IWYU pragma: keep +#include "nanobind/stl/set.h" // IWYU pragma: keep +#include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/variant.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/xla/ifrt_proxy.h" +#include "jaxlib/xla/py_client.h" +#include "jaxlib/xla/py_program.h" +#include "jaxlib/xla/sdy.h" +#include "xla/backends/cpu/collectives/cpu_collectives.h" +#include "xla/pjrt/c/pjrt_c_api.h" +#include "xla/pjrt/distributed/client.h" +#include "xla/pjrt/distributed/distributed.h" +#include "xla/pjrt/distributed/protocol.pb.h" +#include "xla/pjrt/distributed/service.h" +#include "xla/pjrt/pjrt_compiler.h" +#include "xla/pjrt/plugin/xla_cpu/cpu_client_options.h" +#include "xla/pjrt/plugin/xla_cpu/xla_cpu_pjrt_client.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/topology.h" +#include "xla/python/pjrt_ifrt/pjrt_attribute_map_util.h" +#include "xla/python/version.h" +#include "xla/tsl/python/lib/core/numpy.h" // NOLINT + +#if defined(__linux__) +#include "gloo/transport/tcp/attr.h" +#include "gloo/transport/tcp/device.h" +#include "jaxlib/xla/py_socket_transfer.h" +#include "xla/backends/cpu/collectives/gloo_collectives.h" +#include "xla/backends/cpu/collectives/gloo_kv_store.h" +#elif defined(__APPLE__) +#include "gloo/transport/uv/device.h" +#include "xla/backends/cpu/collectives/gloo_collectives.h" // NOLINT +#include "xla/backends/cpu/collectives/gloo_kv_store.h" // NOLINT +#endif // defined(__linux__) + +#if !defined(_WIN32) && !defined(PLATFORM_GOOGLE) +#include "xla/backends/cpu/collectives/mpi_collectives.h" +#endif // !_WIN32 && !PLATFORM_GOOGLE + +#include "jaxlib/xla/config.h" +#include "jaxlib/xla/custom_call_sharding.h" +#include "jaxlib/xla/dlpack.h" +#include "jaxlib/xla/guard_lib.h" +#include "jaxlib/xla/jax_jit.h" +#include "jaxlib/xla/mlir.h" +#include "jaxlib/xla/nb_class_ptr.h" +#include "jaxlib/xla/pjit.h" +#include "jaxlib/xla/pmap_lib.h" +#include "jaxlib/xla/py_array.h" +#include "jaxlib/xla/py_compile_only_client.h" +#include "jaxlib/xla/py_device.h" +#include "jaxlib/xla/py_device_list.h" +#include "jaxlib/xla/py_executable.h" +#include "jaxlib/xla/py_memory_space.h" +#include "jaxlib/xla/python_ref_manager.h" +#include "jaxlib/xla/pytree.h" +#include "jaxlib/xla/sharding.h" +#include "jaxlib/xla/traceback.h" +#include "jaxlib/xla/xla_compiler.h" +#include "xla/pjrt/distributed/key_value_store_interface.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/pjrt_api.h" +#include "xla/pjrt/pjrt_c_api_client.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_common.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/pjrt/pjrt_layout.h" +#include "xla/python/logging.h" // IWYU pragma: keep +#include "xla/python/nb_absl_flat_hash_map.h" // IWYU pragma: keep +#include "xla/python/nb_absl_span.h" // IWYU pragma: keep +#include "xla/python/ops.h" +#include "xla/python/pjrt_ifrt/pjrt_client.h" +#include "xla/python/pjrt_ifrt/pjrt_topology.h" +#include "xla/python/pprof_profile_builder.h" +#include "xla/python/profiler.h" +#include "xla/tsl/distributed_runtime/preemption/preemption_sync_manager.h" +#include "xla/tsl/platform/status.h" +#include "tsl/platform/platform.h" + +// TODO(phawkins): remove host_id properties after JAX is update to avoid them. + +namespace xla { +namespace { + +namespace nb = nanobind; + +bool IsOptimizedBuild() { +#if NDEBUG + return true; +#else + return false; +#endif // NDEBUG +} + +// Is*san reports whether the build is under that particular sanitizer. +bool IsAsan() { +#if defined(ADDRESS_SANITIZER) + return true; +#else // defined(ADDRESS_SANITIZER) + return false; +#endif +} + +bool IsMsan() { +#if defined(MEMORY_SANITIZER) + return true; +#else // defined(MEMORY_SANITIZER) + return false; +#endif +} + +bool IsTsan() { +#if defined(THREAD_SANITIZER) + return true; +#else // defined(THREAD_SANITIZER) + return false; +#endif +} + +// IsSanitized reports whether the build is under any sanitizer. +bool IsSanitized() { return IsAsan() || IsMsan() || IsTsan(); } + +} // namespace + +NB_MODULE(xla_extension, m) { + // Initialize ABSL logging because code within XLA uses it. +#ifndef PLATFORM_GOOGLE + InitializeAbslLogging(); +#endif // PLATFORM_GOOGLE + + // We seem to get a fair number of leak warnings from nanobind. It's unclear + // whether these are false positives or not. + nb::set_leak_warnings(false); + + tsl::ImportNumpy(); + + // Exceptions + nb::exception xla_runtime_error(m, "XlaRuntimeError", + PyExc_RuntimeError); + xla_runtime_error.attr("__doc__") = nb::str( + "Runtime errors thrown by the JAX runtime. While the JAX runtime may " + "raise other exceptions as well, most exceptions thrown by the runtime " + "are instances of this class."); + + // Types + nb::enum_(m, "PrimitiveType", nb::is_arithmetic()) + .value("PRIMITIVE_TYPE_INVALID", PRIMITIVE_TYPE_INVALID) + .value("PRED", PRED) + .value("S4", S4) + .value("S8", S8) + .value("S16", S16) + .value("S32", S32) + .value("S64", S64) + .value("U4", U4) + .value("U8", U8) + .value("U16", U16) + .value("U32", U32) + .value("U64", U64) + .value("F16", F16) + .value("F4E2M1FN", F4E2M1FN) + .value("F8E3M4", F8E3M4) + .value("F8E4M3", F8E4M3) + .value("F8E4M3FN", F8E4M3FN) + .value("F8E4M3B11FNUZ", F8E4M3B11FNUZ) + .value("F8E4M3FNUZ", F8E4M3FNUZ) + .value("F8E5M2", F8E5M2) + .value("F8E5M2FNUZ", F8E5M2FNUZ) + .value("F8E8M0FNU", F8E8M0FNU) + .value("BF16", BF16) + .value("F32", F32) + .value("F64", F64) + .value("C64", C64) + .value("C128", C128) + .value("TUPLE", TUPLE) + .value("OPAQUE_TYPE", OPAQUE_TYPE) + .value("TOKEN", TOKEN); + + // Must be before PyClient.compile. + BuildXlaCompilerSubmodule(m); + + PyDevice::RegisterPythonType(m); + PyMemorySpace::RegisterPythonType(m); + PyClient::RegisterPythonTypes(m); + + nb::enum_(m, "ArrayCopySemantics", + nb::is_arithmetic()) + .value("ALWAYS_COPY", ifrt::ArrayCopySemantics::kAlwaysCopy) + .value("REUSE_INPUT", ifrt::ArrayCopySemantics::kReuseInput) + .value("DONATE_INPUT", ifrt::ArrayCopySemantics::kDonateInput); + + nb::class_(m, "PjRtLayout") + .def("__str__", &PjRtLayout::ToString) + .def("__eq__", [](const PjRtLayout& layout, + const PjRtLayout& other) { return layout == other; }) + .def("__hash__", + [](const PjRtLayout& layout) { return absl::HashOf(layout); }) + .def("_xla_layout", &PjRtLayout::xla_layout) + .def("__getstate__", + [](const PjRtLayout& layout) -> nb::tuple { + absl::StatusOr serialized = layout.Serialize(); + ThrowIfError(serialized.status()); + return nb::make_tuple( + nb::bytes(serialized->data(), serialized->size())); + }) + .def("__setstate__", [](PjRtLayout* self, nb::tuple t) { + nb::bytes serialized = nb::cast(t[0]); + absl::StatusOr> layout = + PjRtLayout::Deserialize( + absl::string_view(serialized.c_str(), serialized.size())); + ThrowIfError(layout.status()); + new (self) PjRtLayout((*layout)->xla_layout()); + }); + + nb::class_ cpu_collectives(m, "CpuCollectives"); + + m.def( + "make_gloo_tcp_collectives", + [](std::shared_ptr distributed_client, + + std::optional hostname, + std::optional interface) + -> std::shared_ptr { +#if defined(__linux__) + std::shared_ptr kv_store = nullptr; + if (distributed_client != nullptr) { + kv_store = GetDistributedKeyValueStore(distributed_client, + /*key_prefix=*/"cpu:"); + } + auto gloo_kv_store = std::make_unique(kv_store); + auto tcp_attrs = gloo::transport::tcp::attr(); + if (hostname) { + tcp_attrs.hostname = *hostname; + } + if (interface) { + tcp_attrs.iface = *interface; + } + auto tcp_device = gloo::transport::tcp::CreateDevice(tcp_attrs); + return std::make_shared(std::move(gloo_kv_store), + std::move(tcp_device)); +#elif defined(__APPLE__) + std::shared_ptr kv_store = nullptr; + if (distributed_client != nullptr) { + kv_store = GetDistributedKeyValueStore(distributed_client, + /*key_prefix=*/"cpu:"); + } + auto gloo_kv_store = std::make_unique(kv_store); + auto uv_attrs = gloo::transport::uv::attr(); + if (hostname) { + uv_attrs.hostname = *hostname; + } + if (interface) { + uv_attrs.iface = *interface; + } + auto uv_device = gloo::transport::uv::CreateDevice(uv_attrs); + return std::make_shared(std::move(gloo_kv_store), + std::move(uv_device)); +#else // defined(__linux__) + throw xla::XlaRuntimeError( + "make_gloo_tcp_collectives only implemented for linux and macos"); +#endif // defined(__linux__) + }, + nb::arg("distributed_client"), nb::arg("hostname").none() = std::nullopt, + nb::arg("interface").none() = std::nullopt); + +#if !defined(_WIN32) && !defined(PLATFORM_GOOGLE) + nb::class_ mpi_collectives(m, "MpiCollectives", + cpu_collectives); + mpi_collectives.def("Init", &cpu::MpiCollectives::Init); + mpi_collectives.def("Finalize", &cpu::MpiCollectives::Finalize); + m.def("make_mpi_collectives", []() -> std::shared_ptr { + return std::make_shared(); + }); +#else // !_WIN32 && !PLATFORM_GOOGLE + m.def("make_mpi_collectives", + []() -> std::shared_ptr { + throw xla::XlaRuntimeError( + "make_mpi_collectives is not implemented for Windows"); + }); +#endif // !_WIN32 && !PLATFORM_GOOGLE + + m.def( + "get_tfrt_cpu_client", + [](bool asynchronous, + std::shared_ptr distributed_client, + int node_id, int num_nodes, + std::shared_ptr collectives, + std::optional num_devices) -> nb_class_ptr { + std::unique_ptr ifrt_client; + { + nb::gil_scoped_release gil_release; + xla::CpuClientOptions options; + + options.asynchronous = asynchronous; + options.collectives = std::move(collectives); + options.process_id = node_id; + options.cpu_device_count = num_devices; + std::unique_ptr client = + xla::ValueOrThrow(xla::GetXlaPjrtCpuClient(std::move(options))); + ifrt::PjRtClient::CreateOptions ifrt_options; + ifrt_options.pjrt_client = + std::shared_ptr(std::move(client)); + if (distributed_client != nullptr) { + ifrt_options.kv_store = + GetDistributedKeyValueStore(distributed_client, + /*key_prefix=*/"cpu:"); + ifrt_options.process_id = node_id; + ifrt_options.num_processes = num_nodes; + } + ifrt_client = + ValueOrThrow(ifrt::PjRtClient::Create(std::move(ifrt_options))); + } + return PyClient::Make(std::move(ifrt_client)); + }, + nb::arg("asynchronous") = true, nb::arg("distributed_client") = nullptr, + nb::arg("node_id") = 0, nb::arg("num_nodes") = 1, + nb::arg("collectives").none() = + std::shared_ptr(), + nb::arg("num_devices").none() = std::nullopt); + m.def("pjrt_plugin_loaded", [](std::string platform_name) -> bool { + absl::StatusOr pjrt_api = pjrt::PjrtApi(platform_name); + return pjrt_api.ok(); + }); + m.def( + "load_pjrt_plugin", + [](std::string platform_name, std::optional library_path, + std::optional c_api) -> nb::capsule { + if (library_path.has_value()) { + const PJRT_Api* api = xla::ValueOrThrow( + pjrt::LoadPjrtPlugin(platform_name, *library_path)); + return nb::capsule(absl::bit_cast(api), "pjrt_c_api"); + } + if (absl::string_view(c_api->name()) != "pjrt_c_api") { + throw nb::value_error( + "c_api argument to load_pjrt_plugin is not a pjrt_c_api " + "capsule."); + } + xla::ThrowIfError(pjrt::SetPjrtApi( + platform_name, static_cast(c_api->data()))); + return *c_api; + }, + nb::arg("platform_name"), nb::arg("library_path").none() = std::nullopt, + nb::arg("c_api").none() = std::nullopt); + m.def("pjrt_plugin_initialized", [](std::string platform_name) -> bool { + return xla::ValueOrThrow(pjrt::IsPjrtPluginInitialized(platform_name)); + }); + m.def("initialize_pjrt_plugin", [](std::string platform_name) { + return xla::ThrowIfError(pjrt::InitializePjrtPlugin(platform_name)); + }); + + m.def( + "get_c_api_client", + [](std::string platform_name, + const absl::flat_hash_map& options, + std::shared_ptr distributed_client) + -> nb_class_ptr { + std::unique_ptr ifrt_client; + { + nb::gil_scoped_release gil_release; + std::shared_ptr kv_store = nullptr; + if (distributed_client != nullptr) { + kv_store = GetDistributedKeyValueStore( + distributed_client, + /*key_prefix=*/absl::StrCat(platform_name, ":")); + } + std::unique_ptr c_api_client = xla::ValueOrThrow( + GetCApiClient(platform_name, options, kv_store)); + ifrt_client = ifrt::PjRtClient::Create(std::move(c_api_client)); + } + return PyClient::Make(std::move(ifrt_client)); + }, + nb::arg("platform_name"), + nb::arg("options") = absl::flat_hash_map(), + nb::arg("distributed_client").none() = nullptr); + // TODO(b/322357665): Delete this method after TPU plugin changes to use the + // standard registration. + m.def("get_default_c_api_topology", + [](std::string platform_name, std::string topology_name, + const absl::flat_hash_map& options) + -> std::shared_ptr { + return std::make_shared(xla::ValueOrThrow( + GetCApiTopology(platform_name, topology_name, options))); + }); + m.def("get_c_api_topology", + [](nb::capsule c_api, std::string topology_name, + const absl::flat_hash_map& options) + -> std::shared_ptr { + if (absl::string_view(c_api.name()) != "pjrt_c_api") { + throw nb::value_error( + "Argument to get_c_api_topology was not a pjrt_c_api capsule."); + } + return std::make_shared(xla::ValueOrThrow( + GetCApiTopology(static_cast(c_api.data()), + topology_name, options))); + }); + m.def("get_topology_for_devices", + [](const std::vector>& py_devices) { + if (py_devices.empty()) { + throw nb::value_error( + "get_topology_for_devices requires >= 1 devices."); + } + auto client = py_devices[0]->client(); + absl::InlinedVector ifrt_devices; + ifrt_devices.reserve(py_devices.size()); + for (const auto& py_device : py_devices) { + if (py_device->client().get() != client.get()) { + throw nb::value_error( + "devices passed to get_topology_for_devices come from " + "different clients."); + } + ifrt_devices.push_back(py_device->device()); + } + ifrt::DeviceListRef device_list = + client->ifrt_client()->MakeDeviceList(ifrt_devices); + return xla::ValueOrThrow( + client->ifrt_client()->GetTopologyForDevices(device_list)); + }); + + TF_CHECK_OK(PyArray::RegisterTypes(m)); + jax::PyDeviceList::Register(m); + jax::RegisterSharding(m); + + nb::class_(m, "CompiledMemoryStats") + .def_rw("generated_code_size_in_bytes", + &CompiledMemoryStats::generated_code_size_in_bytes) + .def_rw("argument_size_in_bytes", + &CompiledMemoryStats::argument_size_in_bytes) + .def_rw("output_size_in_bytes", + &CompiledMemoryStats::output_size_in_bytes) + .def_rw("alias_size_in_bytes", &CompiledMemoryStats::alias_size_in_bytes) + .def_rw("temp_size_in_bytes", &CompiledMemoryStats::temp_size_in_bytes) + .def_rw("host_generated_code_size_in_bytes", + &CompiledMemoryStats::host_generated_code_size_in_bytes) + .def_rw("host_argument_size_in_bytes", + &CompiledMemoryStats::host_argument_size_in_bytes) + .def_rw("host_output_size_in_bytes", + &CompiledMemoryStats::host_output_size_in_bytes) + .def_rw("host_alias_size_in_bytes", + &CompiledMemoryStats::host_alias_size_in_bytes) + .def_rw("host_temp_size_in_bytes", + &CompiledMemoryStats::host_temp_size_in_bytes) + .def_prop_ro("serialized_hlo_proto", + [](const CompiledMemoryStats& cms) -> nb::bytes { + return nb::bytes(cms.serialized_hlo_proto.data(), + cms.serialized_hlo_proto.size()); + }) + .def("__str__", &CompiledMemoryStats::DebugString); + + nb::class_(m, "ExecuteResults") + .def("__len__", [](PyExecuteResults& results) { return results.Size(); }) + .def("disassemble_into_single_device_arrays", + &PyExecuteResults::DisassembleIntoSingleDeviceArrays) + .def("disassemble_prefix_into_single_device_arrays", + &PyExecuteResults::DisassemblePrefixIntoSingleDeviceArrays) + .def("consume_with_handlers", &PyExecuteResults::ConsumeWithHandlers) + .def("consume_token", &PyExecuteResults::ConsumeToken); + + nb::class_(m, "LoadedExecutable") + .def_prop_ro("client", &PyLoadedExecutable::client) + .def("local_devices", &PyLoadedExecutable::AddressableDevices) + .def("size_of_generated_code_in_bytes", + &PyLoadedExecutable::SizeOfGeneratedCodeInBytes) + .def( + "get_compiled_memory_stats", + xla::ValueOrThrowWrapper(&PyLoadedExecutable::GetCompiledMemoryStats)) + .def("delete", &PyLoadedExecutable::Delete) + .def("execute_sharded", + xla::ValueOrThrowWrapper(&PyLoadedExecutable::ExecuteSharded), + nb::arg("arguments"), nb::arg("with_tokens") = false) + .def("hlo_modules", ValueOrThrowWrapper(&PyLoadedExecutable::HloModules)) + .def("get_output_memory_kinds", + xla::ValueOrThrowWrapper(&PyLoadedExecutable::GetOutputMemoryKinds)) + .def("get_output_shardings", &PyLoadedExecutable::GetOutputShardings) + .def("get_parameter_layouts", + xla::ValueOrThrowWrapper(&PyLoadedExecutable::GetParameterLayouts)) + .def("get_output_layouts", + xla::ValueOrThrowWrapper(&PyLoadedExecutable::GetOutputLayouts)) + .def("get_parameter_shardings", + &PyLoadedExecutable::GetParameterShardings) + .def("keep_alive", &PyLoadedExecutable::KeepAlive) + .def("cost_analysis", + [](const PyLoadedExecutable& self) { + auto map = ValueOrThrow(self.GetCostAnalysis()); + return ifrt::ToPjRtAttributeMap(std::move(map)); + }) + .def_prop_ro("traceback", &PyLoadedExecutable::traceback) + .def_prop_ro("fingerprint", [](PyLoadedExecutable* exec) -> nb::object { + if (exec->fingerprint().has_value()) { + return nb::bytes(exec->fingerprint()->data(), + exec->fingerprint()->size()); + } else { + return nb::none(); + } + }); + nb::class_ token(m, "Token"); + token.def("block_until_ready", + [](PyToken& self) { xla::ThrowIfError(self.Await()); }); + + nb::class_ sharded_token(m, "ShardedToken"); + sharded_token.def("block_until_ready", [](PyShardedToken& self) { + xla::ThrowIfError(self.Await()); + }); + sharded_token.def("get_token", &PyShardedToken::GetPyToken); + + m.def("buffer_to_dlpack_managed_tensor", + xla::ValueOrThrowWrapper(BufferToDLPackManagedTensor), + nb::arg("buffer"), nb::arg("stream").none() = nb::none()); + m.def( + "dlpack_managed_tensor_to_buffer", + [](const nb::capsule& tensor, nb_class_ptr device, + std::optional stream) { + return xla::ValueOrThrow(DLPackManagedTensorToBuffer( + tensor, device->device(), device->client(), stream)); + }, + nb::arg("dlpack"), nb::arg("device"), nb::arg("stream").none()); + // Legacy overload + m.def( + "dlpack_managed_tensor_to_buffer", + [](const nb::capsule& tensor, + std::optional> cpu_client, + std::optional> gpu_client) { + return xla::ValueOrThrow(DLPackManagedTensorToBuffer( + tensor, std::move(cpu_client), std::move(gpu_client))); + }, + nb::arg("dlpack"), nb::arg("cpu_backend").none() = nb::none(), + nb::arg("gpu_backend").none() = nb::none()); + m.def("cuda_array_interface_to_buffer", + xla::ValueOrThrowWrapper(CudaArrayInterfaceToBuffer), nb::arg("cai"), + nb::arg("gpu_backend").none() = nb::none(), + nb::arg("device_id").none() = nb::none()); + + jax::BuildConfigSubmodule(m); + BuildIfrtProgramsSubmodule(m); + BuildProfilerSubmodule(m); + BuildOpsSubmodule(m); + BuildPytreeSubmodule(m); + jax::BuildGuardSubmodule(m); + jax::BuildJaxjitSubmodule(m); + jax::BuildPmapSubmodule(m); + jax::BuildPjitSubmodule(m); + BuildTracebackSubmodule(m); + BuildMlirSubmodule(m); + BuildSdySubmodule(m); + BuildCustomCallShardingPybindAPI(m); +#if defined(__linux__) + aux::RegisterTransferServerTypes(m); +#endif // defined(__linux__) + + // The following uses python bindings for PyClient defined above using + // pybind11, and hence needs pybind11::module_ (not just nanobind::module_). + xla::ifrt::proxy::BuildIfrtProxySubmodule(m); + + nb::class_ preemption_sync_manager( + m, "PreemptionSyncManager"); + preemption_sync_manager + .def( + "initialize", + [](tsl::PreemptionSyncManager& manager, + DistributedRuntimeClient* client) { + tsl::CoordinationServiceAgent* agent = + xla::ValueOrThrow(client->GetCoordinationServiceAgent()); + xla::ThrowIfError(manager.Initialize(agent)); + }, + nb::arg("distributed_client")) + .def("reached_sync_point", + [](tsl::PreemptionSyncManager& manager, int step_counter) { + return manager.ReachedSyncPoint(step_counter); + }); + m.def("create_preemption_sync_manager", + []() { return tsl::CreatePreemptionSyncManager(); }); + + nb::class_ distributed_runtime_service( + m, "DistributedRuntimeService"); + distributed_runtime_service.def("shutdown", + &DistributedRuntimeService::Shutdown, + nb::call_guard()); + nb::class_ distributed_runtime_client( + m, "DistributedRuntimeClient"); + distributed_runtime_client + .def("connect", + [](DistributedRuntimeClient& self) { + nb::gil_scoped_release gil_release; + xla::ThrowIfError(self.Connect()); + }) + .def("shutdown", + [](DistributedRuntimeClient& self) { + nb::gil_scoped_release gil_release; + xla::ThrowIfError(self.Shutdown()); + }) + // This method assumes that the value is a Python string. Use + // `blocking_key_value_get_bytes()` if key_value_set() was called with a + // Python bytes object as its value. + .def( + "blocking_key_value_get", + [](DistributedRuntimeClient& client, std::string key, + int64_t timeout_in_ms) { + nb::gil_scoped_release gil_release; + return xla::ValueOrThrow(client.BlockingKeyValueGet( + key, absl::Milliseconds(timeout_in_ms))); + }, + nb::arg("key"), nb::arg("timeout_in_ms")) + // Same as `blocking_key_value_get()`, but retrieves the raw Python byte + // values explicitly. + .def( + "blocking_key_value_get_bytes", + [](DistributedRuntimeClient& client, std::string key, + int64_t timeout_in_ms) -> nb::bytes { + std::string result; + { + nb::gil_scoped_release gil_release; + result = xla::ValueOrThrow(client.BlockingKeyValueGet( + key, absl::Milliseconds(timeout_in_ms))); + } + return nb::bytes(result.data(), result.size()); + }, + nb::arg("key"), nb::arg("timeout_in_ms")) + .def( + "key_value_try_get", + [](DistributedRuntimeClient& client, std::string key) { + nb::gil_scoped_release gil_release; + return xla::ValueOrThrow(client.KeyValueTryGet(key)); + }, + nb::arg("key")) + .def( + "key_value_try_get_bytes", + [](DistributedRuntimeClient& client, std::string key) -> nb::bytes { + std::string result; + { + nb::gil_scoped_release gil_release; + result = xla::ValueOrThrow(client.KeyValueTryGet(key)); + } + return nb::bytes(result.data(), result.size()); + }, + nb::arg("key")) + .def( + "wait_at_barrier", + [](DistributedRuntimeClient& client, std::string barrier_id, + int64_t timeout_in_ms, + std::optional> process_ids) { + nb::gil_scoped_release gil_release; + xla::ThrowIfError(client.WaitAtBarrier( + barrier_id, absl::Milliseconds(timeout_in_ms), process_ids)); + }, + nb::arg("barrier_id"), nb::arg("timeout_in_ms"), + nb::arg("process_ids") = std::nullopt) + .def( + "get_live_nodes", + [](DistributedRuntimeClient& client, + std::vector process_ids) { + nb::gil_scoped_release gil_release; + return xla::ValueOrThrow(client.GetLiveNodes(process_ids)); + }, + nb::arg("process_ids")) + // The key must be a string, but the value can either be a Python string + // or bytes object. + // With Python string values, use `key_value_set()` and + // `blocking_key_value_get()`. + // With Python byte object values, use `key_value_set()` and + // `blocking_key_value_get_bytes()`. + .def( + "key_value_set", + [](DistributedRuntimeClient& client, absl::string_view key, + absl::string_view value, bool allow_overwrite) { + nb::gil_scoped_release gil_release; + xla::ThrowIfError(client.KeyValueSet(key, value, allow_overwrite)); + }, + nb::arg("key"), nb::arg("value"), nb::arg("allow_overwrite") = false) + // The key must be a string, but the value must a + // Python bytes object. + // Use `key_value_set_bytes()` and `blocking_key_value_get_bytes()`. + .def( + "key_value_set_bytes", + [](DistributedRuntimeClient& client, absl::string_view key, + nb::bytes value, bool allow_overwrite) { + nb::gil_scoped_release gil_release; + xla::ThrowIfError(client.KeyValueSet( + key, absl::string_view(value.c_str(), value.size()), + allow_overwrite)); + }, + nb::arg("key"), nb::arg("value"), nb::arg("allow_overwrite") = false) + // Assumes that all values in the directory are Python strings. + .def( + "key_value_dir_get", + [](DistributedRuntimeClient& client, absl::string_view key) { + nb::gil_scoped_release gil_release; + return xla::ValueOrThrow(client.KeyValueDirGet(key)); + }, + nb::arg("key")) + // Assumes that all values in the directory are Python byte objects. + // Same as `key_value_dir_get()`, but retrieves Python byte values + // explicitly. + .def( + "key_value_dir_get_bytes", + [](DistributedRuntimeClient& client, absl::string_view key) + -> std::vector> { + std::vector> result; + { + nb::gil_scoped_release gil_release; + result = xla::ValueOrThrow(client.KeyValueDirGet(key)); + } + // Convert std::string values to nb::bytes. + std::vector> kvs; + kvs.reserve(result.size()); + for (auto& kv : result) { + kvs.push_back( + std::pair(std::move(kv.first), + nb::bytes(kv.second.data(), kv.second.size()))); + } + return kvs; + }, + nb::arg("key")) + .def( + "key_value_delete", + [](DistributedRuntimeClient& client, absl::string_view key) { + nb::gil_scoped_release gil_release; + return xla::ThrowIfError(client.KeyValueDelete(key)); + }, + nb::arg("key")); + + m.def( + "get_distributed_runtime_service", + [](std::string address, int num_nodes, + std::optional heartbeat_interval, + std::optional max_missing_heartbeats, + std::optional cluster_register_timeout, + std::optional shutdown_timeout) + -> std::unique_ptr { + CoordinationServiceImpl::Options options; + options.num_nodes = num_nodes; + if (heartbeat_interval.has_value()) { + options.heartbeat_interval = absl::Seconds(*heartbeat_interval); + } + if (max_missing_heartbeats.has_value()) { + options.max_missing_heartbeats = *max_missing_heartbeats; + } + if (cluster_register_timeout.has_value()) { + options.cluster_register_timeout = + absl::Seconds(*cluster_register_timeout); + } + if (shutdown_timeout.has_value()) { + options.shutdown_timeout = absl::Seconds(*shutdown_timeout); + } + std::unique_ptr service = + xla::ValueOrThrow(GetDistributedRuntimeService(address, options)); + return service; + }, + nb::arg("address"), nb::arg("num_nodes"), + nb::arg("heartbeat_interval").none() = std::nullopt, + nb::arg("max_missing_heartbeats").none() = std::nullopt, + nb::arg("cluster_register_timeout").none() = std::nullopt, + nb::arg("shutdown_timeout").none() = std::nullopt); + + m.def( + "get_distributed_runtime_client", + [](std::string address, int node_id, std::optional rpc_timeout, + std::optional init_timeout, std::optional shutdown_timeout, + std::optional heartbeat_interval, + std::optional max_missing_heartbeats, + std::optional> + missed_heartbeat_callback, + std::optional shutdown_on_destruction, + std::optional use_compression) + -> std::shared_ptr { + bool compression = use_compression.value_or(false); + DistributedRuntimeClient::Options options; + options.node_id = node_id; + if (rpc_timeout.has_value()) { + options.rpc_timeout = absl::Seconds(*rpc_timeout); + } + if (init_timeout.has_value()) { + options.init_timeout = absl::Seconds(*init_timeout); + } + if (shutdown_timeout.has_value()) { + options.shutdown_timeout = absl::Seconds(*shutdown_timeout); + } + if (heartbeat_interval.has_value()) { + options.heartbeat_interval = absl::Seconds(*heartbeat_interval); + } + if (max_missing_heartbeats.has_value()) { + options.max_missing_heartbeats = *max_missing_heartbeats; + } + if (missed_heartbeat_callback.has_value()) { + options.missed_heartbeat_callback = + std::move(*missed_heartbeat_callback); + } + if (shutdown_on_destruction.has_value()) { + options.shutdown_on_destruction = *shutdown_on_destruction; + } + return GetDistributedRuntimeClient(address, options, compression); + }, + nb::arg("address"), nb::arg("node_id"), + nb::arg("rpc_timeout").none() = std::nullopt, + nb::arg("init_timeout").none() = std::nullopt, + nb::arg("shutdown_timeout").none() = std::nullopt, + nb::arg("heartbeat_interval").none() = std::nullopt, + nb::arg("max_missing_heartbeats").none() = std::nullopt, + nb::arg("missed_heartbeat_callback").none() = std::nullopt, + nb::arg("shutdown_on_destruction").none() = std::nullopt, + nb::arg("use_compression").none() = std::nullopt); + + m.def("collect_garbage", []() { GlobalPyRefManager()->CollectGarbage(); }); + + m.def("is_optimized_build", &IsOptimizedBuild); + + m.def("json_to_pprof_profile", xla::ValueOrThrowWrapper(JsonToPprofProfile), + "Encodes the JSON representation of a pprof Profile into its binary " + "protocol buffer encoding."); + m.def("pprof_profile_to_json", xla::ValueOrThrowWrapper(PprofProfileToJson), + "Decodes an uncompressed pprof Profile protocol buffer into a JSON " + "representation"); + + RegisterCompileOnlyClient(m); + nb::class_(m, "DeviceTopology") + .def("_make_compile_only_devices", + [](std::shared_ptr topology) { + if (!llvm::isa(*topology)) { + throw xla::XlaRuntimeError("Only PjRtTopologies are supported."); + } + return MakeCompileOnlyClient( + std::dynamic_pointer_cast(topology)) + ->Devices(); + }) + .def_prop_ro( + "platform", + [](ifrt::Topology& topology) { return topology.platform_name(); }) + .def_prop_ro( + "platform_version", + [](ifrt::Topology& topology) { return topology.platform_version(); }) + .def("serialize", + [](ifrt::Topology& topology) -> nb::bytes { + std::string serialized = ValueOrThrow(topology.Serialize()); + return nb::bytes(serialized.data(), serialized.size()); + }) + .def("__getattr__", + [](ifrt::Topology& topology, absl::string_view name) -> nb::object { + const auto& attrs = topology.Attributes().map(); + auto it = attrs.find(name); + if (it != attrs.end()) { + return std::visit([](auto&& v) { return nb::cast(v.value); }, + it->second); + } + throw nb::attribute_error( + absl::StrCat("Unknown attribute ", name).c_str()); + }); + + nb::class_(m, "Executable") + .def("hlo_modules", ValueOrThrowWrapper(&ifrt::Executable::GetHloModules)) + .def("get_output_memory_kinds", + xla::ValueOrThrowWrapper(&ifrt::Executable::GetOutputMemoryKinds)) + .def("get_output_shardings", &ifrt::Executable::GetOutputShardings) + .def("get_parameter_layouts", + ValueOrThrowWrapper(&ifrt::Executable::GetParameterLayouts)) + .def("get_output_layouts", + xla::ValueOrThrowWrapper(&ifrt::Executable::GetOutputLayouts)) + .def("get_parameter_shardings", &ifrt::Executable::GetParameterShardings) + .def("get_compiled_memory_stats", + xla::ValueOrThrowWrapper(&ifrt::Executable::GetCompiledMemoryStats)) + .def("serialize", + [](const ifrt::Executable& exec) -> nb::bytes { + std::string serialized = ValueOrThrow(exec.Serialize()); + return nb::bytes(serialized.data(), serialized.size()); + }) + .def("cost_analysis", [](const ifrt::Executable& exec) { + auto attrs = ValueOrThrow(exec.GetCostAnalysis()); + return ifrt::ToPjRtAttributeMap(std::move(attrs)); + }); + + m.def("is_asan", IsAsan); + m.def("is_msan", IsMsan); + m.def("is_tsan", IsTsan); + m.def("is_sanitized", IsSanitized); + + m.def( + "batched_device_put", + [](nb::object aval, nb::object sharding, std::vector xs, + std::vector dst_devices, bool committed, + bool force_copy, + PjRtClient::HostBufferSemantics host_buffer_semantics) -> nb::object { + return ValueOrThrow(PyArray::BatchedDevicePut( + aval, sharding, std::move(xs), std::move(dst_devices), committed, + force_copy, host_buffer_semantics, jax::GetEnableX64())); + }, + nb::arg("aval"), nb::arg("sharding"), nb::arg("xs"), nb::arg("devices"), + nb::arg("committed") = true, nb::arg("force_copy") = false, + nb::arg("host_buffer_semantics") = + PjRtClient::HostBufferSemantics::kImmutableZeroCopy); + m.def( + "reorder_shards", + [](PyArray x, nb::object dst_sharding, + ifrt::ArrayCopySemantics array_copy_semantics) { + return ValueOrThrow(PyArray::ReorderShards( + std::move(x), std::move(dst_sharding), array_copy_semantics)); + }, + nb::arg("x"), nb::arg("dst_sharding"), nb::arg("array_copy_semantics")); + + m.def("batched_block_until_ready", [](std::vector xs) { + ThrowIfError(PyArray::BatchedBlockUntilReady(std::move(xs))); + }); + + m.def("check_and_canonicalize_memory_kind", + &jax::CheckAndCanonicalizeMemoryKind, nb::arg("memory_kind").none(), + nb::arg("device_list")); + + m.attr("ifrt_version_number") = JAX_IFRT_VERSION_NUMBER; +} // NOLINT(readability/fn_size) + +} // namespace xla diff --git a/jaxlib/xla/xla_client.py b/jaxlib/xla/xla_client.py new file mode 100644 index 000000000000..543664682c08 --- /dev/null +++ b/jaxlib/xla/xla_client.py @@ -0,0 +1,986 @@ +# Copyright 2017 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""An XLA client in Python.""" + +from __future__ import annotations + +import atexit +from collections.abc import Mapping, Sequence +import contextlib +import enum +import gzip +import inspect +import logging +import os +import threading +from typing import Any, Protocol, Union + +import ml_dtypes +import numpy as np + +from jaxlib import xla_extension as _xla + +# Note this module does *not* depend on any Python protocol buffers. The XLA +# Python bindings are currently packaged both as part of jaxlib and as part +# of TensorFlow. If we use protocol buffers here, then importing both jaxlib +# and TensorFlow may fail with duplicate protocol buffer message definitions. + +# Most functions are snake_case for consistency with other modules, some +# method names are CamelCase for consistency with XLA. +# pylint: disable=invalid-name + +# Pylint has false positives for type annotations. +# pylint: disable=invalid-sequence-index + +ifrt_programs = _xla.ifrt_programs +ops = _xla.ops +profiler = _xla.profiler + +# Just an internal arbitrary increasing number to help with backward-compatible +# changes. In JAX, reference this via jax._src.lib.jaxlib_extension_version. +_version = 330 + +# An internal increasing version number for protecting jaxlib code against +# ifrt changes. +# lives in xla/python/version.h. +# In JAX, reference this via jax._src.lib.ifrt_version. +_ifrt_version = _xla.ifrt_version_number + +# Version number for MLIR:Python components. +mlir_api_version = 58 + +xla_platform_names = { + 'cpu': 'Host', + 'gpu': 'CUDA', +} + +logger = logging.getLogger(__name__) + +_NameValueMapping = Mapping[str, Union[str, int, list[int], float, bool]] + + +def make_cpu_client( + asynchronous=True, + distributed_client=None, + node_id=0, + num_nodes=1, + collectives=None, + num_devices=None, +) -> ...: + register_custom_call_handler('cpu', _xla.register_custom_call_target) + register_custom_type_id_handler('cpu', _xla.register_custom_type_id) + return _xla.get_tfrt_cpu_client( + asynchronous=asynchronous, + distributed_client=distributed_client, + node_id=node_id, + num_nodes=num_nodes, + collectives=collectives, + num_devices=num_devices, + ) + + +DeviceTopology = _xla.DeviceTopology +get_topology_for_devices = _xla.get_topology_for_devices + + +def make_tfrt_tpu_c_api_device_topology( + topology_name: str = '', **kwargs +) -> DeviceTopology: + """Creates a PJRT C API TopologyDescription.""" + return _xla.get_default_c_api_topology('tpu', topology_name, dict(**kwargs)) + + +def make_c_api_device_topology( + c_api: Any, topology_name: str = '', **kwargs +) -> DeviceTopology: + """Creates a PJRT C API TopologyDescription.""" + return _xla.get_c_api_topology(c_api, topology_name, dict(**kwargs)) + + +def pjrt_plugin_loaded(plugin_name: str) -> bool: + return _xla.pjrt_plugin_loaded(plugin_name) + + +def load_pjrt_plugin_dynamically(plugin_name: str, library_path: str) -> Any: + return _xla.load_pjrt_plugin(plugin_name, library_path, c_api=None) + + +def load_pjrt_plugin_with_c_api(plugin_name: str, c_api: Any) -> None: + return _xla.load_pjrt_plugin(plugin_name, None, c_api) + + +def pjrt_plugin_initialized(plugin_name: str) -> bool: + return _xla.pjrt_plugin_initialized(plugin_name) + + +def initialize_pjrt_plugin(plugin_name: str) -> None: + """Initializes a PJRT plugin. + + The plugin needs to be loaded first (through load_pjrt_plugin_dynamically or + static linking) before this method is called. + Args: + plugin_name: the name of the PJRT plugin. + """ + _xla.initialize_pjrt_plugin(plugin_name) + + +def make_c_api_client( + plugin_name: str, + options: _NameValueMapping | None = None, + distributed_client: _xla.DistributedRuntimeClient | None = None, +): + """Creates a PJRT C API client for a PJRT plugin. + + It is required that load_pjrt_plugin_dynamically is called once with the same + plugin_name before this method is called. + + Args: + plugin_name: the name of the PJRT plugin. + options: extra platform-specific options. + distributed_client: distributed client. + + Returns: + A PJRT C API client for plugin_name. + """ + if options is None: + options = {} + return _xla.get_c_api_client(plugin_name, options, distributed_client) + + +def make_tpu_client( + library_path: str | None = None, options: _NameValueMapping | None = None +): + """Returns a TPU client. Defaults to allowing 32 in-flight computations.""" + if not pjrt_plugin_loaded('tpu'): + c_api = load_pjrt_plugin_dynamically('tpu', library_path or 'libtpu.so') + profiler.register_plugin_profiler(c_api) + assert pjrt_plugin_loaded('tpu') + if not pjrt_plugin_initialized('tpu'): + initialize_pjrt_plugin('tpu') + if options is None: + options = {} + return _xla.get_c_api_client('tpu', options) + + +def generate_pjrt_gpu_plugin_options() -> _NameValueMapping: + """Generates the PjRt GPU plugin options. + + Returns: + A dictionary of plugin options. + """ + + options = {} + options['platform_name'] = 'cuda' + allocator = os.getenv('XLA_PYTHON_CLIENT_ALLOCATOR', 'default').lower() + memory_fraction = os.getenv('XLA_CLIENT_MEM_FRACTION', '') + deprecated_memory_fraction = os.getenv('XLA_PYTHON_CLIENT_MEM_FRACTION', '') + if deprecated_memory_fraction: + if memory_fraction: + raise ValueError( + 'XLA_CLIENT_MEM_FRACTION is specified together ' + 'with XLA_PYTHON_CLIENT_MEM_FRACTION. ' + 'Remove the latter one, it is deprecated.' + ) + else: + memory_fraction = deprecated_memory_fraction + preallocate = os.getenv('XLA_PYTHON_CLIENT_PREALLOCATE', '') + collective_memory_size = os.getenv( + 'XLA_PYTHON_CLIENT_COLLECTIVE_MEM_SIZE_MB', '' + ) + if allocator not in ('default', 'platform', 'bfc', 'cuda_async'): + raise ValueError( + 'XLA_PYTHON_CLIENT_ALLOCATOR env var must be "default", "platform", ' + '"bfc", or "cuda_async", got "%s"' % allocator + ) + options['allocator'] = allocator + if memory_fraction: + options['memory_fraction'] = float(memory_fraction) + if preallocate: + options['preallocate'] = preallocate not in ('false', 'False', '0') + if collective_memory_size: + options['collective_memory_size'] = int(collective_memory_size) * (1 << 20) + return options + + +class OpMetadata: + """Python representation of a xla.OpMetadata protobuf.""" + + __slots__ = ('op_type', 'op_name', 'source_file', 'source_line') + + def __init__(self, op_type='', op_name='', source_file='', source_line=0): + self.op_type = op_type + self.op_name = op_name + self.source_file = source_file + self.source_line = source_line + + +def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1): + """Helper for use in source mapping that returns an OpMetadata object.""" + full_filename, lineno = inspect.stack()[skip_frames][1:3] + filename = os.path.basename(full_filename) + return OpMetadata( + op_type=op_type, op_name=op_name, source_file=filename, source_line=lineno + ) + + +PrimitiveType = _xla.PrimitiveType + +XLA_ELEMENT_TYPE_TO_DTYPE = { + PrimitiveType.PRED: np.dtype('bool'), + PrimitiveType.S4: np.dtype(ml_dtypes.int4), + PrimitiveType.S8: np.dtype('int8'), + PrimitiveType.S16: np.dtype('int16'), + PrimitiveType.S32: np.dtype('int32'), + PrimitiveType.S64: np.dtype('int64'), + PrimitiveType.U4: np.dtype(ml_dtypes.uint4), + PrimitiveType.U8: np.dtype('uint8'), + PrimitiveType.U16: np.dtype('uint16'), + PrimitiveType.U32: np.dtype('uint32'), + PrimitiveType.U64: np.dtype('uint64'), + PrimitiveType.F4E2M1FN: np.dtype(ml_dtypes.float4_e2m1fn), + PrimitiveType.F8E3M4: np.dtype(ml_dtypes.float8_e3m4), + PrimitiveType.F8E4M3: np.dtype(ml_dtypes.float8_e4m3), + PrimitiveType.F8E4M3FN: np.dtype(ml_dtypes.float8_e4m3fn), + PrimitiveType.F8E4M3B11FNUZ: np.dtype(ml_dtypes.float8_e4m3b11fnuz), + PrimitiveType.F8E4M3FNUZ: np.dtype(ml_dtypes.float8_e4m3fnuz), + PrimitiveType.F8E5M2: np.dtype(ml_dtypes.float8_e5m2), + PrimitiveType.F8E5M2FNUZ: np.dtype(ml_dtypes.float8_e5m2fnuz), + PrimitiveType.F8E8M0FNU: np.dtype(ml_dtypes.float8_e8m0fnu), + PrimitiveType.BF16: np.dtype(ml_dtypes.bfloat16), + PrimitiveType.F16: np.dtype('float16'), + PrimitiveType.F32: np.dtype('float32'), + PrimitiveType.F64: np.dtype('float64'), + PrimitiveType.C64: np.dtype('complex64'), + PrimitiveType.C128: np.dtype('complex128'), + PrimitiveType.TUPLE: np.dtype(np.object_), + PrimitiveType.TOKEN: np.dtype(np.object_), +} + +# Note the conversion on the key. Numpy has a known issue wherein dtype hashing +# doesn't work as expected (https://github.com/numpy/numpy/issues/7242). Thus, +# when keying by dtype in this dict, we use the string form of dtypes. +DTYPE_TO_XLA_ELEMENT_TYPE = { + str(dt): et for et, dt in XLA_ELEMENT_TYPE_TO_DTYPE.items() +} + + +def dtype_to_etype(dtype): + """Convenience function for reading DTYPE_TO_XLA_ELEMENT_TYPE.""" + return DTYPE_TO_XLA_ELEMENT_TYPE[str(np.dtype(dtype))] + + +Shape = _xla.Shape +Shape.__doc__ = """ +A Shape is an object defined in C++ that duck types like the following class: + +class Shape: + '''Represents an XLA shape. + + A shape is either an array shape, having rank-many integer + dimensions and an element type (represented by a Numpy dtype), or it + is a tuple shape, having a shape for every tuple component: + + type shape = + TupleShape of shape list + | ArrayShape of { dimensions: int list; element_type: dtype } + ''' + + @staticmethod + def tuple_shape(tuple_shapes) -> Shape: + "Construct a tuple shape." + + @staticmethod + def array_shape(element_type, dimensions, minor_to_major=None) -> Shape: + + @staticmethod + def from_pyval(pyval) -> Shape: + "Returns a Shape that describes a tuple-tree of Numpy arrays." + + def __init__(self, str) -> Shape: + "Parses a shape string." + def __eq__(self, other: Shape) -> bool: + def __ne__(self, other: Shape) -> bool: + def __hash__(self): + def __repr__(self): + def is_tuple(self) -> bool: + def is_array(self) -> bool: + def tuple_shapes(self) -> [Shape]: + def numpy_dtype(self) -> np.dtype: + "Like element_type(), but returns dtype('O') for a tuple shape." + def xla_element_type(self) -> PrimitiveType: + def element_type(self) -> np.dtype: + def dimensions(self) -> (int, int, ...): + def rank(self) -> int: + def with_major_to_minor_layout_if_absent(self) -> Shape: + "Returns a copy with missing layouts set to major-to-minor." + + def to_serialized_proto(self) -> bytes: + "Returns 'shape' as a serialized proto." +""" + +ProgramShape = _xla.ProgramShape +ProgramShape.__doc__ = """ +A ProgramShape is a C++ object that duck types like the following class. + +class ProgramShape: + def __init__(self, parameter_shapes, result_shape): + def parameter_shapes(self) -> [Shape]: + def result_shape(self) -> Shape: + def __repr__(self): +""" + +ShapeIndex = _xla.ShapeIndex +ShapeIndex.__doc__ = """ +A Shape is an object defined in C++ that duck types like the following class: + +class ShapeIndex: + '''Represents an XLA ShapeIndex. + + An index for specifying a particular nested subshape within a shape. Used in + ShapeUtil::GetSubshape and other interfaces. ShapeIndex defines a path through + the Shape tree where each element of ShapeIndex indexes into a tuple (or + nested tuple) within the shape. For a non-nested tuple, an index has a single + element. + ''' + + def __init__(self, List[int]) -> ShapeIndex: + def __eq__(self, other: Shape) -> bool: + def __ne__(self, other: Shape) -> bool: + def __hash__(self): + def __repr__(self): +""" + + +def shape_from_pyval(pyval, layout: Sequence[int] | None = None): + """Returns a Shape that describes a tuple-tree of Numpy arrays.""" + + def convert(pyval): + if isinstance(pyval, tuple): + if layout is not None: + raise NotImplementedError( + 'shape_from_pyval does not support layouts for tuple shapes' + ) + return Shape.tuple_shape(tuple(convert(elt) for elt in pyval)) + else: + return Shape.array_shape(pyval.dtype, np.shape(pyval), layout) + + return convert(pyval) + + +DeviceAssignment = _xla.DeviceAssignment +DeviceAssignment.__doc__ = """ +A DeviceAssignment is a C++ object with the following signature. + +def create(assignment): + '''Builds a device assignment. + + Args: + assignment: a 2D numpy array of device ordinal integers, indexed by + [replica][computation_in_replica]. + Returns: + A device assignment. + ''' + +def replica_count(): + '''Returns the number of replicas.''' +def computation_count(): + '''Returns the number of computations per replica.''' +""" + +Device = _xla.Device +CompileOptions = _xla.CompileOptions + +HostBufferSemantics = _xla.HostBufferSemantics + +# An Executable is a C++ class that duck types with the following API: +# class Executable: +# def local_devices(self) -> [Device]: +# def execute(self, arguments : [Buffer]) -> Buffer: +# """Execute on one replica with Buffer arguments and return value.""" +# +# def size_of_generated_code_in_bytes(self) -> int: +# """Return generated binary size, or -1 if not known.""" +# +# def execute_sharded_on_local_devices(self, arguments: [[Buffer]]) +# -> [Buffer]: +# """Execute on many replicas with Buffer arguments and return value. +# +# Args: +# arguments: A sequence of sequences of Buffers. The i'th element of each +# sequence comprises the arguments for execution on the i'th local +# device. +# +# Returns: +# A list of the computation's outputs as a list of Buffers for each +# device. +# """ +# +# There are different implementations of Executable for different backends. + + +class PaddingType(enum.Enum): + VALID = 1 + SAME = 2 + + +def window_padding_type_to_pad_values( + padding_type, lhs_dims, rhs_dims, window_strides +): + """Maps PaddingType or string to pad values (list of pairs of ints).""" + if not isinstance(padding_type, (str, PaddingType)): + msg = 'padding_type must be str or PaddingType, got {}.' + raise TypeError(msg.format(type(padding_type))) + + if isinstance(padding_type, str): + if padding_type.upper() == 'VALID': + padding_type = PaddingType.VALID + elif padding_type.upper() == 'SAME': + padding_type = PaddingType.SAME + else: + msg = 'Unknown padding type string: expected "VALID" or "SAME", got {}.' + raise ValueError(msg.format(padding_type)) + + if padding_type == PaddingType.VALID: + return [(0, 0)] * len(window_strides) + elif padding_type == PaddingType.SAME: + out_shape = np.ceil(np.true_divide(lhs_dims, window_strides)).astype(int) + pad_sizes = [ + max((out_size - 1) * stride + filter_size - in_size, 0) + for out_size, stride, filter_size, in_size in zip( + out_shape, window_strides, rhs_dims, lhs_dims + ) + ] + return [(pad_size // 2, pad_size - pad_size // 2) for pad_size in pad_sizes] + else: + msg = 'Unexpected PaddingType value: {}' + raise ValueError(msg.format(padding_type)) + + +XlaBuilder = _xla.XlaBuilder +XlaComputation = _xla.XlaComputation +XlaOp = _xla.XlaOp +FftType = _xla.FftType +Client = _xla.Client +Memory = _xla.Memory +Array = _xla.Array +ArrayImpl = _xla.ArrayImpl +LoadedExecutable = _xla.LoadedExecutable +DeviceList = _xla.DeviceList +OpSharding = _xla.OpSharding +HloSharding = _xla.HloSharding +Sharding = _xla.Sharding +NamedSharding = _xla.NamedSharding +SingleDeviceSharding = _xla.SingleDeviceSharding +PmapSharding = _xla.PmapSharding +GSPMDSharding = _xla.GSPMDSharding +PjRtLayout = _xla.PjRtLayout +AutotuneCacheMode = _xla.AutotuneCacheMode +ResultAccuracyMode = _xla.ResultAccuracy_Mode + + +def LoadedExecutable_execute(self, arguments, device=None): + del device + results = self.execute_sharded(arguments) + return [x[0] for x in results.disassemble_into_single_device_arrays()] + + +def LoadedExecutable_execute_with_token(self, arguments, device=None): + del device + results = self.execute_sharded(arguments, with_tokens=True) + return ( + [x[0] for x in results.disassemble_into_single_device_arrays()], + results.consume_token().get_token(0), + ) + + +LoadedExecutable.execute = LoadedExecutable_execute +LoadedExecutable.execute_with_token = LoadedExecutable_execute_with_token + + +class CustomCallTargetTraits(enum.IntFlag): + DEFAULT = 0 + # Calls to custom call are safe to trace into the command buffer. It means + # that calls to custom call always launch exactly the same device operations + # (can depend on attribute values) that can be captured and then replayed. + # + # Supported only for custom calls implemented with XLA FFI. + COMMAND_BUFFER_COMPATIBLE = 1 + + +class CustomCallHandler(Protocol): + + def __call__( + self, + name: str, + fn: Any, + platform: str, + /, + api_version: int = ..., + traits: CustomCallTargetTraits = ..., + ) -> None: + ... + + +_custom_callback_handler: dict[str, CustomCallHandler] = {} +# Key is xla_platform_name, value is (function_name, function, api_version) +_custom_callback: dict[ + str, list[tuple[str, Any, int, CustomCallTargetTraits]] +] = {} +_custom_callback_lock = threading.Lock() + + +def register_custom_call_target( + name: str, + fn: Any, + platform: str = 'cpu', + api_version: int = 0, + traits: CustomCallTargetTraits = CustomCallTargetTraits.DEFAULT, +) -> None: + """Registers a custom call target. + + Args: + name: bytes containing the name of the function. + fn: a PyCapsule object containing the function pointer. + platform: the target platform. + api_version: the XLA FFI version to use. Supported versions are: 0 for the + untyped FFI and 1 for the typed FFI. + traits: custom call traits corresponding to XLA FFI handler traits. + """ + # To support AMD GPUs, we need to have xla_platform_names["gpu"] == "ROCM" + # Since that is hardcoded to CUDA, we are using the following as workaround. + xla_platform_name = xla_platform_names.get(platform, platform) + with _custom_callback_lock: + if xla_platform_name in _custom_callback_handler: + _custom_callback_handler[xla_platform_name]( + name, fn, xla_platform_name, api_version, traits + ) + else: + _custom_callback.setdefault(xla_platform_name, []).append( + (name, fn, api_version, traits) + ) + + +def register_custom_call_handler( + platform: str, handler: CustomCallHandler +) -> None: + """Registers a custom handler and use it to register existing custom calls. + + If a custom call handler for the platform already exist, calling this method + is a no-op and it will not register a new handler. + + Args: + platform: the target platform. + handler: the function to register a custom call. + """ + xla_platform_name = xla_platform_names.get(platform, platform) + with _custom_callback_lock: + if xla_platform_name in _custom_callback_handler: + logger.debug( + 'Custom call handler for %s is already register. Will not register a' + ' new one', + xla_platform_name, + ) + return + _custom_callback_handler[xla_platform_name] = handler + if xla_platform_name in _custom_callback: + for name, fn, api_version, traits in _custom_callback[xla_platform_name]: + handler(name, fn, xla_platform_name, api_version, traits) + del _custom_callback[xla_platform_name] + + +class CustomTypeIdHandler(Protocol): + + def __call__(self, name: str, capsule: Any) -> None: + ... + + +_custom_type_id_handler: dict[str, CustomTypeIdHandler] = {} +_custom_type_id: dict[str, Any] = {} +_custom_type_id_lock = threading.Lock() + + +def register_custom_type_id( + type_name: str, + type_id: Any, + platform: str = 'cpu', +) -> None: + """Register a custom type id for use with the FFI. + + Args: + type_name: a unique name for the type. + type_id: a PyCapsule object containing a pointer to the ``ffi::TypeId``. + platform: the target platform. + """ + xla_platform_name = xla_platform_names.get(platform, platform) + with _custom_type_id_lock: + if xla_platform_name in _custom_type_id_handler: + _custom_type_id_handler[xla_platform_name](type_name, type_id) + else: + _custom_type_id.setdefault(xla_platform_name, []).append( + (type_name, type_id) + ) + + +def register_custom_type_id_handler( + platform: str, handler: CustomTypeIdHandler +) -> None: + """Register a custom type id handler and use it to register existing type ids. + + If a custom type id handler for the platform already exist, calling this + method is a no-op and it will not register a new handler. + + Args: + platform: the target platform. + handler: the function to register a custom type id. + """ + xla_platform_name = xla_platform_names.get(platform, platform) + with _custom_callback_lock: + if xla_platform_name in _custom_type_id_handler: + logger.debug( + 'Custom type id handler for %s is already register. Will not ' + 'register a new one', + xla_platform_name, + ) + return + _custom_type_id_handler[xla_platform_name] = handler + if xla_platform_name in _custom_type_id: + for name, capsule in _custom_type_id[xla_platform_name]: + handler(name, capsule) + del _custom_type_id[xla_platform_name] + + +register_custom_call_partitioner = _xla.register_custom_call_partitioner +encode_inspect_sharding_callback = _xla.encode_inspect_sharding_callback +hlo_sharding_util = _xla.hlo_sharding_util +register_custom_call_as_batch_partitionable = ( + _xla.register_custom_call_as_batch_partitionable +) + + +class PaddingConfigDimension: + """Python representation of a xla.PaddingConfigDimension protobuf.""" + + __slots__ = ('edge_padding_low', 'edge_padding_high', 'interior_padding') + + edge_padding_low: int + edge_padding_high: int + interior_padding: int + + def __init__(self): + self.edge_padding_low = 0 + self.edge_padding_high = 0 + self.interior_padding = 0 + + +class PaddingConfig: + """Python representation of a xla.PaddingConfig protobuf.""" + + __slots__ = ('dimensions',) + + def __init__(self): + self.dimensions = [] + + +def make_padding_config( + padding_config: Union[PaddingConfig, Sequence[tuple[int, int, int]]] +) -> PaddingConfig: + """Create PaddingConfig proto from list of triples of integers. + + Args: + padding_config: either a PaddingConfig or a list of integer triples + (edge_padding_low, edge_padding_high, interior_padding) representing the + configuration of the padding operation. + + Returns: + A `PaddingConfig` object. + """ + if not isinstance(padding_config, PaddingConfig): + triples = padding_config + padding_config = PaddingConfig() + for lo, hi, interior in triples: + dimension = PaddingConfigDimension() + dimension.edge_padding_low = lo + dimension.edge_padding_high = hi + dimension.interior_padding = interior + padding_config.dimensions.append(dimension) + return padding_config + + +class DotDimensionNumbers: + """Python representation of a xla.DotDimensionNumbers protobuf.""" + + __slots__ = ( + 'lhs_contracting_dimensions', + 'rhs_contracting_dimensions', + 'lhs_batch_dimensions', + 'rhs_batch_dimensions', + ) + + def __init__(self): + self.lhs_contracting_dimensions = [] + self.rhs_contracting_dimensions = [] + self.lhs_batch_dimensions = [] + self.rhs_batch_dimensions = [] + + +def make_dot_dimension_numbers( + dimension_numbers: Union[ + DotDimensionNumbers, + tuple[tuple[list[int], list[int]], tuple[list[int], list[int]]], + ] +) -> DotDimensionNumbers: + """Builds a DotDimensionNumbers object from a specification. + + Args: + dimension_numbers: either a `DotDimensionNumbers` or a nested tuple + `((lhs_contract, rhs_contract), (lhs_batch, rhs_batch))` of lists of + integers representing the dimensions to treat as contracting dimensions + and batch dimensions on each input operand. + + Returns: + A `DotDimensionNumbers` object. + """ + if isinstance(dimension_numbers, (list, tuple)): + (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers + dot_dims_proto = DotDimensionNumbers() + dot_dims_proto.lhs_contracting_dimensions.extend(lhs_contract) + dot_dims_proto.rhs_contracting_dimensions.extend(rhs_contract) + dot_dims_proto.lhs_batch_dimensions.extend(lhs_batch) + dot_dims_proto.rhs_batch_dimensions.extend(rhs_batch) + return dot_dims_proto + else: + return dimension_numbers + + +class ConvolutionDimensionNumbers: + """Python representation of a xla.ConvolutionDimensionNumbers protobuf.""" + + __slots__ = ( + 'input_batch_dimension', + 'input_feature_dimension', + 'input_spatial_dimensions', + 'kernel_input_feature_dimension', + 'kernel_output_feature_dimension', + 'kernel_spatial_dimensions', + 'output_batch_dimension', + 'output_feature_dimension', + 'output_spatial_dimensions', + ) + + def __init__(self): + self.input_batch_dimension = 0 + self.input_feature_dimension = 0 + self.input_spatial_dimensions = [] + self.kernel_input_feature_dimension = 0 + self.kernel_output_feature_dimension = 0 + self.kernel_spatial_dimensions = [] + self.output_batch_dimension = 0 + self.output_feature_dimension = 0 + self.output_spatial_dimensions = [] + + +def make_convolution_dimension_numbers( + dimension_numbers: Union[ + None, ConvolutionDimensionNumbers, tuple[str, str, str] + ], + num_spatial_dimensions: int, +) -> ConvolutionDimensionNumbers: + """Builds a ConvolutionDimensionNumbers object from a specification. + + Args: + dimension_numbers: optional, either a ConvolutionDimensionNumbers object or + a tuple (lhs_spec, rhs_spec, out_spec). Each element is a string of length + N+2 identifying by position: (1) batch dimensions in lhs, rhs, and the + output with the character 'N', (2) feature dimensions in lhs and the + output with the character 'C', (3) input and output feature dimensions in + rhs with the characters 'I' and 'O' respectively, and (4) spatial + dimension correspondences between lhs, rhs, and the output using any + distinct characters. For example, to indicate dimension numbers consistent + with the Conv operation with two spatial dimensions, one could use + ('NCHW', 'OIHW', 'NCHW'). As another example, to indicate dimension + numbers consistent with the TensorFlow Conv2D operation, one could use + ('NHWC', 'HWIO', 'NHWC'). When using the latter form of convolution + dimension specification, window strides are associated with spatial + dimension character labels according to the order in which the labels + appear in the rhs_spec string, so that window_strides[0] is matched with + the dimension corresponding to the first character appearing in rhs_spec + that is not 'I' or 'O'. By default, use the same dimension numbering as + Conv and ConvWithGeneralPadding. + num_spatial_dimensions: the number of spatial dimensions. + + Returns: + A `ConvolutionDimensionNumbers` object. + """ + if dimension_numbers is None: + nd = num_spatial_dimensions + dimension_numbers = ConvolutionDimensionNumbers() + dimension_numbers.input_batch_dimension = 0 + dimension_numbers.input_feature_dimension = 1 + dimension_numbers.output_batch_dimension = 0 + dimension_numbers.output_feature_dimension = 1 + dimension_numbers.kernel_output_feature_dimension = 0 + dimension_numbers.kernel_input_feature_dimension = 1 + dimension_numbers.input_spatial_dimensions.extend(range(2, 2 + nd)) + dimension_numbers.kernel_spatial_dimensions.extend(range(2, 2 + nd)) + dimension_numbers.output_spatial_dimensions.extend(range(2, 2 + nd)) + elif isinstance(dimension_numbers, tuple): + lhs_spec, rhs_spec, out_spec = dimension_numbers + dimension_numbers = ConvolutionDimensionNumbers() + + dimension_numbers.input_batch_dimension = lhs_spec.index('N') + dimension_numbers.input_feature_dimension = lhs_spec.index('C') + dimension_numbers.output_batch_dimension = out_spec.index('N') + dimension_numbers.output_feature_dimension = out_spec.index('C') + dimension_numbers.kernel_output_feature_dimension = rhs_spec.index('O') + dimension_numbers.kernel_input_feature_dimension = rhs_spec.index('I') + + dimension_numbers.kernel_spatial_dimensions.extend( + i for i, c in enumerate(rhs_spec) if c not in {'I', 'O'} + ) + dimension_numbers.input_spatial_dimensions.extend( + sorted( + (i for i, c in enumerate(lhs_spec) if c not in {'N', 'C'}), + key=lambda i: rhs_spec.index(lhs_spec[i]), + ) + ) + dimension_numbers.output_spatial_dimensions.extend( + sorted( + (i for i, c in enumerate(out_spec) if c not in {'N', 'C'}), + key=lambda i: rhs_spec.index(out_spec[i]), + ) + ) + return dimension_numbers + + +class PrecisionConfig: + """Python representation of a xla.PrecisionConfig protobuf.""" + + __slots__ = ('operand_precision',) + + Precision = _xla.PrecisionConfig_Precision + + def __init__(self): + self.operand_precision = [] + + +class ResultAccuracy: + """Python representation of a xla.ResultAccuracy protobuf.""" + + __slots__ = ('mode', 'atol', 'rtol', 'ulps') + + def __init__(self): + self.mode = _xla.ResultAccuracy_Mode.DEFAULT + self.atol = 0.0 + self.rtol = 0.0 + self.ulps = 0 + + +class GatherDimensionNumbers: + """Python representation of a xla.GatherDimensionNumbers protobuf.""" + + __slots__ = ( + 'offset_dims', + 'collapsed_slice_dims', + 'start_index_map', + 'index_vector_dim', + ) + + def __init__(self): + self.offset_dims = [] + self.collapsed_slice_dims = [] + self.start_index_map = [] + self.index_vector_dim = 0 + + +class ScatterDimensionNumbers: + """Python representation of a xla.ScatterDimensionNumbers protobuf.""" + + __slots__ = ( + 'update_window_dims', + 'inserted_window_dims', + 'scatter_dims_to_operand_dims', + 'index_vector_dim', + ) + + def __init__(self): + self.update_window_dims = [] + self.inserted_window_dims = [] + self.scatter_dims_to_operand_dims = [] + self.index_vector_dim = 0 + + +class ReplicaGroup: + """Python representation of a xla.ReplicaGroup protobuf.""" + + __slots__ = ('replica_ids',) + + def __init__(self): + self.replica_ids = [] + + +def _make_replica_group_proto(replica_group): + replica_group_proto = ReplicaGroup() + replica_group_proto.replica_ids.extend(replica_group) + return replica_group_proto + + +def make_replica_groups(replica_groups): + if replica_groups is None: + replica_groups_protos = [] # special value for XLA API + else: + replica_groups = list(replica_groups) + replica_groups_protos = [ + _make_replica_group_proto(group) for group in replica_groups + ] + return replica_groups_protos + + +Traceback = _xla.Traceback +Frame = _xla.Frame + + +@contextlib.contextmanager +def tracebacks(enabled=True): + """Context manager that enables or disables traceback collection.""" + saved = Traceback.enabled + Traceback.enabled = enabled + try: + yield + finally: + Traceback.enabled = saved + + +def heap_profile(client: Client) -> bytes: + """Returns a gzipped pprof protocol buffer containing a heap profile.""" + return gzip.compress(client.heap_profile()) + + +XlaRuntimeError = _xla.XlaRuntimeError + +# Perform one last garbage collection of deferred Python references. This is +# mostly to keep ASAN happy. +atexit.register(_xla.collect_garbage) + +array_result_handler = _xla.array_result_handler +batched_copy_array_to_devices_with_sharding = ( + _xla.batched_copy_array_to_devices_with_sharding +) +batched_device_put = _xla.batched_device_put +reorder_shards = _xla.reorder_shards +batched_block_until_ready = _xla.batched_block_until_ready +check_and_canonicalize_memory_kind = _xla.check_and_canonicalize_memory_kind +Layout = _xla.Layout +custom_call_targets = _xla.custom_call_targets +ArrayCopySemantics = _xla.ArrayCopySemantics diff --git a/jaxlib/xla/xla_client.pyi b/jaxlib/xla/xla_client.pyi new file mode 100644 index 000000000000..f45b27c461e8 --- /dev/null +++ b/jaxlib/xla/xla_client.pyi @@ -0,0 +1,310 @@ +# Copyright 2021 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import annotations + +from collections.abc import Callable, Mapping, Sequence +import enum +from typing import Any, Union + +import numpy + +from jaxlib import xla_extension as _xla +from jaxlib.xla_extension import ArrayImpl as ArrayImpl +from jaxlib.xla_extension import AutotuneCacheMode as AutotuneCacheMode +from jaxlib.xla_extension import Client as Client +from jaxlib.xla_extension import CompileOptions as CompileOptions +from jaxlib.xla_extension import Device as Device +from jaxlib.xla_extension import DeviceAssignment as DeviceAssignment +from jaxlib.xla_extension import DeviceList as DeviceList +from jaxlib.xla_extension import DeviceTopology as DeviceTopology +from jaxlib.xla_extension import DistributedRuntimeClient as DistributedRuntimeClient +from jaxlib.xla_extension import FftType as FftType +from jaxlib.xla_extension import Frame as Frame +from jaxlib.xla_extension import GSPMDSharding as GSPMDSharding +from jaxlib.xla_extension import HloSharding as HloSharding +from jaxlib.xla_extension import HostBufferSemantics as HostBufferSemantics +from jaxlib.xla_extension import ifrt_programs as ifrt_programs +from jaxlib.xla_extension import Layout as Layout +from jaxlib.xla_extension import LoadedExecutable as LoadedExecutable +from jaxlib.xla_extension import Memory as Memory +from jaxlib.xla_extension import NamedSharding as NamedSharding +from jaxlib.xla_extension import ops as ops +from jaxlib.xla_extension import OpSharding as OpSharding +from jaxlib.xla_extension import PjRtLayout as PjRtLayout +from jaxlib.xla_extension import PmapSharding as PmapSharding +from jaxlib.xla_extension import PrimitiveType as PrimitiveType +from jaxlib.xla_extension import ArrayCopySemantics as ArrayCopySemantics +from jaxlib.xla_extension import profiler as profiler +from jaxlib.xla_extension import Shape as Shape +from jaxlib.xla_extension import Sharding as Sharding +from jaxlib.xla_extension import SingleDeviceSharding as SingleDeviceSharding +from jaxlib.xla_extension import Traceback as Traceback +from jaxlib.xla_extension import XlaBuilder as XlaBuilder +from jaxlib.xla_extension import XlaComputation as XlaComputation +from jaxlib.xla_extension import XlaOp as XlaOp + +_version: int + +_ifrt_version: int + +mlir_api_version: int + +XLA_ELEMENT_TYPE_TO_DTYPE: dict[PrimitiveType, numpy.dtype] + +_NameValueMapping = Mapping[str, Union[str, int, list[int], float, bool]] + +def dtype_to_etype(dtype: numpy.dtype) -> PrimitiveType: + ... + +def shape_from_pyval(pyval: Any, layout: Sequence[int] | None = None) -> Any: ... + +def heap_profile(client: Client) -> bytes: + ... + +XlaRuntimeError = _xla.XlaRuntimeError + +def make_cpu_client( + asynchronous: bool = ..., + distributed_client: DistributedRuntimeClient | None = ..., + node_id: int = ..., + num_nodes: int = ..., + collectives: _xla.CpuCollectives | None = ..., + num_devices: int | None = ..., +) -> Client: + ... + +def make_gpu_client( + distributed_client: DistributedRuntimeClient | None = ..., + node_id: int = ..., + num_nodes: int = ..., + platform_name: str | None = ..., + allowed_devices: set[int] | None = ..., + mock: bool | None = ..., + mock_gpu_topology: str | None = ..., +) -> Client: + ... + +def make_tfrt_tpu_c_api_device_topology( + topology_name: str | None = None, **kwargs +) -> DeviceTopology: + ... + +def make_c_api_device_topology(c_api: Any, topology_name: str = '', **kwargs) -> DeviceTopology: + ... + +def get_topology_for_devices(devices: list[Device]) -> DeviceTopology: + ... + +def make_tpu_client( + library_path: str | None, options: _NameValueMapping | None = None +) -> Client: + ... + +def make_c_api_client( + plugin_name: str, + options: _NameValueMapping | None = None, + distributed_client: DistributedRuntimeClient | None = None, +) -> Client: + ... + +def pjrt_plugin_loaded(plugin_name: str) -> bool: + ... + +def load_pjrt_plugin_dynamically(plugin_name: str, library_path: str) -> Any: + ... + +def load_pjrt_plugin_with_c_api(plugin_name: str, c_api: Any) -> None: + ... + +def pjrt_plugin_initialized(plugin_name: str) -> bool: + ... + +def initialize_pjrt_plugin(plugin_name: str) -> None: + ... + +def generate_pjrt_gpu_plugin_options() -> _NameValueMapping: + ... + +class OpMetadata: + + def __init__( + self, + op_type: str | None = ..., + op_name: str | None = ..., + source_file: str | None = ..., + source_line: int | None = ..., + ): + ... + op_type: str | None + op_name: str | None + source_file: str | None + source_line: int | None + +class PaddingConfigDimension: + edge_padding_low: int + edge_padding_high: int + interior_padding: int + +class PaddingConfig: + dimensions: list[PaddingConfigDimension] + +def make_padding_config( + padding_config: Union[PaddingConfig, Sequence[tuple[int, int, int]]], +) -> PaddingConfig: + ... + +class PaddingType(enum.Enum): + VALID = 1 + SAME = 2 + +class DotDimensionNumbers: + lhs_contracting_dimensions: list[int] + rhs_contracting_dimensions: list[int] + lhs_batch_dimensions: list[int] + rhs_batch_dimensions: list[int] + +def make_dot_dimension_numbers( + dimension_numbers: Union[ + DotDimensionNumbers, + tuple[tuple[list[int], list[int]], tuple[list[int], list[int]]], + ], +) -> DotDimensionNumbers: + ... + +class ConvolutionDimensionNumbers: + input_batch_dimension: int + input_feature_dimension: int + input_spatial_dimensions: list[int] + kernel_input_feature_dimension: int + kernel_output_feature_dimension: int + kernel_spatial_dimensions: list[int] + output_batch_dimension: int + output_feature_dimension: int + output_spatial_dimensions: list[int] + +def make_convolution_dimension_numbers( + dimension_numbers: Union[ + None, ConvolutionDimensionNumbers, tuple[str, str, str] + ], + num_spatial_dimensions: int, +) -> ConvolutionDimensionNumbers: + ... + +class PrecisionConfig: + Precision = _xla.PrecisionConfig_Precision + operand_precision: list[_xla.PrecisionConfig_Precision] + +class ResultAccuracy: + mode: _xla.ResultAccuracy_Mode + atol: float + rtol: float + ulps: int + +class GatherDimensionNumbers: + offset_dims: list[int] + collapsed_slice_dims: list[int] + start_index_map: list[int] + index_vector_dim: int + operand_batching_dims: list[int] + start_indices_batching_dims: list[int] + +class ScatterDimensionNumbers: + update_window_dims: list[int] + inserted_window_dims: list[int] + scatter_dims_to_operand_dims: list[int] + index_vector_dim: int + input_batching_dims: list[int] + scatter_indices_batching_dims: list[int] + +class ReplicaGroup: + replica_ids: list[int] + +def make_replica_groups( + replica_groups: Sequence[Sequence[int]] | None, +) -> list[ReplicaGroup]: + ... + +def weakref_lru_cache(cache_context_fn: Callable, call: Callable, maxsize=...) -> _xla.WeakrefLRUCache: + ... + +def batched_copy_array_to_devices_with_sharding( + arrays: Sequence[ArrayImpl], + devices: Sequence[list[Device]], + sharding: Sequence[Any], + array_copy_semantics: Sequence[ArrayCopySemantics], +) -> Sequence[ArrayImpl]: ... + +def batched_device_put( + aval: Any, + sharding: Any, + shards: Sequence[Any], + devices: list[Device], + committed: bool = ..., + force_copy: bool = ..., + host_buffer_semantics: Any = ..., +) -> ArrayImpl: ... + +def reorder_shards( + x: ArrayImpl, + dst_sharding: Any, + array_copy_semantics: ArrayCopySemantics, +) -> ArrayImpl: ... + +def batched_block_until_ready(x: Sequence[ArrayImpl]) -> None: ... + +def check_and_canonicalize_memory_kind( + memory_kind: str | None, device_list: DeviceList +) -> str | None: ... + +def array_result_handler( + aval: Any, + sharding: Any, + committed: bool, + _skip_checks: bool = ...) -> Callable: + ... + +class CustomCallTargetTraits(enum.IntFlag): + DEFAULT = 0 + COMMAND_BUFFER_COMPATIBLE = 1 + +def register_custom_call_target( + name: str, + fn: Any, + platform: str = ..., + api_version: int = ..., + traits: CustomCallTargetTraits = ..., +) -> None: ... + +def register_custom_call_handler( + xla_platform_name: str, handler: Any +) -> None: ... + +def custom_call_targets(platform: str) -> dict[str, Any]: ... + +def register_custom_type_id( + type_name: str, + type_id: Any, + platform: str = ..., +) -> None: ... + +def register_custom_type_id_handler(platform: str, handler: Any) -> None: ... + +def encode_inspect_sharding_callback(handler: Any) -> bytes: ... + +register_custom_call_partitioner = _xla.register_custom_call_partitioner +register_custom_call_as_batch_partitionable = ( + _xla.register_custom_call_as_batch_partitionable +) diff --git a/jaxlib/xla/xla_client_backend_independent_test.py b/jaxlib/xla/xla_client_backend_independent_test.py new file mode 100644 index 000000000000..ee1c33feb40c --- /dev/null +++ b/jaxlib/xla/xla_client_backend_independent_test.py @@ -0,0 +1,195 @@ +# Copyright 2017 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Backend-independent tests for the Python XLA client.""" + +import unittest + +from absl.testing import absltest +import numpy as np + +from jax.jaxlib.xla import xla_client + +# pylint: disable=g-import-not-at-top +try: + import portpicker +except ImportError: + portpicker = None +# pylint: enable=g-import-not-at-top + +ops = xla_client.ops + + +class ShapeTest(absltest.TestCase): + + def testInvalidShapes(self): + with self.assertRaisesRegex(xla_client.XlaRuntimeError, "invalid shape"): + xla_client.Shape.array_shape(xla_client.PrimitiveType.F32, [-2, 4]) + + with self.assertRaisesRegex( + RuntimeError, "layout minor_to_major field contains 1 element.*"): + xla_client.Shape.array_shape(xla_client.PrimitiveType.F32, [2, 4], [3]) + + with self.assertRaisesRegex( + RuntimeError, "layout minor_to_major field has out-of-bounds value.*"): + xla_client.Shape.array_shape(xla_client.PrimitiveType.F32, [2, 4], + [1, -1]) + + +class ComputationPrinting(absltest.TestCase): + + def ExampleComputation(self): + builder = xla_client.XlaBuilder("acomputation") + p0 = ops.Parameter(builder, 0, xla_client.shape_from_pyval(np.float32(0))) + p1 = ops.Parameter(builder, 1, + xla_client.shape_from_pyval(np.zeros((4,), np.float32))) + x = ops.Mul(p0, p1) + ops.Add(x, x) + return builder.build() + + def testComputationToHloText(self): + computation = self.ExampleComputation() + hlo_text = computation.as_hlo_text() + self.assertTrue(hlo_text.startswith("HloModule acomputation")) + + def testComputationToHloGraph(self): + computation = self.ExampleComputation() + hlo_dot_graph = computation.as_hlo_dot_graph() + self.assertTrue(hlo_dot_graph.startswith("digraph ")) + + def testHloModuleToHloText(self): + computation = self.ExampleComputation() + hlo_text = computation.as_hlo_module().to_string() + self.assertTrue(hlo_text.startswith("HloModule acomputation")) + + def testHloModuleFromText(self): + hlo_module_text = """HloModule test + add { + x = f32[] parameter(0) + y = f32[] parameter(1) + ROOT add = f32[] add(x, y) + } + ENTRY entry { + p0 = f32[2,3] parameter(0) + start = f32[2,3] all-reduce-start(p0), to_apply=add + ROOT done = f32[2,3] all-reduce-done(start) + }""" + hlo_module = xla_client._xla.hlo_module_from_text(hlo_module_text) + hlo_text = hlo_module.to_string() + self.assertTrue(hlo_text.startswith("HloModule test")) + + def testHloModuleToHloGraph(self): + computation = self.ExampleComputation() + hlo_dot_graph = xla_client._xla.hlo_module_to_dot_graph( + computation.as_hlo_module()) + self.assertTrue(hlo_dot_graph.startswith("digraph ")) + + +class ComputationHashTest(absltest.TestCase): + + def testHash(self): + builder0 = xla_client.XlaBuilder("computation0") + p0 = ops.Parameter(builder0, 0, xla_client.shape_from_pyval(np.float32(0))) + p1 = ops.Parameter(builder0, 1, + xla_client.shape_from_pyval(np.zeros((4,), np.float32))) + ops.Mul(p0, p1) + computation0 = builder0.build() + + builder1 = xla_client.XlaBuilder("computation1") + p0 = ops.Parameter(builder1, 0, xla_client.shape_from_pyval(np.float32(0))) + p1 = ops.Parameter(builder1, 1, + xla_client.shape_from_pyval(np.zeros((4,), np.float32))) + ops.Mul(p0, p1) + computation1 = builder1.build() + + self.assertEqual(computation0.hash(), computation1.hash()) + + +class AliasTest(absltest.TestCase): + + def testSetUpAlias(self): + c = xla_client.XlaBuilder(self.id()) + p1 = ops.Parameter( + c, 0, + xla_client.shape_from_pyval(np.array( + 1.0, np.float32)).with_major_to_minor_layout_if_absent()) + p2 = ops.Parameter( + c, 1, + xla_client.shape_from_pyval(np.array( + 1.0, np.float32)).with_major_to_minor_layout_if_absent()) + out = ops.Add(p1, p2) + c.setup_alias([], 0, []) + c.build(out) + + +class ProfilerTest(absltest.TestCase): + + def testTraceMe(self): + # TODO(phawkins): These tests just check that the TraceMe context manager + # acts like a context manager and doesn't explode. Ideally we'd check that + # the profiler saw the traceme too. + with xla_client.profiler.TraceMe("test1"): + pass + with xla_client.profiler.TraceMe("test2", foo=123): + pass + with self.assertRaises(ValueError): + with xla_client.profiler.TraceMe("test3"): + raise ValueError("test") + + @unittest.skipIf(portpicker is None, "Test requires portpicker") + def testStartServer(self): + port = portpicker.pick_unused_port() + server = xla_client.profiler.start_server(port) + del server + + +class HloModuleGroupTest(absltest.TestCase): + + def testHloModuleGroup(self): + builder0 = xla_client.XlaBuilder("computation0") + p0 = ops.Parameter(builder0, 0, xla_client.shape_from_pyval(np.float32(0))) + p1 = ops.Parameter(builder0, 1, + xla_client.shape_from_pyval(np.zeros((4,), np.float32))) + root = ops.Mul(p0, p1) + computation0 = builder0.build(root) + + m = computation0.get_hlo_module() + mg_name = "test_module_group" + mg = xla_client._xla.HloModuleGroup(mg_name, [m]) + self.assertEqual(mg.name, mg_name) + + modules = mg.to_modules() + self.assertLen(modules, 1) + self.assertEqual(m.to_string(), modules[0].to_string()) + + +class RunHloPassTest(absltest.TestCase): + + def testHloDCE(self): + b = xla_client.XlaBuilder("acomputation") + p0 = ops.Parameter(b, 0, xla_client.shape_from_pyval(np.float32(0))) + p1 = ops.Parameter(b, 1, + xla_client.shape_from_pyval(np.zeros((4,), np.float32))) + root = ops.Mul(p0, p1) + + # Dead instructions + p2 = ops.Parameter(b, 2, xla_client.shape_from_pyval(np.float32(0))) + ops.Add(p2, p2) + + hlo_module = b.build(root).get_hlo_module() + self.assertTrue(xla_client._xla.HloDCE().run(hlo_module)) + + +if __name__ == "__main__": + absltest.main() diff --git a/jaxlib/xla/xla_client_test.py b/jaxlib/xla/xla_client_test.py new file mode 100644 index 000000000000..15c307145b29 --- /dev/null +++ b/jaxlib/xla/xla_client_test.py @@ -0,0 +1,3739 @@ +# Copyright 2017 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Backend-dependent tests for the Python XLA client.""" + +import collections +import functools +import itertools +import re +import threading +import traceback +from typing import Sequence +import unittest + +from absl import flags +from absl import logging +from absl.testing import absltest +from absl.testing import parameterized +import ml_dtypes +import numpy as np + +from jax.jaxlib.xla import xla_client +import jax +import jax._src.test_util + +# pylint: disable=g-import-not-at-top +try: + from jax.jaxlib.xla import custom_calls_testlib +except ImportError: + custom_calls_testlib = None + +xla_client._xla.jax_jit.set_thread_local_state_initialization_callback( + lambda: None +) + +bfloat16 = ml_dtypes.bfloat16 +float4_e2m1fn = ml_dtypes.float4_e2m1fn +float8_e3m4 = ml_dtypes.float8_e3m4 +float8_e4m3 = ml_dtypes.float8_e4m3 +float8_e4m3fn = ml_dtypes.float8_e4m3fn +float8_e4m3fnuz = ml_dtypes.float8_e4m3fnuz +float8_e4m3b11fnuz = ml_dtypes.float8_e4m3b11fnuz +float8_e5m2 = ml_dtypes.float8_e5m2 +float8_e5m2fnuz = ml_dtypes.float8_e5m2fnuz +float8_e8m0fnu = ml_dtypes.float8_e8m0fnu +ops = xla_client.ops + +def xla_computation_to_mlir_module(c: xla_client.XlaComputation) -> bytes: + return xla_client._xla.mlir.hlo_to_stablehlo( + c.as_serialized_hlo_module_proto()) + + +def execute_with_python_values(executable, arguments, backend): # pylint: disable=invalid-name + """Execute on one replica with Python values as arguments and output.""" + + def put(arg): # pylint: disable=invalid-name + return backend.buffer_from_pyval(arg, device=executable.local_devices()[0]) + + arguments = [put(arg) for arg in arguments] + outputs = executable.execute(arguments) + return [np.asarray(x) for x in outputs] + + +# pylint: disable=invalid-name +def jax_array_convert_to_array(self, dtype=None, copy=None): + del copy + out, _ = self._single_device_array_to_np_array_did_copy() + if dtype is not None: + out = out.astype(dtype) + return out + + +def jax_array_device(self): + return self._sharding._device + + +def jax_array_copy_to_host_async(self): + self._copy_single_device_array_to_host_async() + + +Array = xla_client.ArrayImpl +Array.__array__ = jax_array_convert_to_array +Array.copy_to_host_async = jax_array_copy_to_host_async +Array.device = jax_array_device +xla_client.SingleDeviceSharding.device_set = property( + lambda self: {self._device} +) +# pylint: enable=invalid-name + + +FLAGS = flags.FLAGS + +# We choose to ignore pylint's complaints about complex comprehensions, which we +# use widely for parameterizing tests. +# pylint: disable=g-complex-comprehension + +_CUSTOM_CALLS_REGISTERED = False + + +# XLA' alignment is 16 bytes at the moment, but it should match what Eigen +# supports, and that can go up to 128 bytes on hardware with HVX. +_XLA_CPU_MAX_ALIGNMENT = 128 + + +# Minimum possible alignment for XLA. +_XLA_CPU_MIN_ALIGNMENT = 16 + + +# Return a copy of `x` with the given alignment. Does nothing if `x` is already +# aligned. We do this manually, because numpy doesn't support custom alignment +# value. +def _Aligned(x, alignment=_XLA_CPU_MAX_ALIGNMENT): + if (x.ctypes.data % alignment) == 0: + return x + + # Create temporary buffer with extra space for alignment. + assert alignment % x.itemsize == 0 + extra = alignment // x.itemsize + buf = np.empty(x.size + extra, dtype=x.dtype) + + # Create a view of the temporary buffer with such an offset, that the result + # buffer is aligned. + offset = (-buf.ctypes.data % alignment) // x.itemsize + result = buf[offset : offset + x.size].reshape(x.shape) + + # Copy the data to the result buffer and return it. + np.copyto(result, x) + return result + + +# Return an unaligned copy of `x`. The result buffer's memory address is +# guaranteed to not be aligned to `alignment`. This function is useful for +# testing failiures. +def _Unaligned(x, alignment=_XLA_CPU_MIN_ALIGNMENT): + if (x.ctypes.data % alignment) != 0: + return x + + # Create temporary buffer with extra space. + assert (x.itemsize % alignment) != 0 + offset = 1 + buf = np.empty(x.size + offset, dtype=x.dtype) + + if (buf.ctypes.data % alignment) != 0: + # If the temporary buffer is already unaligned, return it. + result = buf + else: + # Otherwise, create a view of the temporary buffer with an offset. + result = buf[offset : offset + x.size].reshape(x.shape) + assert (result.ctypes.data % alignment) != 0 + + # Copy the data to the result buffer and return it. + np.copyto(result, x) + return result + + +def TestFactory(xla_backend, + cloud_tpu=False, + tfrt_tpu=False, + pjrt_c_api=False, + pathways=False, + pathways_ifrt=False): + tests = [] + + int_dtypes = [np.int32, np.int64, np.uint32, np.uint64] + # TODO(phawkins): test np.float16, where supported. + float_dtypes = [bfloat16, np.float32, np.float64] + complex_dtypes = [np.complex64, np.complex128] + standard_dtypes = int_dtypes + float_dtypes + complex_dtypes + [np.bool_] + # TODO(zhangqiaorjc): test fp8 types when XLA support is complete. + # standard_dtypes is only used for BufferProtocolTest so we only test fp8 + # round trip tests. + fp8_dtypes = [ + float8_e3m4, + float8_e4m3, + float8_e4m3fn, + float8_e4m3b11fnuz, + float8_e5m2, + float8_e8m0fnu, + ] + standard_dtypes += fp8_dtypes + # TODO(upwind): testRoundTrip and testLiveBuffers fail for float4_e2m1fn type + # standard_dtypes += [float4_e2m1fn] + dlpack_dtypes = int_dtypes + float_dtypes + [np.bool_] + complex_dtypes + + class ComputationTest(parameterized.TestCase): + """Base class for running an XLA Computation through the local client.""" + + def setUp(self): + super(ComputationTest, self).setUp() + self.backend = xla_backend() + + global _CUSTOM_CALLS_REGISTERED + if self.backend.platform == "cpu" and not _CUSTOM_CALLS_REGISTERED: + for name, fn in custom_calls_testlib.registrations().items(): + xla_client.register_custom_call_target( + name, fn, platform="cpu", api_version=1 + ) + for name, val in custom_calls_testlib.type_ids().items(): + xla_client.register_custom_type_id(name, val, platform="cpu") + _CUSTOM_CALLS_REGISTERED = True + + def _NewComputation(self, name=None): + if name is None: + name = self.id() + return xla_client.XlaBuilder(name) + + def _Execute(self, c, arguments): + compiled_c = self.backend.compile( + xla_computation_to_mlir_module(c.build())) + return execute_with_python_values( + compiled_c, arguments, backend=self.backend) + + def _ExecuteAndAssertWith(self, assert_func, c, arguments, expected): + assert expected is not None + results = self._Execute(c, arguments) + self.assertLen(results, len(expected)) + for result, e in zip(results, expected): + # Numpy's comparison methods are a bit too lenient by treating inputs as + # "array-like", meaning that scalar 4 will be happily compared equal to + # [[4]]. We'd like to be more strict so assert shapes as well. + self.assertEqual(np.asanyarray(result).shape, np.asanyarray(e).shape) + assert_func(result, e) + + def _ExecuteAndCompareExact(self, c, arguments=(), expected=None): + self._ExecuteAndAssertWith(np.testing.assert_equal, c, arguments, + expected) + + def _ExecuteAndCompareClose(self, + c, + arguments=(), + expected=None, + rtol=1e-4, + atol=0): + self._ExecuteAndAssertWith( + functools.partial(np.testing.assert_allclose, rtol=rtol, atol=atol), + c, arguments, expected) + + def NumpyArrayF32(*args, **kwargs): + """Convenience wrapper to create Numpy arrays with a np.float32 dtype.""" + return np.array(*args, dtype=np.float32, **kwargs) + + def NumpyArrayF64(*args, **kwargs): + """Convenience wrapper to create Numpy arrays with a np.float64 dtype.""" + return np.array(*args, dtype=np.float64, **kwargs) + + def NumpyArrayS32(*args, **kwargs): + """Convenience wrapper to create Numpy arrays with a np.int32 dtype.""" + return np.array(*args, dtype=np.int32, **kwargs) + + def NumpyArrayBool(*args, **kwargs): + """Convenience wrapper to create Numpy arrays with a np.bool_ dtype.""" + return np.array(*args, dtype=np.bool_, **kwargs) + + class ComputationPrinting(absltest.TestCase): + + def setUp(self): + super(ComputationPrinting, self).setUp() + self.backend = xla_backend() + + def ExampleComputation(self): + builder = xla_client.XlaBuilder("acomputation") + p0 = ops.Parameter(builder, 0, xla_client.shape_from_pyval(np.float32(0))) + p1 = ops.Parameter( + builder, 1, xla_client.shape_from_pyval(np.zeros((4,), np.float32))) + x = ops.Mul(p0, p1) + ops.Add(x, x) + return builder.build() + + @unittest.skipIf(cloud_tpu or pathways, "not implemented") + def testCompiledHloModuleToHloText(self): + computation = self.ExampleComputation() + executable = self.backend.compile( + xla_computation_to_mlir_module(computation)) + hlo_modules = executable.hlo_modules() + self.assertLen(hlo_modules, 1) + hlo_text = hlo_modules[0].to_string() + self.assertTrue(hlo_text.startswith("HloModule acomputation")) + self.assertIn("fusion", hlo_text) + + @unittest.skipIf(cloud_tpu or pathways, "not implemented") + def testCompiledHloModuleAsSerializedProto(self): + computation = self.ExampleComputation() + executable = self.backend.compile( + xla_computation_to_mlir_module(computation)) + hlo_modules = executable.hlo_modules() + self.assertLen(hlo_modules, 1) + hlo_text = hlo_modules[0].to_string() + proto = hlo_modules[0].as_serialized_hlo_module_proto() + hlo_module_roundtrip = xla_client.XlaComputation(proto).get_hlo_module() + hlo_text_roundtrip = hlo_module_roundtrip.to_string() + self.assertEqual(hlo_text, hlo_text_roundtrip) + + @unittest.skipIf(cloud_tpu or pathways, "not implemented") + def testStableComputationSerialization(self): + # Ideally we would test identical computations produced in different + # processes. For now we have this limited smoke test. + computation = self.ExampleComputation() + ref = computation.as_serialized_hlo_module_proto() + for _ in range(10): + self.assertEqual(computation.as_serialized_hlo_module_proto(), ref) + + # TODO(b/261771737): some version of this should work with pjrt_c_api=True + @unittest.skipIf(cloud_tpu or pathways or pathways_ifrt or pjrt_c_api, + "not implemented") + def testFlopEstimate(self): + computation = self.ExampleComputation() + properties = xla_client._xla.hlo_module_cost_analysis( + self.backend, computation.as_hlo_module()) + self.assertEqual(properties["flops"], 8.0) + + def testFingerprint(self): + computation = self.ExampleComputation() + executable = self.backend.compile( + xla_computation_to_mlir_module(computation)) + fingerprint = executable.fingerprint + if ( + self.backend.platform == "tpu" + or self.backend.platform == "gpu" + or self.backend.platform == "cpu" + ) and not (cloud_tpu or pathways or pathways_ifrt): + logging.info("fingerprint: %s", fingerprint) + self.assertNotEmpty(fingerprint) + else: + self.assertIsNone(fingerprint) + + tests.append(ComputationPrinting) + + class ComputationsWithConstantsTest(ComputationTest): + """Tests focusing on Constant ops.""" + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in int_dtypes + float_dtypes) + def testConstantScalarSum(self, dtype): + c = self._NewComputation() + ops.Add(ops.Constant(c, dtype(1.11)), ops.Constant(c, dtype(3.14))) + self._ExecuteAndCompareClose(c, expected=[dtype(1.11) + dtype(3.14)]) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testConstantVectorMul(self, dtype): + c = self._NewComputation() + ops.Mul( + ops.Constant(c, np.array([2.5, 3.3, -1.2, 0.7], dtype)), + ops.Constant(c, np.array([-1.2, 2, -2, -3], dtype))) + self._ExecuteAndCompareClose( + c, expected=[[-3, 6.6, 2.4, -2.1]], rtol=3e-3) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testConstantVectorScalarDiv(self, dtype): + c = self._NewComputation() + ops.Div( + ops.Constant(c, np.array([1.5, 2.5, 3.0, -10.8], dtype=dtype)), + ops.Constant(c, dtype(2.0))) + self._ExecuteAndCompareClose( + c, expected=[[0.75, 1.25, 1.5, -5.4]], rtol=2e-3) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testConstantVectorScalarPow(self, dtype): + c = self._NewComputation() + ops.Pow( + ops.Constant(c, np.array([1.5, 2.5, 3.0], dtype=dtype)), + ops.Constant(c, dtype(2.))) + self._ExecuteAndCompareClose(c, expected=[[2.25, 6.25, 9.]]) + + def testIota(self): + c = self._NewComputation() + ops.Iota(c, xla_client.PrimitiveType.F32, 10) + self._ExecuteAndCompareExact( + c, expected=[np.arange(10, dtype=np.float32)]) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in int_dtypes) + def testBroadcastedIota(self, dtype): + c = self._NewComputation() + shape = xla_client.Shape.array_shape( + xla_client.dtype_to_etype(dtype), (2, 3)) + ops.Iota(c, shape, 1) + expected = np.array([[0, 1, 2], [0, 1, 2]], dtype=dtype) + self._ExecuteAndCompareExact(c, expected=[expected]) + + def testBooleanAnd(self): + c = self._NewComputation() + ops.And( + ops.Constant(c, NumpyArrayBool([True, False, True, False])), + ops.Constant(c, NumpyArrayBool([True, True, False, False]))) + self._ExecuteAndCompareExact(c, expected=[[True, False, False, False]]) + + def testBooleanOr(self): + c = self._NewComputation() + ops.Or( + ops.Constant(c, NumpyArrayBool([True, False, True, False])), + ops.Constant(c, NumpyArrayBool([True, True, False, False]))) + self._ExecuteAndCompareExact(c, expected=[[True, True, True, False]]) + + def testBooleanXor(self): + c = self._NewComputation() + ops.Xor( + ops.Constant(c, NumpyArrayBool([True, False, True, False])), + ops.Constant(c, NumpyArrayBool([True, True, False, False]))) + self._ExecuteAndCompareExact(c, expected=[[False, True, True, False]]) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testSum2D(self, dtype): + c = self._NewComputation() + ops.Add( + ops.Constant(c, np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype)), + ops.Constant(c, np.array([[1, -1, 1], [-1, 1, -1]], dtype=dtype))) + self._ExecuteAndCompareClose(c, expected=[[[2, 1, 4], [3, 6, 5]]]) + + def testShiftLeft(self): + c = self._NewComputation() + ops.ShiftLeft( + ops.Constant(c, NumpyArrayS32([3])), + ops.Constant(c, NumpyArrayS32([2]))) + self._ExecuteAndCompareClose(c, expected=[[12]]) + + def testShiftRightArithmetic(self): + c = self._NewComputation() + ops.ShiftRightArithmetic( + ops.Constant(c, NumpyArrayS32([-2])), + ops.Constant(c, NumpyArrayS32([1]))) + self._ExecuteAndCompareClose(c, expected=[[-1]]) + + def testShiftRightLogical(self): + c = self._NewComputation() + ops.ShiftRightLogical( + ops.Constant(c, NumpyArrayS32([-1])), + ops.Constant(c, NumpyArrayS32([1]))) + self._ExecuteAndCompareClose(c, expected=[[2**31 - 1]]) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testSum2DWith1DBroadcastDim0(self, dtype): + # sum of a 2D array with a 1D array where the latter is replicated across + # dimension 0 to match the former's shape. + c = self._NewComputation() + ops.Add( + ops.Constant(c, + np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], + dtype=dtype)), + ops.Constant(c, np.array([10, 20, 30], dtype=dtype)), + broadcast_dimensions=(0,)) + self._ExecuteAndCompareClose( + c, expected=[[[11, 12, 13], [24, 25, 26], [37, 38, 39]]]) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testSum2DWith1DBroadcastDim1(self, dtype): + # sum of a 2D array with a 1D array where the latter is replicated across + # dimension 1 to match the former's shape. + c = self._NewComputation() + ops.Add( + ops.Constant(c, + np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], + dtype=dtype)), + ops.Constant(c, np.array([10, 20, 30], dtype=dtype)), + broadcast_dimensions=(1,)) + self._ExecuteAndCompareClose( + c, expected=[[[11, 22, 33], [14, 25, 36], [17, 28, 39]]]) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testConstantAxpy(self, dtype): + c = self._NewComputation() + ops.Add( + ops.Mul( + ops.Constant(c, dtype(2)), + ops.Constant(c, np.array([2.2, 3.3, 4.4, 5.5], dtype=dtype))), + ops.Constant(c, np.array([100, -100, 200, -200], dtype))) + self._ExecuteAndCompareClose( + c, expected=[[104.4, -93.4, 208.8, -189]], rtol=2e-3) + + def testCustomCall(self): + if self.backend.platform != "cpu": + self.skipTest("Test requires cpu platform") + c = self._NewComputation() + ops.CustomCallWithLayout( + c, + b"subtract_f32", + operands=[ + ops.Constant(c, np.float32(1.25)), + ops.Constant(c, np.float32(0.5)) + ], + shape_with_layout=xla_client.Shape.array_shape( + np.dtype(np.float32), (), ()), + operand_shapes_with_layout=[ + xla_client.Shape.array_shape(np.dtype(np.float32), (), ()), + xla_client.Shape.array_shape(np.dtype(np.float32), (), ()), + ], + api_version=xla_client.ops.CustomCallApiVersion + .API_VERSION_TYPED_FFI) + self._ExecuteAndCompareClose(c, expected=[0.75]) + + def testCustomCallWithUnifiedApiUnknownTarget(self): + if self.backend.platform != "cpu": + self.skipTest("Test requires cpu platform") + c = self._NewComputation() + + ops.CustomCallWithLayout( + c, + b"not_existing", + operands=[], + shape_with_layout=xla_client.Shape.array_shape( + np.dtype(np.float32), (), () + ), + operand_shapes_with_layout=[], + api_version=xla_client.ops.CustomCallApiVersion + .API_VERSION_STATUS_RETURNING_UNIFIED, + ) + with self.assertRaisesRegex( + xla_client.XlaRuntimeError, expected_regex="NOT_FOUND" + ): + self._Execute(c, arguments=()) + + def testCustomCallTypedFfiUnknownTarget(self): + if self.backend.platform != "cpu": + self.skipTest("Test requires cpu platform") + c = self._NewComputation() + + ops.CustomCallWithLayout( + c, + b"not_existing", + operands=[], + shape_with_layout=xla_client.Shape.array_shape( + np.dtype(np.float32), (), () + ), + operand_shapes_with_layout=[], + api_version=xla_client.ops.CustomCallApiVersion.API_VERSION_TYPED_FFI, + ) + with self.assertRaises(xla_client.XlaRuntimeError): + self._Execute(c, arguments=()) + + def testCustomCallTypedFfiAlwaysFail(self): + if self.backend.platform != "cpu": + self.skipTest("Test requires cpu platform") + c = self._NewComputation() + + ops.CustomCallWithLayout( + c, + b"always_fail", + operands=[], + shape_with_layout=xla_client.Shape.array_shape( + np.dtype(np.float32), (), () + ), + operand_shapes_with_layout=[], + api_version=xla_client.ops.CustomCallApiVersion.API_VERSION_TYPED_FFI, + ) + + with self.assertRaisesRegex( + Exception, expected_regex="Failed intentionally" + ): + self._Execute(c, arguments=()) + + def testCustomCallTypedFfiAlwaysSucceed(self): + if self.backend.platform != "cpu": + self.skipTest("Test requires cpu platform") + c = self._NewComputation() + + ops.CustomCallWithLayout( + c, + b"always_succeed", + operands=[], + shape_with_layout=xla_client.Shape.array_shape( + np.dtype(np.float32), (), () + ), + operand_shapes_with_layout=[], + api_version=xla_client.ops.CustomCallApiVersion.API_VERSION_TYPED_FFI, + ) + + self._Execute(c, arguments=()) + + def testCustomCallTypedFfiSubtract(self): + if self.backend.platform != "cpu": + self.skipTest("Test requires cpu platform") + c = self._NewComputation() + + ops.CustomCallWithLayout( + c, + b"subtract_f32_cst", + operands=[ops.Constant(c, np.float32(1.25))], + shape_with_layout=xla_client.Shape.array_shape( + np.dtype(np.float32), (), () + ), + operand_shapes_with_layout=[ + xla_client.Shape.array_shape(np.dtype(np.float32), (), ()), + ], + opaque=b"{cst = 3.0 : f32}", + api_version=xla_client.ops.CustomCallApiVersion.API_VERSION_TYPED_FFI, + ) + self._ExecuteAndCompareClose(c, expected=[-1.75]) + + def testStatefulCustomCall(self): + if self.backend.platform != "cpu": + self.skipTest("Test requires cpu platform") + c = self._NewComputation() + ops.CustomCallWithLayout( + c, + b"stateful", + operands=[], + shape_with_layout=xla_client.Shape.array_shape( + np.dtype(np.int32), (), ()), + operand_shapes_with_layout=[], + api_version=xla_client.ops.CustomCallApiVersion + .API_VERSION_TYPED_FFI) + self._ExecuteAndCompareClose(c, expected=[42]) + + def testCustomCallLookup(self): + if self.backend.platform != "cpu": + self.skipTest("Test requires cpu platform") + + self.assertTrue(_CUSTOM_CALLS_REGISTERED) + xla_client.make_cpu_client() + self.assertContainsSubset( + list(custom_calls_testlib.registrations().keys()), + xla_client.custom_call_targets("Host").keys(), + ) + + tests.append(ComputationsWithConstantsTest) + + class ComputationFromProtoTest(absltest.TestCase): + """Test computation execution from HLO proto.""" + + def setUp(self): + super(ComputationFromProtoTest, self).setUp() + self.backend = xla_backend() + + def testExecuteFromProto(self): + # Build the HLO proto + b = xla_client.XlaBuilder("computation") + ops.Add(ops.Constant(b, np.int32(1)), ops.Constant(b, np.int32(2))) + serialized_proto = b.build().as_serialized_hlo_module_proto() + + # Load and execute the proto + c = xla_client.XlaComputation(serialized_proto) + m = xla_computation_to_mlir_module(c) + ans, = execute_with_python_values( + self.backend.compile(m), (), backend=self.backend) + np.testing.assert_equal(ans, np.int32(3)) + + tests.append(ComputationFromProtoTest) + + class ParametersTest(ComputationTest): + """Tests focusing on Parameter ops and argument-passing.""" + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in int_dtypes) + def testScalarTimesVector(self, dtype): + c = self._NewComputation() + arg0 = np.array(3, dtype=dtype) + if np.issubdtype(dtype, np.unsignedinteger): + arg1 = np.array([10, 15, 2, 7], dtype=dtype) + else: + arg1 = np.array([10, 15, -2, 7], dtype=dtype) + p0 = ops.Parameter(c, 0, xla_client.shape_from_pyval(arg0)) + p1 = ops.Parameter(c, 1, xla_client.shape_from_pyval(arg1)) + ops.Mul(p0, p1) + self._ExecuteAndCompareExact( + c, arguments=[arg0, arg1], expected=[arg0 * arg1]) + + # TODO(phawkins): test comparison harness doesn't support bfloat16 + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes if dtype != bfloat16) + def testScalarMinusVectorExplicitNumbering(self, dtype): + # Use explicit numbering and pass parameter_num first. Sub is used since + # it's not commutative and can help catch parameter reversal within the + # computation. + c = self._NewComputation() + arg0 = np.array(2.0, dtype=dtype) + arg1 = np.array([-2.3, 3.3, -4.3, 5.3], dtype=dtype) + p1 = ops.Parameter(c, 1, xla_client.shape_from_pyval(arg1)) + p0 = ops.Parameter(c, 0, xla_client.shape_from_pyval(arg0)) + ops.Sub(p1, p0) + self._ExecuteAndCompareClose( + c, arguments=[arg0, arg1], expected=[arg1 - arg0]) + + tests.append(ParametersTest) + + class LayoutsTest(ComputationTest): + """Tests related to getting and setting on-device memory layouts.""" + + def _minor_to_major(self, layout: xla_client.PjRtLayout): # pylint: disable=invalid-name + m2m_str = re.search("{([0-9,]*)", str(layout)).group(1) + if not m2m_str: + return () + return tuple(int(x) for x in m2m_str.split(",")) + + @unittest.skipIf(pathways, "not implemented") + def testGetArgumentLayouts(self): + # Create computation with a few parameters. + c = self._NewComputation() + param_count = 0 + + def MakeArg(shape, dtype): + nonlocal param_count + shape = xla_client.Shape.array_shape(np.dtype(dtype), shape) + param = ops.Parameter(c, param_count, shape) + param_count += 1 + return param + + p0 = MakeArg((2, 3, 4), np.float32) + MakeArg((3, 2), np.int32) + MakeArg((), np.float64) + + ops.Add(p0, ops.Constant(c, np.ones((2, 3, 4), np.float32))) + executable = self.backend.compile( + xla_computation_to_mlir_module(c.build())) + + # Test that compiled executable returns plausible layouts. + layouts: Sequence[xla_client.Layout] = executable.get_parameter_layouts() + self.assertLen(layouts, 3) + self.assertLen(self._minor_to_major(layouts[0]), 3) + self.assertLen(self._minor_to_major(layouts[1]), 2) + self.assertEmpty(self._minor_to_major(layouts[2])) + + @unittest.skipIf(pathways, "not implemented") + def testGetArgumentLayoutsTupled(self): + # Generated with: + # jax.jit(lambda x, y, z: (x, y, z))(np.ones((1024, 8, 128)), + # np.int32(42), + # np.ones(10)) + module_str = """ +module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, + mhlo.num_replicas = 1 : i32} { + func.func public @main( + %arg0: tensor<1024x8x128xf32> {mhlo.sharding = "{replicated}"}, + %arg1: tensor {mhlo.sharding = "{replicated}"}, + %arg2: tensor<10xf32> {mhlo.sharding = "{replicated}"}) + -> (tensor<1024x8x128xf32> {jax.result_info = "[0]"}, + tensor {jax.result_info = "[1]"}, + tensor<10xf32> {jax.result_info = "[2]"}) { + return %arg0, %arg1, %arg2 : tensor<1024x8x128xf32>, tensor, tensor<10xf32> + } +} +""" + options = xla_client.CompileOptions() + # 'parameter_is_tupled_arguments' causes MLIR untupled arguments to get + # turned into HLO tupled arguments. + options.parameter_is_tupled_arguments = True + executable = self.backend.compile(module_str, compile_options=options) + + # Test that compiled executable returns plausible layouts. + layouts: Sequence[xla_client.Layout] = executable.get_parameter_layouts() + self.assertLen(layouts, 3) + self.assertLen(self._minor_to_major(layouts[0]), 3) + self.assertEmpty(self._minor_to_major(layouts[1])) + self.assertLen(self._minor_to_major(layouts[2]), 1) + + @unittest.skipIf(pathways, "not implemented") + def testGetOutputLayouts(self): + # Generated with jax.jit(lambda: (np.ones((1024, 128)), np.int32(42), + # np.ones(10)))() + module_str = """ +module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, + mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<1024x128xf32> {jax.result_info = "[0]"}, + tensor {jax.result_info = "[1]"}, + tensor<10xf32> {jax.result_info = "[2]"}) { + %0 = stablehlo.constant dense<1.000000e+00> : tensor<1024x128xf32> + %1 = stablehlo.constant dense<1.000000e+00> : tensor<10xf32> + %2 = stablehlo.constant dense<42> : tensor + return %0, %2, %1 : tensor<1024x128xf32>, tensor, tensor<10xf32> + } +} +""" + executable = self.backend.compile(module_str) + + # Test that compiled executable returns plausible layouts. + layouts: Sequence[xla_client.Layout] = executable.get_output_layouts() + self.assertLen(layouts, 3) + self.assertLen(self._minor_to_major(layouts[0]), 2) + self.assertEmpty(self._minor_to_major(layouts[1])) + self.assertLen(self._minor_to_major(layouts[2]), 1) + + @unittest.skipIf(pathways, "not implemented") + def testSetArgumentLayouts(self): + # TODO(b/309682374): implement on CPU and GPU + if self.backend.platform != "tpu": + raise self.skipTest("mhlo.layout_mode only implemented on TPU") + + # Hand-edited version of: + # jax.jit(lambda x, y, z: (x, y, z))(np.ones((1024, 8, 128)), + # np.int32(42), + # np.ones(10)) + module_str = """ +module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, + mhlo.num_replicas = 1 : i32} { + func.func public @main( + %arg0: tensor<1024x8x128xf32> {mhlo.sharding = "{replicated}", + mhlo.layout_mode = "{0,1,2}"}, + %arg1: tensor {mhlo.sharding = "{replicated}", + mhlo.layout_mode = "{}"}, + %arg2: tensor<10xf32> {mhlo.sharding = "{replicated}", + mhlo.layout_mode = "{0}"}) + -> (tensor<1024x8x128xf32> {jax.result_info = "[0]"}, + tensor {jax.result_info = "[1]"}, + tensor<10xf32> {jax.result_info = "[2]"}) { + return %arg0, %arg1, %arg2 : tensor<1024x8x128xf32>, tensor, tensor<10xf32> + } +} + """ + executable = self.backend.compile(module_str) + + # Check input layouts. + input_layouts = executable.get_parameter_layouts() + self.assertLen(input_layouts, 3) + self.assertEqual(self._minor_to_major(input_layouts[0]), (0, 1, 2)) + self.assertEqual(self._minor_to_major(input_layouts[1]), ()) + self.assertEqual(self._minor_to_major(input_layouts[2]), (0,)) + + # Compile a version with default arg0 layout so we can make sure we + # actually set it above. + default_executable = self.backend.compile( + module_str.replace('"{0,1,2}"', '"default"') + ) + self.assertNotEqual( + self._minor_to_major(input_layouts[0]), + self._minor_to_major(default_executable.get_parameter_layouts()[0]), + ) + + @unittest.skipIf(pathways or pathways_ifrt, "not implemented") + def testSetArgumentLayoutsLegacy(self): + """Tests setting the arg layouts with compile_options (deprecated). + + New code should use the mhlo.layout_mode string attr on parameters. + """ + # Create computation with custom input layouts. + c = self._NewComputation() + param_count = 0 + + def MakeArg(shape, dtype, layout): + nonlocal param_count + arr = np.arange(np.prod(shape), dtype=dtype).reshape(shape) + param = ops.Parameter(c, param_count, + xla_client.shape_from_pyval(arr, layout)) + param_count += 1 + shape = xla_client.Shape.array_shape(np.dtype(dtype), shape, layout) + return arr, param, shape + + arg0, p0, shape0 = MakeArg((2, 3, 4), np.float32, (1, 2, 0)) + arg1, p1, shape1 = MakeArg((3, 2), np.int32, (0, 1)) + arg2, p2, shape2 = MakeArg((), np.float64, ()) + + ops.Tuple(c, [ + ops.Add(p0, ops.Constant(c, np.ones(arg0.shape, arg0.dtype))), + ops.Add(p1, ops.Constant(c, np.ones(arg1.shape, arg1.dtype))), + ops.Add(p2, ops.Constant(c, np.ones(arg2.shape, arg2.dtype))), + ]) + + # We also need to set the input layouts in the compile options. + options = xla_client.CompileOptions() + options.argument_layouts = [shape0, shape1, shape2] + executable = self.backend.compile( + xla_computation_to_mlir_module(c.build()), compile_options=options) + + # Test that compiled executable has expected layouts. + expected_layouts: Sequence[xla_client.Shape] = [shape0, shape1, shape2] + actual_layouts: Sequence[xla_client.Layout] = ( + executable.get_parameter_layouts()) + self.assertEqual(len(actual_layouts), len(expected_layouts)) + for actual, expected in zip(actual_layouts, expected_layouts): + self.assertEqual( + self._minor_to_major(actual), + expected.layout().minor_to_major(), + ) + + @unittest.skipIf(pathways, "not implemented") + def testSetOutputLayouts(self): + # TODO(b/309682374): implement on CPU and GPU + if self.backend.platform != "tpu": + raise self.skipTest("mhlo.layout_mode only implemented on TPU") + + # Hand-edited version of: + # jax.jit(lambda x, y, z: (x, y, z))(np.ones((1024, 8, 128)), + # np.int32(42), + # np.ones(10)) + module_str = """ +module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, + mhlo.num_replicas = 1 : i32} { + func.func public @main( + %arg0: tensor<1024x8x128xf32> {mhlo.sharding = "{replicated}"}, + %arg1: tensor {mhlo.sharding = "{replicated}"}, + %arg2: tensor<10xf32> {mhlo.sharding = "{replicated}"}) + -> (tensor<1024x8x128xf32> {jax.result_info = "[0]", + mhlo.layout_mode = "{0,1,2}"}, + tensor {jax.result_info = "[1]", + mhlo.layout_mode = "{}"}, + tensor<10xf32> {jax.result_info = "[2]", + mhlo.layout_mode = "{0}"}) { + return %arg0, %arg1, %arg2 : tensor<1024x8x128xf32>, tensor, tensor<10xf32> + } +} + """ + executable = self.backend.compile(module_str) + + # Check output layouts. + output_layouts = executable.get_output_layouts() + self.assertLen(output_layouts, 3) + self.assertEqual(self._minor_to_major(output_layouts[0]), (0, 1, 2)) + self.assertEqual(self._minor_to_major(output_layouts[1]), ()) + self.assertEqual(self._minor_to_major(output_layouts[2]), (0,)) + + # Compile a version with default first output layout so we can make sure + # we actually set it above. + default_executable = self.backend.compile( + module_str.replace('"{0,1,2}"', '"default"') + ) + self.assertNotEqual( + self._minor_to_major(output_layouts[0]), + self._minor_to_major(default_executable.get_output_layouts()[0]), + ) + + @unittest.skipIf(pathways, "not implemented") + def SetLayoutsSharded(self): + # TODO(b/309682374): implement on CPU and GPU + if self.backend.platform != "tpu": + raise self.skipTest("mhlo.layout_mode only implemented on TPU") + + # Hand-edited version of: + # sharding = PositionalSharding(mesh_utils.create_device_mesh((8,))) + # x = jax.device_put(np.ones((1024, 128)), sharding.reshape(4, 2)) + # jax.jit(lambda x, y: x + y, out_shardings=sharding)(x, 1.) + # + # This also lightly tests mixed default + user-specified input layouts. + module_str = """ +module @jit__lambda_ attributes {mhlo.num_partitions = 8 : i32, + mhlo.num_replicas = 1 : i32} { + func.func public @main( + %arg0: tensor<1024x128xf32> {mhlo.sharding = "{devices=[4,2]0,1,2,3,4,5,6,7}", + mhlo.layout_mode = "{0,1}"}, + %arg1: tensor {mhlo.sharding = "{replicated}"}) + -> (tensor<1024x128xf32> {jax.result_info = "", + mhlo.sharding = "{devices=[4,2]0,1,2,3,4,5,6,7}", + mhlo.layout_mode = "{0,1}"}) { + %0 = stablehlo.convert %arg1 : tensor + %1 = stablehlo.broadcast_in_dim %0, dims = [] : (tensor) -> tensor<1024x128xf32> + %2 = stablehlo.add %arg0, %1 : tensor<1024x128xf32> + return %2 : tensor<1024x128xf32> + } +} + """ + executable = self.backend.compile(module_str) + + # Check input layouts. + input_layouts = executable.get_parameter_layouts() + self.assertLen(input_layouts, 2) + self.assertEqual(self._minor_to_major(input_layouts[0]), (0, 1)) + self.assertEqual(self._minor_to_major(input_layouts[1]), ()) + + # Check output layout. + output_layouts = executable.get_output_layouts() + self.assertLen(output_layouts, 1) + self.assertEqual(self._minor_to_major(input_layouts[0]), (0, 1)) + + # Compile a version with default layouts so we can make sure we actually + # set it above. + default_executable = self.backend.compile( + module_str.replace('"{0,1}"', '"default"') + ) + self.assertNotEqual( + self._minor_to_major(input_layouts[0]), + self._minor_to_major(default_executable.get_parameter_layouts()[0]), + ) + self.assertNotEqual( + self._minor_to_major(output_layouts[0]), + self._minor_to_major(default_executable.get_output_layouts()[0]), + ) + + @unittest.skipIf(pathways, "not implemented") + def testAutoArgumentLayouts(self): + # TODO(b/309682374): implement on CPU and GPU + if self.backend.platform != "tpu": + raise self.skipTest("mhlo.layout_mode only implemented on TPU") + + # Hand-edited version of: + # jax.numpy.einsum("...a,ahd->...hd", ...) + module_str = """ +module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, + mhlo.num_replicas = 1 : i32} { + func.func public @main( + %arg0: tensor<1024x1024xf32> {mhlo.sharding = "{replicated}", + mhlo.layout_mode = "auto"}, + %arg1: tensor<1024x8x128xf32> {mhlo.sharding = "{replicated}", + mhlo.layout_mode = "auto"}) + -> (tensor<1024x8x128xf32> {jax.result_info = ""}) { + %0 = stablehlo.dot_general %arg0, %arg1, + contracting_dims = [1] x [0], + precision = [DEFAULT, DEFAULT] : (tensor<1024x1024xf32>, + tensor<1024x8x128xf32>) + -> tensor<1024x8x128xf32> + return %0 : tensor<1024x8x128xf32> + } +} +""" + executable = self.backend.compile(module_str) + + # Check input layouts. + input_layouts = executable.get_parameter_layouts() + self.assertEqual(self._minor_to_major(input_layouts[0]), (1, 0)) + self.assertEqual(self._minor_to_major(input_layouts[1]), (2, 0, 1)) + + # Compile a version with default layouts so we can make sure the compiler + # is actually choosing above. + default_executable = self.backend.compile( + module_str.replace('"auto"', '"default"') + ) + # We expect the compiler to choose a non-default layout for the second + # (1024,8,128) argument. + self.assertNotEqual( + self._minor_to_major(input_layouts[1]), + self._minor_to_major(default_executable.get_parameter_layouts()[1]), + ) + + @unittest.skipIf(pathways, "not implemented") + def testAutoOutputLayouts(self): + # TODO(b/309682374): implement on CPU and GPU + if self.backend.platform != "tpu": + raise self.skipTest("mhlo.layout_mode only implemented on TPU") + + # Generated with jax.numpy.einsum("...a,ahd->...hd", ...) + module_str = """ +module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, + mhlo.num_replicas = 1 : i32} { + func.func public @main( + %arg0: tensor<1024x1024xf32> {mhlo.sharding = "{replicated}"}, + %arg1: tensor<1024x8x128xf32> {mhlo.sharding = "{replicated}"}) + -> (tensor<1024x8x128xf32> {jax.result_info = "", + mhlo.layout_mode = "auto"}) { + %0 = stablehlo.dot_general %arg0, %arg1, + contracting_dims = [1] x [0], + precision = [DEFAULT, DEFAULT] : (tensor<1024x1024xf32>, + tensor<1024x8x128xf32>) + -> tensor<1024x8x128xf32> + return %0 : tensor<1024x8x128xf32> + } +} +""" + executable = self.backend.compile(module_str) + + # Check output layout + output_layout, = executable.get_output_layouts() + self.assertEqual(self._minor_to_major(output_layout), (2, 0, 1)) + + # Compile a version with default layouts so we can make sure the compiler + # is actually choosing above. + default_executable = self.backend.compile( + module_str.replace('"auto"', '"default"') + ) + # We expect the compiler to choose a non-default output layout. + self.assertNotEqual( + self._minor_to_major(output_layout), + self._minor_to_major(default_executable.get_output_layouts()[0]), + ) + + tests.append(LayoutsTest) + + class BufferTest(ComputationTest): + """Tests focusing on execution with Buffers.""" + + def testConstantSum(self): + c = self._NewComputation() + ops.Add( + ops.Constant(c, np.float32(1.11)), ops.Constant(c, np.float32(3.14))) + self._ExecuteAndCompareClose(c, expected=[4.25]) + + def testOneParameterSum(self): + c = self._NewComputation() + ops.Add( + ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(0.))), + ops.Constant(c, np.float32(3.14))) + self._ExecuteAndCompareClose( + c, arguments=[NumpyArrayF32(1.11)], expected=[4.25]) + + def testTwoParameterSum(self): + c = self._NewComputation() + ops.Add( + ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(0.))), + ops.Parameter(c, 1, xla_client.shape_from_pyval(NumpyArrayF32(0.)))) + self._ExecuteAndCompareClose( + c, + arguments=[NumpyArrayF32(1.11), + NumpyArrayF32(3.14)], + expected=[4.25]) + + @unittest.skipIf(cloud_tpu or pathways, "not implemented") + def testCannotCallWithDeletedBuffers(self): + c = self._NewComputation() + ops.Add( + ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(0.))), + ops.Constant(c, np.float32(3.14))) + arg = NumpyArrayF32(1.11) + compiled_c = self.backend.compile( + xla_computation_to_mlir_module(c.build())) + arg_buffer = self.backend.buffer_from_pyval(arg) + arg_buffer.delete() + with self.assertRaises(xla_client.XlaRuntimeError): + compiled_c.execute([arg_buffer]) + + def testXlaShapeIndex(self): + a = xla_client.ShapeIndex((1, 2)) + b = xla_client.ShapeIndex((1, 2)) + c = xla_client.ShapeIndex((2, 3)) + self.assertEqual(a, b) + self.assertNotEqual(b, c) + + def testLayout(self): + f32 = xla_client.PrimitiveType.F32 + a = xla_client.Shape.array_shape(f32, (2, 3), (0, 1)).layout() + b = xla_client.Shape.array_shape(f32, (2, 3), (0, 1)).layout() + c = xla_client.Shape.array_shape(f32, (2, 3), (1, 0)).layout() + self.assertEqual(a.minor_to_major(), (0, 1)) + self.assertEqual(b.minor_to_major(), (0, 1)) + self.assertEqual(c.minor_to_major(), (1, 0)) + self.assertEqual(a, b) + self.assertNotEqual(a, c) + self.assertNotEqual(b, c) + self.assertEqual(hash(a), hash(b)) + self.assertNotEqual(hash(a), hash(c)) + self.assertNotEqual(hash(b), hash(c)) + + def testBlockUntilReadyWorks(self): + arg = np.array([[1., 2.]], np.float32) + arg_buffer = self.backend.buffer_from_pyval(arg) + arg_buffer.block_until_ready() + # This test merely checks that nothing goes awry when we call + # block_until_ready(); it's difficult to test anything else. + + def testBlockUntilReadyRaisesOnDeletedBuffer(self): + arg = np.array([[1., 2.]], np.float32) + buffer = self.backend.buffer_from_pyval(arg) + buffer.delete() + with self.assertRaisesRegex( + RuntimeError, + re.escape( + "BlockHostUntilReady() called on deleted or donated buffer")): + buffer.block_until_ready() + + @unittest.skipIf(pathways_ifrt, "not implemented") + def testOnDeviceSizeInBytes(self): + if not isinstance(self.backend, xla_client.Client): + self.skipTest("TPU Driver doesn't support OnDeviceSizeInBytes.") + arg0 = np.array([]) + arg1 = np.array([[0., 1., 2.]], np.float32) + arg2 = np.array([[3., 4., 5.]], bfloat16) + arg0_buffer = self.backend.buffer_from_pyval(arg0) + arg1_buffer = self.backend.buffer_from_pyval(arg1) + arg2_buffer = self.backend.buffer_from_pyval(arg2) + self.assertEqual(arg0_buffer.on_device_size_in_bytes(), 0) + # OnDeviceSizeInBytes varies depending on the platform. Confirm there's + # a reasonable value. + self.assertGreater(arg1_buffer.on_device_size_in_bytes(), 0) + self.assertGreater(arg2_buffer.on_device_size_in_bytes(), 0) + + def testLiveBuffers(self): + if not isinstance(self.backend, xla_client.Client): + self.skipTest("TPU Driver doesn't support LiveBuffers().") + self.assertEmpty(self.backend.live_buffers()) + arg0 = np.array([]) + arg1 = np.array([[0., 1., 2.]], np.float32) + arg2 = np.array([[3., 4., 5.]], bfloat16) + arg0_buffer = self.backend.buffer_from_pyval(arg0) + arg1_buffer = self.backend.buffer_from_pyval(arg1) + arg2_buffer = self.backend.buffer_from_pyval(arg2) + self.assertLen(self.backend.live_buffers(), 3) + self.assertIs(self.backend.live_buffers()[0], arg2_buffer) + self.assertIs(self.backend.live_buffers()[1], arg1_buffer) + self.assertIs(self.backend.live_buffers()[2], arg0_buffer) + + arg1_buffer.delete() + self.assertLen(self.backend.live_buffers(), 2) + self.assertIs(self.backend.live_buffers()[0], arg2_buffer) + self.assertIs(self.backend.live_buffers()[1], arg0_buffer) + + arg0_buffer.delete() + arg2_buffer.delete() + self.assertEmpty(self.backend.live_buffers()) + + def testCopyToHost(self): + arg0 = np.array([[1., 2.]], np.float32) + arg1 = np.array([[3., 4.]], np.float32) + arg0_buffer = self.backend.buffer_from_pyval(arg0) + arg1_buffer = self.backend.buffer_from_pyval(arg1) + # Prefetch two buffers using copy_to_host_async, and then retrieve their + # values using np.asarray(). + arg0_buffer.copy_to_host_async() + arg0_buffer.copy_to_host_async() # Duplicate calls don't do anything. + arg1_buffer.copy_to_host_async() + np.testing.assert_equal(arg0, np.asarray(arg0_buffer)) + np.testing.assert_equal(arg1, np.asarray(arg1_buffer)) + # copy_to_host_async does nothing after np.asarray() is called. + arg0_buffer.copy_to_host_async() + np.testing.assert_equal(arg0, np.asarray(arg0_buffer)) + + def testDevice(self): + x = np.arange(8, dtype=np.int32) + for device in self.backend.local_devices(): + buf = self.backend.buffer_from_pyval(x, device=device) + self.assertEqual(buf.device(), device) + np.testing.assert_equal(x, np.asarray(buf)) + + def testStandardTypes(self): + for dtype in standard_dtypes: + if dtype == np.complex128: + continue + # float8_e8m0fnu is not supported on TPU. + if dtype == float8_e8m0fnu and self.backend.platform == "tpu": + continue + # float8_e4m3b11fnuz not supported on some TPU backends. + if ( + dtype + in [ + float8_e3m4, + float8_e4m3, + float8_e4m3fnuz, + float8_e4m3b11fnuz, + float8_e5m2fnuz, + ] + and self.backend.platform == "tpu" + ): + if self.backend.platform_version.find("TPU") == -1: + continue + arr = self.backend.buffer_from_pyval(np.array([0, 1], dtype)) + arr = np.asarray(arr) + self.assertEqual(dtype, type(arr[0])) + + @unittest.skipIf(pathways_ifrt, "not implemented") + def testUnsafeBufferPointer(self): + if not isinstance(self.backend, xla_client.Client): + self.skipTest("TPU Driver doesn't support UnsafeBufferPointer().") + arg0 = np.array([]) + arg1 = np.array([[0., 1., 2.]], np.float32) + arg2 = np.array([[3., 4., 5.]], bfloat16) + arg0_buffer = self.backend.buffer_from_pyval(arg0) + arg1_buffer = self.backend.buffer_from_pyval(arg1) + arg2_buffer = self.backend.buffer_from_pyval(arg2) + self.assertGreaterEqual(arg0_buffer.unsafe_buffer_pointer(), 0) + self.assertGreaterEqual(arg1_buffer.unsafe_buffer_pointer(), 0) + self.assertGreaterEqual(arg2_buffer.unsafe_buffer_pointer(), 0) + + @unittest.skipIf(cloud_tpu or pathways or pathways_ifrt, "not implemented") + def testClone(self): + x = np.array([[3., 4., 5.]], np.float32) + y = self.backend.buffer_from_pyval(x) + z = y.clone() + self.assertNotEqual(id(x), id(y)) + np.testing.assert_array_equal(np.asarray(y), np.asarray(z)) + self.assertEqual(y.unsafe_buffer_pointer(), z.unsafe_buffer_pointer()) + + tests.append(BufferTest) + + class SingleOpTest(ComputationTest): + """Tests for single ops. + + The goal here is smoke testing - to exercise the most basic functionality of + single XLA ops. As minimal as possible number of additional ops are added + around the op being tested. + """ + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testConcatenate(self, dtype): + c = self._NewComputation() + args = ( + ops.Constant(c, np.array([1.0, 2.0, 3.0], dtype=dtype)), + ops.Constant(c, np.array([4.0, 5.0, 6.0], dtype=dtype)), + ) + ops.ConcatInDim(c, args, dimension=0) + self._ExecuteAndCompareExact( + c, expected=[np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype=dtype)]) + + # pyformat: disable + @parameterized.named_parameters({ + "testcase_name": "_{}_{}".format(src_dtype.__name__, + dst_dtype.__name__), + "src_dtype": src_dtype, + "dst_dtype": dst_dtype, + } for src_dtype, dst_dtype in itertools.permutations( + [np.bool_, np.int32, np.int64, np.float32, np.float64], 2)) + # pyformat: enable + def testConvertElementType(self, src_dtype, dst_dtype): + if ((src_dtype in [np.int64, np.float64] or + dst_dtype in [np.int64, np.float64]) and + self.backend.platform == "tpu"): + self.skipTest("TPU doesn't support float64") + c = self._NewComputation() + x = np.array([0, 1, 0, 0, 1], dtype=src_dtype) + ops.ConvertElementType( + ops.Constant(c, x), xla_client.dtype_to_etype(dst_dtype)) + + result = execute_with_python_values( + self.backend.compile(xla_computation_to_mlir_module(c.build())), (), + backend=self.backend) + self.assertLen(result, 1) + expected = np.array(x, dtype=dst_dtype) + + self.assertEqual(result[0].shape, expected.shape) + self.assertEqual(result[0].dtype, expected.dtype) + np.testing.assert_equal(result[0], expected) + + # pyformat: disable + @parameterized.named_parameters( + { + "testcase_name": "_{}_{}".format(src_dtype.__name__, + dst_dtype.__name__), + "src_dtype": src_dtype, + "dst_dtype": dst_dtype, + } + for dtypes in [[np.int32, np.float32], [np.int64, np.float64]] + for src_dtype, dst_dtype in itertools.permutations(dtypes, 2)) + # pyformat: enable + def testBitcastConvertType(self, src_dtype, dst_dtype): + if (np.float64 in (src_dtype, dst_dtype) and + self.backend.platform == "tpu"): + self.skipTest("TPU doesn't support float64") + c = self._NewComputation() + x = np.array([0, 1, 0, 0, 1], dtype=src_dtype) + ops.BitcastConvertType( + ops.Constant(c, x), xla_client.dtype_to_etype(dst_dtype)) + + result = execute_with_python_values( + self.backend.compile(xla_computation_to_mlir_module(c.build())), (), + backend=self.backend) + self.assertLen(result, 1) + expected = x.view(dst_dtype) + + self.assertEqual(result[0].shape, expected.shape) + self.assertEqual(result[0].dtype, expected.dtype) + np.testing.assert_equal(result[0], expected) + + # TODO(b/123523486) implement AllToAll on CPU + def DISABLED_testAllToAllOneReplica(self): + samples = [ + NumpyArrayF32([97.0]), + NumpyArrayF32([64.0, 117.0]), + NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]), + ] + for lhs in samples[:1]: + c = self._NewComputation() + ops.AllToAll(ops.Constant(c, lhs), 0, 0) + self._ExecuteAndCompareExact(c, expected=[lhs]) + + def testCrossReplicaSumOneReplica(self): + samples = [ + NumpyArrayF32(42.0), + NumpyArrayF32([97.0]), + NumpyArrayF32([64.0, 117.0]), + NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]), + ] + for lhs in samples: + c = self._NewComputation() + ops.CrossReplicaSum(ops.Constant(c, lhs)) + self._ExecuteAndCompareExact(c, expected=[lhs]) + + def testReplicaId(self): + c = self._NewComputation() + _ = ops.ReplicaId(c) + self._ExecuteAndCompareExact(c, expected=[0]) + + def testCrossReplicaSumOneReplicaWithSingletonGroup(self): + samples = [ + NumpyArrayF32(42.0), + NumpyArrayF32([97.0]), + NumpyArrayF32([64.0, 117.0]), + NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]), + ] + for lhs in samples: + c = self._NewComputation() + ops.CrossReplicaSum( + ops.Constant(c, lhs), xla_client.make_replica_groups([[0]])) + self._ExecuteAndCompareExact(c, expected=[lhs]) + + # TODO(phawkins): np.dot implementation doesn't support bfloat16 + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes if dtype != bfloat16) + def testDotMatrixVector(self, dtype): + c = self._NewComputation() + lhs = np.array([[2.0, 3.0], [4.0, 5.0]], dtype=dtype) + rhs = np.array([[10.0], [20.0]], dtype=dtype) + ops.Dot(ops.Constant(c, lhs), ops.Constant(c, rhs)) + self._ExecuteAndCompareClose(c, expected=[np.dot(lhs, rhs)]) + + # TODO(phawkins): np.dot implementation doesn't support bfloat16 + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes if dtype != bfloat16) + def testDotMatrixMatrix(self, dtype): + c = self._NewComputation() + lhs = np.array([[2.0, 3.0], [4.0, 5.0]], dtype=dtype) + rhs = np.array([[10.0, 20.0], [100.0, 200.0]], dtype=dtype) + ops.Dot(ops.Constant(c, lhs), ops.Constant(c, rhs)) + self._ExecuteAndCompareClose(c, expected=[np.dot(lhs, rhs)]) + + def testDotGeneral(self): + c = self._NewComputation() + rng = np.random.RandomState(0) + lhs = NumpyArrayF32(rng.randn(10, 3, 4)) + rhs = NumpyArrayF32(rng.randn(10, 4, 5)) + dimension_numbers = xla_client.make_dot_dimension_numbers( + (([2], [1]), ([0], [0]))) + ops.DotGeneral( + ops.Constant(c, lhs), ops.Constant(c, rhs), dimension_numbers) + self._ExecuteAndCompareClose(c, expected=[np.matmul(lhs, rhs)], rtol=2e-6) + + def testDotGeneralWithDotDimensionNumbersProto(self): + c = self._NewComputation() + rng = np.random.RandomState(0) + lhs = NumpyArrayF32(rng.randn(10, 3, 4)) + rhs = NumpyArrayF32(rng.randn(10, 4, 5)) + + dimension_numbers = xla_client.DotDimensionNumbers() + dimension_numbers.lhs_contracting_dimensions.append(2) + dimension_numbers.rhs_contracting_dimensions.append(1) + dimension_numbers.lhs_batch_dimensions.append(0) + dimension_numbers.rhs_batch_dimensions.append(0) + + ops.DotGeneral( + ops.Constant(c, lhs), ops.Constant(c, rhs), dimension_numbers) + self._ExecuteAndCompareClose(c, expected=[np.matmul(lhs, rhs)], rtol=2e-6) + + def testDotGeneralWithPrecisionConfig(self): + c = self._NewComputation() + rng = np.random.RandomState(0) + lhs = NumpyArrayF32(rng.randn(10, 3, 4)) + rhs = NumpyArrayF32(rng.randn(10, 4, 5)) + dimension_numbers = xla_client.make_dot_dimension_numbers( + (([2], [1]), ([0], [0]))) + config = xla_client.PrecisionConfig() + config.operand_precision.append(config.Precision.HIGH) + config.operand_precision.append(config.Precision.HIGHEST) + ops.DotGeneral( + ops.Constant(c, lhs), + ops.Constant(c, rhs), + dimension_numbers, + precision_config=config) + self._ExecuteAndCompareClose(c, expected=[np.matmul(lhs, rhs)], rtol=2e-6) + + def testConvGeneralDilatedF32(self): + c = self._NewComputation() + a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") + lhs = a(1, 1, 2, 3) + rhs = a(1, 1, 1, 2) * 10 + strides = [1, 1] + pads = [(1, 0), (0, 1)] + lhs_dilation = (2, 1) + rhs_dilation = (1, 1) + dimension_numbers = xla_client.make_convolution_dimension_numbers( + ("NCHW", "OIHW", "NCHW"), 2) + ops.ConvGeneralDilated( + ops.Constant(c, lhs), ops.Constant(c, rhs), strides, pads, + lhs_dilation, rhs_dilation, dimension_numbers) + result = np.array([[[ + [0., 0., 0.], + [10., 20., 0.], + [0., 0., 0.], + [40., 50., 0.], + ]]]) + self._ExecuteAndCompareClose(c, expected=[result]) + + def testConvGeneralDilatedF32WithPrecisionConfig(self): + c = self._NewComputation() + a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") + lhs = a(1, 1, 2, 3) + rhs = a(1, 1, 1, 2) * 10 + strides = [1, 1] + pads = [(1, 0), (0, 1)] + lhs_dilation = (2, 1) + rhs_dilation = (1, 1) + dimension_numbers = xla_client.make_convolution_dimension_numbers( + ("NCHW", "OIHW", "NCHW"), 2) + config = xla_client.PrecisionConfig() + config.operand_precision.append(config.Precision.HIGHEST) + config.operand_precision.append(config.Precision.DEFAULT) + ops.ConvGeneralDilated( + ops.Constant(c, lhs), + ops.Constant(c, rhs), + strides, + pads, + lhs_dilation, + rhs_dilation, + dimension_numbers, + precision_config=config) + result = np.array([[[ + [0., 0., 0.], + [10., 20., 0.], + [0., 0., 0.], + [40., 50., 0.], + ]]]) + self._ExecuteAndCompareClose(c, expected=[result]) + + def testConvGeneralDilatedPermutedF32(self): + c = self._NewComputation() + a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") + lhs = a(1, 1, 2, 3) + rhs = a(1, 1, 1, 2) * 10 + strides = [1, 1] + pads = [(1, 0), (0, 1)] + lhs_dilation = (2, 1) + rhs_dilation = (1, 1) + + dimension_numbers = xla_client.make_convolution_dimension_numbers( + ("NHWC", "OIHW", "CWNH"), 2) + ops.ConvGeneralDilated( + ops.Constant(c, np.transpose(lhs, + (0, 2, 3, 1))), ops.Constant(c, rhs), + strides, pads, lhs_dilation, rhs_dilation, dimension_numbers) + result = np.array([[[[0., 0., 0.], [10., 20., 0.], [0., 0., 0.], + [40., 50., 0.]]]]) + self._ExecuteAndCompareClose( + c, expected=[np.transpose(result, (1, 3, 0, 2))]) + + def testConvGeneralDilatedGroupedConvolutionF32(self): + c = self._NewComputation() + a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") + lhs = a(1, 2, 2, 3) + rhs = a(2, 1, 1, 2) * 10 + strides = [1, 1] + pads = [(1, 0), (0, 1)] + lhs_dilation = (2, 1) + rhs_dilation = (1, 1) + dimension_numbers = xla_client.make_convolution_dimension_numbers( + ("NCHW", "OIHW", "NCHW"), 2) + feature_group_count = 2 + ops.ConvGeneralDilated( + ops.Constant(c, lhs), ops.Constant(c, rhs), strides, pads, + lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count) + result = np.array([[[ + [0., 0., 0.], + [10., 20., 0.], + [0., 0., 0.], + [40., 50., 0.], + ], [ + [0., 0., 0.], + [330., 380., 160.], + [0., 0., 0.], + [480., 530., 220.], + ]]]) + self._ExecuteAndCompareClose(c, expected=[result]) + + def testConvGeneralDilatedWindowReversalF32(self): + c = self._NewComputation() + a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") + lhs = a(1, 1, 2, 3) + rhs = a(1, 1, 1, 2) * 10 + strides = [1, 1] + pads = [(1, 0), (0, 1)] + lhs_dilation = (2, 1) + rhs_dilation = (1, 1) + window_reversal = [False, True] + dimension_numbers = xla_client.make_convolution_dimension_numbers( + ("NCHW", "OIHW", "NCHW"), 2) + ops.ConvGeneralDilated( + ops.Constant(c, lhs), + ops.Constant(c, rhs), + strides, + pads, + lhs_dilation, + rhs_dilation, + dimension_numbers, + window_reversal=window_reversal) + result = np.array([[[ + [0., 0., 0.], + [0., 10., 20.], + [0., 0., 0.], + [30., 40., 50.], + ]]]) + self._ExecuteAndCompareClose(c, expected=[result]) + + def testBooleanNot(self): + c = self._NewComputation() + arr = NumpyArrayBool([True, False, True]) + ops.Not(ops.Constant(c, arr)) + self._ExecuteAndCompareClose(c, expected=[~arr]) + + def testPopulationCount(self): + c = self._NewComputation() + arr = NumpyArrayS32([3, 0, 1]) + ops.PopulationCount(ops.Constant(c, arr)) + self._ExecuteAndCompareClose(c, expected=[np.array([2, 0, 1])]) + + def testCountLeadingZeros(self): + c = self._NewComputation() + arr = NumpyArrayS32([0x7FFF, 0x12345678]) + ops.Clz(ops.Constant(c, arr)) + self._ExecuteAndCompareClose(c, expected=[[17, 3]]) + + def testExp(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + ops.Exp(ops.Constant(c, arr)) + self._ExecuteAndCompareClose(c, expected=[np.exp(arr)]) + + def testExpWithResultAccuracy(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + accuracy = xla_client.ResultAccuracy() + accuracy.mode = xla_client.ResultAccuracyMode.DEFAULT + ops.Exp(ops.Constant(c, arr), accuracy) + self._ExecuteAndCompareClose(c, expected=[np.exp(arr)]) + + def testExpm1(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + ops.Expm1(ops.Constant(c, arr)) + self._ExecuteAndCompareClose(c, expected=[np.expm1(arr)]) + + def testExpm1WithResultAccuracy(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + accuracy = xla_client.ResultAccuracy() + accuracy.mode = xla_client.ResultAccuracyMode.DEFAULT + ops.Expm1(ops.Constant(c, arr), accuracy) + self._ExecuteAndCompareClose(c, expected=[np.expm1(arr)]) + + def testRound(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + ops.Round(ops.Constant(c, arr)) + self._ExecuteAndCompareClose(c, expected=[np.round(arr)]) + + def testLog(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + ops.Log(ops.Constant(c, arr)) + self._ExecuteAndCompareClose(c, expected=[np.log(arr)]) + + def testLog1p(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + ops.Log1p(ops.Constant(c, arr)) + self._ExecuteAndCompareClose(c, expected=[np.log1p(arr)]) + + def testNeg(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + ops.Neg(ops.Constant(c, arr)) + self._ExecuteAndCompareClose(c, expected=[-arr]) + + def testFloor(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + ops.Floor(ops.Constant(c, arr)) + self._ExecuteAndCompareClose(c, expected=[np.floor(arr)]) + + def testCeil(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + ops.Ceil(ops.Constant(c, arr)) + self._ExecuteAndCompareClose(c, expected=[np.ceil(arr)]) + + def testAbs(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, -12.1, 2.4, -1.]) + ops.Abs(ops.Constant(c, arr)) + self._ExecuteAndCompareClose(c, expected=[np.abs(arr)]) + + def testTanF32(self): + c = self._NewComputation() + arr = NumpyArrayF32([-0.2, 3.3, 12.1, 0.1, 0.0001]) + ops.Tan(ops.Constant(c, arr)) + self._ExecuteAndCompareClose(c, expected=[np.tan(arr)]) + + def testTanhF32(self): + c = self._NewComputation() + arr = NumpyArrayF32([-0.2, 3.3, 12.1, 0.1, 0.0001]) + ops.Tanh(ops.Constant(c, arr)) + self._ExecuteAndCompareClose(c, expected=[np.tanh(arr)]) + + def testTanhF64(self): + if self.backend.platform == "tpu": + self.skipTest("TPU doesn't support 64bit tanh") + c = self._NewComputation() + arr = NumpyArrayF64([-0.2, 3.3, 12.1, 0.1, 0.0001]) + ops.Tanh(ops.Constant(c, arr)) + self._ExecuteAndCompareClose(c, expected=[np.tanh(arr)], rtol=1e-12) + + def testTranspose(self): + + def _TransposeAndTest(array, permutation): + c = self._NewComputation() + ops.Transpose(ops.Constant(c, array), permutation) + expected = np.transpose(array, permutation) + self._ExecuteAndCompareClose(c, expected=[expected]) + + _TransposeAndTest(NumpyArrayF32([[1, 2, 3], [4, 5, 6]]), [0, 1]) + _TransposeAndTest(NumpyArrayF32([[1, 2, 3], [4, 5, 6]]), [1, 0]) + _TransposeAndTest(NumpyArrayF32([[1, 2], [4, 5]]), [0, 1]) + _TransposeAndTest(NumpyArrayF32([[1, 2], [4, 5]]), [1, 0]) + + arr = np.random.RandomState(0).randn(2, 3, 4).astype(np.float32) + for permutation in itertools.permutations(range(arr.ndim)): + _TransposeAndTest(arr, permutation) + _TransposeAndTest(np.asfortranarray(arr), permutation) + + def testEq(self): + c = self._NewComputation() + ops.Eq( + ops.Constant(c, NumpyArrayS32([1, 2, 3, 4])), + ops.Constant(c, NumpyArrayS32([4, 2, 3, 1]))) + self._ExecuteAndCompareExact(c, expected=[[False, True, True, False]]) + + def testNe(self): + c = self._NewComputation() + ops.Ne( + ops.Constant(c, NumpyArrayS32([1, 2, 3, 4])), + ops.Constant(c, NumpyArrayS32([4, 2, 3, 1]))) + self._ExecuteAndCompareExact(c, expected=[[True, False, False, True]]) + + ops.Ne( + ops.Constant(c, NumpyArrayF32([-2.0, 0.0, + float("nan"), + float("nan")])), + ops.Constant(c, NumpyArrayF32([2.0, -0.0, 1.0, + float("nan")]))) + self._ExecuteAndAssertWith( + np.testing.assert_allclose, + c, (), + expected=[[True, False, True, True]]) + + def testGt(self): + c = self._NewComputation() + ops.Gt( + ops.Constant(c, NumpyArrayS32([1, 2, 3, 4, 9])), + ops.Constant(c, NumpyArrayS32([1, 0, 2, 7, 12]))) + self._ExecuteAndCompareExact( + c, expected=[[False, True, True, False, False]]) + + def testGe(self): + c = self._NewComputation() + ops.Ge( + ops.Constant(c, NumpyArrayS32([1, 2, 3, 4, 9])), + ops.Constant(c, NumpyArrayS32([1, 0, 2, 7, 12]))) + self._ExecuteAndCompareExact( + c, expected=[[True, True, True, False, False]]) + + def testLt(self): + c = self._NewComputation() + ops.Lt( + ops.Constant(c, NumpyArrayS32([1, 2, 3, 4, 9])), + ops.Constant(c, NumpyArrayS32([1, 0, 2, 7, 12]))) + self._ExecuteAndCompareExact( + c, expected=[[False, False, False, True, True]]) + + def testLe(self): + c = self._NewComputation() + ops.Le( + ops.Constant(c, NumpyArrayS32([1, 2, 3, 4, 9])), + ops.Constant(c, NumpyArrayS32([1, 0, 2, 7, 12]))) + self._ExecuteAndCompareExact( + c, expected=[[True, False, False, True, True]]) + + def testMax(self): + c = self._NewComputation() + ops.Max( + ops.Constant(c, NumpyArrayF32([1.0, 2.0, 3.0, 4.0, 9.0])), + ops.Constant(c, NumpyArrayF32([1.0, 0.0, 2.0, 7.0, 12.0]))) + self._ExecuteAndCompareExact(c, expected=[[1.0, 2.0, 3.0, 7.0, 12.0]]) + + def testMaxExplicitBroadcastDim0(self): + c = self._NewComputation() + ops.Max( + ops.Constant(c, NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), + ops.Constant(c, NumpyArrayF32([3, 4, 5])), + broadcast_dimensions=(0,)) + self._ExecuteAndCompareExact( + c, expected=[[[3, 3, 3], [4, 5, 6], [7, 8, 9]]]) + + def testMaxExplicitBroadcastDim1(self): + c = self._NewComputation() + ops.Max( + ops.Constant(c, NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), + ops.Constant(c, NumpyArrayF32([3, 4, 5])), + broadcast_dimensions=(1,)) + self._ExecuteAndCompareExact( + c, expected=[[[3, 4, 5], [4, 5, 6], [7, 8, 9]]]) + + def testMin(self): + c = self._NewComputation() + ops.Min( + ops.Constant(c, NumpyArrayF32([1.0, 2.0, 3.0, 4.0, 9.0])), + ops.Constant(c, NumpyArrayF32([1.0, 0.0, 2.0, 7.0, 12.0]))) + self._ExecuteAndCompareExact(c, expected=[[1.0, 0.0, 2.0, 4.0, 9.0]]) + + def testPad(self): + c = self._NewComputation() + ops.Pad( + ops.Constant(c, NumpyArrayF32([[1.0, 2.0], [3.0, 4.0]])), + ops.Constant(c, NumpyArrayF32(0.0)), + xla_client.make_padding_config([(1, 2, 1), (0, 1, 0)])) + self._ExecuteAndCompareClose( + c, + expected=[[[0.0, 0.0, 0.0], [1.0, 2.0, 0.0], [0.0, 0.0, 0.0], + [3.0, 4.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]) + + def testPadWithPaddingConfig(self): + c = self._NewComputation() + padding_config = xla_client.PaddingConfig() + for lo, hi, interior in [(1, 2, 1), (0, 1, 0)]: + dimension = xla_client.PaddingConfigDimension() + dimension.edge_padding_low = lo + dimension.edge_padding_high = hi + dimension.interior_padding = interior + padding_config.dimensions.append(dimension) + ops.Pad( + ops.Constant(c, NumpyArrayF32([[1.0, 2.0], [3.0, 4.0]])), + ops.Constant(c, NumpyArrayF32(0.0)), padding_config) + self._ExecuteAndCompareClose( + c, + expected=[[[0.0, 0.0, 0.0], [1.0, 2.0, 0.0], [0.0, 0.0, 0.0], + [3.0, 4.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]) + + def testReshape(self): + c = self._NewComputation() + ops.Reshape( + ops.Constant(c, NumpyArrayS32([[1, 2], [3, 4], [5, 6]])), + new_sizes=[2, 3]) + self._ExecuteAndCompareExact(c, expected=[[[1, 2, 3], [4, 5, 6]]]) + + def testCollapse(self): + c = self._NewComputation() + ops.Collapse( + ops.Constant(c, NumpyArrayS32([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])), + dimensions=[1, 2]) + self._ExecuteAndCompareExact(c, expected=[[[1, 2, 3, 4], [5, 6, 7, 8]]]) + + def testRev(self): + c = self._NewComputation() + ops.Rev( + ops.Constant(c, NumpyArrayS32([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])), + dimensions=[0, 2]) + self._ExecuteAndCompareExact( + c, expected=[[[[6, 5], [8, 7]], [[2, 1], [4, 3]]]]) + + def testReducePrecision(self): + c = self._NewComputation() + ops.ReducePrecision( + ops.Constant(c, NumpyArrayF32([float.fromhex("0x1.32fffep-3")])), + exponent_bits=8, + mantissa_bits=7) + self._ExecuteAndCompareClose(c, expected=[[float.fromhex("0x1.32p-3")]]) + + def testClampF32(self): + c = self._NewComputation() + ops.Clamp( + ops.Constant(c, NumpyArrayF32(-1)), + ops.Constant(c, NumpyArrayF32([-2, -1, 0, 1, 2, 3])), + ops.Constant(c, NumpyArrayF32(2))) + self._ExecuteAndCompareExact(c, expected=[[-1, -1, 0, 1, 2, 2]]) + + def testClampS32(self): + c = self._NewComputation() + ops.Clamp( + ops.Constant(c, NumpyArrayS32(-1)), + ops.Constant(c, NumpyArrayS32([-2, -1, 0, 1, 2, 3])), + ops.Constant(c, NumpyArrayS32(2))) + self._ExecuteAndCompareExact(c, expected=[[-1, -1, 0, 1, 2, 2]]) + + def testSelect(self): + c = self._NewComputation() + ops.Select( + ops.Constant(c, NumpyArrayBool([True, False, False, True, False])), + ops.Constant(c, NumpyArrayS32([1, 2, 3, 4, 5])), + ops.Constant(c, NumpyArrayS32([-1, -2, -3, -4, -5]))) + self._ExecuteAndCompareExact(c, expected=[[1, -2, -3, 4, -5]]) + + def testSlice(self): + c = self._NewComputation() + ops.Slice( + ops.Constant(c, NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), + [1, 0], [3, 2], [1, 1]) + self._ExecuteAndCompareExact(c, expected=[[[4, 5], [7, 8]]]) + + def testSliceInDim(self): + c = self._NewComputation() + ops.SliceInDim( + ops.Constant(c, NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), + start_index=1, + limit_index=2, + stride=1, + dimno=1) + self._ExecuteAndCompareExact(c, expected=[[[2], [5], [8]]]) + ops.SliceInDim( + ops.Constant(c, NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), + start_index=0, + limit_index=3, + stride=2, + dimno=0) + self._ExecuteAndCompareExact(c, expected=[[[1, 2, 3], [7, 8, 9]]]) + + def testDynamicSlice(self): + c = self._NewComputation() + ops.DynamicSlice( + ops.Constant(c, NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), [ + ops.Constant(c, NumpyArrayS32(1)), + ops.Constant(c, NumpyArrayS32(0)) + ], [2, 2]) + self._ExecuteAndCompareExact(c, expected=[[[4, 5], [7, 8]]]) + + def testDynamicUpdateSlice(self): + c = self._NewComputation() + ops.DynamicUpdateSlice( + ops.Constant(c, NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), + ops.Constant(c, NumpyArrayS32([[1, 2], [3, 4]])), [ + ops.Constant(c, NumpyArrayS32(1)), + ops.Constant(c, NumpyArrayS32(1)) + ]) + self._ExecuteAndCompareExact( + c, expected=[[[1, 2, 3], [4, 1, 2], [7, 3, 4]]]) + + def testTuple(self): + c = self._NewComputation() + ops.Tuple(c, [ + ops.Constant(c, np.int32(42)), + ops.Constant(c, NumpyArrayF32([1.0, 2.0])), + ops.Constant(c, NumpyArrayBool([True, False, False, True])) + ]) + result = execute_with_python_values( + self.backend.compile(xla_computation_to_mlir_module(c.build())), (), + backend=self.backend) + self.assertLen(result, 3) + np.testing.assert_equal(result[0], 42) + np.testing.assert_allclose(result[1], [1.0, 2.0]) + np.testing.assert_equal(result[2], [True, False, False, True]) + + def testGetTupleElement(self): + c = self._NewComputation() + ops.GetTupleElement( + ops.Tuple(c, [ + ops.Constant(c, np.int32(42)), + ops.Constant(c, NumpyArrayF32([1.0, 2.0])), + ops.Constant(c, NumpyArrayBool([True, False, False, True])) + ]), 1) + self._ExecuteAndCompareClose(c, expected=[[1.0, 2.0]]) + + def testBroadcast(self): + c = self._NewComputation() + ops.Broadcast( + ops.Constant(c, NumpyArrayS32([10, 20, 30, 40])), sizes=(3,)) + self._ExecuteAndCompareExact( + c, expected=[[[10, 20, 30, 40], [10, 20, 30, 40], [10, 20, 30, 40]]]) + + def testBroadcastInDim(self): + c = self._NewComputation() + ops.BroadcastInDim(ops.Constant(c, NumpyArrayS32([1, 2])), [2, 2], [0]) + self._ExecuteAndCompareExact(c, expected=[[[1, 1], [2, 2]]]) + ops.BroadcastInDim(ops.Constant(c, NumpyArrayS32([1, 2])), [2, 2], [1]) + self._ExecuteAndCompareExact(c, expected=[[[1, 2], [1, 2]]]) + + def testRngNormal(self): + shape = (2, 3) + c = self._NewComputation() + ops.RngNormal( + ops.Constant(c, NumpyArrayF32(0.)), + ops.Constant(c, NumpyArrayF32(1.)), + shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.F32, + shape)) + result = execute_with_python_values( + self.backend.compile(xla_computation_to_mlir_module(c.build())), (), + backend=self.backend) + # since the result is random, we just check shape and uniqueness + self.assertLen(result, 1) + self.assertEqual(result[0].shape, shape) + self.assertLen(np.unique(result[0]), np.prod(shape)) + + def testRngUniformF32(self): + lo, hi = 2., 4. + shape = (2, 3) + c = self._NewComputation() + ops.RngUniform( + ops.Constant(c, NumpyArrayF32(lo)), + ops.Constant(c, NumpyArrayF32(hi)), + shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.F32, + shape)) + result = execute_with_python_values( + self.backend.compile(xla_computation_to_mlir_module(c.build())), (), + backend=self.backend) + # since the result is random, we just check shape, uniqueness, and range + self.assertLen(result, 1) + self.assertEqual(result[0].shape, shape) + self.assertLen(np.unique(result[0]), np.prod(shape)) + self.assertTrue(np.all(lo <= result[0])) + self.assertTrue(np.all(result[0] < hi)) + + def testRngUniformS32(self): + lo, hi = 2, 4 + shape = (2, 3) + c = self._NewComputation() + ops.RngUniform( + ops.Constant(c, NumpyArrayS32(lo)), + ops.Constant(c, NumpyArrayS32(hi)), + shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.S32, + shape)) + result = execute_with_python_values( + self.backend.compile(xla_computation_to_mlir_module(c.build())), (), + backend=self.backend) + # since the result is random, we just check shape, integrality, and range + self.assertLen(result, 1) + self.assertEqual(result[0].shape, shape) + self.assertEqual(result[0].dtype, np.int32) + self.assertTrue(np.all(lo <= result[0])) + self.assertTrue(np.all(result[0] < hi)) + + def testCholesky(self): + l = np.array([[4, 0, 0, 0], [6, 5, 0, 0], [2, 14, 16, 0], [3, 6, 1, 4]], + dtype=np.float32) + c = self._NewComputation() + ops.Cholesky(ops.Constant(c, np.tril(np.dot(l, l.T)))) + self._ExecuteAndCompareClose(c, expected=[l], rtol=1e-4) + + def testSort(self): + keys = np.array([[2, 4, 1, 3], [3, 1, 4, 2]], dtype=np.float32) + c = self._NewComputation() + ops.Sort(c, [ops.Constant(c, keys)], is_stable=True) + self._ExecuteAndCompareClose( + c, + expected=[np.array([[1, 2, 3, 4], [1, 2, 3, 4]], dtype=np.float32)]) + + def testSortKeyVal(self): + keys = np.array([[2, 4, 1, 3], [3, 1, 4, 2]], dtype=np.float32) + values = np.array([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=np.int32) + c = self._NewComputation() + ops.Sort(c, (ops.Constant(c, keys), ops.Constant(c, values)), dimension=0) + result = execute_with_python_values( + self.backend.compile(xla_computation_to_mlir_module(c.build())), (), + backend=self.backend) + self.assertLen(result, 2) + np.testing.assert_allclose(result[0], [[2, 1, 1, 2], [3, 4, 4, 3]]) + np.testing.assert_equal(result[1], [[0, 5, 2, 7], [4, 1, 6, 3]]) + + def testSortCustomComparator(self): + b = self._NewComputation("comparator") + p0 = ops.Parameter(b, 0, xla_client.shape_from_pyval(NumpyArrayF32(0))) + q0 = ops.Parameter(b, 1, xla_client.shape_from_pyval(NumpyArrayF32(0))) + p1 = ops.Parameter(b, 2, xla_client.shape_from_pyval(NumpyArrayS32(0))) + q1 = ops.Parameter(b, 3, xla_client.shape_from_pyval(NumpyArrayS32(0))) + ops.Or(ops.Lt(p0, q0), ops.And(ops.Eq(p0, q0), ops.Gt(p1, q1))) + comparator = b.build() + + keys = np.array([[2, 3, 1, 3], [3, 1, 2, 2]], dtype=np.float32) + values = np.array([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=np.int32) + c = self._NewComputation() + ops.Sort( + c, (ops.Constant(c, keys), ops.Constant(c, values)), + dimension=1, + comparator=comparator) + result = execute_with_python_values( + self.backend.compile(xla_computation_to_mlir_module(c.build())), (), + backend=self.backend) + self.assertLen(result, 2) + np.testing.assert_allclose(result[0], [[1, 2, 3, 3], [1, 2, 2, 3]]) + np.testing.assert_equal(result[1], [[2, 0, 3, 1], [5, 7, 6, 4]]) + + def testQR(self): + a = np.array([[4, 6, 8, 10], [6, 45, 54, 63], [8, 54, 146, 166], + [10, 63, 166, 310]], + dtype=np.float32) + c = self._NewComputation() + ops.Tuple(c, ops.QR(ops.Constant(c, a), full_matrices=True)) + q, r = self._Execute(c, ()) + np.testing.assert_allclose(np.dot(q, r), a, rtol=1e-4) + + def testEigh(self): + a = np.array([[4, 6, 8, 10], [6, 45, 54, 63], [8, 54, 146, 166], + [10, 63, 166, 310]], + dtype=np.float32) + a = (a + a.T) / 2 + + c = self._NewComputation() + ops.Tuple(c, ops.Eigh(ops.Constant(c, a), lower=True)) + # TODO(b/129396575): Turn this test back on when it passes without + # fastmath. + # v, w = self._Execute(c, ()) + # self.assertLess(np.linalg.norm(np.dot(a, v) - w * v), 1e-3) + + def testSVD(self): + a = np.array([[4, 6, 8, 10], [6, 45, 54, 63], [8, 54, 146, 166], + [10, 63, 166, 310]], + dtype=np.float32) + c = self._NewComputation() + ops.Tuple(c, ops.SVD(ops.Constant(c, a))) + u, d, v = self._Execute(c, ()) + self.assertLess(np.linalg.norm(a - np.matmul(u * d, v.T)), 1e-3) + + def testTriangularSolve(self): + a_vals = np.array( + [[2, 0, 0, 0], [3, 6, 0, 0], [4, 7, 9, 0], [5, 8, 10, 11]], + dtype=np.float32) + b_vals = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], + dtype=np.float32) + + c = self._NewComputation() + ops.TriangularSolve( + ops.Constant(c, a_vals), + ops.Constant(c, b_vals), + left_side=False, + lower=True, + transpose_a=ops.TriangularSolveOptions_Transpose.TRANSPOSE, + unit_diagonal=False) + self._ExecuteAndCompareClose( + c, + expected=[ + np.array([ + [0.5, 0.08333334, 0.04629629, 0.03367003], + [2.5, -0.25, -0.1388889, -0.1010101], + [4.5, -0.58333331, -0.32407406, -0.23569024], + ], + dtype=np.float32) + ], + rtol=1e-4) + + def testApproxTopK(self): + if self.backend.platform != "tpu": + self.skipTest("ApproxTopK is only supported on TPU") + k = 10 + qy_size = 256 + db_size = 3000 + feature = 128 + recall_target = 0.95 + b = self._NewComputation() + p0 = ops.Parameter(b, 0, xla_client.shape_from_pyval(NumpyArrayF32(0))) + q0 = ops.Parameter(b, 1, xla_client.shape_from_pyval(NumpyArrayF32(0))) + ops.Parameter(b, 2, xla_client.shape_from_pyval(NumpyArrayS32(0))) + ops.Parameter(b, 3, xla_client.shape_from_pyval(NumpyArrayS32(0))) + ops.Gt(p0, q0) + comparator = b.build() + qy_shape = [qy_size, feature] + db_shape = [feature, db_size] + rng = np.random.RandomState(0) + qy_arg = rng.randn(*qy_shape).astype(np.float32) + db_arg = rng.randn(*db_shape).astype(np.float32) + b = self._NewComputation() + qy = ops.Parameter(b, 0, xla_client.shape_from_pyval(qy_arg)) + db = ops.Parameter(b, 1, xla_client.shape_from_pyval(db_arg)) + scores = ops.Dot(qy, db) + iota = ops.Iota( + b, + xla_client.Shape.array_shape(xla_client.PrimitiveType.S32, + (qy_size, db_size)), 1) + init_val = ops.Constant(b, np.float32(-1)) + init_arg = ops.Constant(b, np.int32(-1)) + ground_truth = ops.TopK(scores, k=k) + approx_topk = ops.ApproxTopK( + b, [scores, iota], [init_val, init_arg], + top_k=k, + reduction_dim=1, + comparator=comparator, + recall_target=recall_target) + ops.Tuple(b, [ + ops.GetTupleElement(ground_truth, 1), + ops.GetTupleElement(approx_topk, 1) + ]) + results = self._Execute(b, [qy_arg, db_arg]) + ground_truth_docids = [set(x) for x in results[0]] + hits = sum( + len([x for x in approx_topk_per_q if x in ground_truth_docids[q]]) + for q, approx_topk_per_q in enumerate(results[1]) + ) + self.assertGreater(hits / (qy_size * k), recall_target) + + def testIsConstant(self): + c = self._NewComputation() + a = ops.Constant(c, np.int32(3)) + b = ops.Constant(c, np.int32(1)) + x = ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayS32(0))) + const_expr = ops.Sub(b, a) + non_const_expr = ops.Mul(const_expr, x) + self.assertTrue(c.is_constant(const_expr)) + self.assertFalse(c.is_constant(non_const_expr)) + + def testGather(self): + a = np.arange(9).astype(np.int32).reshape((3, 3)) + indices = np.array([[[0, 2], [2, 1]], [[1, 2], [2, 0]]], dtype=np.int32) + dnums = xla_client.GatherDimensionNumbers() + dnums.offset_dims.append(1) + dnums.offset_dims.append(2) + dnums.start_index_map.append(0) + dnums.start_index_map.append(1) + dnums.index_vector_dim = 2 + c = self._NewComputation() + ops.Gather( + ops.Constant(c, a), + ops.Constant(c, indices), + dnums, + slice_sizes=[1, 1]) + g, = self._Execute(c, ()) + expected = np.array([[[[2, 7]]], [[[5, 6]]]], dtype=np.int32) + np.testing.assert_allclose(g, expected, rtol=1e-4) + + def testAllGather(self): + a = np.arange(9).astype(np.int32).reshape((3, 3)) + c = self._NewComputation() + ops.AllGather( + operand=ops.Constant(c, a), + all_gather_dimension=0, + shard_count=1, + replica_groups=xla_client.make_replica_groups([[0]]), + use_global_device_ids=False) + [g] = self._Execute(c, ()) + np.testing.assert_equal(g, a) + + def testFft(self): + if self.backend.platform == "tpu": + self.skipTest("TPU only supports 1D FFT") + shape = [2, 3, 4, 5] + rng = np.random.RandomState(0) + a = rng.randn(*shape) + 1.0j * rng.randn(*shape) + a = a.astype(np.complex64) + # FFT + c = self._NewComputation() + ops.Fft(ops.Constant(c, a), xla_client.FftType.FFT, shape[-3:]) + self._ExecuteAndCompareClose( + c, expected=[np.fft.fftn(a, axes=(1, 2, 3))], rtol=1e-4) + # IFFT + c = self._NewComputation() + ops.Fft(ops.Constant(c, a), xla_client.FftType.IFFT, shape[-3:]) + self._ExecuteAndCompareClose( + c, expected=[np.fft.ifftn(a, axes=(1, 2, 3))], rtol=1e-4) + # RFFT + b = rng.randn(*shape).astype(np.float32) + c = self._NewComputation() + ops.Fft(ops.Constant(c, b), xla_client.FftType.RFFT, shape[-3:]) + self._ExecuteAndCompareClose( + c, expected=[np.fft.rfftn(b, axes=(1, 2, 3))], rtol=1e-4) + # IRFFT + c = self._NewComputation() + ops.Fft(ops.Constant(c, a), xla_client.FftType.IRFFT, [3, 4, 8]) + self._ExecuteAndCompareClose( + c, expected=[np.fft.irfftn(a, axes=(1, 2, 3))], rtol=2e-4 + ) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes + fp8_dtypes) + def testNextAfter(self, dtype): + if dtype == float8_e8m0fnu: + # TODO(b/409114865): Test fails with Mismatched elements error. + self.skipTest("b/409114865: Test fails with Mismatched elements error") + if dtype in [float8_e3m4, float8_e4m3] and self.backend.platform == "tpu": + self.skipTest("TPU doesn't support float8_e3m4 or float8_e4m3") + if dtype == np.float64 and self.backend.platform == "tpu": + self.skipTest("TPU doesn't support float64") + if dtype == bfloat16 and self.backend.platform == "tpu": + self.skipTest("b/371119032: Test fails on TPUs with bfloat16") + finfo = ml_dtypes.finfo(dtype) + eps = finfo.eps + c = self._NewComputation() + # Each row is (value, direction, expected), where + # 'nextafter(value, direction)' should be 'expected'. + data = np.array( + [ + [1, 2, 1 + finfo.eps], + [2, 1, 2 - eps], + [-0., 1, finfo.smallest_subnormal], + [0., -1, -finfo.smallest_subnormal], + [-finfo.smallest_subnormal, 1, -0.], + [finfo.smallest_subnormal, 1, 2 * finfo.smallest_subnormal], + [finfo.smallest_subnormal, -1, 0], + ], + dtype=dtype, + ) + + ops.NextAfter(ops.Constant(c, data[:, 0]), ops.Constant(c, data[:, 1])) + out, = self._Execute(c, ()) + np.testing.assert_equal(out, data[:, 2]) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testRegularizedIncompleteBeta(self, dtype): + x = np.array([0.53787335, 0.24015466, 0.47494545, 0.13567594, 0.95114538], + dtype=dtype) + a = np.array([0.00753073, 0.34813385, 0.30485708, 1.29298632, 0.51472606], + dtype=dtype) + b = np.array([0.55688389, 0.59794214, 0.42661022, 1.59748339, 0.95047677], + dtype=dtype) + c = self._NewComputation() + ops.RegularizedIncompleteBeta( + ops.Constant(c, a), ops.Constant(c, b), ops.Constant(c, x)) + expected = np.array( + [0.98923271, 0.48575411, 0.57952568, 0.12579775, 0.96989155]) + self._ExecuteAndCompareClose(c, expected=[expected], rtol=2e-2) + + tests.append(SingleOpTest) + + class EmbeddedComputationsTest(ComputationTest): + """Tests for XLA graphs with embedded computations (such as maps).""" + + def _CreateConstantComputation(self, in_dtype, out_dtype): + """Computation (A) -> B that returns a constant 1 for any input.""" + c = self._NewComputation("constant_{}_{}_one".format( + in_dtype.__name__, out_dtype.__name__)) + ops.Parameter( + c, 0, + xla_client.shape_from_pyval(np.array( + 0, dtype=in_dtype)).with_major_to_minor_layout_if_absent()) + ops.Constant(c, out_dtype(1)) + return c.build() + + def _CreateMulBy2Computation(self, dtype): + """Computation (dtype) -> dtype that multiplies its parameter by 2.""" + c = self._NewComputation("mul_f32_by2") + ops.Mul( + ops.Parameter( + c, 0, + xla_client.shape_from_pyval(np.array( + 0, dtype=dtype)).with_major_to_minor_layout_if_absent()), + ops.Constant(c, dtype(2.0))) + return c.build() + + def _CreateMulF32ByParamComputation(self): + """Computation (f32) -> f32 that multiplies one parameter by the other.""" + c = self._NewComputation("mul_f32_by_param") + ops.Mul( + ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(0))), + ops.Parameter(c, 1, xla_client.shape_from_pyval(NumpyArrayF32(0)))) + return c.build() + + def _CreateBinaryAddComputation(self, dtype): + """Computation (dtype, dtype) -> dtype that adds its two parameters.""" + c = self._NewComputation("add_param0_by_param1") + shape = xla_client.shape_from_pyval(np.array(0, dtype=dtype)) + shape = shape.with_major_to_minor_layout_if_absent() + ops.Add(ops.Parameter(c, 0, shape), ops.Parameter(c, 1, shape)) + return c.build() + + def _CreateBinaryGeComputation(self, dtype): + """Computation (dtype, dtype) -> bool that tests param0 >= param1.""" + c = self._NewComputation("param0_lt_param1") + shape = xla_client.shape_from_pyval(np.array(0, dtype=dtype)) + shape = shape.with_major_to_minor_layout_if_absent() + ops.Ge(ops.Parameter(c, 0, shape), ops.Parameter(c, 1, shape)) + return c.build() + + def _MakeSample3DArray(self, dtype): + return np.array([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]], + [[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]], + dtype=dtype) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testCall(self, dtype): + c = self._NewComputation() + ops.Call( + c, + self._CreateMulBy2Computation(dtype), + operands=(ops.Constant(c, dtype(5.0)),)) + self._ExecuteAndCompareClose(c, expected=[10.0]) + + @parameterized.named_parameters({ + "testcase_name": "_{}_{}".format(in_dtype.__name__, out_dtype.__name__), + "in_dtype": in_dtype, + "out_dtype": out_dtype, + } for in_dtype, out_dtype in [[np.float32, np.int32]]) + def testMapEachElementToConstant(self, in_dtype, out_dtype): + c = self._NewComputation() + ops.Map(c, + [ops.Constant(c, np.array([1.0, 2.0, 3.0, 4.0], dtype=in_dtype))], + self._CreateConstantComputation(in_dtype, out_dtype), [0]) + self._ExecuteAndCompareExact(c, expected=[[1, 1, 1, 1]]) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testMapMulBy2(self, dtype): + if dtype == np.float64 and self.backend.platform == "tpu": + self.skipTest("TPU doesn't support float64") + c = self._NewComputation() + ops.Map(c, [ops.Constant(c, np.array([1.0, 2.0, 3.0, 4.0], dtype=dtype))], + self._CreateMulBy2Computation(dtype), [0]) + self._ExecuteAndCompareClose(c, expected=[[2.0, 4.0, 6.0, 8.0]]) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testSimpleMapChain(self, dtype): + if dtype == np.float64 and self.backend.platform == "tpu": + self.skipTest("TPU doesn't support float64") + # Chains a map of constant-out with a map of mul-by-2 + c = self._NewComputation() + const = ops.Map( + c, [ops.Constant(c, np.array([1.0, 2.0, 3.0, 4.0], dtype=dtype))], + self._CreateConstantComputation(dtype, dtype), [0]) + ops.Map(c, [const], self._CreateMulBy2Computation(dtype), [0]) + self._ExecuteAndCompareClose(c, expected=[[2.0, 2.0, 2.0, 2.0]]) + + # TODO(b/154752816): bfloat16 crashes in evaluator. + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes if dtype != bfloat16) + def testDivVectorsWithMap(self, dtype): + + def DivComputation(): + c = self._NewComputation("div_param0_by_param1") + shape = xla_client.shape_from_pyval(np.array(0, dtype=dtype)) + ops.Div(ops.Parameter(c, 0, shape), ops.Parameter(c, 1, shape)) + return c.build() + + c = self._NewComputation() + ops.Map(c, (ops.Constant(c, np.array([1.0, 2.0, 3.0, 4.0], dtype=dtype)), + ops.Constant(c, np.array([5.0, 5.0, 4.0, 4.0], dtype=dtype))), + DivComputation(), [0]) + self._ExecuteAndCompareClose( + c, expected=[[0.2, 0.4, 0.75, 1.0]], rtol=1e-3) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testSelectAndScatter(self, dtype): + if dtype == np.float64 and self.backend.platform == "tpu": + self.skipTest("TPU doesn't support float64") + c = self._NewComputation() + operand = ops.Constant( + c, np.array([[1., 2., 6.], [4., 5., 3.]], dtype=dtype)) + window_dimensions = (2, 1) + window_strides = (1, 2) + padding = xla_client.window_padding_type_to_pad_values( + xla_client.PaddingType.VALID, + c.get_shape(operand).dimensions(), window_dimensions, window_strides) + ops.SelectAndScatterWithGeneralPadding( + operand, + select=self._CreateBinaryGeComputation(dtype), + window_dimensions=window_dimensions, + window_strides=window_strides, + padding=padding, + source=ops.Constant(c, np.array([[0.1, 0.2]], dtype=dtype)), + init_value=ops.Constant(c, np.array(1, dtype=dtype)), + scatter=self._CreateBinaryAddComputation(dtype)) + self._ExecuteAndCompareClose( + c, expected=[[[1., 1., 1.2], [1.1, 1., 1.]]], rtol=5e-3) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testReduce1DtoScalar(self, dtype): + c = self._NewComputation() + ops.Reduce( + c, + operands=[ + ops.Constant(c, np.array([1.0, 2.0, 3.0, 4.0], dtype=dtype)) + ], + init_values=[ops.Constant(c, dtype(0))], + computation=self._CreateBinaryAddComputation(dtype), + dimensions_to_reduce=[0]) + self._ExecuteAndCompareClose(c, expected=[10]) + + # TODO(phawkins): test comparison harness doesn't support bfloat16 + @unittest.skipIf(pjrt_c_api, "b/264473047: hangs") + @parameterized.named_parameters({ + "testcase_name": "_{}_dim{}".format(dtype.__name__, dim), + "dtype": dtype, + "dim": dim, + } for dtype in float_dtypes if dtype != bfloat16 for dim in range(2)) + def testReduce2DTo1D(self, dtype, dim): + input_array = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=dtype) + c = self._NewComputation() + ops.Reduce( + c, + operands=[ops.Constant(c, input_array)], + init_values=[ops.Constant(c, dtype(0))], + computation=self._CreateBinaryAddComputation(dtype), + dimensions_to_reduce=[dim]) + self._ExecuteAndCompareClose(c, expected=[np.sum(input_array, axis=dim)]) + + @unittest.skipIf(pjrt_c_api, "b/264473047: hangs") + @parameterized.named_parameters({ + "testcase_name": "_{}_dims[{}]".format(dtype.__name__, dims), + "dtype": dtype, + "dims": tuple(dims) + } for dtype in float_dtypes for dims in itertools.permutations(range(3))) + def testReduce3DAllPossibleWaysF32(self, dtype, dims): + input_array = self._MakeSample3DArray(dtype) + c = self._NewComputation() + ops.Reduce( + c, + operands=[ops.Constant(c, input_array)], + init_values=[ops.Constant(c, dtype(0))], + computation=self._CreateBinaryAddComputation(dtype), + dimensions_to_reduce=dims) + self._ExecuteAndCompareClose(c, expected=[np.sum(input_array, axis=dims)]) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testReduceWindowValidUnitStrides(self, dtype): + if dtype == np.float64 and self.backend.platform == "tpu": + self.skipTest("TPU doesn't support float64") + input_array = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=dtype) + c = self._NewComputation() + window_dimensions = (2, 1) + window_strides = (1, 1) + padding = xla_client.window_padding_type_to_pad_values( + xla_client.PaddingType.VALID, input_array.shape, window_dimensions, + window_strides) + ops.ReduceWindowWithGeneralPadding( + operand=ops.Constant(c, input_array), + init_value=ops.Constant(c, dtype(0)), + computation=self._CreateBinaryAddComputation(dtype), + window_dimensions=window_dimensions, + window_strides=window_strides, + base_dilations=[], + window_dilations=[], + padding=padding) + self._ExecuteAndCompareClose(c, expected=[[[5., 7., 9.]]]) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testReduceWindowSameUnitStrides(self, dtype): + if dtype == np.float64 and self.backend.platform == "tpu": + self.skipTest("TPU doesn't support float64") + input_array = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=dtype) + c = self._NewComputation() + window_dimensions = (2, 1) + window_strides = (1, 1) + padding = xla_client.window_padding_type_to_pad_values( + xla_client.PaddingType.SAME, input_array.shape, window_dimensions, + window_strides) + ops.ReduceWindowWithGeneralPadding( + operand=ops.Constant(c, input_array), + init_value=ops.Constant(c, dtype(0)), + computation=self._CreateBinaryAddComputation(dtype), + window_dimensions=window_dimensions, + window_strides=window_strides, + base_dilations=[], + window_dilations=[], + padding=padding) + self._ExecuteAndCompareClose(c, expected=[[[5., 7., 9.], [4., 5., 6.]]]) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testReduceWindowValidGeneralStrides(self, dtype): + if dtype == np.float64 and self.backend.platform == "tpu": + self.skipTest("TPU doesn't support float64") + input_array = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=dtype) + c = self._NewComputation() + window_dimensions = (2, 1) + window_strides = (1, 2) + padding = xla_client.window_padding_type_to_pad_values( + xla_client.PaddingType.VALID, input_array.shape, window_dimensions, + window_strides) + ops.ReduceWindowWithGeneralPadding( + operand=ops.Constant(c, input_array), + init_value=ops.Constant(c, dtype(0)), + computation=self._CreateBinaryAddComputation(dtype), + window_dimensions=window_dimensions, + window_strides=window_strides, + base_dilations=[], + window_dilations=[], + padding=padding) + self._ExecuteAndCompareClose(c, expected=[[[5., 9.]]]) + + @unittest.skipIf(pjrt_c_api, "b/264473047: hangs") + def testReduceWindowVariadic(self): + c = self._NewComputation("reducer") + shape = xla_client.shape_from_pyval(np.array(0, dtype=np.int32)) + shape = shape.with_major_to_minor_layout_if_absent() + ps = [ops.Parameter(c, i, shape) for i in range(4)] + which = ops.Ge(ps[0], ps[2]) + ops.Tuple( + c, [ops.Select(which, ps[0], ps[2]), + ops.Select(which, ps[1], ps[3])]) + reducer = c.build() + + key_array = np.array([[1, 5, 6], [4, 2, 3]], dtype=np.int32) + val_array = np.array([[7, 8, 9], [10, 11, 12]], dtype=np.int32) + c = self._NewComputation() + window_dimensions = (2, 1) + window_strides = (1, 1) + padding = xla_client.window_padding_type_to_pad_values( + xla_client.PaddingType.VALID, key_array.shape, window_dimensions, + window_strides) + ops.ReduceWindowWithGeneralPadding( + operands=[ops.Constant(c, key_array), + ops.Constant(c, val_array)], + init_values=[ + ops.Constant(c, np.int32(0)), + ops.Constant(c, np.int32(0)) + ], + computation=reducer, + window_dimensions=window_dimensions, + window_strides=window_strides, + base_dilations=[], + window_dilations=[], + padding=padding) + self._ExecuteAndCompareClose(c, expected=[[[4, 5, 6]], [[10, 8, 9]]]) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testWhile(self, dtype): + + def LessThan10Cond(): + c = self._NewComputation("test_lt_10") + shape = xla_client.shape_from_pyval(np.array(0, dtype=dtype)) + ops.Lt(ops.Parameter(c, 0, shape), ops.Constant(c, dtype(10.))) + return c.build() + + cond = LessThan10Cond() + body = self._CreateMulBy2Computation(dtype) + c = self._NewComputation() + init = ops.Constant(c, dtype(1.)) + ops.While(cond, body, init) + self._ExecuteAndCompareClose(c, expected=[16.]) + + def testConditionalTrue(self): + c = self._NewComputation() + pred = ops.Constant(c, np.bool_(True)) + true_operand = ops.Constant(c, np.float32(3.)) + true_computation = self._CreateMulBy2Computation(np.float32) + false_operand = ops.Constant(c, np.float32(2.)) + false_computation = self._CreateConstantComputation( + np.float32, np.float32) + ops.Conditional(pred, true_operand, true_computation, false_operand, + false_computation) + self._ExecuteAndCompareClose(c, expected=[6.]) + + def testConditionalFalse(self): + c = self._NewComputation() + pred = ops.Constant(c, np.bool_(False)) + true_operand = ops.Constant(c, np.float32(3.)) + true_computation = self._CreateMulBy2Computation(np.float32) + false_operand = ops.Constant(c, np.float32(2.)) + false_computation = self._CreateConstantComputation( + np.float32, np.float32) + ops.Conditional(pred, true_operand, true_computation, false_operand, + false_computation) + self._ExecuteAndCompareClose(c, expected=[1.]) + + @unittest.skipIf(cloud_tpu or pathways or pathways_ifrt or pjrt_c_api, + "not implemented") + def testInfeedS32Values(self): + to_infeed = NumpyArrayS32([1, 2, 3, 4]) + c = self._NewComputation() + ops.GetTupleElement( + ops.InfeedWithToken( + ops.CreateToken(c), + xla_client.shape_from_pyval( + to_infeed[0]).with_major_to_minor_layout_if_absent()), 0) + compiled_c = self.backend.compile( + xla_computation_to_mlir_module(c.build())) + device = self.backend.local_devices()[0] + for item in to_infeed: + device.transfer_to_infeed(item) + + for item in to_infeed: + result, = execute_with_python_values( + compiled_c, (), backend=self.backend) + self.assertEqual(result, item) + + @unittest.skipIf(cloud_tpu or pathways or pathways_ifrt or pjrt_c_api, + "not implemented") + def testInfeedTuple(self): + to_infeed = (NumpyArrayS32([1, 2, 3, 4]), NumpyArrayS32([[7], [8]])) + c = self._NewComputation() + ops.GetTupleElement( + ops.InfeedWithToken( + ops.CreateToken(c), + xla_client.shape_from_pyval( + to_infeed).with_major_to_minor_layout_if_absent()), 0) + compiled_c = self.backend.compile( + xla_computation_to_mlir_module(c.build())) + device = self.backend.local_devices()[0] + device.transfer_to_infeed(to_infeed) + + result = execute_with_python_values( + compiled_c, (), backend=self.backend) + self.assertLen(result, 2) + np.testing.assert_equal(result[0], to_infeed[0]) + np.testing.assert_equal(result[1], to_infeed[1]) + + @unittest.skipIf(cloud_tpu or pathways or pathways_ifrt or pjrt_c_api, + "not implemented") + def testInfeedThenOutfeedS32(self): + to_round_trip = NumpyArrayS32([1, 2, 3, 4]) + c = self._NewComputation() + x_and_token = ops.InfeedWithToken( + ops.CreateToken(c), + xla_client.shape_from_pyval( + to_round_trip[0]).with_major_to_minor_layout_if_absent()) + x = ops.GetTupleElement(x_and_token, 0) + token = ops.GetTupleElement(x_and_token, 1) + outfeed_shape = xla_client.shape_from_pyval( + to_round_trip[0]).with_major_to_minor_layout_if_absent() + ops.OutfeedWithToken(x, token, outfeed_shape) + ops.Tuple(c, ()) + + compiled_c = self.backend.compile( + xla_computation_to_mlir_module(c.build())) + device = self.backend.local_devices()[0] + + for want in to_round_trip: + execution = threading.Thread(target=lambda: compiled_c.execute([])) + execution.start() + device.transfer_to_infeed(want) + got = device.transfer_from_outfeed(outfeed_shape) + execution.join() + self.assertEqual(want, got) + + def testScatter(self): + a = np.arange(9).astype(np.int32).reshape((3, 3)) + scatter_indices = np.array([0, 2], dtype=np.int32) + updates = np.array([[10, 20, 30], [70, 80, 90]], dtype=np.int32) + + dnums = xla_client.ScatterDimensionNumbers() + dnums.update_window_dims.append(1) + dnums.inserted_window_dims.append(0) + dnums.scatter_dims_to_operand_dims.append(0) + dnums.index_vector_dim = 1 + + c = self._NewComputation() + ops.Scatter( + ops.Constant(c, a), ops.Constant(c, scatter_indices), + ops.Constant(c, updates), self._CreateBinaryAddComputation(np.int32), + dnums) + expected = np.array([[10, 21, 32], [3, 4, 5], [76, 87, 98]], + dtype=np.int32) + self._ExecuteAndCompareClose(c, expected=[expected]) + + class DeviceTest(ComputationTest): + + def testDevices(self): + self.assertNotEmpty(self.backend.devices()) + + def testLocalDevices(self): + self.assertNotEmpty(self.backend.local_devices()) + if self.backend.platform == "cpu": + self.assertLen(self.backend.local_devices(), 2) + + def testGetAllDevices(self): + # TODO(hyeontaek): Remove this method once we have a unified API for + # enumerating devices with different criteria. + self.assertNotEmpty(self.backend._get_all_devices()) # pylint: disable=protected-access + + def testPlatform(self): + for device in self.backend.local_devices(): + self.assertEqual(device.platform, self.backend.platform) + + def testCoreCount(self): + if self.backend.platform != "gpu": + self.skipTest("core_count is only supported on GPU") + for device in self.backend.local_devices(): + self.assertGreater(device.core_count, 0) + + def testLocalHardwareId(self): + for device in self.backend.devices(): + local_hardware_id = device.local_hardware_id + if local_hardware_id is not None: + self.assertGreaterEqual(local_hardware_id, 0) + + @unittest.skipIf(pathways_ifrt, "not implemented") + def testLocalDeviceFromLocalHardwareId(self): + for device in self.backend.local_devices(): + if device.local_hardware_id is not None: + lookup_device = self.backend.device_from_local_hardware_id( + device.local_hardware_id) + self.assertEqual(lookup_device, device) + + @unittest.skipIf(pathways, "not implemented") + @unittest.skipIf(pathways_ifrt, "not implemented") + def testMemoryStats(self): + for device in self.backend.local_devices(): + stats = device.memory_stats() + if ( + self.backend.platform != "tpu" or not tfrt_tpu + ) and self.backend.platform not in ("gpu", "cuda", "rocm"): + self.assertIsNone(stats) + else: + self.assertIsNotNone(stats) + # Spot check a few fields + self.assertEqual(type(stats["num_allocs"]), int) + self.assertGreaterEqual(stats["num_allocs"], 0) + self.assertEqual(type(stats["bytes_in_use"]), int) + self.assertGreaterEqual(stats["bytes_in_use"], 0) + self.assertEqual(type(stats["peak_bytes_in_use"]), int) + self.assertGreaterEqual(stats["peak_bytes_in_use"], 0) + self.assertEqual(type(stats["largest_alloc_size"]), int) + self.assertGreaterEqual(stats["largest_alloc_size"], 0) + + @unittest.skipIf(pathways, "not implemented") + def testMemory(self): + for device in self.backend.local_devices(): + for memory in device.addressable_memories(): + self.assertEqual(memory.process_index, device.process_index) + self.assertEqual(memory.platform, device.platform) + self.assertIn(device, memory.addressable_by_devices()) + self.assertEqual(memory, device.memory(memory.kind)) + + tests.append(DeviceTest) + + class ErrorTest(ComputationTest): + + def setUp(self): + super(ErrorTest, self).setUp() + self.f32_scalar_2 = NumpyArrayF32(2.0) + self.s32_scalar_2 = NumpyArrayS32(2) + + def testCompileWithWrongElementTypeInLayout(self): + c = self._NewComputation() + c.set_op_metadata(xla_client.CurrentSourceInfoMetadata()) + ops.Parameter(c, 0, xla_client.shape_from_pyval(self.s32_scalar_2)) + c.clear_op_metadata() + + options = xla_client.CompileOptions() + options.argument_layouts = [ + xla_client.Shape.array_shape(np.dtype(np.float32), []) + ] + + def TestFun(): + return self.backend.compile(c.build(), compile_options=options) + + self.assertRaisesRegex( + RuntimeError, r".*Invalid argument shape.*" + r"expected s32\[\], got f32\[\].*", TestFun) + + def testInvokeWithWrongElementType(self): + c = self._NewComputation() + c.set_op_metadata(xla_client.CurrentSourceInfoMetadata()) + ops.Parameter(c, 0, xla_client.shape_from_pyval(self.s32_scalar_2)) + c.clear_op_metadata() + + def TestFun(): + return execute_with_python_values( + self.backend.compile(xla_computation_to_mlir_module(c.build())), + [self.f32_scalar_2], self.backend) + + self.assertRaisesRegex( + RuntimeError, r"Invalid argument: Argument does not match.*" + r"want s32\[\], got f32\[\].*", TestFun) + + tests.append(EmbeddedComputationsTest) + + class ComputationRootTest(ComputationTest): + """Tests related to setting the root of the computation.""" + + def testComputationRootDifferentFromLastOp(self): + c = self._NewComputation() + x = ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(2.0))) + result = ops.Add(x, ops.Constant(c, np.float32(3.14))) + ops.Add(result, ops.Constant(c, np.float32(1.618))) + + arg = NumpyArrayF32(1.0) + compiled_c = self.backend.compile( + xla_computation_to_mlir_module(c.build(result))) + ans, = execute_with_python_values( + compiled_c, [arg], backend=self.backend) + np.testing.assert_allclose(ans, 4.14) + + tests.append(ComputationRootTest) + + class SetShardingTest(ComputationTest): + """Tests related to set OpSharding.""" + + def testSetSharding(self): + c = self._NewComputation() + sharding = xla_client.OpSharding() + sharding.type = xla_client.OpSharding.Type.REPLICATED + sharding.tile_assignment_dimensions = [1] + sharding.tile_assignment_devices = [0] + c.set_sharding(sharding) + x = ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(2.0))) + c.clear_sharding() + + result = ops.Add(x, ops.Constant(c, np.float32(3.14))) + ops.Add(result, ops.Constant(c, np.float32(1.618))) + arg = NumpyArrayF32(1.0) + compiled_c = self.backend.compile( + xla_computation_to_mlir_module(c.build(result))) + ans, = execute_with_python_values( + compiled_c, [arg], backend=self.backend) + np.testing.assert_allclose(ans, 4.14) + + tests.append(SetShardingTest) + + testcase_shapes = [ + (), + (1,), + (2, 3), + (2, 0), + (0, 7), + (4, 1, 2), + (2, 1, 3), + (2, 4, 1), + (3, 1), + (1, 3), + ] + + def FormatShapeAndDtype(shape, dtype): + return "_{}[{}]".format(np.dtype(dtype).name, ",".join(map(str, shape))) + + class DLPackTest(parameterized.TestCase): + + def setUp(self): + super(DLPackTest, self).setUp() + self.backend = xla_backend() + if self.backend.platform not in ("cpu", "gpu", "cuda", "rocm"): + self.skipTest("DLPack requires CPU or GPU") + self.cpu_backend = ( + self.backend + if self.backend.platform == "cpu" else xla_client.make_cpu_client()) + self.gpu_backend = ( + self.backend + if self.backend.platform in ("gpu", "cuda", "rocm") + else None + ) + + def tearDown(self): + super().tearDown() + del self.backend + del self.cpu_backend + del self.gpu_backend + + @classmethod + def _GetStreamFromDevice(cls, device): + try: + return device.get_stream_for_external_ready_events() + except xla_client.XlaRuntimeError as err: # type: ignore + if "UNIMPLEMENTED" in str(err): + return None + else: + raise + + def _DLPackManagedTensorToBuffer( + self, tensor, use_legacy_api, backend=None + ): + if use_legacy_api: + return xla_client._xla.dlpack_managed_tensor_to_buffer( + tensor, self.cpu_backend, self.gpu_backend + ) + else: + if not backend: + backend = self.backend + device = backend.local_devices()[0] + stream = DLPackTest._GetStreamFromDevice(device) + return xla_client._xla.dlpack_managed_tensor_to_buffer( + tensor, device, stream + ) + + # pylint: disable=g-complex-comprehension + # pyformat: disable + @parameterized.named_parameters( + { + "testcase_name": "{}_gpu={}{}".format( + FormatShapeAndDtype(shape, dtype), + gpu, + "_legacy" if use_legacy_api else "", + ), + "dtype": dtype, + "shape": shape, + "gpu": gpu, + "use_legacy_api": use_legacy_api, + } + for dtype in dlpack_dtypes + for shape in testcase_shapes + for gpu in [False, True] + for use_legacy_api in [False, True] + ) + # pyformat: enable + def testRoundTrip(self, dtype, shape, gpu, use_legacy_api): + if gpu and self.gpu_backend is None: + raise unittest.SkipTest("Test not running with GPU support") + backend = self.gpu_backend if gpu else self.cpu_backend + if dtype == np.bool_: + x = np.random.randint(0, 2, size=shape).astype(np.bool_) + else: + x = np.array(np.random.rand(*shape) * 100, dtype=dtype) + buffer = backend.buffer_from_pyval(x) + dlt = xla_client._xla.buffer_to_dlpack_managed_tensor(buffer) + del buffer # Free "buffer" to make sure dlt retains ownership. + self.assertEqual(type(dlt).__name__, "PyCapsule") + y = self._DLPackManagedTensorToBuffer(dlt, use_legacy_api, backend) + np.testing.assert_array_equal( + x.astype(np.uint8) if dtype == np.bool_ else x, np.asarray(y)) + + @parameterized.named_parameters( + { + "testcase_name": "{}".format("_legacy" if use_legacy_api else ""), + "use_legacy_api": use_legacy_api, + } + for use_legacy_api in [False, True] + ) + def testTensorsCanBeConsumedOnceOnly(self, use_legacy_api): + x = np.array(np.random.rand(3, 4, 5, 6), dtype=np.float32) + buffer = self.backend.buffer_from_pyval(x) + dlt = xla_client._xla.buffer_to_dlpack_managed_tensor(buffer) + + def ConsumeDLPackTensor(): + _ = self._DLPackManagedTensorToBuffer(dlt, use_legacy_api) + + ConsumeDLPackTensor() + self.assertRaisesRegex( + RuntimeError, ".*a DLPack tensor may be consumed at most once.*", + ConsumeDLPackTensor) + + @parameterized.named_parameters( + { + "testcase_name": "{}".format("_legacy" if use_legacy_api else ""), + "use_legacy_api": use_legacy_api, + } + for use_legacy_api in [False, True] + ) + def testNonOwnedDlpackCanBeViewedTwice(self, use_legacy_api): + x = np.array(np.random.rand(3, 4, 5, 6), dtype=np.float32) + buffer = self.backend.buffer_from_pyval(x) + d1 = xla_client._xla.buffer_to_dlpack_managed_tensor(buffer) + d2 = xla_client._xla.buffer_to_dlpack_managed_tensor(buffer) + + y = self._DLPackManagedTensorToBuffer(d1, use_legacy_api) + z = self._DLPackManagedTensorToBuffer(d2, use_legacy_api) + del d1, d2 + np.testing.assert_array_equal(x, np.asarray(buffer)) + np.testing.assert_array_equal(x, np.asarray(y)) + np.testing.assert_array_equal(x, np.asarray(z)) + + @parameterized.parameters(False, True) + def testZeroCopyOnAlignedDlpackTensor(self, use_legacy_api): + # Using CPU only, since this test is about CPU memory alignment. + if self.backend.platform != "cpu": + self.skipTest("Test requires CPU") + + # Create a numpy array that is aligned to XLA requirements. + x = np.array(np.random.rand(3, 4, 5, 6), dtype=np.float32) + x = _Aligned(x) + + # Convert it to a DLPack tensor, and then to an XLA buffer. + dlpack_tensor = x.__dlpack__() + buffer = self._DLPackManagedTensorToBuffer(dlpack_tensor, use_legacy_api) + y = np.array(buffer, copy=False) + + # The input was sufficiently aligned, so input and output should alias. + x_ptr = x.__array_interface__["data"][0] + y_ptr = y.__array_interface__["data"][0] + self.assertEqual( + x_ptr, + y_ptr, + msg=f"Buffers are not aliased ({hex(x_ptr)} != {hex(y_ptr)}).", + ) + + @parameterized.named_parameters( + { + "testcase_name": "{}{}".format( + "_legacy" if use_legacy_api else "", + "_transpose" if transpose else "", + ), + "use_legacy_api": use_legacy_api, + "transpose": transpose, + } + for use_legacy_api in [False, True] + for transpose in [False, True] + ) + def testReturnCopyOnUnalignedDlpackTensor(self, use_legacy_api, transpose): + # Using CPU only, since this test is about CPU memory alignment. + if self.backend.platform != "cpu": + self.skipTest("Test requires CPU") + + if transpose and use_legacy_api: + self.skipTest("Non-default layout is not supported in legacy API") + + # Create a numpy array that is not aligned to XLA requirements. XLA's + # alignment requirements differ for different hardware, so we use the + # smallest possible value. If we make sure the buffer is not aligned to + # this value (16 bytes), then it is also not aligned to its multiples (32, + # 64 etc.) + x = np.array(np.random.rand(3, 4, 5, 6), dtype=np.float32) + x = _Unaligned(x, alignment=_XLA_CPU_MIN_ALIGNMENT) + + # Transpose the array to test non-default layout with trivial striding. + if transpose: + x = x.transpose((0, 2, 1, 3)) + + # Convert it to a DLPack tensor, and then to an XLA buffer. + dlpack_tensor = x.__dlpack__() + buffer = self._DLPackManagedTensorToBuffer(dlpack_tensor, use_legacy_api) + y = np.array(buffer, copy=False) + + # The input was not sufficiently aligned, so input and output should not + # alias (output should be a copy of input, and it should be aligned). + x_ptr = x.__array_interface__["data"][0] + y_ptr = y.__array_interface__["data"][0] + self.assertNotEqual( + x_ptr, + y_ptr, + msg=( + f"Buffers aliased, but should not be ({hex(x_ptr)} ==" + f" {hex(y_ptr)})" + ), + ) + self.assertEqual( + y_ptr % _XLA_CPU_MIN_ALIGNMENT, + 0, + msg="Output buffer not aligned: {hex(y_ptr)}", + ) + np.testing.assert_array_equal(y, x) + + tests.append(DLPackTest) + + class BufferProtocolTest(parameterized.TestCase): + + def setUp(self): + super(BufferProtocolTest, self).setUp() + self.backend = xla_backend() + if self.backend.platform != "cpu": + self.skipTest("Test requires CPU") + + # pylint: disable=g-complex-comprehension + @parameterized.named_parameters({ + "testcase_name": FormatShapeAndDtype(shape, dtype), + "dtype": dtype, + "shape": shape + } for dtype in standard_dtypes if dtype != bfloat16 + for shape in testcase_shapes) + def testRoundTrip(self, dtype, shape): + x = np.array(np.random.rand(*shape) * 100, dtype=dtype) + + x = _Aligned(x) + x_ptr = x.__array_interface__["data"][0] + buffer = self.backend.buffer_from_pyval( + x, host_buffer_semantics=xla_client.HostBufferSemantics.ZERO_COPY) + y = np.array(buffer, copy=False) + y_ptr = y.__array_interface__["data"][0] + np.testing.assert_array_equal(x, y) + + # The input was sufficiently aligned, so input and output should alias. + self.assertEqual(x_ptr, y_ptr) + self.assertEqual(y_ptr, buffer.unsafe_buffer_pointer()) + + during_call = xla_client.HostBufferSemantics.IMMUTABLE_ONLY_DURING_CALL + buffer2 = self.backend.buffer_from_pyval( + x, host_buffer_semantics=during_call) + z = np.array(buffer2, copy=False) + self.assertNotEqual(x.__array_interface__["data"][0], + z.__array_interface__["data"][0]) + + def testDeleteWithActiveView(self): + x = np.random.randn(20, 10) + buffer = self.backend.buffer_from_pyval(x) + buffer_ptr = buffer.unsafe_buffer_pointer() + y = np.array(buffer, copy=False) + buffer.delete() + # It is still legal to access `y`; the array view must keep it alive. + np.testing.assert_array_equal(x, y) + self.assertEqual(y.__array_interface__["data"][0], buffer_ptr) + + tests.append(BufferProtocolTest) + + class TracebackTest(absltest.TestCase): + + def setUp(self): + super(TracebackTest, self).setUp() + self.backend = xla_backend() + + def testNoTracebacksIfDisabled(self): + with xla_client.tracebacks(enabled=False): + self.assertEqual(None, xla_client.Traceback.get_traceback()) + buffer = self.backend.buffer_from_pyval(np.array(7, np.int32)) + self.assertEqual(None, buffer.traceback) + + b = xla_client.XlaBuilder("computation") + ops.Add(ops.Constant(b, np.int32(1)), ops.Constant(b, np.int32(2))) + e = self.backend.compile(xla_computation_to_mlir_module(b.build())) + self.assertEqual(None, e.traceback) + + def assertIsTracebackContaining(self, tb, function): + self.assertIsInstance(tb, xla_client.Traceback) + self.assertIn(function, str(tb)) + self.assertTrue(any(f.function_name == function for f in tb.frames)) + + def testTracebacks(self): + with xla_client.tracebacks(enabled=True): + tb = xla_client.Traceback.get_traceback() + self.assertIsTracebackContaining(tb, "testTracebacks") + + # Tracebacks are not implemented on the TPU driver extension's variant + # of buffers and executables. + if not isinstance(self.backend, xla_client.Client): + return + + buffer = self.backend.buffer_from_pyval(np.array(7, np.int32)) + self.assertIsTracebackContaining(buffer.traceback, "testTracebacks") + + b = xla_client.XlaBuilder("computation") + ops.Add(ops.Constant(b, np.int32(1)), ops.Constant(b, np.int32(2))) + e = self.backend.compile(xla_computation_to_mlir_module(b.build())) + self.assertIsTracebackContaining(e.traceback, "testTracebacks") + + def testNestedFunction(self): + + def AFunction(): + + def AnotherFunction(): + return xla_client.Traceback.get_traceback() + + return AnotherFunction() + + with xla_client.tracebacks(enabled=True): + tb = AFunction() + self.assertIsInstance(tb, xla_client.Traceback) + frames = tb.frames + i = next( + i for (i, f) in enumerate(frames) if f.function_name == "AFunction") + self.assertEqual(frames[i - 1].function_name, "AnotherFunction") + self.assertEqual(frames[i + 1].function_name, "testNestedFunction") + + def testPythonTracebackHasCorrectLineNumbers(self): + def B(): + return xla_client.Traceback.get_traceback() + + def A(): + return B() + + tb = A().as_python_traceback() + for frame, lineno in traceback.walk_tb(tb): + if frame.f_code.co_name == "A": + line = A.__code__.co_firstlineno + self.assertBetween(lineno, line, line + 2) + elif frame.f_code.co_name == "B": + line = B.__code__.co_firstlineno + self.assertBetween(lineno, line, line + 2) + + def testAccessingLocalsDoesNotCrash(self): + # https://github.com/google/jax/issues/16027 + tb = xla_client.Traceback.get_traceback() + python_tb = tb.as_python_traceback() + for frame, _ in traceback.walk_tb(python_tb): + _ = frame.f_locals # should not crash + + def testTracebackFromFrames(self): + def FooFn(x): + return x + 1 + + def BarFn(y): + y = y + 1 + y = y + 2 + return y * 2 + + frame_foo = xla_client.Frame( + __file__, + FooFn.__code__.co_name, + FooFn.__code__.co_firstlineno, + FooFn.__code__.co_firstlineno + 1, + ) + frame_bar = xla_client.Frame( + __file__, + BarFn.__code__.co_name, + BarFn.__code__.co_firstlineno, + BarFn.__code__.co_firstlineno + 2, + ) + frames = [frame_foo, frame_bar] + tb = xla_client.Traceback.traceback_from_frames(frames) + + with self.subTest("WalkDoesNotError"): + for frame, _ in traceback.walk_tb(tb): + _ = frame.f_locals # should not crash + + with self.subTest("TracebackCorrectness"): + tb_string = traceback.format_tb(tb) + # The traceback should have the format: + # File , line N in BarFn + # y = y + 2 + # File , line N in FooFn + # return x + 1 + self.assertLen(tb_string, len(frames)) + bar_frame = tb_string[0].split("\n") + self.assertEndsWith(bar_frame[0], "BarFn") + self.assertEqual(bar_frame[1].strip(), "y = y + 2") + foo_frame = tb_string[1].split("\n") + self.assertEndsWith(foo_frame[0], "FooFn") + self.assertEqual(foo_frame[1].strip(), "return x + 1") + + tests.append(TracebackTest) + + class ClientTest(ComputationTest): + + def setUp(self): + super(ClientTest, self).setUp() + self.backend = xla_backend() + + def testPlatformVersion(self): + version = self.backend.platform_version + logging.info("platform_version:\n%s", version) + if self.backend.platform == "cpu": + self.assertEqual(version, "cpu") + elif self.backend.platform in ("gpu", "cuda", "rocm"): + # Following is false if not built with --config=cuda + if version != "": + self.assertTrue( + re.match(r"^cuda \d{4,}$", version), + msg=f"Expected CUDA version string; got {repr(version)}") + elif self.backend.platform == "tpu" and not (pathways or pathways_ifrt): + self.assertIn("tpu", version.lower()) + self.assertIn("cl/", version) + self.assertIn("Built on ", version) + + @unittest.skipIf( + not cloud_tpu and not pjrt_c_api, "PJRT version only exist for plugins" + ) + def testPjRtCApiVersion(self): + self.assertGreaterEqual(self.backend.pjrt_c_api_major_version, 0) + self.assertGreaterEqual(self.backend.pjrt_c_api_minor_version, 0) + + @unittest.skipUnless( + not pjrt_c_api and tfrt_tpu, + "Test that attributes are zero for non-plugin tfrt_tpu", + ) + def testStaticTfrtTpuAttributes(self): + self.assertEqual(self.backend.pjrt_c_api_major_version, 0) + self.assertEqual(self.backend.pjrt_c_api_minor_version, 0) + # CL number is defined as -1 when running as test. + self.assertEqual(self.backend.__getattr__("cl_number"), -1) + + @unittest.skipIf( + cloud_tpu or pjrt_c_api or (not pjrt_c_api and tfrt_tpu), + "PJRT version only exist for plugins", + ) + def testNotExistPjRtCApiVersion(self): + with self.assertRaises(AttributeError): + self.backend.pjrt_c_api_major_version # pylint: disable=pointless-statement + with self.assertRaises(AttributeError): + self.backend.pjrt_c_api_minor_version # pylint: disable=pointless-statement + + @unittest.skipIf(pathways or pathways_ifrt, "has different behavior") + def testPluginProgramDoesNotCompile(self): + program = xla_client.ifrt_programs.make_plugin_program("foobar") + options = xla_client.ifrt_programs.make_plugin_compile_options() + with self.assertRaisesRegex( + xla_client.XlaRuntimeError, "PjRtCompiler requires an HloProgram" + ): + self.backend.compile_ifrt_program(program, options) + + @unittest.skipIf(pathways, "does not work with non-ifrt legacy pathways") + def testHloProgramViaIfrtProgram(self): + c = self._NewComputation() + ops.Iota(c, xla_client.PrimitiveType.F32, 10) + program = xla_client.ifrt_programs.make_hlo_program( + xla_computation_to_mlir_module(c.build()) + ) + options = xla_client.ifrt_programs.make_xla_compile_options( + xla_client.CompileOptions(), [] + ) + + compiled_c = self.backend.compile_ifrt_program(program, options) + results = execute_with_python_values( + compiled_c, arguments=(), backend=self.backend + ) + + self.assertLen(results, 1) + np.testing.assert_equal(results[0], np.arange(10, dtype=np.float32)) + + @unittest.skipIf(cloud_tpu or pathways or pathways_ifrt or tfrt_tpu, + "not implemented") + def testExecutableSerialization(self): + if self.backend.platform != "tpu": + self.skipTest("Test requires tpu platform") + + c = self._NewComputation() + ops.Add( + ops.Constant(c, NumpyArrayS32([1, 2])), + ops.Constant(c, NumpyArrayS32([3, 4]))) + + options = xla_client.CompileOptions() + executable = self.backend.compile( + xla_computation_to_mlir_module(c.build()), options) + self.assertLen(executable.hlo_modules(), 1) + + serialized = self.backend.serialize_executable(executable) + deserialized = self.backend.deserialize_executable(serialized, options) + + expected, = execute_with_python_values(executable, (), self.backend) + actual, = execute_with_python_values(deserialized, (), self.backend) + self.assertTrue(np.all(actual == expected)) + + def testCompileOptionsSerialization(self): + options = xla_client.CompileOptions() + executable_build_options = options.executable_build_options + options.num_replicas = 3 + options.num_partitions = 2 + options.profile_version = 1337 + options.compile_portable_executable = True + executable_build_options.num_replicas = 3 + executable_build_options.num_partitions = 2 + deb_opt = executable_build_options.debug_options + deb_opt.xla_cpu_enable_fast_math = True + deb_opt.xla_test_all_input_layouts = True + deb_opt.xla_gpu_kernel_cache_file = "/foo/bar" + deb_opt.xla_gpu_enable_llvm_module_compilation_parallelism = True + deb_opt.xla_gpu_per_fusion_autotune_cache_dir = "/bar/foo/" + deb_opt.xla_gpu_experimental_autotune_cache_mode = ( + xla_client.AutotuneCacheMode.READ + ) + + b = options.SerializeAsString() + restored = xla_client.CompileOptions.ParseFromString(b) + + for name in ("num_replicas", "num_partitions", "profile_version", + "compile_portable_executable"): + self.assertEqual(getattr(options, name), getattr(restored, name), + msg=name) + + for name in ("num_replicas", "num_partitions"): + self.assertEqual(getattr(options.executable_build_options, name), + getattr(restored.executable_build_options, name), + msg=name) + + for name in ( + "xla_cpu_enable_fast_math", + "xla_test_all_input_layouts", + "xla_gpu_kernel_cache_file", + "xla_gpu_enable_llvm_module_compilation_parallelism", + "xla_gpu_per_fusion_autotune_cache_dir", + "xla_gpu_experimental_autotune_cache_mode", + ): + self.assertEqual( + getattr(options.executable_build_options.debug_options, name), + getattr(restored.executable_build_options.debug_options, name), + msg=name) + + tests.append(ClientTest) + + # TODO(b/182461453): Add TFRT and cloud TPU implementation of + # ReadDynamicShapes + @unittest.skip("Test fails HLO -> MHLO conversion") + class DynamicReshapeTest(ComputationTest): + """Tests related to DynamicReshape.""" + + def _CompareToPyAndBufferProtocol(self, builder, args, expected_results, + test_fn): + compiled = self.backend.compile( + xla_computation_to_mlir_module(builder.build())) + output_buffers = compiled.execute([ + self.backend.buffer_from_pyval( + arg, device=compiled.local_devices()[0]) for arg in args + ]) + self.assertLen(output_buffers, len(expected_results)) + for buf, expected in zip(output_buffers, expected_results): + to_py_result = np.asarray(buf) + self.assertEqual(expected.shape, to_py_result.shape) + test_fn(expected, to_py_result) + if self.backend.platform == "cpu" and buf.dtype != bfloat16: + mview = memoryview(buf) + self.assertEqual(expected.shape, mview.shape) + test_fn(expected, np.asarray(mview)) + else: + # Buffer protocol expected to fail on non-cpu platforms and bfloat16 + # Note that np.asarray(buf) doesn't throw an exception. To test if the + # error was thrown properly we must use memoryview(buf). + with self.assertRaises(BufferError): + memoryview(buf) + + # 1D reshape of full size, half size, and size of 0. + @unittest.skip("not implemented") + @parameterized.parameters((5), (3), (0)) + def testReshape1D(self, reshape_size): + full_size = 5 + c = self._NewComputation() + arg = np.array(reshape_size, dtype=np.int32) + expected = np.array(range(reshape_size), dtype=np.int32) + p = ops.Parameter(c, 0, xla_client.shape_from_pyval(arg)) + ops.DynamicReshape( + ops.Constant(c, NumpyArrayS32(range(full_size))), [p], [full_size], + [True]) + self._CompareToPyAndBufferProtocol(c, [arg], [expected], + np.testing.assert_equal) + + # 2D reshape with an slice on the minor dimension. We test different types + # where the strides may differ between the host and devices. The reshaped + # physical memory layout is not consecutive, and we test if the program can + # return the correct logical view of the data. + @unittest.skipIf( + cloud_tpu or pathways or tfrt_tpu or pjrt_c_api, + "not implemented") + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in int_dtypes + float_dtypes) + def testReshape2D(self, dtype): + arg0 = np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype) + arg1 = np.array(2, dtype=np.int32) + expected = np.array([[1, 2], [4, 5]], dtype=np.int32) + c = self._NewComputation() + p0 = ops.Parameter(c, 0, xla_client.shape_from_pyval(arg0)) + p1 = ops.Parameter(c, 1, xla_client.shape_from_pyval(arg1)) + ops.DynamicReshape(p0, [p1, p1], [2, 3], [False, True]) + self._CompareToPyAndBufferProtocol(c, [arg0, arg1], [expected], + np.testing.assert_equal) + + @unittest.skipIf(cloud_tpu or pathways or tfrt_tpu, "not implemented") + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in int_dtypes + float_dtypes) + def testDynamicShapeArgs(self, dtype): + full_size = 10 + dynamic_shape_size = 4 + # subcomputation 1 + binary_add_builder = self._NewComputation() + scalar_shape = xla_client.Shape.scalar_shape(np.dtype(dtype)) + ops.Add( + ops.Parameter(binary_add_builder, 0, scalar_shape), + ops.Parameter(binary_add_builder, 1, scalar_shape)) + # subcomputation 2 + reshape_reduce_builder = self._NewComputation() + dshape = xla_client.Shape.array_shape( + np.dtype(dtype), dims=[full_size], dynamic_dimensions=[True]) + reshape_reduce_p = ops.Parameter(reshape_reduce_builder, 0, dshape) + ops.Reduce( + reshape_reduce_builder, + operands=[reshape_reduce_p], + init_values=[ops.Constant(reshape_reduce_builder, dtype(0))], + computation=binary_add_builder.build(), + dimensions_to_reduce=[0]) + # main computation: sum(range(full_size)[:dynamic_shape_size]) + c = self._NewComputation() + arg = np.array(dynamic_shape_size, dtype=np.int32) + p = ops.Parameter(c, 0, xla_client.shape_from_pyval(arg)) + reshaped = ops.DynamicReshape( + ops.Constant(c, np.array(range(full_size), dtype=dtype)), [p], + [full_size], [True]) + ops.Call(c, reshape_reduce_builder.build(), operands=(reshaped,)) + self._ExecuteAndCompareClose(c, [arg], [dtype(6)]) + + tests.append(DynamicReshapeTest) + + class DeviceAssignmentTest(ComputationTest): + + def testSerialize(self): + shape = (3, 4) + device_assignment = xla_client.DeviceAssignment.create( + np.arange(np.prod(shape)).reshape(*shape)) + self.assertEqual(device_assignment.replica_count(), shape[0]) + self.assertEqual(device_assignment.computation_count(), shape[1]) + serialized = device_assignment.serialize() + self.assertIsInstance(serialized, bytes) + self.assertNotEmpty(serialized) + + tests.append(DeviceAssignmentTest) + + class TokenTest(ComputationTest): + """Tests related to PyToken.""" + + def testExecuteWithToken(self): + c = self._NewComputation() + ops.Mul( + ops.Constant(c, np.array([2.5, 3.3, -1.2, 0.7], np.float32)), + ops.Constant(c, np.array([-1.2, 2, -2, -3], np.float32))) + compiled_c = self.backend.compile( + xla_computation_to_mlir_module(c.build())) + results, token = compiled_c.execute_with_token([]) + token.block_until_ready() + self.assertLen(results, 1) + np.testing.assert_allclose( + np.asarray(results[0]), np.float32([-3, 6.6, 2.4, -2.1]), rtol=3e-3) + + def testExecuteShardedOnLocalDevicesWithTokens(self): + c = self._NewComputation() + ops.Mul( + ops.Constant(c, np.array([2.5, 3.3, -1.2, 0.7], np.float32)), + ops.Constant(c, np.array([-1.2, 2, -2, -3], np.float32))) + num_replicas = 1 + options = xla_client.CompileOptions() + options.num_replicas = num_replicas + compiled_c = self.backend.compile( + xla_computation_to_mlir_module(c.build()), compile_options=options) + py_results = compiled_c.execute_sharded([], with_tokens=True) + results = py_results.disassemble_into_single_device_arrays() + sharded_token = py_results.consume_token() + sharded_token.block_until_ready() + self.assertLen(results, 1) + self.assertLen(results[0], 1) + np.testing.assert_allclose( + np.asarray(results[0][0]), + np.float32([-3, 6.6, 2.4, -2.1]), + rtol=3e-3) + + tests.append(TokenTest) + + class ExecutePortableTest(ComputationTest): + + @unittest.skip("Test does not work under IFRT") + def testExecutePortable(self): + devices_by_kind = collections.defaultdict(list) + for device in self.backend.devices(): + devices_by_kind[device.device_kind].append(device) + multi_devices = [d for d in devices_by_kind.values() if len(d) > 1] + if not multi_devices: + raise unittest.SkipTest("Test needs multiple identical devices") + devices = multi_devices[0] + + c = self._NewComputation() + args = [ + np.array(3, dtype=np.int32), + np.array([10, 15, -2, 7], dtype=np.int32) + ] + p0 = ops.Parameter(c, 0, xla_client.shape_from_pyval(args[0])) + p1 = ops.Parameter(c, 1, xla_client.shape_from_pyval(args[1])) + ops.Mul(p0, p1) + options = xla_client.CompileOptions() + options.compile_portable_executable = True + compiled_c = self.backend.compile(c.build(), compile_options=options) + for device in devices: + out, = compiled_c.execute( + [self.backend.buffer_from_pyval(a, device=device) for a in args], + device=device) + np.testing.assert_array_equal(np.asarray(out), args[0] * args[1]) + + tests.append(ExecutePortableTest) + + class ExecuteShardedOverloadTest(ComputationTest): + + def testExecuteShardedOverloadEmptyInput(self): + c = self._NewComputation() + ops.Constant(c, np.array([2.5, 3.3, -1.2, 0.7], np.float32)) + options = xla_client.CompileOptions() + options.num_replicas = 1 + compiled_c = self.backend.compile( + xla_computation_to_mlir_module(c.build()), compile_options=options) + + results = compiled_c.execute_sharded( + []).disassemble_into_single_device_arrays() + self.assertLen(results, 1) + self.assertIsInstance(results[0], list) + self.assertLen(results[0], 1) + results[0][0].block_until_ready() + self.assertIsInstance(results[0][0], xla_client.ArrayImpl) + + results = compiled_c.execute_sharded( + [], with_tokens=True).disassemble_into_single_device_arrays() + self.assertLen(results, 1) + self.assertIsInstance(results[0], list) + self.assertLen(results[0], 1) + results[0][0].block_until_ready() + self.assertIsInstance(results[0][0], xla_client.ArrayImpl) + + def testExecuteShardedOverloadBufferInput(self): + arg = np.arange(12, dtype=np.int16).reshape(3, 4) + c = self._NewComputation() + ops.Parameter(c, 0, xla_client.shape_from_pyval(arg)) + + options = xla_client.CompileOptions() + options.num_replicas = 1 + compiled_c = self.backend.compile( + xla_computation_to_mlir_module(c.build()), compile_options=options) + + buffer = self.backend.buffer_from_pyval(arg) + + results = compiled_c.execute_sharded( + [[buffer]]).disassemble_into_single_device_arrays() + self.assertLen(results, 1) + self.assertIsInstance(results[0], list) + self.assertLen(results[0], 1) + results[0][0].block_until_ready() + self.assertIsInstance(results[0][0], xla_client.ArrayImpl) + + results = compiled_c.execute_sharded( + [[buffer]], with_tokens=True).disassemble_into_single_device_arrays() + self.assertLen(results, 1) + self.assertIsInstance(results[0], list) + self.assertLen(results[0], 1) + results[0][0].block_until_ready() + self.assertIsInstance(results[0][0], xla_client.ArrayImpl) + + tests.append(ExecuteShardedOverloadTest) + + return tests + + +def InstantiateTests(globals_dict, backend_fn, test_prefix="", **kw): + # Avoid creating a new backend per test (this causes GPU OOM, and is probably + # inefficient). + backend_fn = functools.lru_cache(maxsize=None)(backend_fn) + for klass in TestFactory(backend_fn, **kw): + test = type(test_prefix + klass.__name__, (klass,), {}) + # Clean up the qualified names of the tests to not include the test factory. + test.__qualname__ = test.__name__ + globals_dict[test.__name__] = test + + +backends = { + "cpu": functools.partial(xla_client.make_cpu_client, num_devices=2), +} + +if __name__ == "__main__": + flags.DEFINE_string("backend", "cpu", "Target platform.") + jax.config.parse_flags_with_absl() + # pylint: disable=unnecessary-lambda + InstantiateTests(globals(), lambda: backends[FLAGS.backend]()) + # pylint: enable=unnecessary-lambda + absltest.main() diff --git a/jaxlib/xla/xla_compiler.cc b/jaxlib/xla/xla_compiler.cc new file mode 100644 index 000000000000..bea3062c64e4 --- /dev/null +++ b/jaxlib/xla/xla_compiler.cc @@ -0,0 +1,1640 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/xla_compiler.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/container/inlined_vector.h" +#include "absl/hash/hash.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "nanobind/ndarray.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/pair.h" // IWYU pragma: keep +#include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/variant.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/xla/dlpack.h" +#include "jaxlib/xla/py_client.h" +#include "xla/array.h" +#include "xla/client/executable_build_options.h" +#include "xla/debug_options_flags.h" +#include "xla/ffi/api/c_api.h" +#include "xla/ffi/ffi.h" +#include "xla/ffi/ffi_api.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_module_group.h" +#include "xla/hlo/ir/hlo_print_options.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/hlo/parser/hlo_parser.h" +#include "xla/hlo/pass/hlo_pass_interface.h" +#include "xla/hlo/transforms/simplifiers/flatten_call_graph.h" +#include "xla/hlo/transforms/simplifiers/hlo_dce.h" +#include "xla/hlo/transforms/simplifiers/tuple_simplifier.h" +#include "xla/layout.h" +#include "xla/layout_util.h" +#include "xla/literal.h" +#include "xla/pjrt/compile_options.pb.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/nb_absl_span.h" // IWYU pragma: keep +#include "xla/python/nb_helpers.h" +#include "xla/python/nb_numpy.h" +#include "xla/python/types.h" +#include "xla/service/call_inliner.h" +#include "xla/service/computation_placer.h" +#include "xla/service/custom_call_target_registry.h" +#include "xla/service/hlo.pb.h" +#include "xla/service/hlo_graph_dumper.h" +#include "xla/service/hlo_module_config.h" +#include "xla/service/name_uniquer.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/tsl/lib/strings/proto_serialization.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" +#include "xla/xla.pb.h" +#include "xla/xla_data.pb.h" + +namespace nanobind { +namespace detail { + +template <> +struct type_caster { + public: + NB_TYPE_CASTER_FROM_PYTHON_ONLY(xla::OpMetadata, + const_name("xla::OpMetadata")); + + bool from_python(handle h, uint8_t, cleanup_list*) noexcept { + handle op_type = getattr(h, "op_type"); + if (!op_type.is_none()) { + value.set_op_type(cast(op_type)); + } + handle op_name = getattr(h, "op_name"); + if (!op_name.is_none()) { + value.set_op_name(cast(op_name)); + } + handle source_file = getattr(h, "source_file"); + if (!source_file.is_none()) { + value.set_source_file(cast(source_file)); + } + handle source_line = getattr(h, "source_line"); + if (!source_line.is_none()) { + value.set_source_line(cast(source_line)); + } + return true; + } +}; + +} // namespace detail +} // namespace nanobind + +namespace xla { +namespace { + +namespace nb = nanobind; + +struct Uniquer { + absl::Mutex mu; + NameUniquer name_uniquer ABSL_GUARDED_BY(mu); +}; + +Uniquer* GetUniquer() { + static Uniquer* uniquer = new Uniquer; + return uniquer; +} + +static std::string UniquifyName(const std::string& name) { + Uniquer* uniquer = GetUniquer(); + absl::MutexLock lock(&uniquer->mu); + return uniquer->name_uniquer.GetUniqueName(name); +} + +// Converts a computation to a serialized HloModuleProto. +absl::StatusOr GetComputationSerializedProto( + const XlaComputation& computation) { + std::string result; + if (!tsl::SerializeToStringDeterministic(computation.proto(), &result)) { + return Unknown("Failed to serialize the HloModuleProto."); + } + return nb::bytes(result.data(), result.size()); +} + +// Converts a hlo module to a serialized HloModuleProto. +absl::StatusOr GetHloModuleSerializedProto(const HloModule& module) { + std::string result; + if (!tsl::SerializeToStringDeterministic(module.ToProto(), &result)) { + return Unknown("Failed to serialize the HloModuleProto."); + } + return nb::bytes(result.data(), result.size()); +} + +// Converts a serialized HloModuleProto into a HloModule. +absl::StatusOr> HloModuleFromSerializedProto( + const nb::bytes& bytes) { + HloModuleProto proto; + proto.ParseFromArray(bytes.c_str(), bytes.size()); + TF_ASSIGN_OR_RETURN(const HloModuleConfig module_config, + HloModule::CreateModuleConfigFromProto( + proto, GetDebugOptionsFromFlags())); + TF_ASSIGN_OR_RETURN(std::unique_ptr module, + HloModule::CreateFromProto(proto, module_config)); + return std::shared_ptr(std::move(module)); +} + +absl::StatusOr> GetHloModule( + const XlaComputation& computation) { + TF_ASSIGN_OR_RETURN(const HloModuleConfig module_config, + HloModule::CreateModuleConfigFromProto( + computation.proto(), GetDebugOptionsFromFlags())); + TF_ASSIGN_OR_RETURN( + std::unique_ptr module, + HloModule::CreateFromProto(computation.proto(), module_config)); + return std::shared_ptr(std::move(module)); +} + +// Converts a computation to textual HLO form. +absl::StatusOr GetComputationHloText( + const XlaComputation& computation, bool print_large_constants = false) { + TF_ASSIGN_OR_RETURN(std::shared_ptr hlo_module, + GetHloModule(computation)); + HloPrintOptions options; + options = HloPrintOptions::ShortParsable(); + options.set_print_large_constants(print_large_constants); + return hlo_module->ToString(options); +} + +// Converts a computation to HLO dot graph form. +absl::StatusOr GetComputationHloDotGraph( + const XlaComputation& computation) { + TF_ASSIGN_OR_RETURN(std::shared_ptr hlo_module, + GetHloModule(computation)); + return RenderGraph(*hlo_module->entry_computation(), /*label=*/"", + hlo_module->config().debug_options(), + RenderedGraphFormat::kDot); +} + +// Hashes the HLO module. +absl::StatusOr HashComputation(const XlaComputation& computation) { + TF_ASSIGN_OR_RETURN(std::shared_ptr hlo_module, + GetHloModule(computation)); + return absl::HashOf(*hlo_module); +} +// Safe version of ShapeUtil::MakeShapeWithDenseLayout that fails gracefully on +// invalid input. +absl::StatusOr MakeShapeWithDenseLayout( + PrimitiveType element_type, absl::Span dims, + std::optional> minor_to_major, + std::optional> dynamic_dimensions) { + Shape shape; + if (dynamic_dimensions) { + TF_ASSIGN_OR_RETURN( + shape, ShapeUtil::MakeValidatedShape(element_type, dims, + dynamic_dimensions.value())); + } else { + TF_ASSIGN_OR_RETURN(shape, + ShapeUtil::MakeValidatedShape(element_type, dims)); + } + if (minor_to_major) { + *shape.mutable_layout() = LayoutUtil::MakeLayout(*minor_to_major); + TF_RETURN_IF_ERROR( + LayoutUtil::ValidateLayoutForShape(shape.layout(), shape)); + } + + return shape; +} + +// Pybind function for HloSharding.iota_tile, which is a non-crashing factory +// that produces a HloSharding instance backed by tile assignment of a +// transposed and reshaped iota array of device ids. More specifically the tile +// assignment array is as if it is produced by the following numpy code: +// numpy.arange(math.prod(dims)).reshape(reshape_dims) +// .transpose(transpose_perm).reshape(math.prod(dims)) +// where: +// `dims`: is the dimensions of the tile assignment array, which corresponds to +// OpSharding.tile_assignment_dimensions. +// `reshape_dims`: is the dimensions the 1D iota array is reshaped to. +// `transpose_perm`: is the dimension permutation to transpose `reshape_dims`. +// `subgroup_types`: indicates the subgroups of the last `subgroup_types.size()` +// dimensions in `dims`. +// +// In practice, `reshape_dims` often maps to the axises of user defined device +// mesh, and `transpose_perm` often maps to the user specification of how a +// tensor is partitioned based on the axes defined in the mesh, e.g. for a mesh +// of size 4x2x2 as AxBxC: +// PartitionSpec('A', 'B', 'C') corresponds to reshape_dims=[4,2,2], +// transpose_perm=[0,1,2] (no transpose) +// PartitionSpec('B', 'A', 'C') corresponds to reshape_dims=[4,2,2], +// transpose_perm=[1,0,2] (swap A and B) +absl::StatusOr IotaTileHelper( + absl::Span dims, absl::Span reshape_dims, + absl::Span transpose_perm, + absl::Span subgroup_types) { + if (dims.empty()) { + return InvalidArgument("`dims` should not be empty."); + } + if (reshape_dims.size() != transpose_perm.size()) { + return InvalidArgument( + "`reshape_dims` and `transpose_perm` should have the same size, saw " + "[%s] v.s. [%s]", + absl::StrJoin(reshape_dims, ","), absl::StrJoin(transpose_perm, ",")); + } + if (!reshape_dims.empty() && Product(dims) != Product(reshape_dims)) { + return InvalidArgument( + "Cannot reshape from `dims` [%s] to `reshape_dims` [%s].", + absl::StrJoin(dims, ","), absl::StrJoin(reshape_dims, ",")); + } + if (subgroup_types.size() > dims.size()) { + return InvalidArgument( + "`subgroup_types`(%lld) should not have more dimensions than " + "`dims`(%lld).", + subgroup_types.size(), dims.size()); + } + if (reshape_dims.empty()) { + return subgroup_types.empty() + ? HloSharding::IotaTile(dims) + : HloSharding::Subgroup(TileAssignment(dims), subgroup_types); + } + return subgroup_types.empty() + ? HloSharding::IotaTile(dims, reshape_dims, transpose_perm) + : HloSharding::Subgroup( + TileAssignment(dims, reshape_dims, transpose_perm), + subgroup_types); +} + +// Registers a 'fn' as a custom call target. +// +// `fn` must be a custom call implementation function pointer (XLA_FFI_Handler* +// when implemented as FFI handler) encapsulated in a PyCapsule object or a +// a dictionary of function pointers (also encapsulated in a PyCapsule). +// +// See XLA_FFI_ExecutionStage documentation for more details about the +// custom execution stages. +absl::Status PyRegisterCustomCallTarget(const std::string& fn_name, + nb::object fn, + const std::string& platform, + int api_version, + XLA_FFI_Handler_Traits traits) { + // Register legacy custom call target (untyped void* API). + if (api_version == 0) { + if (traits != 0) { + return absl::InvalidArgumentError( + "Custom call target registration with traits is not supported for " + "api_version=0"); + } + + nb::capsule capsule; + if (!nb::try_cast(fn, capsule)) { + return absl::InvalidArgumentError( + "Custom call target registration with api_version=0 requires a " + "PyCapsule fn object"); + } + + CustomCallTargetRegistry::Global()->Register( + fn_name, static_cast(capsule.data()), platform); + return absl::OkStatus(); + } + + // Register XLA FFI handler (typed API with explicit function signatures). + if (api_version == 1) { + nb::capsule capsule; + if (nb::try_cast(fn, capsule)) { + return ffi::TakeStatus(ffi::Ffi::RegisterStaticHandler( + xla::ffi::GetXlaFfiApi(), fn_name, platform, + reinterpret_cast( + static_cast(capsule.data())))); + } + + nb::dict bundle; + if (nb::try_cast(fn, bundle)) { + auto handler = [&](const char* name) -> absl::StatusOr { + if (!bundle.contains(name)) return nullptr; + + nb::capsule capsule; + if (!nb::try_cast(bundle[name], capsule)) { + return absl::InvalidArgumentError( + "Custom call target registration with api_version=1 requires a " + "PyCapsule fn object for all dict keys"); + } + + return reinterpret_cast(capsule.data()); + }; + + XLA_FFI_Handler_Bundle bundle; + TF_ASSIGN_OR_RETURN(bundle.instantiate, handler("instantiate")); + TF_ASSIGN_OR_RETURN(bundle.prepare, handler("prepare")); + TF_ASSIGN_OR_RETURN(bundle.initialize, handler("initialize")); + TF_ASSIGN_OR_RETURN(bundle.execute, handler("execute")); + + return ffi::TakeStatus(ffi::Ffi::RegisterStaticHandler( + xla::ffi::GetXlaFfiApi(), fn_name, platform, bundle, traits)); + } + + return absl::InvalidArgumentError( + "Unsupported custom call target type for api_version=1"); + } + + return absl::UnimplementedError(absl::StrFormat( + "API version %d is not supported by RegisterCustomCallTarget. " + "Supported versions are 0 and 1.", + api_version)); +} + +absl::Status PyRegisterCustomTypeId(absl::string_view type_name, + nb::object type_id) { + nb::capsule capsule; + if (!nb::try_cast(type_id, capsule)) { + return absl::InvalidArgumentError( + "The type_id argument to register_custom_call_type_id must be a " + "PyCapsule object holding a pointer to a XLA_FFI_TypeId."); + } + XLA_FFI_TypeId* type_id_ptr = + reinterpret_cast(static_cast(capsule.data())); + return ffi::TakeStatus(ffi::Ffi::RegisterTypeId(xla::ffi::GetXlaFfiApi(), + type_name, type_id_ptr)); +} + +template +void DefRepeatedProperty(nb::class_& cls, const char* name, + Container* (T::*getter)()) { + cls.def_prop_rw( + name, + [getter](T& obj) { + Container* elems = (obj.*getter)(); + std::vector result; + result.reserve(elems->size()); + std::copy(elems->begin(), elems->end(), std::back_inserter(result)); + return result; + }, + [getter](T& obj, std::vector new_elems) { + Container* elems = (obj.*getter)(); + elems->Clear(); + elems->Reserve(new_elems.size()); + for (typename Container::value_type& e : new_elems) { + elems->Add(std::move(e)); + } + }); +} + +template +void DefRepeatedEnumProperty(nb::class_& cls, const char* name, + Container* (T::*getter)()) { + cls.def_prop_rw( + name, + [getter](T& obj) { + Container* elems = (obj.*getter)(); + std::vector result; + result.reserve(elems->size()); + std::copy(elems->begin(), elems->end(), std::back_inserter(result)); + return result; + }, + [getter](T& obj, nb::sequence new_elems) { + Container* elems = (obj.*getter)(); + elems->Clear(); + for (nb::handle e : new_elems) { + elems->Add(nb::cast(e.attr("value"))); + } + }); +} + +template +Array NDArrayToArray(nb::ndarray ndarray) { + std::vector shapes; + shapes.reserve(ndarray.ndim()); + for (int i = 0; i < ndarray.ndim(); ++i) { + shapes.push_back(ndarray.shape(i)); + } + xla::Array array(shapes); + array.Each([&](absl::Span indices, int64_t* val) { + int64_t offset = indices.back(); + int64_t multiplier = 1; + for (int i = ndarray.ndim() - 1; i > 0; --i) { + multiplier *= ndarray.shape(i); + offset += indices[i - 1] * multiplier; + } + *val = *(ndarray.data() + offset); + }); + return array; +} + +absl::StatusOr SubgroupWithTileAssignmentHelper( + nb::ndarray tile_assignment, + absl::Span subgroup_types) { + return HloSharding::Subgroup(NDArrayToArray(tile_assignment), subgroup_types); +} + +nb::ndarray<> LiteralToNdarray(Literal& obj) { + const Shape& shape = obj.shape(); + + if (!shape.has_layout()) { + throw XlaRuntimeError( + "Creating an array is only supported for Literals with a layout."); + } + + const Layout& layout = shape.layout(); + + if (!layout.tiles().empty()) { + throw XlaRuntimeError( + "Creating an array from a tiled Literal is not supported."); + } + + if (!LayoutUtil::IsDenseArray(shape)) { + throw XlaRuntimeError( + "Creating an array is only supported for dense Literals."); + } + + xla::PrimitiveType primitive_type = shape.element_type(); + nb::dlpack::dtype dtype = + ValueOrThrow(PrimitiveTypeToNbDLDataType(primitive_type)); + + absl::Span dimensions = shape.dimensions(); + std::vector unsigned_dimensions(dimensions.begin(), dimensions.end()); + auto strides = StridesForShape(primitive_type, dimensions, layout); + + return nb::ndarray<>(obj.untyped_data(), unsigned_dimensions.size(), + unsigned_dimensions.data(), {}, strides.data(), dtype, + nb::device::cpu::value, 0); +} + +} // namespace + +void BuildXlaCompilerSubmodule(nb::module_& m) { + // Shapes + nb::class_ layout_class(m, "Layout"); + layout_class.def(nb::init>()) + .def("__init__", + [](Layout* self, nb::sequence minor_to_major, nb::sequence tiling, + int64_t element_size_in_bits) { + std::vector xla_tiles; + xla_tiles.reserve(nb::len(tiling.ptr())); + for (auto tile : tiling) { + xla_tiles.push_back(Tile( + SequenceToVector(nb::cast(tile)))); + } + std::vector xla_minor_to_major = + SequenceToVector(minor_to_major); + new (self) + Layout(xla_minor_to_major, xla_tiles, element_size_in_bits); + }) + .def("minor_to_major", + [](Layout layout) { return SpanToNbTuple(layout.minor_to_major()); }) + .def("element_size_in_bits", &Layout::element_size_in_bits) + .def("tiling", + [](Layout layout) { + std::vector result; + result.reserve(layout.tiles().size()); + for (auto& t : layout.tiles()) { + result.push_back(SpanToNbTuple(t.dimensions())); + } + return result; + }) + .def("__eq__", [](const Layout& layout, + const Layout& other) { return layout == other; }) + .def("__ne__", [](const Layout& layout, + const Layout& other) { return layout != other; }) + .def("__str__", &Layout::ToString) + .def("__hash__", + [](const Layout& layout) { return absl::HashOf(layout); }) + .def("to_string", &Layout::ToString) + .def("__getstate__", + [](const Layout& self) -> nb::tuple { + auto proto = self.ToProto(); + std::string result; + if (!tsl::SerializeToStringDeterministic(proto, &result)) { + // throw converted by PyBind to a Python RuntimeError. + throw XlaRuntimeError( + absl::StrCat("Layout.py_pickle: ", + "SerializeToStringDeterministic failed")); + } + return nb::make_tuple(nb::bytes(result.data(), result.size())); + }) + .def("__setstate__", [](Layout* self, nb::tuple t) { + LayoutProto result; + nb::bytes serialized = nb::cast(t[0]); + result.ParseFromArray(serialized.c_str(), serialized.size()); + new (self) Layout(Layout::CreateFromProto(result)); + }); + + nb::class_ shape_class(m, "Shape"); + shape_class + .def("__init__", + [](Shape* self, const std::string& s) { + new (self) Shape(ValueOrThrow(ParseShape(s))); + }) + .def_static( + "tuple_shape", + [](std::vector shapes) -> Shape { + return ShapeUtil::MakeTupleShape(shapes); + }, + "Constructs a tuple shape.") + .def_static("array_shape", + xla::ValueOrThrowWrapper( + [](PrimitiveType type, nb::sequence dims_seq, + std::optional layout_seq, + std::optional> dynamic_dimensions) + -> absl::StatusOr { + std::vector dims = + SequenceToVector(dims_seq); + if (layout_seq) { + std::vector layout = + SequenceToVector(*layout_seq); + return MakeShapeWithDenseLayout(type, dims, layout, + dynamic_dimensions); + } else { + return MakeShapeWithDenseLayout( + type, dims, std::nullopt, dynamic_dimensions); + } + }), + "Constructs an array shape.", nb::arg("type"), + nb::arg("dims"), nb::arg("layout").none() = std::nullopt, + nb::arg("dynamic_dimensions").none() = std::nullopt) + .def_static( + "array_shape", + xla::ValueOrThrowWrapper( + [](nb_dtype dtype, nb::sequence dims_seq, + std::optional layout_seq, + std::optional> dynamic_dimensions) + -> absl::StatusOr { + PrimitiveType type = ValueOrThrow(DtypeToPrimitiveType(dtype)); + std::vector dims = SequenceToVector(dims_seq); + if (layout_seq) { + std::vector layout = + SequenceToVector(*layout_seq); + return MakeShapeWithDenseLayout(type, dims, layout, + dynamic_dimensions); + } else { + return MakeShapeWithDenseLayout(type, dims, std::nullopt, + dynamic_dimensions); + } + }), + "Constructs an array shape.", nb::arg("type"), nb::arg("dims"), + nb::arg("layout").none() = std::nullopt, + nb::arg("dynamic_dimensions").none() = std::nullopt) + .def_static("token_shape", []() { return ShapeUtil::MakeTokenShape(); }) + .def_static( + "scalar_shape", + [](PrimitiveType type) -> Shape { + return ShapeUtil::MakeScalarShape(type); + }, + "Constructs a scalar shape.", nb::arg("type")) + .def_static( + "scalar_shape", + [](nb_dtype dtype) -> Shape { + PrimitiveType type = xla::ValueOrThrow(DtypeToPrimitiveType(dtype)); + return ShapeUtil::MakeScalarShape(type); + }, + "Constructs a scalar shape.", nb::arg("type")) + .def("dimensions", + [](const Shape& shape) -> nb::tuple { + return SpanToNbTuple(shape.dimensions()); + }) + .def("layout", + [](const Shape& shape) -> Layout { return shape.layout(); }) + .def("xla_element_type", &Shape::element_type) + .def("element_type", + [](const Shape& shape) { + return xla::ValueOrThrow( + PrimitiveTypeToNbDtype(shape.element_type())); + }) + .def("numpy_dtype", + [](const Shape& shape) { + if (shape.IsTuple()) { + return nb_dtype("O"); + } + return xla::ValueOrThrow( + PrimitiveTypeToNbDtype(shape.element_type())); + }) + .def("is_tuple", &Shape::IsTuple) + .def("is_array", &Shape::IsArray) + .def("is_token", &Shape::IsToken) + .def("is_static", &Shape::is_static) + .def("is_dynamic", &Shape::is_dynamic) + .def("is_dynamic_dimension", &Shape::is_dynamic_dimension, + nb::arg("dimension")) + .def("set_dynamic_dimension", &Shape::set_dynamic_dimension, + nb::arg("dimension"), nb::arg("is_dynamic")) + .def("rank", &Shape::dimensions_size) + .def("to_serialized_proto", + [](const Shape& shape) { + ShapeProto proto = shape.ToProto(); + std::string s = proto.SerializeAsString(); + return nb::bytes(s.data(), s.size()); + }) + .def("tuple_shapes", + [](const Shape& shape) { + return std::vector(shape.tuple_shapes()); + }) + .def("leaf_count", + [](const Shape& shape) { return ShapeUtil::GetLeafCount(shape); }) + .def( + "with_major_to_minor_layout_if_absent", + [](const Shape& shape) { + Shape out = shape; + ShapeUtil::ForEachMutableSubshape( + &out, [](Shape* subshape, const ShapeIndex&) { + if (!subshape->has_layout()) { + LayoutUtil::SetToDefaultLayout(subshape); + } + }); + return out; + }, + "Returns a copy of a shape with missing layouts set to " + "major-to-minor.") + .def("__eq__", [](const Shape& shape, + const Shape& other) { return shape == other; }) + .def("__ne__", [](const Shape& shape, + const Shape& other) { return shape != other; }) + .def("__hash__", [](const Shape& shape) { return absl::HashOf(shape); }) + .def("__repr__", [](const Shape& shape) { + return shape.ToString(/*print_layout=*/true); + }); + + nb::class_(m, "ProgramShape") + .def( + "__init__", + [](ProgramShape* self, absl::Span params, Shape result) { + new (self) ProgramShape(); + for (const Shape& param : params) { + *self->add_parameters() = param; + } + *self->mutable_result() = result; + }) + .def("parameter_shapes", + static_cast& (ProgramShape::*)() const>( + &ProgramShape::parameters)) + .def("result_shape", &ProgramShape::result) + .def("__repr__", &ProgramShape::ToString); + + nb::class_(m, "ShapeIndex") + .def("__init__", + [](ShapeIndex* self, const std::vector& v) { + new (self) ShapeIndex(v.begin(), v.end()); + }) + .def("__repr__", &ShapeIndex::ToString) + .def("__eq__", [](const ShapeIndex& shape_ind, + const ShapeIndex& other) { return shape_ind == other; }) + .def("__ne__", [](const ShapeIndex& shape_ind, + const ShapeIndex& other) { return shape_ind != other; }) + .def("__hash__", + [](const ShapeIndex& shape_ind) { return absl::HashOf(shape_ind); }); + + // Literals + nb::class_(m, "Literal") + .def(nb::init()) + .def("__repr__", &Literal::ToString) + .def( + "__array__", + [](std::shared_ptr obj, std::optional dtype, + std::optional copy) { + // Provides the interface required by numpy to create a np.ndarray. + // Currently don't support the __dl_pack__ interface but can be + // added with very little effort it if needed. + + nb::ndarray np_array(LiteralToNdarray(*obj)); + + if (dtype.has_value()) { + throw XlaRuntimeError( + "Passing of dtype to __array__ not currently supported."); + } + + if (copy.has_value() && *copy) { + // when a copy is requested we _must_ return a copy: + // https://numpy.org/doc/2.1/reference/generated/numpy.ndarray.__array__.html + return np_array.cast(nb::rv_policy::copy); + } + + return np_array.cast(nb::rv_policy::reference_internal, + nb::cast(obj)); + }, + nb::arg("dtype").none() = nb::none(), + nb::arg("copy").none() = nb::none()) + .def("shape", &Literal::shape); + + nb::class_(m, "XlaComputation") + .def("__init__", + [](XlaComputation* self, + const nb::bytes& serialized_hlo_module_proto) { + HloModuleProto proto; + proto.ParseFromArray(serialized_hlo_module_proto.c_str(), + serialized_hlo_module_proto.size()); + new (self) XlaComputation(proto); + }) + .def("get_hlo_module", xla::ValueOrThrowWrapper(GetHloModule)) + .def("program_shape", + xla::ValueOrThrowWrapper(&XlaComputation::GetProgramShape)) + .def("name", &XlaComputation::name) + .def("as_serialized_hlo_module_proto", + xla::ValueOrThrowWrapper(GetComputationSerializedProto)) + .def("as_hlo_text", xla::ValueOrThrowWrapper(GetComputationHloText), + nb::arg("print_large_constants") = false) + .def("as_hlo_dot_graph", + xla::ValueOrThrowWrapper(GetComputationHloDotGraph)) + .def("hash", xla::ValueOrThrowWrapper(HashComputation)) + .def("as_hlo_module", xla::ValueOrThrowWrapper(GetHloModule)); + + nb::class_ hlo_print_options_class(m, "HloPrintOptions"); + hlo_print_options_class.def(nb::init<>()) + .def_static("short_parsable", &HloPrintOptions::ShortParsable) + .def_static("canonical", &HloPrintOptions::Canonical) + .def_static("fingerprint", &HloPrintOptions::Fingerprint) + .def_prop_rw("print_large_constants", + &HloPrintOptions::print_large_constants, + &HloPrintOptions::set_print_large_constants) + .def_prop_rw("print_metadata", &HloPrintOptions::print_metadata, + &HloPrintOptions::set_print_metadata) + .def_prop_rw("print_backend_config", + &HloPrintOptions::print_backend_config, + &HloPrintOptions::set_print_backend_config) + .def_prop_rw("print_result_shape", &HloPrintOptions::print_result_shape, + &HloPrintOptions::set_print_result_shape) + .def_prop_rw("print_operand_shape", &HloPrintOptions::print_operand_shape, + &HloPrintOptions::set_print_operand_shape) + .def_prop_rw("print_operand_names", &HloPrintOptions::print_operand_names, + &HloPrintOptions::set_print_operand_names) + .def_prop_rw("print_ids", &HloPrintOptions::print_ids, + &HloPrintOptions::set_print_ids) + .def_prop_rw("print_extra_attributes", + &HloPrintOptions::print_extra_attributes, + &HloPrintOptions::set_print_extra_attributes) + .def_prop_rw("print_program_shape", &HloPrintOptions::print_program_shape, + &HloPrintOptions::set_print_program_shape) + .def_prop_rw("print_percent", &HloPrintOptions::print_percent, + &HloPrintOptions::set_print_percent) + .def_prop_rw("print_control_dependencies", + &HloPrintOptions::print_control_dependencies, + &HloPrintOptions::set_print_control_dependencies) + .def_prop_rw("compact_operands", &HloPrintOptions::compact_operands, + &HloPrintOptions::set_compact_operands) + .def_prop_rw("include_layout_in_shapes", + &HloPrintOptions::include_layout_in_shapes, + &HloPrintOptions::set_include_layout_in_shapes) + .def_prop_rw("canonicalize_instruction_names", + &HloPrintOptions::canonicalize_instruction_names, + &HloPrintOptions::set_canonicalize_instruction_names) + .def_prop_rw("canonicalize_computations", + &HloPrintOptions::canonicalize_computations, + &HloPrintOptions::set_canonicalize_computations) + .def_prop_rw("indent_amount", &HloPrintOptions::indent_amount, + &HloPrintOptions::set_indent_amount) + .def_prop_rw("is_in_nested_computation", + &HloPrintOptions::is_in_nested_computation, + &HloPrintOptions::set_is_in_nested_computation); + + // HloModule.computations() returns raw pointers. + // pybind seems to prefer smart pointers. + // We give pybind a smart pointer to a wrapper around a raw pointer to satisfy + // pybind and avoid double frees. + class ComputationWrapper { + public: + ComputationWrapper(const HloComputation* comp, + const std::shared_ptr module) + : comp_(comp), module_(module) {} + absl::string_view name() const { return comp_->name(); } + void render_html(const std::string& filename) { + std::string html = xla::ValueOrThrow(RenderGraph( + *comp_, /*label=*/"", comp_->parent()->config().debug_options(), + RenderedGraphFormat::kHtml, HloRenderOptions())); + xla::ThrowIfError(tsl::WriteStringToFile( + tsl::Env::Default(), absl::StrCat(filename, ".html"), html)); + } + + private: + const HloComputation* comp_; + // The module owns the computations: if its destructor is called, the + // computations are freed. To prevent that from happening in cases where the + // module Python object goes out of scope and gets garbage collected before + // the computations, we keep a shared_ptr to the module that originated the + // computation. + const std::shared_ptr module_; + }; + + nb::class_ hlo_computation_class(m, "HloComputation"); + + hlo_computation_class.def_prop_ro("name", &ComputationWrapper::name) + .def("render_html", &ComputationWrapper::render_html); + + nb::class_ hlo_module_class(m, "HloModule"); + hlo_module_class.def_prop_ro("name", &HloModule::name) + .def( + "to_string", + static_cast( + &HloModule::ToString), + nb::arg("options") = HloPrintOptions()) + .def("as_serialized_hlo_module_proto", + xla::ValueOrThrowWrapper(GetHloModuleSerializedProto)) + .def("from_serialized_hlo_module_proto", + xla::ValueOrThrowWrapper(HloModuleFromSerializedProto)) + .def("computations", + [](const std::shared_ptr m) + -> std::vector> { + std::vector> computations; + for (HloComputation* comp : m->computations()) + computations.push_back( + std::make_shared(comp, m)); + return computations; + }) + .def_prop_ro("spmd_output_sharding", + [](const HloModule& m) -> std::optional { + if (!m.has_spmd_output_sharding()) return std::nullopt; + return m.spmd_output_sharding().ToProto(); + }) + .def_prop_ro("spmd_parameters_shardings", + [](const HloModule& m) + -> std::optional> { + if (!m.has_spmd_parameters_shardings()) + return std::nullopt; + std::vector param_shardings; + for (const auto& parameter_sharding : + m.spmd_parameters_shardings()) { + param_shardings.push_back(parameter_sharding.ToProto()); + } + return param_shardings; + }); + + nb::class_ hlo_module_group_class(m, "HloModuleGroup"); + hlo_module_group_class + .def("__init__", + [](HloModuleGroup* self, const std::string& name, + const std::vector>& hlo_modules) { + std::vector> modules; + modules.reserve(hlo_modules.size()); + for (const auto& m : hlo_modules) { + modules.push_back(m->Clone(/*suffix=*/"")); + } + new (self) HloModuleGroup(name, std::move(modules)); + }) + .def_prop_ro("name", &HloModuleGroup::name) + .def("to_string", &HloModuleGroup::ToString) + .def("to_modules", + [](HloModuleGroup& m) -> std::vector> { + std::vector> modules = + m.ConsumeModules(); + std::vector> shared_modules; + shared_modules.reserve(modules.size()); + for (auto& module : modules) { + shared_modules.push_back(std::move(module)); + } + return shared_modules; + }); + + m.def("hlo_module_to_dot_graph", + [](const HloModule& hlo_module) -> std::string { + return xla::ValueOrThrow(RenderGraph( + *hlo_module.entry_computation(), /*label=*/"", + hlo_module.config().debug_options(), RenderedGraphFormat::kDot)); + }); + m.def( + "hlo_module_cost_analysis", + xla::ValueOrThrowWrapper([](PyClient* client, const HloModule& module) + -> absl::StatusOr { + TF_ASSIGN_OR_RETURN(auto analysis, + client->pjrt_client()->GetHloCostAnalysis()); + TF_RETURN_IF_ERROR(module.entry_computation()->Accept(analysis.get())); + + // Convert from HloCostAnalysis::Properties to a standard map. + nb::dict ret; + analysis->properties().ForEach([&](absl::string_view key, float val) { + ret[nb::str(key.data(), key.size())] = nb::cast(val); + }); + return ret; + })); + m.def("hlo_module_from_text", + xla::ValueOrThrowWrapper( + [](const std::string& hlo_module_text) + -> absl::StatusOr> { + auto hlo_module = + xla::ParseAndReturnUnverifiedModule(hlo_module_text); + TF_RETURN_IF_ERROR(hlo_module.status()); + std::shared_ptr result(std::move(*hlo_module)); + return result; + })); + + nb::class_ xla_op_class(m, "XlaOp"); + + nb::class_(m, "XlaBuilder") + .def("__init__", + [](XlaBuilder* self, const std::string& name) { + new (self) XlaBuilder(UniquifyName(name)); + }) + // TODO(phawkins): delete capitalized names after updating callers. + .def("Build", + xla::ValueOrThrowWrapper( + [](XlaBuilder& builder, std::optional root) { + return root ? builder.Build(*root) : builder.Build(); + }), + "Builds a computation from the contents of the builder.", + nb::arg("root") = std::nullopt) + .def("GetShape", xla::ValueOrThrowWrapper(&XlaBuilder::GetShape)) + .def("build", + xla::ValueOrThrowWrapper( + [](XlaBuilder& builder, std::optional root) { + return root ? builder.Build(*root) : builder.Build(); + }), + "Builds a computation from the contents of the builder.", + nb::arg("root") = std::nullopt) + .def("clear_op_metadata", &XlaBuilder::ClearOpMetadata) + .def("get_shape", xla::ValueOrThrowWrapper(&XlaBuilder::GetShape)) + .def( + "get_program_shape", + [](const XlaBuilder& builder, + std::optional root) -> absl::StatusOr { + return root ? builder.GetProgramShape(*root) + : builder.GetProgramShape(); + }, + nb::arg("root") = std::nullopt) + .def("is_constant", xla::ValueOrThrowWrapper(&XlaBuilder::IsConstant)) + .def("set_op_metadata", &XlaBuilder::SetOpMetadata) + .def("set_sharding", &XlaBuilder::SetSharding) + .def("clear_sharding", &XlaBuilder::ClearSharding) + .def("set_frontend_attributes", &XlaBuilder::SetFrontendAttributes) + .def("clear_frontend_attributes", &XlaBuilder::ClearFrontendAttributes) + .def("setup_alias", + [](XlaBuilder& builder, const std::vector& output_index, + int64_t param_number, const std::vector& param_index) { + builder.SetUpAlias( + ShapeIndex(output_index.begin(), output_index.end()), + param_number, + ShapeIndex(param_index.begin(), param_index.end())); + }); + + // Device assignments + nb::class_(m, "DeviceAssignment") + .def_static( + "create", + xla::ValueOrThrowWrapper([](nb::ndarray> array) + -> absl::StatusOr { + if (array.ndim() != 2) { + return InvalidArgument( + "Argument to DeviceAssignment constructor must be a " + "2D array, received an %dD array.", + array.ndim()); + } + DeviceAssignment result(array.shape(0), array.shape(1)); + for (int i = 0; i < array.shape(0); ++i) { + for (int j = 0; j < array.shape(1); ++j) { + result(i, j) = array(i, j); + } + } + return result; + })) + .def("replica_count", &DeviceAssignment::replica_count) + .def("computation_count", &DeviceAssignment::computation_count) + .def("__repr__", &DeviceAssignment::ToString) + .def("serialize", + xla::ValueOrThrowWrapper( + [](const DeviceAssignment& da) -> absl::StatusOr { + DeviceAssignmentProto proto; + da.Serialize(&proto); + std::string result; + if (!tsl::SerializeToStringDeterministic(proto, &result)) { + return Unknown( + "Failed to serialize the DeviceAssignmentProto."); + } + return nb::bytes(result.data(), result.size()); + })); + + nb::class_ compile_options(m, "CompileOptions"); + compile_options + .def("__init__", + [](CompileOptions* self) { + new (self) CompileOptions(); + DebugOptions* debug_options = + self->executable_build_options.mutable_debug_options(); + // Sets fast-math-disabling default options expected by JAX. + debug_options->set_xla_cpu_enable_fast_min_max(false); + debug_options->set_xla_gpu_enable_fast_min_max(false); + }) + .def("__getstate__", + [](const CompileOptions& self) -> nb::tuple { + auto proto = ValueOrThrow(self.ToProto()); + std::string result; + if (!tsl::SerializeToStringDeterministic(proto, &result)) { + // throw converted by PyBind to a Python RuntimeError. + throw XlaRuntimeError( + absl::StrCat("CompileOptions.py_pickle: ", + "SerializeToStringDeterministic failed")); + } + return nb::make_tuple(nb::bytes(result.data(), result.size())); + }) + .def("__setstate__", + [](CompileOptions* self, nb::tuple t) { + CompileOptionsProto result; + nb::bytes serialized = nb::cast(t[0]); + result.ParseFromArray(serialized.c_str(), serialized.size()); + new (self) CompileOptions( + ValueOrThrow(CompileOptions::FromProto(result))); + }) + .def("SerializeAsString", + [](const CompileOptions& self) -> nb::bytes { + auto proto = ValueOrThrow(self.ToProto()); + std::string result; + if (!tsl::SerializeToStringDeterministic(proto, &result)) { + // throw converted by PyBind to a Python RuntimeError. + throw XlaRuntimeError( + absl::StrCat("CompileOptions.SerializeAsString: ", + "SerializeToStringDeterministic failed")); + } + return nb::bytes(result.data(), result.size()); + }) + .def_static("ParseFromString", + [](nb::bytes s) { + CompileOptionsProto result; + result.ParseFromArray(s.c_str(), s.size()); + return ValueOrThrow(CompileOptions::FromProto(result)); + }) + .def_rw("argument_layouts", &CompileOptions::argument_layouts) + .def_rw("parameter_is_tupled_arguments", + &CompileOptions::parameter_is_tupled_arguments) + .def_rw("compile_portable_executable", + &CompileOptions::compile_portable_executable) + .def_ro("executable_build_options", + &CompileOptions::executable_build_options) + .def_rw("env_option_overrides", &CompileOptions::env_option_overrides) + // TODO(phawkins): the following fields exist for backward compatibility. + // Remove them after JAX has been updated not to use them. + .def_rw("tuple_arguments", &CompileOptions::parameter_is_tupled_arguments) + .def_prop_rw( + "num_replicas", + [](const CompileOptions& options) { + return options.executable_build_options.num_replicas(); + }, + [](CompileOptions& options, int num_replicas) { + options.executable_build_options.set_num_replicas(num_replicas); + }) + .def_prop_rw( + "num_partitions", + [](const CompileOptions& options) { + return options.executable_build_options.num_partitions(); + }, + [](CompileOptions& options, int num_partitions) { + options.executable_build_options.set_num_partitions(num_partitions); + }) + .def_prop_rw( + "profile_version", + [](const CompileOptions& options) { return options.profile_version; }, + [](CompileOptions& options, int64_t profile_version) { + options.profile_version = profile_version; + }) + .def_prop_rw( + "device_assignment", + [](const CompileOptions& options) -> std::optional { + return options.executable_build_options.has_device_assignment() + ? std::optional( + options.executable_build_options + .device_assignment()) + : std::nullopt; + }, + [](CompileOptions& options, + const DeviceAssignment& device_assignment) { + options.executable_build_options.set_device_assignment( + device_assignment); + }); + + // Custom-call targets. + m.def( + "register_custom_call_target", + [](nb::object fn_name_py, nb::object fn, const std::string& platform, + int api_version, XLA_FFI_Handler_Traits traits) { + std::string fn_name; + if (!nb::try_cast(fn_name_py, fn_name)) { + nb::bytes bytes = nb::cast(fn_name_py); + fn_name = std::string(bytes.c_str(), bytes.size()); + } + xla::ThrowIfError(PyRegisterCustomCallTarget( + fn_name, std::move(fn), platform, api_version, traits)); + }, + nb::arg("fn_name"), nb::arg("fn"), nb::arg("platform"), + nb::arg("api_version") = 0, nb::arg("traits") = 0); + + m.def( + "custom_call_targets", + [](const std::string& platform) -> nb::dict { + nb::dict targets; + for (const auto& [name, target] : + CustomCallTargetRegistry::Global()->registered_symbols(platform)) { + targets[nb::str(name.data(), name.size())] = nb::capsule(target); + } + + auto ffi_handlers = ffi::StaticRegisteredHandlers(platform); + if (!ffi_handlers.ok()) return targets; + + for (const auto& [name, registration] : *ffi_handlers) { + nb::dict bundle; + auto export_handler = [&](absl::string_view name, + XLA_FFI_Handler* h) { + if (h != nullptr) { + bundle[nb::str(name.data(), name.size())] = + nb::capsule(reinterpret_cast(h)); + } + }; + export_handler("prepare", registration.bundle.prepare); + export_handler("initialize", registration.bundle.initialize); + export_handler("execute", registration.bundle.execute); + targets[nb::str(name.data(), name.size())] = std::move(bundle); + } + return targets; + }, + nb::arg("platform")); + + nb::enum_(m, "AutotuneCacheMode") + .value("UNSPECIFIED", DebugOptions::AUTOTUNE_CACHE_MODE_UNSPECIFIED) + .value("UPDATE", DebugOptions::AUTOTUNE_CACHE_MODE_UPDATE) + .value("READ", DebugOptions::AUTOTUNE_CACHE_MODE_READ); + + m.def( + "register_custom_type_id", + [](absl::string_view type_name, nb::object type_id) { + xla::ThrowIfError(PyRegisterCustomTypeId(type_name, type_id)); + }, + nb::arg("type_name"), nb::arg("type_id")); + + nb::class_(m, "DebugOptions") + .def("__repr__", &DebugOptions::DebugString) + .def_prop_rw("xla_backend_optimization_level", + &DebugOptions::xla_backend_optimization_level, + &DebugOptions::set_xla_backend_optimization_level) + .def_prop_rw("xla_cpu_enable_fast_math", + &DebugOptions::xla_cpu_enable_fast_math, + &DebugOptions::set_xla_cpu_enable_fast_math) + .def_prop_rw("xla_cpu_enable_xprof_traceme", + &DebugOptions::xla_cpu_enable_xprof_traceme, + &DebugOptions::set_xla_cpu_enable_xprof_traceme) + .def_prop_rw("xla_cpu_fast_math_honor_infs", + &DebugOptions::xla_cpu_fast_math_honor_infs, + &DebugOptions::set_xla_cpu_fast_math_honor_infs) + .def_prop_rw("xla_cpu_fast_math_honor_nans", + &DebugOptions::xla_cpu_fast_math_honor_nans, + &DebugOptions::set_xla_cpu_fast_math_honor_nans) + .def_prop_rw("xla_cpu_fast_math_honor_division", + &DebugOptions::xla_cpu_fast_math_honor_division, + &DebugOptions::set_xla_cpu_fast_math_honor_division) + .def_prop_rw("xla_cpu_fast_math_honor_functions", + &DebugOptions::xla_cpu_fast_math_honor_functions, + &DebugOptions::set_xla_cpu_fast_math_honor_functions) + .def_prop_rw("xla_detailed_logging", &DebugOptions::xla_detailed_logging, + &DebugOptions::set_xla_detailed_logging) + .def_prop_rw("xla_enable_dumping", &DebugOptions::xla_enable_dumping, + &DebugOptions::set_xla_enable_dumping) + .def_prop_rw("xla_gpu_enable_fast_min_max", + &DebugOptions::xla_gpu_enable_fast_min_max, + &DebugOptions::set_xla_gpu_enable_fast_min_max) + .def_prop_rw("xla_gpu_dump_autotune_results_to", + &DebugOptions::xla_gpu_dump_autotune_results_to, + [](DebugOptions* self, std::string value) { + self->set_xla_gpu_dump_autotune_results_to(value); + }) + .def_prop_rw("xla_gpu_load_autotune_results_from", + &DebugOptions::xla_gpu_load_autotune_results_from, + [](DebugOptions* self, std::string value) { + self->set_xla_gpu_load_autotune_results_from(value); + }) + .def_prop_rw("xla_gpu_cuda_data_dir", + &DebugOptions::xla_gpu_cuda_data_dir, + [](DebugOptions* self, std::string value) { + self->set_xla_gpu_cuda_data_dir(value); + }) + .def_prop_rw("xla_llvm_disable_expensive_passes", + &DebugOptions::xla_llvm_disable_expensive_passes, + &DebugOptions::set_xla_llvm_disable_expensive_passes) + .def_prop_rw( + "xla_disable_hlo_passes", + [](DebugOptions* self) { + return absl::StrJoin(self->xla_disable_hlo_passes(), ","); + }, + [](DebugOptions* self, std::string value) { + self->clear_xla_disable_hlo_passes(); + for (const auto& passname : + std::vector(absl::StrSplit(value, ','))) { + self->add_xla_disable_hlo_passes(passname); + } + }) + .def_prop_rw( + "xla_enable_hlo_passes_only", + [](DebugOptions* self) { + return absl::StrJoin(self->xla_enable_hlo_passes_only(), ","); + }, + [](DebugOptions* self, std::string value) { + self->clear_xla_enable_hlo_passes_only(); + for (const auto& passname : + std::vector(absl::StrSplit(value, ','))) { + self->add_xla_enable_hlo_passes_only(passname); + } + }) + .def_prop_rw("xla_test_all_input_layouts", + &DebugOptions::xla_test_all_input_layouts, + &DebugOptions::set_xla_test_all_input_layouts) + .def_prop_rw("xla_force_host_platform_device_count", + &DebugOptions::xla_force_host_platform_device_count, + &DebugOptions::set_xla_force_host_platform_device_count) + .def_prop_rw("xla_dump_to", &DebugOptions::xla_dump_to, + [](DebugOptions* self, std::string value) { + self->set_xla_dump_to(value); + }) + .def_prop_rw("xla_dump_hlo_module_re", + &DebugOptions::xla_dump_hlo_module_re, + [](DebugOptions* self, std::string value) { + self->set_xla_dump_hlo_module_re(value); + }) + .def_prop_rw("xla_dump_hlo_pass_re", &DebugOptions::xla_dump_hlo_pass_re, + [](DebugOptions* self, std::string value) { + self->set_xla_dump_hlo_pass_re(value); + }) + .def_prop_rw("xla_dump_hlo_as_text", &DebugOptions::xla_dump_hlo_as_text, + &DebugOptions::set_xla_dump_hlo_as_text) + .def_prop_rw("xla_dump_hlo_as_proto", + &DebugOptions::xla_dump_hlo_as_proto, + &DebugOptions::set_xla_dump_hlo_as_proto) + .def_prop_rw("xla_dump_hlo_as_dot", &DebugOptions::xla_dump_hlo_as_dot, + &DebugOptions::set_xla_dump_hlo_as_dot) + .def_prop_rw("xla_dump_hlo_as_url", &DebugOptions::xla_dump_hlo_as_url, + &DebugOptions::set_xla_dump_hlo_as_url) + .def_prop_rw("xla_dump_hlo_as_html", &DebugOptions::xla_dump_hlo_as_html, + &DebugOptions::set_xla_dump_hlo_as_html) + .def_prop_rw("xla_dump_fusion_visualization", + &DebugOptions::xla_dump_fusion_visualization, + &DebugOptions::set_xla_dump_fusion_visualization) + .def_prop_rw("xla_dump_hlo_snapshots", + &DebugOptions::xla_dump_hlo_snapshots, + &DebugOptions::set_xla_dump_hlo_snapshots) + .def_prop_rw("xla_dump_max_hlo_modules", + &DebugOptions::xla_dump_max_hlo_modules, + &DebugOptions::set_xla_dump_max_hlo_modules) + .def_prop_rw("xla_dump_module_metadata", + &DebugOptions::xla_dump_module_metadata, + &DebugOptions::set_xla_dump_module_metadata) + .def_prop_rw("xla_dump_compress_protos", + &DebugOptions::xla_dump_compress_protos, + &DebugOptions::set_xla_dump_compress_protos) + .def_prop_rw("xla_dump_hlo_as_long_text", + &DebugOptions::xla_dump_hlo_as_long_text, + &DebugOptions::set_xla_dump_hlo_as_long_text) + .def_prop_rw("xla_dump_disable_metadata", + &DebugOptions::xla_dump_disable_metadata, + &DebugOptions::set_xla_dump_disable_metadata) + .def_prop_rw("xla_dump_hlo_pipeline_re", + &DebugOptions::xla_dump_hlo_pipeline_re, + [](DebugOptions* self, std::string value) { + self->set_xla_dump_hlo_pipeline_re(value); + }) + .def_prop_rw("xla_gpu_dump_autotune_logs_to", + &DebugOptions::xla_gpu_dump_autotune_logs_to, + [](DebugOptions* self, std::string value) { + self->set_xla_gpu_dump_autotune_logs_to(value); + }) + .def_prop_rw("xla_gpu_kernel_cache_file", + &DebugOptions::xla_gpu_kernel_cache_file, + [](DebugOptions* self, std::string value) { + self->set_xla_gpu_kernel_cache_file(value); + }) + .def_prop_rw( + "xla_gpu_enable_llvm_module_compilation_parallelism", + &DebugOptions::xla_gpu_enable_llvm_module_compilation_parallelism, + &DebugOptions::set_xla_gpu_enable_llvm_module_compilation_parallelism) + .def_prop_rw("xla_gpu_per_fusion_autotune_cache_dir", + &DebugOptions::xla_gpu_per_fusion_autotune_cache_dir, + [](DebugOptions* self, std::string value) { + self->set_xla_gpu_per_fusion_autotune_cache_dir(value); + }) + .def_prop_rw("xla_gpu_experimental_autotune_cache_mode", + &DebugOptions::xla_gpu_experimental_autotune_cache_mode, + &DebugOptions::set_xla_gpu_experimental_autotune_cache_mode); + + nb::class_(m, "ExecutableBuildOptions") + .def(nb::init<>()) + .def("__repr__", &ExecutableBuildOptions::ToString) + .def_prop_rw( + "fdo_profile", + [](const ExecutableBuildOptions& options) { + return nb::bytes(options.fdo_profile().data(), + options.fdo_profile().size()); + }, + [](ExecutableBuildOptions& options, nb::bytes fdo_profile) { + options.set_fdo_profile( + std::string(fdo_profile.c_str(), fdo_profile.size())); + }) + .def_prop_rw( + "result_layout", + [](const ExecutableBuildOptions& options) -> std::optional { + return options.result_layout() + ? std::optional(*options.result_layout()) + : std::nullopt; + }, + &ExecutableBuildOptions::set_result_layout) + .def_prop_rw("num_replicas", &ExecutableBuildOptions::num_replicas, + &ExecutableBuildOptions::set_num_replicas) + .def_prop_rw("num_partitions", &ExecutableBuildOptions::num_partitions, + &ExecutableBuildOptions::set_num_partitions) + .def_prop_ro("debug_options", + &ExecutableBuildOptions::mutable_debug_options, + nb::rv_policy::reference, nb::keep_alive<1, 0>()) + .def_prop_rw( + "device_assignment", + [](const ExecutableBuildOptions& options) + -> std::optional { + return options.has_device_assignment() + ? std::optional( + options.device_assignment()) + : std::nullopt; + }, + &ExecutableBuildOptions::set_device_assignment) + .def("compilation_environments_from_serialized_proto", + [](ExecutableBuildOptions& options, + const nb::bytes& serialized_proto) { + xla::CompilationEnvironmentsProto env_proto; + env_proto.ParseFromArray(serialized_proto.c_str(), + serialized_proto.size()); + auto comp_envs = xla::ValueOrThrow( + xla::CompilationEnvironments::CreateFromProto(env_proto)); + *options.mutable_comp_envs() = std::move(*comp_envs); + }) + .def_prop_rw("exec_time_optimization_effort", + &ExecutableBuildOptions::exec_time_optimization_effort, + &ExecutableBuildOptions::set_exec_time_optimization_effort) + .def_prop_rw("memory_fitting_effort", + &ExecutableBuildOptions::memory_fitting_effort, + &ExecutableBuildOptions::set_memory_fitting_effort) + .def_prop_rw( + "optimization_level", &ExecutableBuildOptions::optimization_level, + [](ExecutableBuildOptions& options, int value) { + options.set_optimization_level( + static_cast(value)); + }) + .def_prop_rw( + "memory_fitting_level", &ExecutableBuildOptions::memory_fitting_level, + [](ExecutableBuildOptions& options, int value) { + options.set_memory_fitting_level( + static_cast(value)); + }) + .def_prop_rw("use_spmd_partitioning", + &ExecutableBuildOptions::use_spmd_partitioning, + &ExecutableBuildOptions::set_use_spmd_partitioning) + .def_prop_rw("use_auto_spmd_partitioning", + &ExecutableBuildOptions::use_auto_spmd_partitioning, + &ExecutableBuildOptions::set_use_auto_spmd_partitioning) + .def_prop_rw( + "auto_spmd_partitioning_mesh_shape", + &ExecutableBuildOptions::auto_spmd_partitioning_mesh_shape, + &ExecutableBuildOptions::set_auto_spmd_partitioning_mesh_shape) + .def_prop_rw("auto_spmd_partitioning_mesh_ids", + &ExecutableBuildOptions::auto_spmd_partitioning_mesh_ids, + &ExecutableBuildOptions::set_auto_spmd_partitioning_mesh_ids) + .def_prop_rw( + "allow_spmd_sharding_propagation_to_parameters", + [](const ExecutableBuildOptions& options) -> std::vector { + return std::vector( + options.allow_spmd_sharding_propagation_to_parameters().begin(), + options.allow_spmd_sharding_propagation_to_parameters().end()); + }, + [](ExecutableBuildOptions& options, std::vector values) { + absl::InlinedVector v(values.begin(), values.end()); + options.set_allow_spmd_sharding_propagation_to_parameters(v); + }) + .def_prop_rw( + "allow_spmd_sharding_propagation_to_output", + [](const ExecutableBuildOptions& options) -> std::vector { + return std::vector( + options.allow_spmd_sharding_propagation_to_output().begin(), + options.allow_spmd_sharding_propagation_to_output().end()); + }, + [](ExecutableBuildOptions& options, std::vector values) { + absl::InlinedVector v(values.begin(), values.end()); + options.set_allow_spmd_sharding_propagation_to_output(v); + }) + .def_prop_rw("use_shardy_partitioner", + &ExecutableBuildOptions::use_shardy_partitioner, + &ExecutableBuildOptions::set_use_shardy_partitioner); + + nb::enum_ op_sharding_type(m, "OpSharding_Type", + nb::is_arithmetic()); + op_sharding_type.value("REPLICATED", OpSharding::REPLICATED) + .value("MAXIMAL", OpSharding::MAXIMAL) + .value("MANUAL", OpSharding::MANUAL) + .value("TUPLE", OpSharding::TUPLE) + .value("OTHER", OpSharding::OTHER) + .value("UNKNOWN", OpSharding::UNKNOWN); + + nb::enum_ op_sharding_shard_group_type( + m, "OpSharding_ShardGroupType"); + op_sharding_shard_group_type.value("AS", OpSharding::AS) + .value("LIKE", OpSharding::LIKE); + + nb::class_ op_sharding(m, "OpSharding"); + op_sharding + .def_prop_ro_static( + "Type", + [op_sharding_type](const nb::object&) { return op_sharding_type; }) + .def_prop_ro_static("ShardGroupType", + [op_sharding_shard_group_type](const nb::object&) { + return op_sharding_shard_group_type; + }) + .def(nb::init<>()) + .def("__getstate__", + [](const OpSharding& self) { + std::string serialized = self.SerializeAsString(); + return nb::make_tuple( + nb::bytes(serialized.data(), serialized.size())); + }) + .def("__setstate__", + [](OpSharding* self, nb::tuple t) { + new (self) OpSharding(); + nb::bytes serialized = nb::cast(t[0]); + self->ParseFromArray(serialized.c_str(), serialized.size()); + }) + .def_prop_rw("type", &xla::OpSharding::type, &xla::OpSharding::set_type) + .def_prop_rw("replicate_on_last_tile_dim", + &xla::OpSharding::replicate_on_last_tile_dim, + &xla::OpSharding::set_replicate_on_last_tile_dim) + .def_prop_rw("is_shard_group", &xla::OpSharding::is_shard_group, + &xla::OpSharding::set_is_shard_group) + .def_prop_rw("shard_group_id", &xla::OpSharding::shard_group_id, + &xla::OpSharding::set_shard_group_id) + .def_prop_rw("shard_group_type", &xla::OpSharding::shard_group_type, + &xla::OpSharding::set_shard_group_type) + .def("__repr__", + [](const xla::OpSharding& self) { return self.DebugString(); }) + .def("ParseFromString", + [](OpSharding& sharding, const nb::bytes& s) { + sharding.ParseFromArray(s.c_str(), s.size()); + }) + .def("SerializeToString", + [](const OpSharding& sharding) { + std::string serialized = sharding.SerializeAsString(); + return nb::bytes(serialized.data(), serialized.size()); + }) + .def("clone", + [](const OpSharding& sharding) { return OpSharding(sharding); }); + DefRepeatedProperty(op_sharding, "tile_assignment_dimensions", + &xla::OpSharding::mutable_tile_assignment_dimensions); + DefRepeatedProperty(op_sharding, "tile_assignment_devices", + &xla::OpSharding::mutable_tile_assignment_devices); + DefRepeatedProperty(op_sharding, "iota_reshape_dims", + &xla::OpSharding::mutable_iota_reshape_dims); + DefRepeatedProperty(op_sharding, "iota_transpose_perm", + &xla::OpSharding::mutable_iota_transpose_perm); + DefRepeatedProperty(op_sharding, "tuple_shardings", + &xla::OpSharding::mutable_tuple_shardings); + DefRepeatedEnumProperty(op_sharding, "last_tile_dims", + &xla::OpSharding::mutable_last_tile_dims); + + nb::class_ hlo_sharding(m, "HloSharding"); + hlo_sharding + .def_static("from_proto", + xla::ValueOrThrowWrapper(xla::HloSharding::FromProto)) + .def_static("from_string", xla::ValueOrThrowWrapper(xla::ParseSharding)) + .def_static( + "tuple_sharding", + [](xla::Shape shape, + std::vector shardings) -> xla::HloSharding { + return HloSharding::Tuple(shape, shardings); + }, + "Constructs a tuple sharding.") + .def_static( + "iota_tile", xla::ValueOrThrowWrapper(IotaTileHelper), + nb::arg("dims"), + nb::arg("reshape_dims") = absl::Span(), + nb::arg("transpose_perm") = absl::Span(), + nb::arg("subgroup_types") = absl::Span()) + .def_static("manual", [] { return HloSharding::Manual(); }) + .def_static("replicate", [] { return HloSharding::Replicate(); }) + .def_static("unknown", [] { return HloSharding::Unknown(); }) + .def_static( + "subgroup_with_device_ordering", + xla::ValueOrThrowWrapper(SubgroupWithTileAssignmentHelper), + nb::arg("tile_assignment"), + nb::arg("subgroup_types") = absl::Span()) + .def("__eq__", [](const xla::HloSharding& a, + const xla::HloSharding& b) { return a == b; }) + .def("__hash__", + [](const xla::HloSharding& self) { return absl::HashOf(self); }) + .def("is_replicated", &xla::HloSharding::IsReplicated) + .def("is_manual", &xla::HloSharding::IsManual) + .def("is_unknown", &xla::HloSharding::IsUnknown) + .def("is_tiled", &xla::HloSharding::IsTiled) + .def("is_maximal", &xla::HloSharding::IsTileMaximal) + .def("tile", [](const xla::HloSharding& self, + xla::Shape shape) { return self.TileShape(shape); }) + // tile_assignment.array() is computed using an internal cache, + // which is why nb::lock_self() is required. It may be preferable to move + // this locking into the TileAssignment class if we find it to race with + // non-Python users of that class. + .def( + "tuple_elements", + [](const xla::HloSharding& self) { return self.tuple_elements(); }, + nb::lock_self()) + .def( + "num_devices", + [](const xla::HloSharding& self) { + return self.tile_assignment().num_elements(); + }, + nb::lock_self()) + .def( + "num_dimensions", + [](const xla::HloSharding& self) { + return self.tile_assignment().num_dimensions(); + }, + nb::lock_self()) + .def( + "tile_assignment_dimensions", + [](const xla::HloSharding& self) { + absl::Span span = + self.tile_assignment().dimensions(); + CHECK(span.data()); + return span; + }, + nb::lock_self()) + .def( + "tile_assignment_devices", + [](const xla::HloSharding& self) { + auto span = + absl::MakeConstSpan(self.tile_assignment().array().data(), + self.tile_assignment().num_elements()); + CHECK(span.data()); + return span; + }, + nb::lock_self()) + .def("replicate_on_last_tile_dim", + &xla::HloSharding::ReplicateOnLastTileDim) + .def("subgroup_types", &xla::HloSharding::subgroup_types) + .def("__repr__", + [](const xla::HloSharding& self) { return self.ToString(); }) + .def("to_proto", &xla::HloSharding::ToProto); + + nb::class_ frontend_attributes(m, "FrontendAttributes"); + frontend_attributes.def(nb::init<>()) + .def("__setitem__", + [](FrontendAttributes* attr, std::string key, std::string value) { + (*attr->mutable_map())[key] = value; + }); + + nb::enum_(m, "PrecisionConfig_Precision") + .value("DEFAULT", PrecisionConfig::DEFAULT) + .value("HIGH", PrecisionConfig::HIGH) + .value("HIGHEST", PrecisionConfig::HIGHEST); + + nb::enum_(m, "ResultAccuracy_Mode") + .value("DEFAULT", ResultAccuracy::DEFAULT) + .value("HIGHEST", ResultAccuracy::HIGHEST); + + nb::enum_(m, "FftType") + .value("FFT", FftType::FFT) + .value("IFFT", FftType::IFFT) + .value("RFFT", FftType::RFFT) + .value("IRFFT", FftType::IRFFT); + + // Hlo Module Passes + nb::class_ hlo_pass_interface(m, "HloPassInterface"); + hlo_pass_interface.def_prop_ro("name", &HloPassInterface::name) + .def("is_pass_pipeline", &HloPassInterface::IsPassPipeline) + .def("run", + [](HloPassInterface& pass, HloModule* module) -> bool { + return xla::ValueOrThrow(pass.Run(module)); + }) + .def("run_on_module_group", + [](HloPassInterface& pass, HloModuleGroup* module_group) -> bool { + return xla::ValueOrThrow(pass.RunOnModuleGroup(module_group)); + }); + + nb::class_(m, "HloDCE").def(nb::init<>()); + nb::class_(m, "CallInliner").def(nb::init<>()); + nb::class_(m, "FlattenCallGraph") + .def(nb::init<>()); + nb::class_(m, "TupleSimplifier") + .def(nb::init<>()); +} // NOLINT(readability/fn_size) +} // namespace xla diff --git a/jaxlib/xla/xla_compiler.h b/jaxlib/xla/xla_compiler.h new file mode 100644 index 000000000000..ca5bc762a7d8 --- /dev/null +++ b/jaxlib/xla/xla_compiler.h @@ -0,0 +1,28 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_XLA_COMPILER_H_ +#define JAXLIB_XLA_XLA_COMPILER_H_ + +// placeholder for index annotation headers +#include "nanobind/nanobind.h" + +namespace xla { + +void BuildXlaCompilerSubmodule(nanobind::module_& m); + +} // namespace xla + +#endif // JAXLIB_XLA_XLA_COMPILER_H_ diff --git a/jaxlib/xla/xla_extension/__init__.pyi b/jaxlib/xla/xla_extension/__init__.pyi new file mode 100644 index 000000000000..7bfb2b1f675b --- /dev/null +++ b/jaxlib/xla/xla_extension/__init__.pyi @@ -0,0 +1,1041 @@ +# Copyright 2021 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import annotations + +import enum +import inspect +import types +import typing +from typing import ( + Any, + Callable, + ClassVar, + Dict, + Iterator, + List, + Optional, + Sequence, + Tuple, + Type, + TypeVar, + Union, + overload, +) + +import numpy as np + +from . import config as config +from . import guard_lib as guard_lib +from . import ifrt_programs as ifrt_programs +from . import ifrt_proxy as ifrt_proxy +from . import jax_jit as jax_jit +from . import mlir as mlir +from . import ops as ops +from . import pmap_lib as pmap_lib +from . import profiler as profiler +from . import pytree as pytree +from . import transfer_guard_lib as transfer_guard_lib + +custom_call_targets = Any +hlo_sharding_util = Any + +_LiteralSlice = Any +_Status = Any +_Dtype = Any +_XlaOpMetadata = Any + +_T = TypeVar("_T") + +class XlaRuntimeError(RuntimeError): + pass + +class PrimitiveType(enum.IntEnum): + PRIMITIVE_TYPE_INVALID: PrimitiveType + PRED: PrimitiveType + S2: PrimitiveType + S4: PrimitiveType + S8: PrimitiveType + S16: PrimitiveType + S32: PrimitiveType + S64: PrimitiveType + U2: PrimitiveType + U4: PrimitiveType + U8: PrimitiveType + U16: PrimitiveType + U32: PrimitiveType + U64: PrimitiveType + F4E2M1FN: PrimitiveType + F8E3M4: PrimitiveType + F8E4M3: PrimitiveType + F8E4M3FN: PrimitiveType + F8E4M3B11FNUZ: PrimitiveType + F8E4M3FNUZ: PrimitiveType + F8E5M2: PrimitiveType + F8E5M2FNUZ: PrimitiveType + F8E8M0FNU: PrimitiveType + BF16: PrimitiveType + F16: PrimitiveType + F32: PrimitiveType + F64: PrimitiveType + C64: PrimitiveType + C128: PrimitiveType + TUPLE: PrimitiveType + OPAQUE_TYPE: PrimitiveType + TOKEN: PrimitiveType + +# === BEGIN xla_compiler.cc + +class ArrayCopySemantics(enum.IntEnum): + ALWAYS_COPY: ArrayCopySemantics + REUSE_INPUT: ArrayCopySemantics + DONATE_INPUT: ArrayCopySemantics + +class Layout: + @overload + def __init__(self, minor_to_major: Tuple[int, ...]): ... + @overload + def __init__(self, minor_to_major: Tuple[int, ...], + tiling: Tuple[Tuple[int, ...], ...], + element_size_in_bits: int): ... + def minor_to_major(self) -> Tuple[int, ...]: ... + def tiling(self) -> Sequence[Tuple[int, ...]]: ... + def element_size_in_bits(self) -> int: ... + def to_string(self) -> str: ... + def __eq__(self, other: Layout) -> bool: ... + def __ne__(self, other: Layout) -> bool: ... + def __hash__(self) -> int: ... + +class Shape: + def __init__(self, s: str): ... + @staticmethod + def tuple_shape(shapes: Sequence[Shape]) -> Shape: ... + @staticmethod + def array_shape( + type: Union[np.dtype, PrimitiveType], + dims_seq: Any = ..., + layout_seq: Any = ..., + dynamic_dimensions: Optional[List[bool]] = ..., + ) -> Shape: ... + @staticmethod + def token_shape() -> Shape: ... + @staticmethod + def scalar_shape(type: Union[np.dtype, PrimitiveType]) -> Shape: ... + def dimensions(self) -> Tuple[int, ...]: ... + def layout(self) -> Layout: ... + def xla_element_type(self) -> PrimitiveType: ... + def element_type(self) -> np.dtype: ... + def numpy_dtype(self) -> np.dtype: ... + def is_tuple(self) -> bool: ... + def is_array(self) -> bool: ... + def is_token(self) -> bool: ... + def is_static(self) -> bool: ... + def is_dynamic(self) -> bool: ... + def is_dynamic_dimension(self, dimension: int) -> bool: ... + def set_dynamic_dimension(self, dimension: int, is_dynamic: bool) -> None: ... + def rank(self) -> int: ... + def to_serialized_proto(self) -> bytes: ... + def tuple_shapes(self) -> List[Shape]: ... + def leaf_count(self) -> int: ... + def with_major_to_minor_layout_if_absent(self) -> Shape: ... + def __eq__(self, other: Shape) -> bool: ... + def __ne__(self, other: Shape) -> bool: ... + def __hash__(self) -> int: ... + def __repr__(self) -> str: ... + +class ProgramShape: + def __init__(self, params: Sequence[Shape], result: Shape) -> None: ... + def parameter_shapes(self) -> List[Shape]: ... + def result_shape(self) -> Shape: ... + def __repr__(self) -> str: ... + +class ShapeIndex: + def __init__(self, indices: List[int]) -> ShapeIndex: ... + def __eq__(self, other: Shape) -> bool: ... + def __ne__(self, other: Shape) -> bool: ... + def __hash__(self) -> int: ... + def __repr__(self) -> str: ... + +class Literal: + def __init__(self, shape: Shape) -> Literal: ... + def __repr__(self) -> str: ... + def __array__( + self, dtype: Optional[np.dtype] = None, copy: Optional[bool] = None + ) -> np.ndarray: ... + def shape(self) -> Shape: ... + +class XlaComputation: + def __init__(self, serialized_hlo_module_proto: bytes) -> None: ... + def get_hlo_module(self) -> HloModule: ... + def program_shape(self) -> ProgramShape: ... + def as_serialized_hlo_module_proto(self) -> bytes: ... + def as_hlo_text(self, print_large_constants: bool = False) -> str: ... + def as_hlo_dot_graph(self) -> str: ... + def hash(self) -> int: ... + def as_hlo_module(self) -> HloModule: ... + +class HloPrintOptions: + def __init__(self) -> None: ... + @staticmethod + def short_parsable() -> HloPrintOptions: ... + @staticmethod + def canonical() -> HloPrintOptions: ... + @staticmethod + def fingerprint() -> HloPrintOptions: ... + print_large_constants: bool + print_metadata: bool + print_backend_config: bool + print_result_shape: bool + print_operand_shape: bool + print_operand_names: bool + print_ids: bool + print_extra_attributes: bool + print_program_shape: bool + print_percent: bool + print_control_dependencies: bool + compact_operands: bool + include_layout_in_shapes: bool + canonicalize_instruction_names: bool + canonicalize_computations: bool + indent_amount: int + is_in_nested_computation: bool + +class HloComputation: + def render_html(self) -> None: ... + +class HloModule: + spmd_output_sharding: Optional[OpSharding] + spmd_parameters_shardings: Optional[List[OpSharding]] + @property + def name(self) -> str: ... + def to_string(self, options: HloPrintOptions = ...) -> str: ... + def as_serialized_hlo_module_proto(self) -> bytes: ... + @staticmethod + def from_serialized_hlo_module_proto( + serialized_hlo_module_proto: bytes, + ) -> HloModule: ... + def computations(self) -> List[HloComputation]: ... + +class HloModuleGroup: + def __init__(self, name: str, modules: List[HloModule]) -> None: ... + @property + def name(self) -> str: ... + def to_string(self) -> str: ... + def to_modules(self) -> List[HloModule]: ... + +def hlo_module_to_dot_graph(hlo_module: HloModule) -> str: ... +def hlo_module_from_text(hlo_module_text: str) -> HloModule: ... +def hlo_module_cost_analysis( + client: Client, module: HloModule +) -> Dict[str, float]: ... + +class XlaOp: ... + +class XlaBuilder: + def __init__(self, name: str) -> None: ... + def Build(self, root: Optional[XlaOp] = ...) -> XlaComputation: ... + def GetShape(self, __op: XlaOp) -> Shape: ... + build = Build + def clear_op_metadata(self) -> None: ... + get_shape = GetShape + def get_program_shape(self, root: Optional[XlaOp] = ...) -> ProgramShape: ... + def is_constant(self, __op: XlaOp) -> bool: ... + def set_op_metadata(self, metadata: _XlaOpMetadata) -> None: ... + def set_sharding(self, sharding: OpSharding_Type) -> None: ... + def clear_sharding(self) -> None: ... + def setup_alias( + self, + __output_index: Sequence[int], + __param_number: int, + __param_index: Sequence[int], + ) -> None: ... + +class DeviceAssignment: + @staticmethod + def create(array: np.ndarray) -> DeviceAssignment: ... + def replica_count(self) -> int: ... + def computation_count(self) -> int: ... + def __repr__(self) -> str: ... + def serialize(self) -> bytes: ... + +class CompileOptions: + @staticmethod + def ParseFromString(s: bytes) -> CompileOptions: ... + def __init__(self) -> None: ... + def SerializeAsString(self) -> bytes: ... + argument_layouts: Optional[List[Shape]] + parameter_is_tupled_arguments: bool + executable_build_options: ExecutableBuildOptions + tuple_arguments: bool + num_replicas: int + num_partitions: int + profile_version: int + device_assignment: Optional[DeviceAssignment] + compile_portable_executable: bool + env_option_overrides: List[Tuple[str, str]] + +def register_custom_call_target( + fn_name: str, capsule: Any, platform: str, api_version: int = ..., +) -> _Status: ... +def register_custom_call_partitioner( + name: str, + prop_user_sharding: Callable, + partition: Callable, + infer_sharding_from_operands: Callable, + can_side_effecting_have_replicated_sharding: bool = ..., + c_api: Optional[Any] = ..., +) -> None: ... +def encode_inspect_sharding_callback(handler: Any) -> bytes: ... +def register_custom_call_as_batch_partitionable( + target_name: str, + c_api: Optional[Any] = ..., +) -> None: ... + +def register_custom_type_id(type_name: str, type_id: Any) -> None: ... + +class AutotuneCacheMode(enum.IntEnum): + UNSPECIFIED: AutotuneCacheMode + UPDATE: AutotuneCacheMode + READ: AutotuneCacheMode + +class DebugOptions: + def __repr__(self) -> str: ... + xla_cpu_enable_fast_math: bool + xla_cpu_fast_math_honor_infs: bool + xla_cpu_fast_math_honor_nans: bool + xla_cpu_fast_math_honor_division: bool + xla_cpu_fast_math_honor_functions: bool + xla_gpu_enable_fast_min_max: bool + xla_backend_optimization_level: int + xla_cpu_enable_xprof_traceme: bool + xla_llvm_disable_expensive_passes: bool + xla_test_all_input_layouts: bool + xla_disable_hlo_passes: str + xla_enable_hlo_passes_only: str + xla_force_host_platform_device_count: int + xla_dump_to: str + xla_dump_hlo_module_re: str + xla_dump_hlo_pass_re: str + xla_dump_hlo_as_text: bool + xla_dump_hlo_as_proto: bool + xla_dump_hlo_as_dot: bool + xla_dump_hlo_as_url: bool + xla_dump_hlo_as_html: bool + xla_dump_fusion_visualization: bool + xla_dump_hlo_snapshots: bool + xla_dump_max_hlo_modules: bool + xla_dump_module_metadata: bool + xla_dump_compress_protos: bool + xla_dump_hlo_as_long_text: bool + xla_dump_disable_metadata: bool + xla_dump_hlo_pipeline_re: str + xla_gpu_cuda_data_dir: str + xla_detailed_logging: bool + xla_enable_dumping: bool + xla_gpu_dump_autotune_results_to: str + xla_gpu_load_autotune_results_from: str + xla_gpu_dump_autotune_logs_to: str + xla_gpu_kernel_cache_file: str + xla_gpu_enable_llvm_module_compilation_parallelism: bool + xla_gpu_per_fusion_autotune_cache_dir: str + xla_gpu_experimental_autotune_cache_mode: AutotuneCacheMode + +class CompiledMemoryStats: + generated_code_size_in_bytes: int + argument_size_in_bytes: int + output_size_in_bytes: int + alias_size_in_bytes: int + temp_size_in_bytes: int + host_generated_code_size_in_bytes: int + host_argument_size_in_bytes: int + host_output_size_in_bytes: int + host_alias_size_in_bytes: int + host_temp_size_in_bytes: int + serialized_hlo_proto: bytes + def __str__(self) -> str: ... + +class ExecutableBuildOptions: + def __init__(self) -> None: ... + def __repr__(self) -> str: ... + result_layout: Optional[Shape] + fdo_profile: Optional[bytes] + num_replicas: int + num_partitions: int + debug_options: DebugOptions + device_assignment: Optional[DeviceAssignment] + use_spmd_partitioning: bool + use_auto_spmd_partitioning: bool + auto_spmd_partitioning_mesh_shape: List[int] + auto_spmd_partitioning_mesh_ids: List[int] + use_shardy_partitioner: bool + def compilation_environments_from_serialized_proto(self, serialized_proto: bytes) -> None: ... + +class PrecisionConfig_Precision(enum.IntEnum): + DEFAULT: int + HIGH: int + HIGHEST: int + + +class ResultAccuracy_Mode(enum.IntEnum): + DEFAULT: int + HIGHEST: int + TOLERANCE: int + +class ResultAccuracy: + mode: ResultAccuracy_Mode + atol: float + rtol: float + ulps: int + +class OpSharding_Type(enum.IntEnum): + REPLICATED: int + MAXIMAL: int + TUPLE: int + OTHER: int + MANUAL: int + UNKNOWN: int + +class OpSharding_ShardGroupType(enum.IntEnum): + AS: int + LIKE: int + +class OpSharding: + Type: typing.Type[OpSharding_Type] + type: OpSharding_Type + replicate_on_last_tile_dim: bool + last_tile_dims: Sequence[Type] + tile_assignment_dimensions: Sequence[int] + tile_assignment_devices: Sequence[int] + iota_reshape_dims: Sequence[int] + iota_transpose_perm: Sequence[int] + tuple_shardings: Sequence[OpSharding] + is_shard_group: bool + shard_group_id: int + ShardGroupType: typing.Type[OpSharding_ShardGroupType] + shard_group_type: OpSharding_ShardGroupType + def ParseFromString(self, s: bytes) -> None: ... + def SerializeToString(self) -> bytes: ... + def clone(self) -> OpSharding: ... + +class HloSharding: + @staticmethod + def from_proto(proto: OpSharding) -> HloSharding: ... + @staticmethod + def from_string(sharding: str) -> HloSharding: ... + @staticmethod + def tuple_sharding( + shape: Shape, shardings: Sequence[HloSharding] + ) -> HloSharding: ... + @staticmethod + def iota_tile( + dims: Sequence[int], + reshape_dims: Sequence[int], + transpose_perm: Sequence[int], + subgroup_types: Sequence[OpSharding.Type], + ) -> HloSharding: ... + @staticmethod + def replicate() -> HloSharding: ... + @staticmethod + def manual() -> HloSharding: ... + @staticmethod + def unknown() -> HloSharding: ... + @staticmethod + def subgroup_with_device_ordering( + tile_assignment: np.ndarray, + subgroup_types: Sequence[OpSharding.Type]) -> HloSharding: ... + def __eq__(self, other: HloSharding) -> bool: ... + def __hash__(self) -> int: ... + def __repr__(self) -> str: ... + def tile(self, shape: Shape) -> Shape: ... + def is_replicated(self) -> bool: ... + def is_manual(self) -> bool: ... + def is_unknown(self) -> bool: ... + def is_tiled(self) -> bool: ... + def is_maximal(self) -> bool: ... + def tuple_elements(self) -> List[HloSharding]: ... + def num_devices(self) -> int: ... + def num_dimensions(self) -> int: ... + def tile_assignment_dimensions(self) -> Sequence[int]: ... + def tile_assignment_devices(self) -> Sequence[int]: ... + def subgroup_types(self) -> Sequence[OpSharding.Type]: ... + def replicate_on_last_tile_dim(self) -> bool: ... + def to_proto(self) -> OpSharding: ... + +class FftType(enum.IntEnum): + FFT: FftType + IFFT: FftType + RFFT: FftType + IRFFT: FftType + +# === END xla_compiler.cc + +class Device: + id: int + host_id: int + process_index: int + platform: str + device_kind: str + client: Client + local_hardware_id: int | None + def __repr__(self) -> str: ... + def __str__(self) -> str: ... + def transfer_to_infeed(self, literal: _LiteralSlice): ... + def transfer_from_outfeed(self, shape: Shape): ... + def memory(self, kind: str) -> Memory: ... + def default_memory(self) -> Memory: ... + def addressable_memories(self) -> List[Memory]: ... + def live_buffers(self) -> List[Any]: ... + def memory_stats(self) -> Optional[Dict[str, int]]: ... + def get_stream_for_external_ready_events(self) -> int: ... + def __getattr__(self, name: str) -> Any: ... + +class Memory: + process_index: int + platform: str + kind: str + def __repr__(self) -> str: ... + def __str__(self) -> str: ... + def addressable_by_devices(self) -> List[Device]: ... + +class PjRtLayout: + def __str__(self) -> str: ... + def __eq__(self, other: PjRtLayout) -> bool: ... + def __hash__(self) -> int: ... + def __getstate__(self) -> Any: ... + def __setstate__(self, _: Any): ... + def _xla_layout(self) -> Layout: ... + +class GpuAllocatorConfig: + class Kind(enum.IntEnum): + DEFAULT: int + PLATFORM: int + BFC: int + CUDA_ASYNC: int + + def __init__( + self, + kind: Kind = ..., + memory_fraction: float = ..., + preallocate: bool = ..., + collective_memory_size: int = ..., + ) -> None: ... + +class HostBufferSemantics(enum.IntEnum): + IMMUTABLE_ONLY_DURING_CALL: HostBufferSemantics + IMMUTABLE_UNTIL_TRANSFER_COMPLETES: HostBufferSemantics + ZERO_COPY: HostBufferSemantics + +class Client: + platform: str + _raw_platform: str + platform_version: str + runtime_type: str + def device_count(self) -> int: ... + def local_device_count(self) -> int: ... + def devices(self) -> List[Device]: ... + def local_devices(self) -> List[Device]: ... + def _get_all_devices(self) -> List[Device]: ... + def device_from_local_hardware_id(self, int) -> Device: ... + def live_buffers(self) -> List[Any]: ... + def live_arrays(self) -> List[ArrayImpl]: ... + def live_executables(self) -> List[LoadedExecutable]: ... + def host_id(self) -> int: ... + def process_index(self) -> int: ... + def buffer_from_pyval( + self, + argument: Any, + device: Optional[Device] = ..., + force_copy: bool = ..., + host_buffer_semantics: HostBufferSemantics = ..., + ) -> ArrayImpl: ... + def compile( + self, + computation: Union[str, bytes], + compile_options: CompileOptions = ..., + host_callbacks: Sequence[Any] = ..., + ) -> LoadedExecutable: ... + def compile_ifrt_program( + self, + program: ifrt_programs.Program, + program_options: ifrt_programs.CompileOptions, + ) -> LoadedExecutable: ... + def serialize_executable(self, executable: LoadedExecutable) -> bytes: ... + def deserialize_executable( + self, + serialized: bytes, + options: Optional[CompileOptions], + host_callbacks: Sequence[Any] = ..., + ) -> LoadedExecutable: ... + def heap_profile(self) -> bytes: ... + def make_python_callback_from_host_send_and_recv( + self, + callable: Callable, + operand_shapes: Sequence[Shape], + result_shapes: Sequence[Shape], + send_channel_ids: Sequence[int], + recv_channel_ids: Sequence[int], + serializer: Optional[Callable] = ..., + ) -> Any: ... + def get_default_layout( + self, dtype: np.dtype, shard_shape: Sequence[int], device: Device + ) -> PjRtLayout: ... + def __getattr__(self, name: str) -> Any: ... + +class CpuCollectives: ... + +def make_gloo_tcp_collectives( + distributed_client: Optional[DistributedRuntimeClient] = ..., + hostname: Optional[str] = ..., + interface: Optional[str] = ..., +) -> CpuCollectives: ... + +class MpiCollectives(CpuCollectives): + def Init(self): ... + def Finalize(self): ... + +def make_mpi_collectives() -> MpiCollectives: ... + +def get_tfrt_cpu_client( + asynchronous: bool = ..., + distributed_client: Optional[DistributedRuntimeClient] = ..., + node_id: int = ..., + num_nodes: int = ..., + collectives: Optional[CpuCollectives] = ..., + num_devices: int | None = ..., +) -> Client: ... +def get_gpu_client( + asynchronous: bool = ..., + allocator_config: GpuAllocatorConfig = ..., + distributed_client: Optional[DistributedRuntimeClient] = ..., + node_id: int = ..., + num_nodes: int = ..., + allowed_devices: Optional[Any] = ..., + platform_name: Optional[str] = ..., + mock: Optional[bool] = ..., + mock_gpu_topology: Optional[str] = ..., +) -> Client: ... +def get_mock_gpu_client( + asynchronous: bool = ..., + allocator_config: GpuAllocatorConfig = ..., + distributed_client: Optional[DistributedRuntimeClient] = ..., + node_id: int = ..., + allowed_devices: Optional[Any] = ..., + platform_name: Optional[str] = ..., +) -> Client: ... +def get_c_api_client( + platform_name: str, + options: Dict[str, Union[str, int, List[int], float, bool]], + distributed_client: Optional[DistributedRuntimeClient] = ..., +) -> Client: ... +def get_default_c_api_topology( + platform_name: str, + topology_name: str, + options: Dict[str, Union[str, int, List[int], float]], +) -> DeviceTopology: ... +def get_c_api_topology( + c_api: Any, + topology_name: str, + options: Dict[str, Union[str, int, List[int], float]], +) -> DeviceTopology: ... +def get_topology_for_devices(devices: List[Device]) -> DeviceTopology: ... +def load_pjrt_plugin(platform_name: str, library_path: Optional[str], c_api: Optional[Any]) -> _Status: ... +def pjrt_plugin_loaded(plugin_name: str) -> bool: ... +def pjrt_plugin_initialized(plugin_name: str) -> bool: ... +def initialize_pjrt_plugin(platform_name: str) -> _Status: ... + +Array = Any +ArrayImpl = Any + +# TODO(phawkins): this type is problematic because it is not a subtype of +# jax.Array, and pytype notices. +# class ArrayImpl: +# def __init__(self, +# aval: Any, +# sharding: Any, +# arrays: Sequence[ArrayImpl], +# committed: bool, +# _skip_checks: bool = ...): ... +# def block_until_ready(self) -> ArrayImpl: ... +# def is_deleted(self) -> bool: ... +# def is_ready(self) -> bool: ... +# def delete(self): ... +# def unsafe_buffer_pointer(self) -> Any: ... +# def clone(self) -> ArrayImpl: ... +# def _copy_single_device_array_to_host_async(self): ... +# def _single_device_array_to_np_array_did_copy(self) -> tuple[np.ndarray, bool]: ... +# def on_device_size_in_bytes(self) -> int: ... +# def _fully_replicated_shard(self) -> ArrayImpl: ... +# __cuda_array_interface__: Dict[str, Any] +# dtype: np.dtype +# shape: Tuple[int, ...] +# _arrays: Any +# _npy_value: Any +# traceback: Traceback +# _HAS_DYNAMIC_ATTRIBUTES: bool = ... + +def batched_copy_array_to_devices_with_sharding( + arrays: Sequence[ArrayImpl], + devices: Sequence[List[Device]], + sharding: Sequence[Any], + array_copy_semantics: Sequence[ArrayCopySemantics], +) -> Sequence[ArrayImpl]: ... + +def batched_block_until_ready(x: Sequence[ArrayImpl]) -> None: ... + +def batched_device_put( + aval: Any, + sharding: Any, + shards: Sequence[Any], + devices: List[Device], + committed: bool = True, +) -> ArrayImpl: ... + +def reorder_shards( + x: ArrayImpl, + dst_sharding: Any, + array_copy_semantics: ArrayCopySemantics, +) -> ArrayImpl: ... + +def check_and_canonicalize_memory_kind( + memory_kind: Optional[str], device_list: DeviceList +) -> Optional[str]: ... +def array_result_handler( + aval: Any, sharding: Any, committed: bool, _skip_checks: bool = ... +) -> Callable: ... + +class Token: + def block_until_ready(self): ... + +class ShardedToken: + def block_until_ready(self): ... + def get_token(self, device_id: int): ... + +class ExecuteResults: + def __len__(self) -> int: ... + def disassemble_into_single_device_arrays(self) -> List[List[ArrayImpl]]: ... + def disassemble_prefix_into_single_device_arrays( + self, n: int + ) -> List[List[ArrayImpl]]: ... + def consume_with_handlers(self, handlers: List[Callable]) -> List[Any]: ... + def consume_token(self) -> ShardedToken: ... + +class LoadedExecutable: + client: Client + def local_devices(self) -> List[Device]: ... + def size_of_generated_code_in_bytes(self) -> int: ... + def delete(self) -> None: ... + def execute(self, arguments: Sequence[ArrayImpl]) -> List[ArrayImpl]: ... + def execute_with_token( + self, arguments: Sequence[ArrayImpl] + ) -> Tuple[List[ArrayImpl], Token]: ... + def execute_sharded( + self, arguments: Sequence[List[ArrayImpl]], with_tokens: bool = ... + ) -> ExecuteResults: ... + def hlo_modules(self) -> List[HloModule]: ... + def get_output_memory_kinds(self) -> List[List[str]]: ... + def get_compiled_memory_stats(self) -> CompiledMemoryStats: ... + def get_output_shardings(self) -> Optional[List[OpSharding]]: ... + def get_parameter_shardings(self) -> Optional[List[OpSharding]]: ... + def get_parameter_layouts(self) -> List[Layout]: ... + def get_output_layouts(self) -> List[Layout]: ... + def keep_alive(self) -> None: ... + def cost_analysis(self) -> Dict[str, Any]: ... + traceback: Traceback + fingerprint: Optional[bytes] + +class Executable: + def hlo_modules(self) -> List[HloModule]: ... + def get_output_memory_kinds(self) -> List[List[str]]: ... + def get_output_shardings(self) -> Optional[List[OpSharding]]: ... + def get_parameter_shardings(self) -> Optional[List[OpSharding]]: ... + def get_parameter_layouts(self) -> List[Layout]: ... + def get_output_layouts(self) -> List[Layout]: ... + def get_compiled_memory_stats(self) -> CompiledMemoryStats: ... + def serialize(self) -> str: ... + def cost_analysis(self) -> Dict[str, Any]: ... + +class DeviceTopology: + platform: str + platform_version: str + def _make_compile_only_devices(self) -> List[Device]: ... + def serialize(self) -> bytes: ... + def __getattr__(self, name: str) -> Any: ... + +def buffer_to_dlpack_managed_tensor( + buffer: ArrayImpl, stream: int | None = None +) -> Any: ... +@overload +def dlpack_managed_tensor_to_buffer( + tensor: Any, device: Device, stream: int | None +) -> ArrayImpl: ... +@overload +def dlpack_managed_tensor_to_buffer( # Legacy overload + tensor: Any, + cpu_backend: Optional[Client] = ..., + gpu_backend: Optional[Client] = ..., +) -> ArrayImpl: ... + +def cuda_array_interface_to_buffer( + cai: Dict[str, Union[ + str, int, None, + Tuple[int, ...], Tuple[int, bool], + List[Tuple[str, str]], + List[Tuple[str, str, Tuple[int, ...]]]] + ], + gpu_backend: Optional[Client] = ..., + device_id: int | None = None, +) -> ArrayImpl: ... + +# === BEGIN py_traceback.cc + +class Frame: + file_name: str + function_name: str + function_line_start: int + line_num: int + def __init__(self, + file_name: str, + function_name: str, + function_line_start: int, + line_num: int): ... + def __repr__(self) -> str: ... + +class Traceback: + enabled: ClassVar[bool] + @staticmethod + def get_traceback() -> Traceback: ... + @staticmethod + def traceback_from_frames(frames: Sequence[Frame]) -> Any: ... + frames: Sequence[Frame] + def __str__(self) -> str: ... + def as_python_traceback(self) -> Any: ... + def raw_frames(self) -> Tuple[List[types.CodeType], List[int]]: ... + @staticmethod + def code_addr2line(code: types.CodeType, lasti: int) -> int: ... + @staticmethod + def code_addr2location( + code: types.CodeType, lasti: int + ) -> Tuple[int, int, int, int]: ... + +def replace_thread_exc_traceback(traceback: Any): ... + +# === END py_traceback.cc + +class DistributedRuntimeService: + def shutdown(self) -> None: ... + +class DistributedRuntimeClient: + def connect(self) -> _Status: ... + def shutdown(self) -> _Status: ... + def blocking_key_value_get(self, key: str, timeout_in_ms: int) -> _Status: ... + def blocking_key_value_get_bytes( + self, key: str, timeout_in_ms: int + ) -> _Status: ... + def key_value_try_get(self, key: str) -> _Status: ... + def key_value_try_get_bytes(self, key: str) -> _Status: ... + def key_value_dir_get(self, key: str) -> _Status: ... + def key_value_dir_get_bytes(self, key: str) -> _Status: ... + def key_value_set(self, key: str, value: str, + allow_overwrite: bool = False) -> _Status: ... + def key_value_set_bytes(self, key: str, value: bytes, + allow_overwrite: bool = False) -> _Status: ... + def key_value_delete(self, key: str) -> _Status: ... + def wait_at_barrier( + self, barrier_id: str, timeout_in_ms: int, process_ids: Optional[List[int]] + ) -> _Status: ... + def get_live_nodes(self, process_ids: List[int]) -> _Status: ... + +def get_distributed_runtime_service( + address: str, + num_nodes: int, + heartbeat_interval: Optional[int] = ..., + max_missing_heartbeats: Optional[int] = ..., + cluster_register_timeout: Optional[int] = ..., + shutdown_timeout: Optional[int] = ..., +) -> DistributedRuntimeService: ... +def get_distributed_runtime_client( + address: str, + node_id: int, + rpc_timeout: Optional[int] = ..., + init_timeout: Optional[int] = ..., + shutdown_timeout: Optional[int] = ..., + heartbeat_interval: Optional[int] = ..., + max_missing_heartbeats: Optional[int] = ..., + missed_heartbeat_callback: Optional[Any] = ..., + shutdown_on_destruction: Optional[bool] = ..., + use_compression: Optional[bool] = ..., +) -> DistributedRuntimeClient: ... + +class PreemptionSyncManager: + def initialize(self, client: DistributedRuntimeClient) -> _Status: ... + def reached_sync_point(self, step_counter: int) -> bool: ... + +def create_preemption_sync_manager() -> PreemptionSyncManager: ... +def collect_garbage() -> None: ... +def is_optimized_build() -> bool: ... +def json_to_pprof_profile(json: str) -> bytes: ... +def pprof_profile_to_json(proto: bytes) -> str: ... + +class PmapFunction: + def __call__(self, *args, **kwargs) -> Any: ... + def __getstate__(self) -> Any: ... + def __setstate__(self, Any): ... + __signature__: inspect.Signature + def _cache_size(self) -> int: ... + def _cache_clear(self) -> None: ... + +class DeviceList: + def __init__(self, device_assignment: Tuple[Device, ...]): ... + def __hash__(self) -> int: ... + def __eq__(self, other: Any) -> bool: ... + def __ne__(self, other: Any) -> bool: ... + def __len__(self) -> int: ... + def __getitem__(self, index: Any) -> Any: ... + def __iter__(self) -> Iterator[Device]: ... + def __str__(self) -> str: ... + def __repr__(self) -> str: ... + def __getstate__(self) -> Any: ... + def __setstate__(self, state: Any): ... + @property + def is_fully_addressable(self) -> bool: ... + @property + def addressable_device_list(self) -> DeviceList: ... + @property + def default_memory_kind(self) -> Optional[str]: ... + @property + def memory_kinds(self) -> Tuple[str, ...]: ... + +class Sharding: ... + +class NamedSharding(Sharding): + def __init__( + self, + mesh: Any, + spec: Any, + *, + memory_kind: Optional[str] = None, + _logical_device_ids: tuple[int, ...] | None = None, + ): ... + mesh: Any + spec: Any + _memory_kind: Optional[str] + _internal_device_list: DeviceList + _logical_device_ids: tuple[int, ...] | None + +class SingleDeviceSharding(Sharding): + def __init__(self, device: Device, *, memory_kind: Optional[str] = None): ... + _device: Device + _memory_kind: Optional[str] + _internal_device_list: DeviceList + +class PmapSharding(Sharding): + def __init__( + self, devices: Sequence[Any], sharding_spec: pmap_lib.ShardingSpec + ): ... + devices: List[Any] + sharding_spec: pmap_lib.ShardingSpec + _internal_device_list: DeviceList + +class GSPMDSharding(Sharding): + def __init__( + self, + devices: Sequence[Device], + op_sharding: Union[OpSharding, HloSharding], + *, + memory_kind: Optional[str] = None, + _device_list: Optional[DeviceList] = None, + ): ... + _devices: Tuple[Device, ...] + _hlo_sharding: HloSharding + _memory_kind: Optional[str] + _internal_device_list: DeviceList + +class PjitFunction: + def __call__(self, *args, **kwargs) -> Any: ... + +class PjitFunctionCache: + def __init__(self, capacity: int = ...): ... + def __getstate__(self) -> Any: ... + def __setstate__(self, Any): ... + def size(self) -> int: ... + def capacity(self) -> int: ... + def clear(self): ... + @staticmethod + def clear_all(): ... + +def pjit( + function_name: str, + fun: Optional[Callable], + cache_miss: Callable, + static_argnums: Sequence[int], + static_argnames: Sequence[str], + global_cache_key: Any, + pytree_registry: pytree.PyTreeRegistry, + shard_arg_fallback: Callable, + cache: Optional[PjitFunctionCache] = ..., +) -> PjitFunction: ... + +class HloPassInterface: + @property + def name(self) -> str: ... + def is_pass_pipeline(self) -> bool: ... + def run(self, module: HloModule) -> bool: ... + def run_on_module_group(self, module_group: HloModuleGroup) -> bool: ... + +class HloDCE(HloPassInterface): + def __init__(self) -> None: ... + +class CallInliner(HloPassInterface): + def __init__(self) -> None: ... + +class FlattenCallGraph(HloPassInterface): + def __init__(self) -> None: ... + +class TupleSimplifer(HloPassInterface): + def __init__(self) -> None: ... + +class WeakrefLRUCacheInfo: + @property + def hits(self) -> int: ... + @property + def misses(self) -> int: ... + @property + def maxsize(self) -> int: ... + @property + def currsize(self) -> int: ... + +class WeakrefLRUCache: + def __call__(self, weakref_key: Any, *args, **kwargs) -> Any: ... + def cache_keys(self) -> list[Any]: ... + def cache_info(self) -> WeakrefLRUCacheInfo: ... + def cache_clear(self): ... + +def is_asan() -> bool: ... +def is_msan() -> bool: ... +def is_tsan() -> bool: ... +def is_sanitized() -> bool: ... + +class TransferConnection: + + def address(self) -> str: ... + + def _pull_flat(self, uuid, backend, avals_flat) -> list[Any]: ... + +class TransferServer: + def _await_pull_flat(self, uuid, args: list[ArrayImpl]): ... + + def connect(self, address: str) -> TransferConnection: ... + +def start_transfer_server(client: Client, address: str = "", transport_addresses: list[str] = [], max_num_parallel_copies: int = 0, transfer_size: int = 0) -> TransferServer: ... diff --git a/jaxlib/xla/xla_extension/config.pyi b/jaxlib/xla/xla_extension/config.pyi new file mode 100644 index 000000000000..535554559180 --- /dev/null +++ b/jaxlib/xla/xla_extension/config.pyi @@ -0,0 +1,32 @@ +# Copyright 2024 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import Any, Generic, TypeVar + +unset: object + +_T = TypeVar('_T') + +class Config(Generic[_T]): + def __init__(self, value: _T, include_in_jit_key: bool = False): ... + + @property + def value(self) -> _T: ... + + def get_local(self) -> Any: ... + def get_global(self) -> _T: ... + def set_local(self, value: Any) -> None: ... + def swap_local(self, value: Any) -> Any: ... + def set_global(self, value: _T) -> None: ... diff --git a/jaxlib/xla/xla_extension/guard_lib.pyi b/jaxlib/xla/xla_extension/guard_lib.pyi new file mode 100644 index 000000000000..cfa8b0c5fa5e --- /dev/null +++ b/jaxlib/xla/xla_extension/guard_lib.pyi @@ -0,0 +1,46 @@ +# Copyright 2024 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import Any, List, Optional + +class TransferGuardLevel: + ALLOW: Any + LOG: Any + DISALLOW: Any + LOG_EXPLICIT: Any + DISALLOW_EXPLICIT: Any + +class GarbageCollectionGuardLevel: + ALLOW: Any + LOG: Any + FATAL: Any + +class GuardState: + host_to_device: Optional[TransferGuardLevel] + device_to_device: Optional[TransferGuardLevel] + device_to_host: Optional[TransferGuardLevel] + + explicit_device_put: bool + explicit_device_get: bool + + garbage_collect_array: Optional[GarbageCollectionGuardLevel] + +def global_state() -> GuardState: ... +def thread_local_state() -> GuardState: ... + +class _TestingScopedLogSink: + def __enter__(self) -> _TestingScopedLogSink: ... + def __exit__(self, *args, **kwargs) -> None: ... + def logs(self) -> List[str]: ... diff --git a/jaxlib/xla/xla_extension/ifrt_programs.pyi b/jaxlib/xla/xla_extension/ifrt_programs.pyi new file mode 100644 index 000000000000..bcee365e5732 --- /dev/null +++ b/jaxlib/xla/xla_extension/ifrt_programs.pyi @@ -0,0 +1,43 @@ +# Copyright 2024 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import Any, Sequence, Union + +from jax.jaxlib.xla import xla_extension + +class Program: ... + +class CompileOptions: ... + +def make_hlo_program(mlir_module: Union[str, bytes]) -> Program: ... + +def make_colocated_python_program( + name : str, + picked_function: bytes, + devices: Sequence[xla_extension.Device] | xla_extension.DeviceList, + input_avals: Sequence[Any], + output_avals: Sequence[Any], +) -> Program: ... + +def make_plugin_program(data: Union[str, bytes]) -> Program: ... + +def make_colocated_python_compile_options() -> CompileOptions: ... + +def make_xla_compile_options( + compile_options: xla_extension.CompileOptions, + host_callbacks: Sequence[Any] +) -> CompileOptions: ... + +def make_plugin_compile_options() -> CompileOptions: ... diff --git a/jaxlib/xla/xla_extension/ifrt_proxy.pyi b/jaxlib/xla/xla_extension/ifrt_proxy.pyi new file mode 100644 index 000000000000..3b5de7aa97c9 --- /dev/null +++ b/jaxlib/xla/xla_extension/ifrt_proxy.pyi @@ -0,0 +1,33 @@ +# Copyright 2024 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import Any, Optional, Callable + +from jax.jaxlib.xla import xla_extension + +_Status = Any +Client = xla_extension.Client + + +class ClientConnectionOptions: + on_disconnect: Optional[Callable[[_Status], None]] = None + on_connection_update: Optional[Callable[[str], None]] = None + connection_timeout_in_seconds: Optional[int] = None + + +def get_client( + proxy_server_address: str, + options: ClientConnectionOptions +) -> Client: ... diff --git a/jaxlib/xla/xla_extension/jax_jit.pyi b/jaxlib/xla/xla_extension/jax_jit.pyi new file mode 100644 index 000000000000..1f78d283333c --- /dev/null +++ b/jaxlib/xla/xla_extension/jax_jit.pyi @@ -0,0 +1,76 @@ +# Copyright 2021 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import Any, Callable, Optional, Sequence, Tuple + +import numpy as np +from jax.jaxlib.xla import xla_extension + +from . import pytree + +Client = xla_extension.Client +Device = xla_extension.Device + + +class JitState: + disable_jit: Optional[bool] + enable_x64: Optional[bool] + default_device: Optional[Any] + extra_jit_context: Optional[Any] + post_hook: Optional[Callable[..., Any]] + +def global_state() -> JitState: ... +def thread_local_state() -> JitState: ... + +def get_enable_x64() -> bool: ... +def set_thread_local_state_initialization_callback( + function: Callable[[], None]): ... + +def swap_thread_local_state_disable_jit( + value: Optional[bool]) -> Optional[bool]: ... + +class ArgSignature: + dtype: np.dtype + shape: Tuple[int, ...] + weak_type: bool + +def _ArgSignatureOfValue( + __arg: Any, + __jax_enable_x64: bool) -> ArgSignature: ... + +def _is_float0(__arg: Any) -> bool: ... + + +class ArgumentSignature: + static_args: Sequence[Any] + static_arg_names: Sequence[str] + dynamic_arg_names: Sequence[str] + dynamic_arg_treedefs: Sequence[pytree.PyTreeDef] + + def __eq__(self, value, /): ... + def __ne__(self, value, /): ... + def __hash__(self, /): ... + def __str__(self): ... + def __repr__(self): ... + + +def parse_arguments( + positional_args: Sequence[Any], + keyword_args: Sequence[Any], + kwnames: Tuple[str, ...], + static_argnums: Sequence[int], + static_argnames: Sequence[str], + pytree_registry: pytree.PyTreeRegistry, +) -> tuple[ArgumentSignature, Sequence[Any]]: ... diff --git a/jaxlib/xla/xla_extension/mlir.pyi b/jaxlib/xla/xla_extension/mlir.pyi new file mode 100644 index 000000000000..961f01a0352c --- /dev/null +++ b/jaxlib/xla/xla_extension/mlir.pyi @@ -0,0 +1,35 @@ +# Copyright 2021 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import Union +from . import XlaComputation + +def hlo_to_stablehlo(computation: bytes) -> bytes: ... +def xla_computation_to_mlir_module(computation: XlaComputation) -> str: ... +def mlir_module_to_xla_computation( + mlir_module: Union[bytes, str], + use_tuple_args: bool = ..., + return_tuple: bool = ..., +) -> XlaComputation: ... +def mhlo_to_stablehlo(mlir_module: Union[bytes, str]) -> bytes: ... +def stablehlo_to_mhlo(mlir_module: Union[bytes, str]) -> bytes: ... +def serialize_portable_artifact(mlir_module: str, target: str) -> bytes: ... +def deserialize_portable_artifact(mlir_module: bytes) -> str: ... +def refine_polymorphic_shapes( + mlir_module: Union[bytes, str], + enable_shape_assertions: bool = ..., + validate_static_shapes: bool = ..., + enable_shardy: bool = ..., +) -> bytes: ... diff --git a/jaxlib/xla/xla_extension/ops.pyi b/jaxlib/xla/xla_extension/ops.pyi new file mode 100644 index 000000000000..ff55de3a5cdc --- /dev/null +++ b/jaxlib/xla/xla_extension/ops.pyi @@ -0,0 +1,465 @@ +# Copyright 2021 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import enum +from typing import Any, Optional, Sequence, overload + +from jax.jaxlib.xla import xla_extension + +FftType = xla_extension.FftType +XlaBuilder = xla_extension.XlaBuilder +XlaComputation = xla_extension.XlaComputation +XlaOp = xla_extension.XlaOp +PrecisionConfig_Precision = xla_extension.PrecisionConfig_Precision +PrimitiveType = xla_extension.PrimitiveType +Shape = xla_extension.Shape +ShapeIndex = xla_extension.ShapeIndex +ResultAccuracy = xla_extension.ResultAccuracy + +_ChannelHandle = Any +_ConvDimensionNumbers = Any +_DotDimensionNumbers = Any +_Layout = Any +_LiteralSlice = Any +_GatherDimensionNumbers = Any +_PaddingConfig = Any +_ReplicaGroup = Any +_ScatterDimensionNumbers = Any + +class TriangularSolveOptions_Transpose(enum.IntEnum): + TRANSPOSE_INVALID: int + NO_TRANSPOSE: int + TRANSPOSE: int + ADJOINT: int + +class RandomAlgorithm(enum.IntEnum): + RNG_DEFAULT: int + RNG_THREE_FRY: int + RNG_PHILOX: int + +class CustomCallSchedule(enum.IntEnum): + SCHEDULE_NONE: int + SCHEDULE_LATEST: int + SCHEDULE_EARLIEST: int + +# TODO(b/189822916): Remove this enum when all clients are migrated to the +# status-returning API. +class CustomCallApiVersion(enum.IntEnum): + API_VERSION_ORIGINAL: int + API_VERSION_STATUS_RETURNING: int + API_VERSION_STATUS_RETURNING_UNIFIED: int + API_VERSION_TYPED_FFI: int + +def AfterAll(builder: XlaBuilder, tokens: Sequence[XlaOp]) -> XlaOp: ... +def AllGather( + operand: XlaOp, + all_gather_dimension: int, + shard_count: int, + replica_groups: Sequence[_ReplicaGroup] = ..., + channel_id: Optional[_ChannelHandle] = ..., + shape_with_layout: Optional[_Layout] = ..., + use_global_device_ids: Optional[bool] = ...) -> XlaOp: ... +def AllReduce( + operand: XlaOp, + computation: XlaComputation, + replica_groups: Sequence[_ReplicaGroup] = ..., + channel_id: Optional[_ChannelHandle] = ..., + shape_with_layout: Optional[_Layout] = ...) -> XlaOp: ... +def ApproxTopK( + builder: XlaBuilder, + operands: Sequence[XlaOp], + init_values: Sequence[XlaOp], + top_k: int, + reduction_dim: int, + comparator: XlaComputation, + recall_target: Optional[float], + aggregate_to_topk: Optional[bool], + reduction_input_size_override: Optional[int]) -> XlaOp: ... +def ApproxTopKFallback( + builder: XlaBuilder, + operands: Sequence[XlaOp], + init_values: Sequence[XlaOp], + top_k: int, + reduction_dim: int, + comparator: XlaComputation, + recall_target: Optional[float], + aggregate_to_topk: Optional[bool], + reduction_input_size_override: Optional[int]) -> XlaOp: ... +def ApproxTopKReductionOutputSize( + input_size: int, + rank: int, + top_k: int, + recall_target: float, + aggregate_to_topk: Optional[bool] = ..., + input_size_override: Optional[int] = ...) -> tuple[int, int]: ... +def ReduceScatter( + operand: XlaOp, + computation: XlaComputation, + scatter_dimension: int, + shard_count: int, + replica_groups: Sequence[_ReplicaGroup] = ..., + channel_id: Optional[_ChannelHandle] = ..., + layout: Optional[_Layout] = ..., + use_global_device_ids: Optional[bool] = ...) -> XlaOp: ... +def AllToAll( + operand: XlaOp, + split_dimension: int, + concat_dimension: int, + split_count: int, + replica_groups: Sequence[_ReplicaGroup] = ..., + layout: Optional[_Layout] = ..., + channel_id: Optional[_ChannelHandle] = ...) -> XlaOp: ... +def BitcastConvertType(operand: XlaOp, + new_element_type: PrimitiveType) -> XlaOp: ... +def Broadcast(operand: XlaOp, sizes: Sequence[int]) -> XlaOp: ... +def BroadcastInDim(operand: XlaOp, + shape: Sequence[int], + broadcast_dimensions: Sequence[int]) -> XlaOp: ... +def Call(builder: XlaBuilder, + computation: XlaComputation, + operands: Sequence[XlaOp]) -> XlaOp: ... +def Cholesky(a: XlaOp, lower: bool = ...) -> XlaOp: ... +def Clamp(min: XlaOp, operand: XlaOp, max: XlaOp) -> XlaOp: ... +def Collapse(operand: XlaOp, dimensions: Sequence[int]) -> XlaOp: ... +def CollectivePermute( + operand: XlaOp, + source_target_pairs: Sequence[tuple[int, int]], + channel_id: Optional[_ChannelHandle] = ..., + inplace: bool = ...) -> XlaOp: ... +def ConcatInDim(builder: XlaBuilder, + operands: Sequence[XlaOp], + dimension: int) -> XlaOp: ... +@overload +def Conditional(branch_index: XlaOp, + branch_computations: Sequence[XlaComputation], + branch_operands: Sequence[XlaOp]) -> XlaOp: ... +@overload +def Conditional( + predicate: XlaOp, + true_operand: XlaOp, + true_computation: XlaComputation, + false_operand: XlaOp, + false_computation: XlaComputation) -> XlaOp: ... + +def Constant(builder: XlaBuilder, value: _LiteralSlice) -> XlaOp: ... +def ConstantLiteral(builder: XlaBuilder, value: _LiteralSlice) -> XlaOp: ... +def ConvGeneralDilated( + lhs: XlaOp, + rhs: XlaOp, + window_strides: Sequence[int], + padding: Sequence[tuple[int, int]], + lhs_dilation: Sequence[int], + rhs_dilation: Sequence[int], + dimension_numbers: _ConvDimensionNumbers, + feature_group_count: int = ..., + batch_group_count: int = ..., + precision_config: Optional[PrecisionConfig_Precision] = ..., + preferred_element_type: Optional[PrimitiveType] = ..., + window_reversal: Optional[Sequence[bool]] = ...) -> XlaOp: ... +def ConvertElementType( + operand: XlaOp, + new_element_type: PrimitiveType) -> XlaOp: ... +def CreateToken(builder: XlaBuilder) -> XlaOp: ... +def CrossReplicaSum( + operand: XlaOp, + replica_groups: Sequence[_ReplicaGroup] = ...) -> XlaOp: ... +def CustomCall( + builder: XlaBuilder, + call_target_name: bytes, + operands: Sequence[XlaOp], + shape: Shape, + opaque: bytes = ..., + has_side_effect: bool = ..., + schedule: CustomCallSchedule = ..., + api_version: CustomCallApiVersion = ...) -> XlaOp: ... +def CustomCallWithLayout( + builder: XlaBuilder, + call_target_name: bytes, + operands: Sequence[XlaOp], + shape_with_layout: Shape, + operand_shapes_with_layout: Sequence[Shape], + opaque: bytes = ..., + has_side_effect: bool = ..., + schedule: CustomCallSchedule = ..., + api_version: CustomCallApiVersion = ...) -> XlaOp: ... +def CustomCallWithAliasing( + builder: XlaBuilder, + call_target_name: bytes, + operands: Sequence[XlaOp], + shape_with_layout: Shape, + operand_shapes_with_layout: Sequence[Shape], + opaque: bytes = ..., + has_side_effect: bool = ..., + output_operand_aliasing: Sequence[tuple[ShapeIndex, tuple[int, ShapeIndex]]] = ..., + literal: _LiteralSlice = ..., + schedule: CustomCallSchedule = ..., + api_version: CustomCallApiVersion = ...) -> XlaOp: ... +def Dot( + lhs: XlaOp, + rhs: XlaOp, + precision_config: Optional[PrecisionConfig_Precision] = ..., + preferred_element_type: Optional[PrimitiveType] = ...) -> XlaOp: ... +def DotGeneral( + lhs: XlaOp, + rhs: XlaOp, + dimensions_numbers: _DotDimensionNumbers, + precision_config: Optional[PrecisionConfig_Precision] = ..., + preferred_element_type: Optional[PrimitiveType] = ...) -> XlaOp: ... +def DynamicReshape( + operand: XlaOp, + dim_sizes: Sequence[XlaOp], + new_size_bounds: Sequence[int], + dims_are_dynamic: Sequence[bool]) -> XlaOp: ... +def DynamicSlice( + operand: XlaOp, + start_indices: Sequence[XlaOp], + slice_sizes: Sequence[int]) -> XlaOp: ... +def DynamicUpdateSlice( + operand: XlaOp, + update: XlaOp, + start_indices: Sequence[XlaOp]) -> XlaOp: ... +def Eigh( + a: XlaOp, + lower: bool = ..., + max_iter: int = ..., + epsilon: float = ..., + sort_eigenvalues: bool = ...) -> tuple[XlaOp, XlaOp]: ... +def Fft( + operand: XlaOp, + fft_type: FftType, + fft_length: Sequence[int]) -> XlaOp: ... +def Gather( + a: XlaOp, + start_indices: XlaOp, + dimension_numbers: _GatherDimensionNumbers, + slice_sizes: Sequence[int], + indices_are_sorted: bool = ...) -> XlaOp: ... +def GetDimensionSize(operand: XlaOp, index: int) -> XlaOp: ... +def GetTupleElement(tuple_data: XlaOp, index: int) -> XlaOp: ... +def InfeedWithToken( + token: XlaOp, + shape: Shape, + config: Optional[str] = ...) -> XlaOp: ... +@overload +def Iota(builder: XlaBuilder, shape: Shape, iota_dimension: int) -> XlaOp: ... +@overload +def Iota(builder: XlaBuilder, type: PrimitiveType, size: int) -> XlaOp: ... +def LU(a: XlaOp) -> tuple[XlaOp, XlaOp, XlaOp]: ... +def Map( + builder: XlaBuilder, + operands: Sequence[XlaOp], + computation: XlaComputation, + dimensions: Sequence[int], + static_operands: Sequence[XlaOp] = ...) -> XlaOp: ... +def MultiCollectivePermute( + operands: Sequence[XlaOp], + source_target_pairs: Sequence[tuple[int, int]], + channel_id: Optional[_ChannelHandle] = ..., + inplace: bool = ...) -> XlaOp: ... +def NextAfter(__from: XlaOp, to: XlaOp) -> XlaOp: ... +def OutfeedWithToken( + operand: XlaOp, + token: XlaOp, + shape_with_layout: Shape, + outfeed_config: Optional[str] = ...) -> XlaOp: ... +def Pad( + operand: XlaOp, + padding_value: XlaOp, + padding_config: _PaddingConfig) -> XlaOp: ... +def Parameter( + builder: XlaBuilder, + parameter_number: int, + shape: Shape, + name: str = ..., + replicated_at_leaf_buffers: Sequence[bool] = ...) -> XlaOp: ... +def ProductOfElementaryHouseholderReflectors(a: XlaOp, taus: XlaOp) -> XlaOp: ... +def QR(a: XlaOp, full_matrices: bool) -> tuple[XlaOp, XlaOp]: ... +def QrDecomposition(a: XlaOp) -> tuple[XlaOp, XlaOp]: ... +def Reduce( + builder: XlaBuilder, + operands: Sequence[XlaOp], + init_values: Sequence[XlaOp], + computation: XlaComputation, + dimensions_to_reduce: Sequence[int]) -> XlaOp: ... +def ReducePrecision( + operand: XlaOp, + exponent_bits: int, + mantissa_bits: int) -> XlaOp: ... +@overload +def ReduceWindowWithGeneralPadding( + operand: XlaOp, + init_value: XlaOp, + computation: XlaComputation, + window_dimensions: Sequence[int], + window_strides: Sequence[int], + base_dilations: Sequence[int], + window_dilations: Sequence[int], + padding: Sequence[tuple[int, int]]) -> XlaOp: ... +@overload +def ReduceWindowWithGeneralPadding( + operands: Sequence[XlaOp], + init_values: Sequence[XlaOp], + computation: XlaComputation, + window_dimensions: Sequence[int], + window_strides: Sequence[int], + base_dilations: Sequence[int], + window_dilations: Sequence[int], + padding: Sequence[tuple[int, int]]) -> XlaOp: ... +def ReplicaId(builder: XlaBuilder) -> XlaOp: ... +def Reshape(operand: XlaOp, new_sizes: Sequence[int]) -> XlaOp: ... +def Rev(operand: XlaOp, dimensions: Sequence[int]) -> XlaOp: ... +def RngBitGenerator( + algorithm: RandomAlgorithm, + initial_state: XlaOp, + shape: Shape) -> XlaOp: ... +def RngNormal(mu: XlaOp, sigma: XlaOp, shape: Shape) -> XlaOp: ... +def RngUniform(a: XlaOp, b: XlaOp, shape: Shape) -> XlaOp: ... +@overload +def Scatter( + input: XlaOp, + scatter_indices: XlaOp, + updates: XlaOp, + update_computation: XlaComputation, + dimension_numbers: _ScatterDimensionNumbers, + indices_are_sorted: bool = ..., + unique_indices: bool = ...) -> XlaOp: ... +@overload +def Scatter( + inputs: Sequence[XlaOp], + scatter_indices: XlaOp, + updates: Sequence[XlaOp], + update_computation: XlaComputation, + dimension_numbers: _ScatterDimensionNumbers, + indices_are_sorted: bool = ..., + unique_indices: bool = ...) -> XlaOp: ... +def Select(pred: XlaOp, on_true: XlaOp, on_false: XlaOp) -> XlaOp: ... +def SelectAndScatterWithGeneralPadding( + operand: XlaOp, + select: XlaComputation, + window_dimensions: Sequence[int], + window_strides: Sequence[int], + padding: Sequence[tuple[int, int]], + source: XlaOp, + init_value: XlaOp, + scatter: XlaComputation) -> XlaOp: ... +def Slice( + operand: XlaOp, + start_indices: Sequence[int], + limit_indices: Sequence[int], + strides: Sequence[int]) -> XlaOp: ... +def SliceInDim( + operand: XlaOp, + start_index: int, + limit_index: int, + stride: int, + dimno: int) -> XlaOp: ... +def Sort( + builder: XlaBuilder, + operands: Sequence[XlaOp], + comparator: Optional[XlaComputation] = ..., + dimension: int = ..., + is_stable: bool = ...) -> XlaOp: ... +def SVD( + a: XlaOp, + max_iter: int = ..., + epsilon: float = ...) -> tuple[XlaOp, XlaOp, XlaOp]: ... +def TopK(input: XlaOp, k: int) -> XlaOp: ... +def Transpose(operand: XlaOp, permutation: Sequence[int]) -> XlaOp: ... +def TriangularSolve( + a: XlaOp, + b: XlaOp, + left_side: bool, + lower: bool, + unit_diagonal: bool, + transpose_a: TriangularSolveOptions_Transpose) -> XlaOp: ... +def Tuple(builder: XlaBuilder, elements: Sequence[XlaOp]) -> XlaOp: ... +def While( + condition: XlaComputation, + body: XlaComputation, + init: XlaOp) -> XlaOp: ... + + +def Igamma(a: XlaOp, x: XlaOp) -> XlaOp: ... +def Igammac(a: XlaOp, x: XlaOp) -> XlaOp: ... +def IgammaGradA(a: XlaOp, x: XlaOp) -> XlaOp: ... +def RandomGammaGrad(a: XlaOp, x: XlaOp) -> XlaOp: ... +def RegularizedIncompleteBeta(a: XlaOp, b: XlaOp, x: XlaOp) -> XlaOp: ... +def Zeta(a: XlaOp, q: XlaOp) -> XlaOp: ... + +def Eq(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... +def Ne(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... +def Ge(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... +def Gt(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... +def Lt(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... +def Le(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... +def Add(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... +def Sub(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... +def Mul(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... +def Div(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... +def Rem(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... +def Max(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... +def Min(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... +def And(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... +def Or(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... +def Xor(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... +def ShiftLeft(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... +def ShiftRightArithmetic(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... +def ShiftRightLogical(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... +def Atan2(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... +def Pow(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... +def Complex(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... + +def Not(__arg: XlaOp) -> XlaOp: ... +def PopulationCount(__arg: XlaOp) -> XlaOp: ... +def Clz(__arg: XlaOp) -> XlaOp: ... +def Abs(__arg: XlaOp) -> XlaOp: ... +def Exp(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ... +def Expm1(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ... +def Floor(__arg: XlaOp) -> XlaOp: ... +def Ceil(__arg: XlaOp) -> XlaOp: ... +def Round(__arg: XlaOp) -> XlaOp: ... +def Log(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ... +def Log1p(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ... +def Sign(__arg: XlaOp) -> XlaOp: ... +def Cos(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ... +def OptimizationBarrier(__arg: XlaOp) -> XlaOp: ... +def Sin(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ... +def Tan(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ... +def Tanh(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ... +def IsFinite(__arg: XlaOp) -> XlaOp: ... +def Neg(__arg: XlaOp) -> XlaOp: ... +def Sqrt(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ... +def Rsqrt(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ... +def Cbrt(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ... +def Square(__arg: XlaOp) -> XlaOp: ... +def Reciprocal(__arg: XlaOp) -> XlaOp: ... +def Erfc(__arg: XlaOp) -> XlaOp: ... +def Erf(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ... +def ErfInv(__arg: XlaOp) -> XlaOp: ... +def Lgamma(__arg: XlaOp) -> XlaOp: ... +def Digamma(__arg: XlaOp) -> XlaOp: ... +def BesselI0e(__arg: XlaOp) -> XlaOp: ... +def BesselI1e(__arg: XlaOp) -> XlaOp: ... +def Acos(__arg: XlaOp) -> XlaOp: ... +def Asin(__arg: XlaOp) -> XlaOp: ... +def Atan(__arg: XlaOp) -> XlaOp: ... +def Acosh(__arg: XlaOp) -> XlaOp: ... +def Asinh(__arg: XlaOp) -> XlaOp: ... +def Atanh(__arg: XlaOp) -> XlaOp: ... +def Cosh(__arg: XlaOp) -> XlaOp: ... +def Sinh(__arg: XlaOp) -> XlaOp: ... +def Real(__arg: XlaOp) -> XlaOp: ... +def Imag(__arg: XlaOp) -> XlaOp: ... +def Conj(__arg: XlaOp) -> XlaOp: ... diff --git a/jaxlib/xla/xla_extension/pmap_lib.pyi b/jaxlib/xla/xla_extension/pmap_lib.pyi new file mode 100644 index 000000000000..8733d6c27b21 --- /dev/null +++ b/jaxlib/xla/xla_extension/pmap_lib.pyi @@ -0,0 +1,83 @@ +# Copyright 2021 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import inspect +from typing import Any, Callable, Sequence, Iterable, Tuple + +from . import pytree + +_AvalDimSharding = Any +_MeshDimAssignment = Any + +class NoSharding: + def __init__(self) -> None: ... + def __repr__(self) -> str: ... + def __eq__(self, __other: Any) -> bool: ... + +class Chunked: + @property + def chunks(self) -> Sequence[int]: ... + def __init__(self, __chunks: Sequence[int]) -> None: ... + def __repr__(self) -> str: ... + def __eq__(self, __other: Any) -> bool: ... + +class Unstacked: + @property + def size(self) -> int: ... + def __init__(self, __sz: int) -> None: ... + def __repr__(self) -> str: ... + def __eq__(self, __other: Any) -> bool: ... + +class ShardedAxis: + @property + def axis(self) -> int: ... + def __init__(self, __axis: int) -> None: ... + def __repr__(self) -> str: ... + def __eq__(self, __other: ShardedAxis) -> bool: ... + +class Replicated: + @property + def replicas(self) -> int: ... + def __init__(self, __replicas: int) -> None: ... + def __repr__(self) -> str: ... + def __eq__(self, __other: Replicated) -> bool: ... + +class ShardingSpec: + def __init__(self, + sharding: Iterable[_AvalDimSharding], + mesh_mapping: Iterable[_MeshDimAssignment]) -> None: ... + @property + def sharding(self) -> Tuple[_AvalDimSharding, ...]: ... + @property + def mesh_mapping(self) -> Tuple[_MeshDimAssignment]: ... + def __eq__(self, __other: ShardingSpec) -> bool: ... + def __hash__(self) -> int: ... + + _HAS_DYNAMIC_ATTRIBUTES = True + +class PmapFunction: + def __call__(self, *args, **kwargs) -> Any: ... + def __getstate__(self) -> Any: ... + def __setstate__(self, Any): ... + __signature__: inspect.Signature + def _cache_size(self) -> int: ... + def _cache_clear(self) -> None: ... + def _debug_cache_keys(self) -> str: ... + +def pmap(fun: Callable[..., Any], + cache_miss: Callable[..., Any], + static_argnums: Sequence[int], + shard_arg_fallback: Callable[..., Any], + pytree_registry: pytree.PyTreeRegistry) -> PmapFunction: ... diff --git a/jaxlib/xla/xla_extension/profiler.pyi b/jaxlib/xla/xla_extension/profiler.pyi new file mode 100644 index 000000000000..95749f61978a --- /dev/null +++ b/jaxlib/xla/xla_extension/profiler.pyi @@ -0,0 +1,59 @@ +# Copyright 2021 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from types import TracebackType +from typing import Any, Optional, Type, Union, List, Tuple + +_Status = Any + +class ProfilerServer: ... +def start_server(port: int) -> ProfilerServer: ... + +def register_plugin_profiler(c_api: Any) -> None: ... + +def get_profiled_instructions_proto(tensorboard_dir: str) -> bytes: ... +def get_instructins_profile(tensorboard_dir: str) -> List[Tuple[str, float]]: ... +def get_fdo_profile( + xspace: bytes, as_textproto: bool = ... +) -> Union[bytes, str]: ... + +class ProfilerSession: + def __init__(self, options: Optional[ProfileOptions] = ...) -> None: ... + def stop(self) -> bytes: ... + def export(self, xspace: bytes, tensorboard_dir: str) -> _Status:... + +class ProfileOptions: + include_dataset_ops: bool + host_tracer_level: int + python_tracer_level: int + enable_hlo_proto: bool + start_timestamp_ns: int + duration_ms: int + repository_path: str + raise_error_on_start_failure: bool + +def aggregate_profiled_instructions(profiles: List[bytes], percentile: int) -> str: ... + +class TraceMe: + def __init__(self, name: str, **kwargs: Any) -> None: ... + def __enter__(self) -> TraceMe: ... + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + exc_tb: Optional[TracebackType]) -> Optional[bool]:... + def set_metadata(self, **kwargs): ... + @staticmethod + def is_enabled() -> bool: ... diff --git a/jaxlib/xla/xla_extension/pytree.pyi b/jaxlib/xla/xla_extension/pytree.pyi new file mode 100644 index 000000000000..bfbad5de89d5 --- /dev/null +++ b/jaxlib/xla/xla_extension/pytree.pyi @@ -0,0 +1,158 @@ +# Copyright 2021 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import ( + Any, + Callable, + Hashable, + Iterable, + List, + Optional, + Sequence, + Tuple, + Type, + TypeVar, +) + +_T = TypeVar("_T") + +version: int + +class PyTreeRegistry: + def __init__( + self, + *, + enable_none: bool = ..., + enable_tuple: bool = ..., + enable_namedtuple: bool = ..., + enable_list: bool = ..., + enable_dict: bool = ... + ): ... + def flatten( + self, + tree: Any, + leaf_predicate: Optional[Callable[[Any], bool]] = ..., + ) -> Tuple[List[Any], PyTreeDef]: ... + def flatten_one_level( + self, tree: Any + ) -> Optional[Tuple[Iterable[Any], Any]]: ... + def flatten_one_level_with_keys( + self, tree: Any + ) -> Optional[Tuple[Iterable[_KeyLeafPair], Any]]: ... + def flatten_with_path( + self, + tree: Any, + leaf_predicate: Optional[Callable[[Any], bool]] = ..., + ) -> Tuple[List[Tuple[_KeyPath, Any]], PyTreeDef]: ... + def register_node( + self, + __type: Type[_T], + to_iterable: Callable[[_T], Tuple[_Children, _AuxData]], + from_iterable: Callable[[_AuxData, _Children], _T], + to_iterable_with_keys: ( + Callable[[_T], Tuple[_KeyLeafPairs, _AuxData]] | None + ) = ..., + ) -> Any: ... + def register_dataclass_node( + self, __type: Type[_T], meta_fields: List[str], data_fields: List[str] + ) -> Any: ... + +def default_registry() -> PyTreeRegistry: ... +def tuple(registry: PyTreeRegistry, arg0: Sequence[PyTreeDef]) -> PyTreeDef: ... +def all_leaves(registry: PyTreeRegistry, arg0: Iterable[Any]) -> bool: ... + +class SequenceKey(Hashable): + idx: int + __match_args__: tuple = ... + def __init__(self, idx: int): ... + def __str__(self) -> str: ... + def __repr__(self) -> str: ... + def __hash__(self) -> int: ... + def __getstate__(self) -> Any: ... + def __setstate__(self, state: Any): ... + def __eq__(self, __other: Any) -> bool: ... + +class DictKey(Hashable): + key: Hashable + __match_args__: tuple = ... + def __init__(self, key: Hashable): ... + def __str__(self) -> str: ... + def __repr__(self) -> str: ... + def __hash__(self) -> int: ... + def __getstate__(self) -> Any: ... + def __setstate__(self, state: Any): ... + def __eq__(self, __other: Any) -> bool: ... + +class GetAttrKey(Hashable): + name: str + __match_args__: tuple = ... + def __init__(self, name: str): ... + def __str__(self) -> str: ... + def __repr__(self) -> str: ... + def __hash__(self) -> int: ... + def __getstate__(self) -> Any: ... + def __setstate__(self, state: Any): ... + def __eq__(self, __other: Any) -> bool: ... + +class FlattenedIndexKey(Hashable): + key: int + __match_args__: tuple = ... + def __init__(self, key: int): ... + def __str__(self) -> str: ... + def __repr__(self) -> str: ... + def __hash__(self) -> int: ... + def __getstate__(self) -> Any: ... + def __setstate__(self, state: Any): ... + def __eq__(self, __other: Any) -> bool: ... + +class PyTreeDef: + def unflatten(self, __leaves: Iterable[Any]) -> Any: ... + def flatten_up_to(self, __xs: Any) -> List[Any]: ... + def compose(self, __inner: PyTreeDef) -> PyTreeDef: ... + def walk( + self, + __f_node: Callable[[Any, Any], Any], + __f_leaf: Optional[Callable[[_T], Any]], + leaves: Iterable[Any], + ) -> Any: ... + def from_iterable_tree(self, __xs: Any): ... + def node_data(self) -> Optional[Tuple[Type, Any]]: ... + def children(self) -> List[PyTreeDef]: ... + @staticmethod + def make_from_node_data_and_children( + registry: PyTreeRegistry, + node_data: Optional[Tuple[Type, Any]], + children: Iterable[PyTreeDef], + ) -> PyTreeDef: ... + + num_leaves: int + num_nodes: int + def __repr__(self) -> str: ... + def __eq__(self, __other: PyTreeDef) -> bool: ... + def __ne__(self, __other: PyTreeDef) -> bool: ... + def __hash__(self) -> int: ... + def __getstate__(self) -> Any: ... + def __setstate__(self, state: Any): ... + def serialize_using_proto(self) -> bytes: ... + @staticmethod + def deserialize_using_proto( + registry: PyTreeRegistry, data: bytes + ) -> PyTreeDef: ... + +_Children = TypeVar("_Children", bound=Iterable[Any]) +_KeyLeafPair = TypeVar("_KeyLeafPair", bound=Tuple[Any, Any]) +_KeyLeafPairs = TypeVar("_KeyLeafPairs", bound=Iterable[_KeyLeafPair]) +_KeyPath = TypeVar("_KeyPath", bound=Tuple[Any, ...]) +_AuxData = TypeVar("_AuxData", bound=Hashable) diff --git a/jaxlib/xla/xla_extension/sdy.pyi b/jaxlib/xla/xla_extension/sdy.pyi new file mode 100644 index 000000000000..34714e5c0219 --- /dev/null +++ b/jaxlib/xla/xla_extension/sdy.pyi @@ -0,0 +1,32 @@ +# Copyright 2021 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from mlir import ir + +def sdy_round_trip_export_pipeline( + module: ir.module +) -> str: ... + +def sdy_round_trip_import_shardings( + module: ir.module +) -> str: ... + +def get_mesh( + module: ir.module +) -> tuple[tuple[str, int], ...]: ... + +def lowered_with_shardy( + module: ir.module +) -> bool: ... diff --git a/jaxlib/xla/xla_extension/transfer_guard_lib.pyi b/jaxlib/xla/xla_extension/transfer_guard_lib.pyi new file mode 100644 index 000000000000..091e1e10a742 --- /dev/null +++ b/jaxlib/xla/xla_extension/transfer_guard_lib.pyi @@ -0,0 +1,39 @@ +# Copyright 2022 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import Any, List, Optional + +class TransferGuardLevel: + ALLOW: Any + LOG: Any + DISALLOW: Any + LOG_EXPLICIT: Any + DISALLOW_EXPLICIT: Any + +class TransferGuardState: + host_to_device: Optional[TransferGuardLevel] + device_to_device: Optional[TransferGuardLevel] + device_to_host: Optional[TransferGuardLevel] + + explicit_device_put: bool + explicit_device_get: bool + +def global_state() -> TransferGuardState: ... +def thread_local_state() -> TransferGuardState: ... + +class _TestingScopedLogSink: + def __enter__(self) -> _TestingScopedLogSink: ... + def __exit__(self, *args, **kwargs) -> None: ... + def logs(self) -> List[str]: ... diff --git a/jaxlib/xla_client.py b/jaxlib/xla_client.py new file mode 100644 index 000000000000..01b01ecf704e --- /dev/null +++ b/jaxlib/xla_client.py @@ -0,0 +1,18 @@ +# Copyright 2025 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from jaxlib.xla.xla_client import * # noqa: F403 +from jaxlib.xla.xla_client import _version # noqa: F401 +from jaxlib.xla.xla_client import _xla # noqa: F401 diff --git a/jaxlib/xla_extension.py b/jaxlib/xla_extension.py new file mode 100644 index 000000000000..e4fc7e96a1ab --- /dev/null +++ b/jaxlib/xla_extension.py @@ -0,0 +1,17 @@ +# Copyright 2025 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from jaxlib.xla.xla_extension import * # noqa: F403 +from jaxlib.xla.xla_extension import sdy # noqa: F401 diff --git a/pyproject.toml b/pyproject.toml index a1b9e7dd446a..be29e16beb9c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,8 +23,11 @@ module = [ "jax.experimental.jax2tf.tests.back_compat_testdata", "jax.experimental.jax2tf.tests.flax_models", "jax_cuda12_plugin.*", - "jaxlib.*", + "jaxlib.cpu_feature_guard", + "jaxlib.cuda.*", "jaxlib.mlir.*", + "jaxlib.utils", + "jaxlib.xla_extension.utils", "jraph.*", "libtpu.*", "matplotlib.*", diff --git a/setup.py b/setup.py index 80f45285ba61..823354adb70d 100644 --- a/setup.py +++ b/setup.py @@ -19,11 +19,11 @@ project_name = 'jax' -_current_jaxlib_version = '0.5.1' +_current_jaxlib_version = '0.6.0' # The following should be updated after each new jaxlib release. -_latest_jaxlib_version_on_pypi = '0.5.1' +_latest_jaxlib_version_on_pypi = '0.6.0' -_libtpu_version = '0.0.10.*' +_libtpu_version = '0.0.13.*' def load_version_module(pkg_path): spec = importlib.util.spec_from_file_location( @@ -38,6 +38,13 @@ def load_version_module(pkg_path): _cmdclass = _version_module._get_cmdclass(project_name) _minimum_jaxlib_version = _version_module._minimum_jaxlib_version +# If this is a pre-release ("rc" wheels), append "rc0" to +# _minimum_jaxlib_version and _current_jaxlib_version so that we are able to +# install the rc wheels. +if _version_module._is_prerelease(): + _minimum_jaxlib_version += "rc0" + _current_jaxlib_version += "rc0" + with open('README.md', encoding='utf-8') as f: _long_description = f.read() @@ -55,7 +62,7 @@ def load_version_module(pkg_path): python_requires='>=3.10', install_requires=[ f'jaxlib >={_minimum_jaxlib_version}, <={_jax_version}', - 'ml_dtypes>=0.4.0', + 'ml_dtypes>=0.5.0', 'numpy>=1.25', "numpy>=1.26.0; python_version>='3.12'", 'opt_einsum', @@ -81,32 +88,25 @@ def load_version_module(pkg_path): ], 'cuda': [ - f"jaxlib=={_current_jaxlib_version}", - f"jax-cuda12-plugin[with_cuda]>={_current_jaxlib_version},<={_jax_version}", + f"jaxlib>={_current_jaxlib_version},<={_jax_version}", + f"jax-cuda12-plugin[with-cuda]>={_current_jaxlib_version},<={_jax_version}", ], 'cuda12': [ - f"jaxlib=={_current_jaxlib_version}", - f"jax-cuda12-plugin[with_cuda]>={_current_jaxlib_version},<={_jax_version}", - ], - - # Deprecated alias for cuda12, kept to avoid breaking users who wrote - # cuda12_pip in their CI. - 'cuda12_pip': [ - f"jaxlib=={_current_jaxlib_version}", - f"jax-cuda12-plugin[with_cuda]>={_current_jaxlib_version},<={_jax_version}", + f"jaxlib>={_current_jaxlib_version},<={_jax_version}", + f"jax-cuda12-plugin[with-cuda]>={_current_jaxlib_version},<={_jax_version}", ], # Target that does not depend on the CUDA pip wheels, for those who want # to use a preinstalled CUDA. - 'cuda12_local': [ - f"jaxlib=={_current_jaxlib_version}", - f"jax-cuda12-plugin=={_current_jaxlib_version}", + 'cuda12-local': [ + f"jaxlib>={_current_jaxlib_version},<={_jax_version}", + f"jax-cuda12-plugin>={_current_jaxlib_version},<={_jax_version}", ], # ROCm support for ROCm 6.0 and above. 'rocm': [ - f"jaxlib=={_current_jaxlib_version}", + f"jaxlib>={_current_jaxlib_version},<={_jax_version}", f"jax-rocm60-plugin>={_current_jaxlib_version},<={_jax_version}", ], diff --git a/tests/BUILD b/tests/BUILD index 0ffa68ed8eb3..c46ea7556a44 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -34,7 +34,7 @@ jax_generate_backend_suites() jax_multiplatform_test( name = "api_test", srcs = ["api_test.py"], - enable_configs = ["tpu_v3_2x2"], + enable_configs = ["tpu_v3_x4"], shard_count = 10, deps = [ "//jax:experimental", @@ -44,7 +44,7 @@ jax_multiplatform_test( jax_multiplatform_test( name = "debug_info_test", srcs = ["debug_info_test.py"], - enable_configs = ["tpu_v3_2x2"], + enable_configs = ["tpu_v3_x4"], deps = [ "//jax:experimental", "//jax:pallas", @@ -60,10 +60,13 @@ jax_multiplatform_test( srcs = ["device_test.py"], ) -jax_multiplatform_test( +jax_py_test( name = "dynamic_api_test", srcs = ["dynamic_api_test.py"], - shard_count = 2, + deps = [ + "//jax", + "//jax:test_util", + ] + py_deps("absl/testing"), ) jax_multiplatform_test( @@ -80,6 +83,15 @@ jax_py_test( ] + py_deps("absl/testing"), ) +jax_py_test( + name = "array_extensibility_test", + srcs = ["array_extensibility_test.py"], + deps = [ + "//jax", + "//jax:test_util", + ] + py_deps("absl/testing"), +) + jax_multiplatform_test( name = "array_interoperability_test", srcs = ["array_interoperability_test.py"], @@ -88,11 +100,8 @@ jax_multiplatform_test( "gpu", ], enable_configs = [ - "gpu_p100x2", + "gpu_h100x2", ], - env = { - "PYTHONWARNINGS": "default", # TODO(b/394123878): protobuf, via TensorFlow, issues a Python warning under Python 3.12+ sometimes. - }, tags = ["multiaccelerator"], deps = py_deps("tensorflow_core"), ) @@ -128,9 +137,20 @@ jax_multiplatform_test( srcs = ["debug_nans_test.py"], ) +jax_py_test( + name = "distributed_initialize_test", + srcs = ["distributed_initialize_test.py"], + deps = [ + "//jax", + "//jax:test_util", + ] + py_deps("portpicker"), +) + jax_multiplatform_test( name = "distributed_test", srcs = ["distributed_test.py"], + enable_backends = ["gpu"], + deps = py_deps("portpicker"), ) jax_py_test( @@ -170,7 +190,7 @@ jax_multiplatform_test( name = "ffi_test", srcs = ["ffi_test.py"], enable_configs = [ - "gpu_p100x2", + "gpu_h100x2", ], # TODO(dfm): Remove after removal of jex.ffi imports. deps = ["//jax:extend"], @@ -186,7 +206,7 @@ jax_multiplatform_test( ], # Times out on TPU with asan/tsan. }, shard_count = { - "tpu": 20, + "tpu": 10, "cpu": 20, "gpu": 10, }, @@ -219,11 +239,12 @@ jax_multiplatform_test( jax_multiplatform_test( name = "lobpcg_test", srcs = ["lobpcg_test.py"], - env = {"LOBPCG_EMIT_DEBUG_PLOTS": "1"}, + # Set LOBPCG_EMIT_DEBUG_PLOTS=1 to debug + # checkLobpcgMonotonicity and checkApproxEigs tests + # using matplotlib plots + # env = {"LOBPCG_EMIT_DEBUG_PLOTS": "1"}, shard_count = { - "cpu": 48, - "gpu": 48, - "tpu": 48, + "cpu": 8, }, deps = [ "//jax:experimental_sparse", @@ -234,9 +255,9 @@ jax_multiplatform_test( name = "svd_test", srcs = ["svd_test.py"], shard_count = { - "cpu": 10, + "cpu": 20, "gpu": 10, - "tpu": 40, + "tpu": 15, }, ) @@ -254,17 +275,14 @@ jax_multiplatform_test( srcs = ["memories_test.py"], enable_configs = [ "cpu", - "gpu_p100x2", - "tpu_v3_2x2", - "tpu_v4_2x2", - "tpu_v5p_2x2", - "tpu_v5e_4x2", + "gpu_h100x2", + "tpu_v3_x4", + "tpu_v4_x4", + "tpu_v5p_x4", + "tpu_v5e_x8", "gpu_p100x2_shardy", - "tpu_v5e_4x2_shardy", + "tpu_v5e_x8_shardy", ], - shard_count = { - "tpu": 5, - }, deps = [ "//jax:experimental", ], @@ -279,14 +297,13 @@ jax_multiplatform_test( }, enable_configs = [ "gpu_p100x2_shardy", - "tpu_v3_2x2_shardy", - "tpu_v3_2x2", - "gpu_p100x2", + "tpu_v3_x4_shardy", + "tpu_v3_x4", + "gpu_h100x2", ], shard_count = { - "cpu": 5, - "gpu": 5, - "tpu": 5, + "cpu": 3, + "tpu": 4, }, tags = ["multiaccelerator"], deps = [ @@ -301,7 +318,8 @@ jax_multiplatform_test( "tpu": ["requires-mem:16g"], # Under tsan on 2x2 this test exceeds the default 12G memory limit. }, enable_configs = [ - "tpu_v3_2x2_shardy", + "tpu_v3_x4_shardy", + "tpu_v3_x4", ], tags = ["multiaccelerator"], deps = [ @@ -313,10 +331,10 @@ jax_multiplatform_test( name = "shard_alike_test", srcs = ["shard_alike_test.py"], enable_configs = [ - "tpu_v3_2x2", - "tpu_v5e_4x2", - "tpu_v4_2x2", - "tpu_v3_2x2_shardy", + "tpu_v3_x4", + "tpu_v5e_x8", + "tpu_v4_x4", + "tpu_v3_x4_shardy", ], deps = [ "//jax:experimental", @@ -376,7 +394,7 @@ jax_multiplatform_test( "tpu": ["requires-mem:16g"], # Under tsan on 2x2 this test exceeds the default 12G memory limit. }, enable_configs = [ - "tpu_v3_2x2", + "tpu_v3_x4", ], tags = ["multiaccelerator"], deps = [ @@ -399,8 +417,8 @@ jax_multiplatform_test( srcs = ["image_test.py"], shard_count = { "cpu": 10, - "gpu": 20, - "tpu": 10, + "gpu": 10, + "tpu": 8, }, tags = ["noasan"], # Linking TF causes a linker OOM. deps = py_deps("pil") + py_deps("tensorflow_core"), @@ -409,8 +427,6 @@ jax_multiplatform_test( jax_multiplatform_test( name = "infeed_test", srcs = ["infeed_test.py"], - deps = [ - ], ) jax_multiplatform_test( @@ -444,7 +460,7 @@ jax_multiplatform_test( srcs = ["jet_test.py"], shard_count = { "cpu": 10, - "gpu": 10, + "gpu": 4, }, deps = [ "//jax:jet", @@ -457,8 +473,8 @@ jax_multiplatform_test( srcs = ["lax_control_flow_test.py"], shard_count = { "cpu": 30, - "gpu": 40, - "tpu": 30, + "gpu": 30, + "tpu": 20, }, ) @@ -476,6 +492,7 @@ jax_multiplatform_test( name = "lax_numpy_test", srcs = ["lax_numpy_test.py"], backend_tags = { + "tpu": ["notsan"], # Test times out. "cpu": ["notsan"], # Test times out. }, shard_count = { @@ -522,11 +539,6 @@ jax_multiplatform_test( jax_multiplatform_test( name = "lax_numpy_einsum_test", srcs = ["lax_numpy_einsum_test.py"], - shard_count = { - "cpu": 10, - "gpu": 10, - "tpu": 10, - }, ) jax_multiplatform_test( @@ -534,8 +546,8 @@ jax_multiplatform_test( srcs = ["lax_numpy_ufuncs_test.py"], shard_count = { "cpu": 10, - "gpu": 10, - "tpu": 10, + "gpu": 5, + "tpu": 5, }, ) @@ -548,9 +560,9 @@ jax_multiplatform_test( name = "lax_scipy_test", srcs = ["lax_scipy_test.py"], shard_count = { - "cpu": 20, + "cpu": 30, "gpu": 20, - "tpu": 20, + "tpu": 8, }, deps = py_deps("numpy") + py_deps("scipy") + py_deps("absl/testing"), ) @@ -563,8 +575,8 @@ jax_multiplatform_test( }, shard_count = { "cpu": 10, - "gpu": 10, - "tpu": 10, + "gpu": 5, + "tpu": 5, }, ) @@ -573,11 +585,14 @@ jax_multiplatform_test( srcs = ["lax_scipy_special_functions_test.py"], backend_tags = { "gpu": ["noasan"], # Times out. - "cpu": ["noasan"], # Times out. + "cpu": [ + "noasan", + "notsan", + ], # Times out. }, shard_count = { "cpu": 20, - "gpu": 20, + "gpu": 30, "tpu": 20, }, deps = py_deps("numpy") + py_deps("scipy") + py_deps("absl/testing"), @@ -587,9 +602,9 @@ jax_multiplatform_test( name = "lax_scipy_spectral_dac_test", srcs = ["lax_scipy_spectral_dac_test.py"], shard_count = { - "cpu": 40, - "gpu": 40, - "tpu": 40, + "cpu": 20, + "gpu": 8, + "tpu": 8, }, deps = [ "//jax:internal_test_util", @@ -630,7 +645,7 @@ jax_multiplatform_test( srcs = ["lax_autodiff_test.py"], shard_count = { "cpu": 40, - "gpu": 40, + "gpu": 30, "tpu": 20, }, ) @@ -705,7 +720,7 @@ jax_multiplatform_test( "cpu", ], enable_configs = [ - "gpu_p100x2", + "gpu_h100x2", "gpu_p100x2_shardy", "gpu_p100x2_pjrt_c_api", ], @@ -745,8 +760,8 @@ jax_multiplatform_test( name = "multibackend_test", srcs = ["multibackend_test.py"], enable_configs = [ - "tpu_v3_2x2", - "gpu_p100x2", + "tpu_v3_x4", + "gpu_h100x2", ], ) @@ -799,11 +814,11 @@ jax_multiplatform_test( }, enable_configs = [ "gpu_v100", - "tpu_v3_2x2", + "tpu_v3_x4", ], shard_count = { "cpu": 30, - "gpu": 30, + "gpu": 10, "tpu": 30, }, tags = ["multiaccelerator"], @@ -818,7 +833,7 @@ jax_multiplatform_test( # No implementation of nonsymmetric Eigendecomposition. enable_backends = ["cpu"], shard_count = { - "cpu": 10, + "cpu": 5, }, # This test ends up calling Fortran code that initializes some memory and # passes it to C code. MSan is not able to detect that the memory was @@ -879,29 +894,12 @@ jax_multiplatform_test( "notsan", # Times out ], }, - shard_count = 10, + shard_count = 8, ) jax_multiplatform_test( name = "random_test", srcs = ["random_test.py"], - backend_tags = { - "cpu": [ - "notsan", # Times out - "nomsan", # Times out - ], - "tpu": [ - "optonly", - "nomsan", # Times out - "notsan", # Times out - ], - }, - shard_count = { - "cpu": 30, - "gpu": 30, - "tpu": 40, - }, - tags = ["noasan"], # Times out ) jax_multiplatform_test( @@ -923,7 +921,7 @@ jax_multiplatform_test( }, shard_count = { "cpu": 40, - "gpu": 40, + "gpu": 50, "tpu": 40, }, tags = ["noasan"], # Times out @@ -934,25 +932,7 @@ jax_multiplatform_test( name = "random_test_with_custom_prng", srcs = ["random_test.py"], args = ["--jax_enable_custom_prng=true"], - backend_tags = { - "cpu": [ - "noasan", # Times out under asan/msan/tsan. - "nomsan", - "notsan", - ], - "tpu": [ - "noasan", # Times out under asan/msan/tsan. - "nomsan", - "notsan", - "optonly", - ], - }, main = "random_test.py", - shard_count = { - "cpu": 40, - "gpu": 40, - "tpu": 40, - }, ) jax_multiplatform_test( @@ -1021,9 +1001,9 @@ jax_multiplatform_test( "tpu": ["nomsan"], # Times out }, shard_count = { - "cpu": 40, - "gpu": 30, - "tpu": 40, + "cpu": 50, + "gpu": 50, + "tpu": 50, }, tags = [ "noasan", @@ -1050,8 +1030,8 @@ jax_multiplatform_test( }, shard_count = { "cpu": 50, - "gpu": 50, - "tpu": 50, + "gpu": 30, + "tpu": 20, }, tags = [ "noasan", @@ -1101,20 +1081,6 @@ jax_multiplatform_test( ] + py_deps("scipy"), ) -jax_multiplatform_test( - name = "sparse_nm_test", - srcs = ["sparse_nm_test.py"], - enable_backends = [], - enable_configs = [ - "gpu_a100", - "gpu_h100", - ], - deps = [ - "//jax:experimental_sparse", - "//jax:pallas_gpu", - ], -) - jax_multiplatform_test( name = "sparsify_test", srcs = ["sparsify_test.py"], @@ -1122,10 +1088,11 @@ jax_multiplatform_test( backend_tags = { "cpu": [ "noasan", # Times out under asan - "notsan", # Times out under asan + "notsan", # Times out under tsan ], "tpu": [ - "noasan", # Times out under asan. + "noasan", # Times out under asan + "notsan", # Times out under tsan ], }, shard_count = { @@ -1147,7 +1114,7 @@ jax_multiplatform_test( jax_multiplatform_test( name = "checkify_test", srcs = ["checkify_test.py"], - enable_configs = ["tpu_v3_2x2"], + enable_configs = ["tpu_v3_x4"], shard_count = { "gpu": 2, "tpu": 4, @@ -1159,13 +1126,14 @@ jax_multiplatform_test( srcs = ["error_check_test.py"], ) +jax_multiplatform_test( + name = "jax_numpy_error_test", + srcs = ["jax_numpy_error_test.py"], +) + jax_multiplatform_test( name = "stax_test", srcs = ["stax_test.py"], - shard_count = { - "cpu": 5, - "gpu": 5, - }, deps = ["//jax:stax"], ) @@ -1294,7 +1262,11 @@ jax_multiplatform_test( jax_multiplatform_test( name = "ann_test", srcs = ["ann_test.py"], - shard_count = 10, + shard_count = { + "cpu": 5, + "gpu": 5, + "tpu": 10, + }, ) jax_py_test( @@ -1317,9 +1289,13 @@ jax_multiplatform_test( srcs = ["garbage_collection_guard_test.py"], ) -jax_multiplatform_test( +jax_py_test( name = "name_stack_test", srcs = ["name_stack_test.py"], + deps = [ + "//jax", + "//jax:test_util", + ] + py_deps("absl/testing"), ) jax_multiplatform_test( @@ -1331,9 +1307,9 @@ jax_multiplatform_test( enable_configs = [ "cpu", "gpu_h100", - "tpu_v2_1x1", - "tpu_v3_2x2", - "tpu_v4_2x2", + "tpu_v2", + "tpu_v3_x4", + "tpu_v4_x4", ], tags = ["multiaccelerator"], ) @@ -1344,11 +1320,11 @@ jax_multiplatform_test( enable_configs = [ "cpu", "gpu_h100", - "tpu_v2_1x1", - "tpu_v3_2x2", - "tpu_v4_2x2", + "tpu_v2", + "tpu_v3_x4", + "tpu_v4_x4", "gpu_a100_shardy", - "tpu_v3_2x2_shardy", + "tpu_v3_x4_shardy", ], ) @@ -1359,10 +1335,10 @@ jax_multiplatform_test( "gpu": ["noasan"], # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143 }, enable_configs = [ - "tpu_v2_1x1", - "tpu_v3_2x2", - "tpu_v4_2x2", - "tpu_v3_2x2_shardy", + "tpu_v2", + "tpu_v3_x4", + "tpu_v4_x4", + "tpu_v3_x4_shardy", "gpu_p100x2_shardy", ], tags = ["multiaccelerator"], @@ -1380,9 +1356,9 @@ jax_multiplatform_test( enable_configs = [ "cpu", "gpu_h100", - "tpu_v2_1x1", - "tpu_v3_2x2", - "tpu_v4_2x2", + "tpu_v2", + "tpu_v3_x4", + "tpu_v4_x4", ], ) @@ -1417,8 +1393,6 @@ jax_multiplatform_test( name = "for_loop_test", srcs = ["for_loop_test.py"], shard_count = { - "cpu": 20, - "gpu": 10, "tpu": 20, }, ) @@ -1436,10 +1410,6 @@ jax_multiplatform_test( enable_configs = [ "gpu_p100x2_shardy", ], - shard_count = { - "gpu": 10, - "tpu": 10, - }, tags = [ "multiaccelerator", ], @@ -1453,7 +1423,7 @@ jax_multiplatform_test( srcs = ["shard_map_test.py"], enable_configs = [ "gpu_p100x2_shardy", - "tpu_v3_2x2_shardy", + "tpu_v3_x4_shardy", ], shard_count = { "cpu": 50, @@ -1543,14 +1513,11 @@ jax_multiplatform_test( jax_multiplatform_test( name = "export_test", srcs = ["export_test.py"], - disable_configs = [ - "cpu_shardy", # TODO(b/355263220): enable once export is supported. - ], enable_configs = [ "cpu_shardy", "gpu_p100x2_shardy", - "tpu_v3_2x2_shardy", - "tpu_v3_2x2", + "tpu_v3_x4_shardy", + "tpu_v3_x4", ], tags = [], ) @@ -1566,9 +1533,9 @@ jax_multiplatform_test( "cpu_x32", ], shard_count = { - "cpu": 4, - "gpu": 6, - "tpu": 4, + "cpu": 30, + "gpu": 20, + "tpu": 25, }, tags = [ "noasan", # Times out @@ -1616,9 +1583,6 @@ jax_multiplatform_test( name = "fused_attention_stablehlo_test", srcs = ["fused_attention_stablehlo_test.py"], enable_backends = ["gpu"], - shard_count = { - "gpu": 4, - }, tags = ["multiaccelerator"], ) @@ -1628,6 +1592,20 @@ jax_multiplatform_test( deps = ["//jax:experimental"], ) +jax_multiplatform_test( + name = "unary_ops_accuracy_test", + srcs = ["unary_ops_accuracy_test.py"], + disable_configs = [ + "tpu_pjrt_c_api", + ], + enable_backends = [ + "tpu", + ], + deps = [ + "//jax:experimental", + ], +) + jax_py_test( name = "pretty_printer_test", srcs = ["pretty_printer_test.py"], diff --git a/tests/api_test.py b/tests/api_test.py index aece7b19fdfb..2d1055516074 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -52,6 +52,7 @@ from jax._src import config from jax._src import core from jax._src import custom_derivatives +from jax._src import deprecations from jax._src import linear_util as lu from jax._src import test_util as jtu from jax._src import xla_bridge @@ -3120,7 +3121,6 @@ def test_error_for_invalid_dtype(self): def test_vmap_preserves_docstr(self): def superfun(a): """Does things with stuff.""" - pass self.assertRegex(api.vmap(superfun).__doc__, "\n".join([ "Vectorized version of superfun.*", @@ -4313,6 +4313,21 @@ def g(x, y): for i in range(3): # Loop verifies we exercise both Python and C++ dispatch self.assertEqual(2 * i, g(2, i), msg=i) + def test_make_jaxpr_static_argnums_order(self): + # https://github.com/jax-ml/jax/issues/28065 + def f(a, b, c): + x = a + c + y = b * c + z = x - y + return z + + for static_argnums in [(1, 0), (0, 1)]: + val = jax.jit(f, static_argnums=static_argnums)(1, 2, 3) + self.assertEqual(val, -2) + jaxpr = jax.make_jaxpr(f, static_argnums=static_argnums)(1, 2, 3) + self.assertEqual(jaxpr.eqns[0].invars[0].val, 1) + self.assertEqual(jaxpr.eqns[1].invars[0].val, 2) + def test_fastpath_cache_confusion(self): # https://github.com/jax-ml/jax/issues/12542 @jax.jit @@ -4424,6 +4439,7 @@ def test_grad_conj_symbolic_zeros(self): out = jax.grad(f)(3.0) # doesn't crash self.assertAllClose(out, 1., check_dtypes=False) + @jtu.thread_unsafe_test() def test_cache_clear_pmap(self): @jax.pmap def f(i): @@ -4466,64 +4482,214 @@ def add(x): self.assertEqual(tracing_add_count, 2) @jtu.thread_unsafe_test() # logging is not thread-safe - def test_cache_miss_explanations(self): - @jax.jit - def f(x, y): - return jnp.sin(x) * y['hi'] + def test_cache_miss_explanations_skip_internals(self): + if is_persistent_cache_enabled(): + self.skipTest('With persistent cache, we see the cache misses') + with config.explain_cache_misses(True): + with self.assertNoLogs(level='WARNING'): + for i in range(2): + jnp.sin(jnp.arange(i + 1, dtype=np.float32)) + + @jtu.thread_unsafe_test() # logging is not thread-safe + def test_cache_miss_explanations_first_miss(self): + @jax.jit + def f(x): return x x = jnp.float32(1.) - y = {'hi': jnp.arange(3., dtype='float32')} expected_log_len = 1 if not is_persistent_cache_enabled() else 3 - # print on first miss, not on hit + with config.explain_cache_misses(True): + with self.assertLogs(level="WARNING") as cm: + f(x) + f(x) + self.assertLen(cm.output, expected_log_len) + msg = cm.output[0] + self.assertIn("TRACING CACHE MISS", msg) + self.assertIn("never seen function", msg) + self.assertNotIn("explanation unavailable!", msg) + + @jtu.thread_unsafe_test() # logging is not thread-safe + def test_cache_miss_explanations_other_in_tree(self): + @jax.jit + def f(*args, **kwargs): return args[0] + + f(0., 1., y=(2., 2.1)) + + with config.explain_cache_misses(True): + with self.assertLogs(level="WARNING") as cm: + # Same number of leaves but different trees + f(0., (1., 1.1), y=2.) + self.assertLen(cm.output, 1) + msg = cm.output[0] + self.assertIn("different input pytree", msg) + self.assertNotIn("explanation unavailable!", msg) + + @jtu.thread_unsafe_test() # logging is not thread-safe + def test_cache_miss_explanations_other_arg_passed_as_kwarg(self): + @jax.jit + def f(x, y): return jnp.sin(x) + y + + f(0., 1.) + + # kwarg change + with config.explain_cache_misses(True): + with self.assertLogs(level="WARNING") as cm: + f(0., y=1.) + + self.assertLen(cm.output, 1) + msg = cm.output[0] + self.assertIn("different number of args and kwargs, but same total number", msg) + self.assertIn("now 1 args and kwargs with keys ['y']", msg) + self.assertIn("before 1 args and kwargs with keys []", msg) + self.assertNotIn("explanation unavailable!", msg) + + @jtu.thread_unsafe_test() # logging is not thread-safe + def test_cache_miss_explanations_other_static_argnums(self): + @partial(jax.jit, static_argnums=(0, 2)) + def f(x, y, z): + return y + + f(1., 2., "foo") + + with config.explain_cache_misses(True): + with self.assertLogs(level="WARNING") as cm: + f(1., 2., "bar") + self.assertLen(cm.output, 1) + msg = cm.output[0] + self.assertIn("different value of static args", msg) + self.assertIn("now 1.0, 'bar' and before 1.0, 'foo'", msg) + self.assertNotIn("explanation unavailable!", msg) + + @jtu.thread_unsafe_test() # logging is not thread-safe + def test_cache_miss_explanations_other_static_argnames(self): + @partial(jax.jit, static_argnames="foo") + def f(*, foo): + return 1 + + f(foo="foo") + + with config.explain_cache_misses(True): + with self.assertLogs(level="WARNING") as cm: + f(foo="bar") + self.assertLen(cm.output, 1) + msg = cm.output[0] + self.assertIn("different value of static kwargs", msg) + self.assertIn("now {foo: 'bar'} and before {foo: 'foo'}", msg) + self.assertNotIn('explanation unavailable!', msg) + + @jtu.thread_unsafe_test() # logging is not thread-safe + def test_cache_miss_explanations_other_dtype(self): + @jax.jit + def f(x, y): return x + f(np.float32(0), np.float32(1)) + with config.explain_cache_misses(True): with self.assertLogs(level='WARNING') as cm: - f(x, y) - f(x, y) + f(np.float32(0), np.int32(1)) + self.assertLen(cm.output, 1) + msg = cm.output[0] + self.assertIn("different input types", msg) + self.assertIn("at y, now i32[] and before f32[]", msg) + self.assertNotIn("explanation unavailable!", msg) + + @jtu.thread_unsafe_test() # logging is not thread-safe + def test_cache_miss_explanations_other_weak_type(self): + @jax.jit + def f(x, y): return jnp.sin(x) + y + + y = jnp.arange(4, dtype="float32") + f(jnp.float32(0.), y) + # weak type change (assuming no x64) + if config.enable_x64.value: + self.skipTest("Work only for 32 bit mode") + with config.explain_cache_misses(True): + with self.assertLogs(level="WARNING") as cm: + f(0., y) + expected_log_len = 1 if not is_persistent_cache_enabled() else 3 self.assertLen(cm.output, expected_log_len) msg = cm.output[0] - self.assertIn('TRACING CACHE MISS', msg) - self.assertIn('never seen function', msg) + self.assertIn("different input types", msg) + self.assertIn("at x, now f32[]{weak_type=True} and before f32[]{weak_type=False}", msg) + self.assertIn("https://docs.jax.dev/en/latest/type_promotion.html#weak-types", msg) + self.assertNotIn("explanation unavailable!", msg) + + @jtu.thread_unsafe_test() # logging is not thread-safe + def test_cache_miss_explanations_other_shape(self): + @jax.jit + def f(x, y): return jnp.sin(x) + y + f(np.float32(0), np.arange(1, dtype=np.float32)) - # shape change - y_ = {'hi': jnp.arange(4, dtype='float32')} with config.explain_cache_misses(True): with self.assertLogs(level='WARNING') as cm: - f(x, y_) + f(np.float32(0), np.arange(2, dtype=np.float32)) + expected_log_len = 1 if not is_persistent_cache_enabled() else 3 self.assertLen(cm.output, expected_log_len) msg = cm.output[0] - self.assertIn('never seen input type signature', msg) - self.assertIn('closest seen input type signature has 1 mismatches', msg) - self.assertIn('seen f32[3], but now given f32[4]', msg) + self.assertIn("different input types", msg) + self.assertIn("at y, now f32[2] and before f32[1]", msg) + self.assertNotIn("explanation unavailable!", msg) - # weak type change (assuming no x64) - if not config.enable_x64.value: - with config.explain_cache_misses(True): - with self.assertLogs(level='WARNING') as cm: - f(1., y) - self.assertLen(cm.output, expected_log_len) - msg = cm.output[0] - self.assertIn('weak_type=True', msg) - self.assertIn('https://jax.readthedocs.io/en/latest/type_promotion.html#weak-types', msg) + @jtu.thread_unsafe_test() # logging is not thread-safe + def test_cache_miss_explanations_other_shape_explain_closest(self): + @jax.jit + def f(x): return x + f(np.ones((1, 2), dtype=np.float32)) + f(np.ones((10, 20, 30), dtype=np.float32)) + f(np.ones((1, 2, 3), dtype=np.float32)) - # kwarg change with config.explain_cache_misses(True): with self.assertLogs(level='WARNING') as cm: - f(1, y=y) + f(np.ones((10, 2, 30), dtype=np.float32)) + expected_log_len = 1 if not is_persistent_cache_enabled() else 3 self.assertLen(cm.output, expected_log_len) msg = cm.output[0] - self.assertIn('never seen passing 1 positional args and 1 keyword args', msg) + self.assertIn("key with different input types", msg) + self.assertIn("at x, now f32[10,2,30] and before f32[10,20,30]", msg) + self.assertNotIn("explanation unavailable!", msg) + + @jtu.thread_unsafe_test() # logging is not thread-safe + def test_cache_miss_explanations_other_tracing_config(self): + @jax.jit + def f(x, y): return jnp.sin(x) + y + f(0., 1.) # tracing config change with config.explain_cache_misses(True): - with self.assertLogs(level='WARNING') as cm: - with jax.numpy_rank_promotion('warn'): - f(x, y) - # depending on the backend, we may or may not get persistent cache warnings + with self.assertLogs(level="WARNING") as cm: + with jax.numpy_rank_promotion("warn"): + with jax.default_matmul_precision("high"): + f(0., 1.) + + expected_log_len = 1 if not is_persistent_cache_enabled() else 3 self.assertTrue(1 <= len(cm.output) <= expected_log_len) msg = cm.output[0] - self.assertIn("tracing context doesn't match", msg) + self.assertIn("key with different tracing context", msg) + self.assertIn("now warn and before", msg) + self.assertIn("now high and before", msg) + self.assertNotIn("explanation unavailable!", msg) + + @jtu.thread_unsafe_test() # logging is not thread-safe + def test_cache_miss_explanations_multiple_changes(self): + @jax.jit + def f(x): return jnp.sin(x) + + call_1 = f(np.arange(4, dtype=np.float32)) + with jax.numpy_rank_promotion("warn"): + call_2 = f(np.arange(8, dtype=np.float32)) + + with config.explain_cache_misses(True): + with self.assertLogs(level='WARNING') as cm: + # Matches call_2 in shape but not context, and call_1 in context but + # not in shape. + f(np.arange(8, dtype=np.float32)) + + self.assertLen(cm.output, 1) + msg = cm.output[0] + self.assertIn("key with different input types", msg) + self.assertIn("at x, now f32[8] and before f32[4]", msg) + self.assertIn("key with different tracing context", msg) + self.assertNotIn("explanation unavailable!", msg) @jtu.thread_unsafe_test() # logging is not thread-safe def test_cache_miss_explanations_new_function_in_loop(self): @@ -4547,28 +4713,6 @@ def f(x, y): _, msg = cm.output self.assertIn('another function defined on the same line', msg) - @jtu.thread_unsafe_test() # logging is not thread-safe - def test_cache_miss_explanations_unpacks_transforms(self): - # Tests that the explain_tracing_cache_miss() function does not throw an - # error when unpacking `transforms` with a length greater than 3. - @jax.jit - def f(key): - return jax.random.truncated_normal(key, 1, 1, dtype=jax.numpy.float32) - - with config.explain_cache_misses(True): - with self.assertLogs(level="WARNING") as cm: - f(jax.random.key(seed=123)) - - if is_persistent_cache_enabled(): - # 5 warnings from tracing cache, 5-10 from persistent cache depending on - # the backend - self.assertTrue(10 <= len(cm.output) <= 15) - self.assertTrue(any("TRACING CACHE MISS" in msg for msg in cm.output)) - else: - self.assertLen(cm.output, 5) - for msg in cm.output: - self.assertIn("TRACING CACHE MISS", msg) - def test_cache_miss_explanations_no_source_info(self): # ``operator.add`` is a built-in function and does not have source info. with config.explain_cache_misses(True): @@ -4687,6 +4831,8 @@ def f(inputs): @jtu.run_on_devices("cpu") def test_inner_jit_forwarding_happens(self): + if not config.dynamic_shapes.value: + self.skipTest("Only works for dynamic shapes") jaxpr = jax.make_jaxpr(lambda: jax.jit(lambda x: x)(3))() self.assertLen(jaxpr.jaxpr.outvars, 1) self.assertIsInstance(jaxpr.jaxpr.outvars[0], core.Literal) @@ -4695,6 +4841,8 @@ def test_inner_jit_forwarding_happens(self): @parameterized.parameters(range(8)) @jtu.run_on_devices("cpu") def test_inner_jit_forwarding_correctness(self, num_input_fwd): + if not config.dynamic_shapes.value: + self.skipTest("Only works for dynamic shapes") num_args = 8 rng = np.random.RandomState(0) @@ -4776,7 +4924,7 @@ def sin_of_sin(x): def test_deferred_primal_with_direct_linearize(self): def my_sin_lin(nzs, x): nz, = nzs - return (my_sin_p.bind(x), nz, x, lambda x, t: lax.mul(t, lax.cos(x))) + return (my_sin_p.bind(x, accuracy=None), nz, x, lambda x, t: lax.mul(t, lax.cos(x))) my_sin_p = core.Primitive("my_sin_p") my_sin_p.def_impl(lax.sin) @@ -4823,8 +4971,8 @@ def f(x): sin_impl = lax.sin_p.impl cos_impl = lax.cos_p.impl try: - lax.sin_p.def_impl(lambda x: sin_calls.append(1) or sin_impl(x)) - lax.cos_p.def_impl(lambda x: cos_calls.append(1) or cos_impl(x)) + lax.sin_p.def_impl(lambda x, **kwargs: sin_calls.append(1) or sin_impl(x, **kwargs)) + lax.cos_p.def_impl(lambda x, **kwargs: cos_calls.append(1) or cos_impl(x, **kwargs)) f_lin(3.) finally: lax.sin_p.def_impl(sin_impl) @@ -5019,7 +5167,7 @@ def g(x): # Make sure that introducing constants in vmap works. constant_introducing_p = core.Primitive('introduce_constant') - constant_introducing_p.def_abstract_eval(core.raise_to_shaped) + constant_introducing_p.def_abstract_eval(lambda x: x) def _constant_introducing_batcher(xs, ds): (x,), (d,) = xs, ds return (x + np.arange(x.size, dtype=x.dtype).reshape(x.shape)), d @@ -5117,7 +5265,7 @@ def f(x, y): called = [] sin_impl = lax.sin_p.impl try: - lax.sin_p.def_impl(lambda x: called.append(1) or sin_impl(x)) + lax.sin_p.def_impl(lambda x, **kwargs: called.append(1) or sin_impl(x, **kwargs)) api.grad(g)(3.) finally: lax.sin_p.def_impl(sin_impl) @@ -5901,6 +6049,7 @@ def test_remat_of_scan(self, remat): jtu.check_grads(remat(f), (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(remat(f), 4.)[1])(1.) + print("debug jaxpr: ", str(jaxpr)) self.assertIn(' sin ', str(jaxpr)) self.assertIn(' cos ', str(jaxpr)) @@ -8188,6 +8337,23 @@ def f_jvp(primals, tangents): ): f(0.5, 0.1, z=1.0) + def test_symbolic_zero_custom_jvp_vmap_doesnt_instantiate(self): + @jax.custom_jvp + def f(x, y): + return y + + def f_jvp(primals, tangents): + (x, y), (x_dot, y_dot) = primals, tangents + assert type(y_dot) is custom_derivatives_public.SymbolicZero + return y, y_dot + + f.defjvp(f_jvp, symbolic_zeros=True) + + def g(x): + return f(x, f(x, 1.)) + + jax.jvp(jax.vmap(g), (jnp.ones(3),), (jnp.ones(3),)) # don't crash + class CustomVJPTest(jtu.JaxTestCase): @@ -10032,6 +10198,38 @@ def tp(r, t): return t / r self.assertAllClose(transpose_unary(f1, x)(x), jax.jit(transpose_unary(f1, x))(x)) + def test_linear_call_type_mismatch(self): + def f(x, y): + def fn(r, x): return x / r + def tp(r, t): return None + return x + jax.custom_derivatives.linear_call(fn, tp, y, x) + + x = jnp.ones(2) * 6. + y = jnp.ones(2) * 3. + f1 = lambda x: f(x, y) + with self.assertRaisesRegex(TypeError, "transpose output pytree"): + transpose_unary(f1, x)(x) + + def test_linear_call_recursion(self): + def f(x): + def fn(_, x): return x + def tp(_, t): return f(t) + return jax.custom_derivatives.linear_call(fn, tp, None, x) + jax.jit(f)(0.1) + + def test_linear_call_grad(self): + def f(x, y): + def fn(r, x): return x / r + def tp(r, t): return t / r + return x + jax.custom_derivatives.linear_call(fn, tp, y, x) + + def f_ref(x, y): + return x + x / y + + x = jnp.array(6.) + y = jnp.array(3.) + self.assertAllClose(jax.grad(f)(x, y), jax.grad(f_ref)(x, y)) + def test_basic(self): def f(x, y): @custom_transpose(jnp.ones(2)) @@ -10334,6 +10532,19 @@ def tp(r, t): return 2 * fn(r, t) self.assertAllClose(f_(x), g_(x)) self.assertAllClose(f_t(x), g_t(x)) + def test_jit_signature_deprecation(self): + fun = lambda x: x + if deprecations.is_accelerated('jax-jit-positional-args'): + with self.assertRaisesRegex(TypeError, r'jit\(\) got some positional-only arguments passed as keyword arguments.*'): + jax.jit(fun=fun) + with self.assertRaisesRegex(TypeError, r'jit\(\) takes 1 positional argument but 2 were given.*'): + jax.jit(fun, None) + else: + with self.assertWarnsRegex(DeprecationWarning, r'jax\.jit: passing fun by keyword is deprecated.*'): + jax.jit(fun=fun) + with self.assertWarnsRegex(DeprecationWarning, r'jax\.jit: passing optional arguments by position is deprecated.*'): + jax.jit(fun, None) + def test_cond(self): def f(x, y): @custom_transpose(jnp.ones(2)) @@ -10382,6 +10593,28 @@ def cond_wrap(f): self.assertAllClose(f_(x), g_(x)) self.assertAllClose(f_t(x), g_t(x)) + def test_compose_custom_jvp(self): + @jax.custom_jvp + def f(x): + return jnp.sin(x) + + @f.defjvp + def f_jvp(primals, tangents): + x, = primals + dx, = tangents + return f(x), g(x, dx) + + @custom_transpose + def g(x, dx): + return jnp.cos(x) * dx + + @g.def_transpose + def gt(x, t): + return jnp.cos(x) * t + + with config.use_direct_linearize(True): + self.assertAllClose(jax.grad(f)(0.5), jnp.cos(0.5)) + class CustomDceTest(jtu.JaxTestCase): @@ -11504,5 +11737,82 @@ def wsc_as_noop(ctx, operand, *args, **kwargs): self.assertNotIn("stablehlo.custom_call @Sharding", lowered_ir) +class InputSavedVJPTest(jtu.JaxTestCase): + + def test_basic(self): + def f(x, y): + return x * y + + primals = 2., 3. + y, f_vjp = api.si_vjp(f, [True, True], *primals) + arg_cts = f_vjp(1., *primals) + self.assertAllClose(y, 6.) + self.assertAllClose(arg_cts, (3., 2.)) + + def test_basic_pass_through_jit(self): + def f(x, y): + return x * y + + @jax.jit + def g(): + primals = 2., 3. + y, f_vjp = api.si_vjp(f, [True, True], *primals) + return y, f_vjp + + @jax.jit + def h(f_vjp): + return f_vjp(1., 2., 3.) + + y, f_vjp = g() + arg_cts = h(f_vjp) + self.assertAllClose(y, 6.) + self.assertAllClose(arg_cts, (3., 2.)) + + def test_basic_unused(self): + f = jnp.sin + primals = 3., + y, f_vjp = api.si_vjp(f, [True], *primals) + x_ct, = f_vjp(1., *primals) + self.assertAllClose(y, jnp.sin(3.)) + self.assertAllClose(x_ct, jnp.cos(3.)) + + with self.assertRaisesRegex(Exception, "not used by the backward pass: x"): + _ = api.si_vjp(f, [True], *primals, allow_unused=False) + + def test_basic_opaque(self): + f = jnp.sin + primals = 3., + with self.assertRaisesRegex(Exception, "the backward pass requires opaque"): + _ = api.si_vjp(f, [True], *primals, allow_opaque=False) + + def test_basic_pytree_error(self): + def f(x): + return [x['hi'] * x['bye']] + + y, f_vjp = api.si_vjp(f, [True], {'hi': 2., 'bye': 3.}) + arg_ct, = f_vjp([1.], {'hi': 2., 'bye': 3.}) + self.assertAllClose(y, [6.]) + self.assertAllClose(arg_ct, {'hi': 3., 'bye': 2.}) + + with self.assertRaisesRegex(ValueError, "but the structures differ"): + f_vjp(1., {'hi': 2.}) + + def test_fsdp(self): + # see https://github.com/jax-ml/jax/pull/27017 for why this is called "fsdp" + def f2(x, w): + x = 1. * x + x = x @ w + x = 2. * x + return x + + x = jnp.ones((3, 4)) + w = jnp.ones((4, 4)) + y, f2_sivjp = api.si_vjp(f2, [False, True], x, w) + y_grad = jnp.ones_like(y) + x_grad, w_grad = f2_sivjp(y_grad, w) + self.assertAllClose(x_grad, 2. * y_grad @ w.T) + self.assertAllClose(w_grad, 2. * x.T @ y_grad) + + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/array_api_skips.txt b/tests/array_api_skips.txt index 2f8d4d1c666f..7534cf6f8acd 100644 --- a/tests/array_api_skips.txt +++ b/tests/array_api_skips.txt @@ -10,6 +10,24 @@ array_api_tests/test_creation_functions.py::test_asarray_arrays # Returns wrong zero sign array_api_tests/test_special_cases.py::test_unary[sign((x_i is -0 or x_i == +0)) -> 0] +array_api_tests/test_special_cases.py::test_iop[__imod__(x1_i is +0 and x2_i < 0) -> -0] +array_api_tests/test_special_cases.py::test_iop[__imod__(x1_i is -0 and x2_i > 0) -> +0] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -0 and x2_i > 0) -> -0] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -0 and x2_i < 0) -> +0] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -0 and x2_i > 0) -> -0] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -0 and x2_i < 0) -> +0] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] +array_api_tests/test_special_cases.py::test_binary[remainder(x1_i is -0 and x2_i > 0) -> +0] +array_api_tests/test_special_cases.py::test_binary[__mod__(x1_i is -0 and x2_i > 0) -> +0] +array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i < 0 and x2_i is -infinity) -> +0] +array_api_tests/test_special_cases.py::test_binary[__mod__(x1_i is +0 and x2_i < 0) -> -0] +array_api_tests/test_special_cases.py::test_binary[remainder(x1_i is +0 and x2_i < 0) -> -0] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is -infinity) -> +0] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -0 and x2_i > 0) -> -0] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -0 and x2_i < 0) -> +0] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is -infinity) -> +0] + # Returns int32 when int64 is expected array_api_tests/test_searching_functions.py::test_searchsorted @@ -19,3 +37,40 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_clip # JAX raises a ValueError rather than the expected IndexError for out-of-bound axis array_api_tests/test_manipulation_functions.py::test_expand_dims + +# Doesn't promote to uint64 +array_api_tests/test_statistical_functions.py::test_cumulative_prod + +# TODO(jakevdp): fix the following failures: + +# Returns NaN rather than inf +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i < 0 and x2_i is +0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i > 0 and x2_i is +0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i < 0 and x2_i is -0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i > 0 and x2_i is -0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i < 0 and x2_i is +0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i < 0 and x2_i is -0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i > 0 and x2_i is +0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i > 0 and x2_i is -0) -> -infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i > 0 and x2_i is +0) -> +infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i > 0 and x2_i is -0) -> -infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i < 0 and x2_i is -0) -> +infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i < 0 and x2_i is +0) -> -infinity] + +# Returns -1.0 rather than 0.0 +array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] \ No newline at end of file diff --git a/tests/array_api_test.py b/tests/array_api_test.py index 250eeb810872..8e4ba275fdd3 100644 --- a/tests/array_api_test.py +++ b/tests/array_api_test.py @@ -26,6 +26,7 @@ import jax.numpy as jnp from jax._src import config, test_util as jtu from jax._src.dtypes import _default_types, canonicalize_dtype +from jax._src import xla_bridge as xb ARRAY_API_NAMESPACE = jnp @@ -275,14 +276,18 @@ def build_dtype_dict(self, dtypes): def test_capabilities_info(self): capabilities = self.info.capabilities() - assert capabilities["boolean indexing"] + assert not capabilities["boolean indexing"] assert not capabilities["data-dependent shapes"] + assert capabilities["max dimensions"] == 64 def test_default_device_info(self): assert self.info.default_device() is None def test_devices_info(self): - assert self.info.devices() == jax.devices() + devices = set(self.info.devices()) + assert None in devices + for backend in xb.backends(): + assert devices.issuperset(jax.devices(backend)) def test_default_dtypes_info(self): _default_dtypes = { diff --git a/tests/array_extensibility_test.py b/tests/array_extensibility_test.py new file mode 100644 index 000000000000..45f12ac06473 --- /dev/null +++ b/tests/array_extensibility_test.py @@ -0,0 +1,577 @@ +# Copyright 2018 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +from typing import Any, Callable, NamedTuple + +from absl.testing import absltest +from absl.testing import parameterized +import numpy as np + +import jax +import jax.numpy as jnp +from jax.typing import ArrayLike +from jax._src import config +from jax._src import test_util as jtu + + +config.parse_flags_with_absl() + + +class JaxArrayWrapper: + """Class that provides a __jax_array__ method.""" + x: ArrayLike + + def __init__(self, x: ArrayLike): + self.x = x + + def __jax_array__(self) -> jax.Array: + return jnp.asarray(self.x) + + +class DuckTypedArrayWithErroringJaxArray: + """Duck-typed array that provides a __jax_array__ method which fails.""" + shape = (2, 3) + dtype = np.dtype('float32') + + def __jax_array__(self): + raise ValueError("jax array was called.") + + +class NumPyAPI(NamedTuple): + fun: Callable[..., Any] + args: list[jax.ShapeDtypeStruct] + kwargs: dict[str, Any] + skip_on_devices: list[str] | None + + def name(self): + return self.fun.__name__ + + def make_args(self, rng): + rng = jtu.rand_default(rng) + return jax.tree.map(lambda arg: rng(arg.shape, arg.dtype), self.args) + + def with_skip_on_devices(self, disabled_devices: list[str]) -> 'NumPyAPI': + return self._replace(skip_on_devices=disabled_devices) + + @classmethod + def sig(cls, fun: Callable[..., Any], *args: Any, **kwargs: Any) -> 'NumPyAPI': + return cls(fun, args, kwargs, None) + + +class ShapeDtype: + """Shortcut for specifying ShapeDtypeStruct.""" + def __init__(self, dtype): + self.dtype = jax.dtypes.canonicalize_dtype(dtype) + def __getitem__(self, shape) -> jax.ShapeDtypeStruct: + if isinstance(shape, int): + shape = (shape,) + return jax.ShapeDtypeStruct(shape, self.dtype) + +Bool = ShapeDtype(bool) +Int = ShapeDtype(int) +UInt = ShapeDtype('uint32') +Uint8 = ShapeDtype('uint8') +Float = ShapeDtype(float) +Complex = ShapeDtype(complex) + + +# NumPy namespace objects skipped in the enumeration below, mainly because +# they are not functions or do not take arrays as positional arguments. +SKIPPED_APIS = [ + 'apply_along_axis', + 'apply_over_axes', + 'arange', + 'array_str', + 'array_repr', + 'astype', + 'bartlett', + 'bfloat16', + 'blackman', + 'block', + 'bool', + 'bool_', + 'broadcast_shapes', + 'c_', + 'can_cast', + 'cdouble', + 'character', + 'complex128', + 'complex64', + 'complex_', + 'complexfloating', + 'csingle', + 'diag_indices', + 'double', + 'dtype', + 'e', + 'einsum', + 'einsum_path', + 'euler_gamma', + 'empty', + 'eye', + 'finfo', + 'flexible', + 'float_', + 'float16', + 'float32', + 'float4_e2m1fn', + 'float64', + 'float8_e3m4', + 'float8_e4m3', + 'float8_e4m3b11fnuz', + 'float8_e4m3fn', + 'float8_e4m3fnuz', + 'float8_e5m2', + 'float8_e5m2fnuz', + 'float8_e8m0fnu', + 'floating', + 'from_dlpack', + 'frombuffer', + 'fromfile', + 'fromfunction', + 'fromiter', + 'frompyfunc', + 'fromstring', + 'full', + 'generic', + 'geomspace', + 'get_printoptions', + 'gradient', + 'hamming', + 'hanning', + 'identity', + 'iinfo', + 'index_exp', + 'indices', + 'inexact', + 'inf', + 'int16', + 'int2', + 'int32', + 'int4', + 'int64', + 'int8', + 'int_', + 'integer', + 'isdtype', + 'issubdtype' + 'iterable' + 'kaiser' + 'kron' + 'ix_', + 'linalg', + 'linspace', + 'load', + 'logspace', + 'mask_indices', + 'mgrid', + 'nan', + 'ndarray', + 'newaxis', + 'number', + 'object_', + 'ogrid', + 'ones', + 'pi', + 'printoptions', + 'promote_types' + 'r_', + 'result_type', + 's_', + 'save', + 'savez', + 'set_printoptions', + 'signedinteger', + 'single', + 'tri', + 'tril_indices', + 'triu_indices', + 'ufunc', + 'uint', + 'uint16', + 'uint2', + 'uint32', + 'uint4', + 'uint64', + 'uint8', + 'unsignedinteger', + 'vectorize', + 'zeros', +] + +# TODO(jakevdp): commented APIs are ones which do not yet support +# __jax_array__ on inputs. We should fix these! +NUMPY_APIS = [ + NumPyAPI.sig(jnp.abs, Float[5]), + NumPyAPI.sig(jnp.absolute, Float[5]), + NumPyAPI.sig(jnp.acos, Float[5]), + NumPyAPI.sig(jnp.acosh, Float[5]), + NumPyAPI.sig(jnp.add, Float[5], Float[5]), + NumPyAPI.sig(jnp.all, Bool[5]), + NumPyAPI.sig(jnp.allclose, Float[5], Float[5]), + NumPyAPI.sig(jnp.amax, Float[5]), + NumPyAPI.sig(jnp.amin, Float[5]), + NumPyAPI.sig(jnp.angle, Float[5]), + NumPyAPI.sig(jnp.any, Float[5]), + NumPyAPI.sig(jnp.append, Float[10], Float[()]), + NumPyAPI.sig(jnp.arccos, Float[5]), + NumPyAPI.sig(jnp.arccosh, Float[5]), + NumPyAPI.sig(jnp.arcsin, Float[5]), + NumPyAPI.sig(jnp.arcsinh, Float[5]), + NumPyAPI.sig(jnp.arctan, Float[5]), + NumPyAPI.sig(jnp.arctan2, Float[5], Float[5]), + NumPyAPI.sig(jnp.arctanh, Float[5]), + NumPyAPI.sig(jnp.argmax, Float[10]), + NumPyAPI.sig(jnp.argmin, Float[10]), + NumPyAPI.sig(jnp.argpartition, Float[10], kth=5), + NumPyAPI.sig(jnp.argsort, Float[10]), + NumPyAPI.sig(jnp.argwhere, Float[10]), + NumPyAPI.sig(jnp.around, Float[5]), + NumPyAPI.sig(jnp.array, Float[5]), + NumPyAPI.sig(jnp.array_equal, Float[5], Float[5]), + NumPyAPI.sig(jnp.array_equiv, Float[5], Float[5]), + NumPyAPI.sig(jnp.array_split, Float[9], indices_or_sections=3), + NumPyAPI.sig(jnp.asarray, Float[5]), + NumPyAPI.sig(jnp.asin, Float[5]), + NumPyAPI.sig(jnp.asinh, Float[5]), + NumPyAPI.sig(jnp.atan, Float[5]), + NumPyAPI.sig(jnp.atan2, Float[5], Float[5]), + NumPyAPI.sig(jnp.atanh, Float[5]), + NumPyAPI.sig(jnp.atleast_1d, Float[5]), + NumPyAPI.sig(jnp.atleast_2d, Float[5]), + NumPyAPI.sig(jnp.atleast_3d, Float[5]), + NumPyAPI.sig(jnp.average, Float[10]), + NumPyAPI.sig(jnp.bincount, Int[10]), + NumPyAPI.sig(jnp.bitwise_and, Int[5], Int[5]), + NumPyAPI.sig(jnp.bitwise_count, Int[5]), + NumPyAPI.sig(jnp.bitwise_invert, Int[5]), + NumPyAPI.sig(jnp.bitwise_left_shift, Int[5], Int[5]), + NumPyAPI.sig(jnp.bitwise_not, Int[5]), + NumPyAPI.sig(jnp.bitwise_or, Int[5], Int[5]), + NumPyAPI.sig(jnp.bitwise_right_shift, Int[5], Int[5]), + NumPyAPI.sig(jnp.bitwise_xor, Int[5], Int[5]), + NumPyAPI.sig(jnp.broadcast_arrays, Float[5]), + NumPyAPI.sig(jnp.broadcast_to, Float[()], shape=(10,)), + NumPyAPI.sig(jnp.cbrt, Float[5]), + NumPyAPI.sig(jnp.ceil, Float[5]), + NumPyAPI.sig(jnp.choose, Int[3], [Float[3], Float[3], Float[3]], mode='clip'), + NumPyAPI.sig(jnp.clip, Float[5]), + NumPyAPI.sig(jnp.column_stack, [Float[5], Float[5], Float[5]]), + NumPyAPI.sig(jnp.compress, Float[10], Bool[10]), + NumPyAPI.sig(jnp.concat, [Float[5], Float[5]]), + NumPyAPI.sig(jnp.concatenate, [Float[5], Float[5]]), + NumPyAPI.sig(jnp.conj, Float[5]), + NumPyAPI.sig(jnp.conjugate, Float[5]), + NumPyAPI.sig(jnp.convolve, Float[7], Float[3]), + NumPyAPI.sig(jnp.copy, Float[5]), + NumPyAPI.sig(jnp.copysign, Float[5], Float[5]), + NumPyAPI.sig(jnp.corrcoef, Float[7], Float[7]), + NumPyAPI.sig(jnp.correlate, Float[7], Float[3]), + NumPyAPI.sig(jnp.cos, Float[5]), + NumPyAPI.sig(jnp.cosh, Float[5]), + NumPyAPI.sig(jnp.count_nonzero, Float[10]), + NumPyAPI.sig(jnp.cov, Float[10]), + NumPyAPI.sig(jnp.cross, Float[3], Float[3]), + NumPyAPI.sig(jnp.cumprod, Float[5]), + NumPyAPI.sig(jnp.cumsum, Float[5]), + NumPyAPI.sig(jnp.cumulative_prod, Float[5]), + NumPyAPI.sig(jnp.cumulative_sum, Float[5]), + NumPyAPI.sig(jnp.deg2rad, Float[5]), + NumPyAPI.sig(jnp.degrees, Float[5]), + NumPyAPI.sig(jnp.delete, Float[5], Int[()]), + NumPyAPI.sig(jnp.diag, Float[5]), + NumPyAPI.sig(jnp.diag_indices_from, Float[5, 5]), + NumPyAPI.sig(jnp.diagflat, Float[5]), + NumPyAPI.sig(jnp.diagonal, Float[5, 5]), + NumPyAPI.sig(jnp.diff, Float[5]), + NumPyAPI.sig(jnp.digitize, Float[5], Float[5]), + NumPyAPI.sig(jnp.divide, Float[5], Float[5]), + NumPyAPI.sig(jnp.divmod, Float[5], Float[5]), + NumPyAPI.sig(jnp.dot, Float[5], Float[5]), + NumPyAPI.sig(jnp.dsplit, Float[3, 5, 6], indices_or_sections=2), + NumPyAPI.sig(jnp.dstack, [Float[3, 5, 1], Float[3, 5, 3]]), + NumPyAPI.sig(jnp.ediff1d, Float[5]), + NumPyAPI.sig(jnp.empty_like, Float[5]), + NumPyAPI.sig(jnp.equal, Float[5], Float[5]), + NumPyAPI.sig(jnp.exp, Float[5]), + NumPyAPI.sig(jnp.exp2, Float[5]), + NumPyAPI.sig(jnp.expand_dims, Float[5], axis=0), + NumPyAPI.sig(jnp.expm1, Float[5]), + NumPyAPI.sig(jnp.extract, Bool[5], Float[5]), + NumPyAPI.sig(jnp.fabs, Float[5]), + NumPyAPI.sig(jnp.fft.fft, Float[5]), + NumPyAPI.sig(jnp.fft.fft2, Float[5, 5]), + NumPyAPI.sig(jnp.fft.ifft, Float[5]), + NumPyAPI.sig(jnp.fft.ifft2, Float[5, 5]), + NumPyAPI.sig(jnp.fill_diagonal, Float[5, 5], Float[()], inplace=False), + NumPyAPI.sig(jnp.fix, Float[5]), + NumPyAPI.sig(jnp.flatnonzero, Float[5]), + NumPyAPI.sig(jnp.flip, Float[5]), + NumPyAPI.sig(jnp.fliplr, Float[5, 5]), + NumPyAPI.sig(jnp.flipud, Float[5, 5]), + NumPyAPI.sig(jnp.float_power, Float[5], Float[5]), + NumPyAPI.sig(jnp.floor, Float[5]), + NumPyAPI.sig(jnp.floor_divide, Float[5], Float[5]), + NumPyAPI.sig(jnp.fmax, Float[5], Float[5]), + NumPyAPI.sig(jnp.fmin, Float[5], Float[5]), + NumPyAPI.sig(jnp.fmod, Float[5], Float[5]), + NumPyAPI.sig(jnp.frexp, Float[5]), + NumPyAPI.sig(jnp.full_like, Float[5], Float[()]), + NumPyAPI.sig(jnp.gcd, Int[5], Int[5]), + NumPyAPI.sig(jnp.greater, Float[5], Float[5]), + NumPyAPI.sig(jnp.greater_equal, Float[5], Float[5]), + NumPyAPI.sig(jnp.heaviside, Float[5], Float[5]), + NumPyAPI.sig(jnp.histogram, Float[5]), + NumPyAPI.sig(jnp.histogram2d, Float[5], Float[5]), + NumPyAPI.sig(jnp.histogram_bin_edges, Float[5]), + NumPyAPI.sig(jnp.histogramdd, Float[5, 3]), + NumPyAPI.sig(jnp.hsplit, Float[3, 6], indices_or_sections=2), + NumPyAPI.sig(jnp.hstack, (Float[5], Float[5])), + NumPyAPI.sig(jnp.hypot, Float[5], Float[5]), + NumPyAPI.sig(jnp.i0, Float[5]), + NumPyAPI.sig(jnp.imag, Complex[5]), + NumPyAPI.sig(jnp.inner, Float[5], Float[5]), + NumPyAPI.sig(jnp.insert, Float[5], Int[()], Float[2]), + NumPyAPI.sig(jnp.interp, Float[10], Float[5], Float[5]), + NumPyAPI.sig(jnp.intersect1d, Int[5], Int[5]), + NumPyAPI.sig(jnp.invert, Int[5]), + NumPyAPI.sig(jnp.isclose, Float[5], Float[5]), + NumPyAPI.sig(jnp.iscomplex, Float[5]), + NumPyAPI.sig(jnp.iscomplexobj, Complex[5]), + NumPyAPI.sig(jnp.isfinite, Float[5]), + NumPyAPI.sig(jnp.isin, Int[5], Int[10]), + NumPyAPI.sig(jnp.isinf, Float[5]), + NumPyAPI.sig(jnp.isnan, Float[5]), + NumPyAPI.sig(jnp.isneginf, Float[5]), + NumPyAPI.sig(jnp.isposinf, Float[5]), + NumPyAPI.sig(jnp.isreal, Float[5]), + NumPyAPI.sig(jnp.isrealobj, Float[5]), + NumPyAPI.sig(jnp.isscalar, Float[()]), + NumPyAPI.sig(jnp.lcm, Int[5], Int[5]), + NumPyAPI.sig(jnp.ldexp, Float[5], Int[5]), + NumPyAPI.sig(jnp.left_shift, Int[5], Int[5]), + NumPyAPI.sig(jnp.less, Float[5], Float[5]), + NumPyAPI.sig(jnp.less_equal, Float[5], Float[5]), + NumPyAPI.sig(jnp.lexsort, [Float[5], Float[5]]), + NumPyAPI.sig(jnp.log, Float[5]), + NumPyAPI.sig(jnp.log10, Float[5]), + NumPyAPI.sig(jnp.log1p, Float[5]), + NumPyAPI.sig(jnp.log2, Float[5]), + NumPyAPI.sig(jnp.logaddexp, Float[5], Float[5]), + NumPyAPI.sig(jnp.logaddexp2, Float[5], Float[5]), + NumPyAPI.sig(jnp.logical_and, Int[5], Int[5]), + NumPyAPI.sig(jnp.logical_not, Int[5]), + NumPyAPI.sig(jnp.logical_or, Int[5], Int[5]), + NumPyAPI.sig(jnp.logical_xor, Int[5], Int[5]), + NumPyAPI.sig(jnp.matmul, Float[5, 5], Float[5]), + NumPyAPI.sig(jnp.matrix_transpose, Float[5, 6]), + NumPyAPI.sig(jnp.matvec, Float[5, 5], Float[5]), + NumPyAPI.sig(jnp.max, Float[5]), + NumPyAPI.sig(jnp.maximum, Float[5], Float[5]), + NumPyAPI.sig(jnp.mean, Float[5]), + NumPyAPI.sig(jnp.median, Float[5]), + NumPyAPI.sig(jnp.meshgrid, Float[5], Float[5]), + NumPyAPI.sig(jnp.min, Float[5]), + NumPyAPI.sig(jnp.minimum, Float[5], Float[5]), + NumPyAPI.sig(jnp.mod, Float[5], Float[5]), + NumPyAPI.sig(jnp.modf, Float[5]), + NumPyAPI.sig(jnp.moveaxis, Float[5, 3], source=0, destination=1), + NumPyAPI.sig(jnp.multiply, Float[5], Float[5]), + NumPyAPI.sig(jnp.nan_to_num, Float[5]), + NumPyAPI.sig(jnp.nanargmax, Float[5]), + NumPyAPI.sig(jnp.nanargmin, Float[5]), + NumPyAPI.sig(jnp.nancumprod, Float[5]), + NumPyAPI.sig(jnp.nancumsum, Float[5]), + NumPyAPI.sig(jnp.nanmax, Float[5]), + NumPyAPI.sig(jnp.nanmean, Float[5]), + NumPyAPI.sig(jnp.nanmedian, Float[5]), + NumPyAPI.sig(jnp.nanmin, Float[5]), + NumPyAPI.sig(jnp.nanpercentile, Float[5], q=75), + NumPyAPI.sig(jnp.nanprod, Float[5]), + NumPyAPI.sig(jnp.nanquantile, Float[5], q=0.75), + NumPyAPI.sig(jnp.nanstd, Float[5]), + NumPyAPI.sig(jnp.nansum, Float[5]), + NumPyAPI.sig(jnp.nanvar, Float[5]), + NumPyAPI.sig(jnp.ndim, Float[5]), + NumPyAPI.sig(jnp.negative, Float[5]), + NumPyAPI.sig(jnp.nextafter, Float[5], Float[5]), + NumPyAPI.sig(jnp.nonzero, Float[5]), + NumPyAPI.sig(jnp.not_equal, Float[5], Float[5]), + NumPyAPI.sig(jnp.ones_like, Float[5]), + NumPyAPI.sig(jnp.outer, Float[5], Float[5]), + NumPyAPI.sig(jnp.packbits, Int[5]), + NumPyAPI.sig(jnp.pad, Float[5], pad_width=2), + NumPyAPI.sig(jnp.partition, Float[5], kth=3), + NumPyAPI.sig(jnp.percentile, Float[5], q=75), + NumPyAPI.sig(jnp.permute_dims, Float[3, 5], axes=(1, 0)), + NumPyAPI.sig(jnp.piecewise, Float[5], [Bool[5], Bool[5]], funclist=[jnp.sin, jnp.cos]), + NumPyAPI.sig(jnp.place, Float[5], Bool[5], Float[3], inplace=False), + NumPyAPI.sig(jnp.poly, Float[5]), + NumPyAPI.sig(jnp.polyadd, Float[5], Float[5]), + NumPyAPI.sig(jnp.polyder, Float[5]), + NumPyAPI.sig(jnp.polydiv, Float[5], Float[5]), + NumPyAPI.sig(jnp.polyfit, Float[5], Float[5], deg=2), + NumPyAPI.sig(jnp.polyint, Float[5]), + NumPyAPI.sig(jnp.polymul, Float[5], Float[5]), + NumPyAPI.sig(jnp.polysub, Float[5], Float[5]), + NumPyAPI.sig(jnp.polyval, Float[5], Float[10]), + NumPyAPI.sig(jnp.positive, Float[5]), + NumPyAPI.sig(jnp.pow, Float[5], Float[5]), + NumPyAPI.sig(jnp.power, Float[5], Float[5]), + NumPyAPI.sig(jnp.prod, Float[5]), + NumPyAPI.sig(jnp.ptp, Float[5]), + NumPyAPI.sig(jnp.put, Float[5], Int[()], Float[()], inplace=False), + NumPyAPI.sig(jnp.put_along_axis, Float[5], Int[1], Float[1], axis=0, inplace=False), + NumPyAPI.sig(jnp.quantile, Float[5], q=0.75), + NumPyAPI.sig(jnp.rad2deg, Float[5]), + NumPyAPI.sig(jnp.radians, Float[5]), + NumPyAPI.sig(jnp.ravel, Float[5]), + NumPyAPI.sig(jnp.ravel_multi_index, [Uint8[5], Uint8[5]], dims=(8, 9)), + NumPyAPI.sig(jnp.real, Complex[5]), + NumPyAPI.sig(jnp.reciprocal, Float[5]), + NumPyAPI.sig(jnp.remainder, Float[5], Float[5]), + NumPyAPI.sig(jnp.repeat, Float[5], repeats=np.array([2, 3, 1, 5, 4])), + NumPyAPI.sig(jnp.reshape, Float[6], shape=(2, 3)), + NumPyAPI.sig(jnp.resize, Float[6], new_shape=(2, 3)), + NumPyAPI.sig(jnp.right_shift, Int[5], Int[5]), + NumPyAPI.sig(jnp.rint, Float[5]), + NumPyAPI.sig(jnp.roll, Float[5], Int[1]), + NumPyAPI.sig(jnp.rollaxis, Float[5, 4], axis=1), + NumPyAPI.sig(jnp.roots, Float[5]).with_skip_on_devices(['tpu']), + NumPyAPI.sig(jnp.rot90, Float[5, 3]), + NumPyAPI.sig(jnp.round, Float[5]), + NumPyAPI.sig(jnp.searchsorted, Float[5], Float[5]), + NumPyAPI.sig(jnp.select, [Bool[5], Bool[5]], [Float[5], Float[5]], Float[()]), + NumPyAPI.sig(jnp.setdiff1d, Int[5], Int[5]), + NumPyAPI.sig(jnp.setxor1d, Int[5], Int[5]), + NumPyAPI.sig(jnp.shape, Float[5, 3]), + NumPyAPI.sig(jnp.sign, Float[5]), + NumPyAPI.sig(jnp.signbit, Float[5]), + NumPyAPI.sig(jnp.sin, Float[5]), + NumPyAPI.sig(jnp.sinc, Float[5]), + NumPyAPI.sig(jnp.sinh, Float[5]), + NumPyAPI.sig(jnp.size, Float[5]), + NumPyAPI.sig(jnp.sort, Float[5]), + NumPyAPI.sig(jnp.sort_complex, Complex[5]), + NumPyAPI.sig(jnp.spacing, Float[5]), + NumPyAPI.sig(jnp.split, Float[6], indices_or_sections=2), + NumPyAPI.sig(jnp.sqrt, Float[5]), + NumPyAPI.sig(jnp.square, Float[5]), + NumPyAPI.sig(jnp.squeeze, Float[5]), + NumPyAPI.sig(jnp.stack, [Float[2, 3], Float[2, 3]], axis=1), + NumPyAPI.sig(jnp.std, Float[5]), + NumPyAPI.sig(jnp.subtract, Float[5], Float[5]), + NumPyAPI.sig(jnp.sum, Float[5]), + NumPyAPI.sig(jnp.swapaxes, Float[3, 5], axis1=1, axis2=0), + NumPyAPI.sig(jnp.take, Float[5], Int[2]), + NumPyAPI.sig(jnp.take_along_axis, Float[5], Int[2], axis=0), + NumPyAPI.sig(jnp.tan, Float[5]), + NumPyAPI.sig(jnp.tanh, Float[5]), + NumPyAPI.sig(jnp.tensordot, Float[2, 3, 4], Float[3, 4, 5]), + NumPyAPI.sig(jnp.tile, Float[5], reps=(2,)), + NumPyAPI.sig(jnp.trace, Float[5, 5]), + NumPyAPI.sig(jnp.transpose, Float[5, 6]), + NumPyAPI.sig(jnp.trapezoid, Float[5]), + NumPyAPI.sig(jnp.tril, Float[5, 6]), + NumPyAPI.sig(jnp.tril_indices_from, Float[5, 6]), + NumPyAPI.sig(jnp.trim_zeros, Float[5]), + NumPyAPI.sig(jnp.triu, Float[5, 6]), + NumPyAPI.sig(jnp.triu_indices_from, Float[5, 6]), + NumPyAPI.sig(jnp.true_divide, Float[5], Float[5]), + NumPyAPI.sig(jnp.trunc, Float[5]), + NumPyAPI.sig(jnp.union1d, Int[5], Int[5]), + NumPyAPI.sig(jnp.unique, Int[10]), + NumPyAPI.sig(jnp.unique_all, Int[10]), + NumPyAPI.sig(jnp.unique_counts, Int[10]), + NumPyAPI.sig(jnp.unique_inverse, Int[10]), + NumPyAPI.sig(jnp.unique_values, Int[10]), + NumPyAPI.sig(jnp.unpackbits, Uint8[8]), + NumPyAPI.sig(jnp.unravel_index, Int[5], shape=(2, 3)), + NumPyAPI.sig(jnp.unstack, Float[5]), + NumPyAPI.sig(jnp.unwrap, Float[5]), + NumPyAPI.sig(jnp.vander, Float[5]), + NumPyAPI.sig(jnp.var, Float[5]), + NumPyAPI.sig(jnp.vdot, Float[5], Float[5]), + NumPyAPI.sig(jnp.vecdot, Float[5], Float[5]), + NumPyAPI.sig(jnp.vecmat, Float[5], Float[5, 3]), + NumPyAPI.sig(jnp.vsplit, Float[6], indices_or_sections=2), + NumPyAPI.sig(jnp.vstack, [Float[5], Float[2, 5]]), + NumPyAPI.sig(jnp.where, Bool[5], Float[5], Float[5]), + NumPyAPI.sig(jnp.zeros_like, Float[5]), +] + + +class JaxArrayTests(jtu.JaxTestCase): + @parameterized.named_parameters( + {'testcase_name': api.name(), 'api': api} for api in NUMPY_APIS) + def test_numpy_api_supports_jax_array(self, api): + if api.skip_on_devices and jtu.test_device_matches(api.skip_on_devices): + self.skipTest(f'{api.name()} not supported on {api.skip_on_devices}') + fun = api.fun + args = api.make_args(self.rng()) + wrapped_args = jax.tree.map(JaxArrayWrapper, args) + kwargs = api.kwargs + + expected = fun(*args, **kwargs) + wrapped = fun(*wrapped_args, **kwargs) + + self.assertAllClose(wrapped, expected, atol=0, rtol=0) + + @parameterized.named_parameters( + {'testcase_name': func.__name__, 'func': func} + for func in [jnp.zeros_like, jnp.ones_like, jnp.empty_like, jnp.full_like] + ) + def test_array_creation_from_duck_typed_array(self, func): + # Ensure that jnp.*_like prefers shape/dtype over __jax_array__ when + # both methods are available. + if func is jnp.full_like: + func = functools.partial(func, fill_value=2.0) + obj = DuckTypedArrayWithErroringJaxArray() + + # The test relies on this failing + with self.assertRaises(ValueError): + jnp.asarray(obj) + + result = func(obj) + self.assertIsInstance(result, jax.Array) + self.assertEqual(result.shape, obj.shape) + self.assertEqual(result.dtype, obj.dtype) + + @parameterized.named_parameters( + {"testcase_name": "subscript-form", "args": ("jk,k->j", Float[5, 3], Float[3])}, + {"testcase_name": "index-form", "args": (Float[5, 3], (0, 1), Float[3], (1,), (0,))}, + ) + def test_einsum(self, args): + rng = jtu.rand_default(self.rng()) + def make_arg(arg): + if isinstance(arg, jax.ShapeDtypeStruct): + return rng(arg.shape, arg.dtype) + return arg + args = jax.tree.map(make_arg, args) + + def wrap_array(arg): + if isinstance(arg, (jax.Array, np.ndarray)): + return JaxArrayWrapper(arg) + return arg + wrapped_args = jax.tree.map(wrap_array, args) + + expected = jnp.einsum(*args) + actual = jnp.einsum(*wrapped_args) + + self.assertAllClose(actual, expected, atol=0, rtol=0) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/array_interoperability_test.py b/tests/array_interoperability_test.py index 80a4d8ef5a25..adfd34627e76 100644 --- a/tests/array_interoperability_test.py +++ b/tests/array_interoperability_test.py @@ -95,6 +95,10 @@ def setUp(self): message="Calling from_dlpack with a DLPack tensor", category=DeprecationWarning, ) + @jtu.ignore_warning( + message="jax.dlpack.to_dlpack was deprecated.*", + category=DeprecationWarning, + ) def testJaxRoundTrip(self, shape, dtype, copy, use_stream): rng = jtu.rand_default(self.rng()) np = rng(shape, dtype) @@ -107,6 +111,8 @@ def _check_copy(x: jax.Array, y: jax.Array, expect_copy): x = jax.device_put(np, jax.devices("cpu")[0]) device = jax.devices("gpu")[0] y = jax.device_put(x, device) + # TODO(parkers): Remove after setting 'stream' properly below. + jax.block_until_ready(y) dl_device = y.__dlpack_device__() if use_stream: stream = tuple(y.devices())[0].get_stream_for_external_ready_events() @@ -149,6 +155,8 @@ def testJaxArrayRoundTrip(self, shape, dtype, gpu): raise unittest.SkipTest("Skipping GPU test case on CPU") device = jax.devices("gpu" if gpu else "cpu")[0] x = jax.device_put(np, device) + # TODO(parkers): Remove after setting 'stream' properly. + jax.block_until_ready(x) y = jax.dlpack.from_dlpack(x) self.assertEqual(y.devices(), {device}) self.assertAllClose(np.astype(x.dtype), y) @@ -188,6 +196,10 @@ def testTensorFlowToJax(self, shape, dtype): dtype=dlpack_dtypes, ) @unittest.skipIf(not tf, "Test requires TensorFlow") + @jtu.ignore_warning( + message="jax.dlpack.to_dlpack was deprecated.*", + category=DeprecationWarning, + ) def testJaxToTensorFlow(self, shape, dtype): if (not config.enable_x64.value and dtype in [jnp.int64, jnp.uint64, jnp.float64]): @@ -198,6 +210,8 @@ def testJaxToTensorFlow(self, shape, dtype): rng = jtu.rand_default(self.rng()) np = rng(shape, dtype) x = jnp.array(np) + # TODO(parkers): Remove after setting 'stream' properly. + jax.block_until_ready(x) # TODO(b/171320191): this line works around a missing context initialization # bug in TensorFlow. _ = tf.add(1, 1) @@ -319,6 +333,8 @@ def testJaxToCuPy(self, shape, dtype): rng = jtu.rand_default(self.rng()) x = rng(shape, dtype) y = jnp.array(x) + # TODO(parkers): Remove after setting 'stream' properly. + jax.block_until_ready(y) z = cupy.asarray(y) self.assertEqual(y.__cuda_array_interface__["data"][0], z.__cuda_array_interface__["data"][0]) @@ -354,6 +370,8 @@ def testCaiToJax(self, shape, dtype): device = jax.devices('cuda')[-1] with jax.default_device(device): y = jnp.array(x, dtype=dtype) + # TODO(parkers): Remove after setting 'stream' properly below. + jax.block_until_ready(y) self.assertEqual(y.dtype, dtype) # Using a jax array CAI provider support to construct an object diff --git a/tests/array_test.py b/tests/array_test.py index cc8990828ded..1780213bcc61 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -29,9 +29,10 @@ from jax._src import test_util as jtu from jax._src import xla_bridge as xb from jax._src.lib import xla_client as xc +from jax._src.lib import jaxlib_extension_version from jax._src.lib.mlir import dialects, ir from jax._src.util import safe_zip -from jax._src.mesh import AxisType +from jax._src.mesh import AxisType, AbstractMesh from jax._src.sharding import common_devices_indices_map from jax._src.sharding_impls import ( _op_sharding_to_pos_sharding, pmap_sharding_devices_indices_map, @@ -368,8 +369,6 @@ def test_different_devices_in_arrays_than_sharding(self): array.ArrayImpl(core.ShapedArray(shape, np.float32), s, bufs, committed=True) def test_duplicated_devices_in_arrays(self): - if xc._version <= 274: - self.skipTest('Test requires jaxlib version 275') shape = (8, 2) mesh = jtu.create_mesh((1, 2), ('x', 'y')) # Sharding device ids = {0, 1} @@ -657,12 +656,15 @@ def f(x): output_shardings._to_xla_hlo_sharding(x_dummy.ndim), s._to_xla_hlo_sharding(x_dummy.ndim))) - # TODO(skyewm): remove this test when we can remove the workaround manual - # defragment API - @jtu.skip_on_devices('cpu') # defragment not implemented for TFRT CPU + # TODO(b/399879011): GPU is the only platform that has an implementation for + # this, which exists in py_client.cc. Ideally, this would be replaced with + # some kind of auto-defrag-on-OOM. + @jtu.run_on_devices('gpu') def test_defragment(self): + # Since the GPU implementation is in py_client.cc, it cannot be exposed via + # the PjRt C API. if xb.using_pjrt_c_api(): - self.skipTest("Manual defragment not exposed via PJRT C API") + self.skipTest('Manual defragment not exposed via PJRT C API') # Create a few arrays global_mesh = jtu.create_mesh((jax.local_device_count(),), ('x',)) @@ -675,7 +677,7 @@ def test_defragment(self): # Delete one of them arr2.delete() - # Defragment + # Defragment. xb.get_backend().defragment() # Sanity check remaining arrays @@ -897,7 +899,7 @@ def test_op_sharding_indices(self, pspec): shape = (8, 4) mesh = jtu.create_mesh((4, 2), ('x', 'y')) mps = jax.sharding.NamedSharding(mesh, pspec) - ops = jax.sharding.GSPMDSharding( + ops = GSPMDSharding( list(mesh.devices.flat), mps._to_xla_hlo_sharding(len(shape))) self.assertDictEqual( ops.devices_indices_map(shape), mps.devices_indices_map(shape)) @@ -973,7 +975,7 @@ def test_gspmd_sharding_repr(self): op.tile_assignment_dimensions = [4, 1, 2] op.tile_assignment_devices = [0, 1, 2, 3, 4, 5, 6, 7] op.replicate_on_last_tile_dim = True - s = jax.sharding.GSPMDSharding(jax.devices(), op) + s = GSPMDSharding(jax.devices(), op) # memory kind also appears in the repr but only for TPU. self.assertIn( 'GSPMDSharding({devices=[4,1,2]0,1,2,3,4,5,6,7 ' @@ -981,7 +983,7 @@ def test_gspmd_sharding_repr(self): op2 = xc.OpSharding() op2.type = xc.OpSharding.Type.REPLICATED - s2 = jax.sharding.GSPMDSharding(jax.devices(), op2) + s2 = GSPMDSharding(jax.devices(), op2) # memory kind also appears in the repr but only for TPU. self.assertIn('GSPMDSharding({replicated}', repr(s2)) @@ -1006,7 +1008,7 @@ def test_positional_sharding_op_sharding_lowering( mps = jax.sharding.NamedSharding(mesh, pspec) devices = jax.local_devices()[:8] # Taking up to 8 devices - devices_sharding = jax.sharding.PositionalSharding(devices) + devices_sharding = PositionalSharding(devices) devices_sharding = devices_sharding.reshape(shape).replicate(axes) if transpose: devices_sharding = devices_sharding.T @@ -1108,7 +1110,7 @@ def test_devices_sharding_respects_init_mesh_shape(self): mesh = jtu.create_mesh((4, 2), ('x', 'y')) mps = jax.sharding.NamedSharding(mesh, P('x', 'y')) - devices_sharding = jax.sharding.PositionalSharding(mesh.devices) + devices_sharding = PositionalSharding(mesh.devices) op1 = mps._to_xla_hlo_sharding(len(value_shape)) op2 = devices_sharding._to_xla_hlo_sharding(len(value_shape)) @@ -1127,7 +1129,7 @@ def test_pmap_sharding_repr(self): def test_positional_sharding_repr(self): if jax.device_count() < 2: self.skipTest('Test needs >= 2 devices.') - s = jax.sharding.PositionalSharding(jax.devices()).reshape(jax.device_count(), 1) + s = PositionalSharding(jax.devices()).reshape(jax.device_count(), 1) repr(s) # doesn't crash str(s) # doesn't crash @@ -1198,9 +1200,9 @@ def test_are_shardings_equivalent(self): op1 = xc.OpSharding() op1.type = xc.OpSharding.Type.REPLICATED - s6 = jax.sharding.GSPMDSharding([jax.devices()[0]], op1) + s6 = GSPMDSharding([jax.devices()[0]], op1) - s7 = jax.sharding.GSPMDSharding(jax.devices(), op1) + s7 = GSPMDSharding(jax.devices(), op1) # The OpSharding is replicated but the Sharding itself are on different # devices. @@ -1210,7 +1212,7 @@ def test_are_shardings_equivalent(self): op2.type = xc.OpSharding.Type.OTHER op2.tile_assignment_devices = [0, 1] op2.tile_assignment_dimensions = [2, 1] - s8 = jax.sharding.GSPMDSharding(list(mesh2.devices.flat), op2) + s8 = GSPMDSharding(list(mesh2.devices.flat), op2) self.assertTrue(s1.is_equivalent_to(s6, 2)) self.assertTrue(s5.is_equivalent_to(s8, 2)) @@ -1223,7 +1225,7 @@ def test_are_shardings_equivalent(self): op3.tile_assignment_devices = [0, 1] op3.tile_assignment_dimensions = [1, 1, 2] op3.replicate_on_last_tile_dim = True - s10 = jax.sharding.GSPMDSharding(list(mesh2.devices.flat), op3) + s10 = GSPMDSharding(list(mesh2.devices.flat), op3) self.assertTrue(s9.is_equivalent_to(s10, 2)) @@ -1301,6 +1303,18 @@ def f(x): with self.assertRaisesRegex(TypeError, msg): jax.jit(f)(x) + def test_make_array_from_single_device_arrays_tuple(self): + mesh = jtu.create_mesh((2, 2), ('x', 'y')) + shape = (8, 8) + s = jax.sharding.NamedSharding(mesh, P('x', 'y')) + inp_data = np.arange(math.prod(shape)).reshape(shape) + + arrays = tuple( + jax.device_put(inp_data[index], d) + for d, index in s.addressable_devices_indices_map(shape).items()) + + jax.make_array_from_single_device_arrays(shape, s, arrays) # doesn't crash + def test_make_array_from_single_device_arrays_bad_inputs(self): x = jnp.arange(10) mesh = jtu.create_mesh((2,), ('x',)) @@ -1363,6 +1377,16 @@ def test_mesh_axis_types_mismatch(self): jax.sharding.AbstractMesh((2, 1), ('x', 'y'), axis_types=jax.sharding.AxisType.Auto) + with self.assertRaisesRegex(TypeError, "axis_types.*must be of type"): + AbstractMesh((2,), ('x',), axis_types=("explicit",)) + + with self.assertRaisesRegex(TypeError, "axis_types.*must be of type"): + AbstractMesh((2,), ('x',), axis_types="explicit") + + with self.assertRaisesRegex(TypeError, "axis_types.*must be of type"): + AbstractMesh((2, 2), ('x', 'y'), + axis_types=("explicit", AxisType.Explicit)) + def test_make_mesh_axis_types(self): Auto, Explicit, Manual = AxisType.Auto, AxisType.Explicit, AxisType.Manual @@ -1378,6 +1402,9 @@ def test_make_mesh_axis_types(self): self.assertDictEqual( mesh._axis_types_dict, {AxisType.Auto: ('y',), AxisType.Explicit: ('x',), AxisType.Manual: ('z',)}) + self.assertEqual(mesh.explicit_axes, ('x',)) + self.assertEqual(mesh.auto_axes, ('y',)) + self.assertEqual(mesh.manual_axes, ('z',)) mesh = jax.make_mesh((1, 1, 1), ('x', 'y', 'z'), axis_types=(Explicit, Explicit, Manual)) @@ -1402,6 +1429,28 @@ def test_make_mesh_axis_types(self): self.assertNotEqual(mesh1, mesh2) self.assertNotEqual(hash(mesh1), hash(mesh2)) + def test_memory_kind_with_abstract_mesh(self): + if jaxlib_extension_version < 326: + self.skipTest('Requires jaxlib_extension_version >= 326') + + abstract_mesh = AbstractMesh((2,), ('x',)) + ns = NamedSharding(abstract_mesh, P(), memory_kind='pinned_host') + self.assertEqual(ns.memory_kind, 'pinned_host') + + ns = NamedSharding(abstract_mesh, P()) + self.assertIsNone(ns.memory_kind) + + with self.assertRaisesRegex( + ValueError, 'Got invalid memory kind'): + NamedSharding(abstract_mesh, P(), memory_kind='weird_device') + + def test_pos_gspmd_sharding_warnings(self): + with self.assertWarns(DeprecationWarning): + jax.sharding.PositionalSharding(jax.devices()) + + with self.assertWarns(DeprecationWarning): + jax.sharding.GSPMDSharding.get_replicated(jax.devices()) + @jtu.with_config(jax_use_shardy_partitioner=True) class ShardyShardingTest(jtu.JaxTestCase): @@ -1415,9 +1464,9 @@ def test_long_axis_names(self): SdyArraySharding( mesh.shape_tuple, [SdyDimSharding( - ('sequence', 'data'), True), - SdyDimSharding(('model',), True), - SdyDimSharding([], True)])) + ('sequence', 'data'), False), + SdyDimSharding(('model',), False), + SdyDimSharding([], False)])) with ir.Context() as ctx: dialects.sdy.register_dialect(ctx) self.assertEqual( @@ -1434,9 +1483,9 @@ def test_unconstrained(self): sdy_sharding, SdyArraySharding( mesh.shape_tuple, - [SdyDimSharding([], True), - SdyDimSharding([], False), - SdyDimSharding(('x',), True)])) + [SdyDimSharding([], False), + SdyDimSharding([], True), + SdyDimSharding(('x',), False)])) with ir.Context() as ctx: dialects.sdy.register_dialect(ctx) self.assertEqual( diff --git a/tests/attrs_test.py b/tests/attrs_test.py index 2334a7b98f91..8cf64790311b 100644 --- a/tests/attrs_test.py +++ b/tests/attrs_test.py @@ -15,6 +15,7 @@ from __future__ import annotations from dataclasses import dataclass +import itertools as it from absl.testing import absltest from absl.testing import parameterized @@ -28,7 +29,7 @@ from jax._src.util import safe_zip, safe_map from jax.experimental import attrs -from jax.experimental.attrs import jax_setattr, jax_getattr +from jax.experimental.attrs import jax_setattr, jax_getattr, jax_appendattr config.parse_flags_with_absl() @@ -66,6 +67,19 @@ def double_it() -> None: double_it() self.assertEqual(thing.x, 16.0) + def test_setattr_doesnt_leak(self): + thing = Thing(1.0) + + @jax.jit + def f(x): + jax_setattr(thing, 'x', x) + raise Exception + + try: f(1.) + except: pass + self.assertNotIsInstance(thing.x, jax.core.Tracer) + + @parameterized.parameters([True, False]) def test_jit_basic_tree(self, jit: bool): thing = Thing((1.0, 2.0)) @@ -260,6 +274,26 @@ def body(_, __): double_it_10() self.assertAllClose(thing.x, 1024., check_dtypes=False) + @parameterized.parameters([True, False]) + def test_scan_basic_pytree(self, jit): + class Thing: ... + thing = Thing() + thing.x = (1.0, 1.0) + + def double_it_10(): + def body(_, __): + cur_x, _ = jax_getattr(thing ,"x") + jax_setattr(thing, "x", (cur_x * 2.0, 3.0)) + return None, None + _, _ = jax.lax.scan(body, None, None, length=10) + + if jit: + double_it_10 = jax.jit(double_it_10) + + double_it_10() + self.assertAllClose(thing.x[0], 1024., check_dtypes=False) + self.assertAllClose(thing.x[1], 3., check_dtypes=False) + def test_scan_basic_consts_and_args(self): thing = Thing(1.0) @@ -360,6 +394,227 @@ def body(i, _): return i + 1, None _, _ = jax.lax.scan(body, 0, None, length=3) # don't crash + @parameterized.parameters([True, False]) + def test_setattr_doesnt_exist(self, jit): + class Thing: + ... + thing = Thing() + + def f(x): + assert (not jit) or tracing_is_ok + jax_setattr(thing, 'x', x) + + if jit: + f = jax.jit(f) + + tracing_is_ok = True + self.assertFalse(hasattr(thing, 'x')) + f(1.0) + self.assertEqual(thing.x, 1.0) + f(2.0) + self.assertEqual(thing.x, 2.0) + + tracing_is_ok = False + f(3.0) + self.assertEqual(thing.x, 3.0) + + del thing.x + f(4.0) + self.assertEqual(thing.x, 4.0) + + tracing_is_ok = True + f(5) + self.assertEqual(thing.x, 5) + + def test_setattr_doesnt_exist_doesnt_leave_sentinel_around(self): + class Thing: + ... + thing = Thing() + + def f(x): + jax_setattr(thing, 'x', x) + + jax.make_jaxpr(f)(3.) + self.assertFalse(hasattr(thing, 'x')) + tracing_ok = True + f(0.0) + self.assertAllClose(thing.x, 0.) + tracing_ok = False + f(1.0) + self.assertAllClose(thing.x, 1.) + + @parameterized.parameters(it.product([False, True], repeat=2)) + def test_appendattr_basic(self, jit, initialized): + class Thing: + ... + thing = Thing() + + if initialized: + thing.x = jnp.arange(0.) + + def f(x): + assert (not jit) or tracing_ok + jax_appendattr(thing, 'x', x) + jax_appendattr(thing, 'x', x + 1) + + if jit: + f = jax.jit(f) + + tracing_ok = True + f(0.0) + self.assertAllClose(thing.x, jnp.array([0., 1.])) + tracing_ok = False + f(2.0) + self.assertAllClose(thing.x, jnp.array([0., 1., 2., 3.])) + f(4.0) + self.assertAllClose(thing.x, jnp.array([0., 1., 2., 3., 4., 5.])) + + @parameterized.parameters(it.product([False, True], repeat=2)) + def test_appendattr_constant(self, jit, initialized): + class Thing: ... + thing = Thing() + + if initialized: + thing.x = jnp.arange(0.) + + def f(): + assert (not jit) or tracing_ok + jax_appendattr(thing, 'x', 0.0) + jax_appendattr(thing, 'x', 1.0) + + if jit: + f = jax.jit(f) + + tracing_ok = True + f() + self.assertAllClose(thing.x, jnp.array([0., 1.])) + tracing_ok = False + f() + self.assertAllClose(thing.x, jnp.array([0., 1., 0., 1.])) + + @parameterized.parameters([True, False]) + def test_appendattr_getattr_errors(self, initialized): + class Thing: ... + thing = Thing() + + if initialized: + thing.x = jnp.arange(0.) + + @jax.jit + def f(x): + jax_appendattr(thing, 'x', x) + jax_getattr(thing, 'x') + + with self.assertRaisesRegex(TypeError, "can't read/write"): + f(1.0) + + @jax.jit + def g(x): + jax_setattr(thing, 'x', x) + jax_appendattr(thing, 'x', x) + + with self.assertRaisesRegex(TypeError, "can't append"): + g(1.0) + + if initialized: + self.assertNotIsInstance(thing.x, jax.core.Tracer) + else: + self.assertFalse(hasattr(thing, 'x')) + + @parameterized.parameters(it.product([False, True], repeat=2)) + def test_appendattr_dtype_disagreement(self, jit, initialized): + class Thing: ... + thing = Thing() + + if initialized: + thing.x = jnp.array([], 'float32') + + def f(x): + jax_appendattr(thing, 'x', x) + jax_appendattr(thing, 'x', x.astype('complex64')) + + if jit: + f = jax.jit(f) + + msg = "can only append to attr x with values of trailing shape " + msg += "float32" if initialized else "int32" + with self.assertRaisesRegex(TypeError, msg): + f(jnp.array(1, 'int32')) + + @parameterized.parameters(it.product([False, True], repeat=2)) + def test_appendattr_shape_disagreement(self, jit, initialized): + class Thing: ... + thing = Thing() + + if initialized: + thing.x = jnp.array([]) + + def f(x): + jax_appendattr(thing, 'x', x) + jax_appendattr(thing, 'x', jnp.stack([x, x])) + + if jit: + f = jax.jit(f) + + msg = "can only append to attr x with values of trailing shape" + with self.assertRaisesRegex(TypeError, msg): + f(1) + + @parameterized.parameters(it.product([False, True], repeat=2)) + def test_appendattr_scan(self, jit, initialized): + class Thing: ... + thing = Thing() + + if initialized: + thing.x = jnp.array([]) + + def f(): + def body(c, x): + jax_appendattr(thing, 'x', 2 * x) + jax_appendattr(thing, 'x', 2 * x + 1) + return c, () + _, () = jax.lax.scan(body, 0, jnp.arange(3.)) + + if jit: + f = jax.jit(f) + + f() + + self.assertAllClose(thing.x, jnp.array([0., 1., 2., 3., 4., 5.])) + + @parameterized.parameters(it.product([False, True], repeat=2)) + def test_appendattr_scan_vjp(self, jit, initialized): + class Thing: ... + thing = Thing() + + if initialized: + thing.y_bar = jnp.array([]) + + def f(x): + def body(c, _): + return 0.5 * g(2 * c), () + y, _ = jax.lax.scan(body, x, (), length=5) + return y + + if jit: + f = jax.jit(f) + + @jax.custom_vjp + def g(x): + return x + + def g_fwd(x): + return g(x), None + + def g_bwd(_, y_bar): + jax_appendattr(thing, 'y_bar', y_bar) + return y_bar, + + g.defvjp(g_fwd, g_bwd) + jax.grad(f)(3.) + + self.assertAllClose(thing.y_bar, jnp.array([0.5] * 5)) + class AttrsJVPTest(jtu.JaxTestCase): @@ -500,6 +755,7 @@ def g_ref(x, x_dot, y, y_dot): self.assertAllClose(w_ddot, w_ddot_, check_dtypes=False) self.assertAllClose(z_ddot, z_ddot_, check_dtypes=False) + class AttrsLinTest(jtu.JaxTestCase): @parameterized.parameters([True, False]) diff --git a/tests/batching_test.py b/tests/batching_test.py index f2a4e8c34fe3..393317bcbe77 100644 --- a/tests/batching_test.py +++ b/tests/batching_test.py @@ -1328,33 +1328,70 @@ def list_insert(lst: list[a], idx: int, val: a) -> list[a]: @jtu.thread_unsafe_test_class() # temporary registration isn't thread-safe class VmappableTest(jtu.JaxTestCase): - def test_basic(self): + @parameterized.parameters([False, True]) + def test_basic(self, jit): with temporarily_register_named_array_vmappable(): def f(x): return named_mul(x, x) + if jit: + f = jax.jit(f) x = NamedArray(['i', 'j'], jnp.arange(12.).reshape(3, 4)) g = jax.vmap(f, - in_axes=NamedMapSpec('i', 0), - out_axes=NamedMapSpec('i', 1), - axis_size=3) + in_axes=NamedMapSpec('i', 0), + out_axes=NamedMapSpec('i', 1), + axis_size=3) ans = g(x) expected = NamedArray(['j', 'i'], jnp.arange(12.).reshape(3, 4).T ** 2) self.assertEqual(ans.names, expected.names) self.assertAllClose(ans.data, expected.data) - def test_basic_jit(self): - with temporarily_register_named_array_vmappable(): - def f(x): - return named_mul(x, x) - - x = NamedArray(['i', 'j'], jnp.arange(12.).reshape(3, 4)) - ans = jax.jit(f)(x) - expected = NamedArray(['i', 'j'], jnp.arange(12.).reshape(3, 4) ** 2) - - self.assertEqual(ans.names, expected.names) - self.assertAllClose(ans.data, expected.data) + def test_to_elt_that_binds_primitives(self): + class A: + data: Array + def __init__(self, data): + self.data = data + def to_elt(cont, _, val, spec): + return cont(val.data + 1, spec) + def from_elt(cont, size, elt, spec): + assert False + + @jax.jit + def f(): + a = A(jnp.arange(3.)) + return jax.vmap(lambda x: x - 1, axis_size=3)(a) + + try: + batching.register_vmappable(A, int, int, to_elt, from_elt, None) + ans = f() + finally: + batching.unregister_vmappable(A) + + self.assertAllClose(ans, jnp.arange(3.)) + + def test_from_elt_that_binds_primitives(self): + class A: + data: Array + def __init__(self, data): + self.data = data + def to_elt(cont, _, val, spec): + return A(cont(val.data, spec)) + def from_elt(cont, size, elt, spec): + return A(cont(size, elt.data + 1, spec)) + + @jax.jit + def f(): + a = A(jnp.arange(3.)) + return jax.vmap(lambda x: x, axis_size=3)(a).data + + try: + batching.register_vmappable(A, int, int, to_elt, from_elt, None) + ans = f() + finally: + batching.unregister_vmappable(A) + + self.assertAllClose(ans, jnp.arange(3.) + 1) def test_types_with_same_spec(self): # We register NamedArray. diff --git a/tests/cache_key_test.py b/tests/cache_key_test.py index 2faa4dbaf9d4..fd3e7706260a 100644 --- a/tests/cache_key_test.py +++ b/tests/cache_key_test.py @@ -83,9 +83,9 @@ def test_hash_accelerator_devices(self): self.assertEqual(dev_hash1, dev_hash2) acc_hash1 = self.get_hashed_value( - cache_key._hash_accelerator_config, devices, xla_bridge.get_backend()) + cache_key._hash_accelerator_config, devices) acc_hash2 = self.get_hashed_value( - cache_key._hash_accelerator_config, devices, xla_bridge.get_backend()) + cache_key._hash_accelerator_config, devices) self.assertEqual(acc_hash1, acc_hash2) def test_hash_platform(self): @@ -163,6 +163,8 @@ def test_different_computations(self): cache_key.get(computation2, devices, compile_options, backend), ) + # TODO(phawkins): this test flakes if test concurrency is enabled. + @jtu.thread_unsafe_test() def test_custom_partitioning_ptr_removal(self): def _partition(mesh, arg_shapes, result_shape): arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes) @@ -178,7 +180,8 @@ def _cp_add(x, y): _cp_add.def_partition( infer_sharding_from_operands=_infer_sharding_from_operands, - partition=_partition) + partition=_partition, + sharding_rule='i i -> i') devices = np.asarray(jax.devices()) with Mesh(devices, ('x',)) as m: diff --git a/tests/checkify_test.py b/tests/checkify_test.py index 6a1660b28578..2f4b7d511fbe 100644 --- a/tests/checkify_test.py +++ b/tests/checkify_test.py @@ -24,7 +24,7 @@ from jax.experimental import checkify from jax.experimental import pjit from jax.experimental import shard_map -from jax.sharding import NamedSharding +from jax.sharding import NamedSharding, PartitionSpec as P from jax._src import array from jax._src import config from jax._src import core @@ -475,12 +475,25 @@ def f(init_val): self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), "division by zero") + def test_checify_donation_no_forwarding(self): + mesh = jtu.create_mesh((2,), ('x',)) + + @checkify.checkify + @partial(jax.jit, donate_argnums=(0,)) + def f(x: jax.Array) -> jax.Array: + checkify.check(jnp.all(x > 0), "a") + return x + + x = jax.device_put(jnp.zeros(64, dtype="int32"), NamedSharding(mesh, P())) + err, y = f(x) + err, z = f(y) # doesn't crash + @jtu.skip_on_devices("tpu") def test_while_loop_body_and_cond_error(self): def while_cond(val): i, cond_val, _ = val - _ = jnp.sin(cond_val) - return i < 2 + j = jnp.sin(cond_val) + return i + (0. * j) < 2 # don't let the sin value be dead code def while_body(val): i, cond_val, body_val = val diff --git a/tests/colocated_python_test.py b/tests/colocated_python_test.py index 52d494904fe6..ada17bc61c82 100644 --- a/tests/colocated_python_test.py +++ b/tests/colocated_python_test.py @@ -18,7 +18,6 @@ import threading import time from typing import Sequence -import unittest from absl.testing import absltest from absl.testing import parameterized @@ -36,8 +35,9 @@ try: import cloudpickle # noqa + HAS_CLOUDPICKLE = True except (ModuleNotFoundError, ImportError): - raise unittest.SkipTest("tests depend on cloudpickle library") + HAS_CLOUDPICKLE = False def _colocated_cpu_devices( @@ -68,10 +68,14 @@ class ColocatedPythonTest(jtu.JaxTestCase): def setUp(self): super().setUp() + if not HAS_CLOUDPICKLE: + self.skipTest( + "ColocatedPythonTest depends on cloudpickle library" + ) if np.lib.NumpyVersion(np.__version__) < "2.0.0": self.skipTest( - "Serialization in Colocated Python needs StringDType, and thus" - " requires NumPy 2.0.0 or later" + "Serialization in Colocated Python needs StringDType, and thus" + " requires NumPy 2.0.0 or later" ) def testMakeColocatedPythonProgram(self): diff --git a/tests/core_test.py b/tests/core_test.py index c46d493bda54..00b3eb1d61d5 100644 --- a/tests/core_test.py +++ b/tests/core_test.py @@ -397,6 +397,12 @@ def setUp(self): lax_control_flow._initial_style_jaxpr.cache_clear() lax_control_flow.common._pad_jaxpr_constvars.cache_clear() + def tearDown(self): + super().tearDown() + lax_control_flow._initial_style_open_jaxpr.cache_clear() + lax_control_flow._initial_style_jaxpr.cache_clear() + lax_control_flow.common._pad_jaxpr_constvars.cache_clear() + def test_check_jaxpr_correct(self): jaxpr = make_jaxpr(lambda x: jnp.sin(x) + jnp.cos(x))(1.).jaxpr core.check_jaxpr(jaxpr) @@ -405,6 +411,7 @@ def test_check_jaxpr_cond_correct(self): jaxpr = make_jaxpr(lambda x: lax.switch(0, [jnp.sin, jnp.cos], x))(1.).jaxpr core.check_jaxpr(jaxpr) + @jtu.thread_unsafe_test() # in-place mutation of possibly-cached jaxpr def test_check_jaxpr_jit_invalid(self): jaxpr = make_jaxpr(jax.jit(lambda x, y: x + 1))(1., 2.).jaxpr pjit_eqn, = jaxpr.eqns @@ -414,6 +421,7 @@ def test_check_jaxpr_jit_invalid(self): '0 operands cannot call jaxpr with 2 inputs', lambda: core.check_jaxpr(jaxpr)) + @jtu.thread_unsafe_test() # in-place mutation of possibly-cached jaxpr def test_check_jaxpr_cond_invalid(self): jaxpr = make_jaxpr(lambda x: lax.switch(0, [jnp.sin, jnp.cos], x))(1.).jaxpr cond = next(eqn for eqn in jaxpr.eqns if eqn.primitive.name == 'cond') @@ -433,6 +441,7 @@ def f(c, x): jaxpr = make_jaxpr(partial(lax.scan, f))(c, xs).jaxpr core.check_jaxpr(jaxpr) + @jtu.thread_unsafe_test() # in-place mutation of possibly-cached jaxpr def test_check_jaxpr_invalid_long(self): # jaxprs can be large, and this tests that when large ones are printed for # context in jaxpr typechecking errors, they're not printed entirely @@ -464,6 +473,7 @@ def g(x): self.assertIn('while checking jaxpr:', msg) self.assertLess(msg.count('\n'), 200) + @jtu.thread_unsafe_test() # in-place mutation of possibly-cached jaxpr def test_check_jaxpr_eqn_mismatch(self): def f(x): return jnp.sin(x) + jnp.cos(x) @@ -487,7 +497,7 @@ def new_jaxpr(): self.assertRaisesRegex( core.JaxprTypeError, r"Value for variable 'b' inconsistently typed as f32\[\] " - r"for let-binder of type i32\[\]\n\nin equation:\n\nb:i32\[\] = sin a", + r"for let-binder of type i32\[\]\n\nin equation:\n\nb:i32\[\] = sin\ a", lambda: core.check_jaxpr(jaxpr)) jaxpr = new_jaxpr() @@ -496,7 +506,7 @@ def new_jaxpr(): self.assertRaisesRegex( core.JaxprTypeError, r"Value for variable 'b' inconsistently typed as f32\[\] " - r"for let-binder of type f32\[2,3\]\n\nin equation:\n\nb:f32\[2,3\] = sin a", + r"for let-binder of type f32\[2,3\]\n\nin equation:\n\nb:f32\[2,3\] = sin\ a", lambda: core.check_jaxpr(jaxpr)) def test_jaxpr_dropvar_from_jit_call(self): @@ -534,6 +544,7 @@ def f(x): assert isinstance(jaxpr.eqns[-1].outvars[0], core.DropVar) core.check_jaxpr(jaxpr) + @jtu.thread_unsafe_test() # in-place mutation of possibly-cached jaxpr def test_jaxpr_undefined_eqn_invar(self): jaxpr = make_jaxpr(lambda x: jnp.sin(x) + jnp.cos(x))(1.).jaxpr cos = next(eqn for eqn in jaxpr.eqns if eqn.primitive.name == 'cos') diff --git a/tests/debug_info_test.py b/tests/debug_info_test.py index a39b53c3ad16..d014c86c2506 100644 --- a/tests/debug_info_test.py +++ b/tests/debug_info_test.py @@ -46,6 +46,7 @@ from jax._src.compilation_cache import is_persistent_cache_enabled from jax._src.lax.control_flow import for_loop from jax._src.interpreters import mlir +from jax._src import util as util import numpy as np @@ -241,7 +242,7 @@ def my_f(x, y, z, w): dbg = api_util.debug_info("jit", my_f, (1, 2), dict(z=3, w=4)) self.assertRegex(dbg.func_src_info, r"^my_f at .*debug_info_test.py:\d+") self.assertEqual(dbg.func_name, "my_f") - self.assertEqual(dbg.arg_names, ("x", "y", "z", "w")) + self.assertEqual(dbg.arg_names, ("x", "y", "w", "z")) self.assertIsNone(dbg.result_paths) def test_debug_info_arg_passed_as_kwarg(self): @@ -261,23 +262,29 @@ def my_f(x_tree, *, y_tree): "y_tree['w']", "y_tree['z']")) def test_debug_info_with_statics(self): - def my_f(x, y, *, z, w): + def my_f(x, z, *, w, y): pass - dbg = api_util.debug_info("jit", my_f, (1, 2), dict(z=3, w=4), + dbg = api_util.debug_info("jit", my_f, (1,), dict(y=2, z=3, w=4), static_argnums=(1,), static_argnames=("w",)) - self.assertEqual(dbg.arg_names, ("x", "z")) + self.assertEqual(dbg.arg_names, ("x", "y", "z")) def test_debug_info_with_pytrees_and_statics(self): - def my_f(x, y, *, z, w): + def my_f(x, y, *, z, w, t): pass dbg = api_util.debug_info("jit", my_f, ((1, 2), (2, 3)), - dict(z=(3, 4), w=(5, 6)), + dict(z=(3, 4), w=(5, 6), t=7), + static_argnums=(1,), + static_argnames=("w",)) + self.assertEqual(dbg.arg_names, ("x[0]", "x[1]", "t", "z[0]", "z[1]")) + + dbg = api_util.debug_info("jit", my_f, ((1, 2),), + dict(z=(3, 4), w=(5, 6), t=7, y=3), static_argnums=(1,), static_argnames=("w",)) - self.assertEqual(dbg.arg_names, ("x[0]", "x[1]", "z[0]", "z[1]")) + self.assertEqual(dbg.arg_names, ("x[0]", "x[1]", "t", "y", "z[0]", "z[1]")) def test_debug_info_too_many_args(self): def my_f(x): @@ -287,15 +294,20 @@ def my_f(x): self.assertEqual(dbg.arg_names, ('args[0]', 'args[1]', 'args[2]', "kwargs['z']")) def test_debug_info_no_source_info_built_in(self): - # built-in function "int" does not have an inspect.Signature + # built-in function "max" does not have an inspect.Signature dbg = api_util.debug_info("jit", max, (1,), {}) self.assertEqual(dbg.func_src_info, "max") + self.assertEqual(dbg.func_name, "max") + self.assertEqual(dbg.func_filename, None) + self.assertEqual(dbg.func_lineno, None) self.assertEqual(dbg.arg_names, ("args[0]",)) def test_debug_info_lambda(self): # built-in function "int" does not have an inspect.Signature dbg = api_util.debug_info("jit", lambda my_arg: False, (1,), {}) self.assertRegex(dbg.func_src_info, r"^ at .*debug_info_test.py:\d+") + self.assertEndsWith(dbg.func_filename, "debug_info_test.py") + self.assertIsNotNone(dbg.func_lineno) self.assertEqual(dbg.arg_names, ("my_arg",)) def test_debug_info_save_wrapped_fun_source_info(self): @@ -380,66 +392,6 @@ def f(x): with self.assertRaisesRegex(TypeError, err_str): jax.jit(f)(jnp.int32) - @jtu.thread_unsafe_test() # logging is not thread-safe - def test_arg_names_cache_miss_explanations(self): - @jax.jit - def f(x, y): - return jnp.sin(x) * y['hi'] - - x = jnp.float32(1.) - y = {'hi': jnp.arange(3., dtype='float32')} - - expected_log_len = 1 if not is_persistent_cache_enabled() else 3 - - # print on first miss, not on hit - with config.explain_cache_misses(True): - with self.assertLogs(level='WARNING') as cm: - f(x, y) - f(x, y) - self.assertLen(cm.output, expected_log_len) - msg = cm.output[0] - self.assertIn('TRACING CACHE MISS', msg) - self.assertIn('never seen function', msg) - - # shape change - y_ = {'hi': jnp.arange(4, dtype='float32')} - with config.explain_cache_misses(True): - with self.assertLogs(level='WARNING') as cm: - f(x, y_) - self.assertLen(cm.output, expected_log_len) - msg = cm.output[0] - self.assertIn('never seen input type signature', msg) - self.assertIn('closest seen input type signature has 1 mismatches', msg) - self.assertIn('seen f32[3], but now given f32[4]', msg) - - # weak type change (assuming no x64) - if not config.enable_x64.value: - with config.explain_cache_misses(True): - with self.assertLogs(level='WARNING') as cm: - f(1., y) - self.assertLen(cm.output, expected_log_len) - msg = cm.output[0] - self.assertIn('weak_type=True', msg) - self.assertIn('https://jax.readthedocs.io/en/latest/type_promotion.html#weak-types', msg) - - # kwarg change - with config.explain_cache_misses(True): - with self.assertLogs(level='WARNING') as cm: - f(1, y=y) - self.assertLen(cm.output, expected_log_len) - msg = cm.output[0] - self.assertIn('never seen passing 1 positional args and 1 keyword args', msg) - - # tracing config change - with config.explain_cache_misses(True): - with self.assertLogs(level='WARNING') as cm: - with jax.numpy_rank_promotion('warn'): - f(x, y) - # depending on the backend, we may or may not get persistent cache warnings - self.assertTrue(1 <= len(cm.output) <= expected_log_len) - msg = cm.output[0] - self.assertIn("tracing context doesn't match", msg) - @jtu.thread_unsafe_test() # logging is not thread-safe def test_arg_names_cache_miss_explanations_new_function_in_loop(self): @jax.jit @@ -671,7 +623,7 @@ def my_g(b, d=1): tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ # TODO(necula): result_paths? - "traced_for=jit, fun=my_f, arg_names=a, result_paths=", + "traced_for=jit, fun=my_f, arg_names=a, result_paths=result", "traced_for=jit, fun=my_g, arg_names=b, result_paths=result", ], expected_tracer_debug_infos=[ @@ -761,6 +713,122 @@ def f(x, y, *args, **kwargs): re.compile(r".*func.func public @main\(.*\{jax.result_info = \"result\"\}"), ]) + def test_jit_arg_names_with_out_of_order_kwargs(self): + tracer_spy = TracerSpy() + + # The shapes are different, to differentiate them easily + a1 = (np.float32(0),) # a hashable tuple, can be static + b2 = np.arange(2, dtype=np.float32) # b2 + z3 = np.arange(3, dtype=np.float32) + y4 = (np.float32(0.), np.float32(1.), np.float32(2.), np.float32(3.)) + x5 = np.arange(5, dtype=np.float32) + u6 = np.arange(6, dtype=np.float32) + t7 = np.arange(7, dtype=np.float32) + + def my_f(a1, b2, z3, y4, x5, *, u6, t7): + assert np.shape(a1[0]) == () + assert np.shape(b2) == (2,) + assert np.shape(z3) == (3,) + assert np.shape(y4) == (4,) + assert np.shape(x5) == (5,) + assert np.shape(u6) == (6,) + assert np.shape(t7) == (7,) + tracer_spy.append(b2) + tracer_spy.append(x5) + return a1[0] + b2[0] + z3[0] + y4[0] + x5[0] + u6[0] + t7[0] + + self._check_tracers_and_jaxprs( + jax.jit(my_f, static_argnums=(0,), static_argnames=("y4",)), + # Some positional args passed as keyword + a1, b2, x5=x5, y4=y4, z3=z3, t7=t7, u6=u6, + expected_jaxpr_debug_infos=[ + "traced_for=jit, fun=my_f, arg_names=b2,t7,u6,x5,z3, result_paths=result", + ], + tracer_spy=tracer_spy, + expected_tracer_debug_infos=[ + "traced_for=jit, fun=my_f, arg_names=b2,t7,u6,x5,z3, from b2", + "traced_for=jit, fun=my_f, arg_names=b2,t7,u6,x5,z3, from x5", + ], + expected_lowering_lines=[ + re.compile(r".*func.func public @main\(%arg0: tensor<2xf..> loc\(\"b2\"\)"), + re.compile(r".*func.func public @main\(.*%arg1: tensor<7xf..> loc\(\"t7\"\)"), + re.compile(r".*func.func public @main\(.*%arg2: tensor<6xf..> loc\(\"u6\"\)"), + re.compile(r".*func.func public @main\(.*%arg3: tensor<5xf..> loc\(\"x5\"\)"), + ] + ) + + tracer_spy.tracers = [] + util.clear_all_caches() + self._check_tracers_and_jaxprs( + jax.jit(my_f, static_argnames=("y4",)), + # Positional argument y4 is static and passed by kwarg + a1, b2, z3, x5=x5, y4=y4, t7=t7, u6=u6, + expected_jaxpr_debug_infos=[ + "traced_for=jit, fun=my_f, arg_names=a1[0],b2,z3,t7,u6,x5, result_paths=result", + ], + tracer_spy=tracer_spy, + expected_tracer_debug_infos=[ + "traced_for=jit, fun=my_f, arg_names=a1[0],b2,z3,t7,u6,x5, from b2", + "traced_for=jit, fun=my_f, arg_names=a1[0],b2,z3,t7,u6,x5, from x5", + ], + expected_lowering_lines=[ + re.compile(r".*func.func public @main\(%arg0: tensor loc\(\"a1\[0\]\"\)"), + re.compile(r".*func.func public @main\(.*%arg1: tensor<2xf..> loc\(\"b2\"\)"), + re.compile(r".*func.func public @main\(.*%arg2: tensor<3xf..> loc\(\"z3\"\)"), + re.compile(r".*func.func public @main\(.*%arg3: tensor<7xf..> loc\(\"t7\"\)"), + re.compile(r".*func.func public @main\(.*%arg4: tensor<6xf..> loc\(\"u6\"\)"), + re.compile(r".*func.func public @main\(.*%arg5: tensor<5xf..> loc\(\"x5\"\)"), + ] + ) + + tracer_spy.tracers = [] + util.clear_all_caches() + self._check_tracers_and_jaxprs( + jax.jit(my_f, static_argnames=("y4",)), + # Positional argument y4 is static (declared as static_argnames) + a1, b2, z3, y4, x5=x5, t7=t7, u6=u6, + expected_jaxpr_debug_infos=[ + "traced_for=jit, fun=my_f, arg_names=a1[0],b2,z3,t7,u6,x5, result_paths=result", + ], + tracer_spy=tracer_spy, + expected_tracer_debug_infos=[ + "traced_for=jit, fun=my_f, arg_names=a1[0],b2,z3,t7,u6,x5, from b2", + "traced_for=jit, fun=my_f, arg_names=a1[0],b2,z3,t7,u6,x5, from x5", + ], + expected_lowering_lines=[ + re.compile(r".*func.func public @main\(%arg0: tensor loc\(\"a1\[0\]\"\)"), + re.compile(r".*func.func public @main\(.*%arg1: tensor<2xf..> loc\(\"b2\"\)"), + re.compile(r".*func.func public @main\(.*%arg2: tensor<3xf..> loc\(\"z3\"\)"), + re.compile(r".*func.func public @main\(.*%arg3: tensor<7xf..> loc\(\"t7\"\)"), + re.compile(r".*func.func public @main\(.*%arg4: tensor<6xf..> loc\(\"u6\"\)"), + re.compile(r".*func.func public @main\(.*%arg5: tensor<5xf..> loc\(\"x5\"\)"), + ] + ) + + tracer_spy.tracers = [] + util.clear_all_caches() + self._check_tracers_and_jaxprs( + jax.jit(my_f, static_argnums=(3,)), + # Positional argument y4 is static (declared as static_argnums) + a1, b2, z3, y4, x5=x5, t7=t7, u6=u6, + expected_jaxpr_debug_infos=[ + "traced_for=jit, fun=my_f, arg_names=a1[0],b2,z3,t7,u6,x5, result_paths=result", + ], + tracer_spy=tracer_spy, + expected_tracer_debug_infos=[ + "traced_for=jit, fun=my_f, arg_names=a1[0],b2,z3,t7,u6,x5, from b2", + "traced_for=jit, fun=my_f, arg_names=a1[0],b2,z3,t7,u6,x5, from x5", + ], + expected_lowering_lines=[ + re.compile(r".*func.func public @main\(%arg0: tensor loc\(\"a1\[0\]\"\)"), + re.compile(r".*func.func public @main\(.*%arg1: tensor<2xf..> loc\(\"b2\"\)"), + re.compile(r".*func.func public @main\(.*%arg2: tensor<3xf..> loc\(\"z3\"\)"), + re.compile(r".*func.func public @main\(.*%arg3: tensor<7xf..> loc\(\"t7\"\)"), + re.compile(r".*func.func public @main\(.*%arg4: tensor<6xf..> loc\(\"u6\"\)"), + re.compile(r".*func.func public @main\(.*%arg5: tensor<5xf..> loc\(\"x5\"\)"), + ] + ) + def test_jit_result_info(self): def f(x, y, z): return {'a': x, 'b': [y]} @@ -794,7 +862,7 @@ def my_g(u, v): tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=x,y, result_paths=result", - "traced_for=jit, fun=my_g, arg_names=u,v, result_paths=result['c']" + "traced_for=jit, fun=my_g, arg_names=u,v, result_paths=result['c'],result['d']", ], expected_tracer_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=x,y, from x", @@ -1145,7 +1213,6 @@ def fn_tp(r, t): tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=, arg_names=x, result_paths=result[0]['c']", - "traced_for=linear_call fun, fun=fn, arg_names=r,x['c'], result_paths=result['b']", "traced_for=linear_call fun_transpose, fun=fn_tp, arg_names=r,t['c'], result_paths=result['c']", ], expected_tracer_debug_infos=[ @@ -1318,17 +1385,15 @@ def the_grad(c, as_): tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=the_grad, arg_names=c,as_, result_paths=result[0],result[1]", - # TODO(necula): arg names, bad result paths "traced_for=jit, fun=my_f, arg_names=x,as_, result_paths=,,", - "traced_for=for_loop, fun=f, arg_names=i,refs[0],refs[1],refs[2], result_paths=", "traced_for=for_loop, fun=f, arg_names=,,, result_paths=,", + "traced_for=for_loop, fun=f, arg_names=i,refs[0],refs[1],refs[2], result_paths=", + "traced_for=jit, fun=my_f, arg_names=,,x,as_, result_paths=result[0],result[1]", + "traced_for=checkpoint / remat, fun=to_remat, arg_names=,,, result_paths=,", "traced_for=for_loop, fun=f, arg_names=,,,,,, result_paths=,", - "traced_for=for_loop, fun=f, arg_names=,,,,,,,,,,, result_paths=", + "traced_for=for_loop, fun=f, arg_names=i,refs[0],refs[1],refs[2], result_paths=", "traced_for=for_loop, fun=f, arg_names=,,,,,,,,,,,,,,, result_paths=,", - "traced_for=checkpoint / remat, fun=to_remat, arg_names=,,, result_paths=,", - "traced_for=jit, fun=my_f, arg_names=as_,,, result_paths=" - if config.use_direct_linearize.value else - "traced_for=jit, fun=my_f, arg_names=,,x,as_, result_paths=", + "traced_for=for_loop, fun=f, arg_names=,,,,,,,,,,, result_paths=", ], expected_tracer_debug_infos=[ "traced_for=jit, fun=the_grad, arg_names=c,as_, from c", @@ -1467,7 +1532,7 @@ def my_g(u, v): tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=x,y, result_paths=result", - "traced_for=jit, fun=my_g, arg_names=u,v, result_paths=result['c']", + "traced_for=jit, fun=my_g, arg_names=u,v, result_paths=result['c'],result['d']", ], expected_tracer_debug_infos=[ # TODO(necula): missing debug info @@ -1495,34 +1560,50 @@ def my_f(x): def test_pmap_with_arg_and_result_names(self): tracer_spy = TracerSpy() - x = np.ones((jax.device_count(),), dtype=np.float32) - def my_f(x, y, *args, a, **kwargs): - # y and kwargs[c] is dead + + # Use different shapes arguments to distinguish them in the HLO + def my_f(x0, y1, *args, b4, **kwargs): + assert np.shape(x0) == () + assert np.shape(y1) == (1,) + assert np.shape(args[0]) == (2,) + assert np.shape(args[1]) == (3,) + assert np.shape(b4) == (4,) + assert np.shape(kwargs["a5"]) == (5,) + assert np.shape(kwargs["c6"]) == (6,) + # kwargs[b5] is dead tracer_spy.append(args[1]) - s = x + a + args[1] + kwargs["d"] - return dict(u=s, v=x) + tracer_spy.append(b4) + tracer_spy.append(kwargs["c6"]) + s0 = x0 + y1[0] + b4[0] + args[1][0] + kwargs["c6"][0] + return dict(v1=jnp.broadcast_to(s0, (1,)), u0=s0) self._check_tracers_and_jaxprs( jax.pmap(my_f, static_broadcasted_argnums=(0,)), - 1., x, x, x, # x, y, args[0], args[1] - d=x, a=x, b=x, # kwargs + 1., # x0 + np.ones((jax.device_count(), 1), dtype=np.float32), # y1 + np.ones((jax.device_count(), 2), dtype=np.float32), # args[0] + np.ones((jax.device_count(), 3), dtype=np.float32), # args[1] + b4=np.ones((jax.device_count(), 4), dtype=np.float32), + a5=np.ones((jax.device_count(), 5), dtype=np.float32), + c6=np.ones((jax.device_count(), 6), dtype=np.float32), expected_jaxpr_debug_infos=[ - "traced_for=pmap, fun=my_f, arg_names=y,args[0],args[1],a,kwargs['b'],kwargs['d'], result_paths=result['u'],result['v']", + "traced_for=pmap, fun=my_f, arg_names=y1,args[0],args[1],kwargs['a5'],b4,kwargs['c6'], result_paths=result['u0'],result['v1']", ], tracer_spy=tracer_spy, expected_tracer_debug_infos=[ - "traced_for=pmap, fun=my_f, arg_names=y,args[0],args[1],a,kwargs['b'],kwargs['d'], from args[1]", + "traced_for=pmap, fun=my_f, arg_names=y1,args[0],args[1],kwargs['a5'],b4,kwargs['c6'], from args[1]", + "traced_for=pmap, fun=my_f, arg_names=y1,args[0],args[1],kwargs['a5'],b4,kwargs['c6'], from b4", + "traced_for=pmap, fun=my_f, arg_names=y1,args[0],args[1],kwargs['a5'],b4,kwargs['c6'], from kwargs['c6']", ], expected_lowering_lines=[ - # TODO(necula): we did not DCE y? - re.compile(r".*func.func public @main\(.*%arg0: tensor<1xf..> loc\(\"y\"\)"), - re.compile(r".*func.func public @main\(.*%arg1: tensor<1xf..> loc\(\"args\[0\]\"\)"), - re.compile(r".*func.func public @main\(.*%arg2: tensor<1xf..> loc\(\"args\[1\]\"\)"), - re.compile(r".*func.func public @main\(.*%arg3: tensor<1xf..> loc\(\"a\"\)"), - re.compile(r".*func.func public @main\(.*%arg4: tensor<1xf..> loc\(\"kwargs\['b'\]\"\)"), - re.compile(r".*func.func public @main\(.*%arg5: tensor<1xf..> loc\(\"kwargs\['d'\]\"\)"), - re.compile(r".*func.func public @main\(.* -> .*\{jax.result_info = \"result\['u'\]\"\}"), - re.compile(r".*func.func public @main\(.* -> .*\{jax.result_info = \"result\['v'\]\"\}"), + re.compile(r".*func.func public @main\(.*%arg0: tensor<1x1xf..> loc\(\"y1\"\)"), + re.compile(r".*func.func public @main\(.*%arg1: tensor<1x2xf..> loc\(\"args\[0\]\"\)"), + re.compile(r".*func.func public @main\(.*%arg2: tensor<1x3xf..> loc\(\"args\[1\]\"\)"), + re.compile(r".*func.func public @main\(.*%arg3: tensor<1x5xf..> loc\(\"kwargs\['a5'\]\"\)"), + re.compile(r".*func.func public @main\(.*%arg4: tensor<1x4xf..> loc\(\"b4\"\)"), + re.compile(r".*func.func public @main\(.*%arg5: tensor<1x6xf..> loc\(\"kwargs\['c6'\]\"\)"), + re.compile(r".*func.func public @main\(.* -> .*\{jax.result_info = \"result\['u0'\]\"\}"), + re.compile(r".*func.func public @main\(.* -> .*\{jax.result_info = \"result\['v1'\]\"\}"), ] ) @@ -1611,11 +1692,8 @@ def my_f(x): x, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=x, result_paths=result", - # TODO(necula): arg_names and result_paths? - "traced_for=jit, fun=my_f, arg_names=x, result_paths=,,,", - "traced_for=jit, fun=my_f, arg_names=x,, result_paths=," - if config.use_direct_linearize.value else - "traced_for=jit, fun=my_f, arg_names=,x, result_paths=,", + "traced_for=jit, fun=my_f, arg_names=x, result_paths=,", + "traced_for=jit, fun=my_f, arg_names=,x, result_paths=result", ], tracer_spy=tracer_spy, expected_tracer_debug_infos=[ @@ -1697,7 +1775,7 @@ def my_f(x): "traced_for=shard_map, fun=my_f, arg_names=,, result_paths=", ], expected_tracer_debug_infos=[ - "None" # TODO(necula): missing + "traced_for=shard_map, fun=my_f, arg_names=x, from x" ]) def test_remat_saved_residuals(self): diff --git a/tests/debugging_primitives_test.py b/tests/debugging_primitives_test.py index a8d59bc39e36..d9d50a546e57 100644 --- a/tests/debugging_primitives_test.py +++ b/tests/debugging_primitives_test.py @@ -25,6 +25,7 @@ from jax._src import debugging from jax._src import dispatch from jax._src import test_util as jtu +from jax._src.lib import jaxlib_extension_version import jax.numpy as jnp import numpy as np @@ -1203,6 +1204,169 @@ def f_(x): f(arr) +def _get_output_set(output, num_lines): + """Return a set of strings where each string is num_lines.""" + output = output().strip().split("\n") + return { + "\n".join(output[i : i + num_lines]) + for i in range(0, len(output), num_lines) + } + + +@jtu.thread_unsafe_test_class() # printing isn't thread-safe +class PartitionedDebugCallbackTest(jtu.JaxTestCase): + + def setUp(self): + super().setUp() + if (jtu.device_under_test() not in ("cpu", "gpu")): + raise unittest.SkipTest( + f"Test requires CPU or GPU devices. Got {jtu.device_under_test()}" + ) + if jaxlib_extension_version < 329: + self.skipTest( + "Requires jaxlib_extension_version >= 329. Got" + f" {jaxlib_extension_version}." + ) + if len(jax.devices()) < 2: + raise unittest.SkipTest("Test requires >= 2 devices.") + + def tearDown(self): + super().tearDown() + dispatch.runtime_tokens.clear() + + def test_partitioned_debug_callback(self): + def f_(x): + debug_print("hello: {x}", x=x, partitioned=True) + + f = pjit.pjit(f_) + mesh = jtu.create_mesh((1, 1, 2,), ("x", "y", "z")) + s = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("x", "y", "z")) + arr = jax.device_put(np.arange(24).reshape(2, 3, 4), s) + + with jtu.capture_stdout() as output: + with mesh: + f(arr) + jax.effects_barrier() + + expected = { + _format_multiline(""" + hello: [[[ 0 1] + [ 4 5] + [ 8 9]] + + [[12 13] + [16 17] + [20 21]]]"""), + _format_multiline(""" + hello: [[[ 2 3] + [ 6 7] + [10 11]] + + [[14 15] + [18 19] + [22 23]]]"""), + } + self.assertEqual(_get_output_set(output, 7), expected) + + def test_debug_print_batching(self): + @jax.vmap + def f_(x): + debug_print("hello: {}", x, partitioned=True) + + f = pjit.pjit(f_) + mesh = jtu.create_mesh((1, 1, 2), ("x", "y", "z")) + s = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("x", "y", "z")) + arr = np.arange(24).reshape(2, 3, 4) + arr = jax.device_put(arr, s) + + with jtu.capture_stdout() as output: + with mesh: + f(arr) + jax.effects_barrier() + + expected = { + _format_multiline(""" + hello: [[0 1] + [4 5] + [8 9]]"""), + _format_multiline(""" + hello: [[ 2 3] + [ 6 7] + [10 11]]"""), + _format_multiline(""" + hello: [[14 15] + [18 19] + [22 23]]"""), + _format_multiline(""" + hello: [[12 13] + [16 17] + [20 21]]"""), + } + + self.assertEqual(_get_output_set(output, 3), expected) + + def test_debug_print_batching_with_diff_axes(self): + @functools.partial(jax.vmap, in_axes=(0, 1)) + def f_(x, y): + debug_print("hello: {} {}", x, y, partitioned=True) + + f = pjit.pjit(f_) + mesh = jtu.create_mesh((2,), ("x")) + s = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("x")) + x = np.arange(4).reshape(2, 2) + x = jax.device_put(x, s) + y = np.arange(4).reshape(2, 2) + 6 + y = jax.device_put(y, s) + + with jtu.capture_stdout() as output: + with mesh: + f(x, y) + jax.effects_barrier() + + expected = { + "hello: [2 3] [9]", + "hello: [0 1] [6]", + "hello: [0 1] [8]", + "hello: [2 3] [7]", + } + + self.assertEqual(_get_output_set(output, 1), expected) + + def test_debug_print_with_nested_vmap(self): + @jax.vmap + @jax.vmap + def f_(x): + debug_print("hello: {}", x, partitioned=True) + + f = pjit.pjit(f_) + mesh = jtu.create_mesh((1, 1, 2), ("x", "y", "z")) + s = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("x", "y", "z")) + arr = np.arange(24).reshape(2, 3, 4) + arr = jax.device_put(arr, s) + + with jtu.capture_stdout() as output: + with mesh: + f(arr) + jax.effects_barrier() + + expected = { + "hello: [14 15]", + "hello: [12 13]", + "hello: [18 19]", + "hello: [16 17]", + "hello: [22 23]", + "hello: [20 21]", + "hello: [2 3]", + "hello: [0 1]", + "hello: [6 7]", + "hello: [10 11]", + "hello: [4 5]", + "hello: [8 9]", + } + + self.assertEqual(_get_output_set(output, 1), expected) + + if not rich: del VisualizeShardingTest diff --git a/tests/distributed_initialize_test.py b/tests/distributed_initialize_test.py new file mode 100644 index 000000000000..33242a41a68e --- /dev/null +++ b/tests/distributed_initialize_test.py @@ -0,0 +1,44 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from absl.testing import absltest +import jax +from jax._src import test_util as jtu + +try: + import portpicker +except ImportError: + portpicker = None + +jax.config.parse_flags_with_absl() + + +@unittest.skipIf(not portpicker, "Test requires portpicker") +class DistributedInitializeTest(jtu.JaxTestCase): + + @jtu.skip_under_pytest( + """Side effects from jax.distributed.initialize conflict with other tests + in the same process. pytest runs multiple tests in the same process.""" + ) + def test_is_distributed_initialized(self): + port = portpicker.pick_unused_port() # type: ignore + self.assertFalse(jax.distributed.is_initialized()) + jax.distributed.initialize(f"localhost:{port}", 1, 0) + self.assertTrue(jax.distributed.is_initialized()) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/distributed_test.py b/tests/distributed_test.py index 3961932dfad0..5e47228c1719 100644 --- a/tests/distributed_test.py +++ b/tests/distributed_test.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import subprocess -import sys import threading import unittest @@ -67,22 +65,6 @@ def task(i): for thread in threads: thread.join() - def test_is_distributed_initialized(self): - # Run in subprocess to isolate side effects from jax.distributed.initialize which conflict with other - # tests. Unfortunately this can't be avoided by calling jax.distributed.shutdown, as the XLA backend - # will be warmed up, which yields a RuntimeError on subsequent calls to initialize. - port = portpicker.pick_unused_port() # type: ignore - cmd = f"""import jax; - assert not jax.distributed.is_initialized(); - jax.distributed.initialize('localhost:{port}', 1, 0); - assert jax.distributed.is_initialized(); - """.replace("\n", ' ') - - result = subprocess.run([sys.executable, "-c", cmd], capture_output=True) - self.assertEqual( - result.returncode, 0, msg=f"Test failed with:\n{result.stdout}\n{result.stderr}" - ) - if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/dtypes_test.py b/tests/dtypes_test.py index 87380443f4cb..d8fb30397b27 100644 --- a/tests/dtypes_test.py +++ b/tests/dtypes_test.py @@ -46,30 +46,19 @@ np.dtype('uint64')] unsigned_dtypes = list(np_unsigned_dtypes) -intn_dtypes = [np.dtype('int4'), np.dtype('uint4')] -signed_dtypes += [np.dtype('int4')] -unsigned_dtypes += [np.dtype('uint4')] -if dtypes.int2 is not None: - assert dtypes.uint2 is not None - intn_dtypes[:0] = [np.dtype('int2'), np.dtype('uint2')] - signed_dtypes[:0] = [np.dtype('int2')] - unsigned_dtypes[:0] = [np.dtype('uint2')] - -np_float_dtypes = [np.dtype('float16'), np.dtype('float32'), - np.dtype('float64')] +intn_dtypes = [np.dtype('int2'), np.dtype('uint2'), np.dtype('int4'), np.dtype('uint4')] +signed_dtypes += [np.dtype('int2'), np.dtype('int4')] +unsigned_dtypes += [np.dtype('uint2'), np.dtype('uint4')] + +np_float_dtypes = [np.dtype('float16'), np.dtype('float32'), np.dtype('float64')] float_dtypes = [np.dtype(dtypes.bfloat16)] + np_float_dtypes custom_float_dtypes = [np.dtype(dtypes.bfloat16)] fp8_dtypes = [np.dtype(dtypes.float8_e4m3b11fnuz), np.dtype(dtypes.float8_e4m3fn), np.dtype(dtypes.float8_e4m3fnuz), np.dtype(dtypes.float8_e5m2), - np.dtype(dtypes.float8_e5m2fnuz)] -if dtypes.float8_e3m4 is not None: - fp8_dtypes += [np.dtype(dtypes.float8_e3m4)] -if dtypes.float8_e4m3 is not None: - fp8_dtypes += [np.dtype(dtypes.float8_e4m3)] -if dtypes.float8_e8m0fnu is not None: - fp8_dtypes += [np.dtype(dtypes.float8_e8m0fnu)] + np.dtype(dtypes.float8_e5m2fnuz), np.dtype(dtypes.float8_e3m4), + np.dtype(dtypes.float8_e4m3), np.dtype(dtypes.float8_e8m0fnu)] float_dtypes += fp8_dtypes custom_float_dtypes += fp8_dtypes diff --git a/tests/error_check_test.py b/tests/error_check_test.py index b96c6281411f..0c77989b8a43 100644 --- a/tests/error_check_test.py +++ b/tests/error_check_test.py @@ -13,12 +13,16 @@ # limitations under the License. +import traceback + from absl.testing import absltest from absl.testing import parameterized import jax from jax._src import config from jax._src import error_check +from jax._src import mesh as mesh_lib from jax._src import test_util as jtu +import jax.export import jax.numpy as jnp from jax.sharding import NamedSharding, PartitionSpec as P @@ -30,7 +34,9 @@ jtu.request_cpu_devices(4) -@jtu.with_config(jax_check_tracer_leaks=True) +# TODO: AOT tests fails with the tracer leak checker. +# Reenable once https://github.com/jax-ml/jax/issues/27315 is fixed. +# @jtu.with_config(jax_check_tracer_leaks=True) class ErrorCheckTests(jtu.JaxTestCase): @parameterized.product(jit=[True, False]) @@ -107,6 +113,32 @@ def g(x): with self.assertRaisesRegex(JaxValueError, "x must be greater than 0 in g"): error_check.raise_if_error() + @parameterized.product(jit=[True, False]) + def test_error_includes_traceback(self, jit): + def function_that_triggers_error_for_traceback_test(x): + error_check.set_error_if( # This line must be included in the traceback. + x <= 0, "x must be greater than 0" + ) + return x + 1 + + if jit: + function_that_triggers_error_for_traceback_test = jax.jit( + function_that_triggers_error_for_traceback_test + ) + + x = jnp.zeros((4,), dtype=jnp.int32) + function_that_triggers_error_for_traceback_test(x) + + tb_string = "" + try: + error_check.raise_if_error() + except JaxValueError as e: + tb_string = traceback.format_tb(e.__traceback__) + tb_string = "".join(tb_string) + + self.assertIn("function_that_triggers_error_for_traceback_test", tb_string) + self.assertIn("This line must be included in the traceback", tb_string) + @parameterized.product(jit=[True, False]) def test_error_check_works_with_cond(self, jit): def f(x): @@ -202,13 +234,144 @@ def f(x): if jit: f = jax.jit(f) - sharding = NamedSharding(mesh, P("x", "y")) - x = jnp.full((4, 4), -1, dtype=jnp.int32, device=sharding) with error_check.error_checking_context(): + x = jnp.full((4, 4), -1, dtype=jnp.int32) + f(x) + with self.assertRaisesRegex(JaxValueError, "x must be greater than 0"): + error_check.raise_if_error() + + sharding = NamedSharding(mesh, P("x", "y")) + with error_check.error_checking_context(): + y = jnp.full((4, 4), -1, dtype=jnp.int32, device=sharding) + f(y) + with self.assertRaisesRegex(JaxValueError, "x must be greater than 0"): + error_check.raise_if_error() + + # The unsharded version of `f` should still be able to check errors after + # exiting the error checking context. + f(x) + with self.assertRaisesRegex(JaxValueError, "x must be greater than 0"): + error_check.raise_if_error() + + @parameterized.product(jit=[True, False]) + @jtu.with_user_mesh( + (2, 2), + ("x", "y"), + axis_types=(mesh_lib.AxisType.Auto, mesh_lib.AxisType.Auto), + ) + @jtu.ignore_warning( + message=( + "When at least one mesh axis of `pred` is in auto mode, calling" + " `set_error_if` will cause implicit communication between devices." + " To avoid this, consider converting the mesh axis in auto mode to" + " explicit mode." + ), + category=RuntimeWarning, + ) + def test_error_check_auto_mode(self, jit, mesh): + def f(x): + error_check.set_error_if(x <= 0, "x must be greater than 0") + return x + 1 + + if jit: + f = jax.jit(f) + + with error_check.error_checking_context(): + sharding = NamedSharding(mesh, P("x", "y")) + x = jnp.full((4, 4), -1, dtype=jnp.int32, device=sharding) f(x) with self.assertRaisesRegex(JaxValueError, "x must be greater than 0"): error_check.raise_if_error() + def test_error_check_aot(self): + def run_export(): + def f(x): + error_check.set_error_if(x <= 0, "x must be greater than 0") + return x + 1 + + f = jax.jit(error_check.wrap_for_export(jax.jit(f))) + x = jax.ShapeDtypeStruct((), jnp.float32) + serialized = jax.export.export(f)(x).serialize() + return serialized + + def run_import(serialized): + f = jax.export.deserialize(serialized).call + f = jax.jit(error_check.unwrap_from_import(jax.jit(f))) + x = jnp.float32(-3.) + _ = f(x) + with self.assertRaisesRegex(JaxValueError, "x must be greater than 0"): + error_check.raise_if_error() + + serialized = run_export() + run_import(serialized) + + def test_error_check_aot_includes_traceback(self): + def run_export(): + def function_that_triggers_error_for_traceback_test(x): + error_check.set_error_if( # This line must be included in the traceback + x <= 0, "x must be greater than 0" + ) + return x + 1 + + f = jax.jit( + error_check.wrap_for_export( + jax.jit(function_that_triggers_error_for_traceback_test) + ) + ) + x = jax.ShapeDtypeStruct((), jnp.float32) + serialized = jax.export.export(f)(x).serialize() + return serialized + + def run_import(serialized): + f = jax.export.deserialize(serialized).call + f = jax.jit(error_check.unwrap_from_import(jax.jit(f))) + x = jnp.float32(-3.0) + _ = f(x) + + msg = "" + try: + error_check.raise_if_error() + except JaxValueError as e: + msg = str(e) + + self.assertIn("function_that_triggers_error_for_traceback_test", msg) + self.assertIn("This line must be included in the traceback", msg) + + serialized = run_export() + run_import(serialized) + + def test_error_check_aot_should_not_override_existing_error(self): + def f1(x): + error_check.set_error_if(x <= 0, "x must be greater than 0 in f1") + return x + 1 + + def run_export(): + def f2(x): + error_check.set_error_if(x <= 0, "x must be greater than 0 in f2") + return x + 1 + + f2 = jax.jit(error_check.wrap_for_export(jax.jit(f2))) + x = jax.ShapeDtypeStruct((), jnp.float32) + serialized = jax.export.export(f2)(x).serialize() + return serialized + + def run_import(serialized): + f2 = jax.export.deserialize(serialized).call + f2 = jax.jit(error_check.unwrap_from_import(jax.jit(f2))) + return f2 + + x = jnp.float32(-3.) + _ = f1(x) # check fails. so it should set error + + serialized = run_export() + f2 = run_import(serialized) + _ = f2(x) # check fails, but should not override the error + + with self.assertRaisesRegex( + JaxValueError, "x must be greater than 0 in f1" + ): + error_check.raise_if_error() + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/errors_test.py b/tests/errors_test.py index 25f29cfee224..63618e646dfd 100644 --- a/tests/errors_test.py +++ b/tests/errors_test.py @@ -455,7 +455,7 @@ class FakeTracer(core.Tracer): ErrorClass = getattr(jax.errors, errorclass) err = ErrorClass(FakeTracer(None)) - self.assertIn(f'https://jax.readthedocs.io/en/latest/errors.html#jax.errors.{errorclass}', str(err)) + self.assertIn(f'https://docs.jax.dev/en/latest/errors.html#jax.errors.{errorclass}', str(err)) if __name__ == '__main__': diff --git a/tests/export_back_compat_test.py b/tests/export_back_compat_test.py index 9b457b8f27a5..888b234e94c0 100644 --- a/tests/export_back_compat_test.py +++ b/tests/export_back_compat_test.py @@ -31,6 +31,7 @@ from jax._src.internal_test_util import export_back_compat_test_util as bctu +from jax._src.internal_test_util.export_back_compat_test_data import annotate_data_placement from jax._src.internal_test_util.export_back_compat_test_data import cpu_cholesky_lapack_potrf from jax._src.internal_test_util.export_back_compat_test_data import cpu_eig_lapack_geev from jax._src.internal_test_util.export_back_compat_test_data import cuda_eigh_cusolver_syev @@ -38,7 +39,6 @@ from jax._src.internal_test_util.export_back_compat_test_data import cpu_eigh_lapack_syev from jax._src.internal_test_util.export_back_compat_test_data import cpu_lu_lapack_getrf from jax._src.internal_test_util.export_back_compat_test_data import cuda_qr_cusolver_geqrf -from jax._src.internal_test_util.export_back_compat_test_data import rocm_qr_hipsolver_geqrf from jax._src.internal_test_util.export_back_compat_test_data import cpu_qr_lapack_geqrf from jax._src.internal_test_util.export_back_compat_test_data import cpu_schur_lapack_gees from jax._src.internal_test_util.export_back_compat_test_data import cpu_svd_lapack_gesdd @@ -120,7 +120,7 @@ def test_custom_call_coverage(self): targets_to_cover = set(_export._CUSTOM_CALL_TARGETS_GUARANTEED_STABLE) cpu_ffi_testdatas = [ cpu_cholesky_lapack_potrf.data_2024_05_31, - cpu_qr_lapack_geqrf.data_2024_08_22, + cpu_qr_lapack_geqrf.data_2025_04_02, cpu_eig_lapack_geev.data_2024_08_19, cpu_eigh_lapack_syev.data_2024_08_19, cpu_lu_lapack_getrf.data_2024_05_31, @@ -134,25 +134,17 @@ def test_custom_call_coverage(self): # stable covering_testdatas = [ *cpu_ffi_testdatas, - cpu_cholesky_lapack_potrf.data_2023_06_19, - cpu_eig_lapack_geev.data_2023_06_19, - cpu_eigh_lapack_syev.data_2023_03_17, - cpu_qr_lapack_geqrf.data_2023_03_17, cuda_threefry2x32.data_2024_07_30, - cpu_lu_lapack_getrf.data_2023_06_14, - cuda_lu_pivots_to_permutation.data_2024_08_08, + cuda_lu_pivots_to_permutation.data_2025_04_01, cuda_lu_cusolver_getrf.data_2024_08_19, cuda_qr_cusolver_geqrf.data_2024_09_26, cuda_eigh_cusolver_syev.data_2024_09_30, cuda_svd_cusolver_gesvd.data_2024_10_08, cpu_tridiagonal_solve_lapack_gtsv.data_2025_01_09, cuda_tridiagonal_cusolver_sytrd.data_2025_01_09, - rocm_qr_hipsolver_geqrf.data_2024_08_05, rocm_eigh_hipsolver_syev.data_2024_08_05, cpu_schur_lapack_gees.data_2023_07_16, - cpu_svd_lapack_gesdd.data_2023_06_19, cpu_triangular_solve_blas_trsm.data_2023_07_16, - cpu_hessenberg_lapack_gehrd.data_2024_08_30, cpu_tridiagonal_lapack_sytrd_hetrd.data_2024_09_03, tpu_Eigh.data, tpu_Lu.data_2023_03_21, tpu_Qr.data_2023_03_17, tpu_Sharding.data_2023_03_16, tpu_ApproxTopK.data_2023_04_17, @@ -163,6 +155,8 @@ def test_custom_call_coverage(self): stablehlo_dynamic_top_k.data_2023_07_16, stablehlo_dynamic_top_k.data_2023_08_11, # with shape_assertion stablehlo_dynamic_approx_top_k.data_2024_05_30, + annotate_data_placement.data_2025_04_07_tpu, + annotate_data_placement.data_2025_04_07_cuda, ] # Some of the above are nested structures. covering_testdatas = itertools.chain( @@ -212,10 +206,6 @@ def test_cpu_cholesky_lapack_potrf(self, dtype_name="f32"): data = self.load_testdata(info) self.run_one_test(func, data, rtol=rtol, atol=atol) - data = self.load_testdata(cpu_cholesky_lapack_potrf.data_2023_06_19[dtype_name]) - self.run_one_test(func, data, rtol=rtol, atol=atol, - expect_current_custom_calls=info["custom_call_targets"]) - @parameterized.named_parameters( dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name) for dtype_name in ("f32", "f64", "c64", "c128")) @@ -276,10 +266,6 @@ def check_eigenvalue_is_in_array(eigenvalue, eigenvalues_array): data = self.load_testdata(info) self.run_one_test(func, data, rtol=rtol, atol=atol, check_results=check_eig_results) - data = self.load_testdata(cpu_eig_lapack_geev.data_2023_06_19[dtype_name]) - self.run_one_test(func, data, rtol=rtol, atol=atol, - check_results=check_eig_results, - expect_current_custom_calls=info["custom_call_targets"]) @staticmethod def eigh_input(shape, dtype): @@ -333,12 +319,6 @@ def test_cpu_eigh_lapack_syevd(self, dtype_name="f32"): self.run_one_test(func, data, rtol=rtol, atol=atol, check_results=partial(self.check_eigh_results, operand)) - # Legacy custom call test - data = self.load_testdata(cpu_eigh_lapack_syev.data_2023_03_17[dtype_name]) - self.run_one_test(func, data, rtol=rtol, atol=atol, - check_results=partial(self.check_eigh_results, operand), - expect_current_custom_calls=info["custom_call_targets"]) - @parameterized.named_parameters( dict(testcase_name=f"_dtype={dtype_name}_{variant}", dtype_name=dtype_name, variant=variant) @@ -411,14 +391,14 @@ def lu_pivots_to_permutation_harness(shape): def test_cuda_lu_pivots_to_permutation(self): shape = (2, 3, 4) func = lambda: CompatTest.lu_pivots_to_permutation_harness(shape) - data = self.load_testdata(cuda_lu_pivots_to_permutation.data_2024_08_08) + data = self.load_testdata(cuda_lu_pivots_to_permutation.data_2025_04_01) self.run_one_test(func, data) @parameterized.named_parameters( dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name) for dtype_name in ("f32", "f64", "c64", "c128")) - def test_cuda_lu_lapack_getrf(self, dtype_name:str): + def test_cuda_lu_cusolver_getrf(self, dtype_name:str): if not config.enable_x64.value and dtype_name in ["f64", "c128"]: self.skipTest("Test disabled for x32 mode") dtype = dict(f32=np.float32, f64=np.float64, @@ -445,38 +425,10 @@ def test_cpu_qr_lapack_geqrf(self, dtype_name="f32"): c64=np.complex64, c128=np.complex128)[dtype_name] func = lambda: CompatTest.qr_harness((3, 3), dtype) - info = cpu_qr_lapack_geqrf.data_2024_08_22[dtype_name] + info = cpu_qr_lapack_geqrf.data_2025_04_02[dtype_name] data = self.load_testdata(info) self.run_one_test(func, data, rtol=rtol) - # TODO(b/369826500): Remove legacy custom call test after mid March 2025. - data = self.load_testdata(cpu_qr_lapack_geqrf.data_2023_03_17[dtype_name]) - self.run_one_test(func, data, rtol=rtol, - expect_current_custom_calls=info["custom_call_targets"]) - - # TODO(b/369826500): Remove legacy custom call test after mid March 2025. - @parameterized.named_parameters( - dict(testcase_name=f"_dtype={dtype_name}_{batched}", - dtype_name=dtype_name, batched=batched) - for dtype_name in ("f32",) - # For batched qr we use cublas_geqrf_batched/hipblas_geqrf_batched. - for batched in ("batched", "unbatched")) - def test_gpu_qr_solver_geqrf_legacy(self, dtype_name, batched): - if jtu.test_device_matches(["rocm"]): - data = self.load_testdata(rocm_qr_hipsolver_geqrf.data_2024_08_05[batched]) - prefix = "hip" - elif jtu.test_device_matches(["cuda"]): - data = self.load_testdata(cuda_qr_cusolver_geqrf.data_2023_03_18[batched]) - prefix = "cu" - else: - self.skipTest("Unsupported platform") - dtype = dict(f32=np.float32)[dtype_name] - rtol = dict(f32=1e-3)[dtype_name] - shape = dict(batched=(2, 3, 3), unbatched=(3, 3))[batched] - func = lambda: CompatTest.qr_harness(shape, dtype) - self.run_one_test(func, data, rtol=rtol, expect_current_custom_calls=[ - f"{prefix}solver_geqrf_ffi", f"{prefix}solver_orgqr_ffi"]) - @parameterized.named_parameters( dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name) for dtype_name in ("f32", "f64", "c64", "c128")) @@ -551,14 +503,6 @@ def test_cpu_lu_lapack_getrf(self, dtype_name:str): check_results=partial(self.check_lu_results, operand, dtype=dtype)) - # TODO(b/357034884): Remove legacy custom call test after mid March 2025. - legacy_data = self.load_testdata( - cpu_lu_lapack_getrf.data_2023_06_14[dtype_name]) - self.run_one_test(func, legacy_data, rtol=rtol, atol=atol, - check_results=partial(self.check_lu_results, operand, - dtype=dtype), - expect_current_custom_calls=info["custom_call_targets"]) - def check_svd_results(self, input, res_run, res_exp, rtol=None, atol=None): # Following linalg_test.testSVD @@ -677,12 +621,6 @@ def func(operand): check_results=partial(self.check_svd_results, *data.inputs)) - data = self.load_testdata(cpu_svd_lapack_gesdd.data_2023_06_19[dtype_name]) - self.run_one_test(func, data, rtol=rtol, atol=atol, - check_results=partial(self.check_svd_results, - *data.inputs), - expect_current_custom_calls=info["custom_call_targets"]) - @parameterized.named_parameters( dict(testcase_name=f"_dtype={dtype_name}_algorithm={algorithm_name}", dtype_name=dtype_name, algorithm_name=algorithm_name) @@ -773,12 +711,6 @@ def func(): data = self.load_testdata(info) self.run_one_test(func, data, rtol=rtol, atol=atol) - data = self.load_testdata( - cpu_hessenberg_lapack_gehrd.data_2024_08_30[dtype_name] - ) - self.run_one_test(func, data, rtol=rtol, atol=atol, - expect_current_custom_calls=info["custom_call_targets"]) - @parameterized.named_parameters( dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name) for dtype_name in ("f32", "f64", "c64", "c128")) @@ -842,7 +774,7 @@ def func(x): ) self.run_one_test(func, data, rtol=rtol, atol=atol) - def test_approx_top_k(self): + def test_tpu_approx_top_k(self): def func(): x = np.array([3.0, 1.0, 4.0, 2.0, 5.0, 6.0, 7.0]) y = lax.approx_max_k(x, 3) @@ -859,7 +791,7 @@ def func(x): data = self.load_testdata(cuda_threefry2x32.data_2024_07_30) self.run_one_test(func, data) - def test_sharding(self): + def test_tpu_sharding(self): # Tests "Sharding", "SPMDShardToFullShape", "SPMDFullToShardShape" on TPU if not jtu.test_device_matches(["tpu"]) or len(jax.devices()) < 2: self.skipTest("Test runs only on TPU with at least 2 devices") @@ -881,6 +813,31 @@ def func(x): # b: f32[2, 4] with mesh: self.run_one_test(func, data) + @parameterized.named_parameters( + dict(testcase_name=f"_platform={platform}", platform=platform) + for platform in ("tpu", "gpu")) + def test_annotate_device_placement(self, platform): + if not jtu.test_device_matches([platform]): + self.skipTest(f"Test enabled only for {platform}") + + mesh = Mesh(jax.local_devices()[0:1], axis_names=("a")) + + dev_sharding = NS(mesh, P("a")) + host_sharding = NS(mesh, P("a"), memory_kind="pinned_host") + + @partial(jax.jit, + in_shardings=(dev_sharding, host_sharding), + out_shardings=host_sharding) + def func(x, y): + return x + y + + if platform == "tpu": + data = self.load_testdata(annotate_data_placement.data_2025_04_07_tpu) + else: + data = self.load_testdata(annotate_data_placement.data_2025_04_07_cuda) + + self.run_one_test(func, data) + def test_tpu_stablehlo_dynamic_reduce_window_unary(self): # stablehlo.dynamic_reduce_window is used temporarily on TPU for a # reduce window with dynamic shapes. diff --git a/tests/export_test.py b/tests/export_test.py index 2b083f3121f4..e50738ba2480 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -19,7 +19,6 @@ import dataclasses import functools import logging -import json import math import re import unittest @@ -37,6 +36,7 @@ from jax import tree_util from jax._src import config +from jax._src import compute_on from jax._src import core from jax._src import dtypes from jax._src import effects @@ -204,7 +204,14 @@ def test_basic(self): f = jnp.sin x = np.arange(4, dtype=np.float32) exp_f = get_exported(f)(x) + self.assertAllClose(f(x), exp_f.call(x)) + def test_basic_single_device_sharding(self): + device = jax.local_devices()[0] + s = jax.sharding.SingleDeviceSharding(device) + x = np.arange(16, dtype=np.float32).reshape(4, -1) + f = jax.jit(lambda x: x * 2., in_shardings=s, out_shardings=s) + exp_f = get_exported(f)(x) self.assertAllClose(f(x), exp_f.call(x)) def test_jit_static_arg(self): @@ -281,6 +288,18 @@ def test_unused_args(self): self.assertAllClose(f(x, y), exp_f.call(x, y)) + def test_override_lowering_rules(self): + @jax.jit + def f(x): + return jnp.sin(x) + + def my_lowering_rule(ctx, arg, **_): + return mlir.hlo.CosineOp(arg).results + + exp = get_exported(f, _override_lowering_rules=( + (lax.sin_p, my_lowering_rule),))(42.) + self.assertIn("stablehlo.cosine", exp.mlir_module()) + def test_pytree(self): a = np.arange(4, dtype=np.float32) b = np.arange(6, dtype=np.float32) @@ -410,6 +429,18 @@ def f(x1, x2): self.assertEqual(tree_util.tree_structure(res2), tree_util.tree_structure(res)) + @jtu.parameterized_filterable( + kwargs=[dict(impl=p) + for p in ("rbg", "unsafe_rbg", "threefry2x32")]) + def test_prng_keys(self, *, impl): + + key = jax.random.key(42, impl=impl) + @jax.jit + def f(key): + return key + exp_f = get_exported(jax.jit(f))(key) + self.assertEqual(f(key), exp_f.call(key)) + def test_error_wrong_intree(self): def f(a_b_pair, *, c): return jnp.sin(a_b_pair[0]) + jnp.cos(a_b_pair[1]) + c @@ -941,7 +972,7 @@ def outer(x): # x: outer_poly_spec "Using the following polymorphic shapes specifications: args[0].shape = (2*b + a, a, c + b + a). " "Obtained dimension variables: 'a' = 2 from specification 'a' for dimension args[0].shape[1] (= 2), " "'b' = 0 from specification '2*b + a' for dimension args[0].shape[0] (= 2), . " - "Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details." + "Please see https://docs.jax.dev/en/latest/export/shape_poly.html#shape-assertion-errors for more details." )), dict(shape=(3, 2, 6), # a = 2, b = 0.5, c = 4 - b is not integer poly_spec="(a + 2*b, a, a + b + c)", @@ -950,7 +981,7 @@ def outer(x): # x: outer_poly_spec "Division had remainder 1 when computing the value of 'b'. " "Using the following polymorphic shapes specifications: args[0].shape = (2*b + a, a, c + b + a). " "Obtained dimension variables: 'a' = 2 from specification 'a' for dimension args[0].shape[1] (= 2), . " - "Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details." + "Please see https://docs.jax.dev/en/latest/export/shape_poly.html#shape-assertion-errors for more details." )), dict(shape=(8, 2, 6), # a = 2, b = 3 - inconsistency poly_spec="(a + 2*b, a, a + b)", @@ -960,7 +991,7 @@ def outer(x): # x: outer_poly_spec "Using the following polymorphic shapes specifications: args[0].shape = (2*b + a, a, b + a). " "Obtained dimension variables: 'a' = 2 from specification 'a' for dimension args[0].shape[1] (= 2), " "'b' = 4 from specification 'b + a' for dimension args[0].shape[2] (= 6), . " - "Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details." + "Please see https://docs.jax.dev/en/latest/export/shape_poly.html#shape-assertion-errors for more details." )), dict(shape=(7, 2, 36), # a = 2, b = 3, c = 6 - cannot solve c poly_spec="(2 * a + b, a, c * c)", @@ -969,7 +1000,7 @@ def outer(x): # x: outer_poly_spec "We can only solve linear uni-variate constraints. " "Using the following polymorphic shapes specifications: args[0].shape = (b + 2*a, a, c^2). " "Unprocessed specifications: 'c^2' for dimension size args[0].shape[2]. " - "Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#dimension-variables-must-be-solvable-from-the-input-shapes for more details." + "Please see https://docs.jax.dev/en/latest/export/shape_poly.html#dimension-variables-must-be-solvable-from-the-input-shapes for more details." )), ]) def test_shape_constraints_errors(self, *, @@ -1695,6 +1726,22 @@ def f_jax(b): # b: f32[16 // DEVICES, 4] res_exp = exp.call(a_device) self.assertArraysAllClose(res_native, res_exp) + def test_compute_on_host(self): + operand = np.float32(0.) + + @jax.jit + @compute_on.compute_on("device_host") + def f_host(x): + # Adds 1 on CPU, which should be the result on all platforms because + # this code should always run on the host. + return jax.lax.platform_dependent(x, + cpu=lambda x: x + np.float32(1.), + default=lambda x: x + np.float32(2.)) + + self.assertAllClose(np.float32(1.), f_host(operand)) + exp = get_exported(f_host, platforms=("cpu", "tpu", "cuda", "rocm"))(operand) + self.assertAllClose(np.float32(1.), exp.call(operand)) + @jtu.parameterized_filterable( kwargs=[ dict(v=v) @@ -1903,8 +1950,8 @@ def f_jax(x): @jtu.parameterized_filterable( kwargs=[ - {"m": 5, "k": 4, "n": 3, "group_sizes": [5]}, - {"m": 10, "k": 9, "n": 8, "group_sizes": [3, 7]}, + {"m": 64, "k": 4, "n": 3, "group_sizes": [5]}, + {"m": 64, "k": 9, "n": 8, "group_sizes": [3, 7]}, ]) def test_ragged_dot(self, m, k, n, group_sizes): def f_jax(x, y, gs): diff --git a/tests/ffi_test.py b/tests/ffi_test.py index 46aaefa8f521..978415194e55 100644 --- a/tests/ffi_test.py +++ b/tests/ffi_test.py @@ -200,21 +200,6 @@ def test_ffi_call_batching(self, shape, vmap_method): else: self.assertArraysEqual(a, b) - @jtu.run_on_devices("gpu", "cpu") - def test_vectorized_deprecation(self): - x = self.rng().randn(3, 5, 4).astype(np.float32) - with self.assertWarns(DeprecationWarning): - ffi_call_geqrf(x, vectorized=True) - with self.assertWarns(DeprecationWarning): - jax.vmap(ffi_call_geqrf)(x) - - def test_backward_compat_syntax(self): - def fun(x): - return jax.ffi.ffi_call("test_ffi", x, x, param=0.5) - msg = "Calling ffi_call directly with input arguments is deprecated" - with self.assertDeprecationWarnsOrRaises("jax-ffi-call-args", msg): - jax.jit(fun).lower(jnp.ones(5)) - def test_input_output_aliases(self): def fun(x): return jax.ffi.ffi_call("test", x, input_output_aliases={0: 0})(x) diff --git a/tests/filecheck/custom_call.filecheck.py b/tests/filecheck/custom_call.filecheck.py index c6af4235ebb4..27cc904e59d8 100644 --- a/tests/filecheck/custom_call.filecheck.py +++ b/tests/filecheck/custom_call.filecheck.py @@ -19,7 +19,7 @@ from absl import app import jax -from jax.interpreters import mlir +from jax._src.interpreters import mlir from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import func as func_dialect import numpy as np diff --git a/tests/fused_attention_stablehlo_test.py b/tests/fused_attention_stablehlo_test.py index af0b18b02f37..084ea3b3b0ae 100644 --- a/tests/fused_attention_stablehlo_test.py +++ b/tests/fused_attention_stablehlo_test.py @@ -618,6 +618,8 @@ def test_sdpa_packed_layout(self): return if cudnn_version < 90600: self.skipTest("Requires >= cuDNN 9.6.0") + if not jtu.is_cuda_compute_capability_at_least("9.0"): + self.skipTest("Requires at least Hopper arch") k1, k2, k3, k4 = jax.random.split(jax.random.key(0), 4) query = jax.random.normal( k1, (4, 512, 4, 64), dtype=jnp.bfloat16) @@ -737,6 +739,26 @@ def generate_segment_mask(segment_ids, dtype): self.assertArraysAllClose(key_grad_ref, key_grad, rtol=1e-2, atol=1e-2) self.assertArraysAllClose(value_grad_ref, value_grad, rtol=1e-2, atol=1e-2) + @jtu.run_on_devices("cuda") + def test_sdpa_residual(self): + k1, k2, k3, k4 = jax.random.split(jax.random.key(0), 4) + query = jax.random.normal( + k1, (4, 1024, 4, 64), dtype=jnp.bfloat16) + key = jax.random.normal( + k2, (4, 1024, 4, 64), dtype=jnp.bfloat16) + value = jax.random.normal( + k3, (4, 1024, 4, 64), dtype=jnp.bfloat16) + grad = jax.random.normal( + k4, (4, 1024, 4, 64), dtype=jnp.bfloat16) + + jitted_sdpa_inference = jax.jit( + partial( + dot_product_attention, scale=1.0, mask_type=MaskType.NO_MASK, + dropout_rate=0, return_residual=True), + ) + outs = jitted_sdpa_inference(query, key, value) + assert len(outs) == 2 + @jtu.run_on_devices("cuda") def test_layouts(self): if jax.device_count() < 4: diff --git a/tests/infeed_test.py b/tests/infeed_test.py index 060502ae68cd..79d4dc038fc2 100644 --- a/tests/infeed_test.py +++ b/tests/infeed_test.py @@ -78,6 +78,8 @@ def f(x): self.assertAllClose(f(x), to_infeed) @jax.numpy_rank_promotion("allow") # Test explicitly exercises implicit rank promotion. + @jtu.ignore_warning(category=DeprecationWarning, + message=".*(infeed|outfeed) was deprecated.*") def testInfeedThenOutfeed(self): @jax.jit @@ -99,6 +101,8 @@ def f(x): execution.join() self.assertAllClose(out, y + np.float32(1)) + @jtu.ignore_warning(category=DeprecationWarning, + message=".*(infeed|outfeed) was deprecated.*") def testInfeedThenOutfeedInALoop(self): def doubler(_, token): diff --git a/tests/jax_numpy_error_test.py b/tests/jax_numpy_error_test.py new file mode 100644 index 000000000000..dba277289f42 --- /dev/null +++ b/tests/jax_numpy_error_test.py @@ -0,0 +1,278 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import operator + +from absl.testing import absltest +from absl.testing import parameterized +import jax +from jax._src import config +from jax._src import error_check +from jax._src import test_util as jtu +from jax._src.numpy import error as jnp_error +import jax.numpy as jnp + +config.parse_flags_with_absl() + + +JaxValueError = error_check.JaxValueError + + +class JaxNumpyErrorTests(jtu.JaxTestCase): + def setUp(self): + # TODO(b/408148001): Fix thread safety issue. + if jtu.TEST_NUM_THREADS.value > 1: + self.skipTest("Test does not work with multiple threads") + super().setUp() + + @parameterized.product(jit=[True, False]) + def test_set_error_if_nan(self, jit): + def f(x): + jnp_error._set_error_if_nan(x) + return x + + if jit: + f = jax.jit(f) + + x = jnp.full((4,), jnp.nan, dtype=jnp.float32) + + with jnp_error.error_checking_behavior(nan="ignore"): + _ = f(x) + error_check.raise_if_error() # should not raise error + + with jnp_error.error_checking_behavior(nan="raise"): + _ = f(x) + with self.assertRaisesRegex(JaxValueError, "NaN"): + error_check.raise_if_error() + + @parameterized.product(jit=[True, False]) + def test_set_error_if_divide_by_zero(self, jit): + def f(x, y): + jnp_error._set_error_if_divide_by_zero(y) + return x / y + + if jit: + f = jax.jit(f) + + x = jnp.arange(4, dtype=jnp.float32) + 1 + y = jnp.arange(4, dtype=jnp.float32) + + with jnp_error.error_checking_behavior(divide="ignore"): + _ = f(x, y) + error_check.raise_if_error() # should not raise error + + with jnp_error.error_checking_behavior(divide="raise"): + _ = f(x, y) + with self.assertRaisesRegex(JaxValueError, "Division by zero"): + error_check.raise_if_error() + + @parameterized.product(jit=[True, False]) + def test_error_category_oob_check(self, jit): + def f(x, start_indices, slice_sizes): + jnp_error._set_error_if_with_category( + jnp.logical_or( + start_indices < 0, + start_indices + jnp.array(slice_sizes, dtype=jnp.int32) + >= jnp.array(x.shape, dtype=jnp.int32), + ), + "Out of bounds in dynamic_slice", + category="oob", + ) + y = jax.lax.dynamic_slice( + x, start_indices, slice_sizes, allow_negative_indices=False + ) + return y + + if jit: + f = jax.jit(f, static_argnums=(2,)) + + x = jnp.arange(12).reshape(3, 4) + start_indices = jnp.array([0, -1], dtype=jnp.int32) + slice_sizes = (3, 4) + + with jnp_error.error_checking_behavior(oob="ignore"): + _ = f(x, start_indices, slice_sizes) + error_check.raise_if_error() # should not raise error + + with jnp_error.error_checking_behavior(oob="raise"): + _ = f(x, start_indices, slice_sizes) + with self.assertRaisesRegex( + JaxValueError, "Out of bounds in dynamic_slice", + ): + error_check.raise_if_error() + + def test_error_category_invalid_category(self): + with self.assertRaisesRegex(ValueError, "Invalid category"): + jnp_error._set_error_if_with_category( + jnp.isnan(jnp.float32(1.0)), "x is NaN", category="invalid" + ) + + @staticmethod + def nan_cases(cases): + for jit in (True, False): + for func, args_error, args_no_err in cases: + if not isinstance(args_error, tuple): + args_error = (args_error,) + if not isinstance(args_no_err, tuple): + args_no_err = (args_no_err,) + + jit_str = "jit" if jit else "nojit" + func_str = f"{func.__module__}.{func.__name__}" + name = f"_{jit_str}_{func_str}" + + yield name, jit, func, args_error, args_no_err + + @parameterized.named_parameters( + nan_cases(( + # List of all NaN-producing jax.numpy functions. + # The first group of numbers is the input that will produce a NaN, and + # the second group is the input that will not produce a NaN. + # go/keep-sorted start + (jnp.acos, 2.0, 0.5), + (jnp.acosh, 0.5, 2.0), + (jnp.add, (jnp.inf, -jnp.inf), (0.0, 0.0)), + (jnp.arccos, 2.0, 0.5), + (jnp.arccosh, 0.5, 2.0), + (jnp.arcsin, -2.0, 0.5), + (jnp.arctanh, -2.0, 0.5), + (jnp.asin, -2.0, 0.5), + (jnp.atanh, -2.0, 0.5), + (jnp.cos, jnp.inf, 1.0), + (jnp.divide, (0.0, 0.0), (1.0, 1.0)), + (jnp.divmod, (1.0, 0.0), (1.0, 1.0)), + (jnp.float_power, (-1.0, 0.5), (1.0, 1.0)), + (jnp.fmod, (1.0, 0.0), (1.0, 1.0)), + (jnp.log, -1.0, 1.0), + (jnp.log10, -1.0, 1.0), + (jnp.log1p, -1.5, 1.0), + (jnp.log2, -1.0, 1.0), + (jnp.mod, (1.0, 0.0), (1.0, 1.0)), + (jnp.pow, (-1.0, 0.5), (1.0, 1.0)), + (jnp.power, (-1.0, 0.5), (1.0, 1.0)), + (jnp.remainder, (1.0, 0.0), (1.0, 1.0)), + (jnp.sin, jnp.inf, 1.0), + # TODO(https://github.com/jax-ml/jax/issues/27470): Not yet supported. + # (jnp.sinc, jnp.inf, 1.0), + (jnp.sqrt, -4.0, 4.0), + (jnp.subtract, (jnp.inf, jnp.inf), (0.0, 0.0)), + (jnp.tan, jnp.inf, 1.0), + (jnp.true_divide, (0.0, 0.0), (1.0, 1.0)), + (operator.add, (jnp.inf, -jnp.inf), (0.0, 0.0)), + (operator.mod, (1.0, 0.0), (1.0, 1.0)), + (operator.pow, (-1.0, 0.5), (1.0, 1.0)), + (operator.sub, (jnp.inf, jnp.inf), (0.0, 0.0)), + (operator.truediv, (0.0, 0.0), (1.0, 1.0)), + # go/keep-sorted end + )) + ) + def test_can_raise_nan_error(self, jit, f, args_err, args_no_err): + args_err = [jnp.float32(x) for x in args_err] + args_no_err = [jnp.float32(x) for x in args_no_err] + + if jit: + f = jax.jit(f) + + with jnp_error.error_checking_behavior(nan="raise"): + f(*args_no_err) + error_check.raise_if_error() # should not raise error + + f(*args_err) + with self.assertRaisesRegex(JaxValueError, "NaN"): + error_check.raise_if_error() + + INT_TYPES = (jnp.int32, jnp.uint32, jnp.int64, jnp.uint64, jnp.int16, + jnp.uint16, jnp.int8, jnp.uint8) + FLOAT_TYPES = (jnp.float32, jnp.float64, jnp.float16, jnp.bfloat16) + + @staticmethod + def divide_cases(cases): + for jit in (True, False): + for func, dtypes in cases: + for dtype in dtypes: + jit_str = "jit" if jit else "nojit" + func_str = f"{func.__module__}.{func.__name__}" + dtype_str = dtype.__name__ + name = f"_{jit_str}_{func_str}_{dtype_str}" + yield name, jit, func, dtype + + @parameterized.named_parameters( + divide_cases(( + # go/keep-sorted start + (jnp.divmod, FLOAT_TYPES + INT_TYPES), + (jnp.floor_divide, INT_TYPES), + (jnp.mod, FLOAT_TYPES + INT_TYPES), + (jnp.remainder, FLOAT_TYPES + INT_TYPES), + (jnp.true_divide, FLOAT_TYPES), + (operator.mod, FLOAT_TYPES + INT_TYPES), + (operator.truediv, FLOAT_TYPES), + # go/keep-sorted end + )) + ) + def test_can_raise_divide_by_zero_error(self, jit, div_func, dtype): + if not jax.config.x64_enabled and jnp.dtype(dtype).itemsize == 8: + self.skipTest("64-bit types require x64_enabled") + + args_err = (dtype(1), dtype(0)) + args_no_err = (dtype(1), dtype(1)) + + if jit: + div_func = jax.jit(div_func) + + with jnp_error.error_checking_behavior(divide="raise"): + div_func(*args_no_err) + error_check.raise_if_error() # should not raise error + + div_func(*args_err) + with self.assertRaisesRegex(JaxValueError, "Division by zero"): + error_check.raise_if_error() + + @parameterized.product(jit=[True, False]) + def test_can_raise_oob_error_take(self, jit): + def f(x, a): + return x[a] + + if jit: + f = jax.jit(f) + + x = jnp.arange(10) + a = jnp.int32(10) + + with jnp_error.error_checking_behavior(oob="ignore"): + f(x, a) + error_check.raise_if_error() # should not raise error + + with jnp_error.error_checking_behavior(oob="raise"): + f(x, a) + with self.assertRaisesRegex(JaxValueError, "Out of bounds"): + error_check.raise_if_error() + + def test_can_raise_oob_error_dynamic_slice(self): + def f(x, a): + return x[:, a:a+4] # dynamic indices are non-jittable + + x = jnp.arange(10).reshape(2, 5) + a = jnp.array(3, dtype=jnp.int32) + + with jnp_error.error_checking_behavior(oob="ignore"): + f(x, a) + error_check.raise_if_error() # should not raise error + + with jnp_error.error_checking_behavior(oob="raise"): + f(x, a) + with self.assertRaisesRegex(JaxValueError, "Out of bounds"): + error_check.raise_if_error() + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/jax_to_ir_test.py b/tests/jax_to_ir_test.py index f600a08f5dc4..4eb8190b712f 100644 --- a/tests/jax_to_ir_test.py +++ b/tests/jax_to_ir_test.py @@ -114,15 +114,13 @@ def test_parse_shape_str(self): self.assertParsedShape('f32[]', [], jnp.float32) self.assertParsedShape('f32[1,2,3]', [1, 2, 3], jnp.float32) self.assertParsedShape('pred[1]', [1], jnp.bool_) - if hasattr(jnp, 'int2'): - self.assertParsedShape('s2[1]', [1], jnp.int2) + self.assertParsedShape('s2[1]', [1], jnp.int2) self.assertParsedShape('s4[1]', [1], jnp.int4) self.assertParsedShape('s8[1]', [1], jnp.int8) self.assertParsedShape('s16[1]', [1], jnp.int16) self.assertParsedShape('s32[1]', [1], jnp.int32) self.assertParsedShape('s64[1]', [1], jnp.int64) - if hasattr(jnp, 'uint2'): - self.assertParsedShape('u2[1]', [1], jnp.uint2) + self.assertParsedShape('u2[1]', [1], jnp.uint2) self.assertParsedShape('u4[1]', [1], jnp.uint4) self.assertParsedShape('u8[1]', [1], jnp.uint8) self.assertParsedShape('u16[1]', [1], jnp.uint16) diff --git a/tests/jaxpr_effects_test.py b/tests/jaxpr_effects_test.py index c331bfaf438a..d5574f8a9a1d 100644 --- a/tests/jaxpr_effects_test.py +++ b/tests/jaxpr_effects_test.py @@ -947,7 +947,7 @@ def make_fun(index): def f(x): def body(y): input_effect(x, y, index=index) - return y + return 2 * y lax.while_loop(lambda _: True, body, y) return f jaxpr = jax.make_jaxpr(make_fun(0))(0) @@ -959,7 +959,7 @@ def body(y): def f(x): def body(y): input_effect(x, y, index=1) - return y + return 2 * y lax.while_loop(lambda _: (x > 0).all(), body, y) jaxpr = jax.make_jaxpr(f)(0) self.assertIn(InputEffect(0), jaxpr.effects) diff --git a/tests/lax_autodiff_test.py b/tests/lax_autodiff_test.py index a69f44f37754..aea9d2ad3dff 100644 --- a/tests/lax_autodiff_test.py +++ b/tests/lax_autodiff_test.py @@ -205,14 +205,16 @@ class LaxAutodiffTest(jtu.JaxTestCase): )) def testOpGrad(self, op, rng_factory, shapes, dtype, order, tol): rng = rng_factory(self.rng()) - if jtu.test_device_matches(["cpu"]): + if jtu.test_device_matches(["cpu", "tpu"]): if op is lax.cosh and dtype == np.complex64: - tol = 3e-1 # 2nd-order gradients are noisy on CPU + tol = 3e-1 # 2nd-order gradients are noisy on CPU and TPU if jtu.test_device_matches(["tpu"]): if op is lax.pow: raise SkipTest("pow grad imprecise on tpu") if op is lax.cos: order = 1 # 2nd-order gradient is imprecise on TPU. + if op is lax.sin: + order = 1 # 2nd-order gradient is imprecise on TPUv5p. if op is lax.log: order = 1 # 2nd-order gradient is imprecise on TPU. diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index 3871a87a7a3e..242b0548023e 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -33,7 +33,7 @@ from jax import random from jax._src import test_util as jtu from jax import tree_util -from jax._src.util import unzip2 +from jax._src.util import unzip2, split_list from jax.ad_checkpoint import checkpoint as new_checkpoint, checkpoint_policies import jax.numpy as jnp # scan tests use numpy import jax.scipy as jsp @@ -588,7 +588,6 @@ def test_fori_loop_returns_init_with_nonpositive_length( init = jnp.float32(10) self.assertEqual(fori_loop_with_static_upper_and_lower(init), init) - def testForiLoopBatched(self): def body_fun(i, loop_carry): x, y = loop_carry @@ -994,16 +993,24 @@ def testCondTypeErrors(self): with self.assertRaisesRegex(TypeError, re.escape("Pred must be a scalar, got (1.0, 1.0) of type ")): lax.cond((1., 1.), lambda top: 2., lambda fop: 3., 1.) - with self.assertRaisesRegex(TypeError, - re.compile("true_fun output must have same type structure " - "as false_fun output, but there are differences:.*" - r"at output\['a'\], true_fun output has pytree leaf", re.DOTALL)): + + with self.assertRaisesRegex( + TypeError, + re.compile( + r"cond branch outputs must have the same pytree structure, but they" + r" differ:.*true_fun output at path \['a'\] is a pytree leaf but" + r" false_fun output at path \['a'\] is a ", + re.DOTALL)): lax.cond(True, lambda top: dict(a=2.), lambda fop: dict(a=(3., 3.)), 1.) + with self.assertRaisesRegex( TypeError, - "true_fun output and false_fun output must have identical types, got\n" - r"DIFFERENT ShapedArray\(float32\[1\]\) vs. " - r"ShapedArray\(float32\[\].*\)."): + re.compile( + r"cond branches must have equal output types but they differ.*The" + r" output of true_fun has type float32\[1\] but the corresponding" + r" output of false_fun has type float32\[\], so the shapes do not" + r" match", + re.DOTALL)): lax.cond(True, lambda top: jnp.array([1.], jnp.float32), lambda fop: jnp.float32(1.), @@ -1023,16 +1030,26 @@ def testSwitchErrors(self): with self.assertRaisesRegex(ValueError, re.escape("Empty branch sequence")): lax.switch(0, [], 1.) - with self.assertRaisesRegex(TypeError, - re.compile("branch 0 output must have same type structure " - "as branch 1 output, but there are differences:.*" - r"at output\['a'\], branch 0 output has pytree leaf", re.DOTALL)): + + with self.assertRaisesRegex( + TypeError, + re.compile( + "switch branch outputs must have the same pytree structure, but" + r" they differ.*branch 0 output at path \['a'\] is a pytree leaf" + r" but branch1 output at path \['a'\] is a , so" + r" their" + " Python types differ.", + re.DOTALL)): lax.switch(1, [lambda _: dict(a=2.), lambda _: dict(a=(3., 3.))], 1.) + with self.assertRaisesRegex( TypeError, - "branch 0 output and branch 1 output must have identical types, got\n" - r"{'a': 'DIFFERENT ShapedArray\(float32\[1\]\) " - r"vs. ShapedArray\(float32\[\].*\)'}."): + re.compile( + "switch branches must have equal output types but they differ.*The" + r" output of branch 0 at path \['a'\] has type float32\[1\] but the" + r" corresponding output of branch1 has type float32\[\], so the" + " shapes do not match", + re.DOTALL)): lax.switch(1, [lambda _: dict(a=jnp.array([1.], jnp.float32)), lambda _: dict(a=jnp.float32(1.))], 1.) @@ -1309,6 +1326,34 @@ def f(x): self.assertAllClose(ans, expected, check_dtypes=False) jtu.check_grads(f, (x,), order=2, modes=["fwd", "rev"]) + @parameterized.parameters(itertools.product(range(4), repeat=3)) + @jtu.run_on_devices("cpu") + def testSwitchGradWithForwarding(self, seed, num_input_fwd, num_output_fwd): + num_args = 3 + num_branches = 4 + rng = np.random.RandomState(seed) + in_perm = rng.permutation(num_args) + out_perm = rng.permutation(num_args) + + def branch(s, inputs): + inputs = [inputs[i] for i in in_perm] + outputs = inputs[:num_input_fwd] + [ + s * jnp.exp(inputs[i]) if i < num_output_fwd else jnp.sin(inputs[i]) + for i in range(num_args - num_input_fwd)] + return [outputs[i] for i in out_perm] + + branches = [partial(branch, i) for i in range(num_branches)] + + @jax.jit + def f_(idx, inputs): + idx = lax.convert_element_type(idx // 1, np.int32) + return lax.switch(idx, branches, inputs) + + for idx in range(num_branches): + f = partial(f_, idx) + jtu.check_grads(f, (jnp.arange(float(num_args)),), + order=1, modes=['fwd', 'rev'], atol=1e-2, rtol=1e-2) + def testSwitchGradWithWeakTypeMismatch(self): # issue #4696, PR #4896 dtype = dtypes.canonicalize_dtype(np.float64) dtype = jnp.float32 if dtype == jnp.float32 else jnp.float64 @@ -1955,7 +2000,7 @@ def testScanBodyCarryTypeMismatchErrors(self): with self.assertRaisesRegex( TypeError, re.escape("function carry input and carry output must have equal " - "types (e.g. shapes and dtypes of arrays), but they differ:\n\n" + "types, but they differ:\n\n" "The input carry x has type int32[] but the corresponding " "output carry component has type float32[], so the dtypes do " "not match" @@ -1966,7 +2011,7 @@ def testScanBodyCarryTypeMismatchErrors(self): with self.assertRaisesRegex( TypeError, re.escape("function carry input and carry output must have equal " - "types (e.g. shapes and dtypes of arrays), but they differ:\n\n" + "types, but they differ:\n\n" "The input carry component x[1] has type int32[] but the " "corresponding output carry component has type float32[], " "so the dtypes do not match" @@ -1977,13 +2022,13 @@ def testScanBodyCarryTypeMismatchErrors(self): with self.assertRaisesRegex( TypeError, re.escape("function carry input and carry output must have equal " - "types (e.g. shapes and dtypes of arrays), but they differ:\n\n" + "types, but they differ:\n\n" " * the input carry component x[0] has type int32[] but the " "corresponding output carry component has type float32[], " "so the dtypes do not match;\n" " * the input carry component x[1] has type int32[] but the " "corresponding output carry component has type float32[1,1], " - "so the dtypes do not match and also the shapes do not match." + "so the dtypes do not match, and the shapes do not match." )): jax.lax.scan(lambda x, _: ((x[0].astype('float32'), x[1].astype('float32').reshape(1, 1), @@ -2317,7 +2362,7 @@ def testWhileGradError(self, loop: str = "fori_inside_scan"): elif loop == "fori_inside_cond": func = lambda x: lax.cond( True, - x, lambda x: lax.fori_loop(x, x + 2., lambda i, c: c, x), + x, lambda x: lax.fori_loop(x, x + 2., lambda i, c: c * 2., x), 1., lambda x: x) elif loop == "fori_inside_scan": func = lambda x: lax.scan( @@ -2467,7 +2512,7 @@ def f(c, a): self.assertLess(len(scan_unrolled_hlo), len(scan_fully_unrolled_hlo)) # and the lowering should contain a while loop, unless the scan is fully - # unrolled + # unrolled self.assertIn("while(", scan_hlo) self.assertIn("while(", scan_unrolled_hlo) self.assertNotIn("while(", scan_fully_unrolled_hlo) @@ -2758,7 +2803,6 @@ def cond_fun(val): self.assertAllClose(deriv(my_pow)(3.0, 1), 1.0, check_dtypes=False) - def test_while_loop_fixed_point_with_batched_pred_and_consts(self): def f(i, x): def cond(carry): @@ -3048,7 +3092,7 @@ def test_cond_memory_leak(self): def leak(): data = jax.device_put(np.zeros((1024), dtype=np.float32) + 1) def g(): - return jax.lax.cond( + return jax.lax.cond( True, lambda: data[0], # noqa: F821 lambda: data[1], # noqa: F821 @@ -3066,6 +3110,100 @@ def g(): leak() self.assertEqual(base, nbufs()) + def test_grad_remat_while_fixpoint(self): + @jax.remat + def f(x, y): + def cond(_): + return False + def body(c): + x, y = c + return (y, x) + x, y = jax.lax.while_loop(cond, body, (x, y)) + return x + y + jax.linearize(f, 1., 2.) # don't crash + + def test_readonly_carry_optimization(self): + # https://github.com/google/flax/issues/4700 + def foo(w, x, c_max): + def while_cond(val): + c, x, w = val + return c < c_max + + def while_body(val): + c, x, w = val + return c + 1, x @ w, w + + _, x, w = jax.lax.while_loop(while_cond, while_body, (0, x, w)) + return w, x + + w = jnp.ones((2, 2)) + xs = jnp.ones((4, 2)) + c_maxs = jnp.arange(4) + w_, _ = jax.vmap(foo, in_axes=(None, 0, 0), out_axes=(None, 0) + )(w, xs, c_maxs) # doesn't crash + self.assertAllClose(w, w_, check_dtypes=False) + + @parameterized.parameters(itertools.product(range(3), repeat=5)) + @jtu.run_on_devices("cpu") + def test_while_constification_correctness( + self, + seed, + num_body_consts, + num_inplace_fwds_cond_uses, + num_inplace_fwds_cond_doesnt_use, + num_noninplace_fwds): + + num_fwds = (num_inplace_fwds_cond_uses + num_inplace_fwds_cond_doesnt_use + + num_noninplace_fwds) + num_carry = num_fwds + 4 + + rng = np.random.RandomState(seed) + perm = rng.permutation(num_carry) + iperm = np.argsort(perm) + + body_consts = [rng.randn(3) for _ in range(num_body_consts)] + init_vals = list(rng.uniform(size=num_carry)) + + def cond_fun(c): + i, c = c + c = [c[i] for i in iperm] + c, _ = split_list(c, [num_inplace_fwds_cond_uses]) + return (i < 2) + (0. * jnp.array(sum(c))).astype(bool) + + def body_fun(c): + i, c = c + c = [c[i] for i in iperm] + inplace_fwds, noninplace_fwds, dont_fwd = split_list( + c, [num_inplace_fwds_cond_uses + num_inplace_fwds_cond_doesnt_use, + num_noninplace_fwds]) + dont_fwd = [jnp.sin(x) * sum(jnp.sum(c) for c in body_consts) + for x in dont_fwd] + new_c_perm = [*inplace_fwds, *dont_fwd, *noninplace_fwds] + new_c = [new_c_perm[i] for i in perm] + return (i + 1, new_c) + + i, outs = jax.lax.while_loop(cond_fun, body_fun, (0, init_vals)) + self.assertEqual(i, 2) + _, outs_ref = body_fun(body_fun((0, init_vals))) + self.assertAllClose(outs, outs_ref, check_dtypes=False) + + def test_while_constification_correctness_manually(self): + # regression test for a particular index-offset logic bug + + def cond_fun(c): + # cond doesn't use first or third element of the carry + _, i, _ = c + return i == 0 + + def body_fun(c): + # two body consts + for _ in range(2): jnp.sin(np.zeros(3)) + # first element of the carry is forwarded to third element of the carry + return 0., 1., c[0] + + outs = jax.lax.while_loop(cond_fun, body_fun, (5., 0., 3.14)) + self.assertAllClose(outs, (0., 1., 5.)) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/lax_numpy_indexing_test.py b/tests/lax_numpy_indexing_test.py index 63a725ad3643..ca9ba9c88806 100644 --- a/tests/lax_numpy_indexing_test.py +++ b/tests/lax_numpy_indexing_test.py @@ -926,12 +926,20 @@ def testSimpleIndexingUsesSlice(self): self.assertEqual(jaxpr.jaxpr.eqns[-2].primitive, lax.slice_p) self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.squeeze_p) - # Indexing with `Ellipsis` is not lowered to `gather`. + # Indexing with `Ellipsis` is not lowered to `gather` ... jaxpr = jax.make_jaxpr(lambda x: x[..., 0])(jnp.ones((3, 4, 5))) self.assertLen((jaxpr.jaxpr.eqns), 2) self.assertEqual(jaxpr.jaxpr.eqns[-2].primitive, lax.slice_p) self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.squeeze_p) + # ... even when the ellipsis expands to no dimensions. + jaxpr = jax.make_jaxpr(lambda x: x[..., 0:1])(jnp.ones((3,))) + self.assertLen((jaxpr.jaxpr.eqns), 1) + self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.slice_p) + jaxpr = jax.make_jaxpr(lambda x: x[0:1, ...])(jnp.ones((3,))) + self.assertLen((jaxpr.jaxpr.eqns), 1) + self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.slice_p) + # Simple reverses lower to lax.rev_p jaxpr = jax.make_jaxpr(lambda x: x[:, ::-1])(jnp.ones((3, 4))) self.assertEqual(len(jaxpr.jaxpr.eqns), 1) diff --git a/tests/lax_numpy_reducers_test.py b/tests/lax_numpy_reducers_test.py index 0c3f1d1471fb..aa5e08e96a3e 100644 --- a/tests/lax_numpy_reducers_test.py +++ b/tests/lax_numpy_reducers_test.py @@ -905,8 +905,8 @@ def testCumulativeSumBool(self): @jtu.ignore_warning(category=NumpyComplexWarning) @jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion def testCumulativeProd(self, shape, axis, dtype, out_dtype, include_initial): - if jtu.is_device_tpu(6): - raise unittest.SkipTest("TODO(b/364258243): Test fails on TPU v6e") + if jtu.is_device_tpu_at_least(6): + raise unittest.SkipTest("TODO(b/364258243): Test fails on TPU v6+") rng = jtu.rand_some_zero(self.rng()) # We currently "cheat" to ensure we have JAX arrays, not NumPy arrays as diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 98f10d9c02b3..de943c3b613a 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -50,7 +50,7 @@ from jax._src import dtypes from jax._src import test_util as jtu from jax._src.lax import lax as lax_internal -from jax._src.util import safe_zip, NumpyComplexWarning, tuple_replace +from jax._src.util import safe_zip, NumpyComplexWarning, tuple_update config.parse_flags_with_absl() @@ -3496,11 +3496,6 @@ def testReshape(self, arg_shape, out_shape, dtype, order): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) - def testReshapeDeprecatedArgs(self): - msg = "The newshape argument to jnp.reshape was removed in JAX v0.4.36." - with self.assertRaisesRegex(TypeError, msg): - jnp.reshape(jnp.arange(4), newshape=(2, 2)) - @jtu.sample_product( [dict(arg_shape=arg_shape, out_shape=out_shape) for arg_shape, out_shape in [ @@ -3800,9 +3795,10 @@ def testArrayFromList(self): with self.assertRaisesRegex(OverflowError, "Python int too large.*"): jnp.array([0, val]) - def testArrayNoneWarning(self): - # TODO(jakevdp): make this an error after the deprecation period. - with self.assertWarnsRegex(FutureWarning, r"None encountered in jnp.array\(\)"): + def testArrayNone(self): + with self.assertRaisesRegex( + ValueError, 'None is not a valid value for jnp.array' + ): jnp.array([0.0, None]) def testIssue121(self): @@ -6042,7 +6038,10 @@ def np_fun(a, i, v): dict(a_shape=a_shape, i_shape=i_shape, v_shape=v_shape, axis=axis) for a_shape in nonempty_array_shapes for axis in list(range(-len(a_shape), len(a_shape))) - for i_shape in [tuple_replace(a_shape, axis, J) for J in range(a_shape[axis] + 1)] + for i_shape in [ + tuple_update(a_shape, axis if axis >= 0 else axis + len(a_shape), J) + for J in range(a_shape[axis] + 1) + ] for v_shape in [(), (1,), i_shape] ] + [ dict(a_shape=a_shape, i_shape=i_shape, v_shape=v_shape, axis=None) diff --git a/tests/lax_scipy_special_functions_test.py b/tests/lax_scipy_special_functions_test.py index f4e4e4f48213..4b3945a84453 100644 --- a/tests/lax_scipy_special_functions_test.py +++ b/tests/lax_scipy_special_functions_test.py @@ -288,35 +288,71 @@ def testExpiDisableJit(self): self.assertAllClose(result_jit, result_nojit) def testGammaIncBoundaryValues(self): - dtype = jax.numpy.zeros(0).dtype # default float dtype. + dtype = jax.dtypes.canonicalize_dtype(float) nan = float('nan') inf = float('inf') if jtu.parse_version(scipy.__version__) >= (1, 16): - samples_slice = slice(None) + a_samples = [0, 0, 0, 1, nan, 1, nan, 0, 1, 1, nan] + x_samples = [0, 1, 2, 0, 1, nan, nan, inf, inf, -1, inf] else: # disable samples that contradict with scipy/scipy#22441 - samples_slice = slice(None, -1) - args_maker = lambda: [np.array([0, 0, 0, 1, nan, 1, nan, 0, 1, nan][samples_slice]).astype(dtype), - np.array([0, 1, 2, 0, 1, nan, nan, inf, inf, inf][samples_slice]).astype(dtype)] + a_samples = [0, 0, 0, 1, nan, 1, nan, 0, 1, 1] + x_samples = [0, 1, 2, 0, 1, nan, nan, inf, inf, -1] + + args_maker = lambda: (np.array(a_samples, dtype=dtype), np.array(x_samples, dtype=dtype)) + rtol = 1E-3 if jtu.test_device_matches(["tpu"]) else 1e-5 self._CheckAgainstNumpy(lsp_special.gammainc, osp_special.gammainc, args_maker, rtol=rtol) self._CompileAndCheck(lsp_special.gammainc, args_maker, rtol=rtol) def testGammaIncCBoundaryValues(self): - dtype = jax.numpy.zeros(0).dtype # default float dtype. + dtype = jax.dtypes.canonicalize_dtype(float) nan = float('nan') inf = float('inf') if jtu.parse_version(scipy.__version__) >= (1, 16): - samples_slice = slice(None) + a_samples = [0, 0, 0, 1, nan, 1, nan, 0, 1, 1, nan] + x_samples = [0, 1, 2, 0, 1, nan, nan, inf, inf, -1, inf] else: # disable samples that contradict with scipy/scipy#22441 - samples_slice = slice(None, -1) - args_maker = lambda: [np.array([0, 0, 0, 1, nan, 1, nan, 0, 1, 1, nan][samples_slice]).astype(dtype), - np.array([0, 1, 2, 0, 1, nan, nan, inf, inf, -1, inf][samples_slice]).astype(dtype)] + a_samples = [0, 0, 0, 1, nan, 1, nan, 0, 1, 1] + x_samples = [0, 1, 2, 0, 1, nan, nan, inf, inf, -1] + + args_maker = lambda: (np.array(a_samples, dtype=dtype), np.array(x_samples, dtype=dtype)) + rtol = 1E-3 if jtu.test_device_matches(["tpu"]) else 1e-5 self._CheckAgainstNumpy(lsp_special.gammaincc, osp_special.gammaincc, args_maker, rtol=rtol) self._CompileAndCheck(lsp_special.gammaincc, args_maker, rtol=rtol) + def testBetaIncBoundaryValues(self): + dtype = jax.dtypes.canonicalize_dtype(float) + fi = jax.numpy.finfo(dtype) + nan = float('nan') + inf = float('inf') + tiny = fi.tiny + eps = fi.eps + if jtu.parse_version(scipy.__version__) >= (1, 16): + # TODO(pearu): enable tiny samples when a fix to scipy/scipy#22682 + # will be available + a_samples = [nan, -0.5, inf, 0, eps, 1, tiny][:-1] + b_samples = [nan, -0.5, inf, 0, eps, 1, tiny][:-1] + elif jtu.parse_version(scipy.__version__) >= (1, 12): + # disabled samples that contradict with scipy/scipy#22425 + a_samples = [nan, -0.5, 0.5] + b_samples = [nan, -0.5, 0.5] + else: + a_samples = [-0.5, 0.5] + b_samples = [-0.5, 0.5] + x_samples = [nan, -0.5, 0, 0.5, 1, 1.5] + + a_samples = np.array(a_samples, dtype=dtype) + b_samples = np.array(b_samples, dtype=dtype) + x_samples = np.array(x_samples, dtype=dtype) + + args_maker = lambda: np.meshgrid(a_samples, b_samples, x_samples) + + rtol = 1E-3 if jtu.test_device_matches(["tpu"]) else 5e-5 + self._CheckAgainstNumpy(osp_special.betainc, lsp_special.betainc, args_maker, rtol=rtol) + self._CompileAndCheck(lsp_special.betainc, args_maker, rtol=rtol) if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/lax_scipy_test.py b/tests/lax_scipy_test.py index 388d053d9608..bc80ed4e1cc2 100644 --- a/tests/lax_scipy_test.py +++ b/tests/lax_scipy_test.py @@ -339,8 +339,8 @@ def scipy_fun(z): ) @jtu.ignore_warning(category=DeprecationWarning, message=".*scipy.special.lpmn.*") def testLpmn(self, l_max, shape, dtype): - if jtu.is_device_tpu(6, "e"): - self.skipTest("TODO(b/364258243): fails on TPU v6e") + if jtu.is_device_tpu_at_least(6): + self.skipTest("TODO(b/364258243): fails on TPU v6+") rng = jtu.rand_uniform(self.rng(), low=-0.2, high=0.9) args_maker = lambda: [rng(shape, dtype)] @@ -461,8 +461,8 @@ def testSphHarmOrderOneDegreeOne(self): @jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion def testSphHarmForJitAndAgainstNumpy(self, l_max, num_z, dtype): """Tests against JIT compatibility and Numpy.""" - if jtu.is_device_tpu(6, "e"): - self.skipTest("TODO(b/364258243): fails on TPU v6e") + if jtu.is_device_tpu_at_least(6): + self.skipTest("TODO(b/364258243): fails on TPU v6+") n_max = l_max shape = (num_z,) rng = jtu.rand_int(self.rng(), -l_max, l_max + 1) @@ -508,8 +508,8 @@ def testSphHarmCornerCaseWithWrongNmax(self): ) @jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion def testSphHarmY(self, l_max, num_z, dtype): - if jtu.is_device_tpu(6, "e"): - self.skipTest("TODO(b/364258243): fails on TPU v6e") + if jtu.is_device_tpu_at_least(6): + self.skipTest("TODO(b/364258243): fails on TPU v6+") n_max = l_max shape = (num_z,) rng = jtu.rand_int(self.rng(), -l_max, l_max + 1) diff --git a/tests/lax_test.py b/tests/lax_test.py index 8764caeb2e49..6792b08c37fa 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -29,6 +29,7 @@ import jax from jax._src import core +from jax import export from jax import jvp, grad from jax import lax import jax.numpy as jnp @@ -49,7 +50,6 @@ from jax._src.lax import lax as lax_internal from jax._src.util import NumpyComplexWarning, safe_zip from jax._src.tree_util import tree_map -from jax._src.lib import xla_extension_version config.parse_flags_with_absl() @@ -1128,11 +1128,6 @@ def testDotAlgorithm(self, algorithm, dtype): raise SkipTest( f"The dot algorithm '{algorithm}' is not supported on CPU.") if jtu.test_device_matches(["gpu"]): - if (algorithm == lax.DotAlgorithmPreset.BF16_BF16_F32_X9 and - xla_extension_version < 320): - raise SkipTest( - f"The dot algorithm ${algorithm} requires XLA extension version " - ">= 320.") # GPU algorithm support is a little spotty. It is checked in # xla/service/algorithm_util.cc and the logic is copied here. if algorithm in { @@ -3627,6 +3622,37 @@ def f(x): g = jax.grad(f)(5.) # doesn't crash self.assertAllClose(g, 3., check_dtypes=False) + def test_shape_as_value_handles_static_shapes(self): + result = lax.shape_as_value(()) + self.assertArraysEqual(result, lax.full((0,), np.array(0, np.int64))) + + result = lax.shape_as_value((2,)) + self.assertArraysEqual(result, np.asarray((2,), np.int64)) + + result = lax.shape_as_value((2, 3)) + self.assertArraysEqual(result, np.asarray((2, 3), np.int64)) + + def test_shape_as_value_handles_polymorphic_shapes(self): + @jax.jit + def f(x): + return lax.shape_as_value(x.shape) + + exported = export.export(f)( + jax.ShapeDtypeStruct(export.symbolic_shape("a"), jnp.float32) + ) + result = exported.call(np.ones((1), dtype=np.float32)) + self.assertArraysEqual(result, np.asarray((1,), np.int64)) + result = exported.call(np.ones((2), dtype=np.float32)) + self.assertArraysEqual(result, np.asarray((2,), np.int64)) + + exported = export.export(f)( + jax.ShapeDtypeStruct(export.symbolic_shape("a, b"), jnp.float32) + ) + result = exported.call(np.ones((1, 2), dtype=np.float32)) + self.assertArraysEqual(result, np.asarray((1, 2), np.int64)) + result = exported.call(np.ones((3, 4), dtype=np.float32)) + self.assertArraysEqual(result, np.asarray((3, 4), np.int64)) + class LazyConstantTest(jtu.JaxTestCase): def _Check(self, make_const, expected): @@ -4747,7 +4773,7 @@ def my_square(x): ValueError, "JVP rule for composite not implemented. You can use `jax.custom_jvp` " "to add support. See " - "https://jax.readthedocs.io/en/latest/_autosummary/jax.custom_jvp.html" + "https://docs.jax.dev/en/latest/_autosummary/jax.custom_jvp.html" ): jvp(my_square, (1.0,), (2.0,)) @@ -4760,7 +4786,7 @@ def my_square(x): ValueError, "JVP rule for composite not implemented. You can use `jax.custom_jvp` " "to add support. See " - "https://jax.readthedocs.io/en/latest/_autosummary/jax.custom_jvp.html" + "https://docs.jax.dev/en/latest/_autosummary/jax.custom_jvp.html" ): grad(my_square)(1.0) @@ -4802,10 +4828,10 @@ class RaggedTest(jtu.JaxTestCase): @jtu.sample_product( [ - {'m': 5, 'k': 4, 'n': 3, 'num_groups': 1}, - {'m': 10, 'k': 9, 'n': 8, 'num_groups': 2}, + {'m': 64, 'k': 4, 'n': 3, 'num_groups': 1}, + {'m': 64, 'k': 9, 'n': 8, 'num_groups': 2}, ], - dtype=jtu.dtypes.numeric, + dtype=jtu.dtypes.all_floating, ) def test_ragged_dot(self, m, k, n, num_groups, dtype): """Tests ragged_dot. @@ -4816,6 +4842,8 @@ def test_ragged_dot(self, m, k, n, num_groups, dtype): Raises: SkipTest: in the case dtype is not supported. """ + if (dtype == np.float16): + raise SkipTest(f"unsupported dtype for ragged_dot: {dtype}") lhs_shape = (m, k) rhs_shape = (num_groups, k, n) @@ -4837,6 +4865,25 @@ def group_sizes(m, num_groups): self._CheckAgainstNumpy( lax_reference.ragged_dot, lax.ragged_dot, args_maker) + @parameterized.parameters( + { "m": 5, "k": 4, "n": 3, "num_groups": 1}, + { "m": 10, "k": 9, "n": 8, "num_groups": 2}, + ) + def test_ragged_dot_unsupported( + self, m, k, n, num_groups): + lhs_shape = (m, k) + rhs_shape = (num_groups, k, n) + group_sizes_shape = (num_groups,) + + args_maker = lambda: [ + jnp.ones(lhs_shape, dtype=jnp.float32), + jnp.ones(rhs_shape, dtype=jnp.float32), + jnp.ones(group_sizes_shape, dtype=jnp.int32), + ] + if jtu.test_device_matches(["tpu"]): + with self.assertRaises(jax.errors.JaxRuntimeError): + self._CompileAndCheck(lax.ragged_dot, args_maker) + @parameterized.parameters( { "lhs_shape": lhs_shape, @@ -5055,10 +5102,69 @@ def test_ragged_dot_general_shape_inference_success( lhs = jnp.ones(lhs_shape, dtype=jnp.float32) rhs = jnp.ones(rhs_shape, dtype=jnp.float32) group_sizes = jnp.ones(group_sizes_shape, dtype=jnp.int32) - self.assertEqual( - lax.ragged_dot_general(lhs, rhs, group_sizes, ragged_dnums).shape, - out_shape, + if jtu.test_device_matches(["tpu"]): + actual_shape = lax_internal._ragged_dot_general_shape_rule( + lhs, rhs, group_sizes, ragged_dot_dimension_numbers=ragged_dnums, + precision=jax.lax.Precision.DEFAULT, + preferred_element_type=jnp.float32, + ) + else: + actual_shape = lax.ragged_dot_general( + lhs, rhs, group_sizes, ragged_dnums + ).shape + self.assertEqual(actual_shape, out_shape) + + @parameterized.product( + batch_size=[3, 5], + m=[128, 1024], + k=[128, 1024], + n=[128, 1024], + num_groups=[2, 4], + ) + def test_ragged_dot_general_vmap( + self, batch_size: int, m: int, k: int, n: int, num_groups: int + ): + if (jtu.test_device_matches(["tpu"])): + raise SkipTest("batched ragged_dot not yet supported on TPU") + + lhs_shape = (batch_size, m, k) + rhs_shape = (batch_size, num_groups, k, n) + dtype = jnp.float32 + + def make_group_sizes(m, num_groups): + ends_no_final = jnp.sort(self.rng().choice(m, size=num_groups - 1)) + ends = jnp.concatenate( + [ends_no_final, jnp.array([m], dtype=ends_no_final.dtype)]) + starts = jnp.concatenate( + [jnp.zeros(1, dtype=ends_no_final.dtype), ends_no_final]) + return ends - starts + + rng = jtu.rand_small(self.rng()) + args_maker = lambda: [ + rng(lhs_shape, dtype), + rng(rhs_shape, dtype), + jnp.array([make_group_sizes(m, num_groups) for _ in range(batch_size)]), + ] + lhs, rhs, group_sizes = args_maker() + + out_dtype = jnp.float32 + precision = jax.lax.Precision.HIGHEST + ragged_dot = partial( + jax.lax.ragged_dot, + preferred_element_type=out_dtype, + precision=precision, ) + tol = 1e-5 + + batch_res = jax.vmap(ragged_dot)(lhs, rhs, group_sizes) + for i in range(batch_size): + # The ragged_dot does not zero out the output in the case sum(group_sizes) + # < m, hence we need to compare only the valid part of the output. + upper_bound = group_sizes[i].sum(axis=0) + ref_res = ragged_dot(lhs[i], rhs[i], group_sizes[i])[0:upper_bound, :] + self.assertArraysAllClose( + batch_res[i, 0:upper_bound, :], ref_res, rtol=tol, atol=tol + ) if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/layout_test.py b/tests/layout_test.py index b9062b8d21dc..ae10013a5f60 100644 --- a/tests/layout_test.py +++ b/tests/layout_test.py @@ -21,9 +21,10 @@ import jax.numpy as jnp from jax.sharding import NamedSharding, PartitionSpec as P, SingleDeviceSharding from jax._src import config -from jax._src.layout import Layout, DeviceLocalLayout as DLL from jax._src import test_util as jtu from jax._src.util import safe_zip +from jax.experimental.layout import (with_dll_constraint, Layout, + DeviceLocalLayout as DLL) from jax.experimental.compute_on import compute_on config.parse_flags_with_absl() @@ -744,6 +745,36 @@ def f(x): self.assertArraysEqual(out, np_inp * 2) self.assertEqual(out.layout, out_layout) + def test_with_dll_constraint(self): + if not jtu.test_device_matches(['tpu']): + self.skipTest('Only works for TPU') + mesh = jtu.create_mesh((2, 2), ('x', 'y')) + shape = (16, 128) + s = NamedSharding(mesh, P('x')) + np_inp = np.arange(math.prod(shape)).reshape(shape) + arr = jax.device_put(np_inp, s) + + # Create a custom layout instead of using `arr.layout` to test the API. + custom_dll = DLL(major_to_minor=arr.layout.dll.major_to_minor[::-1]) + + def f(x): + y = x.T + # Constrain `y` to the original layout of `arr` because without it, + # the layout of `y` would be the transpose of `arr`. + y = with_dll_constraint(y, custom_dll) + return y * 2 + + f(arr) # doesn't crash + + f = jax.jit(f) + out = f(arr) + self.assertEqual(out.layout.device_local_layout.major_to_minor, + custom_dll.major_to_minor) + self.assertArraysEqual(out, np_inp.T * 2) + + lowered_text = f.lower(arr).as_text() + self.assertIn('LayoutConstraint', lowered_text) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/linalg_test.py b/tests/linalg_test.py index feab105ccbe2..1670f1ee4abd 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -96,7 +96,7 @@ def args_maker(): a = rng(factor_shape, dtype) return [np.matmul(a, jnp.conj(T(a)))] - jnp_fun = partial(jnp.linalg.cholesky, upper=upper) + jnp_fun = partial(jnp.linalg.cholesky, upper=upper, symmetrize_input=True) def np_fun(x, upper=upper): # Upper argument added in NumPy 2.0.0 @@ -867,9 +867,6 @@ def testSVD(self, b, m, n, dtype, full_matrices, compute_uv, hermitian, algorith self.skipTest("Hermitian SVD doesn't support the algorithm parameter.") if not jtu.test_device_matches(["cpu", "gpu"]): self.skipTest("SVD algorithm selection only supported on CPU and GPU.") - # TODO(danfm): Remove this check after 0.5.2 is released. - if jtu.test_device_matches(["cpu"]) and jtu.jaxlib_version() <= (0, 5, 1): - self.skipTest("SVD algorithm selection on CPU requires a newer jaxlib version.") if jtu.test_device_matches(["cpu"]) and algorithm == lax.linalg.SvdAlgorithm.JACOBI: self.skipTest("Jacobi SVD not supported on GPU.") @@ -2332,6 +2329,22 @@ def testSymmetricProduct(self, shape, dtype, symmetrize_output): self.assertAllClose( new_product_with_batching, old_product, atol=atol) + @jtu.sample_product( + n=[0, 1, 5, 10, 20], + kind=["symmetric", "lower", "upper"], + ) + @jax.default_matmul_precision("float32") + def testPascal(self, n, kind): + args_maker = lambda: [] + osp_fun = partial(osp.linalg.pascal, n=n, kind=kind, exact=False) + jsp_fun = partial(jsp.linalg.pascal, n=n, kind=kind) + self._CheckAgainstNumpy(osp_fun, + jsp_fun, args_maker, + atol=1e-3, + rtol=1e-2 if jtu.test_device_matches(['tpu']) else 1e-3, + check_dtypes=False) + self._CompileAndCheck(jsp_fun, args_maker) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/lobpcg_test.py b/tests/lobpcg_test.py index fc2b0df849d1..76d6006432f4 100644 --- a/tests/lobpcg_test.py +++ b/tests/lobpcg_test.py @@ -272,7 +272,7 @@ def checkLobpcgMonotonicity(self, matrix_name, n, k, m, tol, dtype): self._possibly_plot(A, eigs, X, m, matrix_name) def _possibly_plot(self, A, eigs, X, m, matrix_name): - if not os.getenv('LOBPCG_EMIT_DEBUG_PLOTS'): + if os.getenv('LOBPCG_EMIT_DEBUG_PLOTS', '0') != '1': return if isinstance(A, (np.ndarray, jax.Array)): diff --git a/tests/memories_test.py b/tests/memories_test.py index 0ca973c4d221..203bc4abb613 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -756,9 +756,6 @@ def init(): def test_compute_no_inputs_host_replicated(self): if xb.backend_xla_version() is not None and xb.backend_xla_version() < 3: self.skipTest("This test requires an xla_version >= 3.") - if config.use_shardy_partitioner.value: - self.skipTest("XLA failure due to b/370786664 and b/366411266. " - "Enable when fixed.") mesh = jtu.create_mesh((4,), ('data')) tpu_sharding = NamedSharding(mesh, P('data')) @@ -794,6 +791,36 @@ def f(x): lowered_text = f.lower(jnp.arange(8)).as_text() self.assertIn('_xla_compute_type', lowered_text) + @functools.partial(jax.jit, out_shardings=out_s) + def h(x): + y = g(x) + return y * 3 + + out2 = h(inp) + self.assertArraysEqual(out2, inp * 6) + self.assertEqual(out2.sharding.memory_kind, "pinned_host") + + def test_compute_on_2d(self): + out_s = SingleDeviceSharding(jax.devices()[0], memory_kind="pinned_host") + + @compute_on("device_host") + @jax.jit + def g(x): + return x * 2 + + @jax.jit + def f(x): + y = g(x) + return y * 3 + + inp = jnp.arange(9943.0) + inp = jnp.reshape(inp, (61, 163)) + out = f(inp) + self.assertArraysEqual(out, inp * 6) + + lowered_text = f.lower(inp).as_text() + self.assertIn("_xla_compute_type", lowered_text) + @functools.partial(jax.jit, out_shardings=out_s) def h(x): y = g(x) @@ -1474,8 +1501,8 @@ def test_mem_kind_donation_pinned_host(self): s = NamedSharding(mesh, P(), memory_kind='pinned_host') s_dev = s.with_memory_kind('device') - @compute_on('device_host') @functools.partial(jax.jit, out_shardings=(s, s_dev), donate_argnums=(0, 1)) + @compute_on('device_host') def f(inp1, inp2): return inp1 * 2, inp2 * 2 @@ -1638,6 +1665,20 @@ def f(x): # 2 for `f` and `2` for `mul` (compute type changes for `mul`) self.assertEqual(count(), 4) + def test_compute_on_aot(self): + operand = np.float32(0.) + + @jax.jit + @compute_on("device_host") + def f_host(x): + # Adds 1 on CPU and adds 2 on other platforms + return jax.lax.platform_dependent(x, + cpu=lambda x: x + 1., + default=lambda x: x + 2.) + + self.assertAllClose(1., f_host(operand)) + self.assertAllClose(1., f_host.lower(operand).compile()(operand)) + def test_offload_take_host(self): # TODO(apaszke): Remove after 12 weeks have passed. if not jtu.if_cloud_tpu_at_least(2024, 12, 19): @@ -1664,30 +1705,43 @@ class StreamAnnotationTest(jtu.JaxTestCase): def test_stream_annotation_inside_shmap(self): if not jtu.test_device_matches(["gpu"]): self.skipTest("Stream annotation is only supported on GPU.") + mesh = jtu.create_mesh((2,), ('x',)) s = NamedSharding(mesh, P('x')) np_inp = np.ones((8,)) arr1 = jax.device_put(np_inp, s) arr2 = jax.device_put(np_inp, s) + # Makes sure the compute wrapped here is fusible. + # This is a workaround for limitations in XLA. + # 1) Compute-on boxes contain a single instruction cannot work. + # 2) Compute-on boxes contain tiny matmul cannot work. @compute_on('gpu_stream:1') @jax.jit def g(x, y): - return x * y + return x * y + x @compute_on('gpu_stream:2') @jax.jit def h(x, y): - return x * y + return x * y + x def f(x, y): z = g(x, y) w = h(3 * x, 2 * y) return z + w - out = jax.jit(shard_map(f, mesh=mesh, in_specs=(P('x'), P('x')), - out_specs=P('x')))(arr1, arr2) - self.assertArraysEqual(out, arr1 * 7) + compiled_f = jax.jit( + shard_map(f, mesh=mesh, in_specs=(P('x'), P('x')), + out_specs=P('x'))).lower(arr1, arr2).compile( + {"xla_gpu_experimental_stream_annotation": True} + ) + compiled_text = compiled_f.as_text() + self.assertIn('call-start', compiled_text) + self.assertIn('_xla_stream_annotation="1"', compiled_text) + self.assertIn('call-start.1', compiled_f.as_text()) + self.assertIn('_xla_stream_annotation="2"', compiled_text) + self.assertArraysEqual(compiled_f(arr1, arr2), arr1 * 11) class ActivationOffloadingTest(jtu.JaxTestCase): diff --git a/tests/mesh_utils_test.py b/tests/mesh_utils_test.py index 136b507942e7..28efb266b281 100644 --- a/tests/mesh_utils_test.py +++ b/tests/mesh_utils_test.py @@ -1,4 +1,4 @@ -# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# Copyright 2021 The JAX Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/monitoring_test.py b/tests/monitoring_test.py index 52b53895c2cc..89c7148a2a42 100644 --- a/tests/monitoring_test.py +++ b/tests/monitoring_test.py @@ -29,7 +29,7 @@ def tearDown(self): def test_record_event(self): events = [] - counters = {} # Map event names to frequency. + counters = {} # Map event names to frequency. def increment_event_counter(event): if event not in counters: counters[event] = 0 @@ -48,7 +48,7 @@ def increment_event_counter(event): "test_common_event": 2}) def test_record_event_durations(self): - durations = {} # Map event names to frequency. + durations = {} # Map event names to frequency. def increment_event_duration(event, duration): if event not in durations: durations[event] = 0. @@ -62,6 +62,30 @@ def increment_event_duration(event, duration): self.assertDictEqual(durations, {"test_short_event": 3, "test_long_event": 10}) + def test_record_scalar(self): + observed_keys = [] + observed_values = [] + + monitoring.register_scalar_listener( + lambda key, _: observed_keys.append(key), + ) + monitoring.register_scalar_listener( + lambda _, value: observed_values.append(value), + ) + + monitoring.record_scalar("test_unique_event", 1) + monitoring.record_scalar("test_common_event", 2.5) + monitoring.record_scalar("test_common_event", 5e5) + + self.assertListEqual( + observed_keys, + ["test_unique_event", "test_common_event", "test_common_event"], + ) + self.assertListEqual( + observed_values, + [1, 2.5, 5e5], + ) + def test_unregister_exist_callback_success(self): original_duration_listeners = jax_src_monitoring.get_event_duration_listeners() callback = lambda event, durations: None diff --git a/tests/mosaic/gpu_dialect_test.py b/tests/mosaic/gpu_dialect_test.py index ba9d23fa5b4f..2d75c42424ef 100644 --- a/tests/mosaic/gpu_dialect_test.py +++ b/tests/mosaic/gpu_dialect_test.py @@ -34,6 +34,7 @@ from jax.experimental.mosaic import gpu as mgpu from jax.experimental.mosaic.gpu import layouts from jax.experimental.mosaic.gpu import utils as mgpu_utils +from jax.experimental.mosaic.gpu import dialect_lowering as lowering _cext = mgpu.dialect._cext if mgpu.dialect is not None else None @@ -592,6 +593,18 @@ def test_wgmma_b_n_dim_not_equal_to_acc_n_dim(self): ): self.module.operation.verify() + def test_tiled_layout_attr_parsing(self): + with ir.InsertionPoint(self.module.body): + for layout in ( + mgpu.WGMMA_LAYOUT, + mgpu.WGMMA_ROW_LAYOUT, + mgpu.WGMMA_COL_LAYOUT, + mgpu.WGMMA_TRANSPOSED_LAYOUT, + ): + attr = layouts.to_tiled_layout_attr(layout) + parsed_layout = layouts.from_tiled_layout_attr(attr) + self.assertEqual(layout, parsed_layout) + class DialectLoweringTest(MosaicGpuTest): @@ -862,6 +875,71 @@ def test_lower_conversion_op_lowers_to_same_op(self, op, in_dtype, out_dtype): self.assertLen(conversion_ops, 1) self.assertEqual(conversion_ops[0].result.type, scalar_out_ty) + @parameterized.parameters( + (True, False, False), + (False, True, False), + (False, False, True), + ) + def test_custom_primitive_op_must_have_number_of_annotations_matching_operands_and_results( + self, omit_in_layouts, omit_in_transforms, omit_out_layouts + ): + vec_ty = ir.VectorType.get((4, 32), ir.BF16Type.get()) + out_layouts = [ + layouts.to_layout_attr( + mgpu.WGStridedFragLayout.from_shaped_type(vec_ty) + ) + ] + in_layouts = out_layouts * 2 + in_transforms = [ + ir.ArrayAttr.get([mgpu.dialect.SwizzleTransformAttr.get(128)]) + ] + + in_layouts = [] if omit_in_layouts else in_layouts + in_transforms = [] if omit_in_transforms else in_transforms + out_layouts = [] if omit_out_layouts else out_layouts + + def body(vec1, vec2, ref): + mgpu.dialect.custom_primitive( + [vec_ty], [vec1, vec2, ref], in_layouts, in_transforms, out_layouts + ) + + with ir.InsertionPoint(self.module.body): + smem = ir.Attribute.parse("#gpu.address_space") + ref_ty = ir.MemRefType.get((4, 32), ir.BF16Type.get(), memory_space=smem) + func.FuncOp.from_py_func(vec_ty, vec_ty, ref_ty)(body) + + if omit_in_layouts: + error = "layout for each vector operand" + elif omit_in_transforms: + error = "transforms for each memref operand in smem" + else: + assert omit_out_layouts + error = "layout for each result" + + with self.assertRaisesRegex(ir.MLIRError, error): + self.module.operation.verify() + + def test_memref_transforms_with_transpose(self): + with ir.InsertionPoint(self.module.body): + ty_in = ir.MemRefType.get( + (64, 128), + ir.BF16Type.get(), + memory_space=ir.Attribute.parse("#gpu.address_space"), + ) + ref = memref.alloc(ty_in, [], []) + + ref = mgpu_utils.memref_transpose(ref, (1, 0)) + # This tiling is applied to the transposed memref. + transforms = [mgpu.TileTransform(tiling=(16, 32))] + + ref_transformed = lowering.reinterpret_smem_ref(ref, transforms) + ty_transformed = ir.MemRefType(ref_transformed.type) + self.assertEqual(ty_transformed.shape, [8, 2, 16, 32]) + strides, _ = ty_transformed.get_strides_and_offset() + self.assertEqual(strides, [512, 4096, 1, 16]) + + + if __name__ == "__main__": parameterized.absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/mosaic/gpu_layout_inference_test.py b/tests/mosaic/gpu_layout_inference_test.py index 36c8ff9cf47e..104b088bbdd2 100644 --- a/tests/mosaic/gpu_layout_inference_test.py +++ b/tests/mosaic/gpu_layout_inference_test.py @@ -19,6 +19,7 @@ from absl.testing import parameterized import jax from jax._src import config +from jax._src import lib as jaxlib from jax._src import test_util as jtu from jax._src.interpreters import mlir as mlir_interpreter from jax._src.lib.mlir import ir @@ -212,6 +213,32 @@ def body(lhs, rhs): ) self.assertSequenceEqual(add.attributes["out_layouts"], [layout_attr]) + def test_infer_layout_cast_layout(self): + # TODO(dasenov): Remove this after the minimal jaxlib version is 0.5.4. + if jaxlib.version < (0, 5, 4): + self.skipTest("Test requires jaxlib version >= 0.5.4") + add = cast = None + + shape = (128, 64) + splat_layout = layouts.to_layout_attr(mgpu.WGSplatFragLayout(shape=shape)) + wgmma_layout = layouts.to_layout_attr(mgpu.WGMMA_LAYOUT) + + def body(x): + nonlocal add, cast + add = arith.AddFOp(x, x) + cast = mgpu.dialect.LayoutCastOp(add.result, wgmma_layout) + + with ir.InsertionPoint(self.module.body): + elt_type = ir.BF16Type.get() + ty = ir.VectorType.get(shape, elt_type) + func_op = func.FuncOp.from_py_func(ty)(body).func_op + + func_op.attributes["in_layouts"] = ir.ArrayAttr.get([splat_layout]) + mgpu.infer_layout(self.module) + self.assertSequenceEqual(add.attributes["out_layouts"], [splat_layout]) + self.assertSequenceEqual(cast.attributes["in_layouts"], [wgmma_layout]) + self.assertSequenceEqual(cast.attributes["out_layouts"], [wgmma_layout]) + def test_infer_layout_traverses_ops_correctly(self): shape = (16, 8) elt_type = ir.BF16Type.get() @@ -442,6 +469,57 @@ def body(lhs, rhs): self.assertNotIn("in_layouts", f.attributes) self.assertNotIn("out_layouts", f.attributes) + def test_optimization_barrier_op_propagates_user_layouts(self): + add = optimization_barrier = None + + def body(lhs, rhs): + nonlocal add, optimization_barrier + optimization_barrier = mgpu.dialect.OptimizationBarrierOp([lhs, rhs]) + lhs, rhs = optimization_barrier.results + add = arith.AddFOp(lhs, rhs) + + with ir.InsertionPoint(self.module.body): + shape = (32, 4) + ty = ir.VectorType.get(shape, ir.BF16Type.get()) + func.FuncOp.from_py_func(ty, ty)(body) + + splat_layout = layouts.to_layout_attr(mgpu.WGSplatFragLayout(shape)) + add.attributes["out_layouts"] = ir.ArrayAttr.get([splat_layout]) + mgpu.infer_layout(self.module) + + self.assertSequenceEqual( + optimization_barrier.attributes["in_layouts"], + [splat_layout, splat_layout], + ) + self.assertSequenceEqual( + optimization_barrier.attributes["out_layouts"], + [splat_layout, splat_layout], + ) + + def test_optimization_barrier_op_propagates_producer_layouts(self): + add = optimization_barrier = None + + def body(lhs, rhs): + nonlocal add, optimization_barrier + add = arith.AddFOp(lhs, rhs) + optimization_barrier = mgpu.dialect.OptimizationBarrierOp([add]) + + with ir.InsertionPoint(self.module.body): + shape = (32, 4) + ty = ir.VectorType.get(shape, ir.BF16Type.get()) + func.FuncOp.from_py_func(ty, ty)(body) + + splat_layout = layouts.to_layout_attr(mgpu.WGSplatFragLayout(shape)) + add.attributes["out_layouts"] = ir.ArrayAttr.get([splat_layout]) + mgpu.infer_layout(self.module) + + self.assertSequenceEqual( + optimization_barrier.attributes["in_layouts"], [splat_layout] + ) + self.assertSequenceEqual( + optimization_barrier.attributes["out_layouts"], [splat_layout] + ) + if __name__ == "__main__": parameterized.absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index e7bd7fad3798..9ffaff121849 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -489,19 +489,12 @@ def get_packed_shape(strides, shape): class WGMMALayoutTest(TestCase): - @parameterized.product(dtype=[jnp.float16, jnp.float32], - transposed_smem=[False, True]) - def test_store_untiled(self, dtype, transposed_smem): + @parameterized.product(dtype=[jnp.float16, jnp.float32]) + def test_store_untiled(self, dtype): def kernel(ctx, out, _): del ctx - if transposed_smem: - out = memref_transpose(out, (1, 0)) - iota_tensor(64, 64, dtype).store_untiled( - out, vector_store=not transposed_smem - ) + iota_tensor(64, 64, dtype).store_untiled(out, optimized=False) expected = np.arange(64 * 64, dtype=dtype).reshape(64, 64) - if transposed_smem: - expected = expected.T iota = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), expected, () )() @@ -749,7 +742,7 @@ def kernel(ctx, lhs, rhs, out, scratch): acc = mgpu.wgmma(init_acc, lhs_smem, rhs_smem, swizzle=swizzle) nvvm.wgmma_commit_group_sync_aligned() nvvm.wgmma_wait_group_sync_aligned(0) - acc.value.store_untiled(out) + acc.value.store_untiled(out, optimized=False) def quantize(x): # Quantize the input to avoid rounding when feeding the WGMMA @@ -821,7 +814,7 @@ def kernel(ctx, rhs, out, rhs_smem): acc = mgpu.wgmma(init_acc, lhs_regs, rhs_smem, swizzle=swizzle) nvvm.wgmma_commit_group_sync_aligned() nvvm.wgmma_wait_group_sync_aligned(0) - acc.value.store_untiled(out) + acc.value.store_untiled(out, optimized=False) y_shape = (n, k) if rhs_transpose else (k, n) y = self.prng.uniform(-1, 1, y_shape).astype(dtype) @@ -881,7 +874,7 @@ def kernel(ctx, rhs, out, smem): acc = mgpu.wgmma(init_acc, lhs_regs, rhs_smem, swizzle=swizzle) nvvm.wgmma_commit_group_sync_aligned() nvvm.wgmma_wait_group_sync_aligned(0) - acc.value.store_untiled(out) + acc.value.store_untiled(out, optimized=False) jax_dtype = jnp.float16 y_shape = (n, k) if rhs_transpose else (k, n) @@ -908,19 +901,82 @@ def setUp(self): if not any(jtu.is_cuda_compute_capability_equal(sm) for sm in capabilities): self.skipTest("Only works on GPU with capability sm_100a or sm_101a") + @parameterized.parameters([jnp.float32, jnp.float16]) + def test_load_store_tmem(self, jax_dtype): + swizzle = 128 + in_mlir_dtype = utils.dtype_to_ir_type(jax_dtype) + swizzle_elems = swizzle // bytewidth(in_mlir_dtype) + tiling = (8, swizzle_elems) + + def kernel(ctx, input, output, scratch): + smem, barrier, tmem = scratch + ctx.async_copy( + src_ref=input, + dst_ref=smem, + swizzle=swizzle, + gmem_transform=mgpu.TileTransform(tiling), + barrier=barrier, + ) + barrier.wait() + tmem[:] = fa.FragmentedArray.load_tiled(smem, swizzle, layout=tcgen05.LAYOUT) + tcgen05.commit_tmem() + tmem[:].store_tiled(smem, swizzle) + mgpu.commit_shared() + ctx.async_copy( + src_ref=smem, dst_ref=output, swizzle=swizzle, gmem_transform=mgpu.TileTransform(tiling), + ) + ctx.await_async_copy(0) + + x = self.prng.uniform(-1, 1, (128, 128)).astype(jax_dtype) + scratch_shape = [ + jax.ShapeDtypeStruct(tile_shape(x.shape, tiling), jax_dtype), + mgpu.TMABarrier(), + mgpu.TMEM(x.shape, jax_dtype), + ] + y = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), x, x, scratch_shape + )(x) + np.testing.assert_array_equal(x, y) + @parameterized.product( lhs_transpose=(False, True), rhs_transpose=(False, True), in_jax_dtype=(jnp.float16, jnp.bfloat16), # TODO(apaszke): f32 - out_jax_dtype=(jnp.float32,), # TODO(apaszke): f16 accumulation + out_jax_dtype=(jnp.float16, jnp.float32,), m=(128,), # TODO(apaszke): 64, 192, 256 n=(64, 128, 256, 512), # TODO(apaszke): 192, other non-power-of-2 - k_steps=(1, 2), swizzle=(32, 64, 128,), - rhs_transpose_tiles=(False, True), + ) + def test_mma_basic(self, *args, **kwargs): + self._basic_mma_test( + *args, + **kwargs, + k_steps=2, # Reducing to 1 can be helpful while debugging. + lhs_transpose_tiles=False, + rhs_transpose_tiles=False, + ) + + @parameterized.product( + lhs_transpose=(False, True), + rhs_transpose=(False, True), + in_jax_dtype=(jnp.float16,), + out_jax_dtype=(jnp.float32,), + m=(128,), + n=(128, 512), + swizzle=(32, 64, 128,), lhs_transpose_tiles=(False, True), + rhs_transpose_tiles=(False, True), ) - def test_mma_basic( + def test_mma_transposed_tiles(self, *args, **kwargs): + if not kwargs["lhs_transpose_tiles"] and not kwargs["rhs_transpose_tiles"]: + self.skipTest("This is already tested in test_mma_basic") + self._basic_mma_test( + *args, + **kwargs, + k_steps=2, # Reducing to 1 can be helpful while debugging. + ) + + def _basic_mma_test( self, m, n, @@ -979,18 +1035,12 @@ def kernel(ctx, lhs, rhs, out, scratch): ) tcgen05.commit_arrive(barriers[2]) barriers[2].wait(for_tensor_core=True) - acc[:].store_untiled(out) - - in_finfo = jnp.finfo(in_jax_dtype) - exponent_bits, mantissa_bits = in_finfo.nexp, in_finfo.nmant - def quantize(x): - # Quantize the input to avoid rounding when feeding the TensorCore - return jax.lax.reduce_precision(x, exponent_bits, mantissa_bits) + acc[:].store_untiled(out, optimized=False) x_shape = (k, m) if lhs_transpose else (m, k) - x = quantize(self.prng.uniform(-1, 1, x_shape)).astype(in_jax_dtype) + x = self.prng.uniform(-1, 1, x_shape).astype(in_jax_dtype) y_shape = (n, k) if rhs_transpose else (k, n) - y = quantize(self.prng.uniform(-1, 1, y_shape)).astype(in_jax_dtype) + y = self.prng.uniform(-1, 1, y_shape).astype(in_jax_dtype) out_shape = jax.ShapeDtypeStruct((m, n), out_jax_dtype) if rhs_transpose_tiles: rhs_smem_shape = ( @@ -1015,14 +1065,15 @@ def quantize(x): )(x, y) x32, y32 = x.astype(np.float32), y.astype(np.float32) ref = (x32.T if lhs_transpose else x32) @ (y32.T if rhs_transpose else y32) - atol = 2e-2 if out_jax_dtype == jnp.float16 else 5e-6 - np.testing.assert_allclose(z, ref, atol=atol) + atol = 2e-2 if out_jax_dtype == jnp.float16 else 2e-5 + rtol = 8e-4 if out_jax_dtype == jnp.float16 else 1e-7 + np.testing.assert_allclose(z, ref, atol=atol, rtol=rtol) @parameterized.product( lhs_transpose=(False, True), rhs_transpose=(False, True), - in_jax_dtype=(jnp.float16,), # TODO(apaszke): f32 - out_jax_dtype=(jnp.float32,), # TODO(apaszke): f16 accumulation + in_jax_dtype=(jnp.float16,), + out_jax_dtype=(jnp.float32,), m=(256,), # TODO(apaszke): 64, 192, 256 n=(128, 256, 512), # TODO(apaszke): 192, other non-power-of-2 k_steps=(1, 2), @@ -1087,7 +1138,7 @@ def kernel(ctx, lhs, rhs, out, scratch): tcgen05.commit_arrive(barriers[2], collective=True, ctx=ctx) barriers[2].wait(for_tensor_core=True) m_slice = ds(arith.muli(block_id, c(m_block_tile, index)), m_block_tile) - acc[:].store_untiled(memref_slice(out, m_slice)) + acc[:].store_untiled(memref_slice(out, m_slice), optimized=False) in_finfo = jnp.finfo(in_jax_dtype) exponent_bits, mantissa_bits = in_finfo.nexp, in_finfo.nmant @@ -1140,7 +1191,7 @@ def kernel(ctx, dst, scratch): final_arr = arr + mgpu.FragmentedArray.load_strided( tmp, is_signed=False ) - final_arr.store_untiled(memref_slice(dst, 0)) + final_arr.store_untiled(memref_slice(dst, 0), optimized=False) scf.yield_([]) with ir.InsertionPoint(scf.IfOp(is_second_wg).then_block): barriers[0].wait() @@ -1151,7 +1202,7 @@ def kernel(ctx, dst, scratch): barriers[2].wait() # Synchronize this warpgroup before we overwrite tmp. arr.store_untiled(tmp) barriers[1].arrive() # Signal that tmp is ready. - final_arr.store_untiled(memref_slice(dst, 1)) + final_arr.store_untiled(memref_slice(dst, 1), optimized=False) scf.yield_([]) out_shape = jax.ShapeDtypeStruct((2, 128), jnp.int32) y = mgpu.as_gpu_kernel( @@ -1579,7 +1630,7 @@ def run_kernel(shape): run_kernel([1] * 6) with self.assertRaisesRegex( - ValueError, "last dimension to be divisible by 16" + ValueError, "last dimension to be divisible by 128" ): run_kernel([23]) @@ -1612,7 +1663,7 @@ def kernel(ctx, dst, _): mlir_dtype = utils.dtype_to_ir_type(dtype) iota = iota_tensor(m, n, dtype) rhs = iota if scalar_rhs is None else c(scalar_rhs, mlir_dtype) - op(iota, rhs).store_untiled(dst) + op(iota, rhs).store_untiled(dst, optimized=False) out_shape = jax.ShapeDtypeStruct((m, n), dtype) result = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () @@ -1658,7 +1709,7 @@ def test_division(self, op, dtype, m=64, n=32): def kernel(ctx, dst, _): iota = iota_tensor(m, n, dtype) - op(dtype(4.2).item() * iota, iota + 1).store_untiled(dst) + op(dtype(4.2).item() * iota, iota + 1).store_untiled(dst, optimized=False) out_shape = jax.ShapeDtypeStruct((m, n), dtype) result = mgpu.as_gpu_kernel( @@ -1688,22 +1739,46 @@ def kernel(ctx, dst, _): rhs = 0 if rhs_is_literal else iota + 1 res = op(iota, rhs) assert not res.is_signed - res.astype(i8, is_signed=False).store_untiled(dst) + res.astype(i8, is_signed=False).store_untiled(dst, optimized=False) out_shape = jax.ShapeDtypeStruct((m, n), jnp.int8) result = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () )() iota = np.arange(m * n, dtype=dtype).reshape(m, n) - rhs = rhs = 0 if rhs_is_literal else iota + 1 + rhs = 0 if rhs_is_literal else iota + 1 np.testing.assert_array_equal(result, op(iota, rhs).astype(jnp.int8)) + def test_foreach_wgmma_row_array(self): + def kernel(ctx, out, smem): + del ctx, smem + x = iota_tensor(128, 128, jnp.float32) + row = x.reduce("add", 1) + # Test returning an array + row = row.foreach( + lambda x, _: arith.addf(x, c(1, row.mlir_dtype)), create_array=True + ) + # Test no array return + @row.foreach + def _(v, idx): + memref.store(v, out, idx) + + result = mgpu.as_gpu_kernel( + kernel, + grid=(1, 1, 1), + block=(128, 1, 1), + in_shape=(), + out_shape=jax.ShapeDtypeStruct(shape=(128,), dtype=jnp.float32), + smem_scratch_shape=(), + )() + iota = np.arange(128 * 128, dtype=jnp.float32).reshape(128, 128) + np.testing.assert_array_equal(result, iota.sum(axis=1) + 1) + def test_foreach(self): dtype = jnp.int32 swizzle = 128 - tile = 64, swizzle // jnp.dtype(dtype).itemsize + tiling = (8, swizzle // jnp.dtype(dtype).itemsize) shape = 128, 192 - tiled_shape = mgpu.tile_shape(shape, tile) mlir_dtype = utils.dtype_to_ir_type(dtype) cst = 9999 def causal(val, idx): @@ -1711,12 +1786,16 @@ def causal(val, idx): mask = arith.cmpi(arith.CmpIPredicate.uge, row, col) return arith.select(mask, val, c(cst, mlir_dtype)) - tiling = mgpu.TileTransform(tile) def kernel(ctx, dst, smem): x = iota_tensor(shape[0], shape[1], dtype) - x.foreach(causal, create_array=True, is_signed=False).store_untiled(smem) + x.foreach(causal, create_array=True, is_signed=False).store_tiled(smem, swizzle=128) mgpu.commit_shared() - ctx.async_copy(src_ref=smem, dst_ref=dst) + ctx.async_copy( + src_ref=smem, + dst_ref=dst, + gmem_transform=mgpu.TileTransform(tiling), + swizzle=128, + ) ctx.await_async_copy(0) iota = np.arange(np.prod(shape), dtype=dtype).reshape(*shape) @@ -1726,7 +1805,7 @@ def kernel(ctx, dst, smem): (128, 1, 1), (), jax.ShapeDtypeStruct(shape=shape, dtype=dtype), - jax.ShapeDtypeStruct(shape=shape, dtype=dtype), + jax.ShapeDtypeStruct(shape=mgpu.tile_shape(shape, tiling), dtype=dtype), )() expected = jnp.tril(iota) + jnp.triu(jnp.ones(shape), k=1) * cst np.testing.assert_array_equal(result, expected) @@ -1738,7 +1817,7 @@ def kernel(ctx, dst, smem): def test_bitwise(self, op, dtype, m=64, n=8): def kernel(ctx, dst, _): iota = iota_tensor(m, n, dtype) - op(iota, iota + 1).store_untiled(dst) + op(iota, iota + 1).store_untiled(dst, optimized=False) out_shape = jax.ShapeDtypeStruct((m, n), dtype) result = mgpu.as_gpu_kernel( @@ -1762,7 +1841,7 @@ def test_unary(self, ops, dtype, m=64, n=32): def kernel(ctx, dst, _): iota = iota_tensor(m, n, dtype) - op(iota).store_untiled(dst) + op(iota).store_untiled(dst, optimized=False) out_shape = jax.ShapeDtypeStruct((m, n), dtype) result = mgpu.as_gpu_kernel( @@ -1775,7 +1854,7 @@ def test_select(self, m=64, n=32): def kernel(ctx, dst, _): iota = iota_tensor(m, n, jnp.int32) - (iota < 16).select(iota * 2, iota * 3).store_untiled(dst) + (iota < 16).select(iota * 2, iota * 3).store_untiled(dst, optimized=False) out_shape = jax.ShapeDtypeStruct((m, n), jnp.int32) result = mgpu.as_gpu_kernel( @@ -1798,7 +1877,7 @@ def test_math(self, ops, approx, m=64, n=32): op, np_op = ops def kernel(ctx, dst, _): iota = iota_tensor(m, n, jnp.float32) - op(iota).store_untiled(dst) + op(iota).store_untiled(dst, optimized=False) out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32) result = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () @@ -1819,7 +1898,7 @@ def kernel(ctx, src, dst, scratch): src, is_signed=utils.is_signed(dtype) ) acc = src.reduce_sum(scratch).broadcast((m,)) - acc.store_untiled(dst) + acc.store_untiled(dst, optimized=False) in_shape = jax.ShapeDtypeStruct((m, n), dtype) out_shape = jax.ShapeDtypeStruct((m,), dtype) @@ -1847,7 +1926,7 @@ def kernel(ctx, dst, _): is_signed=utils.is_signed(dtype), ) acc = src.reduce_sum().broadcast((m,)) - acc.store_untiled(dst) + acc.store_untiled(dst, optimized=False) kernel_fn = mgpu.as_gpu_kernel( kernel, @@ -1867,7 +1946,7 @@ def kernel(ctx, dst, _): def test_reduce(self, op, m=64, n=32): def kernel(ctx, dst, _): iota = iota_tensor(m, n, jnp.float32) - iota.reduce(op, axis=1).broadcast_minor(n).store_untiled(dst) + iota.reduce(op, axis=1).broadcast_minor(n).store_untiled(dst, optimized=False) out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32) result = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () @@ -1888,7 +1967,7 @@ def kernel(ctx, dst, _): cte = c(1, iota.mlir_dtype) cte_arr = mgpu.FragmentedArray.splat(cte, ()) cte_arr = cte_arr.reshape((1, 1)).broadcast((m, n)) - (iota + cte_arr).store_untiled(dst) + (iota + cte_arr).store_untiled(dst, optimized=False) out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32) result = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () @@ -1903,7 +1982,7 @@ def kernel(ctx, dst, _): t = mgpu.FragmentedArray.splat( v, (128,), mgpu.WGMMA_ROW_LAYOUT ) - t.broadcast_minor(32).store_untiled(dst) + t.broadcast_minor(32).store_untiled(dst, optimized=False) out_shape = jax.ShapeDtypeStruct((128, 32), jnp.float32) result = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () @@ -1922,7 +2001,7 @@ def kernel(ctx, src, dst, _): assert isinstance(pi_arr_sq.layout, mgpu.WGStridedFragLayout) pi_arr_cube = pi_splat.broadcast(pi_arr.shape) * pi_arr_sq assert isinstance(pi_arr_cube.layout, mgpu.WGStridedFragLayout) - (pi_arr == pi_arr).select(pi_splat, pi_arr_cube).store_untiled(dst) + (pi_arr == pi_arr).select(pi_splat, pi_arr_cube).store_untiled(dst, optimized=False) out_shape = jax.ShapeDtypeStruct((128, 32), jnp.float32) inp = jnp.ones_like(out_shape) * 3.14 @@ -1946,20 +2025,63 @@ def kernel(ctx, *args): )(inp) np.testing.assert_array_equal(inp, result) - @parameterized.product(in_shape=((128,), (64,))) - def test_wgmma_row_load_store_with_layout(self, in_shape): + @parameterized.product( + in_shape=((1024,), (256,), (128,), (64,)), + dtype=(jnp.float16, jnp.float32), + swizzle=(16, 32, 64, 128) + ) + def test_wgmma_row_load_store_with_layout(self, in_shape, dtype, swizzle): + def kernel(ctx, gmem_input, gmem_output, smem): + smem_input, smem_output = smem + copy(gmem_input, smem_input, swizzle=swizzle) + t = mgpu.FragmentedArray.load_untiled( + smem_input, layout=mgpu.WGMMA_ROW_LAYOUT, swizzle=swizzle + ) + t.store_untiled(smem_output) + copy(smem_output, gmem_output) + + inp = out = self.prng.uniform(-1, 1, in_shape).astype(dtype) + result = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (inp,), out, [inp, out], + )(inp) + np.testing.assert_array_equal(inp, result) + + @parameterized.product( + in_shape=((128,), (64,)), + dtype=(jnp.float16, jnp.float32), + swizzle=(16, 32, 64, 128), + ) + def test_wgmma_col_load_store_with_layout(self, in_shape, dtype, swizzle): def kernel(ctx, *args): gmem_input, gmem_output, (smem_input, smem_output) = args - copy(gmem_input, smem_input) - t = mgpu.FragmentedArray.load_wgmma_row(smem_input) + copy(gmem_input, smem_input, swizzle=swizzle) + t = mgpu.FragmentedArray.load_untiled( + smem_input, swizzle=swizzle, layout=mgpu.WGMMA_COL_LAYOUT + ) t.store_untiled(smem_output) copy(smem_output, gmem_output) - inp = out = self.prng.uniform(-1, 1, in_shape).astype(jnp.float32) + inp = out = self.prng.uniform(-1, 1, in_shape).astype(dtype) result = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (inp,), out, [inp, out], )(inp) - np.testing.assert_array_equal(inp, result) + np.testing.assert_array_equal(result, inp) + + @parameterized.parameters((128, 128), (128, 64), (64, 128)) + def test_broadcast_major(self, m, n): + def kernel(ctx, gmem_input, gmem_output, _): + t = mgpu.FragmentedArray.load_untiled( + gmem_input, layout=mgpu.WGMMA_COL_LAYOUT, optimized=False + ) + t.broadcast_major(m).store_untiled(gmem_output, optimized=False) + + inp = self.prng.uniform(-1, 1, (n,)).astype(jnp.float16) + out_shape = jax.ShapeDtypeStruct((m, n), jnp.float16) + result = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (inp,), out_shape, inp + )(inp) + out_ref = jax.lax.broadcast_in_dim(inp, (m, n), (1,)) + np.testing.assert_array_equal(result, out_ref) def test_warp_tree_reduce(self): def kernel(ctx, out, *_): @@ -1988,7 +2110,7 @@ def kernel(ctx, inp, out, smem): del ctx, smem arr = mgpu.FragmentedArray.load_strided(inp, is_signed=True) assert ir.VectorType(arr.registers.flat[0].type).shape == [reg_length] - arr.astype(mlir_dtype_to).store_untiled(out) + arr.astype(mlir_dtype_to).store_untiled(out, optimized=False) x = jnp.arange(-128, 128, dtype=jax_dtype_from) x = jnp.tile(x, reg_length // 2) @@ -2064,7 +2186,7 @@ def test_convert_bool_to_u8(self): def kernel(ctx, dst, _): i8 = ir.IntegerType.get_signless(8) iota = iota_tensor(m, n, jnp.uint8) - (iota > 10).astype(i8, is_signed=False).store_untiled(dst) + (iota > 10).astype(i8, is_signed=False).store_untiled(dst, optimized=False) out_shape = jax.ShapeDtypeStruct((m, n), jnp.int8) result = mgpu.as_gpu_kernel( @@ -2192,7 +2314,7 @@ def kernel(ctx, dst, _): ) self.assertEqual(tiled.shape, shape) self.assertEqual(tiled.mlir_dtype, iota.mlir_dtype) - tiled.store_untiled(dst) + tiled.store_untiled(dst, optimized=False) ty = jax.ShapeDtypeStruct(shape, dtype) f = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), (), ty, ()) expected = np.arange(math.prod(shape), dtype=dtype).reshape(shape) @@ -2430,7 +2552,7 @@ def set_in_transforms( in_transforms = [] smem_refs = filter(inference_utils.is_transformable_smem_memref, op.operands) # pylint: disable=undefined-variable - for _, result_transforms in jax.util.safe_zip(smem_refs, transforms): + for _, result_transforms in jax._src.util.safe_zip(smem_refs, transforms): in_transforms.append( ir.ArrayAttr.get([t.attr() for t in result_transforms]) ) @@ -2476,7 +2598,7 @@ def add(ctx, a, b, result, smem): in_shape=(jax_shape, jax_shape), out_shape=jax_shape, smem_scratch_shape=[], - thread_semantics=mgpu.ThreadSemantics.Warpgroup, + thread_semantics=mgpu.LoweringSemantics.Warpgroup, ) x = self.prng.uniform(-1, 1, shape).astype(dtype) @@ -2621,7 +2743,7 @@ def add( jax_shape_sliced, core.TMABarrier(1), ], - thread_semantics=mgpu.ThreadSemantics.Warpgroup, + thread_semantics=mgpu.LoweringSemantics.Warpgroup, ) x = self.prng.uniform(-1, 1, test_case.shape).astype(dtype) @@ -2720,7 +2842,7 @@ def add( spec, core.TMABarrier(1), ], - thread_semantics=mgpu.ThreadSemantics.Warpgroup, + thread_semantics=mgpu.LoweringSemantics.Warpgroup, ) x = self.prng.uniform(-1, 1, spec.shape).astype(dtype) @@ -2868,7 +2990,7 @@ def matmul( result_jax_shape, core.TMABarrier(1), ], - thread_semantics=mgpu.ThreadSemantics.Warpgroup, + thread_semantics=mgpu.LoweringSemantics.Warpgroup, ) prng_key = jax.random.key(1234) diff --git a/tests/mosaic/gpu_transform_inference_test.py b/tests/mosaic/gpu_transform_inference_test.py index b7cd146dfdb6..983efebc4f86 100644 --- a/tests/mosaic/gpu_transform_inference_test.py +++ b/tests/mosaic/gpu_transform_inference_test.py @@ -24,7 +24,9 @@ from jax._src.interpreters import mlir as mlir_interpreter from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith +from jax._src.lib.mlir.dialects import builtin from jax._src.lib.mlir.dialects import func +from jax._src.lib.mlir.dialects import memref from jax._src.lib.mlir.dialects import vector import jax.experimental.mosaic.gpu as mgpu from jax.experimental.mosaic.gpu import fragmented_array as fa @@ -418,6 +420,105 @@ def body(offset): with self.assertRaisesRegex(NotImplementedError, "Conflicting transforms"): mgpu.infer_transforms(self.module) + @parameterized.parameters([False, True]) + def test_infer_transforms_for_subview_op_propagates_undisturbed_tile_and_swizzle_transforms( + self, annotate_input + ): + subview_op = user_op = None + shape = (2, 64, 64) + elt_ty = ir.BF16Type.get() + smem = ir.Attribute.parse("#gpu.address_space") + + in_ref_ty = ir.MemRefType.get(shape, elt_ty, memory_space=smem) + out_ref_ty = ir.MemRefType.get(shape[2:], elt_ty, memory_space=smem) + + def body(in_ref): + nonlocal subview_op, user_op + subview_op = memref.SubViewOp( + out_ref_ty, + in_ref, + [], + [], + [], + static_offsets=[1, 0, 0], + static_sizes=[1, 64, 64], + static_strides=[1, 1, 1], + ) + user_op = builtin.UnrealizedConversionCastOp( + [out_ref_ty], [subview_op.result] + ) + + with ir.InsertionPoint(self.module.body): + f = func.FuncOp.from_py_func(in_ref_ty)(body).func_op + + transforms = ir.ArrayAttr.get([ + mgpu.dialect.TileTransformAttr.get((32, 16)), + mgpu.dialect.SwizzleTransformAttr.get(32), + ]) + + if annotate_input: + f.attributes["in_transforms"] = ir.ArrayAttr.get([transforms]) + else: + user_op.attributes["in_transforms"] = ir.ArrayAttr.get([transforms]) + + mgpu.infer_transforms(self.module) + + self.assertSequenceEqual( + inference_utils.in_transforms(subview_op), [transforms] + ) + self.assertSequenceEqual( + inference_utils.out_transforms(subview_op), [transforms] + ) + + @parameterized.parameters([False, True]) + def test_infer_transforms_for_subview_op_raises_on_disturbed_transforms( + self, annotate_input + ): + subview_op = user_op = None + shape = (2, 64, 64) + elt_ty = ir.BF16Type.get() + smem = ir.Attribute.parse("#gpu.address_space") + + in_ref_ty = ir.MemRefType.get(shape, elt_ty, memory_space=smem) + out_ref_ty = ir.MemRefType.get((2, 64, 32), elt_ty, memory_space=smem) + + def body(in_ref): + nonlocal subview_op, user_op + subview_op = memref.SubViewOp( + out_ref_ty, + in_ref, + [], + [], + [], + static_offsets = [1, 0, 0], + static_sizes = [2, 64, 32], + static_strides = [1, 1, 1] + ) + user_op = builtin.UnrealizedConversionCastOp( + [out_ref_ty], [subview_op.result] + ) + + with ir.InsertionPoint(self.module.body): + f = func.FuncOp.from_py_func(in_ref_ty)(body).func_op + + transforms = ir.ArrayAttr.get([ + mgpu.dialect.TileTransformAttr.get((32, 16)), + mgpu.dialect.SwizzleTransformAttr.get(32), + ]) + + if annotate_input: + f.attributes["in_transforms"] = ir.ArrayAttr.get([transforms]) + else: + user_op.attributes["in_transforms"] = ir.ArrayAttr.get([transforms]) + + if annotate_input: + f.attributes["in_transforms"] = ir.ArrayAttr.get([transforms]) + else: + user_op.attributes["in_transforms"] = ir.ArrayAttr.get([transforms]) + + with self.assertRaises(NotImplementedError): + mgpu.infer_transforms(self.module) + if __name__ == "__main__": parameterized.absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/mosaic/matmul_test.py b/tests/mosaic/matmul_test.py index d598d7d0c0ec..9634718d2d44 100644 --- a/tests/mosaic/matmul_test.py +++ b/tests/mosaic/matmul_test.py @@ -15,12 +15,15 @@ """Test different parameterizations of a matmul.""" import os -import unittest from absl.testing import absltest, parameterized from jax._src import config from jax._src import test_util as jtu import jax.numpy as jnp + +import hypothesis as hp +import hypothesis.strategies as hps + try: # We only import this to see if Mosaic is available. import jax.experimental.mosaic.gpu # noqa: F401 @@ -28,11 +31,6 @@ matmul = None else: from jax.experimental.mosaic.gpu.examples import matmul -try: - import hypothesis as hp - import hypothesis.strategies as hps -except (ModuleNotFoundError, ImportError): - raise unittest.SkipTest("these tests require hypothesis") config.parse_flags_with_absl() @@ -48,6 +46,7 @@ def wrapper(self, seed): @jtu.with_config(jax_traceback_filtering="off") +@jtu.thread_unsafe_test_class() # hypothesis is not thread safe class MatmulTestCase(jtu.JaxTestCase): def setUp(self): diff --git a/tests/mutable_array_test.py b/tests/mutable_array_test.py index e962653ed32d..5b7669f3db2d 100644 --- a/tests/mutable_array_test.py +++ b/tests/mutable_array_test.py @@ -116,6 +116,18 @@ def f(y_mut, z): check_dtypes=False) self.assertAllClose(w, 10, check_dtypes=False) + @parameterized.parameters([True, False]) + def test_len_mutable_array(self, jit): + x_mut = core.mutable_array(jnp.zeros(3)) + + def f(): + return jnp.int32(len(x_mut)) + + if jit: + f = jax.jit(f) + + self.assertEqual(f(), 3) + @parameterized.parameters([True, False]) def test_internal_mutarray_basic(self, jit): def f(): @@ -227,6 +239,26 @@ def f(x_ref): x_ref = core.mutable_array(x) y = f(x_ref) + def test_vmap_basic(self): + @jax.vmap + def f(x): + x_ref = core.mutable_array(x) + x_ref[...] = x_ref[...] * x_ref[...] + return x_ref[...] + xs = jnp.arange(4.) + ys = f(xs) + self.assertAllClose(ys, xs ** 2, check_dtypes=False) + + def test_implicit_bitcast_regression(self): + # https://github.com/jax-ml/jax/issues/27683 + v = core.mutable_array(jnp.array([0, 0, 0])) + with self.assertRaises(ValueError): + v[...] += 1.0 + + def test_implicit_cast_in_swap(self): + v = core.mutable_array(jnp.array(0, dtype='bfloat16')) + v[...] += 1.0 # don't crash + @jtu.with_config(jax_mutable_array_checks=True) class MutableArrayErrorsTest(jtu.JaxTestCase): diff --git a/tests/nn_test.py b/tests/nn_test.py index ed016ec349ef..385b216aeb57 100644 --- a/tests/nn_test.py +++ b/tests/nn_test.py @@ -31,7 +31,6 @@ from jax._src.cudnn.scaled_matmul_stablehlo import ( quantize, shape_normalization, - BlockScaleConfig, ) from jax.test_util import check_grads from jax import nn @@ -110,17 +109,7 @@ def create_mxfp8_configs_if_available(): if _dtypes.float8_e8m0fnu is None: raise unittest.SkipTest("float8_e8m0fnu is not available.") - def _create_mxfp8_config(): - return BlockScaleConfig( - mode='mxfp8', - block_size=32, - data_type=jnp.float8_e4m3fn, - scale_type=jnp.float8_e8m0fnu, - global_scale=None, - infer_only=False - ) - - return [_create_mxfp8_config() for _ in range(3)] + return [nn.get_scaled_dot_general_config("mxfp8") for _ in range(3)] @jtu.with_config(jax_legacy_prng_key="allow", @@ -130,10 +119,9 @@ class NNFunctionsTest(jtu.JaxTestCase): contract=[160, 96], lhs_non_contract=[240, 100], dtype=[jnp.float16, jnp.bfloat16, jnp.float32], - impl=['cudnn',], ) - def testScaledMatmul(self, contract, lhs_non_contract, dtype, impl): - if impl == 'cudnn' and not _is_required_cudnn_version_satisfied("10.0", 90700): + def testScaledMatmul(self, contract, lhs_non_contract, dtype): + if not _is_required_cudnn_version_satisfied("10.0", 90700): raise unittest.SkipTest("CUDA or cuDNN versions are not compatible") # Check if float8_e8m0fnu is available configs = create_mxfp8_configs_if_available() @@ -153,11 +141,10 @@ def testScaledMatmul(self, contract, lhs_non_contract, dtype, impl): @parameterized.product( is_training=[True, False], output_type=[jnp.float16, jnp.bfloat16, jnp.float32], - impl=['cudnn',], ) def testScaledDotGeneral( - self, is_training, output_type, impl): - if impl == 'cudnn' and not _is_required_cudnn_version_satisfied("10.0", 90700): + self, is_training, output_type): + if not _is_required_cudnn_version_satisfied("10.0", 90700): raise unittest.SkipTest("CUDA or cuDNN versions are not compatible") configs = create_mxfp8_configs_if_available() @@ -422,6 +409,7 @@ def testSparseplusAndSparseSigmoid(self): jax.grad(nn.sparse_plus)(-2.), nn.sparse_sigmoid(-2.), check_dtypes=False) + @jtu.skip_on_flag("jax_skip_slow_tests", True) def testSquareplusGrad(self): check_grads(nn.squareplus, (1e-8,), order=4, rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None) @@ -442,6 +430,7 @@ def testSquareplusGradNan(self): def testSquareplusZero(self, dtype): self.assertEqual(dtype(1), nn.squareplus(dtype(0), dtype(4))) + @jtu.skip_on_flag("jax_skip_slow_tests", True) def testMishGrad(self): check_grads(nn.mish, (1e-8,), order=4, rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None) @@ -541,7 +530,7 @@ def gelu_reference(x): (jnp.float32, jnp.bfloat16, jnp.float16), (partial(nn.gelu, approximate=False), partial(nn.gelu, approximate=True), - nn.relu, nn.softplus, nn.sparse_plus, nn.sigmoid, nn.squareplus, nn.mish))) + nn.relu, nn.identity, nn.softplus, nn.sparse_plus, nn.sigmoid, nn.squareplus, nn.mish))) def testDtypeMatchesInput(self, dtype, fn): x = jnp.zeros((), dtype=dtype) out = fn(x) @@ -829,6 +818,12 @@ def testVarianceScalingError(self): ): initializer(rng, shape) + def testIdentity(self): + x = jnp.array([1., 2., 3.]) + self.assertAllClose(nn.identity(x), x, check_dtypes=False) + grad = jax.grad(nn.identity)(6.0) + self.assertEqual(grad, 1.) + def testAccidentalUpcasting(self): rng = random.PRNGKey(0) shape = (4, 4) diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index 987a3aa9d50a..fa98c0af4be8 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -107,7 +107,7 @@ jax_multiplatform_test( "gpu_a100_x32", "gpu_h100", "gpu_h100_x32", - "tpu_v6e_1x1", + "tpu_v6e", ], shard_count = { "cpu": 16, @@ -215,7 +215,6 @@ jax_multiplatform_test( "gpu_h100", ], env = { - "JAX_PALLAS_USE_MOSAIC_GPU": "1", "JAX_PALLAS_VERBOSE_ERRORS": "0", }, deps = [ @@ -316,7 +315,7 @@ jax_multiplatform_test( ], enable_backends = [], enable_configs = [ - "tpu_v5e_4x2", + "tpu_v5e_x8", ], deps = [ "//jax:pallas_tpu_ops", @@ -329,7 +328,7 @@ jax_multiplatform_test( "tpu_gmm_test.py", ], enable_backends = ["tpu"], - shard_count = 50, + shard_count = 5, tags = [ "noasan", # Times out. "nomsan", # Times out. @@ -353,7 +352,7 @@ jax_multiplatform_test( enable_backends = ["tpu"], enable_configs = [ "tpu_v5e", - "tpu_v5p_1x1", + "tpu_v5p", ], deps = [ "//jax:extend", @@ -384,10 +383,10 @@ jax_multiplatform_test( srcs = ["tpu_pallas_distributed_test.py"], enable_backends = ["tpu"], enable_configs = [ - "tpu_v5e_4x2", - "tpu_v5p_2x2", - "tpu_v4_2x2", - "tpu_v3_2x2", + "tpu_v5e_x8", + "tpu_v5p_x4", + "tpu_v4_x4", + "tpu_v3_x4", ], deps = [ "//jax:extend", @@ -401,8 +400,8 @@ jax_multiplatform_test( srcs = ["tpu_pallas_pipeline_test.py"], enable_backends = ["tpu"], enable_configs = [ - "tpu_v5e_4x2", - "tpu_v5p_1x1", + "tpu_v5e_x8", + "tpu_v5p", ], shard_count = 5, tags = [ @@ -422,8 +421,8 @@ jax_multiplatform_test( srcs = ["tpu_pallas_async_test.py"], enable_backends = ["tpu"], enable_configs = [ - "tpu_v5e_4x2", - "tpu_v5p_1x1", + "tpu_v5e_x8", + "tpu_v5p", ], deps = [ "//jax:pallas_tpu", @@ -452,7 +451,7 @@ jax_multiplatform_test( ], enable_backends = ["tpu"], enable_configs = [ - "tpu_v5p_2x2", + "tpu_v5p_x4", ], deps = [ "//jax:pallas", @@ -492,7 +491,7 @@ jax_multiplatform_test( name = "tpu_paged_attention_kernel_test", srcs = ["tpu_paged_attention_kernel_test.py"], disable_configs = [ - "tpu_v5p_1x1", + "tpu_v5p", ], enable_backends = ["tpu"], shard_count = 5, @@ -510,7 +509,7 @@ jax_multiplatform_test( name = "tpu_ragged_paged_attention_test", srcs = ["tpu_ragged_paged_attention_test.py"], disable_configs = [ - "tpu_v5p_1x1", + "tpu_v5p", ], enable_backends = ["tpu"], shard_count = 24, @@ -541,6 +540,21 @@ jax_multiplatform_test( ] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"), ) +jax_multiplatform_test( + name = "tpu_splash_attention_kernel_sharded_test", + srcs = ["tpu_splash_attention_kernel_sharded_test.py"], + enable_configs = [ + "tpu_v5e_x8", + "tpu_v5p_x4", + ], + shard_count = 5, + deps = [ + "//jax:extend", + "//jax:pallas_tpu", + "//jax:pallas_tpu_ops", + ], +) + # This test doesn't need a TPU; it only tests numpy-using helpers. jax_py_test( name = "tpu_splash_attention_mask_test", @@ -667,10 +681,31 @@ jax_multiplatform_test( ) jax_multiplatform_test( - name = "tpu_fusable_matmul_test", - srcs = ["tpu_fusable_matmul_test.py"], + name = "fusion_test", + srcs = [ + "fusion_test.py", + ], + disable_configs = [ + "cpu", + "cpu_shardy", + ], + enable_backends = ["cpu"], + tags = [ + "noasan", + "nomsan", + "notsan", + ], + deps = [ + "//jax:pallas", + "//jax:pallas_fuser", + ] + py_deps("absl/testing") + py_deps("numpy"), +) + +jax_multiplatform_test( + name = "tpu_fusible_matmul_test", + srcs = ["tpu_fusible_matmul_test.py"], disable_configs = [ - "tpu_v3_1x1", + "tpu_v3", "tpu_pjrt_c_api", "gpu_v100", "gpu_v100_x32", @@ -684,10 +719,10 @@ jax_multiplatform_test( ], enable_backends = ["tpu"], enable_configs = [ - "tpu_v4_1x1", + "tpu_v4", "tpu_v5e", - "tpu_v5p_1x1", - "tpu_v6e_1x1", + "tpu_v5p", + "tpu_v6e", ], shard_count = 4, tags = [ diff --git a/tests/pallas/fuser_block_spec_test.py b/tests/pallas/fuser_block_spec_test.py index 1b3a215876ec..377901933b4e 100644 --- a/tests/pallas/fuser_block_spec_test.py +++ b/tests/pallas/fuser_block_spec_test.py @@ -769,6 +769,41 @@ def f(x): kernel_fn((0, 0, 0, 0), scalar_prefetch_values, (), x), relu_x ) + def test_pull_block_spec_handles_closed_over_constants(self): + x = jnp.ones((2, 512, 512)) + i = jnp.array(1) + + def f(): + return x[i] + + f2, new_values, scalar_prefetch_values = block_spec_lib.get_fusion_values(f) + self.assertLen(new_values, 1) + self.assertLen(scalar_prefetch_values, 1) + + block_spec = pl.BlockSpec( + (None, 1, 128, 128), lambda i, j, k, l, _: (i, j, k, l) + ) + kernel_fn, (value_block_specs,), _ = block_spec_lib.pull_block_spec( + f2, + block_spec, + grid=(2, 2, 4, 4), + scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(), + )(new_values) + self.assertLen(value_block_specs, 1) + scalar_prefetch_values = jax.tree.map( + lambda x: x[None], scalar_prefetch_values + ) + fn = lambda x: kernel_fn((0, 0, 0, 0), scalar_prefetch_values, x) + new_values_type = (jax.ShapeDtypeStruct((1, 128, 128), jnp.float32),) + # Try pulling again + # This should not raise an error. + _ = block_spec_lib.pull_block_spec( + fn, + block_spec, + grid=(1,), + scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(), + )(new_values_type) + class PushBlockSpecTest(parameterized.TestCase): diff --git a/tests/pallas/fusion_test.py b/tests/pallas/fusion_test.py new file mode 100644 index 000000000000..4bd02345ca62 --- /dev/null +++ b/tests/pallas/fusion_test.py @@ -0,0 +1,234 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest +import jax +from jax._src import test_util as jtu +from jax.experimental.pallas import fuser +import jax.numpy as jnp +import numpy as np + +jax.config.parse_flags_with_absl() + + +class FusionTest(jtu.JaxTestCase): + + def test_basic_fusion(self): + + @jax.jit + @fuser.fuse + @fuser.fusible + def f(x_fn, y_fn): + x = x_fn() + if y_fn is None: + y_fn = lambda x: x + return y_fn(x) + + x = jax.random.normal(jax.random.key(0), (128, 128), dtype=jnp.float32) + np.testing.assert_array_equal(f(x), x) + + def test_separate_output_fusions_trivial(self): + + @fuser.fusible(output_fusion_prefix=(True, True)) + def f(x_fn, y_fn, z_fns): + x = x_fn() + y = y_fn() + if z_fns is None: + z_fns = lambda x: x, lambda x: x + z_fn1, z_fn2 = z_fns + return z_fn1(x), z_fn2(y) + + @jax.jit + @fuser.fuse + def g(x, y): + x, y = f(x, y) + return x, y * 2 + + x = jax.random.normal(jax.random.key(0), (128, 128), dtype=jnp.float32) + y = jax.random.normal(jax.random.key(1), (1, 128), dtype=jnp.float32) + x_out, y_out = g(x, y) + np.testing.assert_array_equal(x_out, x) + np.testing.assert_array_equal(y_out, y * 2) + + def test_separate_output_fusions_should_error_if_not_disjoint(self): + + @fuser.fusible(output_fusion_prefix=(True, True)) + def f(x_fn, y_fn, z_fns): + x = x_fn() + y = y_fn() + if z_fns is None: + z_fns = lambda x: x, lambda x: x + z_fn1, z_fn2 = z_fns + return z_fn1(x), z_fn2(y) + + @jax.jit + @fuser.fuse + def g(x, y): + x_res, y_res = f(x, y) + return x_res + y_res + + x = jax.random.normal(jax.random.key(0), (128, 128), dtype=jnp.float32) + y = jax.random.normal(jax.random.key(1), (128, 128), dtype=jnp.float32) + + with self.assertRaisesRegex( + ValueError, + "Outputs must be disjoint in order to use separate output fusions", + ): + g(x, y) + + def test_separate_output_fusions_allows_permute(self): + + @fuser.fusible(output_fusion_prefix=(True, True)) + def f(x_fn, y_fn, z_fns): + x = x_fn() + y = y_fn() + if z_fns is None: + z_fns = lambda x: x, lambda x: x + z_fn1, z_fn2 = z_fns + return z_fn1(x), z_fn2(y) + + @jax.jit + @fuser.fuse + def g(x, y): + x_res, y_res = f(x, y) + return y_res * 2, x_res + + x = jax.random.normal(jax.random.key(0), (128, 128), dtype=jnp.float32) + y = jax.random.normal(jax.random.key(1), (1, 128), dtype=jnp.float32) + y_out, x_out = g(x, y) + np.testing.assert_array_equal(x_out, x) + np.testing.assert_array_equal(y_out, y * 2) + + def test_separate_output_fusions_with_nesting(self): + + @fuser.fusible(output_fusion_prefix=(True, True)) + def f(x_fn, y_fn, z_fns): + x = x_fn() + y = y_fn() + if z_fns is None: + z_fns = lambda x: x, lambda x: x + z_fn1, z_fn2 = z_fns + return z_fn1(x), z_fn2(y) + + @jax.jit + @fuser.fuse + def g(x, y): + x_res, y_res = f(x, y) + return (x_res * 2, x_res + x_res), y_res + + x = jax.random.normal(jax.random.key(0), (128, 128), dtype=jnp.float32) + y = jax.random.normal(jax.random.key(1), (1, 128), dtype=jnp.float32) + (x1_out, x2_out), y_out = g(x, y) + np.testing.assert_array_equal(x1_out, x * 2) + np.testing.assert_array_equal(x2_out, x + x) + np.testing.assert_array_equal(y_out, y) + + def test_separate_output_fusions_with_nesting_and_permutation(self): + + @fuser.fusible(output_fusion_prefix=(True, True)) + def f(x_fn, y_fn, z_fns): + x = x_fn() + y = y_fn() + if z_fns is None: + z_fns = lambda x: x, lambda x: x + z_fn1, z_fn2 = z_fns + return z_fn1(x), z_fn2(y) + + @jax.jit + @fuser.fuse + def g(x, y): + x_res, y_res = f(x, y) + return y_res, (x_res * 2, x_res + x_res) + + x = jax.random.normal(jax.random.key(0), (128, 128), dtype=jnp.float32) + y = jax.random.normal(jax.random.key(1), (1, 128), dtype=jnp.float32) + y_out, (x1_out, x2_out) = g(x, y) + np.testing.assert_array_equal(x1_out, x * 2) + np.testing.assert_array_equal(x2_out, x + x) + np.testing.assert_array_equal(y_out, y) + + def test_separate_output_fusions_with_deep_output_mask(self): + + @fuser.fusible(output_fusion_prefix=(True, (True, True))) + def f(x_fn, y_fn, z_fn, o_fns): + x = x_fn() + y = y_fn() + z = z_fn() + if o_fns is None: + o_fns = lambda x: x, (lambda x: x, lambda x: x) + o_fn1, (o_fn2, o_fn3) = o_fns + return o_fn1(x), (o_fn2(y), o_fn3(z)) + + @jax.jit + @fuser.fuse + def g(x, y, z): + x_res, (y_res, z_res) = f(x, y, z) + return (x_res * 2, (y_res, z_res + z_res)) + + x = jax.random.normal(jax.random.key(0), (128, 128), dtype=jnp.float32) + y = jax.random.normal(jax.random.key(1), (1, 128), dtype=jnp.float32) + z = jax.random.normal(jax.random.key(1), (128, 1), dtype=jnp.float32) + x_out, (y_out, z_out) = g(x, y, z) + np.testing.assert_array_equal(x_out, x * 2) + np.testing.assert_array_equal(y_out, y) + np.testing.assert_array_equal(z_out, z + z) + + def test_separate_output_fusions_with_reused_value(self): + + @fuser.fusible(output_fusion_prefix=(True, True)) + def f(x_fn, y_fn, z_fns): + x = x_fn() + y = y_fn() + if z_fns is None: + z_fns = lambda x: x, lambda x: x + z_fn1, z_fn2 = z_fns + return z_fn1(x), z_fn2(y) + + @jax.jit + @fuser.fuse + def g(x, y, a): + x_res, y_res = f(x, y) + return y_res + a, (x_res * 2, x_res + x_res + a) + + x = jax.random.normal(jax.random.key(0), (128, 128), dtype=jnp.float32) + y = jax.random.normal(jax.random.key(1), (1, 128), dtype=jnp.float32) + a = jax.random.normal(jax.random.key(1), (1, 128), dtype=jnp.float32) + y_out, (x1_out, x2_out) = g(x, y, a) + np.testing.assert_array_equal(x1_out, x * 2) + np.testing.assert_array_equal(x2_out, x + x + a) + np.testing.assert_array_equal(y_out, y + a) + + def test_empty_fusion(self): + + @fuser.fusible + def f(x_fn, y_fn): + x = x_fn() + if y_fn is None: + y_fn = lambda x: x + return y_fn(x) + + @jax.jit + @fuser.fuse + def g(x, a): + _ = f(x) + return a + + x = jax.random.normal(jax.random.key(0), (128, 128), dtype=jnp.float32) + a = jax.random.normal(jax.random.key(1), (128, 128), dtype=jnp.float32) + y_out = g(x, a) + np.testing.assert_array_equal(y_out, a) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/indexing_test.py b/tests/pallas/indexing_test.py index c3f3fa6e80a8..cb862b406603 100644 --- a/tests/pallas/indexing_test.py +++ b/tests/pallas/indexing_test.py @@ -14,7 +14,6 @@ from __future__ import annotations import sys -import unittest from absl.testing import absltest from absl.testing import parameterized @@ -32,11 +31,7 @@ else: pltpu = None -try: - import hypothesis as hp -except (ModuleNotFoundError, ImportError): - raise unittest.SkipTest("tests depend on hypothesis library") - +import hypothesis as hp import hypothesis.extra.numpy as hnp import hypothesis.strategies as hps @@ -95,7 +90,7 @@ def array_indexer_strategy(draw, shape) -> jax.Array: @hps.composite def indexer_strategy(draw, dim, int_indexer_shape - ) -> int | Slice | jax.Array: + ) -> int | Slice | jax.Array: return draw(hps.one_of( int_indexer_strategy(dim), slice_indexer_strategy(dim), @@ -127,6 +122,7 @@ def pallas_call(cls, *args, **kwargs): return pl.pallas_call(*args, interpret=cls.INTERPRET, **kwargs) +@jtu.thread_unsafe_test_class() # hypothesis is not thread safe class IndexerTest(jtu.JaxTestCase): """These are unit tests for the indexer logic, not using pallas_call.""" @@ -217,12 +213,13 @@ def test_indexer_with_all_types(self): indices = (ds(0, 2), np.arange(5)[:, None], np.arange(4)[None]) indexer = NDIndexer.from_indices_shape(indices, shape) - self.assertTupleEqual(indexer.get_indexer_shape(), (5, 4, 2)) + self.assertTupleEqual(indexer.get_indexer_shape(), (2, 5, 4)) @hp.given(hps.data()) def test_ndindexer(self, data): shape = data.draw(hnp.array_shapes()) indexer = data.draw(nd_indexer_strategy(shape)) + is_int_indexer = [not isinstance(idx, Slice) for idx in indexer.indices] rest_indexers, int_indexers = util.partition_list( is_int_indexer, indexer.indices @@ -234,18 +231,15 @@ def test_ndindexer(self, data): self.assertTupleEqual( indexer.int_indexer_shape, expected_int_indexer_shape ) + for idx in rest_indexers: self.assertIsInstance(idx, (np.ndarray, Slice)) if isinstance(idx, np.ndarray): self.assertTupleEqual(idx.shape, ()) self.assertEqual(idx.dtype, np.dtype("int32")) - rest_shape = tuple( - r.size for r in rest_indexers if not isinstance(r, np.ndarray) - ) - self.assertTupleEqual((*indexer.int_indexer_shape, *rest_shape), - indexer.get_indexer_shape()) +@jtu.thread_unsafe_test_class() # hypothesis is not thread safe class IndexerOpsTest(PallasBaseTest): def test_multi_indexing_interpreter_only(self): @@ -373,7 +367,7 @@ def permute_columns_in_row_kernel(left, right, new_left, new_right): def test_vmap_nd_indexing(self, data): self.skipTest("TODO(necula): enable this test; was in jax_triton.") vmap_shape = data.draw(hnp.array_shapes(min_dims=1, max_dims=3, min_side=2), - label="vmap_shape") + label="vmap_shape") el_shape = data.draw(hnp.array_shapes(min_dims=2), label="el_shape") # TODO(sharadmv,apaszke): enable rank 0 and rank 1 Refs # hp.assume(len(el_shape) >= 2) @@ -390,7 +384,7 @@ def kernel(x_ref, y_ref): shape = el_shape for vmap_dim in vmap_shape[::-1]: index = data.draw(hps.integers(min_value=0, - max_value=max(0, len(shape) - 2)), + max_value=max(0, len(shape) - 2)), label="index") # hp.assume(index <= max(0, len(shape) - 2)) # TODO(sharadmv,apaszke): enable vmapping over batch axes in 2 minormost @@ -641,6 +635,34 @@ def kernel(x_ref, indices, y_ref): )(x, indices) self.assertAllClose(res[:, start : start + 1, :], x, atol=0., rtol=0.) + def test_scalar_load_from_vmem(self): + if not jtu.is_device_tpu_at_least(4): + self.skipTest("Requires TPU v4 or later") + def kernel(x_ref, o_ref, sem_ref): + o_ref[...] = jnp.zeros_like(o_ref) + scalar_val = x_ref[1, 2] + # Use scalar_val in both async_copy and store. + o_ref[scalar_val] = jnp.ones_like(o_ref[0]) * scalar_val + desc = pltpu.make_async_copy( + o_ref.at[scalar_val], + o_ref.at[scalar_val + 1], + sem_ref, + ) + desc.start() + desc.wait() + + x = jnp.array([[1, 2, 3], [4, 5, 6]], dtype=jnp.int32) + res = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((8, 8, 128), jnp.int32), + grid=(1,), + scratch_shapes=[pltpu.SemaphoreType.DMA] + )(x) + expected = jnp.zeros_like(res) + expected = expected.at[6].set(jnp.ones((8, 128), jnp.int32) * 6) + expected = expected.at[7].set(jnp.ones((8, 128), jnp.int32) * 6) + self.assertArraysEqual(res, expected) + class IndexerOpsInterpretTest(IndexerOpsTest): INTERPRET = True @@ -662,18 +684,18 @@ class IndexerOpsInterpretTest(IndexerOpsTest): ((4, 3), lambda arr, a, b, c, d: arr[a, 2]), # slice + 1-D array ((4, 3), lambda arr, a, b, c, d: arr[a, :]), - # ((4, 3), lambda arr, a, b, c, d: arr[:, a]), + ((4, 3), lambda arr, a, b, c, d: arr[:, a]), ((6, 8, 3), lambda arr, a, b, c, d: arr[c, ::3]), - # ((8, 6, 3), lambda arr, a, b, c, d: arr[::3, c]), - # ((8, 8, 3), lambda arr, a, b, c, d: arr[::4, ::2, a]), - # ((8, 8, 3), lambda arr, a, b, c, d: arr[::4, a, ::2]), + ((8, 6, 3), lambda arr, a, b, c, d: arr[::3, c]), + ((8, 8, 3), lambda arr, a, b, c, d: arr[::4, ::2, a]), + ((8, 8, 3), lambda arr, a, b, c, d: arr[::4, a, ::2]), ((8, 8, 3, 7), lambda arr, a, b, c, d: arr[b, ::4, a, ::2]), ((3, 8, 8, 7), lambda arr, a, b, c, d: arr[b, a, ::4, ::2]), # ((8, 8, 3, 7), lambda arr, a, b, c, d: arr[::4, b, a, ::2]), ((16, 3, 6, 2), lambda arr, a, b, c, d: arr[::4, a, 1::2, b]), ((8, 8, 3, 6), lambda arr, a, b, c, d: arr[b, ::4, a, a]), # slice + array w/ broadcasting - ((8, 8, 3, 6), lambda arr, a, b, c, d: \ + ((8, 8, 3, 6), lambda arr, a, b, c, d: arr[b[:, None], ::4, a[None], a[:, None]]), # integer + slice + 1-D array ((5, 8, 8, 3), lambda arr, a, b, c, d: arr[2, ::4, ::2, a]), diff --git a/tests/pallas/mgpu_attention_test.py b/tests/pallas/mgpu_attention_test.py index cf8ed30925bf..27588683d0e9 100644 --- a/tests/pallas/mgpu_attention_test.py +++ b/tests/pallas/mgpu_attention_test.py @@ -62,6 +62,7 @@ def setUp(self): attention_mgpu.attention, attention_mgpu.attention_with_pipeline_emitter, ), + save_residuals=(True,), ) def test_flash_attention( self, @@ -71,22 +72,28 @@ def test_flash_attention( num_q_and_kv_heads, head_dim, attention_impl, + save_residuals, ): num_q_heads, num_kv_heads = num_q_and_kv_heads k1, k2, k3 = jax.random.split(jax.random.key(42), 3) q = jax.random.normal(k1, (batch_size, q_seq_len, num_q_heads, head_dim), jnp.float16) k = jax.random.normal(k2, (batch_size, kv_seq_len, num_kv_heads, head_dim), jnp.float16) v = jax.random.normal(k3, (batch_size, kv_seq_len, num_kv_heads, head_dim), jnp.float16) - out = attention_impl( + out, *res = attention_impl( q, k, v, attention_mgpu.TuningConfig( block_q=64, block_kv=64, max_concurrent_steps=2 ), + save_residuals=save_residuals, ) - out_ref = attention_mgpu.attention_reference(q, k, v) + out_ref, *res_ref = attention_mgpu.attention_reference(q, k, v, save_residuals=save_residuals) np.testing.assert_allclose(out, out_ref, atol=2e-3, rtol=1e-3) + if save_residuals: + (lse,) = res[0] + (lse_ref,) = res_ref[0] + np.testing.assert_allclose(lse, lse_ref, atol=2e-3, rtol=1e-3) if __name__ == "__main__": diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index b3c3ddb84e09..9ad9038dfc49 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -13,23 +13,33 @@ # limitations under the License. import contextlib +import dataclasses import functools import math import operator import os import re import tempfile +from typing import ClassVar from absl.testing import absltest from absl.testing import parameterized import jax from jax import lax +from jax._src import lib as jaxlib from jax._src import test_util as jtu +from jax._src.pallas import pallas_call +from jax._src.pallas import core as pallas_core +from jax._src.pallas.mosaic_gpu import core as gpu_core +from jax._src.pallas.mosaic_gpu import lowering as mgpu_lowering from jax._src.pallas.mosaic_gpu import pipeline as mgpu_pipeline +from jax._src.pallas.mosaic_gpu import primitives as mgpu_primitives from jax.experimental import pallas as pl +import jax.experimental.mosaic.gpu as mgpu from jax.experimental.pallas import mosaic_gpu as plgpu import jax.numpy as jnp import numpy as np + try: from jax._src.lib import mosaic_gpu as mosaic_gpu_lib except ImportError: @@ -54,14 +64,44 @@ def _sum_same_dtype(x): return jnp.sum(x, dtype=x.dtype) -class PallasTest(jtu.JaxTestCase): +class PallasTestMetaclass(parameterized.TestGeneratorMetaclass): + + def __new__(mcs, *args, lowering_semantics=plgpu.LoweringSemantics.Lane): + cls = super().__new__(mcs, *args) + cls.LOWERING_SEMANTICS = lowering_semantics + return cls + + +class PallasTest(jtu.JaxTestCase, metaclass=PallasTestMetaclass): + LOWERING_SEMANTICS: ClassVar[plgpu.LoweringSemantics] def setUp(self): if not jtu.is_cuda_compute_capability_at_least("9.0"): self.skipTest("Only works on a GPU with capability >= sm90") + context_stack = contextlib.ExitStack() + context_stack.enter_context(pallas_call._PALLAS_USE_MOSAIC_GPU(True)) + self.addCleanup(context_stack.close) super().setUp() + def skip_if_wg_semantics(self): + if self.LOWERING_SEMANTICS == plgpu.LoweringSemantics.Warpgroup: + self.skipTest("Not supported under WG semantics") + + def kernel(self, *args, **kwargs): + compiler_params = dataclasses.replace( + kwargs.pop("compiler_params", plgpu.GPUCompilerParams()), + lowering_semantics=self.LOWERING_SEMANTICS, + ) + return plgpu.kernel(*args, compiler_params=compiler_params, **kwargs) + + def pallas_call(self, *args, **kwargs): + compiler_params = dataclasses.replace( + kwargs.pop("compiler_params", plgpu.GPUCompilerParams()), + lowering_semantics=self.LOWERING_SEMANTICS, + ) + return pl.pallas_call(*args, compiler_params=compiler_params, **kwargs) + @contextlib.contextmanager def capture_stdout(self): if mosaic_gpu_lib is None: @@ -79,6 +119,13 @@ def setUp(self): super().setUp() +class PallasSm100ATest(PallasTest, jtu.CudaArchSpecificTest): + + def setUp(self): + self.skip_unless_sm100a() + super().setUp() + + class PallasCallTest(PallasTest): @parameterized.product( @@ -93,17 +140,14 @@ class PallasCallTest(PallasTest): lax.log, ], approx_math=[True, False], - thread_semantics=[*plgpu.ThreadSemantics], ) - def test_unary_op(self, op, approx_math, thread_semantics): + def test_unary_op(self, op, approx_math): dtype = jnp.int32 if op is lax.bitwise_not else jnp.float32 @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], dtype), - compiler_params=plgpu.GPUCompilerParams( - approx_math=approx_math, thread_semantics=thread_semantics - ), + compiler_params=plgpu.GPUCompilerParams(approx_math=approx_math), ) def kernel(x_ref, o_ref): o_ref[...] = op(x_ref[...]) @@ -124,16 +168,10 @@ def kernel(x_ref, o_ref): jnp.maximum, ], dtype=[jnp.float32, jnp.int32, jnp.uint32], - thread_semantics=[*plgpu.ThreadSemantics], ) - def test_binary_op(self, op, dtype, thread_semantics): - + def test_binary_op(self, op, dtype): @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct([256], dtype), - compiler_params=plgpu.GPUCompilerParams( - thread_semantics=thread_semantics - ), + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], dtype) ) def kernel(x_ref, y_ref, o_ref): o_ref[...] = op(x_ref[...], y_ref[...]) @@ -154,16 +192,10 @@ def kernel(x_ref, y_ref, o_ref): ], # TODO(slebedev): Support integral types. dtype=[jnp.float32, jnp.int32, jnp.uint32], - thread_semantics=[*plgpu.ThreadSemantics], ) - def test_comparison_op(self, op, dtype, thread_semantics): - + def test_comparison_op(self, op, dtype): @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct([256], dtype), - compiler_params=plgpu.GPUCompilerParams( - thread_semantics=thread_semantics - ), + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], dtype) ) def kernel(o_ref): o_ref[...] = jnp.broadcast_to( @@ -173,8 +205,9 @@ def kernel(o_ref): np.testing.assert_array_equal(kernel(), jnp.full([256], op(42, 24), dtype)) def test_add_first(self): + @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.float32), ) def kernel(x_ref, y_ref, o_ref): @@ -184,16 +217,10 @@ def kernel(x_ref, y_ref, o_ref): y = jnp.flip(x).reshape(1, 256) np.testing.assert_array_equal(kernel(x, y), x + y[0]) - @parameterized.product( - shape=[(128,), (128, 128)], thread_semantics=[*plgpu.ThreadSemantics] - ) - def test_reduce_sum(self, shape, thread_semantics): + @parameterized.product(shape=[(128,), (128, 128)]) + def test_reduce_sum(self, shape): @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct(shape, jnp.float32), - compiler_params=plgpu.GPUCompilerParams( - thread_semantics=thread_semantics - ), + self.pallas_call, out_shape=jax.ShapeDtypeStruct(shape, jnp.float32) ) def kernel(x_ref, o_ref): o_ref[...] = jnp.broadcast_to(_sum_same_dtype(x_ref[...]), o_ref.shape) @@ -202,11 +229,12 @@ def kernel(x_ref, o_ref): np.testing.assert_array_equal(kernel(x), jnp.sum(x)) def test_reshape(self): + self.skip_if_wg_semantics() + shape1, shape2 = (128,), (2, 16, 4) @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct(shape2, jnp.float32), + self.pallas_call, out_shape=jax.ShapeDtypeStruct(shape2, jnp.float32) ) def kernel(x_ref, out_ref): x_ref_reshaped = x_ref.reshape(shape2) @@ -217,14 +245,9 @@ def kernel(x_ref, out_ref): x = jnp.arange(math.prod(shape1)).astype(jnp.float32) np.testing.assert_array_equal(kernel(x), x.reshape(shape2)) - @parameterized.product(thread_semantics=[*plgpu.ThreadSemantics]) - def test_add_xy_indexed(self, thread_semantics): + def test_add_xy_indexed(self): @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct([128], jnp.float32), - compiler_params=plgpu.GPUCompilerParams( - thread_semantics=thread_semantics - ), + self.pallas_call, out_shape=jax.ShapeDtypeStruct([128], jnp.float32) ) def kernel(x_ref, y_ref, o_ref): idx = _sum_same_dtype(y_ref[...]) @@ -235,8 +258,9 @@ def kernel(x_ref, y_ref, o_ref): np.testing.assert_array_equal(kernel(x, y), x[jnp.sum(y)]) def test_add_one_grid(self): + @functools.partial( - pl.pallas_call, + self.pallas_call, in_specs=[pl.BlockSpec((128,), lambda *i: i)], out_specs=pl.BlockSpec((128,), lambda *i: i), out_shape=jax.ShapeDtypeStruct([128 * 2], jnp.float32), @@ -249,9 +273,8 @@ def kernel(x_ref, o_ref): np.testing.assert_array_equal(kernel(x), x + 1.0) def test_add_one_grid_with_scratch(self): - @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([128 * 2], jnp.float32), in_specs=[pl.BlockSpec((128,), lambda *i: i)], out_specs=pl.BlockSpec((128,), lambda *i: i), @@ -267,9 +290,8 @@ def kernel(x_ref, o_ref, scratch_ref): @parameterized.product(max_concurrent_steps=[1, 2, 3, 4, 16]) def test_add_one_grid_pipelined(self, max_concurrent_steps): - @functools.partial( - pl.pallas_call, + self.pallas_call, in_specs=[pl.BlockSpec((128, 16), lambda i, j: (i, j))], out_specs=pl.BlockSpec((128, 16), lambda i, j: (i, j)), out_shape=jax.ShapeDtypeStruct([128 * 2, 64], jnp.float32), @@ -286,9 +308,8 @@ def kernel(x_ref, o_ref): np.testing.assert_array_equal(kernel(x), x + 1.0) def test_add_one_grid_pipelined_program_id(self): - @functools.partial( - pl.pallas_call, + self.pallas_call, out_specs=pl.BlockSpec((16, 16), lambda i, j: (i, j)), out_shape=jax.ShapeDtypeStruct([16, 64], jnp.int32), compiler_params=plgpu.GPUCompilerParams( @@ -306,8 +327,9 @@ def kernel(o_ref): ) def test_add_one_grid_pipelined_sequential_invariant_output(self): + @functools.partial( - pl.pallas_call, + self.pallas_call, in_specs=[pl.BlockSpec((32, 16), lambda i, j: (i, j))], out_specs=pl.BlockSpec((32, 16), lambda i, j: (i, 0)), out_shape=jax.ShapeDtypeStruct([32 * 2, 64], jnp.float32), @@ -334,30 +356,71 @@ def kernel(x_ref, o_ref): @parameterized.parameters(jnp.float32, jnp.int32, jnp.uint32) def test_iota(self, dtype): + self.skip_if_wg_semantics() + dimension = 1 + @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct((128, 128), dtype), + self.pallas_call, out_shape=jax.ShapeDtypeStruct((128, 128), dtype) ) def kernel(o_ref): - o_ref[...] = plgpu.broadcasted_iota(dtype, (128, 128), dimension, layout=plgpu.Layout.WGMMA) + o_ref[...] = plgpu.broadcasted_iota( + dtype, o_ref.shape, dimension, layout=plgpu.Layout.WGMMA + ) - np.testing.assert_array_equal(kernel(), jax.lax.broadcasted_iota(dtype, (128, 128), dimension)) + np.testing.assert_array_equal( + kernel(), jax.lax.broadcasted_iota(dtype, (128, 128), dimension) + ) - @parameterized.product( - indexer=[..., slice(128), slice(None, 128)], - thread_semantics=[*plgpu.ThreadSemantics], - ) - def test_copy_smem_to_gmem(self, indexer, thread_semantics): + def test_inline_mgpu(self): + dtype = jnp.bfloat16 + self.skip_if_wg_semantics() + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((128, 128), dtype), + in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),), + scratch_shapes=[ + plgpu.SMEM((128, 128), dtype), + plgpu.Barrier(num_arrivals=1), + ], + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + ) + def kernel(x_ref, o_ref, smem_ref, barrier): + plgpu.copy_gmem_to_smem(x_ref, smem_ref, barrier) + plgpu.barrier_wait(barrier) + layout = plgpu.Layout.WG_STRIDED(x_ref.shape, vec_size=4) + @plgpu.inline_mgpu( + arg_types=(plgpu.RefType(),), + return_type=plgpu.GPUShapeDtypeStruct( + (128, 128), dtype, layout=layout + ), + ) + def foo(ctx, smem_ref): + del ctx + x = mgpu.FragmentedArray.load_strided(smem_ref) + y = mgpu.FragmentedArray.splat( + mgpu.c(1, x.mlir_dtype), shape=x.shape, layout=x.layout + ) + return (x + y) + + arr = foo(smem_ref) + @plgpu.inline_mgpu(arg_types=(layout, plgpu.RefType())) + def store(ctx, arr, o_ref): + del ctx + arr.store_untiled(o_ref) + store(arr, o_ref) + key = jax.random.key(0) + x = (jax.random.uniform(key, (128, 128)) * 42).astype(dtype) + np.testing.assert_array_equal(kernel(x), x + 1) + + @parameterized.product(indexer=[..., slice(128), slice(None, 128)]) + def test_copy_smem_to_gmem(self, indexer): @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.float32), out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), scratch_shapes=[plgpu.SMEM((256,), jnp.float32)], - compiler_params=plgpu.GPUCompilerParams( - thread_semantics=thread_semantics - ), ) def kernel(x_ref, o_ref_gmem, scratch_ref): scratch_ref[...] = x_ref[...] + 1 @@ -368,6 +431,29 @@ def kernel(x_ref, o_ref_gmem, scratch_ref): x = jnp.arange(256).astype(jnp.float32) np.testing.assert_array_equal(kernel(x)[indexer], x[indexer] + 1.0) + @parameterized.parameters(jnp.bfloat16, jnp.float16, jnp.float32) + def test_copy_smem_to_gmem_reduction(self, dtype): + @functools.partial( + pl.pallas_call, + grid=(200,), + in_specs=[pl.BlockSpec((128,), lambda *i: i), pl.BlockSpec(memory_space=plgpu.GMEM)], + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + out_shape=jax.ShapeDtypeStruct([128], dtype), + scratch_shapes=[plgpu.SMEM((128,), dtype)], + input_output_aliases={1:0} + ) + def kernel(x_ref, o_ref_gmem, o_ref_gmem_alias, scratch_ref): + del o_ref_gmem_alias + scratch_ref[...] = x_ref[...] + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(scratch_ref.at[...], o_ref_gmem.at[...], reduction_op="add") + plgpu.wait_smem_to_gmem(0) + x = jnp.ones(200 * 128).astype(dtype) # 200 blocks + output = jnp.zeros(128).astype(dtype) + output = kernel(x, output) + output_val = x.reshape(-1, 128).sum(axis=0) + np.testing.assert_array_equal(output, output_val) + @parameterized.named_parameters( {"testcase_name": "1d_none", "shape": (256,), "indexers": (slice(0, 128), slice(None, 32))}, @@ -377,8 +463,9 @@ def kernel(x_ref, o_ref_gmem, scratch_ref): "shape": (64, 64), "indexers": (4, slice(0, 64))}, ) def test_copy_smem_to_gmem_with_multiple_gmem_indexers(self, shape, indexers): + @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct(shape, jnp.float32), out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), scratch_shapes=[plgpu.SMEM(shape, jnp.float32)], @@ -402,8 +489,9 @@ def kernel(x_ref, o_ref_gmem, scratch_ref): @parameterized.product(indexer=[..., slice(128), slice(None, 128)]) def test_copy_gmem_to_smem(self, indexer): + @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.float32), in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),), scratch_shapes=[ @@ -447,13 +535,15 @@ def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref): }, ) def test_copy_gmem_to_smem_with_multiple_gmem_indexers(self, shape, indexers): + @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct(shape, jnp.float32), in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),), - scratch_shapes=[plgpu.SMEM(shape, jnp.float32), - plgpu.Barrier(num_arrivals=1), - ], + scratch_shapes=[ + plgpu.SMEM(shape, jnp.float32), + plgpu.Barrier(num_arrivals=1), + ], grid=(1,), ) def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref): @@ -478,7 +568,7 @@ def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref): def test_gmem_to_smem_with_multiple_smem_indexers(self): x = jax.random.uniform(jax.random.key(0), (2, 64, 64), dtype=jnp.float32) @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([64, 64], jnp.float32), in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),), scratch_shapes=[ @@ -495,21 +585,31 @@ def extract_x0(x_ref_gmem, o_ref, scratch_ref, barrier_ref): np.testing.assert_array_equal(extract_x0(x), x[0]) def test_gmem_to_smem_with_multiple_smem_indexers_and_transforms(self): + self.skip_if_wg_semantics() + x = jnp.arange(512 * 512, dtype=jnp.int32).reshape(512, 512) @functools.partial( - pl.pallas_call, + self.pallas_call, grid=(4, 4), out_shape=jax.ShapeDtypeStruct((256, 128), jnp.int32), - in_specs=(plgpu.GPUBlockSpec( - block_shape=(128, 128), - index_map=lambda i, j: (i, j), - memory_space=plgpu.SMEM, - transforms=(plgpu.TilingTransform((64, 32)), - plgpu.SwizzleTransform(128))),), - out_specs=(plgpu.GPUBlockSpec( - block_shape=(64, 32), - index_map=lambda i, j: (i, j), - memory_space=plgpu.SMEM,)), + in_specs=( + plgpu.GPUBlockSpec( + block_shape=(128, 128), + index_map=lambda i, j: (i, j), + memory_space=plgpu.SMEM, + transforms=( + plgpu.TilingTransform((8, 32)), + plgpu.SwizzleTransform(128), + ), + ), + ), + out_specs=( + plgpu.GPUBlockSpec( + block_shape=(64, 32), + index_map=lambda i, j: (i, j), + memory_space=plgpu.SMEM, + ) + ), ) def kernel(x_ref, o_ref): x_sliced = x_ref.at[0:64, 32:96].at[:, 0:32] # get x_ref[0:64, 32:64] @@ -521,8 +621,9 @@ def kernel(x_ref, o_ref): @parameterized.product(indexer=[0, 1, 2, 3]) def test_copy_gmem_to_smem_with_indexed_barrier(self, indexer): + @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([128], jnp.float32), in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),), scratch_shapes=[ @@ -542,6 +643,8 @@ def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref): @parameterized.named_parameters(("_g2s", False), ("_s2g", True)) def test_copy_with_transforms(self, to_smem): + self.skip_if_wg_semantics() + def kernel(x_ref, o_ref, barrier_ref): if to_smem: plgpu.copy_gmem_to_smem(x_ref, o_ref, barrier_ref) @@ -553,17 +656,15 @@ def kernel(x_ref, o_ref, barrier_ref): in_spec = pl.BlockSpec(memory_space=plgpu.GMEM) out_spec = plgpu.GPUBlockSpec( - (128, 128), - lambda: (0, 0), transforms=( - plgpu.TilingTransform((64, 32)), + plgpu.TilingTransform((8, 32)), plgpu.SwizzleTransform(128), ), memory_space=plgpu.SMEM, ) if not to_smem: in_spec, out_spec = out_spec, in_spec - f = pl.pallas_call( + f = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct([128, 128], jnp.float32), in_specs=(in_spec,), @@ -574,7 +675,9 @@ def kernel(x_ref, o_ref, barrier_ref): np.testing.assert_array_equal(f(x), x) def test_scoped_copy_with_transforms(self): - ts = (plgpu.TilingTransform((64, 32)), plgpu.SwizzleTransform(128)) + self.skip_if_wg_semantics() + + ts = (plgpu.TilingTransform((8, 32)), plgpu.SwizzleTransform(128)) def kernel(x_ref, o_ref, barrier_ref): def body(tmp_ref): plgpu.copy_gmem_to_smem(x_ref, tmp_ref, barrier_ref) @@ -583,20 +686,40 @@ def body(tmp_ref): pl.run_scoped(body, plgpu.SMEM((128, 128), jnp.float32, transforms=ts)) in_spec = pl.BlockSpec(memory_space=plgpu.GMEM) - out_spec = plgpu.GPUBlockSpec( - (128, 128), lambda: (0, 0), transforms=ts, memory_space=plgpu.SMEM, + out_spec = plgpu.GPUBlockSpec(transforms=ts, memory_space=plgpu.SMEM) + f = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct([128, 128], jnp.float32), + in_specs=(in_spec,), + out_specs=out_spec, + scratch_shapes=[plgpu.Barrier(num_arrivals=1)], ) + x = jnp.arange(128 * 128, dtype=jnp.float32).reshape(128, 128) + np.testing.assert_array_equal(f(x), x * 2) + + def test_scoped_copy_with_user_transforms(self): + def kernel(x_ref, o_ref, barrier_ref): + def body(tmp_ref): + tmp_ref = plgpu.unswizzle_ref(tmp_ref, 128) + tmp_ref = plgpu.untile_ref(tmp_ref, (8, 32)) + plgpu.copy_gmem_to_smem(x_ref, tmp_ref, barrier_ref) + plgpu.barrier_wait(barrier_ref) + o_ref[...] = tmp_ref[...] * 2 + pl.run_scoped(body, plgpu.SMEM((16, 4, 8, 32), jnp.float32)) + + in_spec = pl.BlockSpec(memory_space=plgpu.GMEM) f = pl.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct([128, 128], jnp.float32), in_specs=(in_spec,), - out_specs=out_spec, scratch_shapes=[plgpu.Barrier(num_arrivals=1)], ) x = jnp.arange(128 * 128, dtype=jnp.float32).reshape(128, 128) np.testing.assert_array_equal(f(x), x * 2) def test_copy_with_transforms_and_indexing(self): + self.skip_if_wg_semantics() + def kernel(x_ref, o_ref, barrier_ref): for i in range(2): plgpu.copy_gmem_to_smem(x_ref, o_ref.at[i], barrier_ref) @@ -604,16 +727,14 @@ def kernel(x_ref, o_ref, barrier_ref): in_spec = pl.BlockSpec(memory_space=plgpu.GMEM) out_spec = plgpu.GPUBlockSpec( - (2, 128, 128), - lambda: (0, 0, 0), transforms=( - plgpu.TilingTransform((64, 32)), + plgpu.TilingTransform((8, 32)), plgpu.TransposeTransform((0, 2, 1, 3, 4)), plgpu.SwizzleTransform(128), ), memory_space=plgpu.SMEM, ) - f = pl.pallas_call( + f = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct([2, 128, 128], jnp.float32), in_specs=(in_spec,), @@ -623,7 +744,104 @@ def kernel(x_ref, o_ref, barrier_ref): x = jnp.arange(128 * 128, dtype=jnp.float32).reshape(128, 128) np.testing.assert_array_equal(f(x), np.stack([x, x], axis=0)) + @parameterized.product( + src_memory_space=[plgpu.SMEM, plgpu.GMEM], + layout=[plgpu.Layout.WG_STRIDED((128,), vec_size=1), None, + ] + ) + def test_load_to_strided_layout_with_indexing(self, src_memory_space, layout): + self.skip_if_wg_semantics() + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct([2, 128], jnp.float32), + in_specs=[pl.BlockSpec(memory_space=src_memory_space)], + out_specs=plgpu.GPUBlockSpec(memory_space=plgpu.SMEM), + ) + def kernel(x_ref, o_ref): + for i in range(2): + x = plgpu.load(x_ref, (i,), layout=layout) + o_ref[i, ...] = x + + x = jnp.arange(2 * 128, dtype=jnp.float32).reshape(2, 128) + np.testing.assert_array_equal(kernel(x), x) + + @parameterized.product( + src_memory_space=[plgpu.SMEM, plgpu.GMEM], + layout=[plgpu.Layout.WGMMA_ROW, plgpu.Layout.WGMMA_COL], + m=[64, 128, 192], + ) + def test_load_to_wgmma_row_col_layout_with_indexing(self, src_memory_space, layout, m): + self.skip_if_wg_semantics() + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct([2, m], jnp.float32), + in_specs=[pl.BlockSpec(memory_space=src_memory_space)], + out_specs=plgpu.GPUBlockSpec(memory_space=plgpu.SMEM), + ) + def kernel(x_ref, o_ref): + for i in range(2): + x = plgpu.load( + x_ref, (i,), layout=layout, optimized=src_memory_space != plgpu.GMEM + ) + o_ref[i, ...] = x + + x = jnp.arange(2 * m, dtype=jnp.float32).reshape(2, m) + np.testing.assert_array_equal(kernel(x), x) + + @parameterized.product( + src_memory_space=[plgpu.SMEM], + layout=[plgpu.Layout.WGMMA_ROW, plgpu.Layout.WGMMA_COL], + ) + def test_load_row_input_to_wgmma_with_transforms(self, src_memory_space, layout): + self.skip_if_wg_semantics() + + m, k, n = 64, 128, 192 + key1, key2 = jax.random.split(jax.random.key(42), 2) + if layout == plgpu.Layout.WGMMA_ROW: + input_shape = (m,) + broadcast_dim = 0 + expand_dim = 1 + else: + input_shape = (k,) + broadcast_dim = 1 + expand_dim = 0 + a = jax.random.uniform(key1, shape=input_shape, dtype=jnp.float16) + b = jax.random.uniform(key2, shape=(k, n), dtype=jnp.float16) + def kernel(x_ref, y_ref, o_ref): + x = plgpu.load(x_ref, (), layout=layout) + x = lax.broadcast_in_dim(x, (m, k), [broadcast_dim]) + + def compute(acc_ref): + plgpu.wgmma(acc_ref, x, y_ref) + return acc_ref[...] + + out = pl.run_scoped(compute, plgpu.ACC((m, n), jnp.float32)) + o_ref[...] = out + f = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct([m, n], jnp.float32), + in_specs=( + pl.BlockSpec(memory_space=src_memory_space), + plgpu.GPUBlockSpec( + transforms=( + plgpu.TilingTransform((8, 64)), + plgpu.SwizzleTransform(128), + ), + ), + ), + out_specs=plgpu.GPUBlockSpec(memory_space=plgpu.SMEM), + ) + + out_ref = ( + jnp.broadcast_to(jnp.expand_dims(a, axis=expand_dim), (m, k)) @ b + ) + np.testing.assert_allclose(f(a, b), out_ref, rtol=1e-3) + def test_indexing_before_transpose(self): + self.skip_if_wg_semantics() + def kernel(x_ref, o_ref, barrier_ref): for i in range(2): plgpu.copy_gmem_to_smem( @@ -632,10 +850,8 @@ def kernel(x_ref, o_ref, barrier_ref): plgpu.barrier_wait(barrier_ref) in_spec = pl.BlockSpec(memory_space=plgpu.GMEM) - out_spec = plgpu.GPUBlockSpec( - (2, 64, 2, 128), lambda: (0, 0, 0, 0), memory_space=plgpu.SMEM, - ) - f = pl.pallas_call( + out_spec = plgpu.GPUBlockSpec(memory_space=plgpu.SMEM) + f = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct([2, 64, 2, 128], jnp.float32), in_specs=(in_spec,), @@ -647,8 +863,9 @@ def kernel(x_ref, o_ref, barrier_ref): np.testing.assert_array_equal(f(x), np.stack([xt, xt], axis=0)) def test_copy_gmem_to_smem_in_run_scoped(self): + @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.float32), in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),), ) @@ -665,8 +882,9 @@ def inner_body(scratch_ref): np.testing.assert_array_equal(kernel(x), x + 1.0) def test_add_doubled_sum(self): + @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([128], jnp.float32), ) def kernel(x_ref, o_ref): @@ -675,26 +893,6 @@ def kernel(x_ref, o_ref): x = jnp.arange(128).astype(jnp.float32) np.testing.assert_array_equal(kernel(x), x + x.sum()*2) - @parameterized.named_parameters( - ("rsqrt", jax.lax.rsqrt, ), - ("log", jax.lax.log, 5e-7), - ("exp", jax.lax.exp, ), - ("exp2", jax.lax.exp2, 5e-7), - ("logistic", jax.lax.logistic, ), - ("tanh", jax.lax.tanh, 5e-7), - ) - def test_approx_math_unary_op(self, unary_op, rtol=1e-7): - @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct([128], jnp.float32), - compiler_params=plgpu.GPUCompilerParams(approx_math=True), - ) - def kernel(x_ref, o_ref): - o_ref[...] = unary_op(x_ref[...]) - - x = jnp.arange(128).astype(jnp.float32) / 128 - np.testing.assert_allclose(kernel(x), unary_op(x), rtol=rtol, atol=1e-5) - @parameterized.product(input_factor=[0.001, 1, 10, 100, 100]) def test_layer_norm(self, input_factor): eps = 1e-5 @@ -702,7 +900,7 @@ def test_layer_norm(self, input_factor): beta = 1.0 @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.float32), ) def layer_norm(x_ref, o_ref): @@ -730,8 +928,9 @@ def layer_norm_np(x): np.testing.assert_allclose(layer_norm(x), layer_norm_np(x), rtol=5e-5) def test_print(self): + @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.float32), ) def kernel(x_ref, o_ref): @@ -744,16 +943,30 @@ def kernel(x_ref, o_ref): self.assertEqual(output(), "It works!\n") def test_print_wgmma_tiled_layout(self): + self.skip_if_wg_semantics() + shape = (128, 64) size = math.prod(shape) + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct(shape, jnp.float32), + in_specs=[ + plgpu.GPUBlockSpec( + transforms=( + plgpu.TilingTransform((8, 32)), + plgpu.SwizzleTransform(128), + ) + ) + ], + ) def kernel(x_ref, o_ref): + del o_ref # Unused. pl.debug_print("prefix {}", x_ref[...]) - spec = plgpu.GPUBlockSpec(shape, lambda: (0, 0), transforms=(plgpu.TilingTransform((64, 32)), plgpu.SwizzleTransform(128))) - x = jnp.arange(size, dtype=jnp.float32).reshape(shape) - f = pl.pallas_call(kernel, out_shape=x, in_specs=[spec], out_specs=spec) + x = jnp.arange(size, dtype=jnp.float32).reshape(shape) with self.capture_stdout() as get_output: - jax.block_until_ready(f(x)) + jax.block_until_ready(kernel(x)) output = get_output() results = re.findall(r"prefix \[(\d+), (\d+)\]: (\d+).?\d*", output) @@ -763,8 +976,10 @@ def kernel(x_ref, o_ref): self.assertEqual(v, i * shape[1] + j) def test_print_scalar(self): + self.skip_if_wg_semantics() + @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.int32), ) def kernel(x_ref, o_ref): @@ -778,8 +993,10 @@ def kernel(x_ref, o_ref): self.assertIn(f"x.sum() = {x.sum()}", output()) def test_print_scalar_array(self): + self.skip_if_wg_semantics() + @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.int32), ) def kernel(x_ref, o_ref): @@ -793,10 +1010,12 @@ def kernel(x_ref, o_ref): self.assertIn(f"x.sum() = {x.sum() + 1}", output()) def test_print_array(self): + self.skip_if_wg_semantics() + in_shape = [2, 1, 64, 64] @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct(in_shape, jnp.int32), ) def kernel(x_ref, o_ref): @@ -811,7 +1030,7 @@ def kernel(x_ref, o_ref): def test_load_scalar(self): @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct((128,), jnp.int32), in_specs=[plgpu.GPUBlockSpec(memory_space=plgpu.GPUMemorySpace.GMEM)], ) @@ -821,9 +1040,11 @@ def kernel(x_ref, o_ref): np.testing.assert_array_equal(kernel(jnp.arange(11, dtype=jnp.int32)), jnp.full((128,), 10, dtype=jnp.int32)) - @parameterized.product(thread_semantics=[*plgpu.ThreadSemantics]) - def test_run_scoped(self, thread_semantics): - + def test_run_scoped(self): + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), + ) def kernel(x_ref, o_ref): def body(tmp_ref): self.assertEqual(tmp_ref.shape, (8, 128)) @@ -834,20 +1055,32 @@ def body(tmp_ref): self.assertEqual(tmp.shape, (8, 128)) o_ref[...] = tmp - inp = np.ones((8, 128), jnp.float32) - f = pl.pallas_call( - kernel, - out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), - compiler_params=plgpu.GPUCompilerParams( - thread_semantics=thread_semantics - ), + x = np.ones((8, 128), jnp.float32) + np.testing.assert_array_equal(kernel(x), x + 1.0) + + def test_run_scoped_in_cond(self): + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct([256], jnp.int32), + in_specs=[pl.BlockSpec(memory_space=plgpu.GPUMemorySpace.GMEM)], + out_specs=pl.BlockSpec(memory_space=plgpu.GPUMemorySpace.SMEM), ) - o = f(inp) - np.testing.assert_array_equal(o, inp + 1.0) + def kernel(x_ref_gmem, o_ref): + def scoped_kernel(barrier_ref): + plgpu.copy_gmem_to_smem(x_ref_gmem, o_ref, barrier_ref) + plgpu.barrier_wait(barrier_ref) + + def branch(): + pl.run_scoped(scoped_kernel, plgpu.Barrier(num_arrivals=1)) + + jax.lax.cond(x_ref_gmem[0] % 2 == 0, branch, branch) + + x = jnp.full((256,), 1234, dtype=jnp.int32) + np.testing.assert_array_equal(kernel(x), x) def test_program_id(self): @functools.partial( - pl.pallas_call, + self.pallas_call, in_specs=(), out_specs=pl.BlockSpec((128,), lambda *i: i), out_shape=jax.ShapeDtypeStruct([128 * 2], jnp.int32), @@ -866,7 +1099,7 @@ def test_program_id_in_squashed_grid(self): # 3 CUDA grid dimensions. grid = (2, 3, 4, 5) @functools.partial( - pl.pallas_call, + self.pallas_call, in_specs=(), out_specs=pl.BlockSpec((1,) * len(grid) + (128,), lambda *i: (*i, 0)), out_shape=jax.ShapeDtypeStruct([*grid, 128], jnp.int32), @@ -887,7 +1120,7 @@ def kernel(o_ref): def test_program_id_in_block_spec(self): @functools.partial( - pl.pallas_call, + self.pallas_call, in_specs=(pl.BlockSpec((2, 128), lambda i: (pl.program_id(0), i)),), out_specs=pl.BlockSpec((2, 128), lambda i: (pl.program_id(0), i)), out_shape=jax.ShapeDtypeStruct([2, 128], jnp.int32), @@ -901,7 +1134,7 @@ def kernel(x_ref, o_ref): def test_num_programs(self): @functools.partial( - pl.pallas_call, + self.pallas_call, in_specs=(), out_specs=pl.BlockSpec((128,), lambda *i: i), out_shape=jax.ShapeDtypeStruct([128 * 2], jnp.int32), @@ -916,17 +1149,18 @@ def kernel(o_ref): ) def test_swizzled_blockspec_shapes(self): + self.skip_if_wg_semantics() spec = plgpu.GPUBlockSpec( (128, 64), lambda *i: i, transforms=( - plgpu.TilingTransform((64, 64)), + plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128), ), ) @functools.partial( - pl.pallas_call, + self.pallas_call, in_specs=[spec], out_specs=spec, out_shape=jax.ShapeDtypeStruct((128, 128), jnp.float16), @@ -939,14 +1173,10 @@ def kernel(x_ref, o_ref): x = jnp.arange(128 * 128).astype(jnp.float16).reshape(128, 128) np.testing.assert_array_equal(kernel(x), x) - @parameterized.product( - force_while=[False, True], thread_semantics=[*plgpu.ThreadSemantics] - ) - def test_fori_loop_array(self, force_while, thread_semantics): + @parameterized.product(force_while=[False, True]) + def test_fori_loop_array(self, force_while): @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct([256], jnp.int32), - compiler_params=plgpu.GPUCompilerParams(thread_semantics=thread_semantics), + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.int32) ) def kernel(x_ref, o_ref): # Equivalent to x_ref[...] + 2 + 3. @@ -955,14 +1185,10 @@ def kernel(x_ref, o_ref): x = jnp.arange(256, dtype=jnp.int32) np.testing.assert_array_equal(kernel(x), x + 2 + 3) - @parameterized.product( - force_while=[False, True], thread_semantics=[*plgpu.ThreadSemantics] - ) - def test_fori_loop_scalar(self, force_while, thread_semantics): + @parameterized.product(force_while=[False, True]) + def test_fori_loop_scalar(self, force_while): @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct([256], jnp.int32), - compiler_params=plgpu.GPUCompilerParams(thread_semantics=thread_semantics), + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.int32) ) def kernel(o_ref): # Equivalent to 2 + 3. @@ -974,9 +1200,8 @@ def kernel(o_ref): np.testing.assert_array_equal(kernel(), jnp.full([256], 5, jnp.int32)) def test_fori_loop_dynamic_bounds(self): - @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.int32), grid=(1,) ) @@ -989,16 +1214,10 @@ def kernel(o_ref): np.testing.assert_array_equal(kernel(), jnp.full([256], 5, dtype=jnp.int32)) - @parameterized.product( - force_while=[False, True], thread_semantics=[*plgpu.ThreadSemantics] - ) - def test_fori_loop_tuple(self, force_while, thread_semantics): + @parameterized.product(force_while=[False, True]) + def test_fori_loop_tuple(self, force_while): @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct([256], jnp.int32), - compiler_params=plgpu.GPUCompilerParams( - thread_semantics=thread_semantics - ), + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.int32) ) def kernel(o_ref): def body(step, xs): @@ -1017,16 +1236,11 @@ def body(step, xs): kernel(), jnp.full([256], 3 * (0 + 1), jnp.int32) ) - @parameterized.product( - force_while=[False, True], thread_semantics=[*plgpu.ThreadSemantics] - ) - def test_fori_loop_indexed_store(self, force_while, thread_semantics): + @parameterized.product(force_while=[False, True]) + def test_fori_loop_indexed_store(self, force_while): @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([4, 128], jnp.float32), - compiler_params=plgpu.GPUCompilerParams( - thread_semantics=thread_semantics - ), ) def kernel(x_ref, y_ref, o_ref): def body(idx, _): @@ -1039,17 +1253,11 @@ def body(idx, _): y = x + 1 np.testing.assert_array_equal(kernel(x, y), x + y) - @parameterized.product(thread_semantics=[*plgpu.ThreadSemantics]) - def test_while_loop(self, thread_semantics): - if thread_semantics == plgpu.ThreadSemantics.Warpgroup: - self.skipTest("WG lowering does not support reduce_sum_p needed for this test") + def test_while_loop(self): + self.skip_if_wg_semantics() @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct([128], jnp.int32), - compiler_params=plgpu.GPUCompilerParams( - thread_semantics=thread_semantics - ), + self.pallas_call, out_shape=jax.ShapeDtypeStruct([128], jnp.int32) ) def kernel(x_ref, o_ref): o_ref[...] = jnp.zeros(o_ref.shape, dtype=jnp.int32) @@ -1070,8 +1278,10 @@ def body(acc): ) def test_while_loop_layout_mismatch(self): + self.skip_if_wg_semantics() # while and conditional are not yet supported. + @functools.partial( - pl.pallas_call, out_shape=jax.ShapeDtypeStruct([128], jnp.int32) + self.pallas_call, out_shape=jax.ShapeDtypeStruct([128], jnp.int32) ) def kernel(o_ref): def cond(acc): @@ -1090,12 +1300,9 @@ def body(acc): with self.assertRaisesRegex(ValueError, "has layout .*, when it should be"): kernel() - @parameterized.parameters([*plgpu.ThreadSemantics]) - def test_cond(self, thread_semantics): + def test_cond(self): @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct([256], jnp.int32), - compiler_params=plgpu.GPUCompilerParams(thread_semantics=thread_semantics), + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.int32) ) def kernel(x_ref, o_ref): jax.lax.cond( @@ -1111,27 +1318,49 @@ def kernel(x_ref, o_ref): self.assertIn("acc % 2", output()) - @parameterized.parameters([*plgpu.ThreadSemantics]) - def test_cond_returning_array(self, thread_semantics): + def test_cond_returning_array(self): @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct([256], jnp.int32), - compiler_params=plgpu.GPUCompilerParams( - thread_semantics=thread_semantics - ), + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.int32) ) def kernel(x_ref, o_ref): - acc = _sum_same_dtype(x_ref[...]) + acc_sum = _sum_same_dtype(x_ref[...]) acc2, acc = jax.lax.cond( - acc % 2 == 0, - lambda: (acc * 2, acc), - lambda: (acc, acc * 2), + acc_sum % 2 == 0, + lambda: (acc_sum * 2, x_ref[...]), + lambda: (acc_sum, x_ref[...]), ) - o_ref[...] = jnp.broadcast_to(acc + acc2, o_ref.shape) + o_ref[...] = jnp.broadcast_to(_sum_same_dtype(acc) + acc2, o_ref.shape) x = jnp.arange(256, dtype=jnp.int32) np.testing.assert_array_equal(kernel(x), jnp.broadcast_to(jnp.sum(x) * 3, [256])) + def test_tile_slicing(self): + # Not testing with warpgroup semantics, because we want to enforce a layout. + self.skip_if_wg_semantics() + + shape = (256, 128) + block_spec = plgpu.GPUBlockSpec( + transforms=(plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128)) + ) + @functools.partial( + self.pallas_call, + in_specs=[block_spec], + out_specs=block_spec, + out_shape=jax.ShapeDtypeStruct((64, 64), jnp.uint16), + ) + def kernel(x_ref, o_ref): + def sum_tiles(row, acc): + row_slice = pl.ds(row * 64, 64) + for col in range(128 // 64): + acc += x_ref[row_slice, pl.ds(col * 64, 64)] + return acc + acc = plgpu.layout_cast(jnp.zeros((64, 64), jnp.uint16), plgpu.Layout.WGMMA) + o_ref[...] = _fori_loop(False, 0, 256 // 64, sum_tiles, acc) + + x = jnp.arange(math.prod(shape), dtype=jnp.uint16).reshape(shape) + y = x.reshape(256 // 64, 64, 128 // 64, 64).sum(axis=(0, 2), dtype=jnp.uint16) + np.testing.assert_array_equal(kernel(x), y) + def test_input_output_aliases(self): # Note that we're writing to the input pointer, which should alias b_ptr. def kernel(a_ref, b_ref): @@ -1139,7 +1368,7 @@ def kernel(a_ref, b_ref): a_ref[...] = jnp.ones_like(a_ref) a = np.zeros((64, 64), dtype=jnp.float32) - b = pl.pallas_call( + b = self.pallas_call( kernel, in_specs=[plgpu.GPUBlockSpec(memory_space=plgpu.GPUMemorySpace.GMEM)], out_specs=plgpu.GPUBlockSpec(memory_space=plgpu.GPUMemorySpace.GMEM), @@ -1149,6 +1378,8 @@ def kernel(a_ref, b_ref): np.testing.assert_array_equal(b, np.ones_like(a)) def test_slicing(self): + self.skip_if_wg_semantics() + left = upper = slice(None, 64) right = lower = slice(64, None) # We rotate the four quadrants of the input clockwise. @@ -1160,21 +1391,20 @@ def rotate(src, dst): x = jnp.arange(128 * 128).astype(jnp.float16).reshape(128, 128) spec = plgpu.GPUBlockSpec( - (128, 128), - lambda: (0, 0), - transforms=( - plgpu.TilingTransform((64, 64)), - plgpu.SwizzleTransform(128), - ), + transforms=(plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128)) ) - f = pl.pallas_call(rotate, out_shape=x, in_specs=[spec], out_specs=spec) + f = self.pallas_call(rotate, out_shape=x, in_specs=[spec], out_specs=spec) expected = np.empty_like(x) rotate(x, expected) np.testing.assert_array_equal(f(x), expected) def test_layout_cast(self, shape=(256, 64)): + # TODO(dasenov): Remove this after the minimal jaxlib version is 0.5.4. + if jaxlib.version < (0, 5, 4): + self.skip_if_wg_semantics() + @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct(shape, jnp.float32), ) def kernel(o_ref): @@ -1183,7 +1413,52 @@ def kernel(o_ref): x = jnp.full(shape, 42.0, jnp.float32) np.testing.assert_array_equal(kernel(), x) + @parameterized.parameters(False, True) + def test_wgmma_transposed_layout(self, store_transposed): + """Tests that the result of wgmma can be store transposed using + the WGMMA_TRNASPOSED layout. + """ + + dtype = jnp.dtype(jnp.float16) + swizzle_elems = 128 // dtype.itemsize + shape = (128, 128) + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct(shape, dtype), + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + scratch_shapes=[ + plgpu.SMEM( + shape, dtype, + transforms=( + plgpu.TilingTransform((8, swizzle_elems)), + plgpu.SwizzleTransform(128), + ), + ) + ] + ) + def kernel(o_ref, smem): + iota = plgpu.broadcasted_iota( + dtype, o_ref.shape, 0, layout=plgpu.Layout.WGMMA + ) * o_ref.shape[0] + iota += plgpu.broadcasted_iota( + dtype, o_ref.shape, 1, layout=plgpu.Layout.WGMMA + ) + + smem_trns = plgpu.transpose_ref(smem, (1, 0)) + smem_trns[...] = plgpu.layout_cast(iota, plgpu.Layout.WGMMA_TRANSPOSED) + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(smem_trns if store_transposed else smem, o_ref) + + x = jnp.arange(128 * 128, dtype=dtype).reshape((128, 128)).T + if store_transposed: + with self.assertRaises(ValueError): + kernel() + else: + np.testing.assert_array_equal(kernel(), x) + def test_profiler(self): + self.skip_if_wg_semantics() # Transform inference fails. + def kernel(x_ref, o_ref): with jax.named_scope("add"): with jax.named_scope("load"): @@ -1193,7 +1468,7 @@ def kernel(x_ref, o_ref): o_ref[...] = o with tempfile.TemporaryDirectory() as tmpdir: x = jnp.arange(256).astype(jnp.float32) - y = pl.pallas_call( + y = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct([256], jnp.float32), compiler_params=plgpu.GPUCompilerParams( @@ -1221,20 +1496,13 @@ def kernel(x_ref, o_ref): (jnp.uint32, jnp.int32), (jnp.int32, jnp.uint32), ], - thread_semantics=[*plgpu.ThreadSemantics], ) - def test_bitcast_convert_type(self, dtypes, thread_semantics): + def test_bitcast_convert_type(self, dtypes): in_dtype, out_dtype = dtypes m, n = 16, 8 out_shape = jax.ShapeDtypeStruct((m, n), out_dtype) - @functools.partial( - pl.pallas_call, - out_shape=out_shape, - compiler_params=plgpu.GPUCompilerParams( - thread_semantics=thread_semantics - ), - ) + @functools.partial(self.pallas_call, out_shape=out_shape) def convert(x_ref, y_ref): y_ref[...] = jax.lax.bitcast_convert_type(x_ref[...], out_shape) @@ -1243,17 +1511,140 @@ def convert(x_ref, y_ref): convert(x), jax.lax.bitcast_convert_type(x, out_dtype) ) + def test_optimization_barrier(self): + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((128,), jnp.float32), + ) + def kernel(x_ref, o_ref): + o_ref[...] = lax.optimization_barrier(x_ref[...]) + + x = jax.lax.iota(jnp.float32, 128) + np.testing.assert_array_equal(kernel(x), x) + + def test_optimization_barrier_multiple_inputs(self): + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((128,), jnp.float32), + ) + def kernel(x_ref, y_ref, o_ref): + x, y = lax.optimization_barrier([x_ref[...], y_ref[...]]) + o_ref[...] = x + y + + x = jax.lax.iota(jnp.float32, 128) + y = jax.lax.iota(jnp.float32, 128) * 3 + np.testing.assert_array_equal(kernel(x, y), x + y) + + def test_warp_specialization_axis_index(self): + if self.LOWERING_SEMANTICS != plgpu.LoweringSemantics.Lane: + self.skipTest("Test only works on Lane semantics") + warp_mesh = plgpu.WarpMesh(axis_name="warp") + @functools.partial(plgpu.kernel, + out_shape=jax.ShapeDtypeStruct((2, 128), jnp.int32)) + def kernel(y_ref): + def scope(ones_smem_ref, threes_smem_ref): + # Prepare data to copy. + ones_smem_ref[:] = jnp.ones((1, 128), jnp.int32) + threes_smem_ref[:] = jnp.ones((1, 128), jnp.int32) * 3 + plgpu.commit_smem() + @pl.core_map(warp_mesh) + def _(): + warp_id = lax.axis_index("warp") + # We cannot load/store inside of core_map, so we issue async + # copies instead to produce a testable result. + @pl.when(warp_id == 1) + def _(): + plgpu.copy_smem_to_gmem(ones_smem_ref, y_ref.at[0:1]) + @pl.when(warp_id == 3) + def _(): + plgpu.copy_smem_to_gmem(threes_smem_ref, y_ref.at[1:2]) + plgpu.wait_smem_to_gmem(0) + pl.run_scoped(scope, + plgpu.SMEM((1, 128), jnp.int32), + plgpu.SMEM((1, 128), jnp.int32) + ) + result = kernel() + expected = jnp.stack((jnp.ones((128,), jnp.int32), + jnp.ones((128,), jnp.int32) * 3), axis=0) + np.testing.assert_array_equal(result, expected) + + def test_warp_mesh_errors_when_closing_over_array(self): + if self.LOWERING_SEMANTICS != plgpu.LoweringSemantics.Lane: + self.skipTest("Test only works on Lane semantics") + # We currently do not allow closing over arrays when mapping over + # a mesh, since we would need to present a view of the array local + # to each warp. + warp_mesh = plgpu.WarpMesh(axis_name="warp") + @functools.partial(plgpu.kernel, + out_shape=jax.ShapeDtypeStruct((32, 32), jnp.float32), + scratch_shapes=[plgpu.SMEM((32, 32), jnp.float32)]) + def kernel(out_ref, smem_ref): + arr = jnp.ones((32, 32), dtype=jnp.float32) + @pl.core_map(warp_mesh) + def _(): + smem_ref[...] = arr + 1 + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(smem_ref, out_ref) + plgpu.wait_smem_to_gmem(0) + with self.assertRaisesRegex( + mgpu_lowering.LoweringError, + "Can only close over scalars and Refs when using core_map with " + "WarpMesh", + ): + kernel() + + +class PallasCallWGTest( + PallasCallTest, lowering_semantics=plgpu.LoweringSemantics.Warpgroup +): + ... + + def test_missing_primitive_lowerings_are_tracked(self): + # This test is a way to keep track of which primitives need to be adapted + # to using warpgroup semantics. Once the set is empty, we should be able to + # enable warpgroup semantics by default (assuming we haven't overspecialized + # lowerings). + rules = mgpu_lowering.mosaic_lowering_rules + wg_wg_lowered_primitives = set( + rules[(plgpu.LoweringSemantics.Warpgroup, + gpu_core.PrimitiveSemantics.Warpgroup)]) + lane_wg_lowered_primitives = set(rules[ + (plgpu.LoweringSemantics.Lane, gpu_core.PrimitiveSemantics.Warpgroup)]) + + actual_missing_primitives = (lane_wg_lowered_primitives - + wg_wg_lowered_primitives) + expected_missing_primitives = { + mgpu_primitives.inline_mgpu_p, + mgpu_primitives.broadcasted_iota_p, + mgpu_primitives.load_p, + lax.slice_p, + pallas_core.core_map_p, + } + + # TODO(dasenov): Remove this after the minimal jaxlib version is 0.5.4. + if jaxlib.version < (0, 5, 4): + expected_missing_primitives.add(mgpu_primitives.layout_cast_p) + + self.assertSetEqual(actual_missing_primitives, expected_missing_primitives) + class PallasCallSm90ATest(PallasSm90ATest): @parameterized.parameters(False, True) def test_fori_loop_accumulator(self, force_while): - transforms = (plgpu.TilingTransform((64, 64)), plgpu.SwizzleTransform(128)) + if force_while: + # Layout inference and lowering for 'while' are not yet implemented for + # warpgroup semantics. + self.skip_if_wg_semantics() + if self.LOWERING_SEMANTICS == plgpu.LoweringSemantics.Lane: + transforms = (plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128)) + else: + transforms = () @functools.partial( - pl.pallas_call, - in_specs=[plgpu.GPUBlockSpec((64, 64), lambda: (0, 0), transforms=transforms)], + self.pallas_call, + in_specs=[plgpu.GPUBlockSpec((64, 64), transforms=transforms)], out_shape=jax.ShapeDtypeStruct((64, 64), jnp.float16), - out_specs=plgpu.GPUBlockSpec((64, 64), lambda: (0, 0)), + out_specs=plgpu.GPUBlockSpec((64, 64)), ) def kernel(i_ref, o_ref): def scope(acc_ref): @@ -1263,7 +1654,8 @@ def scope(acc_ref): acc_ini = jnp.ones((64, 64), dtype=jnp.float16) np.testing.assert_array_equal(kernel(acc_ini), jnp.full((64, 64), 5, dtype=jnp.float16)) - def test_realistic_matmul(self): + @parameterized.product(lhs_transpose=[False, True], rhs_transpose=[False, True]) + def test_realistic_matmul(self, lhs_transpose, rhs_transpose): dtype = jnp.float16 swizzle = 128 elems_128b = swizzle // jnp.dtype(dtype).itemsize @@ -1273,7 +1665,11 @@ def test_realistic_matmul(self): m, k, n = grid_m * tile_m, grid_k * tile_k, grid_n * tile_n def kernel(a_ref, b_ref, o_ref, acc_ref): # Make sure tiling does not alter the shape of references + if lhs_transpose: + a_ref = plgpu.transpose_ref(a_ref, (1, 0)) assert a_ref.shape == (tile_m, tile_k) + if rhs_transpose: + b_ref = plgpu.transpose_ref(b_ref, (1, 0)) assert b_ref.shape == (tile_k, tile_n) assert o_ref.shape == acc_ref.shape == (tile_m, tile_n) plgpu.wgmma(acc_ref, a_ref, b_ref) @@ -1284,37 +1680,66 @@ def _epilogue(): plgpu.wgmma_wait(1) # We don't await the last WGMMA, hence delay_release=1 key1, key2 = jax.random.split(jax.random.key(42), 2) - a = jax.random.uniform(key1, shape=(m, k), dtype=dtype) - b = jax.random.uniform(key2, shape=(k, n), dtype=dtype) + a_shape = (k, m) if lhs_transpose else (m, k) + a = jax.random.uniform(key1, shape=a_shape, dtype=dtype) + b_shape = (n, k) if rhs_transpose else (k, n) + b = jax.random.uniform(key2, shape=b_shape, dtype=dtype) - res = pl.pallas_call( + if lhs_transpose: + lhs_spec = pl.BlockSpec( + (tile_k, tile_m), + lambda m, n, k: (k, m), + ) + else: + lhs_spec = pl.BlockSpec( + (tile_m, tile_k), + lambda m, n, k: (m, k), + ) + if rhs_transpose: + rhs_spec = pl.BlockSpec( + (tile_n, tile_k), + lambda m, n, k: (n, k), + ) + else: + rhs_spec = pl.BlockSpec( + (tile_k, tile_n), + lambda m, n, k: (k, n), + ) + out_spec = pl.BlockSpec( + (tile_m, tile_n), + lambda m, n, k: (m, n), + ) + + if self.LOWERING_SEMANTICS == plgpu.LoweringSemantics.Lane: + lhs_spec = plgpu.GPUBlockSpec( + lhs_spec.block_shape, + lhs_spec.index_map, + transforms=( + plgpu.TilingTransform((8, elems_128b)), + plgpu.SwizzleTransform(128), + ), + ) + rhs_spec = plgpu.GPUBlockSpec( + rhs_spec.block_shape, + rhs_spec.index_map, + transforms=( + plgpu.TilingTransform((8, elems_128b)), + plgpu.SwizzleTransform(128), + ), + ) + out_spec = plgpu.GPUBlockSpec( + out_spec.block_shape, + out_spec.index_map, + transforms=( + plgpu.TilingTransform((8, elems_128b)), + plgpu.SwizzleTransform(128), + ), + ) + + res = self.pallas_call( kernel, - in_specs=[ - plgpu.GPUBlockSpec( - (tile_m, tile_k), - lambda m, n, k: (m, k), - transforms=( - plgpu.TilingTransform((64, elems_128b)), - plgpu.SwizzleTransform(128), - ), - ), - plgpu.GPUBlockSpec( - (tile_k, tile_n), - lambda m, n, k: (k, n), - transforms=( - plgpu.TilingTransform((elems_128b, elems_128b)), - plgpu.SwizzleTransform(128), - ), - ), - ], - out_specs=plgpu.GPUBlockSpec( - (tile_m, tile_n), - lambda m, n, k: (m, n), - transforms=( - plgpu.TilingTransform((64, elems_128b)), - plgpu.SwizzleTransform(128), - ), - ), + in_specs=[lhs_spec, rhs_spec], + out_specs=out_spec, out_shape=jax.ShapeDtypeStruct((m, n), jnp.float16), scratch_shapes=[plgpu.ACC((tile_m, tile_n), jnp.float32)], grid=(grid_m, grid_n, grid_k), @@ -1324,10 +1749,16 @@ def _epilogue(): delay_release=1, ), )(a, b) - np.testing.assert_allclose(res, a @ b, rtol=1e-3) + np.testing.assert_allclose( + res, + (a.T if lhs_transpose else a) @ (b.T if rhs_transpose else b), + rtol=1e-3, + ) @parameterized.parameters(jnp.float16, jnp.float32) def test_wgmma(self, dtype): + self.skip_if_wg_semantics() + # TensorCores can only fuse transposes of 16-bit values, and RHS # is expected to be column major by default. rhs_transpose = jnp.dtype(dtype).itemsize != 2 @@ -1349,17 +1780,15 @@ def scope(acc_ref): b_shape = b_shape[::-1] b = jax.random.uniform(key2, shape=b_shape, dtype=dtype) - rhs_transforms = (plgpu.TilingTransform((elems_128b, elems_128b)),) - if rhs_transpose: - rhs_transforms += (plgpu.TransposeTransform((1, 0, 2, 3)),) - res = pl.pallas_call( + rhs_transforms = (plgpu.TilingTransform((8, elems_128b)),) + res = self.pallas_call( kernel, in_specs=[ plgpu.GPUBlockSpec( (64, 128), lambda i, j: (i, j), transforms=( - plgpu.TilingTransform((64, elems_128b)), + plgpu.TilingTransform((8, elems_128b)), plgpu.SwizzleTransform(128), ), ), @@ -1388,14 +1817,15 @@ def scope(acc_ref): a = jax.random.uniform(key1, shape=(64, 128), dtype=jnp.float16) b = jax.random.uniform(key2, shape=(128, 192), dtype=jnp.float16) - transforms = (plgpu.TilingTransform((64, 64)), plgpu.SwizzleTransform(128)) - res = pl.pallas_call( + transforms = () + if self.LOWERING_SEMANTICS == plgpu.LoweringSemantics.Lane: + transforms = (plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128)) + res = self.pallas_call( kernel, in_specs=[ - plgpu.GPUBlockSpec((64, 128), lambda: (0, 0), transforms=transforms), - plgpu.GPUBlockSpec((128, 192), lambda: (0, 0), transforms=transforms), + plgpu.GPUBlockSpec(transforms=transforms), + plgpu.GPUBlockSpec(transforms=transforms), ], - out_specs=plgpu.GPUBlockSpec((64, 192), lambda: (0, 0)), out_shape=jax.ShapeDtypeStruct((64, 192), jnp.float32), )(a, b) np.testing.assert_allclose(res, a @ b, rtol=1e-3) @@ -1411,20 +1841,24 @@ def scope(acc_ref): b = jax.random.uniform(key2, shape=(128, 192), dtype=jnp.float16) i = jax.random.uniform(key3, shape=(64, 192), dtype=jnp.float16) * 10 - transforms = (plgpu.TilingTransform((64, 64)), plgpu.SwizzleTransform(128)) - res = pl.pallas_call( + if self.LOWERING_SEMANTICS == plgpu.LoweringSemantics.Lane: + transforms = (plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128)) + else: + transforms = () + res = self.pallas_call( kernel, in_specs=[ - plgpu.GPUBlockSpec((64, 128), lambda: (0, 0), transforms=transforms), - plgpu.GPUBlockSpec((128, 192), lambda: (0, 0), transforms=transforms), - plgpu.GPUBlockSpec((64, 192), lambda: (0, 0), transforms=transforms), + plgpu.GPUBlockSpec(transforms=transforms), + plgpu.GPUBlockSpec(transforms=transforms), + plgpu.GPUBlockSpec(transforms=transforms), ], - out_specs=plgpu.GPUBlockSpec((64, 192), lambda: (0, 0)), out_shape=jax.ShapeDtypeStruct((64, 192), jnp.float16), )(a, b, i) np.testing.assert_allclose(res, i + a @ b, rtol=2e-3) def test_wgmma_sliced_ref(self): + self.skip_if_wg_semantics() # Needs WGMMA to support slices. + def kernel(a_ref, b_ref, o_ref): def scope(acc_ref): plgpu.wgmma(acc_ref, a_ref.at[0], b_ref.at[0]) @@ -1436,30 +1870,23 @@ def scope(acc_ref): a = jax.random.uniform(key1, shape=(2, 64, 128), dtype=jnp.float16) b = jax.random.uniform(key2, shape=(2, 128, 192), dtype=jnp.float16) - res = pl.pallas_call( + transforms = () + if self.LOWERING_SEMANTICS == plgpu.LoweringSemantics.Lane: + transforms = (plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128)) + + res = self.pallas_call( kernel, in_specs=[ - plgpu.GPUBlockSpec( - (2, 64, 128), lambda: (0, 0, 0), - transforms=( - plgpu.TilingTransform((64, 64)), - plgpu.SwizzleTransform(128), - ), - ), - plgpu.GPUBlockSpec( - (2, 128, 192), lambda: (0, 0, 0), - transforms=( - plgpu.TilingTransform((64, 64)), - plgpu.SwizzleTransform(128), - ), - ), + plgpu.GPUBlockSpec(transforms=transforms), + plgpu.GPUBlockSpec(transforms=transforms), ], - out_specs=plgpu.GPUBlockSpec((64, 192), lambda: (0, 0)), out_shape=jax.ShapeDtypeStruct((64, 192), jnp.float32), )(a, b) np.testing.assert_allclose(res, a[0] @ b[0], rtol=1e-3) def test_wgmma_sliced_acc(self): + self.skip_if_wg_semantics() # Needs WGMMA to support slices. + swizzle = 128 elems_128b = swizzle // jnp.dtype(jnp.float16).itemsize def kernel(a_ref, b_ref, o_ref): @@ -1472,33 +1899,66 @@ def scope(acc_ref): key1, key2 = jax.random.split(jax.random.key(42), 2) a = jax.random.uniform(key1, shape=(64, 128), dtype=jnp.float16) b = jax.random.uniform(key2, shape=(128, 128), dtype=jnp.float16) - res = pl.pallas_call( + transforms = () + if self.LOWERING_SEMANTICS == plgpu.LoweringSemantics.Lane: + transforms = ( + plgpu.TilingTransform((8, elems_128b)), + plgpu.SwizzleTransform(128), + ) + res = self.pallas_call( kernel, in_specs=[ plgpu.GPUBlockSpec( - (64, 128), - lambda i, j: (i, j), - transforms=( - plgpu.TilingTransform((64, elems_128b)), - plgpu.SwizzleTransform(128), - ), + (64, 128), lambda *ij: ij, transforms=transforms ), plgpu.GPUBlockSpec( - (128, 128), - lambda *i: i, - transforms=( - plgpu.TilingTransform((elems_128b, elems_128b)), - plgpu.SwizzleTransform(128), - ), + (128, 128), lambda *ij: ij, transforms=transforms ), ], - out_specs=plgpu.GPUBlockSpec((64, 128), lambda *i: i), + out_specs=plgpu.GPUBlockSpec((64, 128), lambda *ij: ij), out_shape=jax.ShapeDtypeStruct((64, 128), jnp.float32), grid=(1, 1), )(a, b) np.testing.assert_allclose(res, a @ b, rtol=1e-3) +class PallasCallSm90AWGTest( + PallasCallSm90ATest, lowering_semantics=plgpu.LoweringSemantics.Warpgroup +): + ... + + +class PallasCallSm100ATest(PallasSm100ATest): + + def test_tmem_alloc(self): + + @functools.partial( + self.kernel, + out_shape=jnp.zeros((128, 128), jnp.float32), + scratch_shapes=[ + plgpu.TMEM((128, 128), jnp.float32), + plgpu.SMEM((128, 128), jnp.float32), + ], + num_threads=1, + thread_name="x", + ) + def kernel(y_ref, tmem_ref, smem_ref): + # Issue a write so the TMEM load is not DCE'd. + smem_ref[...] = tmem_ref[...] + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(smem_ref, y_ref) + plgpu.wait_smem_to_gmem(0) + + # Test that this runs without errors. + jax.block_until_ready(kernel()) + + +class PallasCallSm100AWGTest( + PallasCallSm100ATest, lowering_semantics=plgpu.LoweringSemantics.Warpgroup +): + ... + + class PipelineTest(PallasTest): def test_pipeline_mode(self): @@ -1520,13 +1980,13 @@ def body(x_ref, y_ref, o_ref): @jax.jit def vadd(x, y): - return pl.pallas_call( - body, - out_shape=jax.ShapeDtypeStruct(x.shape, jnp.float32), - in_specs=in_specs, - out_specs=out_specs, - grid=data_size // block_size, - )(x, y) + return self.pallas_call( + body, + out_shape=jax.ShapeDtypeStruct(x.shape, jnp.float32), + in_specs=in_specs, + out_specs=out_specs, + grid=data_size // block_size, + )(x, y) with self.assertRaisesRegex(Exception, "Pipeline mode is not supported"): vadd(x, y) @@ -1588,7 +2048,7 @@ def body(step, _): plgpu.wait_smem_to_gmem(0) x = jnp.arange(32 * 4 * 64).reshape(32 * 4, 64).astype(jnp.float32) - kernel_fn = pl.pallas_call( + kernel_fn = self.pallas_call( kernel, in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), @@ -1597,38 +2057,45 @@ def body(step, _): ) np.testing.assert_array_equal(kernel_fn(x), x + 1.0) - @parameterized.parameters( - ((),), - ((plgpu.TilingTransform((64, 32)), plgpu.SwizzleTransform(128)),), + @parameterized.product( + transforms=( + (), + (plgpu.TilingTransform((8, 32)), plgpu.SwizzleTransform(128)), + ), + repeats=(1, 3), ) - def test_emit(self, transforms): + def test_emit(self, transforms, repeats): + if transforms: + self.skip_if_wg_semantics() + num_steps = 4 def kernel(x_gmem, o_gmem): - plgpu.emit_pipeline( - kernel_body, - in_specs=[ - plgpu.GPUBlockSpec( - (64, 64), lambda i: (0, i), transforms=transforms - ) - ], - out_specs=[ - plgpu.GPUBlockSpec( - (64, 64), lambda i: (0, i), transforms=transforms - ) - ], - grid=(num_steps,), - max_concurrent_steps=2, - )(x_gmem, o_gmem) + for _ in range(repeats): + plgpu.emit_pipeline( + kernel_body, + in_specs=[ + plgpu.GPUBlockSpec( + (64, 64), lambda i: (0, i), transforms=transforms + ) + ], + out_specs=[ + plgpu.GPUBlockSpec( + (64, 64), lambda i: (0, i), transforms=transforms + ) + ], + grid=(num_steps,), + max_concurrent_steps=2, + )(x_gmem, o_gmem) - def kernel_body(x_smem, o_smem): + def kernel_body(_, x_smem, o_smem): # +1 for the indexing done by ``emit_pipeline`. self.assertLen(x_smem.transforms, len(transforms) + 1) o_smem[...] = x_smem[...] + 1.0 x = jnp.arange(64 * num_steps * 64) x = x.reshape(-1, num_steps * 64).astype(jnp.float32) - kernel_fn = pl.pallas_call( + kernel_fn = self.pallas_call( kernel, in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), @@ -1647,7 +2114,7 @@ def kernel(x_gmem, o_gmem): grid=(), )(x_gmem, o_gmem) - def nested_kernel(x_gmem, o_gmem): + def nested_kernel(_, x_gmem, o_gmem): plgpu.emit_pipeline( nested_kernel_body, in_specs=[pl.BlockSpec((32, 16), lambda i: (0, i))], @@ -1656,12 +2123,12 @@ def nested_kernel(x_gmem, o_gmem): max_concurrent_steps=2, )(x_gmem, o_gmem) - def nested_kernel_body(x_smem, o_smem): + def nested_kernel_body(_, x_smem, o_smem): o_smem[...] = x_smem[...] + 1.0 x = jnp.arange(32 * num_steps * 16) x = x.reshape(-1, num_steps * 16).astype(jnp.float32) - kernel_fn = pl.pallas_call( + kernel_fn = self.pallas_call( kernel, in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), @@ -1681,12 +2148,12 @@ def kernel(x_gmem, o_gmem): max_concurrent_steps=2, )(x_gmem, o_gmem) - def kernel_body(x_smem, o_smem): + def kernel_body(_, x_smem, o_smem): o_smem[...] = x_smem[...] + 1.0 x = jnp.arange(32 * num_steps * 16) x = x.reshape(-1, num_steps * 16).astype(jnp.float32) - kernel_fn = pl.pallas_call( + kernel_fn = self.pallas_call( kernel, in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), @@ -1714,12 +2181,12 @@ def kernel(x_gmem, o_gmem): max_concurrent_steps=2, )(x_gmem, o_gmem) - def kernel_body(x_smem, o_smem): + def kernel_body(_, x_smem, o_smem): o_smem[...] = x_smem[...] + 1.0 x = jnp.arange(num_steps1 * 32 * num_steps2 * 16) x = x.reshape(-1, num_steps2 * 16).astype(jnp.float32) - kernel_fn = pl.pallas_call( + kernel_fn = self.pallas_call( kernel, in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), @@ -1729,25 +2196,32 @@ def kernel_body(x_smem, o_smem): y = x + 1.0 np.testing.assert_array_equal(kernel_fn(x), y) - def test_emit_with_2d_grid(self): + @parameterized.product(static=[False, True], short=[False, True]) + def test_emit_with_2d_grid(self, static, short): num_steps1 = 4 num_steps2 = 5 + if short: + num_steps1 = num_steps2 = 1 def kernel(x_gmem, o_gmem): + grid = (num_steps1, num_steps2) + if static: + grid = jax.tree.map(jnp.asarray, grid) + plgpu.emit_pipeline( kernel_body, in_specs=[pl.BlockSpec((32, 16, 8), lambda i, j: (0, i, j))], out_specs=[pl.BlockSpec((32, 16, 8), lambda i, j: (0, i, j))], - grid=(num_steps1, num_steps2), + grid=grid, max_concurrent_steps=2, )(x_gmem, o_gmem) - def kernel_body(x_smem, o_smem): + def kernel_body(_, x_smem, o_smem): o_smem[...] = x_smem[...] + 1.0 x = jnp.arange(32 * num_steps1 * 16 * num_steps2 * 8) x = x.reshape(-1, num_steps1 * 16, num_steps2 * 8).astype(jnp.float32) - kernel_fn = pl.pallas_call( + kernel_fn = self.pallas_call( kernel, in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), @@ -1756,9 +2230,17 @@ def kernel_body(x_smem, o_smem): np.testing.assert_array_equal(kernel_fn(x), x + 1.0) +class PipelineWGTest( + PipelineTest, lowering_semantics=plgpu.LoweringSemantics.Warpgroup +): + ... + + class PipelineSm90ATest(PallasSm90ATest): def test_realistic_matmul(self): + self.skip_if_wg_semantics() # Needs WGMMA to support slices. + dtype = jnp.float16 swizzle = 128 elems_128b = swizzle // jnp.dtype(dtype).itemsize @@ -1768,8 +2250,15 @@ def test_realistic_matmul(self): tile_k = elems_128b m, k, n = grid_m * tile_m, grid_k * tile_k, grid_n * tile_n + transforms = () + if self.LOWERING_SEMANTICS == plgpu.LoweringSemantics.Lane: + transforms = ( + plgpu.TilingTransform((8, elems_128b)), + plgpu.SwizzleTransform(128), + ) + def kernel(a_gmem, b_gmem, o_smem, acc): - def kernel_body(a_smem, b_smem): + def kernel_body(_, a_smem, b_smem): assert a_smem.shape == (tile_m, tile_k) assert b_smem.shape == (tile_k, tile_n) plgpu.wgmma(acc, a_smem, b_smem) @@ -1781,21 +2270,11 @@ def kernel_body(a_smem, b_smem): kernel_body, in_specs=[ plgpu.GPUBlockSpec( - (tile_m, tile_k), - lambda k: (pid_m, k), - transforms=( - plgpu.TilingTransform((64, elems_128b)), - plgpu.SwizzleTransform(128), - ), - ), - plgpu.GPUBlockSpec( - (tile_k, tile_n), - lambda k: (k, pid_n), - transforms=( - plgpu.TilingTransform((elems_128b, elems_128b)), - plgpu.SwizzleTransform(128), - ), - ), + (tile_m, tile_k), lambda k: (pid_m, k), transforms=transforms + ), + plgpu.GPUBlockSpec( + (tile_k, tile_n), lambda k: (k, pid_n), transforms=transforms + ), ], grid=(grid_k,), max_concurrent_steps=2, @@ -1808,19 +2287,14 @@ def kernel_body(a_smem, b_smem): a = jax.random.uniform(key1, shape=(m, k), dtype=dtype) b = jax.random.uniform(key2, shape=(k, n), dtype=dtype) - res = pl.pallas_call( + res = self.pallas_call( kernel, in_specs=[ pl.BlockSpec(memory_space=plgpu.GMEM), - pl.BlockSpec(memory_space=plgpu.GMEM) + pl.BlockSpec(memory_space=plgpu.GMEM), ], out_specs=plgpu.GPUBlockSpec( - (tile_m, tile_n), - lambda m, n: (m, n), - transforms=( - plgpu.TilingTransform((64, elems_128b)), - plgpu.SwizzleTransform(128), - ), + (tile_m, tile_n), lambda m, n: (m, n), transforms=transforms ), out_shape=jax.ShapeDtypeStruct((m, n), jnp.float16), scratch_shapes=[plgpu.ACC((tile_m, tile_n), jnp.float32)], @@ -1829,17 +2303,23 @@ def kernel_body(a_smem, b_smem): np.testing.assert_array_equal(res, a @ b) +class PipelineSm90AWGTest( + PipelineSm90ATest, lowering_semantics=plgpu.LoweringSemantics.Warpgroup +): + ... + + class WarpSpecializedPipelineTest(PallasTest): - @parameterized.product(m=[512], n=[512], + @parameterized.product(m=[512], n=[512], repeats=[1, 3], manual_consumed_barriers=[False, True]) - def test_pipelined_copy(self, m, n, manual_consumed_barriers): + def test_pipelined_copy(self, m, n, repeats, manual_consumed_barriers): + self.skip_if_wg_semantics() # Times out! + x = jax.random.uniform(jax.random.key(0), (m, n), dtype=jnp.float16) - o = jnp.zeros((m, n), dtype=jnp.float16) blk_m = blk_n = 64 - o_last_block = jnp.zeros((blk_m, blk_n), dtype=jnp.float16) - def copy_kernel(x_smem, o_smem, o_last_block_smem, *consumed_barriers): + def copy_kernel(_, x_smem, o_smem, o_last_block_smem, *consumed_barriers): # TODO(justinfu): Have each wg compute a separate slice # after multiple-indexers are supported. # This is currently a race, but the values written are the same. @@ -1848,99 +2328,109 @@ def copy_kernel(x_smem, o_smem, o_last_block_smem, *consumed_barriers): if manual_consumed_barriers: [x_barrier] = consumed_barriers plgpu.barrier_arrive(x_barrier) - block_spec = plgpu.GPUBlockSpec( - block_shape=(blk_m, blk_n), - index_map=lambda i, j: (i, j), - transforms=[], - ) - pipeline = mgpu_pipeline.emit_pipeline_warp_specialized( - copy_kernel, - grid=(m // blk_m, n // blk_n), - memory_registers=40, - max_concurrent_steps=2, - num_compute_wgs=2, - wg_axis="wg", - manual_consumed_barriers=manual_consumed_barriers, - in_specs=[block_spec], - out_specs=[block_spec, - # Create an index-invariant output. - plgpu.GPUBlockSpec(block_shape=(blk_m, blk_n), - index_map=lambda i, j: (0, 0)) - ], - ) - mesh = plgpu.GPUMesh(grid=(1,), num_threads=3, axis_names=("_", "wg")) - def run(refs): - @pl.core_map( - mesh, compiler_params=plgpu.GPUCompilerParams(approx_math=True) + + spec = pl.BlockSpec( + block_shape=(blk_m, blk_n), index_map=lambda i, j: (i, j) + ) + def body(*gmem_refs): + pipeline = mgpu_pipeline.emit_pipeline_warp_specialized( + copy_kernel, + grid=(m // blk_m, n // blk_n), + memory_registers=40, + max_concurrent_steps=2, + num_compute_wgs=2, + wg_axis="wg", + manual_consumed_barriers=manual_consumed_barriers, + in_specs=[spec], + out_specs=[ + spec, + # Create an index-invariant output. + pl.BlockSpec( + block_shape=(blk_m, blk_n), index_map=lambda i, j: (0, 0) + ), + ], ) - def _kernel_entry(): - pipeline(*refs) - @jax.jit - def run_function(x, o, o_last_block): - _, out, out_last = pl.run_state(run)((x, o, o_last_block)) - return (out, out_last) - out, out_last_block = run_function(x, o, o_last_block) + for _ in range(repeats): + pipeline(*gmem_refs) # Make sure we can run the pipeline multiple times + kernel = self.kernel( + body, + out_shape=( + jax.ShapeDtypeStruct((m, n), jnp.float16), + jax.ShapeDtypeStruct((blk_m, blk_n), jnp.float16), + ), + compiler_params=plgpu.GPUCompilerParams(approx_math=True), + grid=(1,), + grid_names=("_",), + num_threads=3, + thread_name="wg", + ) + out, out_last_block = kernel(x) np.testing.assert_array_equal(out, x) np.testing.assert_array_equal(out_last_block, x[-blk_m:, -blk_n:]) - def test_elementwise_add(self, m=256, n=256, num_compute_wgs=2): + @parameterized.product( + m=[256, 64], n=[256, 64], num_compute_wgs=[1, 2], static=[False, True] + ) + def test_elementwise_add(self, m, n, num_compute_wgs, static): + self.skip_if_wg_semantics() # Crashes! + blk_m = blk_n = 64 - x = jax.random.uniform(jax.random.key(0), (m, n), dtype=jnp.float32) - y = jax.random.uniform(jax.random.key(1), (m, n), dtype=jnp.float32) - o = jnp.zeros((m, n), dtype=jnp.float32) + spec = pl.BlockSpec( + block_shape=(blk_m, blk_n), index_map=lambda i, j: (i, j) + ) - def tiled_add_kernel(x_smem, y_smem, o_smem): + def tiled_add_kernel(_, x_smem, y_smem, o_smem): # TODO(justinfu): Have each wg compute a separate slice # after multiple-indexers are supported. # This is currently a race, but the values written are the same. o_smem[...] = x_smem[...] + y_smem[...] - pipeline = mgpu_pipeline.emit_pipeline_warp_specialized( - tiled_add_kernel, - grid=(m // blk_m, n // blk_n), - max_concurrent_steps=2, - num_compute_wgs=num_compute_wgs, - memory_registers=40, - wg_axis="wg", - in_specs=[ - plgpu.GPUBlockSpec( - block_shape=(blk_m, blk_n), - index_map=lambda i, j: (i, j), - transforms=[]), - plgpu.GPUBlockSpec( - block_shape=(blk_m, blk_n), - index_map=lambda i, j: (i, j), - transforms=[]), - ], - out_specs=[ - plgpu.GPUBlockSpec( - block_shape=(blk_m, blk_n), - index_map=lambda i, j: (i, j), - transforms=[])], - ) - mesh = plgpu.GPUMesh( - grid=(1,), num_threads=num_compute_wgs + 1, axis_names=("_", "wg") + def pipeline(*gmem_refs): + grid = (m // blk_m, n // blk_n) + if not static: + grid = jax.tree.map(jnp.asarray, grid) + return mgpu_pipeline.emit_pipeline_warp_specialized( + tiled_add_kernel, + grid=grid, + max_concurrent_steps=2, + num_compute_wgs=num_compute_wgs, + memory_registers=40, + wg_axis="wg", + in_specs=[spec, spec], + out_specs=[spec], + )(*gmem_refs) + + kernel = self.kernel( + pipeline, + out_shape=jax.ShapeDtypeStruct((m, n), jnp.float32), + compiler_params=plgpu.GPUCompilerParams(approx_math=True), + grid=(1,), + grid_names=("_",), + num_threads=num_compute_wgs + 1, + thread_name="wg", ) - def run(refs): - @pl.core_map( - mesh, compiler_params=plgpu.GPUCompilerParams(approx_math=True) - ) - def _kernel_entry(): - pipeline(*refs) - @jax.jit - def run_function(x, y, o): - _, _, out = pl.run_state(run)((x, y, o)) - return out - out = run_function(x, y, o) - reference = x + y - np.testing.assert_allclose(out, reference, atol=1e-4) + x = jax.random.uniform(jax.random.key(0), (m, n), dtype=jnp.float32) + y = jax.random.uniform(jax.random.key(1), (m, n), dtype=jnp.float32) + np.testing.assert_allclose(kernel(x, y), x + y, atol=1e-4) def test_carry_accumulate(self, m=256, n=256, num_compute_wgs=2): + self.skip_if_wg_semantics() # `plgpu.layout_cast` is not supported. + blk_m = blk_n = 64 - x = jax.random.uniform(jax.random.key(0), (m, n), dtype=jnp.float32) - acc_init = jnp.zeros((blk_m, blk_n), dtype=jnp.float32) - def _scoped(acc_smem, x_gmem, acc_gmem): + @functools.partial( + self.kernel, + out_shape=jax.ShapeDtypeStruct((blk_m, blk_n), jnp.float32), + scratch_shapes=[ + plgpu.SMEM((blk_m, blk_n), jnp.float32), + ], + compiler_params=plgpu.GPUCompilerParams(approx_math=True), + grid=(1,), + grid_names=("_",), + num_threads=num_compute_wgs + 1, + thread_name="wg", + ) + def kernel(x_gmem, acc_gmem, acc_smem): def _compute_thread(): # Cast the init value to the same layout as x_smem, so the pipeline loop # carry has a constant signature. @@ -1958,7 +2448,7 @@ def _compute_thread(): plgpu.copy_smem_to_gmem(acc_smem, acc_gmem) plgpu.wait_smem_to_gmem(0) - def tiled_acc_kernel(x_smem, carry): + def tiled_acc_kernel(_, x_smem, carry): o_carry, = carry new_carry = x_smem[...] + o_carry return (new_carry,) @@ -1972,77 +2462,66 @@ def tiled_acc_kernel(x_smem, carry): wg_axis="wg", carry_coroutine=_compute_thread, in_specs=[ - plgpu.GPUBlockSpec( - block_shape=(blk_m, blk_n), - index_map=lambda i, j: (i, j), - transforms=[]), + pl.BlockSpec( + block_shape=(blk_m, blk_n), index_map=lambda i, j: (i, j) + ) ], out_specs=[], ) pipeline(x_gmem) - mesh = plgpu.GPUMesh( - grid=(1,), - num_threads=num_compute_wgs + 1, - axis_names=("_", "wg",), - ) - def run(refs): - x_ref, acc_ref = refs - @pl.core_map(mesh) - def _kernel_entry(): - pl.run_scoped( - functools.partial(_scoped, x_gmem=x_ref, acc_gmem=acc_ref), - plgpu.SMEM((blk_m, blk_n), jnp.float32) - ) - @jax.jit - def run_function(x, acc): - _, out_acc = pl.run_state(run)((x, acc)) - return out_acc - out_acc = run_function(x, acc_init) + x = jax.random.uniform(jax.random.key(0), (m, n), dtype=jnp.float32) ref = jnp.sum(jnp.stack(np.split(x, m // blk_m, axis=0)), axis=0) ref = jnp.sum(jnp.stack(np.split(ref, n // blk_n, axis=1)), axis=0) - np.testing.assert_allclose(out_acc, ref, atol=1e-4) + np.testing.assert_allclose(kernel(x), ref, atol=1e-4) + +class WarpSpecializedPipelineWGTest( + WarpSpecializedPipelineTest, + lowering_semantics=plgpu.LoweringSemantics.Warpgroup, +): + ... -class CoreMapTest(PallasTest): + +class CoreMapTest(PallasTest, jtu.CudaArchSpecificTest): def test_multiple_wg(self): - mesh = plgpu.GPUMesh(num_threads=2, axis_names=("y",)) - @jax.jit - def f(): - @pl.run_state - def inner(y_ref): - @pl.core_map(mesh) - def kernel(): - wg_idx = jax.lax.axis_index("y") - y_ref[wg_idx] = jnp.broadcast_to(wg_idx, (128,)) - y_init = jnp.zeros((2, 128), np.int32) - return inner(y_init) + @functools.partial( + self.kernel, + out_shape=jnp.zeros((2, 128), np.int32), + num_threads=2, + thread_name="wg", + ) + def kernel(o_ref): + wg_idx = jax.lax.axis_index("wg") + o_ref[wg_idx] = jnp.broadcast_to(wg_idx, (128,)) + np.testing.assert_array_equal( - f(), np.repeat(np.arange(2), 128).reshape(2, 128) + kernel(), np.repeat(np.arange(2), 128).reshape(2, 128) ) def test_multiple_wg_with_grid(self): - mesh = plgpu.GPUMesh(grid=(2, 2), num_threads=2, axis_names=("x", "y", "wg")) - @jax.jit - def f(): - @pl.run_state - def inner(y_ref): - @pl.core_map(mesh) - def kernel(): - xy_idx = jax.lax.axis_index(("x", "y")) - yx_idx = jax.lax.axis_index(("y", "x")) - wg_idx = jax.lax.axis_index("wg") - num_wgs = jax.lax.psum(1, "wg") - y_ref[xy_idx, wg_idx] = jnp.broadcast_to( - yx_idx * num_wgs + wg_idx, (128,) - ) - y_init = jnp.zeros((4, 2, 128), np.int32) - return inner(y_init) + @functools.partial( + self.kernel, + out_shape=jnp.zeros((4, 2, 128), np.int32), + grid=(2, 2), + grid_names=("x", "y"), + num_threads=2, + thread_name="wg", + ) + def kernel(o_ref): + xy_idx = jax.lax.axis_index(("x", "y")) + yx_idx = jax.lax.axis_index(("y", "x")) + wg_idx = jax.lax.axis_index("wg") + num_wgs = jax.lax.psum(1, "wg") + o_ref[xy_idx, wg_idx] = jnp.broadcast_to( + yx_idx * num_wgs + wg_idx, (128,) + ) + np.testing.assert_array_equal( - f(), np.repeat([0, 1, 4, 5, 2, 3, 6, 7], 128).reshape(4, 2, 128) + kernel(), np.repeat([0, 1, 4, 5, 2, 3, 6, 7], 128).reshape(4, 2, 128) ) def test_multiple_wg_with_squashed_grid(self): @@ -2053,104 +2532,353 @@ def test_multiple_wg_with_squashed_grid(self): y_dim = 5 z_dim = 7 num_threads = 2 - mesh = plgpu.GPUMesh(grid=(b, x_dim, y_dim, z_dim), - num_threads=num_threads, - axis_names=("b", "x", "y", "z", "wg")) - @jax.jit - def f(): - @pl.run_state - def inner(y_ref): - @pl.core_map(mesh) - def _(): - b_idx = jax.lax.axis_index("b") - x_idx = jax.lax.axis_index("x") - y_idx = jax.lax.axis_index("y") - z_idx = jax.lax.axis_index("z") - wg_idx = jax.lax.axis_index("wg") - bxyzw_idx = jax.lax.axis_index(("b", "x", "y", "z", "wg")) - y_ref[b_idx, x_idx, y_idx, z_idx, wg_idx] = jnp.broadcast_to( - bxyzw_idx, (128,) - ) - y_init = jnp.zeros((b, x_dim, y_dim, z_dim, num_threads, 128), np.int32) - return inner(y_init) - result = f()[:, :, :, :, :, 0] + @functools.partial( + self.kernel, + out_shape=jnp.zeros( + (b, x_dim, y_dim, z_dim, num_threads, 128), np.int32 + ), + grid=(b, x_dim, y_dim, z_dim), + grid_names=("b", "x", "y", "z"), + num_threads=num_threads, + thread_name="wg", + ) + def kernel(o_ref): + b_idx = jax.lax.axis_index("b") + x_idx = jax.lax.axis_index("x") + y_idx = jax.lax.axis_index("y") + z_idx = jax.lax.axis_index("z") + wg_idx = jax.lax.axis_index("wg") + bxyzw_idx = jax.lax.axis_index(("b", "x", "y", "z", "wg")) + o_ref[b_idx, x_idx, y_idx, z_idx, wg_idx] = jnp.broadcast_to( + bxyzw_idx, (128,) + ) + + result = kernel()[:, :, :, :, :, 0] ref = np.arange(b * x_dim * y_dim * z_dim * num_threads).reshape( - result.shape) + result.shape + ) np.testing.assert_array_equal(result, ref) - def test_cross_wg_barrier(self): - mesh = plgpu.GPUMesh(num_threads=2, axis_names=("wg",)) + self.skip_if_wg_semantics() # Times out! - @jax.jit - def f(): - @pl.run_state - def inner(y_ref): - @pl.core_map(mesh) - def kernel(): - def scoped(barrier): - plgpu.barrier_arrive(barrier) - plgpu.barrier_wait(barrier) - wg_idx = jax.lax.axis_index("wg") - y_ref[wg_idx] = jnp.broadcast_to(wg_idx, (128,)) - # Each warpgroup is a single logical thread! - pl.run_scoped(scoped, plgpu.Barrier(num_arrivals=2)) - y_init = jnp.zeros((2, 128), np.int32) - return inner(y_init) - np.testing.assert_array_equal(f(), np.repeat([0, 1], 128).reshape(2, 128)) + @functools.partial( + self.kernel, + out_shape=jnp.zeros((2, 128), np.int32), + # Each warpgroup is a single logical thread! + scratch_shapes=[plgpu.Barrier(num_arrivals=2)], + num_threads=2, + thread_name="wg", + ) + def kernel(o_ref, barrier): + plgpu.barrier_arrive(barrier) + plgpu.barrier_wait(barrier) + wg_idx = jax.lax.axis_index("wg") + o_ref[wg_idx] = jnp.broadcast_to(wg_idx, (128,)) + + np.testing.assert_array_equal( + kernel(), np.repeat([0, 1], 128).reshape(2, 128) + ) + + def test_cluster(self): + self.skip_if_wg_semantics() # Needs debug_print in the MGPU dialect. + + @functools.partial( + self.kernel, + out_shape=jnp.zeros(128, np.int32), + grid=(2,), + grid_names=("x",), + cluster=(2,), + cluster_names=("cluster",), + ) + def kernel(ref): + block_idx = jax.lax.axis_index("x") + cluster_idx = jax.lax.axis_index("cluster") + pl.debug_print("block: {} cluster: {}", block_idx, cluster_idx) + + ref[...] = ref[...] + + with self.capture_stdout() as output: + jax.block_until_ready(kernel()) + self.assertEqual( + set(output().splitlines()), + { + "block: 0 cluster: 0", + "block: 1 cluster: 0", + "block: 0 cluster: 1", + "block: 1 cluster: 1", + }, + ) + + def test_realistic_matmul_with_cluster(self): + self.skip_if_wg_semantics() # Needs WGMMA to support slices. + self.skip_unless_sm90a() # Requires WGMMA. + + dtype = jnp.float16 + swizzle = 128 + elems_128b = swizzle // jnp.dtype(dtype).itemsize + grid_m, grid_k, grid_n = 132, 10, 32 + # TODO(slebedev): Remove ``grid_tile_n`` to simplify the test. + grid_tile_n = 4 + assert grid_n % grid_tile_n == 0 + cluster_m = 2 + cluster_n = 2 + cluster_tile_n = min(cluster_n, grid_tile_n) + tile_m = tile_n = 128 + assert tile_m % elems_128b == 0 + tile_k = elems_128b + m, k, n = grid_m * tile_m, grid_k * tile_k, grid_n * tile_n + + transforms = ( + plgpu.TilingTransform((8, elems_128b)), + plgpu.SwizzleTransform(128), + ) + + max_concurrent_steps = 2 + delay_release = 1 + + @functools.partial( + self.kernel, + out_shape=jax.ShapeDtypeStruct((m, n), dtype), + scratch_shapes=[ + plgpu.SMEM( + (max_concurrent_steps, tile_m, tile_k), + dtype, + transforms=transforms, + ), + plgpu.SMEM( + (max_concurrent_steps, tile_k, tile_n), + dtype, + transforms=transforms, + ), + plgpu.SMEM((tile_m, tile_n), dtype, transforms=transforms), + plgpu.ACC((tile_m, tile_n), jnp.float32), + plgpu.Barrier(num_arrivals=2, num_barriers=max_concurrent_steps), + plgpu.ClusterBarrier( + collective_axes=(("x", "z"), "y"), + num_barriers=max_concurrent_steps, + ), + ], + grid=(grid_tile_n, grid_m, grid_n // grid_tile_n), + grid_names=("tile_n", "m", "n"), + cluster=(cluster_tile_n, cluster_m, cluster_n // cluster_tile_n), + cluster_names=("x", "y", "z"), + ) + def kernel( + a_gmem, + b_gmem, + o_gmem, + a_smem, + b_smem, + o_smem, + acc, + barrier, + cluster_barrier, + ): + m_slice = pl.ds(lax.axis_index("m") * tile_m, tile_m) + n_slice = pl.ds( + (lax.axis_index("tile_n") + lax.axis_index("n") * grid_tile_n) + * tile_n, + tile_n, + ) + + def fetch(step, slot): + if not isinstance(slot, int): # Skip in initialization. + plgpu.barrier_arrive(cluster_barrier.at[slot]) + plgpu.barrier_wait(cluster_barrier.at[slot]) + + k_slice = pl.ds(step * tile_k, tile_k) + plgpu.copy_gmem_to_smem( + a_gmem.at[m_slice, k_slice], + a_smem.at[slot], + barrier.at[slot], + collective_axes=("x", "z"), + ) + plgpu.copy_gmem_to_smem( + b_gmem.at[k_slice, n_slice], + b_smem.at[slot], + barrier.at[slot], + collective_axes="y", + ) + + # Initialize the pipeline. + for slot in range(min(max_concurrent_steps, grid_k)): + fetch(slot, slot) + + def body(step, _): + slot = step % max_concurrent_steps + plgpu.barrier_wait(barrier.at[slot]) + + plgpu.wgmma(acc, a_smem.at[slot], b_smem.at[slot]) + plgpu.wgmma_wait(delay_release) + + fetch_step = step + (max_concurrent_steps - delay_release) + fetch_slot = lax.rem(fetch_step, max_concurrent_steps) + jax.lax.cond( + lax.bitwise_and(step >= delay_release, fetch_step < grid_k), + lambda: fetch(fetch_step, fetch_slot), + lambda: None, + ) + return () + + jax.lax.fori_loop(0, grid_k, body, ()) + + # Finalize the pipeline. + o_smem[...] = acc[...].astype(dtype) + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(o_smem, o_gmem.at[m_slice, n_slice]) + plgpu.wait_smem_to_gmem(0) + + key1, key2 = jax.random.split(jax.random.key(42), 2) + a = jax.random.uniform(key1, shape=(m, k), dtype=dtype) + b = jax.random.uniform(key2, shape=(k, n), dtype=dtype) + np.testing.assert_array_equal(kernel(a, b), a @ b) + + +class CoreMapWGTest( + CoreMapTest, lowering_semantics=plgpu.LoweringSemantics.Warpgroup +): + ... + + +class PrettyPrintingTest(PallasTest): + + def test_load(self): + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct([2, 128], jnp.float32), + in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], + out_specs=plgpu.GPUBlockSpec(memory_space=plgpu.SMEM), + ) + def kernel(x_ref, o_ref): + for i in range(2): + x = plgpu.load(x_ref, (i,)) + o_ref[i, ...] = x + + _ = str(jax.make_jaxpr(kernel)(jax.ShapeDtypeStruct((2, 128), jnp.float32))) + + def test_copy_primitives(self): + num_steps = 4 + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((64, 64), jnp.float32), + in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + ) + def kernel(x_gmem, o_gmem): + # ``plgpu.emit_pipeline`` is implemented in terms of async copy and + # synchronization primitives. + plgpu.emit_pipeline( + kernel_body, + in_specs=[pl.BlockSpec((64, 64), lambda i: (0, i))], + out_specs=[ + pl.BlockSpec( + (64, 64), + lambda i: (0, i), + ) + ], + grid=(num_steps,), + max_concurrent_steps=2, + )(x_gmem, o_gmem) + + def kernel_body(_, x_smem, o_smem): + o_smem[...] = x_smem[...] + 1.0 + + _ = str(jax.make_jaxpr(kernel)(jax.ShapeDtypeStruct((64, 64), jnp.float32))) + + def test_wgmma(self): + transforms = () + if self.LOWERING_SEMANTICS == plgpu.LoweringSemantics.Lane: + transforms = (plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128)) + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((64, 192), jnp.float32), + in_specs=[ + plgpu.GPUBlockSpec(transforms=transforms), + plgpu.GPUBlockSpec(transforms=transforms), + ], + ) + def kernel(a_ref, b_ref, o_ref): + def scope(acc_ref): + plgpu.wgmma(acc_ref, a_ref[...], b_ref) + return acc_ref[...] + + o_ref[...] = pl.run_scoped(scope, plgpu.ACC((64, 192), jnp.float32)) + + _ = str( + jax.make_jaxpr(kernel)( + jax.ShapeDtypeStruct((64, 128), jnp.float16), + jax.ShapeDtypeStruct((128, 192), jnp.float16), + ) + ) class ExamplesTest(PallasTest): # Basic def test_stage0(self): - def body(l_ref, r_ref, o_ref): + x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128) + + @functools.partial(self.kernel, out_shape=x) + def kernel(l_ref, r_ref, o_ref): o_ref[...] = l_ref[...] + r_ref[...] - x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128) - out = plgpu.kernel(body, out_shape=x)(x, x) - np.testing.assert_allclose(out, x + x) + np.testing.assert_allclose(kernel(x, x), x + x) # Multi-block kernels def test_stage1(self): row_block = 64 - def body(l_ref, r_ref, o_ref): + x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128) + + @functools.partial( + self.kernel, out_shape=x, grid=(2,), grid_names=("rows",) + ) + def kernel(l_ref, r_ref, o_ref): my_slice = pl.ds(lax.axis_index("rows") * row_block, row_block) o_ref[my_slice] = l_ref[my_slice] + r_ref[my_slice] - x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128) - out = plgpu.kernel(body, out_shape=x, grid=(2,), axis_names=("rows",))(x, x) - np.testing.assert_allclose(out, x + x) + np.testing.assert_allclose(kernel(x, x), x + x) # Async copies def test_stage3(self): row_block, col_block = 64, 128 - def body(l_ref, r_ref, o_ref): + + @functools.partial( + self.kernel, + out_shape=jax.ShapeDtypeStruct((128, 128), jnp.float16), + scratch_shapes=[ + *([plgpu.SMEM((row_block, col_block), jnp.float16)] * 3), + plgpu.Barrier(num_arrivals=2), + ], + grid=(2,), + grid_names=("rows",), + ) + def kernel(l_ref, r_ref, o_ref, l_smem, r_smem, o_smem, barrier): my_slice = pl.ds(lax.axis_index("rows") * row_block, row_block) - def scoped(l_smem, r_smem, o_smem, barrier): - plgpu.copy_gmem_to_smem(l_ref.at[my_slice], l_smem, barrier) - plgpu.copy_gmem_to_smem(r_ref.at[my_slice], r_smem, barrier) - plgpu.barrier_wait(barrier) - o_smem[...] = l_smem[...] + r_smem[...] - plgpu.commit_smem() - plgpu.copy_smem_to_gmem(o_smem, o_ref.at[my_slice]) - plgpu.wait_smem_to_gmem(0) - pl.run_scoped( - scoped, - *([plgpu.SMEM((row_block, col_block), jnp.float16)] * 3), - plgpu.Barrier(num_arrivals=2), - ) + plgpu.copy_gmem_to_smem(l_ref.at[my_slice], l_smem, barrier) + plgpu.copy_gmem_to_smem(r_ref.at[my_slice], r_smem, barrier) + plgpu.barrier_wait(barrier) + o_smem[...] = l_smem[...] + r_smem[...] + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(o_smem, o_ref.at[my_slice]) + plgpu.wait_smem_to_gmem(0) x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128) - out = plgpu.kernel(body, out_shape=x, grid=(2,), axis_names=("rows",))(x, x) - np.testing.assert_allclose(out, x + x) + np.testing.assert_allclose(kernel(x, x), x + x) # Pipelining def test_stage4(self): row_block, col_block = 64, 32 - def body(l_ref, r_ref, o_ref): - def compute(l_smem, r_smem, o_smem): + x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128) + + @functools.partial( + self.kernel, out_shape=x, grid=(2,), grid_names=("rows",) + ) + def kernel(l_ref, r_ref, o_ref): + def compute(_, l_smem, r_smem, o_smem): o_smem[...] = l_smem[...] + r_smem[...] r = lax.axis_index("rows") block = pl.BlockSpec((row_block, col_block), lambda c: (r, c)) @@ -2161,20 +2889,25 @@ def compute(l_smem, r_smem, o_smem): out_specs=[block], )(l_ref, r_ref, o_ref) - x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128) - out = plgpu.kernel(body, out_shape=x, grid=(2,), axis_names=("rows",))(x, x) - np.testing.assert_allclose(out, x + x) + np.testing.assert_allclose(kernel(x, x), x + x) # Transforms def test_stage5(self): + self.skip_if_wg_semantics() # Needs WGMMA to support slices. + row_block, col_block = 64, 32 - def body(l_ref, r_ref, o_ref): - def compute(l_smem, r_smem, o_smem): + x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128) + + @functools.partial( + self.kernel, out_shape=x, grid=(2,), grid_names=("rows",) + ) + def kernel(l_ref, r_ref, o_ref): + def compute(_, l_smem, r_smem, o_smem): o_smem[...] = l_smem[...] + r_smem[...] r = lax.axis_index("rows") block = plgpu.GPUBlockSpec( (row_block, col_block), lambda c: (r, c), - transforms=(plgpu.TilingTransform((64, 32)), plgpu.SwizzleTransform(64)), + transforms=(plgpu.TilingTransform((8, 32)), plgpu.SwizzleTransform(64)), ) plgpu.emit_pipeline( compute, @@ -2183,26 +2916,64 @@ def compute(l_smem, r_smem, o_smem): out_specs=[block], )(l_ref, r_ref, o_ref) - x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128) - out = plgpu.kernel(body, out_shape=x, grid=(2,), axis_names=("rows",))(x, x) - np.testing.assert_allclose(out, x + x) + np.testing.assert_allclose(kernel(x, x), x + x) + + def test_semaphore_lowering(self): + # This is a smoke test until we add support for lowering of semaphore ops. + def body(i_ref1, i_ref2, o_ref, sem_ref): + del i_ref2 # Only here to have a different number of inputs and outputs. + assert sem_ref.shape == (4,) + assert jnp.issubdtype(sem_ref.dtype, pl.semaphore) + o_ref[...] = i_ref1[...] + x = jnp.arange(128, dtype=jnp.float32).reshape((128,)) + kernel = self.pallas_call( + body, + out_shape=x, + scratch_shapes=[plgpu.SemaphoreType.REGULAR((4,))], + ) + text = jax.jit(kernel).lower(x, x).as_text() + self.assertIn( + r"output_operand_aliases =" + r" [#stablehlo.output_operand_alias]", + text, + ) + self.assertIn( + r"(tensor<128xf32>, tensor<128xf32>, tensor<4xi32>) ->" + r" (tensor<128xf32>, tensor<4xi32>)", + text, + ) + + +class ExamplesWGTest( + ExamplesTest, lowering_semantics=plgpu.LoweringSemantics.Warpgroup +): + ... class ExamplesSm90ATest(PallasSm90ATest): # WGMMA def test_stage6(self): + self.skip_if_wg_semantics() # Needs WGMMA to support slices. + m_block = n_block = 64 k_block = 32 - def body(l_ref, r_ref, o_ref): - def compute(l_smem, r_smem, o_smem): + x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128) + + @functools.partial( + self.kernel, out_shape=x, grid=(2, 2), grid_names=("m", "n") + ) + def kernel(l_ref, r_ref, o_ref): + def compute(_, l_smem, r_smem, o_smem): def do_wgmma(acc_ref): plgpu.wgmma(acc_ref, l_smem, r_smem) return acc_ref[...] o_smem[...] += pl.run_scoped(do_wgmma, plgpu.ACC((m_block, n_block), jnp.float16)) - m, n = lax.axis_index("m"), lax.axis_index("n") - lo_transforms = (plgpu.TilingTransform((64, 32)), plgpu.SwizzleTransform(64)) - r_transforms = (plgpu.TilingTransform((32, 32)), plgpu.SwizzleTransform(64)) + m = lax.axis_index("m") + n = lax.axis_index("n") + lo_transforms = (plgpu.TilingTransform((8, 32)), plgpu.SwizzleTransform(64)) + r_transforms = (plgpu.TilingTransform((8, 32)), plgpu.SwizzleTransform(64)) plgpu.emit_pipeline( compute, grid=(l_ref.shape[1] // k_block,), @@ -2211,12 +2982,16 @@ def do_wgmma(acc_ref): out_specs=[plgpu.GPUBlockSpec((m_block, n_block), lambda k: (m, n), transforms=lo_transforms)], )(l_ref, r_ref, o_ref) - x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128) - out = plgpu.kernel(body, out_shape=x, grid=(2, 2), axis_names=("m", "n"))(x, x) - np.testing.assert_allclose(out, x @ x) + np.testing.assert_allclose(kernel(x, x), x @ x) # TODO(apaszke): Clusters and multicast +class ExamplesSm90AWGTest( + ExamplesSm90ATest, lowering_semantics=plgpu.LoweringSemantics.Warpgroup +): + ... + + if __name__ == "__main__": absltest.main() diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 0fc375bf64a1..709828186480 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -30,6 +30,7 @@ from jax._src import linear_util as lu from jax._src import state from jax._src import test_util as jtu +from jax._src.pallas import pallas_call from jax.experimental import pallas as pl from jax.interpreters import partial_eval as pe import jax.numpy as jnp @@ -47,21 +48,18 @@ plgpu_triton = None pltpu = None -try: - import hypothesis as hp -except (ModuleNotFoundError, ImportError): - raise unittest.SkipTest("tests depend on hypothesis library") - +import hypothesis as hp import hypothesis.extra.numpy as hnp import hypothesis.strategies as hps + # There are many inherited redefinitions of _ # ruff: noqa: F811 jax.config.parse_flags_with_absl() jtu.setup_hypothesis(max_examples=50) -use_mosaic_gpu = jax.config.read("jax_pallas_use_mosaic_gpu") +use_mosaic_gpu = pallas_call._PALLAS_USE_MOSAIC_GPU.value intx = dtypes.canonicalize_dtype(jnp.int64) floatx = dtypes.canonicalize_dtype(jnp.float64) @@ -187,7 +185,7 @@ def select_n_strategy( else: pred_dtype = np.int32 pred = draw(arrays(shape=pred_shape, dtype=pred_dtype, - elements=allowed_elements)) + elements=allowed_elements)) cases = ( draw( arrays(shape=case_shape_dtype.shape, dtype=case_shape_dtype.dtype) @@ -203,7 +201,7 @@ def select_n_strategy( # TODO(sharadmv,apaszke): enable zero dim sizes # TODO(sharadmv,apaszke): enable one dim sizes ( - lax.neg_p, + lax.neg_p, {}, make_shape_dtype_strategy( min_rank=2, max_rank=3, @@ -213,7 +211,7 @@ def select_n_strategy( ), ), ( - lax.not_p, + lax.not_p, {}, make_shape_dtype_strategy( min_rank=2, max_rank=3, @@ -225,6 +223,7 @@ def select_n_strategy( *[ ( prim, + params, make_shape_dtype_strategy( min_rank=2, max_rank=3, @@ -233,23 +232,23 @@ def select_n_strategy( valid_dtypes=[jnp.dtype("float32")], ), ) - for prim in [ - lax.exp_p, - lax.tanh_p, - lax.logistic_p, - lax.rsqrt_p, - lax.log_p, - lax.exp2_p, - lax.abs_p, - lax.log1p_p, - lax.sin_p, - lax.sqrt_p, + for prim, params in [ + (lax.abs_p, {}), + (lax.exp_p, {"accuracy": None}), + (lax.tanh_p, {"accuracy": None}), + (lax.logistic_p, {"accuracy": None}), + (lax.rsqrt_p, {"accuracy": None}), + (lax.log_p, {"accuracy": None}), + (lax.exp2_p, {"accuracy": None}), + (lax.log1p_p, {"accuracy": None}), + (lax.sin_p, {"accuracy": None}), + (lax.sqrt_p, {"accuracy": None}), ] ], ] UNARY_FUNCTIONS = [ - (prim.name, prim.bind, strategy) for prim, strategy in UNARY_PRIMITIVES + (prim.name, functools.partial(prim.bind, **params), strategy) for prim, params, strategy in UNARY_PRIMITIVES ] + [ ( name, @@ -294,7 +293,7 @@ def pallas_call(cls, *args, **kwargs): if jtu.test_device_matches(["cuda"]) and use_mosaic_gpu: assert plgpu_mgpu is not None compiler_params = plgpu_mgpu.GPUCompilerParams( - thread_semantics=plgpu_mgpu.ThreadSemantics.Warpgroup + lowering_semantics=plgpu_mgpu.LoweringSemantics.Warpgroup ) kwargs["compiler_params"] = compiler_params @@ -305,6 +304,7 @@ def skip_if_mosaic_gpu(self): self.skipTest("TODO: Mosaic GPU does not support this yet") +@jtu.thread_unsafe_test_class() # hypothesis is not thread safe class OpsTest(PallasBaseTest): @parameterized.named_parameters( @@ -329,7 +329,7 @@ def kernel(x_ref, y_ref, o_ref): x = jnp.full((8, 128), 4, dtype=dtype) y = jnp.full((8, 128), 2 if jnp.issubdtype(dtype, jnp.integer) else 2.0, - dtype=dtype) + dtype=dtype) np.testing.assert_allclose(kernel(x, y), fn(x, y)) @parameterized.named_parameters( @@ -560,7 +560,8 @@ def kernel(*refs): ) @hp.given(hps.data()) def test_unary_primitives(self, name, func, shape_dtype_strategy, data): - self.skip_if_mosaic_gpu() + if name in ["abs", "log1p", "pow2", "reciprocal", "relu", "sin", "sqrt"]: + self.skip_if_mosaic_gpu() if self.INTERPRET: self.skipTest("This hypothesis test is slow, even more so in interpret mode.") @@ -577,6 +578,12 @@ def test_unary_primitives(self, name, func, shape_dtype_strategy, data): def kernel(x_ref, y_ref): y_ref[...] = func(x_ref[...]) x_shape_dtype = data.draw(shape_dtype_strategy) + + sut_is_mosaic_gpu = jtu.test_device_matches(["gpu"]) and use_mosaic_gpu + if sut_is_mosaic_gpu: + hp.assume(math.prod(x_shape_dtype.shape) % 128 == 0) + hp.assume(x_shape_dtype.shape[-1] >= 16) + key = random.key(0) x = _random_value(key, x_shape_dtype) out = self.pallas_call(kernel, out_shape=x_shape_dtype)(x) @@ -1061,8 +1068,8 @@ def kernel(x_ref, o_ref): ( # fmt: off [jnp.expm1, jnp.log1p, jnp.cbrt, lax.rsqrt, jnp.tan, jnp.asin, - jnp.acos, jnp.atan, jnp.sinh, jnp.cosh, jnp.tanh, jnp.asinh, - jnp.acosh, jnp.atanh], + jnp.acos, jnp.atan, jnp.sinh, jnp.cosh, jnp.tanh, jnp.asinh, + jnp.acosh, jnp.atanh], # fmt: on ["bfloat16", "float32", "float64"], ), @@ -1086,7 +1093,7 @@ def test_elementwise(self, fn, dtype): self.skipTest("int16 and float16 are not supported on TPU") if ( fn in (jnp.ceil, jnp.floor, jnp.negative, jnp.exp, jnp.exp2, jnp.log, - jnp.sqrt, lax.rsqrt) + jnp.sqrt, lax.rsqrt) and dtype == "bfloat16" and not jtu.is_device_tpu_at_least(6) ): @@ -1464,7 +1471,7 @@ def kernel(x_ref, y_ref, o_ref): ( # fmt: off [jnp.bitwise_and, jnp.bitwise_or, jnp.bitwise_xor, - jnp.bitwise_left_shift, jnp.bitwise_right_shift], + jnp.bitwise_left_shift, jnp.bitwise_right_shift], # fmt: on ["int32", "uint32"], ), @@ -1525,11 +1532,12 @@ def kernel(x_ref, y_ref, o_ref): np.testing.assert_allclose(f(x, y), kernel(x, y)) @parameterized.parameters( + ((32,), jnp.int32, 0), ((8, 4), jnp.int32, 0), ((8, 16), jnp.float32, 1), ((8, 16, 2), jnp.int8, 1), ) - def test_broadcasted_iota(self, shape, dtype, dimension): + def test_iota(self, shape, dtype, dimension): self.skip_if_mosaic_gpu() if jtu.test_device_matches(["tpu"]): @@ -1739,6 +1747,27 @@ def f(x_ref, o_ref): expected = x.reshape(out_shape) np.testing.assert_allclose(f(x), expected) + def test_reshape_to_scalar(self): + self.skip_if_mosaic_gpu() + # Test reshapes from (1, 1) to (). + # Because TPUs distinguish between VREGs/SREGs this tests an implicit + # copy from VREG -> SREG that must be inserted by Pallas. + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((8, 128), jnp.int32), + ) + def f(x_ref, o_ref): + o_ref[...] = jnp.zeros_like(o_ref) + vector_val = x_ref[1:2, 0:1] + scalar_val = jnp.reshape(vector_val, ()) + o_ref[scalar_val] = jnp.ones_like(o_ref[0]) * scalar_val + + in_shape = (4, 4) + x = jnp.arange(int(np.prod(in_shape)), dtype=jnp.int32).reshape(in_shape) + expected = jnp.zeros((8, 128), jnp.int32) + expected = expected.at[x[1, 0]].set(x[1, 0]) + np.testing.assert_allclose(f(x), expected) + def test_num_programs(self): self.skip_if_mosaic_gpu() @@ -1886,7 +1915,7 @@ def dot(x_ref, y_ref, o_ref): # Pallas always accumulates in FP32, so we are explicit about # preferred_element_type here. expected = jnp.dot(x.T if trans_x else x, y.T if trans_y else y, - preferred_element_type=jnp.float32).astype(dtype) + preferred_element_type=jnp.float32).astype(dtype) np.testing.assert_allclose( out.astype(jnp.float32), expected.astype(jnp.float32), @@ -1936,7 +1965,7 @@ def test_masked_oob_load_store_slice(self): def masked_oob_load_store_slice(x_ref, mask_ref, start_idx_ref, o_ref): x = pl.load(x_ref, (pl.dslice(start_idx_ref[()], n)), mask=mask_ref[:], other=-1.) - pl.store(o_ref, (pl.dslice(None),), x) + o_ref[...] = x x = random.normal(random.key(0), (n,)) slice_start = random.randint(random.key(2), (), 1, n) @@ -2075,7 +2104,7 @@ def test_masked_oob_swap_slice(self): @functools.partial( self.pallas_call, out_shape=(jax.ShapeDtypeStruct((n,), floatx), - jax.ShapeDtypeStruct((m,), floatx)), + jax.ShapeDtypeStruct((m,), floatx)), input_output_aliases={0: 0, 1: 1}, ) def masked_oob_swap_slice(_, _2, mask_ref, start_idx_ref, x_ref, y_ref): @@ -2205,7 +2234,7 @@ def swap(_, lock_ref, out_ref): lock, out = swap(init_value) np.testing.assert_allclose(lock, new_value if cmp == init_value else - init_value) + init_value) np.testing.assert_allclose(out, init_value) @parameterized.parameters(1, 2, 3, 4, 8) @@ -2571,15 +2600,15 @@ def body(x_ref): @parameterized.parameters(*[ (lambda: (pl.dslice(0, 4), slice(None), slice(None)), - "c:i32[4,3,2], a[:,:,:] <-"), + "c:i32[4,3,2], a[:,:,:] <-"), (lambda: (pl.dslice(0, 3), slice(None), slice(None)), - "c:i32[3,3,2], a[:3,:,:] <-"), + "c:i32[3,3,2], a[:3,:,:] <-"), (lambda: (pl.dslice(1, 3), slice(None), pl.dslice(0, 4)), - "c:i32[3,3,4], a[1:,:,:4] <-"), + "c:i32[3,3,4], a[1:,:,:4] <-"), (lambda: (jnp.arange(5), slice(None), pl.dslice(0, 4)), - "e:i32[5,3,4], a[b,:,:4] <-"), + "e:i32[5,3,4], a[b,:,:4] <-"), (lambda: (jnp.arange(5)[:, None], jnp.arange(3)[None], pl.dslice(4)), - "o:i32[5,3,4], a[m,n,:4] <-"), + "o:i32[5,3,4], a[m,n,:4] <-"), ]) def test_swap_pretty_print(self, expr, expected): def body(x_ref): diff --git a/tests/pallas/pallas_error_handling_test.py b/tests/pallas/pallas_error_handling_test.py index cd5ceecfc9a8..84e38f3d09db 100644 --- a/tests/pallas/pallas_error_handling_test.py +++ b/tests/pallas/pallas_error_handling_test.py @@ -92,7 +92,7 @@ def kernel_in_jitted_fn(x): tb_string = "".join(tb_string) self.assertEndsWith(tb_string, "x = input_ref[:, ::8]\n") - def test_invalid_smem_vmem_verification_error(self): + def test_index_with_f32_verification_error(self): input_arr = jax.random.uniform(jax.random.key(0), (2, 2), dtype=jnp.float32) out_shape = jax.ShapeDtypeStruct((1, 1), jnp.float32) grid_spec = pltpu.PrefetchScalarGridSpec( @@ -105,7 +105,8 @@ def test_invalid_smem_vmem_verification_error(self): @functools.partial(pl.pallas_call, out_shape=out_shape, grid_spec=grid_spec) def test_kernel(input_ref, output_ref): - output_ref[0, 0] = input_ref[0, 0] + idx = input_ref[0, 0] + output_ref[idx, 0] = input_ref[0, 0] # Test that a verification error is raised. This assert is a guard against # underlying changes in Pallas lowering. @@ -113,8 +114,8 @@ def test_kernel(input_ref, output_ref): # the test example to force a different error. with self.assertRaisesRegex( error_handling.VerificationError, - "'memref.store' op failed to verify that type of 'value' matches " - "element type of 'memref'", + "must be signless-integer-like or memref of signless-integer, " + "but got 'f32'" ): test_kernel(input_arr) @@ -125,7 +126,7 @@ def test_kernel(input_ref, output_ref): except error_handling.MosaicError as e: tb_string = traceback.format_tb(e.__traceback__) tb_string = "".join(tb_string) - self.assertEndsWith(tb_string, "output_ref[0, 0] = input_ref[0, 0]\n") + self.assertEndsWith(tb_string, "output_ref[idx, 0] = input_ref[0, 0]\n") def test_parse_location_string(self): name, frames = error_handling.parse_location_string(LOCATION_TEST_STRING) diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index 745c30ba98cb..781934ecd682 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -128,8 +128,8 @@ def matmul_block_spec(x, y, *, bm, bn, bk, interpret, debug=False): def matmul_kernel(x_ref, y_ref, o_ref): acc = jnp.zeros(o_ref.shape, dtype=jnp.float32) def body(i, acc_ref): - x_block = pl.load(x_ref, (slice(None), pl.ds(i * bk, bk))) - y_block = pl.load(y_ref, (pl.ds(i * bk, bk), slice(None))) + x_block = x_ref[:, pl.ds(i * bk, bk)] + y_block = y_ref[pl.ds(i * bk, bk), :] acc_ref[:, :] += pl.dot(x_block, y_block) acc = for_loop(k // bk, body, acc).astype(o_ref.dtype) o_ref[:, :] = acc @@ -624,8 +624,9 @@ def test_unused_ref(self): out_shape=jax.ShapeDtypeStruct((m, n), jnp.float32), ) def dummy(_, o_ref): - pl.store(o_ref, (jnp.arange(m)[:, None], jnp.arange(n)[None, :]), - jnp.ones_like(o_ref)) + o_ref[jnp.arange(m)[:, None], jnp.arange(n)[None, :]] = jnp.ones_like( + o_ref + ) key = random.key(0) x = random.normal(key, (m, n)) @@ -667,8 +668,7 @@ def test_using_pallas_slice(self): out_shape=out_shape, ) def slice_kernel(x_ref, y_ref): - x = pl.load(x_ref, (pl.dslice(0, 4), pl.dslice(0, 4))) - pl.store(y_ref, (pl.dslice(4), pl.dslice(4)), x) + y_ref[:4, :4] = x_ref[:4, :4] x = random.normal(random.key(0), (m, n)) y = slice_kernel(x) y_ref = x[:4] @@ -702,6 +702,9 @@ def f(x): ("float32", jax.lax.DotAlgorithmPreset.DEFAULT), ("float32", jax.lax.DotAlgorithmPreset.F16_F16_F32), ("float32", jax.lax.DotAlgorithmPreset.BF16_BF16_F32), + ("float32", jax.lax.DotAlgorithmPreset.BF16_BF16_F32_X3), + ("float32", jax.lax.DotAlgorithmPreset.BF16_BF16_F32_X6), + ("float32", jax.lax.DotAlgorithmPreset.BF16_BF16_F32_X9), ("float32", jax.lax.DotAlgorithmPreset.TF32_TF32_F32), ("float32", jax.lax.DotAlgorithmPreset.TF32_TF32_F32_X3), ("float32", jax.lax.DotAlgorithmPreset.F32_F32_F32), @@ -731,7 +734,21 @@ def dot_kernel(x_ref, y_ref, o_ref): precision=jax.lax.Precision.HIGHEST, preferred_element_type=jnp.float32, ) - self.assertAllClose(dot_kernel(x, y), expected, atol=5e-2, rtol=5e-3) + if dtype == "bfloat16" or precision in ( + jax.lax.Precision.HIGHEST, + jax.lax.DotAlgorithmPreset.F32_F32_F32, + jax.lax.DotAlgorithmPreset.BF16_BF16_F32_X6, + jax.lax.DotAlgorithmPreset.BF16_BF16_F32_X9, + ): + atol = 5e-6 + elif precision in ( + jax.lax.DotAlgorithmPreset.BF16_BF16_F32_X3, + jax.lax.DotAlgorithmPreset.TF32_TF32_F32_X3, + ): + atol = 5e-4 + else: + atol = 5e-2 + self.assertAllClose(dot_kernel(x, y), expected, atol=atol, rtol=atol / 10) @parameterized.parameters(jnp.int8, jnp.uint8) def test_integer_dot(self, dtype): @@ -826,6 +843,25 @@ def dot_kernel(x_ref, y_ref, o_ref): self.assertAllClose(dot_kernel(x, y), expected) + @parameterized.parameters( + ((32,), 2, 0), ((32, 64), 4, 0), ((32, 16), 8, 1), ((32, 16, 2), 16, 1) + ) + def test_split(self, shape, num_parts, axis): + if jtu.test_device_matches(["tpu"]) and shape[axis] == num_parts: + self.skipTest("TPU doesn't support fully split axis.") + + x = jax.random.normal(jax.random.key(0), shape) + expected = jnp.split(x, num_parts, axis) + + @functools.partial(self.pallas_call, out_shape=expected) + def kernel(x_ref, *o_ref): + x_parts = jnp.split(x_ref[()], num_parts, axis) + for o_ref, x_part in zip(o_ref, x_parts): + o_ref[...] = x_part + + self.assertAllClose(kernel(x), expected) + + class PallasCallInterpretTest(PallasCallTest): INTERPRET = True @@ -1697,7 +1733,7 @@ def test_range_while_loop(self): def kernel(x_ref, r_ref): @pl.when(pl.program_id(0) == 0) def _(): - pl.store(r_ref, (0, 0), 0) + r_ref[0, 0] = 0 def cond(carry): i, j = carry @@ -1709,8 +1745,7 @@ def body(carry): sl = jax.lax.div(i, 128) l = jax.lax.rem(i, 128) v = x_ref[0, sl, l] - s = pl.load(r_ref, (0, 0)) - pl.store(r_ref, (0, 0), s + v) + r_ref[0, 0] += v return io + 1, j i = 128 @@ -1762,7 +1797,7 @@ def test_non_range_while_loop(self): def kernel(x_ref, r_ref): @pl.when(pl.program_id(0) == 0) def _(): - pl.store(r_ref, (0, 0), 0) + r_ref[0, 0] = 0 def cond(state): i, s = state @@ -1772,14 +1807,11 @@ def body(state): i, s = state sl = jax.lax.div(i, jnp.astype(128, i.dtype)) l = jax.lax.rem(i, jnp.astype(128, i.dtype)) - v = pl.load(x_ref, (0, sl, l)) + v = x_ref[0, sl, l] return i + 1, s + v i = jnp.int32(0) - s = pl.load(r_ref, (0, 0)) - - i, s = jax.lax.while_loop(cond, body, (i, s)) - pl.store(r_ref, (0, 0), s) + _, r_ref[0, 0] = jax.lax.while_loop(cond, body, (i, r_ref[0, 0])) x = jnp.arange(4096) x = jnp.reshape(x, [4, 8, 128]) diff --git a/tests/pallas/tpu_all_gather_test.py b/tests/pallas/tpu_all_gather_test.py index 98b3e5b40135..47168e1c35b4 100644 --- a/tests/pallas/tpu_all_gather_test.py +++ b/tests/pallas/tpu_all_gather_test.py @@ -25,114 +25,109 @@ import jax.numpy as jnp import numpy as np -try: - import hypothesis as hp - import hypothesis.strategies as hps - CAN_USE_HYPOTHESIS = True -except (ModuleNotFoundError, ImportError): - CAN_USE_HYPOTHESIS = False +import hypothesis as hp +import hypothesis.strategies as hps jax.config.parse_flags_with_absl() P = jax.sharding.PartitionSpec -if CAN_USE_HYPOTHESIS: - - hp.settings.register_profile( - "deterministic", - database=None, - derandomize=True, - deadline=None, - max_examples=50, - print_blob=True, - verbosity=hp.Verbosity.verbose, +hp.settings.register_profile( + "deterministic", + database=None, + derandomize=True, + deadline=None, + max_examples=50, + print_blob=True, + verbosity=hp.Verbosity.verbose, +) +hp.settings.load_profile("deterministic") + + +@hps.composite +def _array_shapes(draw): + # TODO(sharadmv, apaszke): enable this on a wider variety of shapes + valid_shapes = [ + (128, 128), + (256, 128), + (256, 512), + (256, 1024), + # TODO(sharadmv,apaszke): enable these shapes + # (256, 129), + # (129, 128), + # (64, 64), + # (1, 1), + ] + return draw(hps.sampled_from(valid_shapes)) + + +@hps.composite +def _array_dtypes(draw): + return draw( + hps.sampled_from([ + jnp.float32, + jnp.bfloat16, + jnp.int32, + # jnp.float16, # TODO(sharadmv,apaszke): enable float16 all gather + # jnp.int16, # TODO(sharadmv,apaszke): enable int16 all gather + # jnp.int8, # TODO(sharadmv,apaszke): enable int8 all gather + ]) ) - hp.settings.load_profile("deterministic") - - - @hps.composite - def _array_shapes(draw): - # TODO(sharadmv, apaszke): enable this on a wider variety of shapes - valid_shapes = [ - (128, 128), - (256, 128), - (256, 512), - (256, 1024), - # TODO(sharadmv,apaszke): enable these shapes - # (256, 129), - # (129, 128), - # (64, 64), - # (1, 1), - ] - return draw(hps.sampled_from(valid_shapes)) - - - @hps.composite - def _array_dtypes(draw): - return draw( - hps.sampled_from([ - jnp.float32, - jnp.bfloat16, - jnp.int32, - # jnp.float16, # TODO(sharadmv,apaszke): enable float16 all gather - # jnp.int16, # TODO(sharadmv,apaszke): enable int16 all gather - # jnp.int8, # TODO(sharadmv,apaszke): enable int8 all gather - ]) - ) - class AllGatherTest(jtu.JaxTestCase): - - def setUp(self): - if not jtu.test_device_matches(["tpu"]): - self.skipTest("Need TPU devices") - if not jtu.is_device_tpu(version=5, variant="e"): - # TODO(sharadmv,apaszke): expand support to more versions - self.skipTest("Currently only supported on TPU v5e") - - super().setUp() - - @hp.given(hps.booleans(), _array_shapes(), _array_dtypes()) - def test_all_gather_1d_mesh(self, is_vmem, shape, dtype): - if jax.device_count() < 2: - self.skipTest("Need more devices") - memory_space = pltpu.VMEM if is_vmem else pltpu.ANY - mesh_shape = (jax.device_count(),) - mesh = jax.sharding.Mesh( - mesh_utils.create_device_mesh(mesh_shape, jax.devices()), ["x"] - ) - leading, *rest = shape - shape = (mesh.shape["x"] * leading, *rest) - x = random.normal(random.key(0), shape, dtype=jnp.float32).astype(dtype) - x_sharded = jax.device_put(x, jax.sharding.NamedSharding(mesh, P("x"))) - y = all_gather.all_gather(x_sharded, mesh=mesh, axis_name="x", - memory_space=memory_space) - np.testing.assert_array_equal(y, x) - - @hp.given(hps.booleans(), _array_shapes(), _array_dtypes(), - hps.sampled_from(["x", "y"])) - def test_all_gather_2d_mesh(self, is_vmem, shape, dtype, - axis_name): - if jax.device_count() < 2: - self.skipTest("Need more devices") - if jax.device_count() % 2: - self.skipTest("Need an even number of devices") - memory_space = pltpu.VMEM if is_vmem else pltpu.ANY - mesh_shape = (2, jax.device_count() // 2) - mesh = jax.sharding.Mesh( - mesh_utils.create_device_mesh(mesh_shape, jax.devices()), ["x", "y"] - ) - if axis_name == "x": - sharding = jax.sharding.NamedSharding(mesh, P("x", None)) - else: - sharding = jax.sharding.NamedSharding(mesh, P("y", None)) - leading, *rest = shape - shape = (mesh.shape[axis_name] * leading, *rest) - x = random.normal(random.key(0), shape, dtype=jnp.float32).astype(dtype) - x_sharded = jax.device_put(x, sharding) - y = all_gather.all_gather(x_sharded, mesh=mesh, axis_name=axis_name, - memory_space=memory_space) - np.testing.assert_array_equal(y, x) +@jtu.thread_unsafe_test_class() # hypothesis is not thread safe +class AllGatherTest(jtu.JaxTestCase): + + def setUp(self): + if not jtu.test_device_matches(["tpu"]): + self.skipTest("Need TPU devices") + if not jtu.is_device_tpu(version=5, variant="e"): + # TODO(sharadmv,apaszke): expand support to more versions + self.skipTest("Currently only supported on TPU v5e") + + super().setUp() + + @hp.given(hps.booleans(), _array_shapes(), _array_dtypes()) + def test_all_gather_1d_mesh(self, is_vmem, shape, dtype): + if jax.device_count() < 2: + self.skipTest("Need more devices") + memory_space = pltpu.VMEM if is_vmem else pltpu.ANY + mesh_shape = (jax.device_count(),) + mesh = jax.sharding.Mesh( + mesh_utils.create_device_mesh(mesh_shape, jax.devices()), ["x"] + ) + leading, *rest = shape + shape = (mesh.shape["x"] * leading, *rest) + x = random.normal(random.key(0), shape, dtype=jnp.float32).astype(dtype) + x_sharded = jax.device_put(x, jax.sharding.NamedSharding(mesh, P("x"))) + y = all_gather.all_gather(x_sharded, mesh=mesh, axis_name="x", + memory_space=memory_space) + np.testing.assert_array_equal(y, x) + + @hp.given(hps.booleans(), _array_shapes(), _array_dtypes(), + hps.sampled_from(["x", "y"])) + def test_all_gather_2d_mesh(self, is_vmem, shape, dtype, + axis_name): + if jax.device_count() < 2: + self.skipTest("Need more devices") + if jax.device_count() % 2: + self.skipTest("Need an even number of devices") + memory_space = pltpu.VMEM if is_vmem else pltpu.ANY + mesh_shape = (2, jax.device_count() // 2) + mesh = jax.sharding.Mesh( + mesh_utils.create_device_mesh(mesh_shape, jax.devices()), ["x", "y"] + ) + if axis_name == "x": + sharding = jax.sharding.NamedSharding(mesh, P("x", None)) + else: + sharding = jax.sharding.NamedSharding(mesh, P("y", None)) + leading, *rest = shape + shape = (mesh.shape[axis_name] * leading, *rest) + x = random.normal(random.key(0), shape, dtype=jnp.float32).astype(dtype) + x_sharded = jax.device_put(x, sharding) + y = all_gather.all_gather(x_sharded, mesh=mesh, axis_name=axis_name, + memory_space=memory_space) + np.testing.assert_array_equal(y, x) if __name__ == "__main__": diff --git a/tests/pallas/tpu_fusable_matmul_test.py b/tests/pallas/tpu_fusible_matmul_test.py similarity index 93% rename from tests/pallas/tpu_fusable_matmul_test.py rename to tests/pallas/tpu_fusible_matmul_test.py index df7c1221bb0c..2382c09f26ac 100644 --- a/tests/pallas/tpu_fusable_matmul_test.py +++ b/tests/pallas/tpu_fusible_matmul_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Fusable matmul test.""" +"""Fusible matmul test.""" import functools from typing import Any @@ -71,10 +71,11 @@ def _(): def _(): acc = acc_ref[...].astype(out_dtype) z_values = jax.tree.map(lambda ref: ref.get(), z_value_refs) - o_ref[...] = z_fn(pids, scalar_prefetch, z_values, acc) + out = z_fn(pids, scalar_prefetch, z_values, acc) + jax.tree.map(lambda ref, x: ref.set(x), o_ref, out) -def _fusable_matmul( +def _fusible_matmul( x: fuser.Fusion[[], jax.Array], # pytype: disable=invalid-annotation y: fuser.Fusion[[], jax.Array], # pytype: disable=invalid-annotation z: fuser.Fusion[[jax.Array], jax.Array] | None, # pytype: disable=invalid-annotation @@ -174,12 +175,12 @@ def z_index_map(i, j, k, *_): y_value_block_specs, z_value_block_specs, ], - out_specs=z_out_block_spec, + out_specs=[z_out_block_spec], ), compiler_params=pltpu.TPUCompilerParams( dimension_semantics=dimension_semantics, ), - out_shape=z_out_type, + out_shape=[z_out_type], interpret=interpret, debug=debug, )( @@ -187,10 +188,10 @@ def z_index_map(i, j, k, *_): x_values, y_values, z_values, - ) + )[0] -def fusable_matmul( +def fusible_matmul( x: jax.Array, y: jax.Array, *, @@ -200,9 +201,9 @@ def fusable_matmul( debug: bool = False, interpret: bool = False, ) -> jax.Array: - return fuser.fusable( + return fuser.fusible( functools.partial( - _fusable_matmul, + _fusible_matmul, bm=bm, bk=bk, bn=bn, @@ -212,7 +213,7 @@ def fusable_matmul( )(x, y) -class FusableMatmulTest(jtu.JaxTestCase): +class FusibleMatmulTest(jtu.JaxTestCase): def setUp(self): if not jtu.is_device_tpu_at_least(4): @@ -225,7 +226,7 @@ def test_matmul(self, dtype): x = jax.random.normal(k0, (512, 512), dtype) y = jax.random.normal(k1, (512, 512), dtype) np.testing.assert_allclose( - jax.jit(fusable_matmul)(x, y), mm_ref(x, y), atol=5e-5 + jax.jit(fusible_matmul)(x, y), mm_ref(x, y), atol=5e-5 ) @parameterized.parameters('float32', 'bfloat16') @@ -237,7 +238,7 @@ def test_matmul_with_activation(self, dtype): @jax.jit @fuser.fuse def matmul_relu(x, y): - x = fusable_matmul(x, y) + x = fusible_matmul(x, y) x = jnp.maximum(x, 0.0) return x @@ -257,7 +258,7 @@ def test_matmul_with_bias(self, dtype): @jax.jit @fuser.fuse def matmul_bias(x, y, b): - x = fusable_matmul(x, y).astype(dtype) + b + x = fusible_matmul(x, y).astype(dtype) + b x = jnp.maximum(x, 0.0) return x @@ -276,7 +277,7 @@ def test_matmul_with_slice(self, dtype): @jax.jit @fuser.fuse def matmul_slice(x, y): - x = fusable_matmul(x, y[1]) + x = fusible_matmul(x, y[1]) return x np.testing.assert_allclose(matmul_slice(x, y), mm_ref(x, y[1]), atol=5e-5) @@ -290,7 +291,7 @@ def test_matmul_with_dynamic_slice(self, dtype): @jax.jit @fuser.fuse def matmul_slice(x, y, i): - x = fusable_matmul(x, y[i]) + x = fusible_matmul(x, y[i]) return x np.testing.assert_allclose( @@ -307,7 +308,7 @@ def test_matmul_with_dynamic_slice_bias(self, dtype): @jax.jit @fuser.fuse def matmul_slice(x, y, b, i, j): - x = fusable_matmul(x, y[j]).astype(dtype) + b[i] + x = fusible_matmul(x, y[j]).astype(dtype) + b[i] return x np.testing.assert_allclose( @@ -325,7 +326,7 @@ def test_matmul_with_multi_slice(self, dtype): @jax.jit @fuser.fuse def matmul_slice(x, y): - x = fusable_matmul(x, y[1, 1]) + x = fusible_matmul(x, y[1, 1]) return x np.testing.assert_allclose( @@ -341,7 +342,7 @@ def test_matmul_with_multiple_slices(self, dtype): @jax.jit @fuser.fuse def matmul_slice(x, y): - x = fusable_matmul(x, y[1][1]) + x = fusible_matmul(x, y[1][1]) return x np.testing.assert_allclose( @@ -357,7 +358,7 @@ def test_matmul_with_multiple_dynamic_slices(self, dtype): @jax.jit @fuser.fuse def matmul_slice(x, y, i, j): - x = fusable_matmul(x, y[i][j]) + x = fusible_matmul(x, y[i][j]) return x for i in range(2): @@ -375,7 +376,7 @@ def test_matmul_with_mixed_slices(self, dtype): @jax.jit @fuser.fuse def matmul_slice(x, y, i, j): - x = fusable_matmul(x, y[2][i, j]) + x = fusible_matmul(x, y[2][i, j]) return x for i in range(2): @@ -396,7 +397,7 @@ def test_matmul_with_multiple_mixed_slices_and_bias(self, dtype): @jax.jit @fuser.fuse def matmul_slice(x, y, b, i, j, k): - x = fusable_matmul(x[k][3], y[2][i, j]).astype(dtype) + x = fusible_matmul(x[k][3], y[2][i, j]).astype(dtype) return x + b[i, j] @jit_no_excess_precision @@ -427,7 +428,7 @@ def test_matmul_input_concat_output(self, dtype): @fuser.fuse def matmul_concat(x, ys): y = jnp.concatenate(ys, axis=1) - x = fusable_matmul(x, y) + x = fusible_matmul(x, y) return x @jax.jit @@ -453,7 +454,7 @@ def test_matmul_input_concat_contract(self, dtype): @fuser.fuse def matmul_concat(x, ys): y = jnp.concatenate(ys, axis=0) - x = fusable_matmul(x, y) + x = fusible_matmul(x, y) return x @jit_no_excess_precision @@ -481,7 +482,7 @@ def test_matmul_double_concat(self, dtype): def matmul_concat(x, ys, y3): y = jnp.concatenate(ys, axis=0) y = jnp.concatenate([y, y3], axis=1) - x = fusable_matmul(x, y) + x = fusible_matmul(x, y) return x @jit_no_excess_precision @@ -508,7 +509,7 @@ def test_matmul_slice_concat(self, dtype): @fuser.fuse def matmul_concat(x, y1, y2): y = jnp.concatenate([y1, y2[3]], axis=0) - x = fusable_matmul(x, y) + x = fusible_matmul(x, y) return x @jit_no_excess_precision @@ -533,7 +534,7 @@ def test_matmul_slice_concat_slice(self, dtype): @fuser.fuse def matmul_concat(x, y1, y2): y = jnp.concatenate([y1, y2[3]], axis=1)[1] - x = fusable_matmul(x, y) + x = fusible_matmul(x, y) return x @jit_no_excess_precision @@ -558,7 +559,7 @@ def test_matmul_dynamic_slice_concat(self, dtype): @fuser.fuse def matmul_concat(x, y1, y2, i, j): y = jnp.concatenate([y1, y2[i]], axis=1)[j] - x = fusable_matmul(x, y) + x = fusible_matmul(x, y) return x @jit_no_excess_precision @@ -584,7 +585,7 @@ def matmul(impl, x, y): return z impl = fuser.fuse( - functools.partial(matmul, functools.partial(fusable_matmul, bn=256)) + functools.partial(matmul, functools.partial(fusible_matmul, bn=256)) ) ref = functools.partial(matmul, mm_ref) @@ -606,7 +607,7 @@ def matmul(impl, x, y): return z impl = fuser.fuse( - functools.partial(matmul, functools.partial(fusable_matmul, bn=256)) + functools.partial(matmul, functools.partial(fusible_matmul, bn=256)) ) ref = functools.partial(matmul, mm_ref) @@ -628,7 +629,7 @@ def matmul(impl, x, y): return z impl = fuser.fuse( - functools.partial(matmul, functools.partial(fusable_matmul, bn=256)) + functools.partial(matmul, functools.partial(fusible_matmul, bn=256)) ) ref = functools.partial(matmul, mm_ref) @@ -650,7 +651,7 @@ def matmul(impl, x, y): return z impl = fuser.fuse( - functools.partial(matmul, functools.partial(fusable_matmul, bn=256)) + functools.partial(matmul, functools.partial(fusible_matmul, bn=256)) ) ref = functools.partial(matmul, mm_ref) @@ -672,7 +673,7 @@ def matmul(impl, x, y): return z impl = fuser.fuse( - functools.partial(matmul, functools.partial(fusable_matmul, bn=256)) + functools.partial(matmul, functools.partial(fusible_matmul, bn=256)) ) ref = functools.partial(matmul, mm_ref) @@ -694,7 +695,7 @@ def matmul(impl, x, y): return z impl = fuser.fuse( - functools.partial(matmul, functools.partial(fusable_matmul, bn=256)) + functools.partial(matmul, functools.partial(fusible_matmul, bn=256)) ) ref = functools.partial(matmul, mm_ref) @@ -715,7 +716,7 @@ def matmul(impl, x, y): return z impl = fuser.fuse( - functools.partial(matmul, functools.partial(fusable_matmul, bn=256)) + functools.partial(matmul, functools.partial(fusible_matmul, bn=256)) ) ref = functools.partial(matmul, mm_ref) @@ -737,7 +738,7 @@ def matmul(impl, x, y): return z impl = fuser.fuse( - functools.partial(matmul, functools.partial(fusable_matmul, bm=256)) + functools.partial(matmul, functools.partial(fusible_matmul, bm=256)) ) ref = functools.partial(matmul, mm_ref) @@ -759,7 +760,7 @@ def matmul(impl, x, y): return z impl = fuser.fuse( - functools.partial(matmul, functools.partial(fusable_matmul, bm=256)) + functools.partial(matmul, functools.partial(fusible_matmul, bm=256)) ) ref = functools.partial(matmul, mm_ref) @@ -781,7 +782,7 @@ def matmul(impl, x, y): return z.T impl = fuser.fuse( - functools.partial(matmul, functools.partial(fusable_matmul, bn=256)) + functools.partial(matmul, functools.partial(fusible_matmul, bn=256)) ) ref = functools.partial(matmul, mm_ref) @@ -803,7 +804,7 @@ def matmul(impl, x, y): return z.T * 2 impl = fuser.fuse( - functools.partial(matmul, functools.partial(fusable_matmul, bn=256)) + functools.partial(matmul, functools.partial(fusible_matmul, bn=256)) ) ref = functools.partial(matmul, mm_ref) @@ -866,7 +867,7 @@ def matmul(impl, x, y): impl = fuser.fuse( functools.partial( matmul, - fusable_matmul, + fusible_matmul, ) ) ref = functools.partial(matmul, dot_ref) @@ -892,7 +893,7 @@ def matmul(impl, x, y): out_ref = jit_no_excess_precision(ref)(x, y) - impl = fuser.fuse(functools.partial(matmul, fusable_matmul)) + impl = fuser.fuse(functools.partial(matmul, fusible_matmul)) out = jax.jit(impl)(x, y) self.assertAllClose(out, out_ref, atol=0) @@ -916,7 +917,7 @@ def matmul(impl, x, y): impl = fuser.fuse( functools.partial( matmul, - functools.partial(fusable_matmul, bk=256, bn=128), + functools.partial(fusible_matmul, bk=256, bn=128), ) ) out = jax.jit(impl)(x, y) @@ -924,7 +925,7 @@ def matmul(impl, x, y): atol = 0 if jtu.is_device_tpu_at_least(6): # 256 MXU changes some tols. - atol = 1e-6 + atol = 1e-5 self.assertAllClose(out, out_ref, atol=atol) def test_matmul_f32_out_fused_downcast(self): @@ -952,7 +953,7 @@ def matmul(impl, x, y): functools.partial( matmul, functools.partial( - fusable_matmul, + fusible_matmul, bm=bm, bk=bk, bn=bn, @@ -989,7 +990,7 @@ def matmul(impl, x, y): functools.partial( matmul, functools.partial( - fusable_matmul, + fusible_matmul, bm=bm, bk=bk, bn=bn, @@ -1024,7 +1025,7 @@ def matmul(impl, x, y): functools.partial( matmul, functools.partial( - fusable_matmul, + fusible_matmul, bm=bm, bk=bk, bn=bn, diff --git a/tests/pallas/tpu_gmm_test.py b/tests/pallas/tpu_gmm_test.py index 9c416dabaeb1..7bc698794f09 100644 --- a/tests/pallas/tpu_gmm_test.py +++ b/tests/pallas/tpu_gmm_test.py @@ -24,12 +24,8 @@ import jax.numpy as jnp import numpy as np -try: - import hypothesis as hp - import hypothesis.strategies as hps - CAN_USE_HYPOTHESIS = True -except (ModuleNotFoundError, ImportError): - CAN_USE_HYPOTHESIS = False +import hypothesis as hp +import hypothesis.strategies as hps jax.config.parse_flags_with_absl() @@ -37,326 +33,326 @@ partial = functools.partial -if CAN_USE_HYPOTHESIS: - hp.settings.register_profile( - "deterministic", - database=None, - derandomize=True, - deadline=None, - max_examples=10, - print_blob=True, +hp.settings.register_profile( + "deterministic", + database=None, + derandomize=True, + deadline=None, + max_examples=10, + print_blob=True, +) +hp.settings.load_profile("deterministic") + +def seed_strategy() -> hps.SearchStrategy[int]: + return hps.integers(min_value=0, max_value=4) + +@hps.composite +def group_strategy( + draw: hps.DrawFn, + max_groups: int = 32, + max_stride: int = 32, + min_groups: int = 1, +) -> tuple[int, int]: + assert max_stride <= max_groups + + # Sample the number of groups owned by each shard. + group_stride = draw(hps.integers(min_value=1, max_value=max_stride)) + + # Sample the number of groups as a multiple of the stride to ensure that we + # have an equal number of groups per shard. Round down s.t. num_groups <= + # max_groups. + num_groups = group_stride * draw( + hps.integers(min_value=min_groups, max_value=max_groups // group_stride) ) - hp.settings.load_profile("deterministic") - - def seed_strategy() -> hps.SearchStrategy[int]: - return hps.integers(min_value=0, max_value=4) - - @hps.composite - def group_strategy( - draw: hps.DrawFn, - max_groups: int = 32, - max_stride: int = 32, - min_groups: int = 1, - ) -> tuple[int, int]: - assert max_stride <= max_groups - - # Sample the number of groups owned by each shard. - group_stride = draw(hps.integers(min_value=1, max_value=max_stride)) - - # Sample the number of groups as a multiple of the stride to ensure that we - # have an equal number of groups per shard. Round down s.t. num_groups <= - # max_groups. - num_groups = group_stride * draw( - hps.integers(min_value=min_groups, max_value=max_groups // group_stride) - ) - return num_groups, group_stride - - @hps.composite - def group_sizes_strategy( - draw: hps.DrawFn, m: int, num_groups: int - ) -> jnp.ndarray: - # Randomly sample the ends of the groups in the m-dimension. Let the fuzzer - # sample with replacement so that it's possible to get zero-sized groups. Get - # 'num_groups - 1' run ends. The final group will end at 'm'. - ends_no_final = np.sort( - np.array( - [ - draw(hps.integers(min_value=0, max_value=m)) - for _ in range(num_groups - 1) - ], - dtype=np.int32, - ), + return num_groups, group_stride + +@hps.composite +def group_sizes_strategy( + draw: hps.DrawFn, m: int, num_groups: int +) -> jnp.ndarray: + # Randomly sample the ends of the groups in the m-dimension. Let the fuzzer + # sample with replacement so that it's possible to get zero-sized groups. Get + # 'num_groups - 1' run ends. The final group will end at 'm'. + ends_no_final = np.sort( + np.array( + [ + draw(hps.integers(min_value=0, max_value=m)) + for _ in range(num_groups - 1) + ], + dtype=np.int32, + ), + ) + ends = np.concatenate([ends_no_final, np.array([m], dtype=np.int32)]) + + # Calculate the run starts by shifting ends 1 to the right. The first run + # starts at zero. + starts = np.concatenate([np.zeros(1, dtype=np.int32), ends_no_final]) + return jnp.array(ends - starts, dtype=jnp.int32) + +GROUPED_MATMUL_TESTS = ( + (128, 128, 128), # Small + (512, 2048, 256), # Big + (128, 8, 16), # Test partial tiles. +) + +def random_dense( + shape: tuple[int, ...], + key: jax.Array, + dtype: jnp.dtype, + limit: int | None = None, +) -> jnp.ndarray: + if limit is None: + limit = 1 / np.prod(shape) + x = jax.random.uniform(key, shape, dtype, minval=-limit, maxval=limit) # pylint: disable=invalid-unary-operand-type + return x.astype(jnp.bfloat16).astype(dtype) + +def dot( + lhs: jnp.ndarray, + rhs: jnp.ndarray, + transpose_lhs: bool = False, + transpose_rhs: bool = False, + preferred_element_type: jnp.dtype = jnp.float32, +) -> jnp.ndarray: + lhs = jnp.transpose(lhs) if transpose_lhs else lhs + rhs = jnp.transpose(rhs) if transpose_rhs else rhs + return jax.lax.dot(lhs, rhs, preferred_element_type=preferred_element_type) + +def reference_gmm( + lhs: jnp.ndarray, + rhs: jnp.ndarray, + group_sizes: jnp.ndarray, + preferred_element_type: jnp.dtype = jnp.float32, +) -> jnp.ndarray: + + start = 0 + out = [] + for i, size in enumerate(group_sizes): + result = dot( + lhs[start : start + size, :], + rhs[i, :, :], + preferred_element_type=preferred_element_type, ) - ends = np.concatenate([ends_no_final, np.array([m], dtype=np.int32)]) - # Calculate the run starts by shifting ends 1 to the right. The first run - # starts at zero. - starts = np.concatenate([np.zeros(1, dtype=np.int32), ends_no_final]) - return jnp.array(ends - starts, dtype=jnp.int32) + out.append(result) + start += group_sizes[i] + return jnp.concatenate(out, axis=0) + +def with_dtype_arguments(xs: tuple[Any, ...]) -> tuple[Any, ...]: + dtypes = [jnp.float32, jnp.bfloat16] + + result = [] + for x in xs: + for dtypes_tuple in itertools.product(dtypes, dtypes, dtypes): + result.append(x + dtypes_tuple) + return tuple(result) + +def with_transpose_argument(xs: tuple[Any, ...]) -> tuple[Any, ...]: + flags = [False, True] + result = [] + for x in xs: + for flag in flags: + result.append(x + (flag,)) + return tuple(result) + +def tolerances( + lhs_dtype: jnp.dtype, rhs_dtype: jnp.dtype, out_dtype: jnp.dtype +) -> tuple[float, float]: + if ( + lhs_dtype == jnp.bfloat16 + or rhs_dtype == jnp.bfloat16 + or out_dtype == jnp.bfloat16 + ): + return 1e-3, 1e-2 # atol, rtol + return 1e-3, 1e-5 # atol, rtol + +# TODO(tgale): Fix errors with strict dtype promotion. +@jtu.with_config(jax_numpy_dtype_promotion="standard") +@jtu.thread_unsafe_test_class() # hypothesis is not thread safe +class GroupedMatmulTest(jtu.JaxTestCase): + + def setUp(self): + if not jtu.test_device_matches(["tpu"]): + self.skipTest("Test requires TPU device.") + + super().setUp() + self.key = jax.random.PRNGKey(1234) + + def assert_allclose( + self, + out: jnp.ndarray, + expected_out: jnp.ndarray, + *, + atol: float = 1e-5, + rtol: float = 1e-5, + ): + self.assertEqual(out.dtype, expected_out.dtype) + np.testing.assert_allclose( + out.astype(jnp.float32), + expected_out.astype(jnp.float32), + atol=atol, + rtol=rtol, + ) - GROUPED_MATMUL_TESTS = ( - (128, 128, 128), # Small - (512, 2048, 256), # Big - (128, 8, 16), # Test partial tiles. - ) + def gmm_test( + self, + m: int, + k: int, + n: int, + data: hps.SearchStrategy[hps.DataObject], + interpret: bool = False, + ): + seed = data.draw(seed_strategy()) + num_groups, _ = data.draw(group_strategy(max_stride=1)) + lhs_dtype, rhs_dtype, out_dtype = [ + data.draw(hps.sampled_from([jnp.float32, jnp.bfloat16])) + for _ in range(3) + ] + transpose_rhs = data.draw(hps.booleans()) + + key = jax.random.key(seed) + k1, k2 = jax.random.split(key, 2) + lhs = random_dense((m, k), k1, lhs_dtype, limit=1) + rhs = random_dense((num_groups, k, n), k2, rhs_dtype, limit=1) + group_sizes = data.draw(group_sizes_strategy(m=m, num_groups=num_groups)) + + out, vjpfun = jax.vjp( + partial( + mblx.gmm, + preferred_element_type=out_dtype, + transpose_rhs=transpose_rhs, + interpret=interpret, + ), + lhs, + rhs.swapaxes(1, 2) if transpose_rhs else rhs, + group_sizes, + ) - def random_dense( - shape: tuple[int, ...], - key: jax.Array, - dtype: jnp.dtype, - limit: int | None = None, - ) -> jnp.ndarray: - if limit is None: - limit = 1 / np.prod(shape) - x = jax.random.uniform(key, shape, dtype, minval=-limit, maxval=limit) # pylint: disable=invalid-unary-operand-type - return x.astype(jnp.bfloat16).astype(dtype) - - def dot( - lhs: jnp.ndarray, - rhs: jnp.ndarray, - transpose_lhs: bool = False, - transpose_rhs: bool = False, - preferred_element_type: jnp.dtype = jnp.float32, - ) -> jnp.ndarray: - lhs = jnp.transpose(lhs) if transpose_lhs else lhs - rhs = jnp.transpose(rhs) if transpose_rhs else rhs - return jax.lax.dot(lhs, rhs, preferred_element_type=preferred_element_type) - - def reference_gmm( - lhs: jnp.ndarray, - rhs: jnp.ndarray, - group_sizes: jnp.ndarray, - preferred_element_type: jnp.dtype = jnp.float32, - ) -> jnp.ndarray: - - start = 0 - out = [] - for i, size in enumerate(group_sizes): - result = dot( - lhs[start : start + size, :], - rhs[i, :, :], - preferred_element_type=preferred_element_type, + def reference_fn(lhs, rhs, group_sizes, preferred_element_type): + rhs = rhs.swapaxes(1, 2) if transpose_rhs else rhs + return reference_gmm( + lhs, rhs, group_sizes, preferred_element_type=preferred_element_type ) - out.append(result) - start += group_sizes[i] - return jnp.concatenate(out, axis=0) - - def with_dtype_arguments(xs: tuple[Any, ...]) -> tuple[Any, ...]: - dtypes = [jnp.float32, jnp.bfloat16] - - result = [] - for x in xs: - for dtypes_tuple in itertools.product(dtypes, dtypes, dtypes): - result.append(x + dtypes_tuple) - return tuple(result) - - def with_transpose_argument(xs: tuple[Any, ...]) -> tuple[Any, ...]: - flags = [False, True] - result = [] - for x in xs: - for flag in flags: - result.append(x + (flag,)) - return tuple(result) - - def tolerances( - lhs_dtype: jnp.dtype, rhs_dtype: jnp.dtype, out_dtype: jnp.dtype - ) -> tuple[float, float]: - if ( - lhs_dtype == jnp.bfloat16 - or rhs_dtype == jnp.bfloat16 - or out_dtype == jnp.bfloat16 - ): - return 1e-3, 1e-2 # atol, rtol - return 1e-3, 1e-5 # atol, rtol - - # TODO(tgale): Fix errors with strict dtype promotion. - @jtu.with_config(jax_numpy_dtype_promotion="standard") - class GroupedMatmulTest(jtu.JaxTestCase): - - def setUp(self): - if not jtu.test_device_matches(["tpu"]): - self.skipTest("Test requires TPU device.") - - super().setUp() - self.key = jax.random.PRNGKey(1234) - - def assert_allclose( - self, - out: jnp.ndarray, - expected_out: jnp.ndarray, - *, - atol: float = 1e-5, - rtol: float = 1e-5, - ): - self.assertEqual(out.dtype, expected_out.dtype) - np.testing.assert_allclose( - out.astype(jnp.float32), - expected_out.astype(jnp.float32), - atol=atol, - rtol=rtol, - ) + expected_out, reference_vjpfun = jax.vjp( + partial(reference_fn, preferred_element_type=out_dtype), + lhs, + rhs.swapaxes(1, 2) if transpose_rhs else rhs, + group_sizes, + ) + self.assertEqual(out.dtype, out_dtype) + self.assertEqual(expected_out.dtype, out_dtype) + + atol, rtol = tolerances(lhs_dtype, rhs_dtype, out_dtype) + self.assert_allclose(out, expected_out, atol=atol, rtol=rtol) + + cotangent = random_dense((m, n), k1, out_dtype, limit=1) + grad_lhs, grad_rhs, *_ = vjpfun(cotangent) + expected_grad_lhs, expected_grad_rhs, *_ = reference_vjpfun(cotangent) + self.assert_allclose(grad_lhs, expected_grad_lhs, atol=atol, rtol=rtol) + self.assert_allclose(grad_rhs, expected_grad_rhs, atol=atol, rtol=rtol) + + @parameterized.parameters(*GROUPED_MATMUL_TESTS) + @hp.given(hps.data()) + def test_gmm( + self, + m: int, + k: int, + n: int, + data: hps.SearchStrategy[hps.DataObject], + ): + self.gmm_test(m, k, n, data) + + # NOTE: Run fewer tests with interpret mode. We just want to sanity check that + # changes do not break running these kernels with interpret=True. + @parameterized.parameters(*GROUPED_MATMUL_TESTS[0:1]) + @hp.given(hps.data()) + def test_gmm_interpret( + self, + m: int, + k: int, + n: int, + data: hps.SearchStrategy[hps.DataObject], + ): + self.skipTest("interpret mode with dynamic grids is unsupported") + self.gmm_test( + m, + k, + n, + data=data, + interpret=True, + ) - def gmm_test( - self, - m: int, - k: int, - n: int, - data: hps.SearchStrategy[hps.DataObject], - interpret: bool = False, - ): - seed = data.draw(seed_strategy()) - num_groups, _ = data.draw(group_strategy(max_stride=1)) - lhs_dtype, rhs_dtype, out_dtype = [ - data.draw(hps.sampled_from([jnp.float32, jnp.bfloat16])) - for _ in range(3) - ] - transpose_rhs = data.draw(hps.booleans()) - - key = jax.random.key(seed) - k1, k2 = jax.random.split(key, 2) - lhs = random_dense((m, k), k1, lhs_dtype, limit=1) - rhs = random_dense((num_groups, k, n), k2, rhs_dtype, limit=1) - group_sizes = data.draw(group_sizes_strategy(m=m, num_groups=num_groups)) - - out, vjpfun = jax.vjp( - partial( - mblx.gmm, - preferred_element_type=out_dtype, - transpose_rhs=transpose_rhs, - interpret=interpret, + @parameterized.parameters(*GROUPED_MATMUL_TESTS) + @hp.given(hps.data()) + def test_gmm_sharded_groups( + self, + m: int, + k: int, + n: int, + data: hps.SearchStrategy[hps.DataObject], + ): + seed = data.draw(seed_strategy()) + num_groups, group_stride = data.draw(group_strategy()) + lhs_dtype, rhs_dtype, out_dtype = [ + data.draw(hps.sampled_from([jnp.float32, jnp.bfloat16])) + for _ in range(3) + ] + + key = jax.random.key(seed) + k1, k2 = jax.random.split(key, 2) + lhs = random_dense((m, k), k1, lhs_dtype, limit=1) + rhs = random_dense((num_groups, k, n), k2, rhs_dtype, limit=1) + group_sizes = data.draw(group_sizes_strategy(m=m, num_groups=num_groups)) + + out, shard_vjpfun = jax.vjp( + partial(mblx.gmm, preferred_element_type=out_dtype), + lhs, + rhs[0:group_stride], + group_sizes, + ) + vjpfuns = [shard_vjpfun] + for group_offset in range(group_stride, num_groups, group_stride): + out, shard_vjpfun = jax.vjp( + lambda lhs, rhs, group_sizes, out: mblx.gmm( + lhs, + rhs, + group_sizes, + out_dtype, + group_offset=jnp.array(group_offset, dtype=jnp.int32), # pylint: disable=cell-var-from-loop + existing_out=out, ), lhs, - rhs.swapaxes(1, 2) if transpose_rhs else rhs, - group_sizes, - ) - - def reference_fn(lhs, rhs, group_sizes, preferred_element_type): - rhs = rhs.swapaxes(1, 2) if transpose_rhs else rhs - return reference_gmm( - lhs, rhs, group_sizes, preferred_element_type=preferred_element_type - ) - - expected_out, reference_vjpfun = jax.vjp( - partial(reference_fn, preferred_element_type=out_dtype), - lhs, - rhs.swapaxes(1, 2) if transpose_rhs else rhs, + rhs[group_offset : group_offset + group_stride], group_sizes, + out, ) - self.assertEqual(out.dtype, out_dtype) - self.assertEqual(expected_out.dtype, out_dtype) - - atol, rtol = tolerances(lhs_dtype, rhs_dtype, out_dtype) - self.assert_allclose(out, expected_out, atol=atol, rtol=rtol) - - cotangent = random_dense((m, n), k1, out_dtype, limit=1) - grad_lhs, grad_rhs, *_ = vjpfun(cotangent) - expected_grad_lhs, expected_grad_rhs, *_ = reference_vjpfun(cotangent) - self.assert_allclose(grad_lhs, expected_grad_lhs, atol=atol, rtol=rtol) - self.assert_allclose(grad_rhs, expected_grad_rhs, atol=atol, rtol=rtol) - - @parameterized.parameters(*GROUPED_MATMUL_TESTS) - @hp.given(hps.data()) - def test_gmm( - self, - m: int, - k: int, - n: int, - data: hps.SearchStrategy[hps.DataObject], - ): - self.gmm_test(m, k, n, data) - - # NOTE: Run fewer tests with interpret mode. We just want to sanity check that - # changes do not break running these kernels with interpret=True. - @parameterized.parameters(*GROUPED_MATMUL_TESTS[0:1]) - @hp.given(hps.data()) - def test_gmm_interpret( - self, - m: int, - k: int, - n: int, - data: hps.SearchStrategy[hps.DataObject], - ): - self.skipTest("interpret mode with dynamic grids is unsupported") - self.gmm_test( - m, - k, - n, - data=data, - interpret=True, - ) + vjpfuns.append(shard_vjpfun) - @parameterized.parameters(*GROUPED_MATMUL_TESTS) - @hp.given(hps.data()) - def test_gmm_sharded_groups( - self, - m: int, - k: int, - n: int, - data: hps.SearchStrategy[hps.DataObject], + expected_out, reference_vjpfun = jax.vjp( + partial(reference_gmm, preferred_element_type=out_dtype), + lhs, + rhs, + group_sizes, + ) + self.assertEqual(out.dtype, out_dtype) + self.assertEqual(expected_out.dtype, out_dtype) + atol, rtol = tolerances(lhs_dtype, rhs_dtype, out_dtype) + self.assert_allclose(out, expected_out, atol=atol, rtol=rtol) + + cotangent = random_dense((m, n), k1, out_dtype, limit=1) + shard_grad_lhs, shard_grad_rhs, *_ = vjpfuns[0](cotangent) + grad_lhs = shard_grad_lhs + grad_rhs = [shard_grad_rhs] + for i, group_offset in enumerate( + range(group_stride, num_groups, group_stride) ): - seed = data.draw(seed_strategy()) - num_groups, group_stride = data.draw(group_strategy()) - lhs_dtype, rhs_dtype, out_dtype = [ - data.draw(hps.sampled_from([jnp.float32, jnp.bfloat16])) - for _ in range(3) - ] - - key = jax.random.key(seed) - k1, k2 = jax.random.split(key, 2) - lhs = random_dense((m, k), k1, lhs_dtype, limit=1) - rhs = random_dense((num_groups, k, n), k2, rhs_dtype, limit=1) - group_sizes = data.draw(group_sizes_strategy(m=m, num_groups=num_groups)) - - out, shard_vjpfun = jax.vjp( - partial(mblx.gmm, preferred_element_type=out_dtype), - lhs, - rhs[0:group_stride], - group_sizes, - ) - vjpfuns = [shard_vjpfun] - for group_offset in range(group_stride, num_groups, group_stride): - out, shard_vjpfun = jax.vjp( - lambda lhs, rhs, group_sizes, out: mblx.gmm( - lhs, - rhs, - group_sizes, - out_dtype, - group_offset=jnp.array(group_offset, dtype=jnp.int32), # pylint: disable=cell-var-from-loop - existing_out=out, - ), - lhs, - rhs[group_offset : group_offset + group_stride], - group_sizes, - out, - ) - vjpfuns.append(shard_vjpfun) - - expected_out, reference_vjpfun = jax.vjp( - partial(reference_gmm, preferred_element_type=out_dtype), - lhs, - rhs, - group_sizes, - ) - self.assertEqual(out.dtype, out_dtype) - self.assertEqual(expected_out.dtype, out_dtype) - atol, rtol = tolerances(lhs_dtype, rhs_dtype, out_dtype) - self.assert_allclose(out, expected_out, atol=atol, rtol=rtol) - - cotangent = random_dense((m, n), k1, out_dtype, limit=1) - shard_grad_lhs, shard_grad_rhs, *_ = vjpfuns[0](cotangent) - grad_lhs = shard_grad_lhs - grad_rhs = [shard_grad_rhs] - for i, group_offset in enumerate( - range(group_stride, num_groups, group_stride) - ): - shard_grad_lhs, shard_grad_rhs, *_ = vjpfuns[i + 1](cotangent) - grad_lhs += shard_grad_lhs - grad_rhs.append(shard_grad_rhs) - grad_rhs = jnp.concatenate(grad_rhs, axis=0) - expected_grad_lhs, expected_grad_rhs, *_ = reference_vjpfun(cotangent) - self.assert_allclose(grad_lhs, expected_grad_lhs, atol=atol, rtol=rtol) - self.assert_allclose(grad_rhs, expected_grad_rhs, atol=atol, rtol=rtol) + shard_grad_lhs, shard_grad_rhs, *_ = vjpfuns[i + 1](cotangent) + grad_lhs += shard_grad_lhs + grad_rhs.append(shard_grad_rhs) + grad_rhs = jnp.concatenate(grad_rhs, axis=0) + expected_grad_lhs, expected_grad_rhs, *_ = reference_vjpfun(cotangent) + self.assert_allclose(grad_lhs, expected_grad_lhs, atol=atol, rtol=rtol) + self.assert_allclose(grad_rhs, expected_grad_rhs, atol=atol, rtol=rtol) if __name__ == "__main__": diff --git a/tests/pallas/tpu_ops_test.py b/tests/pallas/tpu_ops_test.py index c8def2627462..1fb0bc24701b 100644 --- a/tests/pallas/tpu_ops_test.py +++ b/tests/pallas/tpu_ops_test.py @@ -15,7 +15,6 @@ import functools import math import sys -import unittest from absl.testing import absltest from absl.testing import parameterized @@ -32,13 +31,10 @@ else: pltpu = None -try: - import hypothesis as hp -except (ModuleNotFoundError, ImportError): - raise unittest.SkipTest("tests depend on hypothesis library") - +import hypothesis as hp import hypothesis.strategies as hps + jax.config.parse_flags_with_absl() jtu.setup_hypothesis(max_examples=100) @@ -66,6 +62,7 @@ def pallas_call(cls, *args, **kwargs): return pl.pallas_call(*args, interpret=cls.INTERPRET, **kwargs) +@jtu.thread_unsafe_test_class() # hypothesis is not thread safe class OpsTest(PallasBaseTest): @parameterized.product( @@ -490,7 +487,32 @@ def kernel(x, out): expected = dot(x[:], jnp.ones((1, d), jnp.bfloat16)) np.testing.assert_array_equal(output, expected) + # We need to manually run the test with the env variable + # `export LIBTPU_INIT_ARGS="--xla_jf_bounds_check=true"` + def test_disable_bounds_check(self): + if not jtu.if_cloud_tpu_at_least(2025, 4, 16): + self.skipTest("Requires libtpu built after 2025-04-16") + if jtu.get_tpu_version() < 4: + self.skipTest("Requires TPUv4+") + src_shape = (8, 128) + tgt_shape = (16, 256) + + def kernel(src, tgt): + tgt[:] = pl.load(src, tuple(pl.ds(0, d) for d in tgt.shape)) + + x = jnp.arange(np.prod(src_shape), dtype=jnp.float32).reshape(src_shape) + run = pl.pallas_call( + kernel, + jax.ShapeDtypeStruct(tgt_shape, jnp.float32), + compiler_params=pltpu.TPUCompilerParams(disable_bounds_checks=True), + ) + output = run(x) + np.testing.assert_array_equal( + output[tuple(slice(0, d) for d in src_shape)], x + ) + +@jtu.thread_unsafe_test_class() # hypothesis is not thread safe class OpsInterpretTest(OpsTest): INTERPRET = True diff --git a/tests/pallas/tpu_paged_attention_kernel_test.py b/tests/pallas/tpu_paged_attention_kernel_test.py index 7fbccdb338d4..ac24fea1b45a 100644 --- a/tests/pallas/tpu_paged_attention_kernel_test.py +++ b/tests/pallas/tpu_paged_attention_kernel_test.py @@ -18,19 +18,176 @@ from jax._src import test_util as jtu from jax.experimental.pallas.ops.tpu import paged_attention from jax.experimental.pallas.ops.tpu.paged_attention import quantization_utils +from jax.experimental.pallas.ops.tpu.paged_attention import util import jax.numpy as jnp import numpy as np -jax.config.parse_flags_with_absl() +def _generate_qkv_simplest( + dtype: jnp.dtype, +) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]: + """Generates queries with one query head, kv pages, and attention.""" + max_seq_len = 4 + seq_lens = jnp.asarray([max_seq_len // 2]) + assert seq_lens.shape == (1,) + + # q_shape = (batch_size=1, num_q_heads=1, head_dim=1) + queries = jnp.asarray([[[1.2]]], dtype) + assert queries.shape == (1, 1, 1) + + # kv_shape = (batch_size=1, num_kv_heads=1, max_seq_len=4, head_dim=1) + k_pages = jnp.asarray([[[[0.1], [0.2], [0.3], [0.4]]]], dtype) + v_pages = jnp.asarray([[[[4.0], [3.0], [2.0], [1.0]]]], dtype) + assert k_pages.shape == (1, 1, 4, 1) + assert v_pages.shape == k_pages.shape + + # q*k: [[[ [.12, .24, .36, .48] ]]] + # masked: [[[ [.12, .24, -inf, -inf] ]]] + # softmax: [[[ [.47, .53, 0, 0] ]]] + # softmax(q*k) * v: .47*4 + .53*3 + 0*... = 3.47 + attention = jnp.asarray([[[3.47]]], dtype) + assert attention.shape == queries.shape + return seq_lens, queries, k_pages, v_pages, attention + + +def _generate_qkv_with_one_q_head( + dtype: jnp.dtype, +) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]: + """Generates queries with one query head, kv pages, and attention.""" + max_seq_len = 4 + seq_lens = jnp.asarray([max_seq_len - 1]) + assert seq_lens.shape == (1,) + + # q_shape = (batch_size=1, num_q_heads=1, head_dim=1) + queries = jnp.asarray([[[1.7]]], dtype) + assert queries.shape == (1, 1, 1) + + # kv_shape = (batch_size=1, num_kv_heads=1, max_seq_len=4, head_dim=1) + k_pages = jnp.asarray([[[[0.12], [0.23], [0.34], [0.45]]]], dtype) + v_pages = jnp.asarray([[[[4.32], [3.21], [2.10], [1.09]]]], dtype) + assert k_pages.shape == (1, 1, 4, 1) + assert v_pages.shape == k_pages.shape + + # q*k: [[[ [.204, .391, .578, .765] ]]] + # masked: [[[ [.204, .391, .578, -inf] ]]] + # softmax: [[[ [.273, .330, .397, 0] ]]] + # softmax(q*k) * v: .273*4.32 + .330*3.21 + .397*2.10 + 0*... = 3.0723 + attention = jnp.asarray([[[3.0723]]], dtype) + assert attention.shape == queries.shape + return seq_lens, queries, k_pages, v_pages, attention + + +def _generate_qkv_with_two_q_heads( + dtype: jnp.dtype, +) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]: + """Generates queries with two query heads, kv pages, and attention.""" + max_seq_len = 4 + seq_lens = jnp.asarray([max_seq_len]) + assert seq_lens.shape == (1,) + + # q_shape = (batch_size=1, num_q_heads=2, head_dim=1) + queries = jnp.asarray([[[1.3], [9.7]]], dtype) + assert queries.shape == (1, 2, 1) + + # kv_shape = (batch_size=1, num_kv_heads=1, max_seq_len=4, head_dim=1) + k_pages = jnp.asarray([[[[0.12], [0.23], [0.34], [0.45]]]], dtype) + v_pages = jnp.asarray([[[[4.32], [3.21], [2.10], [1.09]]]], dtype) + assert k_pages.shape == (1, 1, 4, 1) + assert v_pages.shape == k_pages.shape + + # q*k: [[[ [ .156, .299, .442, .585], + # [1.164, 2.231, 3.298, 4.365] ]]] + # softmax: [[[ [ .199, .230, .265, .306], + # [ .027, .079, .229, .665] ]]] + # softmax(q*k) * v: .199*4.32 + .230*3.21 + .265*2.10 + .306*1.09 = 2.488 + # softmax(q*k) * v: .027*4.32 + .079*3.21 + .229*2.10 + .665*1.09 = 1.576 + attention = jnp.asarray([[[2.488], [1.576]]], dtype) + assert attention.shape == queries.shape + return seq_lens, queries, k_pages, v_pages, attention + + +def _generate_qkv_with_head_dim_two( + dtype: jnp.dtype, +) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]: + """Generates queries, kv pages, and attention with head_dim=2.""" + max_seq_len = 4 + seq_lens = jnp.asarray([max_seq_len // 2]) + assert seq_lens.shape == (1,) + + # q_shape = (batch_size=1, num_q_heads=1, head_dim=2) + queries = jnp.asarray([[[1.2, 9.0]]], dtype) + assert queries.shape == (1, 1, 2) + + # kv_shape = (batch_size=1, num_kv_heads=1, max_seq_len=4, head_dim=2) + k_pages = jnp.asarray( + [[[[0.1, 0.2], [0.2, 0.3], [0.3, 0.4], [0.4, 0.5]]]], dtype + ) + v_pages = jnp.asarray( + [[[[4.0, 5.0], [3.0, 6.0], [2.0, 7.0], [1.0, 8.0]]]], dtype + ) + assert k_pages.shape == (1, 1, 4, 2) + assert v_pages.shape == k_pages.shape + + # q*k: [[[ [ 1.92, 2.94, 3.96, 4.98] ]]] + # masked: [[[ [ 1.92, 2.94, -inf, -inf] ]]] + # softmax: [[[ [ .265, .735, 0, 0] ]]] + # softmax(q*k) * v: .265*4 + 0.735*3 + 0*... = 3.265 + # softmax(q*k) * v: .265*5 + 0.735*6 + 0*... = 5.735 + attention = jnp.asarray([[[3.265, 5.735]]], dtype) + assert attention.shape == queries.shape + return seq_lens, queries, k_pages, v_pages, attention def _generate_qkv( + dtype: jnp.dtype, + case: int, +) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]: + match case: + case 0: + return _generate_qkv_simplest(dtype) + case 1: + return _generate_qkv_with_one_q_head(dtype) + case 2: + return _generate_qkv_with_two_q_heads(dtype) + case 3: + return _generate_qkv_with_head_dim_two(dtype) + case _: + raise ValueError(f"Unsupported case: {case}") + + +@jtu.with_config(jax_numpy_dtype_promotion="standard") +class JaxGroupedQueryAttentionReferenceTest(jtu.JaxTestCase): + + @parameterized.product( + dtype=(jnp.float32, jnp.bfloat16), + case=(0, 1, 2, 3), + ) + def test_grouped_query_attention(self, dtype: jnp.dtype, case: int): + # generate queries, kv pages, and seq_lens + seq_lens, queries, k_pages, v_pages, expected = _generate_qkv(dtype, case) + jax.debug.print("seq_lens: {seq_lens}", seq_lens=seq_lens) + jax.debug.print("queries: {queries}", queries=queries) + jax.debug.print("k_pages: {k_pages}", k_pages=k_pages) + jax.debug.print("v_pages: {v_pages}", v_pages=v_pages) + jax.debug.print("expected: {expected}", expected=expected) + + # calculate grouped query attention + attention = util.grouped_query_attention_reference( + queries, k_pages, v_pages, seq_lens + ) + jax.debug.print("attention: {attention}", attention=attention) + + # compare the results + atol, rtol = (3e-3, 5e-3) if dtype == jnp.bfloat16 else (2e-4, 2e-4) + self.assertAllClose(attention, expected, atol=atol, rtol=rtol) + + +def _generate_random_qkv( seq_lens, page_size, max_seq_len, num_kv_heads, - num_heads, + num_q_heads, head_dim, prng_key, dtype=jnp.float32, @@ -55,7 +212,7 @@ def _generate_qkv( page_indices = jnp.arange(batch_size * pages_per_sequence, dtype=jnp.int32) page_indices = jax.random.permutation(k3, page_indices, independent=True) page_indices = page_indices.reshape(batch_size, pages_per_sequence) - q = jax.random.normal(k4, (batch_size, num_heads, head_dim), dtype=dtype) + q = jax.random.normal(k4, (batch_size, num_q_heads, head_dim), dtype=dtype) return q, k_pages, v_pages, page_indices @@ -64,7 +221,7 @@ def _reconstruct_kv(page_indices, pages): pages = quantization_utils.unquantize_from_int8(pages, dtype=jnp.float32) batch_size = page_indices.shape[0] - num_heads, _, _, head_dim = pages.shape + num_kv_heads, _, _, head_dim = pages.shape def per_sequence_page_gather(pages, page_indices): return jnp.take(pages, page_indices, 1) @@ -72,32 +229,7 @@ def per_sequence_page_gather(pages, page_indices): gathered = jax.vmap(per_sequence_page_gather, in_axes=(None, 0))( pages, page_indices ) - return gathered.reshape(batch_size, num_heads, -1, head_dim) - - -def _grouped_query_attention_reference(q, k, v, lengths, attn_logits_soft_cap): - batch_size, num_heads, head_dim = q.shape - _, num_kv_heads, max_seq_len, _ = k.shape - assert k.shape == v.shape - assert num_heads % num_kv_heads == 0 - q = q.reshape(batch_size, num_kv_heads, num_heads // num_kv_heads, head_dim) - - if isinstance(k, quantization_utils.QuantizedTensor): - k = quantization_utils.unquantize_from_int8(k, dtype=jnp.float32) - if isinstance(v, quantization_utils.QuantizedTensor): - v = quantization_utils.unquantize_from_int8(v, dtype=jnp.float32) - - logits = jnp.einsum( - "bhgd,bhtd->bhgt", q.astype(jnp.float32), k.astype(jnp.float32) - ) - if attn_logits_soft_cap is not None: - logits = jnp.tanh(logits / attn_logits_soft_cap) * attn_logits_soft_cap - mask = jnp.arange(max_seq_len)[None] < lengths[:, None] - mask_value = -0.7 * float(np.finfo(np.dtype("float32")).max) - logits = logits + jnp.where(mask, 0.0, mask_value)[:, None, None, :] - weights = jax.nn.softmax(logits, axis=-1) - o = jnp.einsum("bhgt,bhtd->bhgd", weights.astype(v.dtype), v) - return o.reshape(batch_size, num_heads, head_dim) + return gathered.reshape(batch_size, num_kv_heads, -1, head_dim) def _megacore_enabled(): @@ -149,7 +281,7 @@ def test_paged_attention( max_kv_len = 2048 block_size = 512 seq_lens = np.asarray([0, 3, 256, 513, 1023, 2048]) - q, k_pages, v_pages, page_indices = _generate_qkv( + q, k_pages, v_pages, page_indices = _generate_random_qkv( seq_lens, page_size, max_kv_len, @@ -172,8 +304,9 @@ def test_paged_attention( ) k = _reconstruct_kv(page_indices, k_pages) v = _reconstruct_kv(page_indices, v_pages) - o_ref = _grouped_query_attention_reference( - q, k, v, seq_lens, attn_logits_soft_cap) + o_ref = util.grouped_query_attention_reference( + q, k, v, seq_lens, attn_logits_soft_cap + ) if q_kv_head_ratio > 1: atol, rtol = 1e-2, 2e-2 @@ -188,4 +321,5 @@ def test_paged_attention( if __name__ == "__main__": + jax.config.config_with_absl() absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/tpu_pallas_distributed_test.py b/tests/pallas/tpu_pallas_distributed_test.py index f7d7daf1874f..3d4d441d7cd0 100644 --- a/tests/pallas/tpu_pallas_distributed_test.py +++ b/tests/pallas/tpu_pallas_distributed_test.py @@ -51,8 +51,7 @@ def test_basic_remote_vmem_dma(self, mem): # Implements very simple collective permute def kernel(x_ref, y_ref): def body(ready_sem, send_sem, recv_sem): - dev_id = pltpu.device_id() - other_dev_id = 1 - dev_id + other_dev_id = 1 - lax.axis_index('x') pltpu.semaphore_signal(ready_sem, device_id=other_dev_id, device_id_type=pltpu.DeviceIdType.LOGICAL) pltpu.semaphore_wait(ready_sem) @@ -236,7 +235,7 @@ def body(x): in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM)], out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), out_shape=x, - compiler_params=dict(mosaic=dict(collective_id=0)), + compiler_params=pltpu.TPUCompilerParams(collective_id=0), )(x) device_mesh = mesh_utils.create_device_mesh( diff --git a/tests/pallas/tpu_pallas_interpret_distributed_test.py b/tests/pallas/tpu_pallas_interpret_distributed_test.py index 518c16ed2109..1ed139e9e867 100644 --- a/tests/pallas/tpu_pallas_interpret_distributed_test.py +++ b/tests/pallas/tpu_pallas_interpret_distributed_test.py @@ -18,8 +18,6 @@ contains only tests that use shard_map. """ -import functools - from absl.testing import absltest from absl.testing import parameterized @@ -1017,19 +1015,6 @@ def test_race_detection(self): input_arr = jax.device_put(input_arr, sharding) def kernel(src_dst_ids_ref, x_ref, o_ref, send_sem, recv_sem): - # Barrier with all devices before doing any DMAs. - barrier_sem = pltpu.get_barrier_semaphore() - @functools.partial(jax.lax.fori_loop, 0, num_devices, init_val=None) - def _(i, _): - pltpu.semaphore_signal( - barrier_sem, - inc=1, - device_id=(jnp.int32(i),), - device_id_type=pltpu.DeviceIdType.MESH, - ) - return None - pltpu.semaphore_wait(barrier_sem, num_devices) - # Send the specified DMAs. my_id = lax.axis_index('x') src_dst_ids = src_dst_ids_ref[:] @@ -1076,7 +1061,6 @@ def run(src_dst_ids): ], out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), scratch_shapes=[pltpu.SemaphoreType.DMA, pltpu.SemaphoreType.DMA], - compiler_params=pltpu.TPUCompilerParams(collective_id=0), interpret=mosaic_interpret.TPUInterpretParams( dma_execution_mode='eager', detect_races=True, diff --git a/tests/pallas/tpu_pallas_interpret_test.py b/tests/pallas/tpu_pallas_interpret_test.py index bc589855b836..c4bf07f39cef 100644 --- a/tests/pallas/tpu_pallas_interpret_test.py +++ b/tests/pallas/tpu_pallas_interpret_test.py @@ -18,23 +18,72 @@ contains only tests that do not use shard_map. """ +import functools + from absl.testing import absltest from absl.testing import parameterized - import jax from jax._src import test_util as jtu import jax._src.pallas.mosaic.interpret as mosaic_interpret from jax.experimental import pallas as pl from jax.experimental.pallas import tpu as pltpu import jax.numpy as jnp - import numpy as np jax.config.parse_flags_with_absl() +jax.config.update('jax_threefry_partitionable', True) + + +class CountStoreCallbacksContext(object): + """Wraps the I/O callback `store` into a callback that counts the number of calls to `store`.""" + + def __init__(self): + self._num_stores = 0 + self._saved = mosaic_interpret.store + + def __enter__(self): + def _store_callback(self, *args, **kwargs): + self._num_stores += 1 + return self._saved(*args, **kwargs) + + mosaic_interpret.store = functools.partial(_store_callback, self) + return self + + def __exit__(self, ty, value, traceback): + del ty, value, traceback + mosaic_interpret.store = self._saved + + @property + def num_stores(self): + return self._num_stores + + +class GridPointRecorderContext(object): + """Records grid points in the order in which they are traversed.""" + + def __init__(self): + self._grid_points = [] + + def __enter__(self): + return self + + def __exit__(self, ty, value, traceback): + ... + + def get_recorder(self): + def _recorder(grid_point): + self._grid_points.append(grid_point) + + return _recorder + + @property + def grid_points(self): + return self._grid_points class InterpretTest(jtu.JaxTestCase): + def setUp(self): super().setUp() self.num_devices = jax.device_count() @@ -49,17 +98,18 @@ def matmul_kernel(x_ref, y_ref, z_ref): @jax.jit def matmul(x: jax.Array, y: jax.Array): return pl.pallas_call( - matmul_kernel, - out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), x.dtype), - grid=(2, 2), - in_specs=[ - pl.BlockSpec((x.shape[0] // 2, x.shape[1]), lambda i, j: (i, 0)), - pl.BlockSpec((y.shape[0], y.shape[1] // 2), lambda i, j: (0, j)) - ], - out_specs=pl.BlockSpec( - (x.shape[0] // 2, y.shape[1] // 2), lambda i, j: (i, j), - ), - interpret=mosaic_interpret.TPUInterpretParams(), + matmul_kernel, + out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), x.dtype), + grid=(2, 2), + in_specs=[ + pl.BlockSpec((x.shape[0] // 2, x.shape[1]), lambda i, j: (i, 0)), + pl.BlockSpec((y.shape[0], y.shape[1] // 2), lambda i, j: (0, j)), + ], + out_specs=pl.BlockSpec( + (x.shape[0] // 2, y.shape[1] // 2), + lambda i, j: (i, j), + ), + interpret=mosaic_interpret.TPUInterpretParams(), )(x, y) k1, k2 = jax.random.split(jax.random.key(0)) @@ -68,11 +118,50 @@ def matmul(x: jax.Array, y: jax.Array): z = matmul(x, y) np.testing.assert_allclose(z, x @ y, atol=1e-4) + def test_scalar_prefetch_example(self): + def dynamic_slice_kernel(indices, x_ref, o_ref): + del indices + o_ref[...] = x_ref[...] + + @functools.partial(jax.jit, static_argnums=(2,)) + def block_dynamic_slice(x, starts, sizes): + grid_spec = pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=1, + grid=(1, 1), + in_specs=[ + pl.BlockSpec( + sizes, lambda i, j, block_idx: (block_idx[0], block_idx[1]) + ) + ], + out_specs=pl.BlockSpec(sizes, lambda *_: (0, 0)), + ) + + kernel = pl.pallas_call( + dynamic_slice_kernel, + grid_spec=grid_spec, + out_shape=jax.ShapeDtypeStruct(shape=sizes, dtype=x.dtype), + interpret=mosaic_interpret.TPUInterpretParams(), + ) + block_idx = jnp.array([starts[0] // sizes[0], starts[1] // sizes[1]]) + return kernel(block_idx, x) + + shape = (512, 512) + x = jnp.reshape(jnp.arange(np.prod(shape), dtype=jnp.int32), shape) + result = block_dynamic_slice( + x, starts=jnp.array([128, 256]), sizes=(128, 128) + ) + ref = jax.lax.dynamic_slice( + x, start_indices=(128, 256), slice_sizes=(128, 128) + ) + diff = jnp.max(jnp.abs(result - ref)) + np.testing.assert_allclose(result, ref) + def test_dynamic_grid_and_aliasing(self): def kernel(s_ref, x_ref, o_ref): o_ref[...] = x_ref[...] + s_ref[0].astype(x_ref.dtype) iters = jax.random.randint(jax.random.key(0), (), 10, 20, dtype=jnp.int32) + @jax.jit def f(s, x): return pl.pallas_call( @@ -85,14 +174,57 @@ def f(s, x): ], out_specs=pl.BlockSpec(x.shape, lambda i: (0, 0)), input_output_aliases={1: 0}, - interpret=mosaic_interpret.TPUInterpretParams() + interpret=mosaic_interpret.TPUInterpretParams(), )(s, x) s = jnp.array([1], dtype=jnp.int32) - x = jnp.arange(32 * 128.).reshape((32, 128)) + x = jnp.arange(32 * 128.0).reshape((32, 128)) y = f(s, x) + # NOTE: No matter how many times the kernel body is run, the kernel input + # buffer will only be written once by the pallas_call machinery, just + # before the first iteration. So the output will be x + 1 , despite the + # aliasing in HBM. np.testing.assert_allclose(y, x + 1.0) + def test_aliasing(self): + def kernel(x_ref, o_ref, s_ref): + @pl.when((pl.program_id(0) == 0) & (pl.program_id(1) == 0)) + def _(): + s_ref[0] = jnp.int32(0) + + s = s_ref[0] + s_ref[0] = s + 1 + o_ref[:] = x_ref[:] + s.astype(x_ref.dtype) + + x = jnp.zeros((4 * 8, 4 * 128)) + y = pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), + grid=(4, 4), + in_specs=[ + pl.BlockSpec(block_shape=(8, 128), index_map=lambda i, j: (i, j)), + ], + out_specs=pl.BlockSpec( + block_shape=(8, 128), index_map=lambda i, j: (j, i) + ), + scratch_shapes=(pltpu.SMEM((1,), jnp.int32),), + input_output_aliases={0: 0}, + interpret=mosaic_interpret.TPUInterpretParams(), + )(x) + + expected = np.zeros((4, 4)) + t = 0 + for i in range(4): + for j in range(4): + expected[j, i] = expected[i, j] + t + t += 1 + # NOTE: expected is + # [[0, 5, 10, 15], + # [1, 5, 15, 20], + # [2, 6, 10, 25], + # [3, 7, 11, 15]] + np.testing.assert_allclose(y[::8, ::128], expected) + @parameterized.parameters('eager', 'on_wait') def test_race_detection(self, dma_execution_mode): def kernel_without_race(x_ref, o_ref, t_ref, sem): @@ -109,7 +241,8 @@ def kernel_with_race(x_ref, o_ref, t_ref, sem): copy.wait() x = jnp.zeros((8, 128), jnp.float32) - y = pl.pallas_call(kernel_without_race, + y = pl.pallas_call( + kernel_without_race, out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY)], scratch_shapes=[ @@ -117,12 +250,14 @@ def kernel_with_race(x_ref, o_ref, t_ref, sem): pltpu.SemaphoreType.DMA, ], interpret=mosaic_interpret.TPUInterpretParams( - detect_races=True, dma_execution_mode=dma_execution_mode), + detect_races=True, dma_execution_mode=dma_execution_mode + ), )(x).block_until_ready() self.assertFalse(mosaic_interpret.races.races_found) np.testing.assert_allclose(y, x + 1.0) - pl.pallas_call(kernel_with_race, + pl.pallas_call( + kernel_with_race, out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY)], scratch_shapes=[ @@ -130,7 +265,8 @@ def kernel_with_race(x_ref, o_ref, t_ref, sem): pltpu.SemaphoreType.DMA, ], interpret=mosaic_interpret.TPUInterpretParams( - detect_races=True, dma_execution_mode=dma_execution_mode), + detect_races=True, dma_execution_mode=dma_execution_mode + ), )(x).block_until_ready() self.assertTrue(mosaic_interpret.races.races_found) @@ -153,8 +289,8 @@ def matmul(x: jax.Array, y: jax.Array): z = jax.jit(matmul)(x, y) np.testing.assert_array_equal(z, jnp.full_like(z, jnp.inf)) - lowered = jax.jit(matmul).lower(x, y).as_text(dialect="stablehlo") - self.assertNotIn("dot_general", lowered) + lowered = jax.jit(matmul).lower(x, y).as_text(dialect='stablehlo') + self.assertNotIn('dot_general', lowered) @parameterized.parameters('nan', 'zero') def test_uninitialized_memory(self, uninitialized_memory): @@ -175,7 +311,8 @@ def kernel(o1_ref, o2_ref, o3_ref, t1_ref, t2_ref): pltpu.VMEM((8, 128), jnp.int16), ], interpret=mosaic_interpret.TPUInterpretParams( - uninitialized_memory=uninitialized_memory), + uninitialized_memory=uninitialized_memory + ), )() if uninitialized_memory == 'nan': self.assertTrue(jnp.isnan(x).all()) @@ -186,6 +323,169 @@ def kernel(o1_ref, o2_ref, o3_ref, t1_ref, t2_ref): np.testing.assert_equal(np.array(y), 0) np.testing.assert_equal(np.array(z), 0) + def test_correct_number_of_stores(self): + def kernel(x_ref, s_ref, o_ref): + s = s_ref[0] + x_ref[:] += jax.lax.full_like(x_ref, s) + s_ref[0] = s + 1 + o_ref[:] = x_ref[:] + + def kernel_call(x, s): + return pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((16, 256), jnp.float32), + grid=(2, 2), + in_specs=[ + pl.BlockSpec((8, 256), lambda i, j: (i, 0)), + pl.BlockSpec(memory_space=pltpu.SMEM), + ], + out_specs=pl.BlockSpec((8, 256), lambda i, j: (i, 0)), + interpret=mosaic_interpret.TPUInterpretParams(), + )(x, s) + + with CountStoreCallbacksContext() as store_callbacks_counter: + result = jax.jit(kernel_call)( + jnp.zeros((16, 256), jnp.float32), jnp.zeros((1,), jnp.int32) + ) + np.testing.assert_allclose(result[::8, ::256], [[1.0], [5.0]]) + self.assertEqual(store_callbacks_counter.num_stores, 5) + + def test_randomization_of_parallel_dimensions(self): + def kernel(s_ref, o_ref): + s = s_ref[0] + s_ref[0] = s + 1 + o_ref[:] = jax.lax.full_like(o_ref, s) + + def kernel_call_dimensions_arbitrary_parallel(s, grid_point_recorder): + return pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((32, 512), jnp.float32), + grid=(4, 4), + in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM)], + out_specs=pl.BlockSpec((8, 128), lambda i, j: (i, j)), + interpret=mosaic_interpret.TPUInterpretParams( + random_seed=12345, grid_point_recorder=grid_point_recorder + ), + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=('arbitrary', 'parallel') + ), + )(s) + + with GridPointRecorderContext() as grid_point_recorder: + result = jax.jit( + kernel_call_dimensions_arbitrary_parallel, static_argnums=1 + )( + jnp.zeros((1,), jnp.int32), + grid_point_recorder.get_recorder(), + ) + np.testing.assert_allclose( + result[::8, ::128], + [ + [ 2.0, 3.0, 0.0, 1.0], + [ 6.0, 7.0, 4.0, 5.0], + [10.0, 11.0, 8.0, 9.0], + [14.0, 15.0, 12.0, 13.0], + ], + ) + np.testing.assert_array_equal( + grid_point_recorder.grid_points, + [ + [0, 2], + [0, 3], + [0, 0], + [0, 1], + [1, 2], + [1, 3], + [1, 0], + [1, 1], + [2, 2], + [2, 3], + [2, 0], + [2, 1], + [3, 2], + [3, 3], + [3, 0], + [3, 1], + ], + ) + + def kernel_call_dimensions_parallel_arbitrary(s, grid_point_recorder): + return pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((32, 512), jnp.float32), + grid=(4, 4), + in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM)], + out_specs=pl.BlockSpec((8, 128), lambda i, j: (i, j)), + interpret=mosaic_interpret.TPUInterpretParams( + random_seed=12345, grid_point_recorder=grid_point_recorder + ), + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=('parallel', 'arbitrary') + ), + )(s) + + with GridPointRecorderContext() as grid_point_recorder: + result = jax.jit( + kernel_call_dimensions_parallel_arbitrary, static_argnums=1 + )( + jnp.zeros((1,), jnp.int32), + grid_point_recorder.get_recorder(), + ) + np.testing.assert_allclose( + result[::8, ::128], + [ + [ 8.0, 9.0, 10.0, 11.0], + [12.0, 13.0, 14.0, 15.0], + [ 0.0, 1.0, 2.0, 3.0], + [ 4.0, 5.0, 6.0, 7.0], + ], + ) + np.testing.assert_array_equal( + grid_point_recorder.grid_points, + [ + [2, 0], + [2, 1], + [2, 2], + [2, 3], + [3, 0], + [3, 1], + [3, 2], + [3, 3], + [0, 0], + [0, 1], + [0, 2], + [0, 3], + [1, 0], + [1, 1], + [1, 2], + [1, 3], + ], + ) + + def test_dynamic_parallel_dimension_raises(self): + def kernel(o_ref): + o_ref[0] = 42.0 + + @jax.jit + def kernel_call_dynamic_parallel_dimension(): + dim_size = jax.random.randint( + jax.random.key(0), (), 10, 20, dtype=jnp.int32 + ) + return pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((1,), jnp.float32), + grid=(dim_size,), + in_specs=[], + out_specs=pl.BlockSpec((1,), lambda _: (0,)), + interpret=mosaic_interpret.TPUInterpretParams(), + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=('parallel',) + ), + )() + + with self.assertRaises(jax.errors.ConcretizationTypeError): + kernel_call_dynamic_parallel_dimension() + -if __name__ == "__main__": +if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/tpu_pallas_pipeline_test.py b/tests/pallas/tpu_pallas_pipeline_test.py index 8e72c49e2598..f29182e56314 100644 --- a/tests/pallas/tpu_pallas_pipeline_test.py +++ b/tests/pallas/tpu_pallas_pipeline_test.py @@ -26,25 +26,21 @@ from jax.experimental.pallas import tpu as pltpu import jax.numpy as jnp import numpy as np -try: - import hypothesis as hp - import hypothesis.strategies as hps - CAN_USE_HYPOTHESIS = True -except (ModuleNotFoundError, ImportError): - CAN_USE_HYPOTHESIS = False - - -if CAN_USE_HYPOTHESIS: - hp.settings.register_profile( - 'deterministic', - database=None, - derandomize=True, - deadline=None, - max_examples=200, - print_blob=True, - verbosity=hp.Verbosity.verbose, - ) - hp.settings.load_profile('deterministic') + +import hypothesis as hp +import hypothesis.strategies as hps + + +hp.settings.register_profile( + 'deterministic', + database=None, + derandomize=True, + deadline=None, + max_examples=200, + print_blob=True, + verbosity=hp.Verbosity.verbose, +) +hp.settings.load_profile('deterministic') jax.config.parse_flags_with_absl() @@ -720,20 +716,20 @@ def _wait_on_prev_dma(): pl.BlockSpec(memory_space=memory_space), pl.BlockSpec(memory_space=memory_space), ], - out_specs=[pl.BlockSpec(memory_space=memory_space), - pl.BlockSpec(memory_space=memory_space)], + out_specs=[ + pl.BlockSpec(memory_space=memory_space), + pl.BlockSpec(memory_space=memory_space), + ], grid=(outer_steps, 2), - scratch_shapes=[ - pltpu.VMEM((tm, tn), jnp.float32)] + scratch_shapes=[pltpu.VMEM((tm, tn), jnp.float32)] + [pltpu.SemaphoreType.DMA] * 4 - + inner_allocs + + inner_allocs, ), - compiler_params=dict( - mosaic=dict(collective_id=0, - # must set scoped vmem flag *larger* than below! e.g.: - # flags.FLAGS.xla_tpu_scoped_vmem_limit_kib = 131072 - vmem_limit_bytes=int(134217728 * 0.9) # 0.9 * 128MiB - ) + compiler_params=pltpu.TPUCompilerParams( + collective_id=0, + # must set scoped vmem flag *larger* than below! e.g.: + # flags.FLAGS.xla_tpu_scoped_vmem_limit_kib = 131072 + vmem_limit_bytes=int(134217728 * 0.9), # 0.9 * 128MiB ), ) @@ -1010,15 +1006,13 @@ def _loop_epilogue(): grid=(outer_steps, 2), scratch_shapes=[pltpu.VMEM((tm, tn), jnp.float32)] + [pltpu.SemaphoreType.DMA] * 4 - + inner_allocs + + inner_allocs, ), - compiler_params=dict( - mosaic=dict( - collective_id=0, - # must set scoped vmem flag *larger* than below! - # e.g. flags.FLAGS.xla_tpu_scoped_vmem_limit_kib = 131072 - vmem_limit_bytes=int(134217728 * 0.9) # 0.9 * 128MiB - ) + compiler_params=pltpu.TPUCompilerParams( + collective_id=0, + # must set scoped vmem flag *larger* than below! + # e.g. flags.FLAGS.xla_tpu_scoped_vmem_limit_kib = 131072 + vmem_limit_bytes=int(134217728 * 0.9), # 0.9 * 128MiB ), ) @@ -1273,15 +1267,13 @@ def _prefetch_accumulator(): grid=(outer_steps, 2), scratch_shapes=[pltpu.VMEM((tm, tn), jnp.float32)] + [pltpu.SemaphoreType.DMA] * 4 - + inner_allocs + + inner_allocs, ), - compiler_params=dict( - mosaic=dict( - collective_id=0, - # must set scoped vmem flag *larger* than below! - # e.g. flags.FLAGS.xla_tpu_scoped_vmem_limit_kib = 131072 - vmem_limit_bytes=int(134217728 * 0.9) # 0.9 * 128MiB - ) + compiler_params=pltpu.TPUCompilerParams( + collective_id=0, + # must set scoped vmem flag *larger* than below! + # e.g. flags.FLAGS.xla_tpu_scoped_vmem_limit_kib = 131072 + vmem_limit_bytes=int(134217728 * 0.9), # 0.9 * 128MiB ), ) @@ -1362,7 +1354,9 @@ def mul_kernel(iters_ref, x_ref, y_ref): out_specs=pl.BlockSpec(memory_space=pltpu.ANY), grid=(num_cores,), ), - compiler_params=dict(mosaic=dict(dimension_semantics=('parallel',))), + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=('parallel',) + ), ) x = jax.random.uniform(jax.random.key(0), (640, 640)) np.testing.assert_allclose(func(jnp.array([5]), x), x * 2) @@ -1396,7 +1390,9 @@ def matmul_kernel(x_ref, y_ref): ], out_specs=pl.BlockSpec(memory_space=pltpu.ANY), grid=(num_cores,), - compiler_params=dict(mosaic=dict(dimension_semantics=('parallel',))), + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=('parallel',) + ), ) np.testing.assert_allclose(func(x), x * 2) @@ -1445,109 +1441,110 @@ def matmul_kernel(x_ref, y_ref, z_ref, *, bm, bk, bn): ], out_specs=pl.BlockSpec(memory_space=pltpu.ANY), grid=(num_cores,), - compiler_params=dict(mosaic=dict(dimension_semantics=('parallel',))), + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=('parallel',) + ), ) np.testing.assert_allclose(func(x, y), x @ y, atol=7e-5) -if CAN_USE_HYPOTHESIS: - - @partial(jax.jit, static_argnames=['bm', 'bk', 'bn']) - def matmul(x: jax.Array, y: jax.Array, *, bm: int, bk: int, bn: int): +@partial(jax.jit, static_argnames=['bm', 'bk', 'bn']) +def matmul(x: jax.Array, y: jax.Array, *, bm: int, bk: int, bn: int): - m, k = x.shape - _, n = y.shape + m, k = x.shape + _, n = y.shape - def kernel(x_hbm_ref, y_hbm_ref, o_hbm_ref): + def kernel(x_hbm_ref, y_hbm_ref, o_hbm_ref): - grid = (pl.cdiv(m, bm), pl.cdiv(n, bn), pl.cdiv(k, bk)) + grid = (pl.cdiv(m, bm), pl.cdiv(n, bn), pl.cdiv(k, bk)) - def run(acc_scratch_ref): - pltpu.emit_pipeline( - partial(basic_matmul_kernel, acc_scratch_ref=acc_scratch_ref, k=k), - in_specs=[ - pl.BlockSpec((bm, bk), lambda i, j, k: (i, k)), - pl.BlockSpec((bk, bn), lambda i, j, k: (k, j)), - ], - out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)), - grid=grid, - core_axis=0, - dimension_semantics=( - pltpu.PARALLEL, - pltpu.PARALLEL, - pltpu.ARBITRARY, - ), - )(x_hbm_ref, y_hbm_ref, o_hbm_ref) - - accum_dtype = ( - jnp.float32 if jnp.issubdtype(x.dtype, jnp.floating) else jnp.int32 - ) - pl.run_scoped(run, pltpu.VMEM((bm, bn), accum_dtype)) - - num_cores = jax.devices()[0].num_cores - return pl.pallas_call( - kernel, - out_shape=jax.ShapeDtypeStruct((m, n), x.dtype), - in_specs=[ - pl.BlockSpec(memory_space=pltpu.ANY), - pl.BlockSpec(memory_space=pltpu.ANY), - ], - out_specs=pl.BlockSpec(memory_space=pltpu.ANY), - grid=(num_cores,), - )(x, y) - - class PaddedPipelineEmitterTest(parameterized.TestCase): + def run(acc_scratch_ref): + pltpu.emit_pipeline( + partial(basic_matmul_kernel, acc_scratch_ref=acc_scratch_ref, k=k), + in_specs=[ + pl.BlockSpec((bm, bk), lambda i, j, k: (i, k)), + pl.BlockSpec((bk, bn), lambda i, j, k: (k, j)), + ], + out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)), + grid=grid, + core_axis=0, + dimension_semantics=( + pltpu.PARALLEL, + pltpu.PARALLEL, + pltpu.ARBITRARY, + ), + )(x_hbm_ref, y_hbm_ref, o_hbm_ref) + + accum_dtype = ( + jnp.float32 if jnp.issubdtype(x.dtype, jnp.floating) else jnp.int32 + ) + pl.run_scoped(run, pltpu.VMEM((bm, bn), accum_dtype)) + + num_cores = jax.devices()[0].num_cores + return pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((m, n), x.dtype), + in_specs=[ + pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pltpu.ANY), + ], + out_specs=pl.BlockSpec(memory_space=pltpu.ANY), + grid=(num_cores,), + )(x, y) + +@jtu.thread_unsafe_test_class() # hypothesis is not thread safe +class PaddedPipelineEmitterTest(parameterized.TestCase): - def setUp(self): - super().setUp() - if not jtu.is_device_tpu_at_least(4): - self.skipTest('Only TPU v4+ allowed.') + def setUp(self): + super().setUp() + if not jtu.is_device_tpu_at_least(4): + self.skipTest('Only TPU v4+ allowed.') - @parameterized.named_parameters( - ('float32', 'float32'), ('bfloat16', 'bfloat16'), ('int8', 'int8') - ) - @hp.given( - hps.integers(1, 1024), - hps.integers(1, 1024), - hps.integers(1, 1024), - hps.sampled_from([8, 16, 32, 128, 256, 512]), - hps.sampled_from([128, 256, 512]), - hps.sampled_from([128, 256, 512]), - hps.integers(0, 4), - ) - def test_padded_matmul(self, dtype, m, k, n, bm, bk, bn, seed): - if dtype == 'int8' and jtu.is_device_tpu_at_least(6): - self.skipTest('Not implemented for TPU v6.') - - def align_up_to(x, y): - return (x + y - 1) // y * y - - hp.assume(bm <= m) - hp.assume(bn <= n) - hp.assume(bk <= k) - if dtype == 'bfloat16': - hp.assume(bm >= 16) - if dtype == 'int8': - if not jtu.is_device_tpu_at_least(5): - self.skipTest('Only TPU v5+ allowed for int8.') - hp.assume(bm >= 32) - # TODO(apaszke): Relax DMA restrictions and remove this. - packing = 4 // jnp.dtype(dtype).itemsize - if packing != 1: - m = align_up_to(m, 8 * packing) - k = align_up_to(k, 8 * packing) - k1, k2 = jax.random.split(jax.random.key(seed)) - x = jax.random.normal(k1, (m, k), jnp.float32).astype(dtype) - y = jax.random.normal(k2, (k, n), jnp.float32).astype(dtype) - - out = matmul(x, y, bm=bm, bk=bk, bn=bn) - expected = x @ y - atol = rtol = 2.3e-5 - if dtype == 'bfloat16': - out = out.astype('float32') - expected = expected.astype('float32') - atol = rtol = 1e-2 - np.testing.assert_allclose(out, expected, atol=atol, rtol=rtol) + @parameterized.named_parameters( + ('float32', 'float32'), ('bfloat16', 'bfloat16'), ('int8', 'int8') + ) + @hp.given( + hps.integers(1, 1024), + hps.integers(1, 1024), + hps.integers(1, 1024), + hps.sampled_from([8, 16, 32, 128, 256, 512]), + hps.sampled_from([128, 256, 512]), + hps.sampled_from([128, 256, 512]), + hps.integers(0, 4), + ) + def test_padded_matmul(self, dtype, m, k, n, bm, bk, bn, seed): + if dtype == 'int8' and jtu.is_device_tpu_at_least(6): + self.skipTest('Not implemented for TPU v6.') + + def align_up_to(x, y): + return (x + y - 1) // y * y + + hp.assume(bm <= m) + hp.assume(bn <= n) + hp.assume(bk <= k) + if dtype == 'bfloat16': + hp.assume(bm >= 16) + if dtype == 'int8': + if not jtu.is_device_tpu_at_least(5): + self.skipTest('Only TPU v5+ allowed for int8.') + hp.assume(bm >= 32) + # TODO(apaszke): Relax DMA restrictions and remove this. + packing = 4 // jnp.dtype(dtype).itemsize + if packing != 1: + m = align_up_to(m, 8 * packing) + k = align_up_to(k, 8 * packing) + k1, k2 = jax.random.split(jax.random.key(seed)) + x = jax.random.normal(k1, (m, k), jnp.float32).astype(dtype) + y = jax.random.normal(k2, (k, n), jnp.float32).astype(dtype) + + out = matmul(x, y, bm=bm, bk=bk, bn=bn) + expected = x @ y + atol = rtol = 2.3e-5 + if dtype == 'bfloat16': + out = out.astype('float32') + expected = expected.astype('float32') + atol = rtol = 1e-2 + np.testing.assert_allclose(out, expected, atol=atol, rtol=rtol) if __name__ == '__main__': diff --git a/tests/pallas/tpu_pallas_random_test.py b/tests/pallas/tpu_pallas_random_test.py index ca8edf7a269e..feeaa3cfceb9 100644 --- a/tests/pallas/tpu_pallas_random_test.py +++ b/tests/pallas/tpu_pallas_random_test.py @@ -303,6 +303,7 @@ def test_threefry_kernel_matches_jax_threefry_sharded(self, shape): mesh=mesh, in_specs=partition, out_specs=partition, + check_rep=False, ) jax_gen = generate(key_jax) pl_gen = generate(key_pallas) diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index 55831ff6af1d..ce9348b594b0 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -145,8 +145,7 @@ def body(_, x_ref, o_ref): x = jnp.arange(8 * 8 * 128, dtype=jnp.int32).reshape((8 * 8, 128)) def _x_transform(i, s_ref): - s = pl.load(s_ref, (i,)) - return (s, 0) + return (s_ref[i], 0) out = self.pallas_call( body, @@ -225,7 +224,7 @@ def kernel(s_refs, src, to_store, dst, *scratch_refs): assert s2.shape == (3,) assert s3 is None store_idx = s_ref[pl.program_id(0)] - pl.store(dst, (pl.dslice(store_idx, 1), slice(None)), to_store[...]) + dst[pl.dslice(store_idx, 1), :] = to_store[...] # Pass a pytree of scalar return kernel((s, np.arange(3, dtype=np.int32), None), x, to_store) @@ -281,7 +280,7 @@ def body(_, x_ref, o_ref): x = jnp.arange(2 * 8 * 8 * 128, dtype=jnp.int32).reshape((2, 8 * 8, 128)) def _x_transform(i, s_ref): - s = pl.load(s_ref, (i,)) + s = s_ref[i] return (s, 0) def f(x): @@ -423,7 +422,7 @@ def body(_, x_ref, o_ref): x = jnp.arange(8 * 8 * 128, dtype=jnp.int32).reshape((8 * 8, 128)) def _x_transform(i, s_ref): - s = pl.load(s_ref, (i,)) + s = s_ref[i] return (s, 0) s = s[None] @@ -457,7 +456,7 @@ def body(_, x_ref, o_ref): x = jnp.arange(2 * 8 * 8 * 128, dtype=jnp.int32).reshape((2, 8 * 8, 128)) def _x_transform(i, s_ref): - s = pl.load(s_ref, (i,)) + s = s_ref[i] return (s, 0) s = jnp.tile(s[None], [2, 1]) @@ -1136,11 +1135,43 @@ def kernel(x_hbm_ref, y_hbm_ref, sem_val_ref, dma_sem): np.testing.assert_array_equal(y, x) np.testing.assert_array_equal(sem_val, 0) + def test_set_dma_priority(self): + if not jtu.if_cloud_tpu_at_least(2025, 4, 5): + self.skipTest('Needs a newer libTPU') + if jtu.get_tpu_version() < 5: + self.skipTest('Target does not support DMA prefetch between HBM and VMEM') + def kernel(x1, x2, y1, y2, scratch1, scratch2, sem1, sem2): + copy1 = pltpu.async_copy(x1, scratch1, sem1, priority=1) + copy2 = pltpu.async_copy(x2, scratch2, sem2, priority=0) + copy1.wait() + copy2.wait() + copy1 = pltpu.async_copy(scratch1, y1, sem1, priority=0) + copy2 = pltpu.async_copy(scratch2, y2, sem2, priority=1) + copy1.wait() + copy2.wait() + + shape = (8, 128) + dtype = jnp.int32 + x1 = jnp.arange(np.prod(shape), dtype=dtype).reshape(shape) + x2 = x1 + 1 + y1, y2 = self.pallas_call( + kernel, + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + in_specs=[pl.BlockSpec(memory_space=pl.ANY)] * 2, + scratch_shapes=[pltpu.VMEM(shape, dtype)] * 2 + + [pltpu.SemaphoreType.DMA] * 2, + out_specs=[pl.BlockSpec(memory_space=pl.ANY)] * 2, + ), + out_shape=[jax.ShapeDtypeStruct(shape, dtype)] * 2, + )(x1, x2) + np.testing.assert_array_equal(y1, x1) + np.testing.assert_array_equal(y2, x2) + def test_hbm_hbm_dma(self): def kernel(x_hbm_ref, y_hbm_ref): def body(sem): - pltpu.async_copy(x_hbm_ref.at[pl.ds(8), :], y_hbm_ref.at[:, pl.ds(128)], - sem).wait() + pltpu.async_copy(x_hbm_ref.at[:8, :], y_hbm_ref.at[:, :128], sem).wait() pl.run_scoped(body, pltpu.SemaphoreType.DMA) x = jnp.arange(8 * 128.).reshape((8, 128)) y = self.pallas_call( @@ -1546,6 +1577,7 @@ def kernel(x_bbm_ref, y_ref, sem, dma_sem): )(x) np.testing.assert_array_equal(y, x) + @jtu.thread_unsafe_test() # Uses a lot of TPU memory. def test_large_array_indexing(self): n = 6 dtype = jnp.bfloat16 @@ -2300,6 +2332,7 @@ def kernel(x_ref, y_ref): np.testing.assert_array_equal(y, x[8:16, :128]) +@jtu.thread_unsafe_test_class() # debug print test is not thread safe class PallasCallPrintTest(PallasBaseTest): def test_debug_print(self): @@ -2383,6 +2416,7 @@ def kernel(x_ref, o_ref): class PallasCallTraceTest(PallasBaseTest): + @jtu.thread_unsafe_test() # stdout redirection is not thread safe def test_trace_start_stop_match(self): def kernel(o_ref): with jax.named_scope('scope1'): @@ -2402,6 +2436,7 @@ def kernel(o_ref): self.assertEqual(num_start, 1) self.assertEqual(num_stop, 1) + @jtu.thread_unsafe_test() # stdout redirection is not thread safe def test_run_scoped(self): def kernel(o_ref): def scope1(): @@ -2570,8 +2605,7 @@ def body(scalar_ref, x_ref, o_ref): x = jnp.arange(8 * 8 * 128, dtype=jnp.int32).reshape((8 * 8, 128)) def _x_transform(i, s_ref): - s = pl.load(s_ref, (i,)) - return (s, 0) + return (s_ref[i], 0) pallas_call = self.pallas_call( body, @@ -2668,19 +2702,19 @@ class PrettyPrintingTest(PallasBaseTest): @parameterized.parameters( ( lambda i: (i, pl.ds(0, 8), pl.ds(0, 128)), - 'dma_start c[d,:,:] -> e[...] f', + 'dma_start(p0) c[d,:,:] -> e[...] f', ), ( lambda i: (0, pl.ds(i, 8), pl.ds(0, 128)), - 'dma_start c[0,d:d+8,:] -> e[...] f', + 'dma_start(p0) c[0,d:d+8,:] -> e[...] f', ), ( lambda i: (i, pl.ds(2, 4), pl.ds(0, 100)), - 'dma_start c[d,2:6,:100] -> e[...] f', + 'dma_start(p0) c[d,2:6,:100] -> e[...] f', ), ( lambda i: (i, pl.ds(2, 6), pl.ds(4, 100)), - 'dma_start c[d,2:,4:104] -> e[...] f', + 'dma_start(p0) c[d,2:,4:104] -> e[...] f', ), ) def test_dma_custom_pretty_print(self, indexer, expected): diff --git a/tests/pallas/tpu_ragged_paged_attention_test.py b/tests/pallas/tpu_ragged_paged_attention_test.py index bffcebc5254b..f86d54575519 100644 --- a/tests/pallas/tpu_ragged_paged_attention_test.py +++ b/tests/pallas/tpu_ragged_paged_attention_test.py @@ -13,14 +13,15 @@ # limitations under the License. import random + from absl.testing import absltest from absl.testing import parameterized import jax from jax._src import test_util as jtu from jax.experimental.pallas.ops.tpu.ragged_paged_attention import ( + dynamic_validate_inputs, ragged_paged_attention, ref_ragged_paged_attention, - validate_inputs_on_runtime, ) import jax.numpy as jnp @@ -50,6 +51,8 @@ def _test_ragged_paged_attention( vmem_limit_bytes=32 * 1024 * 1024, max_num_batched_tokens=512, max_num_seq=8, + sliding_window: int | None = None, + soft_cap: float | None = None, ): if not jtu.is_device_tpu_at_least(version=4): self.skipTest("Expect TPUv4+") @@ -71,59 +74,58 @@ def _test_ragged_paged_attention( cu_q_lens = jnp.pad(cu_q_lens, (0, max_num_seq + 1 - cu_q_lens.shape[0])) kv_lens = jnp.pad(kv_lens, (0, max_num_seq - kv_lens.shape[0])) prng_key = jax.random.key(1234) - k0, k1, k2, k3 = jax.random.split(prng_key, 4) + k0, k1, k2 = jax.random.split(prng_key, 3) q = jax.random.normal( k0, (max_num_batched_tokens, num_q_heads, head_dim), dtype=dtype, ) - k_pages = jax.random.normal( + kv_pages = jax.random.normal( k1, - (num_pages, page_size, num_kv_heads, head_dim), - dtype=dtype, - ) - v_pages = jax.random.normal( - k2, - (num_pages, page_size, num_kv_heads, head_dim), + (num_pages, page_size, num_kv_heads * 2, head_dim), dtype=dtype, ) page_indices = jax.random.randint( - k3, (max_num_seq, pages_per_seq), 0, num_pages, dtype=jnp.int32 + k2, (max_num_seq, pages_per_seq), 0, num_pages, dtype=jnp.int32 ) num_seqs = jnp.array([len(seq_lens)], dtype=jnp.int32) - validate_inputs_on_runtime( + dynamic_validate_inputs( q, - k_pages, - v_pages, + kv_pages, kv_lens, page_indices, cu_q_lens, num_seqs, + sliding_window=sliding_window, + soft_cap=soft_cap, ) + actual_num_q_tokens = cu_q_lens[num_seqs[0]] output = ragged_paged_attention( q, - k_pages, - v_pages, + kv_pages, kv_lens, page_indices, cu_q_lens, num_seqs=num_seqs, - num_kv_pages_per_block=num_kv_pages_per_block, + num_kv_pages_per_block=min(num_kv_pages_per_block, pages_per_seq), num_queries_per_block=num_queries_per_block, vmem_limit_bytes=vmem_limit_bytes, - )[: cu_q_lens[num_seqs[0]]] + sliding_window=sliding_window, + soft_cap=soft_cap, + )[: actual_num_q_tokens] expected = ref_ragged_paged_attention( q, - k_pages, - v_pages, + kv_pages, kv_lens, page_indices, cu_q_lens, num_seqs=num_seqs, + sliding_window=sliding_window, + soft_cap=soft_cap, ) tols = { "float32": 0.15, @@ -262,7 +264,7 @@ def test_ragged_paged_attention_mixed(self, dtype): @parameterized.product( num_seqs=[1, 5, 16], # TODO(jevinjiang): Support more num_heads! - num_heads=[(32, 8), (32, 16), (12, 2), (4, 4)], + num_heads=[(32, 8), (32, 16), (12, 2), (4, 4), (8, 1)], dtype=[jnp.float32, jnp.bfloat16], num_kv_pages_per_block=[4, 8], num_queries_per_block=[32, 64], @@ -296,6 +298,126 @@ def test_ragged_paged_attention_complex( num_queries_per_block=num_queries_per_block, ) + @parameterized.product( + num_kv_pages_per_block=[4, 8], + num_queries_per_block=[32, 64], + sliding_window=[None, 5, 128], + ) + def test_ragged_paged_attention_sliding_window( + self, + num_kv_pages_per_block, + num_queries_per_block, + sliding_window: int | None, + ): + num_seqs = 5 + num_heads = (4, 4) + dtype = jnp.float32 + seq_lens = [] + for _ in range(num_seqs): + q_len = random.randint(1, 100) + kv_len = q_len + random.randint(0, 50) + seq_lens.append((q_len, kv_len)) + # TODO(jevinjiang): Support non-128 head_dim! + head_dim = 128 + page_size = 16 + num_pages = 1000 + + self._test_ragged_paged_attention( + seq_lens, + num_heads, + head_dim, + page_size, + dtype, + num_pages, + num_kv_pages_per_block=num_kv_pages_per_block, + num_queries_per_block=num_queries_per_block, + sliding_window=sliding_window, + ) + + @parameterized.product( + num_kv_pages_per_block=[4, 8], + num_queries_per_block=[32, 64], + soft_cap=[None, 50.0], + ) + def test_ragged_paged_attention_logit_soft_capping( + self, + num_kv_pages_per_block, + num_queries_per_block, + soft_cap: float | None, + ): + num_heads = (12, 2) + num_seqs = 2 + dtype = jnp.float32 + seq_lens = [] + for _ in range(num_seqs): + q_len = random.randint(1, 100) + kv_len = q_len + random.randint(0, 50) + seq_lens.append((q_len, kv_len)) + head_dim = 128 + page_size = 16 + num_pages = 1000 + + self._test_ragged_paged_attention( + seq_lens, + num_heads, + head_dim, + page_size, + dtype, + num_pages, + num_kv_pages_per_block=num_kv_pages_per_block, + num_queries_per_block=num_queries_per_block, + soft_cap=soft_cap, + ) + + def test_ragged_paged_attention_sliding_window_should_be_positive(self): + dtype = jnp.float32 + seq_lens = [(192, 328), (128, 180), (64, 255)] + num_heads = (32, 8) + head_dim = 128 + page_size = 16 + num_pages = 1000 + + with self.assertRaisesRegex(ValueError, "must be positive"): + self._test_ragged_paged_attention( + seq_lens, + num_heads, + head_dim, + page_size, + dtype, + num_pages, + sliding_window=0, + ) + + with self.assertRaisesRegex(ValueError, "must be positive"): + self._test_ragged_paged_attention( + seq_lens, + num_heads, + head_dim, + page_size, + dtype, + num_pages, + sliding_window=-1, + ) + + def test_ragged_paged_attention_soft_cap_cannot_be_zero(self): + dtype = jnp.float32 + seq_lens = [(192, 328), (128, 180), (64, 255)] + num_heads = (32, 8) + head_dim = 128 + page_size = 16 + num_pages = 1000 + + with self.assertRaisesRegex(ValueError, "must not be 0.0"): + self._test_ragged_paged_attention( + seq_lens, + num_heads, + head_dim, + page_size, + dtype, + num_pages, + soft_cap=0.0, + ) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/tpu_splash_attention_kernel_sharded_test.py b/tests/pallas/tpu_splash_attention_kernel_sharded_test.py new file mode 100644 index 000000000000..db14b44938e9 --- /dev/null +++ b/tests/pallas/tpu_splash_attention_kernel_sharded_test.py @@ -0,0 +1,223 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for partitioning splash_attention.""" + +import functools +import math +from absl.testing import absltest, parameterized +import jax +from jax import random +from jax._src import test_util as jtu +from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel as splash +from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask as mask_lib +from jax.experimental.shard_map import shard_map +import jax.numpy as jnp +from jax.sharding import PartitionSpec +import numpy as np + +partial = functools.partial + +jax.config.parse_flags_with_absl() + + +@jtu.with_config(jax_traceback_filtering="off") +class PallasBaseTest(jtu.JaxTestCase): + INTERPRET = False + + def setUp(self): + super().setUp() + if not jtu.is_device_tpu(): + self.skipTest("Test requires TPU.") + + if len(jax.devices()) < 4: + self.skipTest("This test requires at least 4 devices.") + + def _assert_allclose(self, x, y, **kwargs): + if x.dtype == np.dtype(jnp.bfloat16): + x = x.astype(np.float32) + if y.dtype == np.dtype(jnp.bfloat16): + y = y.astype(np.float32) + self.assertEqual(x.dtype, y.dtype) + self.assertTupleEqual(x.shape, y.shape) + np.testing.assert_allclose(x, y, **kwargs) + + +def generate_mask(shape, num_heads, seed) -> np.ndarray: + assert num_heads >= 2 + assert shape > (64, 64) + + masks = [ + mask_lib.make_causal_mask(shape), + mask_lib.make_local_attention_mask(shape, window_size=(64, 64)), + ] + masks += [mask_lib.make_random_mask(shape, 0.8, seed)] * (num_heads - 2) + return np.stack(masks, axis=0) + + +class SplashAttentionShardingTest(PallasBaseTest): + + @parameterized.product( + topology=[(1, 1), (2, 1), (2, 2), (1, 2), (1, 4), (4, 1)], + num_heads=[2, 4, 16], + dtype=[jnp.bfloat16], + is_dynamic_mask=[False, True], + ) + def test_dynamic_mask_manual_partitioning_mha( + self, topology, num_heads, dtype, is_dynamic_mask + ): + k1, k2, k3 = random.split(random.key(0), 3) + seq_len = 1024 + head_dim = 128 + + head_shards, q_seq_shards = topology + num_devices = math.prod(topology) + + if head_shards > num_heads: + self.skipTest( + f"This test requires {num_heads} heads, but has only" + f" {head_shards} head shards available." + ) + + if len(jax.devices()) < num_devices: + self.skipTest( + f"This test requires {num_devices} devices, but has only" + f" {len(jax.devices())} devices available." + ) + + q = random.uniform(k1, (num_heads, seq_len, head_dim), dtype=dtype) + k = random.uniform(k2, (num_heads, seq_len, head_dim), dtype=dtype) + v = random.uniform(k3, (num_heads, seq_len, head_dim), dtype=dtype) + + mask = generate_mask((seq_len, seq_len), num_heads, seed=0) + if is_dynamic_mask: + mask = jnp.array(mask) + + devices = np.asarray(jax.devices()[:num_devices]).reshape( + head_shards, q_seq_shards + ) + + mesh = jax.sharding.Mesh(devices, ("heads", "q_seq")) + q_spec = PartitionSpec( + "heads" if head_shards > 1 else None, + "q_seq" if q_seq_shards > 1 else None, + ) + kv_spec = PartitionSpec("heads" if head_shards > 1 else None, None) + kernel = splash.make_splash_mha( + mask, head_shards=head_shards, q_seq_shards=q_seq_shards + ) + kernel_spec = kernel.manual_sharding_spec( + jax.sharding.NamedSharding(mesh, q_spec) + ) + + @partial( + shard_map, + mesh=mesh, + in_specs=( + kernel_spec, + q_spec, + kv_spec, + kv_spec, + ), + out_specs=q_spec, + check_rep=False, + ) + def f(kernel, q, k, v): + return kernel(q, k, v) + + out = f(kernel, q, k, v) + out_ref = jax.vmap(splash.attention_reference)(mask, q, k, v, None) + self._assert_allclose(out, out_ref, rtol=3e-3, atol=3e-3) + + @parameterized.product( + topology=[(1, 1), (2, 1), (2, 2), (1, 2), (1, 4), (4, 1)], + num_heads=[2, 4], + dtype=[jnp.bfloat16], + is_dynamic_mask=[False, True], + ) + def test_dynamic_mask_manual_partitioning_mha_bwd( + self, topology, num_heads, dtype, is_dynamic_mask + ): + assert num_heads % 2 == 0 + k1, k2, k3, k4 = random.split(random.key(0), 4) + seq_len = 1024 + head_dim = 128 + + head_shards, q_seq_shards = topology + num_devices = math.prod(topology) + + if head_shards > num_heads: + self.skipTest( + f"This test requires {num_heads} heads, but has only" + f" {head_shards} head shards available." + ) + + q = random.uniform(k1, (num_heads, seq_len, head_dim), dtype=dtype) + k = random.uniform(k2, (num_heads, seq_len, head_dim), dtype=dtype) + v = random.uniform(k3, (num_heads, seq_len, head_dim), dtype=dtype) + + mask = generate_mask((seq_len, seq_len), num_heads, seed=0) + if is_dynamic_mask: + mask = jnp.array(mask) + + devices = np.asarray(jax.devices()[:num_devices]).reshape( + head_shards, q_seq_shards + ) + + mesh = jax.sharding.Mesh(devices, ("heads", "q_seq")) + q_spec = PartitionSpec( + "heads" if head_shards > 1 else None, + "q_seq" if q_seq_shards > 1 else None, + ) + kv_spec = PartitionSpec("heads" if head_shards > 1 else None, None) + + kernel = splash.make_splash_mha( + mask, head_shards=head_shards, q_seq_shards=q_seq_shards + ) + kernel_spec = kernel.manual_sharding_spec( + jax.sharding.NamedSharding(mesh, q_spec) + ) + + @partial( + shard_map, + mesh=mesh, + in_specs=( + kernel_spec, + q_spec, + kv_spec, + kv_spec, + ), + out_specs=q_spec, + check_rep=False, + ) + def f(kernel, q, k, v): + return kernel(q, k, v) + + f_ref = jax.vmap(splash.attention_reference) + + out, out_vjp = jax.vjp(f, kernel, q, k, v) + out_ref, out_vjp_ref = jax.vjp(f_ref, mask, q, k, v, None) + self._assert_allclose(out, out_ref, rtol=3e-3, atol=3e-3) + + do = random.uniform(k4, out.shape, dtype=out.dtype) + _, dq, dk, dv = out_vjp(do) + _, dq_ref, dk_ref, dv_ref, _ = out_vjp_ref(do.astype(jnp.float32)) + + self.assertAllClose(dq, dq_ref, atol=5e-2) + self.assertAllClose(dk, dk_ref, atol=5e-2) + self.assertAllClose(dv, dv_ref, atol=5e-2) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/tpu_splash_attention_kernel_test.py b/tests/pallas/tpu_splash_attention_kernel_test.py index dfe0bcc0da3b..a494a62745d1 100644 --- a/tests/pallas/tpu_splash_attention_kernel_test.py +++ b/tests/pallas/tpu_splash_attention_kernel_test.py @@ -32,11 +32,9 @@ import numpy as np -try: - import hypothesis as hp - import hypothesis.strategies as hps -except (ModuleNotFoundError, ImportError): - raise unittest.SkipTest("these tests require hypothesis") +import hypothesis as hp +import hypothesis.strategies as hps + jax.config.parse_flags_with_absl() jtu.setup_hypothesis(max_examples=5) @@ -303,14 +301,6 @@ def attn_logits_soft_cap_strategy() -> hps.SearchStrategy[float | None]: return hps.one_of(hps.just(None), hps.floats(min_value=1.0, max_value=50.0)) -def to_dynamic_mask(mask: mask_lib.MultiHeadMask) -> jax.Array: - q_seq_len, kv_seq_len = mask.masks[0].shape - full_mask_slice = (slice(0, q_seq_len), slice(0, kv_seq_len)) - dynamic_mask = jnp.stack([m[full_mask_slice] for m in mask.masks], axis=0) - - return dynamic_mask - - @jtu.with_config(jax_traceback_filtering="off") class PallasBaseTest(jtu.JaxTestCase): INTERPRET = False @@ -337,6 +327,7 @@ def _assert_allclose(self, x, y, **kwargs): np.testing.assert_allclose(x, y, **kwargs) +@jtu.thread_unsafe_test_class() # hypothesis is not thread safe class SplashAttentionTest(PallasBaseTest): @parameterized.product( is_mqa=(False, True), @@ -384,7 +375,7 @@ def test_splash_attention(self, is_mqa, is_segmented, is_dynamic_mask, data): masks = data.draw(mha_mask_strategy(q_seq_len, kv_seq_len, num_q_heads)) mask = mask_lib.MultiHeadMask(tuple(m.get_mask() for m in masks)) if is_dynamic_mask: - mask = to_dynamic_mask(mask) + mask = jnp.array(mask[:, :, :]) block_sizes = data.draw(block_sizes_strategy(q_seq_len, kv_seq_len)) if is_mqa: @@ -460,7 +451,7 @@ def test_splash_attention_fwd( masks = data.draw(mha_mask_strategy(q_seq_len, kv_seq_len, num_q_heads)) mask = mask_lib.MultiHeadMask(tuple(m.get_mask() for m in masks)) if is_dynamic_mask: - mask = to_dynamic_mask(mask) + mask = jnp.array(mask[:, :, :]) block_sizes = data.draw(block_sizes_strategy(q_seq_len, kv_seq_len)) if is_mqa: attn_ref = splash.make_masked_mqa_reference(mask) @@ -522,9 +513,9 @@ def test_splash_attention_custom_bwd(self, is_segmented, data): masks = data.draw(mha_mask_strategy(q_seq_len, kv_seq_len, 1)) mask = jnp.array(masks[0].get_mask()[:, :]) attn_logits_soft_cap = data.draw(attn_logits_soft_cap_strategy(), - label="logit_cap") + label="logit_cap") attn_ref = partial(splash.attention_reference, mask, - attn_logits_soft_cap=attn_logits_soft_cap) + attn_logits_soft_cap=attn_logits_soft_cap) attn_custom = partial(splash.attention_reference_custom, mask, attn_logits_soft_cap=attn_logits_soft_cap) attn_custom_vanilla = partial(splash.attention_reference_custom, mask, @@ -532,7 +523,7 @@ def test_splash_attention_custom_bwd(self, is_segmented, data): attn_logits_soft_cap=attn_logits_soft_cap) o_ref, attn_vjp_ref = jax.vjp(attn_ref, q, k, v, segment_ids) q32, k32, v32 = jax.tree.map(lambda x: x.astype(jnp.float32), - (q, k, v)) + (q, k, v)) o_custom = attn_custom(q32, k32, v32, segment_ids) _, attn_vjp = jax.vjp(attn_custom, q32, k32, v32, segment_ids) _, attn_vanilla_vjp = jax.vjp(attn_custom_vanilla, q32, k32, v32, @@ -628,10 +619,10 @@ def test_splash_attention_bwd( masks = data.draw(mha_mask_strategy(q_seq_len, kv_seq_len, num_q_heads)) mask = mask_lib.MultiHeadMask(tuple(m.get_mask() for m in masks)) if use_dynamic_mask: - mask = to_dynamic_mask(mask) + mask = jnp.array(mask[:, :, :]) block_sizes = data.draw( block_sizes_strategy(q_seq_len, kv_seq_len, include_bwd_blocks=True, - use_fused_bwd_kernel=use_fused_bwd_kernel) + use_fused_bwd_kernel=use_fused_bwd_kernel) ) if is_mqa: attn_ref = splash.make_masked_mqa_reference(mask, backward_impl="custom") diff --git a/tests/pallas/tpu_splash_attention_mask_test.py b/tests/pallas/tpu_splash_attention_mask_test.py index f39b4d839340..7c4b53529169 100644 --- a/tests/pallas/tpu_splash_attention_mask_test.py +++ b/tests/pallas/tpu_splash_attention_mask_test.py @@ -44,6 +44,15 @@ def _make_local_attention_mask(*args, **kwargs): return mask_lib.make_local_attention_mask(*args, **kwargs) +def _make_lazy_chunked_causal_mask(shape, chunk_size): + mask = mask_lib.ChunkedCausalMask(shape=shape, chunk_size=chunk_size) + return mask[:, :] + + +def _make_chunked_causal_mask(shape, chunk_size): + return mask_lib.make_chunk_attention_mask(shape=shape, chunk_size=chunk_size) + + class SplashAttentionMaskTest(jtu.JaxTestCase): @parameterized.parameters([_make_lazy_causal_mask, _make_causal_mask]) @@ -412,6 +421,181 @@ def test_lazy_local_mask_chunking( block_size, ) + @parameterized.parameters( + [_make_lazy_chunked_causal_mask, _make_chunked_causal_mask] + ) + def test_chunked_causal_mask(self, make_chunked_mask): + """Tests the chunked causal mask logic for various shapes and chunk sizes.""" + with self.subTest("unit"): + expected = np.array([[1]], dtype=np.bool_) + actual = make_chunked_mask(shape=(1, 1), chunk_size=1) + self.assertArraysEqual(actual, expected) + actual = make_chunked_mask(shape=(1, 1), chunk_size=2) + self.assertArraysEqual(actual, expected) + + with self.subTest("square_exact_chunks"): + # Chunk 0: [0, 1], Chunk 1: [2, 3] + expected = np.array( + [ + [1, 0, 0, 0], + [1, 1, 0, 0], + [0, 0, 1, 0], + [0, 0, 1, 1], + ], + dtype=np.bool_, + ) + actual = make_chunked_mask(shape=(4, 4), chunk_size=2) + self.assertArraysEqual(actual, expected) + + with self.subTest("square_uneven_chunks"): + expected = np.array( + [ + [1, 0, 0, 0, 0], + [1, 1, 0, 0, 0], + [1, 1, 1, 0, 0], + [0, 0, 0, 1, 0], + [0, 0, 0, 1, 1], + ], + dtype=np.bool_, + ) + actual = make_chunked_mask(shape=(5, 5), chunk_size=3) + self.assertArraysEqual(actual, expected) + + with self.subTest("wide_rectangle"): + expected = np.array( + [ + [1, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0], + ], + dtype=np.bool_, + ) + actual = make_chunked_mask(shape=(4, 6), chunk_size=3) + self.assertArraysEqual(actual, expected) + + with self.subTest("tall_rectangle"): + expected = np.array( + [ + [1, 0, 0, 0], + [1, 1, 0, 0], + [1, 1, 1, 0], + [0, 0, 0, 1], + [0, 0, 0, 1], + [0, 0, 0, 1], + ], + dtype=np.bool_, + ) + actual = make_chunked_mask(shape=(6, 4), chunk_size=3) + self.assertArraysEqual(actual, expected) + + with self.subTest("chunk_size_1"): + # Should only allow self-attention q==k and chunk_size == 1 + expected = np.array( + [ + [1, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, 1, 0], + [0, 0, 0, 1], + ], + dtype=np.bool_, + ) + actual = make_chunked_mask(shape=(4, 4), chunk_size=1) + self.assertArraysEqual(actual, expected) + + with self.subTest("chunk_size_greater_equal_seqlen"): + # Should behave like a normal causal mask + expected = np.array( + [ + [1, 0, 0, 0], + [1, 1, 0, 0], + [1, 1, 1, 0], + [1, 1, 1, 1], + ], + dtype=np.bool_, + ) + # Test chunk_size == seqlen + actual_eq = make_chunked_mask(shape=(4, 4), chunk_size=4) + self.assertArraysEqual(actual_eq, expected) + # Test chunk_size > seqlen + actual_gt = make_chunked_mask(shape=(4, 4), chunk_size=5) + self.assertArraysEqual(actual_gt, expected) + + @parameterized.product( + block_size=[(128, 128), (256, 128), (128, 256)], + shape=[(512, 512), (512, 1024), (1024, 512)], + chunk_size=[64, 128, 256, 512, 1024], + ) + def test_lazy_chunked_causal_mask_chunking( + self, + block_size: tuple[int, int], + shape: tuple[int, int], + chunk_size: int, + ): + """Compares lazy chunked mask evaluation against the dense version block-by-block.""" + q_len, kv_len = shape + # Adjust block size if it exceeds shape dimensions + adjusted_block_size = ( + min(block_size[0], q_len), + min(block_size[1], kv_len), + ) + + if ( + q_len % adjusted_block_size[0] != 0 + or kv_len % adjusted_block_size[1] != 0 + ): + self.skipTest( + f"Shape {shape} not divisible by block_size {adjusted_block_size}" + ) + + dense_mask = _make_chunked_causal_mask(shape=shape, chunk_size=chunk_size) + lazy_mask = mask_lib.ChunkedCausalMask(shape=shape, chunk_size=chunk_size) + self._compare_masks( + dense_mask, + lazy_mask, + adjusted_block_size, + ) + + def test_chunked_causal_mask_invalid_chunk_size(self): + """Tests that invalid chunk_size raises ValueError.""" + with self.assertRaises(ValueError): + mask_lib.ChunkedCausalMask(shape=(10, 10), chunk_size=0) + with self.assertRaises(ValueError): + mask_lib.ChunkedCausalMask(shape=(10, 10), chunk_size=-1) + with self.assertRaises(ValueError): + mask_lib.make_chunk_attention_mask(shape=(10, 10), chunk_size=0) + + def test_chunked_causal_mask_minimal_equality_hash(self): + """Tests for __eq__ and __hash__ of ChunkedCausalMask.""" + shape1, chunk_size1 = (128, 256), 16 + shape2, chunk_size2 = (128, 128), 32 # Different shape/chunk_size + + # Create three masks: two identical, one with different shape/chunk_size. + mask1 = mask_lib.ChunkedCausalMask(shape=shape1, chunk_size=chunk_size1) + mask2 = mask_lib.ChunkedCausalMask(shape=shape1, chunk_size=chunk_size1) + mask_diff_shape = mask_lib.ChunkedCausalMask( + shape=shape2, chunk_size=chunk_size1 + ) + mask_diff_chunk = mask_lib.ChunkedCausalMask( + shape=shape1, chunk_size=chunk_size2 + ) + other_obj = object() + + # Test __eq__ + self.assertEqual(mask1, mask2) + self.assertNotEqual(mask1, mask_diff_shape) + self.assertNotEqual(mask1, mask_diff_chunk) + self.assertNotEqual(mask1, other_obj) + + # Test __hash__ of identical masks + self.assertEqual(hash(mask1), hash(mask2)) + + mask_set = {mask1, mask2, mask_diff_chunk} + self.assertLen(mask_set, 2) # mask1 and mask2 are duplicates + self.assertIn(mask1, mask_set) + self.assertIn(mask_diff_chunk, mask_set) + self.assertNotIn(mask_diff_shape, mask_set) + def test_using_logical_operators_raises_exception(self): mask_1 = mask_lib.NumpyMask( mask_lib.make_random_mask((256, 256), 0.5, seed=1) @@ -2166,7 +2350,9 @@ def test_dynamic_mask(self, is_dkv: bool): self.assertArraysEqual(mask_info.block_mask, _expected_block_mask) self.assertArraysEqual( - mask_info.partial_mask_blocks, + mask_info.partial_mask_blocks.reshape( + -1, *mask_info.partial_mask_blocks.shape[-2:] + ), _expected_partial_mask_blocks, ) self.assertArraysEqual(mask_info.mask_next, _expected_mask_next) diff --git a/tests/pgle_test.py b/tests/pgle_test.py index 7f9ea598d51b..2787de4c6e17 100644 --- a/tests/pgle_test.py +++ b/tests/pgle_test.py @@ -21,7 +21,7 @@ import tempfile import warnings -from absl.testing import absltest +from absl.testing import absltest, parameterized import jax from jax._src import api from jax._src import compilation_cache as cc @@ -65,7 +65,11 @@ def testPGLEProfilerGetFDOProfile(self): jax.jit, in_shardings=NamedSharding(mesh, PartitionSpec('x')), out_shardings=NamedSharding(mesh, PartitionSpec('x')), - compiler_options={'xla_gpu_enable_latency_hiding_scheduler': 'True'}, + compiler_options={ + 'xla_gpu_enable_latency_hiding_scheduler': 'True', + # Make sure that matmul is not emitted as Triton GEMM. + 'xla_gpu_enable_triton_gemm': 'False', + }, ) def f(x, y): return x @ y @@ -93,6 +97,8 @@ def testPGLEProfilerGetFDOProfileLarge(self): compiler_options = { 'xla_gpu_enable_latency_hiding_scheduler': 'True', + # Make sure that matmul is not emitted as Triton GEMM. + 'xla_gpu_enable_triton_gemm': 'False', } # TODO(b/37664749): Remove this flag once the bug is fixed. compiler_options['xla_gpu_enable_command_buffer'] = '' @@ -321,7 +327,11 @@ def testPassingFDOProfile(self): jax.jit, in_shardings=NamedSharding(mesh, PartitionSpec('x')), out_shardings=NamedSharding(mesh, PartitionSpec('x')), - compiler_options={'xla_gpu_enable_latency_hiding_scheduler': 'True'}, + compiler_options={ + 'xla_gpu_enable_latency_hiding_scheduler': 'True', + # Make sure that matmul is not emitted as Triton GEMM. + 'xla_gpu_enable_triton_gemm': 'False', + }, ) def f(x, y): return x @ y @@ -468,5 +478,55 @@ def check_if_cache_hit(event): self.assertLen(w, 1) self.assertIn("PERSISTENT CACHE WRITE with key jit_h-", str(w[0].message)) + @parameterized.parameters([True, False]) + @jtu.thread_unsafe_test() + def testAutoPgleWithCommandBuffers(self, enable_compilation_cache): + with (config.pgle_profiling_runs(1), + config.enable_compilation_cache(enable_compilation_cache), + config.enable_pgle(True), + tempfile.TemporaryDirectory() as dump_dir, + tempfile.TemporaryDirectory() as cache_dir): + if enable_compilation_cache: + cc.reset_cache() + cc.set_cache_dir(cache_dir) + compiler_options = { + 'xla_dump_to': dump_dir, + # FUSION, see https://github.com/openxla/xla/issues/22459 + 'xla_gpu_enable_command_buffer': 1, + 'xla_gpu_graph_min_graph_size': 1, + } + @partial( + jax.jit, + compiler_options=compiler_options, + ) + def f(x): + return x * 2 + + x = jnp.arange(1) + expected = x * 2 + + # This is ugly, but it does not seem possible to get the AutoPGLE-recompiled + # executable text (.lower(x).compile().as_text() or similar). + def get_new_hlo(): + additions = set(os.listdir(dump_dir)) - get_new_hlo.seen_files + get_new_hlo.seen_files |= additions + new_hlos = list(filter(lambda f: f.endswith("_gpu_after_optimizations.txt"), additions)) + assert len(new_hlos) == 1 + with open(os.path.join(dump_dir, new_hlos[0]), "r") as ifile: + return ifile.read() + + get_new_hlo.seen_files = set() + + # Run 1 + self.assertArraysEqual(f(x), expected) + self.assertNotIn("command_buffer", get_new_hlo()) # b/376647494 workaround + # Run 2 + self.assertArraysEqual(f(x), expected) + self.assertIn("command_buffer", get_new_hlo()) # workaround disabled + + api.clear_caches() + pjit._pgle_profiler_dict.clear() + + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pickle_test.py b/tests/pickle_test.py index 185eebd90726..a3dc5be6e11c 100644 --- a/tests/pickle_test.py +++ b/tests/pickle_test.py @@ -28,6 +28,7 @@ from jax.interpreters import pxla from jax._src import test_util as jtu from jax._src.lib import xla_client as xc +from jax._src.sharding_impls import GSPMDSharding import numpy as np @@ -182,7 +183,7 @@ def test_pickle_pmap_sharding(self): self.assertEqual(s, pickle.loads(pickle.dumps(s))) def test_pickle_gspmd_sharding(self): - s = jax.sharding.GSPMDSharding.get_replicated(jax.devices()) + s = GSPMDSharding.get_replicated(jax.devices()) self.assertEqual(s, pickle.loads(pickle.dumps(s))) @unittest.skipIf(cloudpickle is None, "Requires cloudpickle") diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 293b37a9fbc7..8f30475eee32 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -14,7 +14,7 @@ from collections import OrderedDict, namedtuple import re -from functools import partial +from functools import partial, wraps import logging import json import math @@ -59,10 +59,10 @@ from jax._src import mesh as mesh_lib from jax._src.mesh import AxisType from jax._src.interpreters import pxla -from jax._src.lib.mlir import dialects from jax._src import xla_bridge from jax._src.lib import xla_client as xc from jax._src.lib import xla_extension +from jax._src.lib import jaxlib_extension_version from jax._src.util import curry, unzip2 config.parse_flags_with_absl() @@ -940,6 +940,18 @@ def testWithCustomPRNGKey(self): # Make sure this doesn't crash pjit(lambda x: x, in_shardings=None, out_shardings=None)(key) + def test_lower_with_wrapper_error(self): + @jax.jit + def f(x): + return x + + self.assertAllClose(1., f(1.)) + self.assertAllClose(1., f.lower(1.).compile()(1.)) + wrapped_f = wraps(f)(lambda x: f(x + 1)) + + with self.assertRaisesRegex(AttributeError, "has no attribute 'lower'"): + wrapped_f.lower(1.) + @jtu.with_mesh([('x', 2), ('y', 2)]) def testLowerCompile(self): @partial(pjit, @@ -1240,9 +1252,12 @@ def test_pretty_print_pjit_id(self): jaxpr.pretty_print(use_color=False), textwrap.dedent(""" { lambda ; a:f32[1]. let - pjit[name= jaxpr={ lambda ; a:f32[1] b:f32[1]. let in () }] a a - c:f32[1] = add a a - in (c,) } + b:f32[1] = pjit[ + name= + jaxpr={ lambda ; a:f32[1] c:f32[1]. let in (a,) } + ] a a + d:f32[1] = add a b + in (d,) } """).strip(), ) @@ -1289,8 +1304,11 @@ def test_pretty_print_with_literal_outvar(self): jaxpr.pretty_print(use_color=False), textwrap.dedent(""" { lambda ; a:f32[1]. let - b:i32[] = pjit[name= jaxpr={ lambda ; a:f32[1]. let in (2,) }] a - in (b, a) } + b:i32[] c:f32[1] = pjit[ + name= + jaxpr={ lambda ; a:f32[1]. let in (2, a) } + ] a + in (b, c) } """).strip(), ) @@ -1336,19 +1354,19 @@ def f(x): self.assertEqual( jaxpr.pretty_print(use_color=False), textwrap.dedent(""" - let f = { lambda ; a:f32[1]. let in () } in - let f1 = { lambda ; b:f32[2]. let in () } in + let f = { lambda ; a:f32[1]. let in (a,) } in + let f1 = { lambda ; b:f32[2]. let in (b,) } in { lambda ; c:f32[1] d:f32[2]. let e:f32[2] = pjit[ name=g jaxpr={ lambda ; c:f32[1] d:f32[2]. let - pjit[name=f jaxpr=f] c - pjit[name=f jaxpr=f] c - g:f32[1] = mul c c - pjit[name=f jaxpr=f1] d - pjit[name=f jaxpr=f1] d - h:f32[2] = mul d d - e:f32[2] = add g h + g:f32[1] = pjit[name=f jaxpr=f] c + h:f32[1] = pjit[name=f jaxpr=f] c + i:f32[1] = mul g h + j:f32[2] = pjit[name=f jaxpr=f1] d + k:f32[2] = pjit[name=f jaxpr=f1] d + l:f32[2] = mul j k + e:f32[2] = add i l in (e,) } ] c d in (e,) } @@ -1394,6 +1412,18 @@ def test_zero_literal_equality(self): self.assertIn("stablehlo.constant dense<0.000000e+00>", ir) self.assertIn("stablehlo.constant dense<-0.000000e+00>", ir) + def test_device_put_copy_donate(self): + if jaxlib_extension_version < 327: + raise unittest.SkipTest("Copy not supported in device put.") + x = np.arange(1000) + y = jax.device_put(x, device=jax.devices()[0], may_alias=False, donate=False) + z = jax.device_put(y, device=jax.devices()[0], may_alias=False, donate=False) + a = jax.jit(lambda y: y * 2, donate_argnums=0)(y) + self.assertDeleted(y) + self.assertNotDeleted(z) + self.assertArraysEqual(a, x * 2) + + @jtu.pytest_mark_if_available('multiaccelerator') class CustomPartitionerTest(jtu.JaxTestCase): @@ -2477,6 +2507,20 @@ def test_pjit_committed_array_different_devices_variadic_args(self): r"\[1\].*"): pjit(lambda *x: x)(a, b) + def test_jit_no_forwarding(self): + mesh = jtu.create_mesh((2,), ('x',)) + + @partial(jax.jit, donate_argnums=(0,)) + def f(x): + return x, x * 2 + + x = jax.device_put(jnp.zeros(64, dtype="int32"), NamedSharding(mesh, P())) + jaxpr = jax.make_jaxpr(f)(x) + y = core.jaxpr_as_fun(jaxpr)(x) + self.assertTrue(x.is_deleted()) + self.assertFalse(y[0].is_deleted()) + self.assertFalse(y[1].is_deleted()) + def test_pjit_pytree_inp_device_assignment_mismatch(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) a = jax.device_put(np.array([1, 2, 3]), jax.devices()[0]) @@ -3423,9 +3467,8 @@ def f(x, y): f(x_, y) self.assertLen(cm.output, expected_log_len) msg = cm.output[0] - self.assertIn('never seen input type signature', msg) - self.assertIn('closest seen input type signature has 1 mismatches', msg) - self.assertIn("seen f32[8]({}), but now given f32[8]({Auto: ('x',)})", msg) + self.assertIn("different input types", msg) + self.assertIn("at x, now f32[8]({Auto: ('x',)}) and before f32[8]({})", msg) def test_pjit_function_cache_cpp(self): def f(x): @@ -3438,6 +3481,7 @@ def f(x): pjit(f)(inp) self.assertEqual(count(), 1) + @jtu.thread_unsafe_test() # count_pjit_cpp_cache_miss is not thread-safe def test_pjit_no_global_cache_hit_axis_resources(self): mesh = jtu.create_mesh((1,), ('x',)) s = NamedSharding(mesh, P('x')) @@ -4995,11 +5039,13 @@ def g(x, y): return x * y with self.assertRaisesRegex( - TypeError, "mul got incompatible shardings for broadcasting"): + core.ShardingTypeError, + "mul got incompatible shardings for broadcasting"): g(arr1, jax.device_put(np_inp1, NamedSharding(mesh, P('y', 'x')))) with self.assertRaisesRegex( - TypeError, "mul got incompatible shardings for broadcasting"): + core.ShardingTypeError, + "mul got incompatible shardings for broadcasting"): g(arr1, jax.device_put(np_inp1, NamedSharding(mesh, P(('x', 'y'))))) @parameterized.named_parameters( @@ -5098,14 +5144,14 @@ def f(x, y): @parameterized.named_parameters( ('fail1', P('x', None), P(None, 'x'), "dot_general operation.*produces an illegally sharded result", - TypeError), + core.ShardingTypeError), ('fail2', P('x', 'y'), P('x', 'y'), "dot_general requires contracting dimensions to have consistent sharding", - TypeError), + core.ShardingTypeError), ('contracting1', P('x', 'y'), P('y', None), - 'Contracting dimensions are sharded', ValueError), + 'Contracting dimensions are sharded', core.ShardingTypeError), ('other_half_tp', P(None, 'y'), P('y', None), - 'Contracting dimensions are sharded', ValueError), + 'Contracting dimensions are sharded', core.ShardingTypeError), ) @jtu.with_user_mesh((2, 2), ('x', 'y')) def test_dot_general_error(self, spec1, spec2, error_msg, error_type, mesh): @@ -5127,14 +5173,14 @@ def test_dot_general_batch_error(self, mesh): arr2 = jax.device_put(np.ones((8, 2, 4)), NamedSharding(mesh, P('y', 'z', 'x'))) with self.assertRaisesRegex( - TypeError, + core.ShardingTypeError, 'dot_general requires lhs batch dimensions and rhs batch dimensions to' ' have the consistent sharding'): jax.lax.dot_general( arr1, arr2, dimension_numbers=(([2], [1]), ([0], [0]))) with self.assertRaisesRegex( - TypeError, + core.ShardingTypeError, 'dot_general requires lhs batch dimensions and rhs batch dimensions to' ' have the consistent sharding'): jnp.einsum('abc,acz->abz', arr1, arr2) @@ -5472,6 +5518,18 @@ def h2(x, y): ('4', (1, 4, 1, 6, 1), (1, 4, 6), P(None, 'x', None, None, None), P(None, 'x', None), False), ('5', (4, 6), (4, 6), P(None, 'x'), P(None, 'x'), False), + ('6', (1024, 4096), (1024, 2048, 2, 1, 1, 1, 1), + P('x', None), P('x', None, None, None, None, None, None), False), + ('7', (1024, 4096, 32), (1024, 2048, 2, 1, 1, 32), + P('x', None, None), P('x', None, None, None, None, None), False), + ('8', (1024, 4096), (1024, 1, 1, 4096), + P('x', None), P('x', None, None, None), False), + ('9', (1024, 4096), (1024, 1, 1, 4096), + P(None, 'x'), P(None, None, None, 'x'), False), + ('10', (1024, 2048, 2, 1, 1, 1), (1024, 4096), + P('x', None, None, None, None, None), P('x', None), False), + ('11', (1024, 2048, 2, 1, 1, 1), (1024, 4096), + P(None, 'x', None, None, None, None), P(None, 'x'), False), ) @jtu.with_user_mesh((2,), ('x',)) def test_reshape(self, src_shape, dst_shape, src_spec, dst_spec, @@ -5483,6 +5541,8 @@ def test_reshape(self, src_shape, dst_shape, src_spec, dst_spec, @partial(jax.jit, static_argnums=1) def f(x, new_sharding): y = lax.reshape(x, dst_shape, out_sharding=new_sharding) + self.assertEqual(y.aval.sharding.spec, dst_spec) + self.assertEqual(y.shape, dst_shape) y = y * 2 self.assertEqual(y.aval.sharding.spec, dst_spec) return y @@ -5569,7 +5629,7 @@ def f(x): return y if error_msg: - with self.assertRaisesRegex(ValueError, error_msg): + with self.assertRaisesRegex(core.ShardingTypeError, error_msg): f(arr) else: out = f(arr) @@ -5608,7 +5668,7 @@ def f(pred, on_true, on_false): arr3 = jax.device_put(np_inp, NamedSharding(mesh, P('y', 'x'))) with self.assertRaisesRegex( - TypeError, "select cases must have the same shardings"): + core.ShardingTypeError, "select cases must have the same shardings"): f(arr1 == arr2, arr1, arr3) def test_explicit_mode_no_context_mesh(self): @@ -5778,10 +5838,10 @@ def g(x): out = jax.jit(jax.grad(g))(arr) self.assertEqual(out.sharding, arr.sharding) - with self.assertRaisesRegex(NotImplementedError, "slicing on sharded dims"): + with self.assertRaisesRegex(core.ShardingTypeError, "slicing on sharded dims"): f(jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y')))) - with self.assertRaisesRegex(NotImplementedError, "slicing on sharded dims"): + with self.assertRaisesRegex(core.ShardingTypeError, "slicing on sharded dims"): f(jax.device_put(np_inp, NamedSharding(mesh, P(None, ('x', 'y'))))) @jtu.with_user_mesh((2, 2), ('x', 'y')) @@ -5842,13 +5902,13 @@ def g(x): out = jax.jit(jax.grad(g))(arr) self.assertEqual(out.sharding, arr.sharding) - with self.assertRaisesRegex(NotImplementedError, "padding on sharded dims"): + with self.assertRaisesRegex(core.ShardingTypeError, "padding on sharded dims"): f(arr, ((2, 3, 0), ), None) - with self.assertRaisesRegex(NotImplementedError, "padding on sharded dims"): + with self.assertRaisesRegex(core.ShardingTypeError, "padding on sharded dims"): f(arr, ((0, 3, 0), ), None) - with self.assertRaisesRegex(NotImplementedError, "padding on sharded dims"): + with self.assertRaisesRegex(core.ShardingTypeError, "padding on sharded dims"): arr = jax.device_put(np_inp, NamedSharding(mesh, P(('x', 'y')))) f(arr, ((4, 4, 1),), None) @@ -5879,7 +5939,7 @@ def f(x, y, method='jnp'): self.assertArraysEqual(out, np.concatenate([arr1, arr2], axis=1)) with self.assertRaisesRegex( - TypeError, "All operands should have the same sharding"): + core.ShardingTypeError, "All operands should have the same sharding"): arr3 = jax.device_put(np.arange(4.).reshape(4, 1), NamedSharding(mesh, P('x'))) f(arr1, arr3) @@ -6147,7 +6207,7 @@ def f(x, sizes=(4, 4), axis=0): f(arr) self.check_wsc_in_lowered(f.lower(arr).as_text()) - with self.assertRaisesRegex(NotImplementedError, "split on sharded dims"): + with self.assertRaisesRegex(core.ShardingTypeError, "split on sharded dims"): f(arr, sizes=(1, 1), axis=1) def g(x): @@ -6360,8 +6420,8 @@ def f(x): def test_intermediate_einsum(self, mesh): shape1 = (8, 32, 1, 16) shape2 = (8, 32, 1, 8) - np_inp1 = np.arange(math.prod(shape1)).reshape(shape1) - np_inp2 = np.arange(math.prod(shape2)).reshape(shape2) + np_inp1 = np.ones(math.prod(shape1)).reshape(shape1) + np_inp2 = np.ones(math.prod(shape2)).reshape(shape2) s = NamedSharding(mesh, P('data')) arr1 = jax.device_put(np_inp1, s) @@ -6387,9 +6447,9 @@ def test_intermediate_einsum_auto_complete_spec(self, mesh): shape1 = (8, 32, 2*16) shape2 = (8, 32, 2, 8) shape3 = (8, 32, 2, 8) - np_inp1 = np.arange(math.prod(shape1)).reshape(shape1) - np_inp2 = np.arange(math.prod(shape2)).reshape(shape2) - np_inp3 = np.arange(math.prod(shape3)).reshape(shape3) + np_inp1 = np.ones(math.prod(shape1)).reshape(shape1) + np_inp2 = np.ones(math.prod(shape2)).reshape(shape2) + np_inp3 = np.ones(math.prod(shape3)).reshape(shape3) arr1 = jax.device_put(np_inp1, s) arr2 = jax.device_put(np_inp2, s) @@ -6436,8 +6496,8 @@ def f(condition, x, y): def test_intermediate_einsum_conflict_error(self, mesh): shape1 = (8, 32, 1, 16) shape2 = (8, 32, 1, 8) - np_inp1 = np.arange(math.prod(shape1)).reshape(shape1) - np_inp2 = np.arange(math.prod(shape2)).reshape(shape2) + np_inp1 = np.ones(math.prod(shape1)).reshape(shape1) + np_inp2 = np.ones(math.prod(shape2)).reshape(shape2) arr1 = jax.device_put( np_inp1, NamedSharding(mesh, P(None, None, None, 'data'))) @@ -6452,7 +6512,7 @@ def f(x, y, z): # Errors out on the intermediate einsum: `bthj,bthD->bthjD` # because of a conflict with self.assertRaisesRegex( - TypeError, + core.ShardingTypeError, 'dot_general operation.*produces an illegally sharded result'): f(arr1, arr2, arr3) @@ -6974,6 +7034,18 @@ def f(x): self.assertArraysEqual(out, np.cumsum(np_inp)) self.assertEqual(out.sharding, NamedSharding(mesh, P(None))) + @jax.jit + def f(x): + x = jnp.expand_dims(x, 1) + self.assertEqual(x.aval.sharding.spec, P('x', None)) + out = jnp.cumsum(x, axis=1) + self.assertEqual(out.aval.sharding.spec, P('x', None)) + return out + + arr2 = jax.device_put(np.arange(8), P('x')) + out = f(arr2) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) + def test_device_put_under_use_mesh(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) x = jnp.zeros((4, 4), dtype=jnp.int32) @@ -7090,7 +7162,31 @@ def f(x): self.assertEqual(out.shape, expected_shape) self.assertEqual(out.sharding, NamedSharding(mesh, expected_spec)) - def test_auto_axes_computation_follows_data_error(self): + @jtu.with_user_mesh((2,), ('x',)) + def test_dynamic_slice(self, mesh): + np_inp = np.arange(16., dtype=np.float32) + s = NamedSharding(mesh, P('x')) + arr = jax.device_put(np_inp, s) + + @jax.jit + def f(x): + y = lax.dynamic_slice_in_dim(x, jnp.array(1, dtype=np.int32), 2) + self.assertEqual(y.aval.sharding.spec, P('x')) + return y + + out = f(arr) + self.assertEqual(out.sharding, s) + + def g(x): + return jnp.sum(f(x)) + + out = jax.jit(jax.grad(g))(arr) + self.assertEqual(out.sharding, arr.sharding) + + out = jax.grad(g)(arr) + self.assertEqual(out.sharding, arr.sharding) + + def test_auto_axes_computation_follows_data(self): mesh = jtu.create_mesh((2,), ('x',), axis_types=(AxisType.Explicit,)) s = NamedSharding(mesh, P('x')) arr = jax.device_put(np.arange(8), s) @@ -7099,8 +7195,9 @@ def test_auto_axes_computation_follows_data_error(self): def f(x): return x * 2 - with self.assertRaisesRegex(ValueError, "Context mesh.*cannot be empty"): - auto_axes(f, out_shardings=s)(arr) + out = auto_axes(f, out_shardings=s)(arr) + self.assertEqual(out.sharding, s) + self.assertArraysEqual(out, arr * 2) def test_divisbility_aval_error(self): abstract_mesh = mesh_lib.AbstractMesh( @@ -7135,6 +7232,7 @@ def f(x): out = f(np.arange(8)) self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) + @jtu.thread_unsafe_test() def test_set_mesh(self): mesh = jtu.create_mesh((2,), ('x',), axis_types=(AxisType.Explicit,)) try: @@ -7164,6 +7262,174 @@ def f(x): self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) self.assertArraysEqual(out, np.arange(8) * 2) + @jtu.with_user_mesh((2,), ('x',)) + def test_rng_bit_generator(self, mesh): + def f(key): + out = lax.rng_bit_generator(key, shape=(4, 8), out_sharding=P('x')) + self.assertEqual(out[0].aval.sharding.spec, P(None)) + self.assertEqual(out[1].aval.sharding.spec, P('x', None)) + return out + + key = np.array((1, 2, 3, 4)).astype(np.uint32) + out1 = f(key) + jit_f = jax.jit(f) + out2 = jit_f(key) + self.assertEqual(out1[0].shape, (4,)) + self.assertEqual(out1[1].shape, (4, 8)) + self.assertEqual(out2[0].sharding, NamedSharding(mesh, P())) + self.assertEqual(out2[1].sharding, NamedSharding(mesh, P('x', None))) + self.assertEqual(out1[0].sharding, out2[0].sharding) + self.assertEqual(out1[1].sharding, out2[1].sharding) + self.assertArraysEqual(out1[0], out2[0]) + self.assertArraysEqual(out1[1], out2[1]) + + @jtu.with_user_mesh((2,), ('x',)) + def test_fold_in(self, mesh): + key = jax.random.key(72) + key = jax.device_put(key, NamedSharding(mesh, P())) + + @jax.jit + def f(key): + f1 = jax.random.fold_in(key, 1) + self.assertEqual(jax.random.key_data(f1).aval.sharding.spec, P(None)) + return f1 + + out = f(key) + self.assertEqual(out.sharding, NamedSharding(mesh, P())) + + @parameterized.named_parameters( + ("bits", partial(jax.random.bits, shape=(8, 12)), P('x', 'y')), + ("uniform", partial(jax.random.uniform, shape=(8, 12)), P('x', 'y')), + ("normal", partial(jax.random.normal, shape=(8, 12)), P('x', 'y')), + ("randint", partial(jax.random.randint, shape=(8, 12), minval=0, maxval=10), + P('x', 'y')), + ("permutation_1d", partial(jax.random.permutation, x=8), P('x')), + ("permutation_2d", partial(jax.random.permutation, + x=np.arange(8 * 12).reshape(8, 12)), + P('x', 'y')), + ) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_random_functions(self, fun, out_spec, mesh): + @jax.jit + def f(key): + out = fun(key, out_sharding=out_spec) + self.assertEqual(out.aval.sharding.spec, out_spec) + return out + + key = jax.random.key(1) + out = f(key) + self.assertEqual(out.sharding, NamedSharding(mesh, out_spec)) + + lowered_text = f.lower(key).as_text() + if config.use_shardy_partitioner.value: + self.assertIn('sdy.sharding_constraint', lowered_text) + if out_spec == P('x', 'y'): + self.assertIn('<@mesh, [{"x"}, {"y"}]>', lowered_text) + else: + assert out_spec == P('x') + self.assertIn('<@mesh, [{"x"}]>', lowered_text) + else: + if out_spec == P('x', 'y'): + self.assertIn('mhlo.sharding = "{devices=[2,2]<=[4]}"}', lowered_text) + else: + assert out_spec == P('x') + self.assertIn( + 'mhlo.sharding = "{devices=[2,2]<=[4] last_tile_dim_replicate}"}', + lowered_text) + + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_random_truncated_normal(self, mesh): + @jax.jit + def f(key, lower): + out = jax.random.truncated_normal(key, lower, 2., shape=(8, 12), + out_sharding=P('x', 'y')) + self.assertEqual(out.aval.sharding.spec, P('x', 'y')) + return out + + key = jax.random.key(1) + out = f(key, -1.) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + + lowered_text = f.lower(key, -1.).as_text() + if config.use_shardy_partitioner.value: + self.assertIn('sdy.sharding_constraint', lowered_text) + self.assertIn('<@mesh, [{"x"}, {"y"}]>', lowered_text) + else: + self.assertIn('mhlo.sharding = "{devices=[2,2]<=[4]}"}', lowered_text) + + def test_random_normal_wo_mesh_context_error(self): + mesh = jtu.create_mesh((2, 2), ('x', 'y'), + axis_types=(AxisType.Explicit,) * 2) + s = NamedSharding(mesh, P('x', 'y')) + + @jax.jit + def f(key): + out = jax.random.normal(key, shape=(8, 12), out_sharding=s) + self.assertEqual(out.aval.sharding.spec, P('x', 'y')) + self.assertEqual(out.aval.sharding.mesh, mesh.abstract_mesh) + return out + + key = jax.random.key(1) + with self.assertRaisesRegex( + ValueError, + 'Length of device assignment.*is not equal to the size of the mesh'): + f(key) + + def test_random_normal_wo_mesh_context(self): + mesh = jtu.create_mesh((2, 2), ('x', 'y'), + axis_types=(AxisType.Explicit,) * 2) + s = NamedSharding(mesh, P('x', 'y')) + + @jax.jit + def f(arr, key): + out = jax.random.normal(key, shape=(8, 12), out_sharding=s) + self.assertEqual(out.aval.sharding.spec, P('x', 'y')) + return arr + out + + key = jax.random.key(1) + out = f(jax.device_put(np.arange(8 * 12.).reshape(8, 12), s), key) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + + def test_auto_axes_no_context_mesh(self): + mesh = jtu.create_mesh((2, 2), ('x', 'y'), axis_types=(AxisType.Explicit,) * 2) + np_inp = np.arange(16.).reshape(8, 2) + s = NamedSharding(mesh, P('x', 'y')) + arr = jax.device_put(np_inp, s) + + @partial(auto_axes, axes='x', + out_shardings=NamedSharding(mesh, P('x', 'y'))) + def h(y): + self.assertEqual(y.aval.sharding.spec, P(None, 'y')) + z = jnp.sin(y) + self.assertEqual(z.aval.sharding.spec, P(None, 'y')) + return z + + out = jax.jit(h)(arr) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + + out = h(arr) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + + def test_scan_with_random_key_inside_jit(self): + mesh = jtu.create_mesh((2,), ('x',)) + sharding = NamedSharding(mesh, P(None, 'x')) + + @jax.jit + def scan(xs): + def step(carry, x): + next_carry = jax.vmap(jax.random.fold_in)(carry, x) + next_carry = jnp.where(x % 2 == 0, carry, next_carry) + return next_carry, None + rng = jnp.broadcast_to(jax.random.key(0), xs.shape[1:]) + rng, _ = jax.lax.scan(step, rng, xs) + return rng + + xs = jnp.arange(8).reshape(2, 4) + scan(xs) + + xs = jax.device_put(xs, sharding) + scan(xs) + @jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase): @@ -7865,12 +8131,6 @@ def f(x, y): @jtu.with_config(jax_use_shardy_partitioner=True) class ShardyTest(jtu.JaxTestCase): - # TODO(bartchr): Once JAX is released with SDY, remove setUp. - def setUp(self): - if not dialects.sdy: - raise unittest.SkipTest('Shardy is not available.') - super().setUp() - def test_lowering_input_output_sharding(self): mesh = jtu.create_mesh((4, 2), ('x', 'y')) np_inp = np.arange(16).reshape(8, 2) @@ -7939,8 +8199,8 @@ def test_array_sharding_repr_with_priority(self): sharding = sharding_impls.SdyArraySharding( mesh_shape=(('data', 4), ('model', 8), ('expert', 2)), dimension_shardings=[ - sharding_impls.SdyDimSharding(axes=['data', 'expert'], is_closed=True), - sharding_impls.SdyDimSharding(axes=['model'], is_closed=False, priority=2)]) + sharding_impls.SdyDimSharding(axes=['data', 'expert'], is_open=False), + sharding_impls.SdyDimSharding(axes=['model'], is_open=True, priority=2)]) self.assertEqual(repr(sharding), "SdyArraySharding([{'data', 'expert'}, {'model', ?}p2])") def test_array_sharding_repr_with_logical_ids(self): @@ -7953,7 +8213,7 @@ def test_array_sharding_repr_with_logical_ids(self): def test_dimension_sharding_repr(self): dim_sharding = sharding_impls.SdyDimSharding( - axes=['data', 'model'], is_closed=False, priority=2) + axes=['data', 'model'], is_open=True, priority=2) self.assertEqual(repr(dim_sharding), "SdyDimSharding({'data', 'model', ?}p2)") @@ -8020,5 +8280,44 @@ def f(x, y, static_arg0=1, static_arg1=2): self.assertArraysEqual(result, expected_result) self.assertEqual(result.sharding, NamedSharding(mesh, P(None, None, 'x'))) + def test_custom_partition_shardy_migration(self): + if jtu.is_cloud_tpu(): + raise unittest.SkipTest("Custom partitioning is not supported on libtpu.") + + def partition(mesh, arg_shapes, result_shape): + def lower_fn(x): + return x + + return ( + mesh, + lower_fn, + arg_shapes[0].sharding, + (arg_shapes[0].sharding,), + ) + + def infer_sharding_from_operands(mesh, arg_shapes, result_shape): + return arg_shapes[0].sharding + + def propagate_user_sharding(mesh, user_shape): + return user_shape.sharding + + @custom_partitioning + def f(x): + return x + + f.def_partition( + infer_sharding_from_operands=infer_sharding_from_operands, + partition=partition, + propagate_user_sharding=propagate_user_sharding, + ) + + mesh = jtu.create_mesh((4, 2), ('x', 'y')) + x = jax.device_put(np.arange(32 * 16).reshape(32, 16), + NamedSharding(mesh, P(None, 'x'))) + with self.assertRaisesRegex(ValueError, "provide sharding_rule to migrate " + "to Shardy"): + jax.jit(f)(x) + + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pmap_test.py b/tests/pmap_test.py index af2d03e2945d..a07a9e271907 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -3189,7 +3189,7 @@ class EagerPmapMixin: def setUp(self): super().setUp() stack = contextlib.ExitStack() - stack.enter_context(jtu.thread_local_config_context(jax_disable_jit=True, jax_eager_pmap=True)) + stack.enter_context(jtu.thread_local_config_context(jax_disable_jit=True)) stack.enter_context(jtu.ignore_warning( message="Some donated buffers were not usable", category=UserWarning)) self.addCleanup(stack.close) diff --git a/tests/python_callback_test.py b/tests/python_callback_test.py index 05b4c8d7c0ff..e4aeb8d66f9e 100644 --- a/tests/python_callback_test.py +++ b/tests/python_callback_test.py @@ -28,6 +28,7 @@ from jax._src import dispatch from jax._src import test_util as jtu from jax._src import util +from jax._src.lib import jaxlib_extension_version from jax.experimental import io_callback from jax.experimental import pjit from jax.experimental.shard_map import shard_map @@ -585,6 +586,56 @@ def fun(x): self.assertAllClose(2 * x, fun(x)) self.assertEqual(count(), 1) + @parameterized.parameters("int2", "int4", "uint2", "uint4") + def test_subbyte_operands(self, dtype: str): + if jaxlib_extension_version <= 321: + self.skipTest("Requires jaxlib_extension_version >= 322.") + def get(x): + return x + def f(x): + y = jax.pure_callback( + get, + jax.ShapeDtypeStruct((8,), dtype=dtype), + x, + ) + return y + x = np.arange(8, dtype=dtype) + # TODO(b/395428868): Remove this check once we support subbyte types. + if jtu.test_device_matches(["tpu"]): + if "2" in dtype: + self.skipTest("TODO(dsuo): TPU callbacks send SIGABRT for int2/uint2.") + np.testing.assert_array_equal(jax.jit(f)(x), np.arange(8, dtype=dtype)) + else: + with self.assertRaisesRegex( + Exception, "Unsupported primitive type" + ): + _ = jax.jit(f)(x) + + @parameterized.parameters("int2", "int4", "uint2", "uint4") + def test_subbyte_results(self, dtype: str): + if jaxlib_extension_version <= 321: + self.skipTest("Requires jaxlib_extension_version >= 322.") + def get(): + return np.arange(8, dtype=dtype) + + def f(): + y = jax.pure_callback( + get, + jax.ShapeDtypeStruct((8,), dtype) + ) + return y + + # TODO(b/395428868): Remove this check once we support subbyte types. + if jtu.test_device_matches(["tpu"]): + if "2" in dtype: + self.skipTest("TODO(dsuo): TPU callbacks send SIGABRT for int2/uint2.") + np.testing.assert_array_equal(jax.jit(f)(), np.arange(8, dtype=dtype)) + else: + with self.assertRaisesRegex( + Exception, "Unsupported primitive type" + ): + _ = jax.jit(f)() + class PureCallbackTest(jtu.JaxTestCase): @@ -990,26 +1041,11 @@ def f(x): def test_vmap_method_raise(self): @jax.vmap def f(x): - # Setting vectorized to None disables the current default behavior of - # falling back on sequential. - return jax.pure_callback(np.sin, x, x, vectorized=None) + return jax.pure_callback(np.sin, x, x) with self.assertRaisesRegex(NotImplementedError, "vmap is only supported"): f(jnp.arange(4.)) - def test_deprecated_vectorized(self): - def f(x, **kwargs): - return jax.pure_callback(np.sin, x, x, **kwargs) - - with self.assertWarnsRegex(DeprecationWarning, "The default behavior"): - jax.vmap(f)(jnp.arange(4.0)) - - with self.assertWarnsRegex(DeprecationWarning, "The vectorized argument"): - f(jnp.arange(4.0), vectorized=True) - - with self.assertWarnsRegex(DeprecationWarning, "The vectorized argument"): - f(jnp.arange(4.0), vectorized=False) - def test_vmap_method_expand_dims(self): def callback(x, y): self.assertTupleEqual(x.shape, (4,)) diff --git a/tests/pytorch_interoperability_test.py b/tests/pytorch_interoperability_test.py index e41c4329b95b..3035e68d234c 100644 --- a/tests/pytorch_interoperability_test.py +++ b/tests/pytorch_interoperability_test.py @@ -67,6 +67,8 @@ def testTorchToJaxFailure(self): y, client, client) @jtu.sample_product(shape=all_shapes, dtype=torch_dtypes) + @jtu.ignore_warning(message="jax.dlpack.to_dlpack was deprecated.*", + category=DeprecationWarning) def testJaxToTorch(self, shape, dtype): if not config.enable_x64.value and dtype in [ jnp.int64, diff --git a/tests/ragged_collective_test.py b/tests/ragged_collective_test.py index 844892adc052..1dd6ef657561 100644 --- a/tests/ragged_collective_test.py +++ b/tests/ragged_collective_test.py @@ -21,6 +21,7 @@ import jax import jax.ad_checkpoint from jax import lax +from jax import vmap from jax.sharding import PartitionSpec as P from jax._src import config from jax._src import test_util as jtu @@ -381,6 +382,199 @@ def fwd( c, jnp.array([[0, 0, 1, 0], [0, 2, 3, 4]], dtype=jnp.int32) ) + @parameterized.named_parameters( + dict( + testcase_name='_batch_0_data_shard_axis_0_input_0', + axis_name='x', + vmap_axis_name='y', + mesh_axes=dict(x=2, y=2), + vmap_batch_axis=0, + data_shard_axis=0, + input_config=0, + ), + dict( + testcase_name='_batch_0_data_shard_axis_1_input_0', + axis_name='x', + vmap_axis_name='y', + mesh_axes=dict(x=2, y=2), + vmap_batch_axis=0, + data_shard_axis=1, + input_config=0, + ), + dict( + testcase_name='_batch_1_data_shard_axis_0_input_1', + axis_name='x', + vmap_axis_name='y', + mesh_axes=dict(x=2, y=2), + vmap_batch_axis=1, + data_shard_axis=0, + input_config=1, + ), + dict( + testcase_name='_batch_1_data_shard_axis_1_input_1', + axis_name='x', + vmap_axis_name='y', + mesh_axes=dict(x=2, y=2), + vmap_batch_axis=1, + data_shard_axis=1, + input_config=1, + ), + ) + def test_ragged_all_to_all_vmap( + self, + axis_name, + vmap_axis_name, + mesh_axes, + vmap_batch_axis, + data_shard_axis, + input_config, + ): + device_type = jax.devices()[0].platform + if device_type == 'tpu' and jtu.get_tpu_version() < 4: + raise unittest.SkipTest( + 'UNSUPPORTED: HLO opcode `ragged-all-to-all` is not supported by TPU' + f' v{jtu.get_tpu_version()}' + ) + mesh = jtu.create_mesh(tuple(mesh_axes.values()), tuple(mesh_axes.keys())) + + def get_data_sharding(axis): + if axis == 0: + return P(axis_name, None, None) + elif axis == 1: + return P(None, axis_name, None) + else: + raise ValueError("Invalid data_shard_axis") + + data_sharding = get_data_sharding(data_shard_axis) + + if input_config == 0: + operand_data = jnp.array([[[1, 2, 3], [4, 5, 6]], + [[1, 2, 3], [4, 5, 6]]], dtype=jnp.int32) + send_sizes_data = jnp.array([[[1, 2], [1, 1]], + [[1, 2], [1, 1]]], dtype=jnp.int32) + output_offsets_data = jnp.array([[[0, 0], [1, 2]], + [[0, 0], [1, 2]]], dtype=jnp.int32) + recv_sizes_data = jnp.array([[[1, 1], [2, 1]], + [[1, 1], [2, 1]]], dtype=jnp.int32) + elif input_config == 1: + operand_data = jnp.array([[[1, 2, 3], [1, 2, 3]], + [[4, 5, 6], [4, 5, 6]]], dtype=jnp.int32) + send_sizes_data = jnp.array([[[1, 2], [1, 2]], + [[1, 1], [1, 1]]], dtype=jnp.int32) + output_offsets_data = jnp.array([[[0, 0], [0, 0]], + [[1, 2], [1, 2]]], dtype=jnp.int32) + recv_sizes_data = jnp.array([[[1, 1], [1, 1]], + [[2, 1], [2, 1]]], dtype=jnp.int32) + else: + raise ValueError("Invalid input config") + + output_data = jnp.zeros((2, 2, 4), dtype=jnp.int32) + input_offsets_data = jnp.array([[[0, 1], [0, 1]], + [[0, 1], [0, 1]]], dtype=jnp.int32) + + operand = jax.device_put(operand_data, jax.sharding.NamedSharding(mesh, data_sharding)) + output = jax.device_put(output_data, jax.sharding.NamedSharding(mesh, data_sharding)) + input_offsets = jax.device_put(input_offsets_data, jax.sharding.NamedSharding(mesh, data_sharding)) + send_sizes = jax.device_put(send_sizes_data, jax.sharding.NamedSharding(mesh, data_sharding)) + output_offsets = jax.device_put(output_offsets_data, jax.sharding.NamedSharding(mesh, data_sharding)) + recv_sizes = jax.device_put(recv_sizes_data, jax.sharding.NamedSharding(mesh, data_sharding)) + + @partial( + shard_map, + mesh=mesh, + in_specs=( + P(axis_name, None), + P(axis_name, None), + P(axis_name, None), + P(axis_name, None), + P(axis_name, None), + P(axis_name, None), + ), + out_specs=P(axis_name), + check_rep=False, + ) + def fwd( + operand, output, input_offsets, send_sizes, output_offsets, recv_sizes + ): + return lax.ragged_all_to_all( + operand=operand.reshape(operand.shape[1:]), + output=output.reshape(output.shape[1:]), + input_offsets=input_offsets.reshape(input_offsets.shape[1:]), + send_sizes=send_sizes.reshape(send_sizes.shape[1:]), + output_offsets=output_offsets.reshape(output_offsets.shape[1:]), + recv_sizes=recv_sizes.reshape(recv_sizes.shape[1:]), + axis_name=axis_name, + ) + + res = vmap( + fwd, in_axes=vmap_batch_axis, out_axes=0, axis_name=vmap_axis_name + )( + operand, output, input_offsets, send_sizes, output_offsets, recv_sizes + ).reshape( + (2, 2, 4) + ) + expected_res = jnp.array([[[1, 4, 0, 0], [2, 3, 5, 0]], + [[1, 4, 0, 0], [2, 3, 5, 0]]], dtype=jnp.int32) + self.assertAllClose(res, expected_res) + + def test_ragged_all_to_all_vmap_unsupported_axis_index_groups(self): + device_type = jax.devices()[0].platform + if device_type == 'tpu' and jtu.get_tpu_version() < 4: + raise unittest.SkipTest( + 'UNSUPPORTED: HLO opcode `ragged-all-to-all` is not supported by TPU' + f' v{jtu.get_tpu_version()}' + ) + + axis_name = 'x' + mesh_axes = dict(x=2) + mesh = jtu.create_mesh(tuple(mesh_axes.values()), tuple(mesh_axes.keys())) + data_sharding = P(axis_name, None, None) + operand_data = jnp.zeros((2, 2, 3), dtype=jnp.int32) + output_data = jnp.zeros((2, 2, 4), dtype=jnp.int32) + input_offsets_data = jnp.zeros((2, 2, 2), dtype=jnp.int32) + send_sizes_data = jnp.zeros((2, 2, 2), dtype=jnp.int32) + output_offsets_data = jnp.zeros((2, 2, 2), dtype=jnp.int32) + recv_sizes_data = jnp.zeros((2, 2, 2), dtype=jnp.int32) + + operand = jax.device_put(operand_data, jax.sharding.NamedSharding(mesh, data_sharding)) + output = jax.device_put(output_data, jax.sharding.NamedSharding(mesh, data_sharding)) + input_offsets = jax.device_put(input_offsets_data, jax.sharding.NamedSharding(mesh, data_sharding)) + send_sizes = jax.device_put(send_sizes_data, jax.sharding.NamedSharding(mesh, data_sharding)) + output_offsets = jax.device_put(output_offsets_data, jax.sharding.NamedSharding(mesh, data_sharding)) + recv_sizes = jax.device_put(recv_sizes_data, jax.sharding.NamedSharding(mesh, data_sharding)) + + @partial( + shard_map, + mesh=mesh, + in_specs=( + P(axis_name, None), + P(axis_name, None), + P(axis_name, None), + P(axis_name, None), + P(axis_name, None), + P(axis_name, None), + ), + out_specs=P(axis_name), + check_rep=False, + ) + def fwd( + operand, output, input_offsets, send_sizes, output_offsets, recv_sizes + ): + return lax.ragged_all_to_all( + operand=operand.reshape(operand.shape[1:]), + output=output.reshape(output.shape[1:]), + input_offsets=input_offsets.reshape(input_offsets.shape[1:]), + send_sizes=send_sizes.reshape(send_sizes.shape[1:]), + output_offsets=output_offsets.reshape(output_offsets.shape[1:]), + recv_sizes=recv_sizes.reshape(recv_sizes.shape[1:]), + axis_name=axis_name, + axis_index_groups=[[0, 1]], + ) + + with self.assertRaisesWithLiteralMatch( + NotImplementedError, 'Please open a feature request!'): + vmap(fwd, in_axes=0, out_axes=0, axis_name='b')(operand, output, input_offsets, send_sizes, output_offsets, recv_sizes) + def test_ragged_all_to_all_errors(self): operand = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype=jnp.float32) output = jnp.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], dtype=jnp.float32) diff --git a/tests/random_lax_test.py b/tests/random_lax_test.py index b6f8b4f132bf..b4d2853abd65 100644 --- a/tests/random_lax_test.py +++ b/tests/random_lax_test.py @@ -46,7 +46,7 @@ @jtu.with_config(jax_legacy_prng_key='allow') -class LaxRandomTest(jtu.JaxTestCase): +class RandomTestBase(jtu.JaxTestCase): def _CheckCollisions(self, samples, nbits): fail_prob = 0.01 # conservative bound on statistical fail prob by Chebyshev @@ -110,6 +110,11 @@ def _CheckChiSquared(self, samples, pmf, *, pval=None): def make_key(self, seed): return random.PRNGKey(seed, impl='threefry2x32') + +class CommonRandomTest(RandomTestBase): + """ + Tests of common functionality that should be run with all PRNG impls. + """ @jtu.sample_product( num=(None, 6, (6,), (2, 3), (2, 3, 4)), ) @@ -164,6 +169,60 @@ def testRngRandint(self, dtype): self.assertTrue(np.all(lo <= samples)) self.assertTrue(np.all(samples < hi)) + def test_eval_shape_big_random_array(self): + def f(x): + return random.normal(self.make_key(x), (int(1e12),)) + with jax.enable_checks(False): # check_jaxpr will materialize array + jax.eval_shape(f, 0) # doesn't error + + @jtu.sample_product( + type_=["int", "np.array", "jnp.array"], + seed=[-1, 0, 1, (1 << 32) - 1, (1 << 63) - 1, np.uint64((1 << 64) - 1)], + ) + def test_prng_jit_invariance(self, seed, type_): + if type_ == "int" and seed == (1 << 64) - 1: + self.skipTest("Expected failure: Python int too large.") + if not config.enable_x64.value and seed > np.iinfo(np.int32).max: + self.skipTest("Expected failure: Python int too large.") + type_ = {"int": int, "np.array": np.array, "jnp.array": jnp.array}[type_] + args_maker = lambda: [type_(seed)] + f = lambda s: random.key_data(self.make_key(s)) + self._CompileAndCheck(f, args_maker) + + def test_prng_errors(self): + seed = np.iinfo(np.int64).max + 1 + with self.assertRaises(OverflowError): + self.make_key(seed) + with self.assertRaises(OverflowError): + jax.jit(self.make_key)(seed) + + def test_random_split_doesnt_device_put_during_tracing(self): + key = self.make_key(1).block_until_ready() + with jtu.count_device_put() as count: + jax.jit(random.split)(key) + self.assertLessEqual(count(), 1) # 1 for the argument device_put + + def test_large_prng(self): + # https://github.com/jax-ml/jax/issues/11010 + def f(): + return random.uniform( + self.make_key(3), (308000000, 128), dtype=jnp.bfloat16) + + # TODO(jakevdp): key reuse checks for this OOM because of slice masking. + # Can we fix this? + with jax.debug_key_reuse(False): + # just lower, don't run, takes too long + jax.jit(f).lower() + + +class DistributionsTest(RandomTestBase): + """ + Tests of distribution statistics that need only be run with the default PRNG. + + We limit this to the default PRNG to avoid repeated execution of very costly + tests. So long as the input bits are valid (as tested in BasicRandomTest) then + the distribution logic tested here will apply correctly. + """ @jtu.sample_product(dtype=float_dtypes) def testNormal(self, dtype): key = lambda: self.make_key(0) @@ -396,7 +455,6 @@ def testCategoricalWithoutReplacement(self, logits_shape, prefix_shape): counts = jax.vmap(partial(jnp.bincount, length=n_categories), 1)(flat) assert (counts <= 1).all() - def testBernoulliShape(self): key = self.make_key(0) with jax.numpy_rank_promotion('allow'): @@ -1071,39 +1129,6 @@ def testChoiceShapeIsNotSequenceError(self): with self.assertRaises(TypeError): random.choice(key, 5, 2, replace=True) - def test_eval_shape_big_random_array(self): - def f(x): - return random.normal(self.make_key(x), (int(1e12),)) - with jax.enable_checks(False): # check_jaxpr will materialize array - jax.eval_shape(f, 0) # doesn't error - - @jtu.sample_product( - type_=["int", "np.array", "jnp.array"], - seed=[-1, 0, 1, (1 << 32) - 1, (1 << 63) - 1, np.uint64((1 << 64) - 1)], - ) - def test_prng_jit_invariance(self, seed, type_): - if type_ == "int" and seed == (1 << 64) - 1: - self.skipTest("Expected failure: Python int too large.") - if not config.enable_x64.value and seed > np.iinfo(np.int32).max: - self.skipTest("Expected failure: Python int too large.") - type_ = {"int": int, "np.array": np.array, "jnp.array": jnp.array}[type_] - args_maker = lambda: [type_(seed)] - f = lambda s: random.key_data(self.make_key(s)) - self._CompileAndCheck(f, args_maker) - - def test_prng_errors(self): - seed = np.iinfo(np.int64).max + 1 - with self.assertRaises(OverflowError): - self.make_key(seed) - with self.assertRaises(OverflowError): - jax.jit(self.make_key)(seed) - - def test_random_split_doesnt_device_put_during_tracing(self): - key = self.make_key(1).block_until_ready() - with jtu.count_device_put() as count: - jax.jit(random.split)(key) - self.assertLessEqual(count(), 1) # 1 for the argument device_put - @jtu.sample_product(dtype=int_dtypes + uint_dtypes) def test_randint_bounds(self, dtype): min = np.iinfo(dtype).min @@ -1131,18 +1156,6 @@ def test_randint_out_of_range(self): self.assertGreater((r == 0).sum(), 0) self.assertGreater((r == 255).sum(), 0) - def test_large_prng(self): - # https://github.com/jax-ml/jax/issues/11010 - def f(): - return random.uniform( - self.make_key(3), (308000000, 128), dtype=jnp.bfloat16) - - # TODO(jakevdp): key reuse checks for this OOM because of slice masking. - # Can we fix this? - with jax.debug_key_reuse(False): - # just lower, don't run, takes too long - jax.jit(f).lower() - @jtu.sample_product(shape=[(3, 4)], logits_shape_base=[(3, 4), (3, 1), (1, 4)], axis=[-3, -2, -1, 0, 1, 2]) @@ -1461,7 +1474,7 @@ def _double_threefry_fold_in(key, data): tag='fry2') @jtu.with_config(jax_default_prng_impl='threefry2x32') -class LaxRandomWithCustomPRNGTest(LaxRandomTest): +class CustomPRNGTest(CommonRandomTest): def make_key(self, seed): return prng_internal.random_seed(seed, impl=double_threefry_prng_impl) @@ -1522,7 +1535,7 @@ def test_grad_of_prng_key(self): @jtu.with_config(jax_default_prng_impl='rbg') -class LaxRandomWithRBGPRNGTest(LaxRandomTest): +class RBGPRNGTest(CommonRandomTest): def make_key(self, seed): return random.PRNGKey(seed, impl='rbg') @@ -1634,7 +1647,7 @@ def test_randint_out_of_range(self): @jtu.with_config(jax_default_prng_impl='unsafe_rbg') -class LaxRandomWithUnsafeRBGPRNGTest(LaxRandomWithRBGPRNGTest): +class UnsafeRBGPRNGTest(RBGPRNGTest): def make_key(self, seed): return random.PRNGKey(seed, impl="unsafe_rbg") @@ -1648,24 +1661,6 @@ def test_vmap_split_mapped_key_values(self): self.assertArraysEqual(random.key_data(vmapped_keys), random.key_data(ref_keys)) -def _sampler_unimplemented_with_custom_prng(*args, **kwargs): - raise SkipTest('sampler only implemented for default RNG') - -for test_prefix in [ - 'testPoisson', - 'testPoissonBatched', - 'testPoissonShape', - 'testPoissonZeros', -]: - for attr in dir(LaxRandomTest): - if attr.startswith(test_prefix): - setattr(LaxRandomWithCustomPRNGTest, attr, - _sampler_unimplemented_with_custom_prng) - setattr(LaxRandomWithRBGPRNGTest, attr, - _sampler_unimplemented_with_custom_prng) - setattr(LaxRandomWithUnsafeRBGPRNGTest, attr, - _sampler_unimplemented_with_custom_prng) - if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/random_test.py b/tests/random_test.py index a51e387dca76..d75f3a9c5e2e 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -338,7 +338,7 @@ def testRandomDistributionValues(self, case, make_key): Any refactoring of random distributions that leads to non-trivial differences in this test should follow the procedure outlined at - https://jax.readthedocs.io/en/latest/api_compatibility.html#numerics-and-randomness + https://docs.jax.dev/en/latest/api_compatibility.html#numerics-and-randomness This includes: * Announcing the change in the CHANGELOG.md @@ -602,10 +602,26 @@ def assertKeysEqual(self, key1, key2): self.assertEqual(key1.dtype, key2.dtype) self.assertArraysEqual(random.key_data(key1), random.key_data(key2)) + def make_keys(self, *shape, seed=28): + seeds = seed + jnp.arange(math.prod(shape), dtype=jnp.uint32) + return jax.vmap(random.key)(seeds).reshape(shape) + def test_construction(self): key = random.key(42) self.assertIsInstance(key, prng_internal.PRNGKeyArray) + def test_numpy_construction(self): + key = random.wrap_key_data(np.array([42, 173], dtype=np.uint32), + impl='threefry2x32') + self.assertIsInstance(key, prng_internal.PRNGKeyArray) + self.assertIsInstance(key._base_array, jax.Array) + self.assertEqual(key._base_array.device, jax.devices()[0]) + self.assertEqual(key.device, jax.devices()[0]) + + def test_device_property(self): + key = random.key(42) + self.assertEqual(key.device, key._base_array.device) + def test_random_clone(self): # Here we test value semantics and compatibility with jit/vmap # key reuse semantics are tested in key_reuse_test.py @@ -632,10 +648,6 @@ def test_construction_upgrade_flag(self): key = random.PRNGKey(42) self.assertIsInstance(key, prng_internal.PRNGKeyArray) - def make_keys(self, *shape, seed=28): - seeds = seed + jnp.arange(math.prod(shape), dtype=jnp.uint32) - return jax.vmap(random.key)(seeds).reshape(shape) - def test_key_as_seed(self): key = self.make_keys() with self.assertRaisesRegex(TypeError, "PRNGKey accepts a scalar seed"): @@ -657,6 +669,11 @@ def test_non_integer_seed(self): with self.assertRaisesRegex(TypeError, "PRNG key seed must be an integer"): random.key(seed) + def test_nbytes_property(self): + key = self.make_keys() + self.assertEqual(key.nbytes, key._base_array.nbytes) + self.assertEqual(key.nbytes, key.itemsize * key.size) + def test_dtype_property(self): k1, k2 = self.make_keys(), self.make_keys() self.assertEqual(k1.dtype, k2.dtype) diff --git a/tests/roofline_test.py b/tests/roofline_test.py index 564b4a9a1f9e..140beb3c6e71 100644 --- a/tests/roofline_test.py +++ b/tests/roofline_test.py @@ -18,7 +18,6 @@ from absl.testing import absltest import jax -from jax._src import mesh from jax._src import test_util as jtu from jax.experimental import roofline import jax.lax as lax @@ -29,6 +28,8 @@ jax.config.parse_flags_with_absl() jtu.request_cpu_devices(8) +_VERY_LARGE_NUMBER = 512 * 1024 + def create_inputs( *shardings: P, @@ -465,11 +466,7 @@ def collective_matmul(a, b): ) def test_unary_ops(self, f, dtype): data = jnp.zeros((3, 8), dtype=dtype) - out, result = roofline.roofline( - f, - in_specs=(P()), - out_specs=P(), - )(data) + out, result = roofline.roofline(f)(data) with self.subTest("flops"): self.assertEqual(result.unfused_flops, 3 * 8) with self.subTest("hbm_bytes"): @@ -495,12 +492,9 @@ def test_binary_ops(self): lambda a, b: jnp.minimum(a, b), lambda a, b: jnp.maximum(a, b), ]: - out, result = roofline.roofline( - f, - mesh=mesh.AbstractMesh((), ()), - in_specs=(P(), P()), - out_specs=P(), - )(jnp.zeros((3, 8), dtype=int), jnp.ones((3, 8), dtype=int)) + out, result = roofline.roofline(f)( + jnp.zeros((3, 8), dtype=int), jnp.ones((3, 8), dtype=int) + ) self.assertEqual(result.unfused_flops, 3 * 8) self.assertEqual( result.unfused_hbm_bytes, @@ -515,12 +509,7 @@ def test_broadcast(self): (2.0, jnp.ones((3, 8))), (jnp.zeros((3, 8)), 2.0), ]: - _, result = roofline.roofline( - lambda a, b: a + b, - mesh=mesh.AbstractMesh((), ()), - in_specs=(P(), P()), - out_specs=P(), - )(left, right) + _, result = roofline.roofline(lambda a, b: a + b)(left, right) self.assertEqual(result.unfused_flops, 3 * 8) def test_nested(self): @@ -531,27 +520,21 @@ def g(x): return g(x) + g(y) - _, result = roofline.roofline( - f, - mesh=mesh.AbstractMesh((), ()), - in_specs=(P(), P()), - out_specs=P(), - )(jnp.zeros((11, 4), dtype=int), jnp.ones((11, 4), dtype=int)) + _, result = roofline.roofline(f)( + jnp.zeros((11, 4), dtype=int), jnp.ones((11, 4), dtype=int) + ) self.assertEqual(result.unfused_flops, 3 * (11 * 4)) def test_no_mesh(self): - _, result = roofline.roofline( - lambda a, b: a + b, - in_specs=(P(), P()), - out_specs=P(), - )(jnp.zeros((3, 8), dtype=int), jnp.ones((3, 8), dtype=int)) + _, result = roofline.roofline(lambda a, b: a + b)( + jnp.zeros((3, 8), dtype=int), jnp.ones((3, 8), dtype=int) + ) self.assertEqual(result.unfused_flops, 3 * 8) def test_no_specs(self): - _, result = roofline.roofline( - lambda a, b: a + b, - mesh=mesh.AbstractMesh((), ()), - )(jnp.zeros((3, 8), dtype=int), jnp.ones((3, 8), dtype=int)) + _, result = roofline.roofline(lambda a, b: a + b)( + jnp.zeros((3, 8), dtype=int), jnp.ones((3, 8), dtype=int) + ) self.assertEqual(result.unfused_flops, 3 * 8) def test_no_mesh_and_no_specs(self): @@ -561,62 +544,77 @@ def test_no_mesh_and_no_specs(self): self.assertEqual(result.unfused_flops, 3 * 8) def test_dot_general(self): - _, result = roofline.roofline( - lambda a, b: a @ b, - mesh=mesh.AbstractMesh((), ()), - in_specs=(P(), P()), - out_specs=P(), - )(jnp.zeros((3, 7), dtype=int), jnp.ones((7, 5), dtype=int)) + _, result = roofline.roofline(lambda a, b: a @ b)( + jnp.zeros((3, 7), dtype=int), jnp.ones((7, 5), dtype=int) + ) self.assertEqual(result.unfused_flops, 2 * 3 * 7 * 5) self.assertEqual( result.unfused_hbm_bytes, self._bytes_per_word * (3 * 7 + 7 * 5 + 3 * 5) ) - def get_conv_output_dim(self, i, k, pad_low, pad_high, stride): + def get_conv_output_dim(self, i, k, pad_low, pad_high, stride) -> int: return jnp.floor((i - k + pad_low + pad_high) / stride) + 1 - @jtu.parameterized.named_parameters( - dict( - testcase_name="simple", - window_strides=(1, 1), - padding=((0, 0), (0, 0)), - ), - dict( - testcase_name="padding", - window_strides=(1, 1), - padding=((1, 2), (3, 4)), - ), - dict( - testcase_name="window_strides", - window_strides=(2, 2), - padding=((0, 0), (0, 0)), - ), - dict( - testcase_name="window_strides_and_padding", - window_strides=(3, 3), - padding=((1, 2), (3, 4)), - ), + def get_conv_num_output_channels( + self, batch_group_count: int, feature_group_count: int + ) -> int: + if batch_group_count > 1: + return batch_group_count + elif feature_group_count > 1: + return feature_group_count + else: + return 1 + + @jtu.parameterized.product( + window_strides=[(1, 1), (2, 2)], + padding=[((0, 0), (0, 0)), ((1, 2), (3, 4))], + # batch must be divisible by batch_group_count, so we only include factors + # of batch_group_count. + batch=[6, 12], + batch_group_count=[1, 3], + # num_input_channels must be divisible by feature_group_count, so we only + # include factors of feature_group_count. + num_input_channels=[6, 12], + feature_group_count=[1, 3], ) def test_conv_general_dilated_unfused_hbm_bytes( - self, window_strides: Sequence[int, int], padding: Sequence[int, int] + self, + window_strides: Sequence[int, int], + padding: Sequence[int, int], + batch: int, + batch_group_count: int, + num_input_channels: int, + feature_group_count: int, ): + if batch_group_count > 1 and feature_group_count > 1: + self.skipTest( + "batch_group_count and feature_group_count cannot both be > 1" + ) + + num_output_channels = self.get_conv_num_output_channels( + batch_group_count, feature_group_count + ) + + num_input_features = int(num_input_channels / feature_group_count) iw, ih = 100, 200 kw, kh = 7, 7 - input_data = jnp.zeros((1, 1, iw, ih), dtype=int) - kernel_data = jnp.ones((1, 1, kw, kh), dtype=int) + input_data = jnp.zeros((batch, num_input_channels, iw, ih), dtype=int) + kernel_data = jnp.ones( + (num_output_channels, num_input_features, kw, kh), dtype=int + ) conv = lambda a, b: lax.conv_general_dilated( - lhs=a, rhs=b, window_strides=window_strides, padding=padding + lhs=a, + rhs=b, + window_strides=window_strides, + padding=padding, + batch_group_count=batch_group_count, + feature_group_count=feature_group_count, ) - _, result = roofline.roofline( - conv, - mesh=mesh.AbstractMesh((), ()), - in_specs=(P(), P()), - out_specs=P(), - )(input_data, kernel_data) + _, result = roofline.roofline(conv)(input_data, kernel_data) - expected_input_size = 1 * 1 * iw * ih - expected_kernel_size = 1 * 1 * kw * kh + expected_input_size = batch * num_input_channels * iw * ih + expected_kernel_size = num_output_channels * num_input_features * kw * kh ow = self.get_conv_output_dim( iw, kw, padding[0][0], padding[0][1], window_strides[0] @@ -624,12 +622,14 @@ def test_conv_general_dilated_unfused_hbm_bytes( oh = self.get_conv_output_dim( ih, kh, padding[1][0], padding[1][1], window_strides[1] ) - expected_output_size = 1 * 1 * ow * oh + expected_output_shape = jnp.array( + (batch / batch_group_count, num_output_channels, ow, oh) + ) + expected_output_size = jnp.prod((expected_output_shape)) # Bytes accessed is sum of inputs and output. expected_unfused_hbm_bytes = self._bytes_per_word * ( expected_input_size + expected_kernel_size + expected_output_size ) - # TODO(b/394648206): add subtest for unfused_flops once they are supported. self.assertEqual(result.unfused_hbm_bytes, expected_unfused_hbm_bytes) @jtu.parameterized.named_parameters( @@ -642,24 +642,22 @@ def test_conv_general_dilated_unfused_hbm_bytes( padding="SAME_LOWER", ), ) - def test_conv_general_dilated_padding_string_unfused_hbm_bytes(self, padding: str): - input_data = jnp.zeros((1, 1, 10, 20), dtype=int) + def test_conv_general_dilated_padding_string( + self, padding: str + ): + input_data = jnp.zeros((1, 1, 3, 3), dtype=int) kernel_data = jnp.ones((1, 1, 3, 3), dtype=int) conv = lambda a, b: lax.conv_general_dilated( lhs=a, rhs=b, window_strides=(1, 1), padding=padding ) - _, result = roofline.roofline( - conv, - mesh=mesh.AbstractMesh((), ()), - in_specs=(P(), P()), - out_specs=P(), - )(input_data, kernel_data) + _, result = roofline.roofline(conv)(input_data, kernel_data) - expected_input_size = 1 * 1 * 10 * 20 + # Test hbm bytes. + expected_input_size = 1 * 1 * 3 * 3 expected_kernel_size = 1 * 1 * 3 * 3 # Because of same{_lower} padding, output shape should equal to input shape. - # This may not be true for other `{feature, batch}`_group_count`s.c + # This may not be true for other `{feature, batch}`_group_count`s. expected_output_size = expected_input_size # Bytes accessed is sum of inputs and output. expected_unfused_hbm_bytes = self._bytes_per_word * ( @@ -667,19 +665,28 @@ def test_conv_general_dilated_padding_string_unfused_hbm_bytes(self, padding: st ) self.assertEqual(result.unfused_hbm_bytes, expected_unfused_hbm_bytes) - def test_conv_general_dilated_padding_string_valid_unfused_hbm_bytes(self): + # Test flops. + # For spatial_valid_position_counts, we have 3x3 output with the following + # flops for each element: + # 4 6 4 + # 6 9 6 + # 4 6 4 + # Non_spatial_dims_factor = 1 because `{batch, feature}_group_count` are + # both equal to 1. + # Each FMA is 2 flops. + self.assertEqual( + result.unfused_flops, + 2 * (4 + 6 + 4 + 6 + 9 + 6 + 4 + 6 + 4), + ) + + def test_conv_general_dilated_padding_string_valid(self): input_data = jnp.zeros((1, 1, 10, 20), dtype=int) kernel_data = jnp.ones((1, 1, 3, 3), dtype=int) conv = lambda a, b: lax.conv_general_dilated( lhs=a, rhs=b, window_strides=(1, 1), padding="VALID" ) - _, result = roofline.roofline( - conv, - mesh=mesh.AbstractMesh((), ()), - in_specs=(P(), P()), - out_specs=P(), - )(input_data, kernel_data) + _, result = roofline.roofline(conv)(input_data, kernel_data) expected_input_size = 1 * 1 * 10 * 20 expected_kernel_size = 1 * 1 * 3 * 3 @@ -690,19 +697,92 @@ def test_conv_general_dilated_padding_string_valid_unfused_hbm_bytes(self): * self.get_conv_output_dim(10, 3, 0, 0, 1) * self.get_conv_output_dim(20, 3, 0, 0, 1) ) + # Bytes accessed is sum of inputs and output. expected_unfused_hbm_bytes = self._bytes_per_word * ( expected_input_size + expected_kernel_size + expected_output_size ) self.assertEqual(result.unfused_hbm_bytes, expected_unfused_hbm_bytes) + # Output shape is [1x1x8x18] and each output element requires (3x3) FMAs, + # and each FMA is 2 flops. + self.assertEqual( + result.unfused_flops, 2 * expected_output_size * 3 * 3 + ) + + + @jtu.parameterized.named_parameters( + dict( + testcase_name="padding", + input_spatial_dim=1, + window_strides=[1], + padding=[(_VERY_LARGE_NUMBER - 1, _VERY_LARGE_NUMBER - 1)], + lhs_dilation=[1], + ), + dict( + testcase_name="input", + input_spatial_dim=_VERY_LARGE_NUMBER, + window_strides=[_VERY_LARGE_NUMBER - 1], + padding=[(0, 0)], + lhs_dilation=[_VERY_LARGE_NUMBER], + ), + ) + def test_conv_general_dilated_flops_very_large( + self, input_spatial_dim, window_strides, padding, lhs_dilation + ): + input_data = jnp.zeros((1, 1, input_spatial_dim), dtype=int) + kernel_data = jnp.ones((1, 1, _VERY_LARGE_NUMBER), dtype=int) + conv = lambda a, b: lax.conv_general_dilated( + lhs=a, + rhs=b, + window_strides=window_strides, + padding=padding, + lhs_dilation=lhs_dilation, + ) + _, result = roofline.roofline(conv)(input_data, kernel_data) + + self.assertEqual(result.unfused_flops, 2 * _VERY_LARGE_NUMBER) + + def test_conv_general_dilated_flops_feature_group_count(self): + feature_group_count = 120 + input_data = jnp.zeros((1, feature_group_count, 10, 20), dtype=int) + kernel_data = jnp.ones((feature_group_count, 1, 3, 3), dtype=int) + conv = lambda a, b: lax.conv_general_dilated( + lhs=a, + rhs=b, + window_strides=(1, 1), + padding=((0, 0), (0, 0)), + feature_group_count=feature_group_count, + ) + _, result = roofline.roofline(conv)(input_data, kernel_data) + + # Output shape is [1x120x8x18] and each output element requires (3x3) + # FMAs and one FMA is 2 flops. + self.assertEqual( + result.unfused_flops, 2 * 120 * 8 * 18 * 3 * 3 + ) + + def test_conv_general_dilated_flops_batch_group_count(self): + batch_group_count = 120 + input_data = jnp.zeros((batch_group_count, 1, 10, 20), dtype=int) + kernel_data = jnp.ones((batch_group_count, 1, 3, 3), dtype=int) + conv = lambda a, b: lax.conv_general_dilated( + lhs=a, + rhs=b, + window_strides=(1, 1), + padding=((0, 0), (0, 0)), + batch_group_count=batch_group_count, + ) + _, result = roofline.roofline(conv)(input_data, kernel_data) + + # Output shape is [120x1x8x18] and each output element requires (3x3) + # FMAs and one FMA is 2 flops. + self.assertEqual( + result.unfused_flops, 2 * 120 * 8 * 18 * 3 * 3 + ) + def test_reduce_sum_no_axis(self): - _, result = roofline.roofline( - lambda x: jnp.sum(x), - mesh=mesh.AbstractMesh((), ()), - in_specs=(P()), - out_specs=P(), - )(jnp.zeros((11, 4))) + _, result = roofline.roofline(lambda x: jnp.sum(x))(jnp.zeros((11, 4))) self.assertEqual(result.unfused_flops, 11 * 4 - 1) self.assertEqual( result.unfused_hbm_bytes, self._bytes_per_word * (11 * 4 + 1) @@ -715,12 +795,9 @@ def test_reduce_sum_with_axis(self): ([0, 1], 11 * 4 - 1, 11 * 4 + 1), ([], 0, 11 * 4 + 11 * 4), ]: - _, result = roofline.roofline( - lambda x: jnp.sum(x, axis=axis), - mesh=mesh.AbstractMesh((), ()), - in_specs=(P()), - out_specs=P(), - )(jnp.zeros((11, 4))) + _, result = roofline.roofline(lambda x: jnp.sum(x, axis=axis))( + jnp.zeros((11, 4)) + ) self.assertEqual(result.unfused_flops, expected_flops) self.assertEqual( result.unfused_hbm_bytes, self._bytes_per_word * expected_memory diff --git a/tests/scaled_matmul_stablehlo_test.py b/tests/scaled_matmul_stablehlo_test.py index 141839a19a08..b53ffcd5b977 100644 --- a/tests/scaled_matmul_stablehlo_test.py +++ b/tests/scaled_matmul_stablehlo_test.py @@ -174,7 +174,7 @@ def update_global_scale(config, new_global_scale): config.global_scale = new_global_scale return config -def generate_nvfp4_quantized_tensors(dot_config, output_type): +def generate_nvfp4_quantized_tensors(dot_config, output_type, enable_grad_clip=False): k1, k2 = jax.random.split(jax.random.key(0), 2) a_shape, b_shape, dimension_numbers = dot_config @@ -194,6 +194,11 @@ def generate_nvfp4_quantized_tensors(dot_config, output_type): amax_a = jnp.max(jnp.abs(a)).astype(jnp.float32) amax_b = jnp.max(jnp.abs(b)).astype(jnp.float32) + # To emulate calibrated amax + amax_sf = 0.9 if enable_grad_clip else 1.0 + amax_a *= amax_sf + amax_b *= amax_sf + # Update global scales data_max = jnp.finfo(block_scale_configs_nvfp4[0].data_type).max.astype( jnp.float32 @@ -508,6 +513,68 @@ def fn(a): self.assertArraysAllClose(out_q, out_q_ref, rtol=1e-5, atol=1e-5) self.assertArraysAllClose(scale, scale_ref, rtol=1e-5, atol=1e-5) + @jtu.sample_product( + enable_grad_clip=[True, False], + configs=[ + # a_shape, b_shape, dimension_numbers + ((1, 128, 128), (1, 128, 128), (([2], [2]), ([0], [0]))), + ((30, 64), (100, 64), (([1], [1]), ([], []))), + ] + ) + @jtu.run_on_devices("cuda") + def test_nvfp4_gradient_clip(self, enable_grad_clip, configs): + output_type = jnp.float32 + (a_raw, b_raw), (a_dq, b_dq), _, block_scale_configs = ( + generate_nvfp4_quantized_tensors(configs, output_type, enable_grad_clip) + ) + a_gs = block_scale_configs[0].global_scale + b_gs = block_scale_configs[1].global_scale + dimension_numbers = configs[2] + + scaled_dot_general = partial( + scaled_dot_general_wrapper, + configs=block_scale_configs + ) + + def fwd(a, b, use_normalized=False): + y = scaled_dot_general( + a, b, dimension_numbers, + preferred_element_type=output_type + ) + return jnp.sum(y) + + j_train = jax.jit(jax.value_and_grad(fwd, argnums=[0, 1])) + _, (x_grad, w_grad) = j_train(a_raw, b_raw) + + data_max = jnp.finfo(jnp.float4_e2m1fn).max.astype(output_type) + scale_max = jnp.finfo(jnp.float8_e4m3fn).max.astype(output_type) + prev_amax_a = a_gs * data_max * scale_max + prev_amax_b = b_gs * data_max * scale_max + + # Use a large value to ensure no clipping + threshold_a = prev_amax_a if enable_grad_clip else 1e9 + threshold_b = prev_amax_b if enable_grad_clip else 1e9 + + # Verify gradients are clipped to 0 where |input| > global_scale * MAX * SCALE_MAX + self.assertArraysEqual( + jnp.where(jnp.abs(a_raw) > threshold_a, x_grad, 0), + jnp.zeros_like(x_grad), + ) + self.assertArraysEqual( + jnp.where(jnp.abs(b_raw) > threshold_b, w_grad, 0), + jnp.zeros_like(w_grad), + ) + if enable_grad_clip: + # Verify gradients are preserved where |input| <= global_scale * MAX * SCALE_MAX + self.assertArraysEqual( + jnp.where(jnp.abs(a_raw) <= prev_amax_a, x_grad, 0), + x_grad, + ) + self.assertArraysEqual( + jnp.where(jnp.abs(b_raw) <= prev_amax_b, w_grad, 0), + w_grad, + ) + @jtu.sample_product( configs=[ # a_shape, b_shape, dimension_numbers, is_training @@ -567,6 +634,16 @@ def fwd(a, b, is_ref=False, use_normalized=False): out_ref, _ = j_train_fwd_ref(a_dq, b_dq) self.assertArraysAllClose(out, out_ref, rtol=1e-2, atol=1e-2) + def _grad_clip(amax, x, grad): + return jnp.where(jnp.abs(x) <= amax, grad, 0) + + data_max = jnp.finfo(jnp.float4_e2m1fn).max.astype(output_type) + scale_max = jnp.finfo(jnp.float8_e4m3fn).max.astype(output_type) + prev_amax_a = a_gs * data_max * scale_max + prev_amax_b = b_gs * data_max * scale_max + + x_grad_ref = _grad_clip(prev_amax_a, a_raw, x_grad_ref) + w_grad_ref = _grad_clip(prev_amax_b, b_raw, w_grad_ref) self.assertArraysAllClose(x_grad, x_grad_ref, rtol=1e-2, atol=1e1) self.assertArraysAllClose(w_grad, w_grad_ref, rtol=1e-2, atol=1e1) else: diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index f8d5a11e842f..1dd59daee873 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -44,6 +44,7 @@ from jax._src.interpreters import partial_eval as pe from jax._src import linear_util as lu from jax._src import tree_util +from jax.custom_derivatives import SymbolicZero import jax.numpy as jnp from jax.experimental.custom_partitioning import custom_partitioning @@ -666,11 +667,6 @@ def test_check_rep_false_doesnt_hit_rep_rules(self): def f(): prim.bind() - with self.assertRaises(NotImplementedError): - f() - with self.assertRaises(NotImplementedError): - jax.jit(f)() - @partial(shard_map, mesh=mesh, in_specs=(), out_specs=None, check_rep=False) def f2(): prim.bind() @@ -685,6 +681,19 @@ def f3(): f3() jax.jit(f3)() + def test_multiple_result_primitive_with_none_sharding(self): + # https://github.com/jax-ml/jax/issues/27673 + xs = jnp.arange(20).reshape(2, 10) + mesh = jtu.create_mesh((2,), ("i",)) + y = shard_map( + lambda x: jnp.split(x.squeeze(), 2), + mesh=mesh, + in_specs=(None,), + out_specs=P("i"), + )(xs) + expected = jnp.repeat(xs, 2, axis=0).reshape(2, 2, 10) + self.assertArraysEqual(y, expected) + def test_vmap_spmd_axis_name(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) @@ -967,11 +976,12 @@ def f(x, y, z): def body(c, _): c, *cs = c return (*cs, c), None + x = lax.pvary(x, ('x', 'y')) + y = lax.pvary(y, 'y') out, _ = jax.lax.scan(body, (x, y, z), None, length=3) return [jnp.expand_dims(a, 0) for a in out] x = jnp.arange(4) - # doesn't crash, because out_spec assumes no replication (and there is none) shard_map(f, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), out_specs=P(('x', 'y')))(x, x, x) @@ -1003,6 +1013,59 @@ def body(c, _): shard_map(g, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), out_specs=[P(None), P(None), P(('x', 'y'))])(x, x, x) + def test_while_rep_rule(self): + mesh = jtu.create_mesh((2, 2,), ('x', 'y')) + + def f(x, y, z): + x, y, z = x.sum(), y.sum(), z.sum() + def cond(c): + i, *_ = c + return i < 5 + def body(c): + i, c, *cs = c + return (i + 1, *cs, c) + x = lax.pvary(x, ('x', 'y')) + y = lax.pvary(y, 'y') + _, *out = jax.lax.while_loop(cond, body, (0, x, y, z)) + return [jnp.expand_dims(a, 0) for a in out] + + x = jnp.arange(4) + + # doesn't crash, because out_spec assumes no replication (and there is none) + shard_map(f, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), + out_specs=P(('x', 'y')))(x, x, x) + + # does crash, because output incorrectly promises replication + with self.assertRaisesRegex(ValueError, "require replication"): + shard_map(f, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), + out_specs=P('x'))(x, x, x) + with self.assertRaisesRegex(ValueError, "require replication"): + shard_map(f, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), + out_specs=P('y'))(x, x, x) + with self.assertRaisesRegex(ValueError, "require replication"): + shard_map(f, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), + out_specs=P(None))(x, x, x) + + def g(x, y, z): + x, y, z = x.sum(), y.sum(), z.sum() + def cond(c): + i, *_ = c + return i < 1 + def body(c): + i, *cs = c + return (i + 1, *cs) + _, *out = jax.lax.while_loop(cond, body, (0, x, y, z)) + return [jnp.expand_dims(a, 0) for a in out] + + # doesn't crash, because everything matches + shard_map(g, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), + out_specs=[P(None), P('x'), P(('x', 'y'))])(x, x, x) + + # does crash, because the second guy is wrong + with self.assertRaisesRegex(ValueError, "require replication"): + shard_map(g, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), + out_specs=[P(None), P(None), P(('x', 'y'))])(x, x, x) + def test_cond_rep_rule(self): mesh = jtu.create_mesh((2, 2,), ('x', 'y')) x = jnp.arange(4) @@ -1015,17 +1078,19 @@ def false_fun(x, y): return jax.lax.cond(True, true_fn, false_fun, x, y) shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P('x'))(x, x) + with self.assertRaisesRegex(ValueError, "require replication"): shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P(None))(x, x) def f(x, y): def true_fn(x, y): - return x + return lax.pvary(x, 'y') def false_fun(x, y): - return y + return lax.pvary(y, 'x') return jax.lax.cond(True, true_fn, false_fun, x, y) shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P(('x', 'y')))(x, x) + with self.assertRaisesRegex(ValueError, "require replication"): shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P('x'))(x, x) @@ -1037,6 +1102,7 @@ def false_fun(x, y): return jax.lax.cond(jnp.any(x > 0), true_fn, false_fun, x, y) shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P('x'))(x, x) + with self.assertRaisesRegex(ValueError, "require replication"): shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P(None))(x, x) @@ -1048,8 +1114,7 @@ def false_fun(x, y): return jax.lax.cond(jnp.any(y > 0), true_fn, false_fun, x, y) shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P(('x', 'y')))(x, x) - with self.assertRaisesRegex(ValueError, "require replication"): - shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P('x'))(x, x) + shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P('x'))(x, x) # https://github.com/jax-ml/jax/issues/24418 def f(a): @@ -1456,38 +1521,6 @@ def f(x): y = shard_f(x) self.assertEqual(x_spec, y.sharding.spec) - @parameterized.parameters([True, False]) - def test_rewrite_process_custom_vjp_call_match_less_replicated(self, jit): - @jax.custom_vjp - def foo(x, y): - del y - return 2. * x - - def foo_fwd(x, y): - return foo(x, y), y - - def foo_bwd(y, _): - return y, None # diff! x_bar less replicated than primal/tangent - - foo.defvjp(foo_fwd, foo_bwd) - - mesh = jtu.create_mesh((4,), ('x',)) - g = shard_map(lambda x, y: foo(x, y) * y, mesh, - in_specs=(P(), P('x')), out_specs=P('x')) - if jit: - g = jax.jit(g) - - x = jnp.arange(4.) - y = jnp.arange(4 * 4.) - - z = g(x, y) - self.assertAllClose(z, 2 * jnp.tile(x, (4,)) * y, check_dtypes=False) - - z_, x_bar = jax.value_and_grad(lambda x, y: g(x, y).sum())(x, y) - self.assertAllClose(z.sum(), z_, check_dtypes=False) - self.assertAllClose(x_bar, jnp.arange(16).reshape(4, 4).sum(0), - check_dtypes=False) - @parameterized.parameters([True, False]) def test_rewrite_custom_vjp_call_jaxpr(self, jit): @jax.custom_vjp @@ -1555,7 +1588,7 @@ def f(x): jaxpr = jax.make_jaxpr(jax.vjp(f, jnp.arange(1.))[1])(jnp.arange(4.)) e, = jaxpr.jaxpr.eqns e2, = e.params['jaxpr'].eqns - self.assertEqual(str(e2.primitive), 'psum2') + self.assertEqual(str(e2.primitive), 'psum_invariant') self.assertEqual(e2.params['axes'], ('x',)) def test_fanin_psum_transposes_to_fanout(self): @@ -1568,7 +1601,7 @@ def f(x): jaxpr = jax.make_jaxpr(jax.vjp(f, jnp.arange(4.))[1])(jnp.array([1.])) e, = jaxpr.jaxpr.eqns e1, = e.params['jaxpr'].eqns - self.assertEqual(str(e1.primitive), 'pbroadcast') + self.assertEqual(str(e1.primitive), 'pvary') def test_psum_with_implicit_fanout_self_transposes(self): mesh = jtu.create_mesh((4,), ('x',)) @@ -1580,8 +1613,8 @@ def f(x): jaxpr = jax.make_jaxpr(jax.vjp(f, jnp.arange(4.))[1])(jnp.arange(4.)) e, = jaxpr.jaxpr.eqns e1, e2 = e.params['jaxpr'].eqns - self.assertEqual(str(e1.primitive), 'psum2') - self.assertEqual(str(e2.primitive), 'pbroadcast') + self.assertEqual(str(e1.primitive), 'psum_invariant') + self.assertEqual(str(e2.primitive), 'pvary') def test_transpose_float0(self): mesh = jtu.create_mesh((4,), ('x',)) @@ -1632,6 +1665,18 @@ def example(x, y): dx, dy = example(x, y) self.assertEqual(dy.dtype, jax.dtypes.float0) + def test_pvary(self): + mesh = jtu.create_mesh((4,), ('x',)) + + @partial(shard_map, mesh=mesh, in_specs=P(), out_specs=P('x')) + def f(x): + y = jax.lax.pvary(x, 'x') + self.assertEqual(y.aval.vma, {'x'}) + return y + + f(jnp.arange(8.)) + jax.grad(lambda x: f(x).sum())(jnp.arange(8.)) + def test_rewrite_binops(self): mesh = jtu.create_mesh((4,), ('x',)) @@ -1642,7 +1687,7 @@ def f(x, y): jaxpr = jax.make_jaxpr(f)(jnp.arange(1.), jnp.arange(4.)) e, = jaxpr.jaxpr.eqns e = e.params['jaxpr'].eqns[0] - self.assertEqual(e.primitive.name, 'pbroadcast') + self.assertEqual(e.primitive.name, 'pvary') self.assertEqual(e.params['axes'], ('x',)) def test_rewrite_scan(self): @@ -1650,16 +1695,17 @@ def test_rewrite_scan(self): @partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P('x')) def f(x): - x, _ = jax.lax.scan(lambda x, _: (jax.lax.psum(x, 'x'), None), x, None, - length=2) + def g(x, _): + return lax.pvary(jax.lax.psum(x, 'x'), 'x'), None + x, _ = jax.lax.scan(g, x, None, length=2) return x jaxpr = jax.make_jaxpr(f)(jnp.arange(4.)) e, = jaxpr.jaxpr.eqns e, = e.params['jaxpr'].eqns e1, e2 = e.params['jaxpr'].eqns - self.assertEqual(e1.primitive.name, 'psum2') - self.assertEqual(e2.primitive.name, 'pbroadcast') + self.assertEqual(e1.primitive.name, 'psum_invariant') + self.assertEqual(e2.primitive.name, 'pvary') def test_check_rep_false_grads(self): if jtu.is_device_tpu(5, 'e'): @@ -2185,17 +2231,12 @@ def g(x): return x * x def h(x): - return shard_map(g, mesh, - in_specs=P(None, 'j'), - out_specs=P(None, 'j'))(x) + return shard_map(g, mesh, in_specs=P(None, 'j'), out_specs=P(None, 'j'))(x) @jax.jit def f(x): - return shard_map(h, mesh, - in_specs=P('i', None), - out_specs=P('i', None), - check_rep=False, - auto=frozenset({'j'}))(x) + return shard_map(h, mesh, in_specs=P('i', None), out_specs=P('i', None), + check_rep=False, auto=frozenset({'j'}))(x) v = jnp.arange(32.).reshape(4, 8) v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j'))) @@ -2412,16 +2453,22 @@ def f(x): f(keys) # doesn't crash + def test_grad_remat(self): + mesh = jtu.create_mesh((1, 1), ('i', 'j')) + args = [jnp.arange(6.).reshape(3, 2), jnp.arange(6.).reshape(3, 2, 1)] + + @partial(jax.remat, policy=lambda *_, **__: True) + @partial(shard_map, mesh=mesh, in_specs=(P('j'), P('i')), + out_specs=P('i', 'j')) + def f(x, y): + return jnp.dot(x, y) + jax.grad(lambda x, y: f(x, y).sum())(*args) + def test_vmap_grad_shmap_spmd_axis_name_residuals(self): # https://github.com/jax-ml/jax/pull/21032 mesh = jtu.create_mesh((4, 2), ('i', 'j')) - @partial( - shard_map, - mesh=mesh, - in_specs=P('j'), - out_specs=P('j'), - ) + @partial(shard_map, mesh=mesh, in_specs=P('j'), out_specs=P('j')) def f(x): return jnp.sin(x) @@ -2434,12 +2481,7 @@ def test_vmap_grad_remat_shmap_spmd_axis_name_residuals(self): mesh = jtu.create_mesh((4, 2), ('i', 'j')) @partial(jax.remat, policy=lambda *_, **__: True) - @partial( - shard_map, - mesh=mesh, - in_specs=P('j'), - out_specs=P('j'), - ) + @partial(shard_map, mesh=mesh, in_specs=P('j'), out_specs=P('j')) def f(x): return jnp.sin(x) @@ -2454,8 +2496,8 @@ def test_grad_shmap_residuals_axis_names_in_mesh_order(self): @partial( shard_map, mesh=mesh, - in_specs=P('j'), - out_specs=P('j'), + in_specs=P(('i', 'k')), + out_specs=P(('i', 'k')), ) def f(x): return jnp.sin(x) @@ -2465,22 +2507,45 @@ def f(x): ir = jax.jit(jax.grad(lambda x: f(x).sum())).lower(xs) if config.use_shardy_partitioner.value: self.assertIn( - 'out_shardings=[<@mesh, [{"i", "j", "k", "a"}]>]', ir.as_text() + 'out_shardings=[<@mesh, [{"i", "k"}]>]', ir.as_text() ) else: self.assertIn( - "{jax.result_info = \"[('i', 'j', 'k', 'a')]\"}", ir.as_text() + "{jax.result_info = \"[('i', 'k')]\"}", ir.as_text() ) + def test_dynamic_slice_transpose(self): + mesh = jtu.create_mesh((2,), ('x',)) + arr = np.arange(16., dtype=np.float32) + + @partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P('x')) + def f(x): + return lax.dynamic_slice_in_dim(x, jnp.array(1, dtype=np.int32), 2) + + f(arr) # doesn't crash + jax.jit(f)(arr) # doesn't crash + + def g(x): + return jnp.sum(f(x)) + + jax.grad(g)(arr) # doesn't crash + jax.jit(jax.grad(g))(arr) # doesn't crash + + @parameterized.parameters([P()], [P('x')], [P(('x', 'y'))]) + def test_print_inside_shard_map(self, specs): + mesh = jtu.create_mesh((2, 2), ('x', 'y')) + x = jnp.arange(4.) + + @partial(shard_map, mesh=mesh, in_specs=specs, out_specs=specs) + def f(x): + print(x) + return 2 * x + f(x) # doesn't crash + def test_vmap_spmd_axis_name_error(self): mesh = jtu.create_mesh((4, 2), ('i', 'j')) - @partial( - shard_map, - mesh=mesh, - in_specs=P('i'), - out_specs=P('i'), - ) + @partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i')) def f(x): return jnp.sin(x) @@ -2488,13 +2553,8 @@ def f(x): with self.assertRaisesRegex(ValueError, "spmd_axis_name cannot appear"): jax.vmap(f, spmd_axis_name='i')(xs) - @partial( - shard_map, - mesh=mesh, - in_specs=P('j'), - out_specs=P(('i', 'j')), - check_rep=False, - ) + @partial(shard_map, mesh=mesh, in_specs=P('j'), out_specs=P(('i', 'j')), + check_rep=False) def g(x): return jnp.sin(x) @@ -2650,13 +2710,7 @@ def f(x, reduce_along, use_jit): @partial(shard_map, mesh=mesh, in_specs=P('x', 'y'), out_specs=out_spec) def g(x): result = lax.psum(x, axis_name=reduce_along) - def check_rep(result): - self.assertEqual( - jax.experimental.shard_map.get_replication(result), - set(reduce_along)) - return result - result = check_rep(result) - result = jax.vmap(check_rep)(result) + self.assertEqual(result.aval.vma, x.aval.vma - set(reduce_along)) return result if use_jit: return jax.jit(g)(x) @@ -2673,18 +2727,213 @@ def test_pmin(self): mesh = jtu.create_mesh((4,), ('i',)) x = jnp.arange(8., dtype=np.float32) y = shard_map(lambda x: jax.lax.pmin(x, 'i'), - mesh=mesh, in_specs=P('i'), out_specs=P() - )(x) # don't crash + mesh=mesh, in_specs=P('i'), out_specs=P())(x) # don't crash self.assertArraysEqual(y, np.array([0, 1], dtype=np.float32)) def test_pmax(self): mesh = jtu.create_mesh((4,), ('i',)) x = jnp.arange(8., dtype=np.float32) y = shard_map(lambda x: jax.lax.pmax(x, 'i'), - mesh=mesh, in_specs=P('i'), out_specs=P() - )(x) # don't crash + mesh=mesh, in_specs=P('i'), out_specs=P())(x) # don't crash self.assertArraysEqual(y, np.array([6, 7], dtype=np.float32)) + def test_pmax_vma_in_types(self): + mesh = jtu.create_mesh((4,), ('i',)) + x = jnp.arange(8., dtype=np.float32) + f = jax.jit(shard_map(lambda x: jax.lax.pmax(x, 'i'), mesh=mesh, + in_specs=P(), out_specs=P())) + jaxpr = f.trace(x).jaxpr + self.assertIn("pvary[axes=('i',)", str(jaxpr)) + f(x) # doesn't crash + + def test_mul_with_vma_in_types(self): + mesh = jtu.create_mesh((2,), ('x',)) + x = np.arange(8.) + + def f(x): + self.assertEqual(x.aval.vma, frozenset({'x'})) + out = x * 2 + self.assertEqual(out.aval.vma, frozenset({'x'})) + return out + + f = jax.jit(shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))) + jaxpr = f.trace(x).jaxpr + self.assertIn("pvary[axes=('x',)", str(jaxpr)) + out = f(x) + self.assertArraysEqual(out, x * 2) + + # TODO(yashkatariya): Enable grad test which requires adding psum_p support. + # def g(x, y): + # return jnp.sum(f(x, y)) + # print(jax.jit(jax.grad(g)).trace(x, y).jaxpr) + + def test_all_gather_with_vma_in_types(self): + mesh = jtu.create_mesh((2,), ('x',)) + x = np.arange(8.) + + def f(x): + self.assertEqual(x.aval.vma, frozenset()) + out = jax.lax.all_gather(x, 'x') + self.assertEqual(out.aval.vma, frozenset({'x'})) + return out + + f = jax.jit(shard_map(f, mesh=mesh, in_specs=P(), out_specs=P('x'))) + jaxpr = f.trace(x).jaxpr + self.assertIn("pvary[axes=('x',)", str(jaxpr)) + + f(x) # doesn't crash + + def test_rep_none_canonicalization(self): + # https://github.com/jax-ml/jax/issues/26621 + if config.use_shardy_partitioner.value: + self.skipTest('complex values fail under shardy') + N = 8 + xs = jnp.ones((8, N), dtype=jnp.int32) + variables = jax.random.normal(jax.random.key(1), (N, N), jnp.complex64) + mesh = jtu.create_mesh((2,), ('i',)) + in_specs = (P(), P("i"),) + out_specs = P("i") + + variables = jax.lax.with_sharding_constraint(variables, NamedSharding(mesh, P())) + xs = jax.lax.with_sharding_constraint(xs, NamedSharding(mesh, P('i'))) + + def fun(v, xs): + # Commenting this single line below makes everything work + v = jax.scipy.linalg.expm(v) + v = v.sum() + return v * xs.sum(axis=-1).astype(v.dtype) + + res = fun(variables, xs) + fun_shard_map = shard_map(fun, mesh=mesh, in_specs=in_specs, out_specs=out_specs) + res = fun_shard_map(variables, xs) # don't crash + + def test_rep_none_canonicalization_again(self): + # https://github.com/jax-ml/jax/issues/24762 + mesh = jtu.create_mesh((2,), ('i',)) + def f(x): + return jnp.insert(x, 0, 0)[None] + f = shard_map(f, mesh, P('i'), P('i')) + f(jnp.zeros(100)) # don't crash + + def test_custom_jvp_symbolic_zeros(self): + # https://github.com/jax-ml/jax/issues/26763 + mesh = jtu.create_mesh((4,), ('i',)) + @jax.custom_jvp + def f(a: jax.Array, b: jax.Array) -> jax.Array: + return a + b + + @partial(f.defjvp, symbolic_zeros=True) + def f_jvp(primals, tangents): + a, b = primals + a_dot, b_dot = tangents + y = f(a, b) + y_dot = jnp.zeros_like(y) + if not isinstance(a_dot, SymbolicZero): + y_dot += a_dot + if not isinstance(b_dot, SymbolicZero): + y_dot += b_dot + return y, y_dot + x = jax.random.normal(jax.random.key(0), (jax.device_count(), 20)) + A = jax.random.normal(jax.random.key(1), (jax.device_count(), 20)) + + g = shard_map(f, mesh, in_specs=P('i'), out_specs=P('i')) + jax.jvp(lambda x: g(x, A), (x,), (x,)) # don't crash + + def test_cond_pvary_errors(self): + mesh = jtu.create_mesh((1, 1), ('x', 'y')) + def f(x, y): + def true_fn(x, y): + return x + def false_fun(x, y): + return y + return jax.lax.cond(True, true_fn, false_fun, x, y) + x = jnp.arange(4.) + with self.assertRaisesRegex( + TypeError, + r"applying `jax.lax.pvary\(..., \('y',\)\)` to the output of true_fun"): + shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P(('x', 'y')))(x, x) + + def test_cond_pvary_errors_pytree(self): + mesh = jtu.create_mesh((1, 1), ('x', 'y')) + + def f(x, y): + def true_fn(x, y): + return x, y + def false_fun(x, y): + return y, x + return jax.lax.cond(True, true_fn, false_fun, x, y) + x = jnp.arange(4.) + with self.assertRaisesRegex( + TypeError, + r"applying `jax.lax.pvary\(..., \('y',\)\)` to the output of true_fun"): + shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P(('x', 'y')))(x, x) + + def test_scan_pvary_errors(self): + mesh = jtu.create_mesh((1, 1), ('i', 'j')) + x = jnp.arange(3.) + y = jnp.arange(3.) + + @partial(shard_map, mesh=mesh, in_specs=(P('i'), P()), out_specs=P('i')) + def f(x, y): + def body(carry, _): + c1, c2 = carry + return (c2, c1), () # swap the carry + (x_, y_), _ = jax.lax.scan(body, (x, y), (), length=2) + return x_, y_ + + with self.assertRaisesRegex( + TypeError, + r"This might be fixed by applying `jax.lax.pvary\(..., \('i',\)\)` to" + r' the initial'): + f(x, y) + + @partial(shard_map, mesh=mesh, in_specs=(P('i'), P()), out_specs=P('i')) + def g(x, y): + def body(carry, _): + c1, c2 = carry + return (c2, c1), () + y = jax.lax.pvary(y, 'i') # fix the issue + (x_, y_), _ = jax.lax.scan(body, (x, y), (), length=2) + return x_, y_ + + g(x, y) # doesn't crash + + def test_scan_pvary_errors2(self): + mesh = jtu.create_mesh((1, 1), ('i', 'j')) + x = jnp.arange(3.) + y = jnp.arange(3.) + z = jnp.arange(3.) + + @partial(shard_map, mesh=mesh, in_specs=(P('i'), P(), P(('i', 'j'))), out_specs=P(('i', 'j'))) + def f(x, y, z): + def body(carry, _): + c1, c2, c3 = carry + return (c3, c1, c2), () # swap the carry + + # x = jax.lax.pvary(x, 'j') + # y = jax.lax.pvary(y, ('i', 'j')) + carry, _ = jax.lax.scan(body, (x, y, z), (), length=2) + return carry + + with self.assertRaisesRegex( + TypeError, + r'This might be fixed by:\n \* applying `jax.lax.pvary\(...,' + r" \('j',\)\)`"): + f(x, y, z) + + @partial(shard_map, mesh=mesh, in_specs=(P('i'), P(), P(('i', 'j'))), out_specs=P(('i', 'j'))) + def g(x, y, z): + def body(carry, _): + c1, c2, c3 = carry + return (c3, c1, c2), () # swap the carry + + x = jax.lax.pvary(x, 'j') # fix the issue + y = jax.lax.pvary(y, ('i', 'j')) + carry, _ = jax.lax.scan(body, (x, y, z), (), length=2) + return carry + + g(x, y, z) # doesn't crash + class FunSpec(NamedTuple): name: str @@ -3042,7 +3291,7 @@ def g(*args): else: slices = map(jnp.stack, zip(*expected_slices)) expected = jax.tree.unflatten(treedef, slices) - tol = 1e-2 if jtu.test_device_matches(['tpu']) else None + tol = 1e-2 if jtu.test_device_matches(['gpu', 'tpu']) else None self.assertAllClose(ans, expected, check_dtypes=False, atol=tol, rtol=tol) @jtu.pytest_mark_if_available('multiaccelerator') @@ -3083,6 +3332,7 @@ def f(x): infer_sharding_from_operands=infer_sharding_from_operands, partition=partition, propagate_user_sharding=propagate_user_sharding, + sharding_rule='i -> i', ) @jax.jit diff --git a/tests/sparse_bcoo_bcsr_test.py b/tests/sparse_bcoo_bcsr_test.py index e839bacbe5fc..1224717570d1 100644 --- a/tests/sparse_bcoo_bcsr_test.py +++ b/tests/sparse_bcoo_bcsr_test.py @@ -36,7 +36,7 @@ from jax.experimental.sparse import util as sparse_util import jax.numpy as jnp import jax.random -from jax.util import split_list +from jax._src.util import split_list import numpy as np jax.config.parse_flags_with_absl() diff --git a/tests/sparse_nm_test.py b/tests/sparse_nm_test.py deleted file mode 100644 index 9ecf30eb6229..000000000000 --- a/tests/sparse_nm_test.py +++ /dev/null @@ -1,209 +0,0 @@ -# Copyright 2024 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math - -import numpy as np -from absl.testing import absltest -from absl.testing import parameterized - -import jax -import jax.numpy as jnp -from jax import dtypes -from jax._src import config -from jax._src import test_util as jtu -from jax.experimental.sparse import nm - -jax.config.parse_flags_with_absl() - - -class SpmmTest(jtu.JaxTestCase): - def setUp(self): - if not jtu.test_device_matches(["gpu"]): - self.skipTest("Only works on GPU") - if (jtu.test_device_matches(["cuda"]) and - not jtu.is_cuda_compute_capability_at_least("8.0")): - self.skipTest("Only works on GPUs with capability >= sm80") - super().setUp() - - # ----- Test different input shapes - @parameterized.product( - tile_m=(32, 128), - tile_n=(32, 128), - tile_k=(32, 128), - batch=(None, 5), - sparse_idx=(0, 1), - ) - @jtu.run_on_devices("gpu") - def test_shapes(self, tile_m, tile_n, tile_k, batch, sparse_idx): - # Build keyword arguments - kwargs = { - "dimension_numbers": (((1,), (1,)), (tuple(), tuple())), - "sparse_operand_idx": sparse_idx, - } - if batch: - kwargs["dimension_numbers"] = (((2,), (2,)), ((0,), (0,))) - - # Build input data - batch_dims = (batch,) if batch else tuple() - lhs = ( - (np.arange((batch or 1) * tile_m * tile_k) % 11) - .astype(dtypes.bfloat16) - .reshape(batch_dims + (tile_m, tile_k)) - ) - rhs = ( - (np.arange((batch or 1) * tile_n * tile_k) % 13) - .astype(dtypes.bfloat16) - .reshape(batch_dims + (tile_n, tile_k)) - ) - - # Build sparsity mask and metadata - sp = [lhs, rhs][sparse_idx] - mask = np.tile([True, False], math.prod(sp.shape) // 2).reshape(sp.shape) - sparse = sp[mask].reshape(sp.shape[:-1] + (sp.shape[-1] // 2,)) - meta = nm.nm_pack(mask) - - # Calculate sparse and dense dots - if sparse_idx == 0: - dot_sparse = nm.nm_spmm(sparse, rhs, meta, **kwargs) - dot_dense = jnp.einsum("...mk,...nk->...mn", (lhs * mask), rhs) - else: - dot_sparse = nm.nm_spmm(lhs, sparse, meta, **kwargs) - dot_dense = jnp.einsum("...mk,...nk->...mn", lhs, (rhs * mask)) - - # Verify the result - jtu.check_eq(dot_sparse, dot_dense.astype(dtypes.bfloat16)) - - # ----- Test different input types - @parameterized.product( - lhs_type=[jnp.int8, jnp.int16, jnp.float16, jnp.bfloat16], - rhs_type=[jnp.bfloat16], - output_type=[jnp.bfloat16, jnp.float32], - ) - @jtu.run_on_devices("gpu") - def test_types(self, lhs_type, rhs_type, output_type): - tile_m, tile_n, tile_k = 64, 32, 128 - - # Build input data - lhs = ( - (np.arange(tile_m * tile_k) % 17) - .astype(lhs_type) - .reshape((tile_m, tile_k)) - ) - rhs = ( - (np.arange(tile_k * tile_n) % 19) - .astype(rhs_type) - .reshape((tile_k, tile_n)) - ) - - # Build sparsity mask and metadata - mask = np.tile([True, False], tile_m * tile_k // 2).reshape(lhs.shape) - sparse = lhs[mask].reshape(tile_m, tile_k // 2) - meta = nm.nm_pack(mask) - - # Calculate sparse and dense dots - dot_sparse = nm.nm_spmm(sparse, rhs, meta, output_dtype=output_type) - dot_dense = (lhs * mask) @ rhs - - # Verify the result - jtu.check_close(dot_sparse, dot_dense.astype(output_type), rtol=0.01) - - # ----- Test validation - @jtu.run_on_devices("gpu") - def test_validate_nm_pack(self): - with self.assertRaisesRegex(TypeError, "Mask should be bool"): - nm.nm_pack(jnp.zeros(16, jnp.int8)) - with self.assertRaisesRegex( - TypeError, "Inner dimension size should be divisible by 16" - ): - nm.nm_pack(jnp.array([False] * 8)) - - @jtu.run_on_devices("gpu") - def test_validate_nm_spmm(self): - batch, tile_m, tile_n, tile_k = 2, 64, 32, 128 - lhs = jnp.zeros((batch, tile_m, tile_k // 2), dtype=jnp.bfloat16) - rhs = jnp.zeros((batch, tile_k, tile_n), dtype=jnp.bfloat16) - meta = jnp.zeros((batch, tile_m, tile_k // 16), dtype=jnp.uint16) - - if config.enable_x64.value: - with self.assertRaisesRegex(TypeError, "Unsupported lhs input type"): - nm.nm_spmm(jnp.zeros(lhs.shape, dtype=jnp.int64), rhs, meta) - with self.assertRaisesRegex(TypeError, "Unsupported rhs input type"): - nm.nm_spmm(lhs, jnp.zeros(rhs.shape, dtype=jnp.int64), meta) - with self.assertRaisesRegex(TypeError, "Unsupported output type"): - nm.nm_spmm(lhs, rhs, meta, output_dtype=jnp.int64) - - # Check dimension numbers - nm_spmm_with_dnums = lambda c, b: nm.nm_spmm( - lhs, rhs, meta, dimension_numbers=(c, b) - ) - with self.assertRaisesRegex( - TypeError, "Only single contracting dimension is supported" - ): - nm_spmm_with_dnums(((0, 2), (0, 1)), (tuple(), tuple())) - with self.assertRaisesRegex( - TypeError, "Incorrect dimension numbers for lhs" - ): - nm_spmm_with_dnums(((2,), (1,)), ((2,), (0,))) - with self.assertRaisesRegex( - TypeError, "Incorrect dimension numbers for rhs" - ): - nm_spmm_with_dnums(((2,), (1,)), ((0,), (1,))) - with self.assertRaisesRegex( - TypeError, "Only single non-contracting dimension is supported" - ): - nm_spmm_with_dnums(((2,), (1,)), (tuple(), tuple())) - with self.assertRaisesRegex( - TypeError, "Batch dimension sizes do not match" - ): - nm.nm_spmm( - lhs, - rhs.reshape(1, tile_k, tile_n * batch), - meta, - dimension_numbers=(((2,), (1,)), ((0,), (0,))), - ) - - # Check metadata - nm_spmm_with_meta = lambda m: nm.nm_spmm( - lhs, rhs, m, dimension_numbers=(((2,), (1,)), ((0,), (0,))) - ) - with self.assertRaisesRegex(TypeError, "Metadata must be uint16"): - nm_spmm_with_meta(jnp.zeros(meta.shape, dtype=jnp.uint8)) - with self.assertRaisesRegex( - TypeError, "Metadata shape must match the operand shape" - ): - nm_spmm_with_meta(meta.reshape(1, batch * tile_m, tile_k // 16)) - with self.assertRaisesRegex( - TypeError, - "Metadata must be exactly 8 times less than the contracting dimension" - " for 2:4 structured sparsity", - ): - nm_spmm_with_meta(jnp.repeat(meta, 2, axis=-1)) - with self.assertRaisesRegex( - TypeError, "Contracting dimension must be the minor one" - ): - nm.nm_spmm(lhs, rhs, meta, dimension_numbers=(((1,), (1,)), ((0,), (0,)))) - with self.assertRaisesRegex( - TypeError, "Contracting dimension sizes should have 2:4 ratio" - ): - nm.nm_spmm( - lhs, - jnp.repeat(rhs, 2, axis=1), - meta, - dimension_numbers=(((2,), (1,)), ((0,), (0,))), - ) - - -if __name__ == "__main__": - absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/sparse_test.py b/tests/sparse_test.py index eb8d70be1f05..219875d4b7d0 100644 --- a/tests/sparse_test.py +++ b/tests/sparse_test.py @@ -38,7 +38,7 @@ from jax._src import test_util as jtu from jax.interpreters import mlir import jax.numpy as jnp -from jax.util import split_list +from jax._src.util import split_list import numpy as np import scipy.sparse diff --git a/tests/state_test.py b/tests/state_test.py index 60a7d8bc9f8a..d9bf66eb3f50 100644 --- a/tests/state_test.py +++ b/tests/state_test.py @@ -28,6 +28,7 @@ from jax import lax from jax._src import core from jax._src import config +from jax._src import dtypes from jax._src import linear_util as lu from jax._src.interpreters import partial_eval as pe from jax._src import test_util as jtu @@ -36,13 +37,9 @@ import jax.numpy as jnp from jax._src.lax.control_flow import for_loop -try: - import hypothesis as hp - import hypothesis.extra.numpy as hnp - import hypothesis.strategies as hps - CAN_USE_HYPOTHESIS = True -except (ModuleNotFoundError, ImportError): - CAN_USE_HYPOTHESIS = False +import hypothesis as hp +import hypothesis.extra.numpy as hnp +import hypothesis.strategies as hps from jax._src.state.discharge import (run_state, run_state_reference, discharge_state) @@ -477,27 +474,17 @@ def g(r, rdot): op=[ lambda x_ref, indexer: [x_ref[indexer]], lambda x_ref, indexer: [ - ref_swap(x_ref, indexer, - jnp.ones(x_ref.shape, x_ref.dtype)[None][(0, - *indexer)])], + ref_swap(x_ref, indexer, jnp.ones_like(x_ref[indexer]))], lambda x_ref, indexer: ( - ref_addupdate(x_ref, indexer, - jnp.ones(x_ref.shape, x_ref.dtype)[None][(0, - *indexer)]) - or [jnp.ones(x_ref.shape, x_ref.dtype)[None][(0, *indexer)]]) + ref_addupdate(x_ref, indexer, jnp.ones_like(x_ref[indexer])) + or [jnp.ones_like(x_ref[indexer])]), ], ) def test_vmap(self, ref_shape, ref_bdim, idx_shape, indexed_dims, idx_bdims, out_bdim, op): - - float_ = (jnp.dtype('float64') if config.enable_x64.value else - jnp.dtype('float32')) - int_ = (jnp.dtype('int64') if config.enable_x64.value else - jnp.dtype('int32')) + intx = dtypes.canonicalize_dtype(jnp.int64) + floatx = dtypes.canonicalize_dtype(jnp.float64) axis_size = 7 - out_shape = tuple(d for d, b in zip(ref_shape, indexed_dims) if not b) - if any(indexed_dims): - out_shape = (*idx_shape, *out_shape) def maybe_insert(shape, idx): if idx is None: @@ -505,13 +492,13 @@ def maybe_insert(shape, idx): return tuple_insert(shape, idx, axis_size) batched_ref_shape = maybe_insert(ref_shape, ref_bdim) - ref_aval = shaped_array_ref(ref_shape, float_) - bat_ref_aval = shaped_array_ref(batched_ref_shape, float_) + ref_aval = shaped_array_ref(ref_shape, floatx) + bat_ref_aval = shaped_array_ref(batched_ref_shape, floatx) - idx_avals = [core.ShapedArray(idx_shape, int_) + idx_avals = [core.ShapedArray(idx_shape, intx) for _ in idx_bdims] bat_idx_avals = [ - core.ShapedArray(maybe_insert(idx_shape, idx_bdim), int_) + core.ShapedArray(maybe_insert(idx_shape, idx_bdim), intx) for idx_bdim in idx_bdims] def f(x_ref, *idxs): @@ -531,6 +518,7 @@ def f(x_ref, *idxs): wrap_init(f_batched, 1 + len(bat_idx_avals)), [bat_ref_aval, *bat_idx_avals]) jaxpr, consts = discharge_state(stateful_jaxpr, stateful_consts) discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, a, *idxs) + # vmap-of-discharge stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( wrap_init(f, 1 + len(idx_avals)), [ref_aval, *idx_avals]) @@ -792,7 +780,7 @@ def body(i, st): lax.fori_loop(0, 5, body, init_val=()) return a_ref[...], b_ref[...] - ref = lambda x: AbstractRef(core.raise_to_shaped(core.get_aval(x))) + ref = lambda x: AbstractRef(core.get_aval(x)) f_jaxpr = jax.make_jaxpr(f)(ref(1.), ref(2.)) jaxpr, _ = discharge_state(f_jaxpr.jaxpr, (), should_discharge=[False, True]) # Effects on y_ref were discharged away but not the effects on x_ref @@ -806,294 +794,303 @@ def body(i, st): self.assertLen(jaxpr.outvars, 3) -if CAN_USE_HYPOTHESIS: - - def index_arrays(size, idx_shape): - valid_idx = hps.integers(min_value=-size, max_value=size - 1) - return hnp.arrays(np.int32, idx_shape, elements=valid_idx) - - Shape = tuple[int, ...] - - class IndexParam(NamedTuple): - ref_aval: shaped_array_ref - ref_shape: Shape - indexed_dims: list[bool] - idx_avals: tuple[core.ShapedArray, ...] - idx_shape: Shape - slice_aval: core.ShapedArray - slice_shape: Shape - - @hps.composite - def index_params(draw): - ref_shape = draw(hnp.array_shapes(max_dims=4, max_side=7), label='ref_shape') - indexed_dims = draw(hps.lists(hps.booleans(), - min_size=len(ref_shape), - max_size=len(ref_shape))) - idx_shape = draw(hnp.array_shapes(max_dims=3, max_side=5)) - if any(indexed_dims): - sliced_shape = (s for s, b in zip(ref_shape, indexed_dims) if not b) +def index_arrays(size, idx_shape): + valid_idx = hps.integers(min_value=-size, max_value=size - 1) + return hnp.arrays(np.int32, idx_shape, elements=valid_idx) + +Shape = tuple[int, ...] + +class IndexParam(NamedTuple): + ref_aval: shaped_array_ref + ref_shape: Shape + indexed_dims: list[bool] + idx_avals: tuple[core.ShapedArray, ...] + idx_shape: Shape + slice_aval: core.ShapedArray + slice_shape: Shape + +@hps.composite +def index_params(draw): + ref_shape = draw(hnp.array_shapes(max_dims=4, max_side=7), label='ref_shape') + indexed_dims = draw(hps.lists(hps.booleans(), + min_size=len(ref_shape), + max_size=len(ref_shape))) + idx_shape = draw(hnp.array_shapes(max_dims=3, max_side=5)) + if not any(indexed_dims): + slice_shape = ref_shape + else: + sliced_shape = tuple(s for s, b in zip(ref_shape, indexed_dims) if not b) + int_indexers_contiguous = bool( + np.all(np.diff(np.where(indexed_dims)[0]) == 1) + ) + if not int_indexers_contiguous: slice_shape = (*idx_shape, *sliced_shape) else: - slice_shape = ref_shape - ref_aval = shaped_array_ref(ref_shape, np.float32) - idx_avals = tuple(core.ShapedArray(idx_shape, np.int32) for _ in - range(sum(indexed_dims))) - slice_aval = core.ShapedArray(slice_shape, np.float32) - return IndexParam(ref_aval, ref_shape, indexed_dims, idx_avals, idx_shape, - slice_aval, slice_shape) - - class VmappableIndexParam(NamedTuple): - index_param: IndexParam - ref_bdim: int | None - non_slice_idx_bdims: tuple[int | None, ...] - slice_bdim: int - bat_ref_aval: shaped_array_ref - bat_ref_shape: Shape - bat_non_slice_idx_avals: tuple[core.ShapedArray, ...] - bat_non_slice_idx_shapes: tuple[Shape, ...] - bat_slice_aval: core.ShapedArray - bat_slice_shape: Shape - - def maybe_tuple_insert(t: tuple[Any, ...], idx: int | None, - val: Any) -> tuple[Any, ...]: - if idx is None: - return t - return tuple_insert(t, idx, val) - - @hps.composite - def vmappable_index_params(draw, *, op_type: str): - axis_size = draw(hps.integers(min_value=1, max_value=7), label='axis_size') - index_param: IndexParam = draw(index_params()) - non_slice_idx_bdims = tuple( - draw(hps.one_of( - hps.none(), - hps.integers(min_value=0, max_value=len(index_param.idx_shape)))) - for b in index_param.indexed_dims if b) - bat_non_slice_idx_shapes = tuple( - maybe_tuple_insert(index_param.idx_shape, idx_bdim, axis_size) - for idx_bdim in non_slice_idx_bdims) - if op_type == "swap": - # In a swap, the ref *must* be batched - ref_bdim = draw(hps.integers(min_value=0, - max_value=len(index_param.ref_shape))) - if any(idx_bdim is not None for idx_bdim in non_slice_idx_bdims): - # If it's a swap, if indices are batched, val must be batched. - slice_bdim = draw(hps.integers( - min_value=0, max_value=len(index_param.slice_shape))) - else: - slice_bdim = draw(hps.one_of(hps.none(), hps.integers( - min_value=0, max_value=len(index_param.slice_shape)))) - elif op_type == "get": - # In a get, the indices must be batched or ref is batched - if all(idx_bdim is None for idx_bdim in non_slice_idx_bdims): - ref_bdim = draw(hps.integers(min_value=0, - max_value=len(index_param.ref_shape))) - else: - ref_bdim = draw(hps.one_of(hps.none(), - hps.integers(min_value=0, max_value=len(index_param.ref_shape)))) + insert_pos = indexed_dims.index(True) + slice_shape = ( + *sliced_shape[:insert_pos], + *idx_shape, + *sliced_shape[insert_pos:], + ) + ref_aval = shaped_array_ref(ref_shape, np.float32) + idx_avals = tuple(core.ShapedArray(idx_shape, np.int32) for _ in + range(sum(indexed_dims))) + slice_aval = core.ShapedArray(slice_shape, np.float32) + return IndexParam(ref_aval, ref_shape, indexed_dims, idx_avals, idx_shape, + slice_aval, slice_shape) + +class VmappableIndexParam(NamedTuple): + index_param: IndexParam + ref_bdim: int | None + non_slice_idx_bdims: tuple[int | None, ...] + slice_bdim: int + bat_ref_aval: shaped_array_ref + bat_ref_shape: Shape + bat_non_slice_idx_avals: tuple[core.ShapedArray, ...] + bat_non_slice_idx_shapes: tuple[Shape, ...] + bat_slice_aval: core.ShapedArray + bat_slice_shape: Shape + +def maybe_tuple_insert(t: tuple[Any, ...], idx: int | None, + val: Any) -> tuple[Any, ...]: + if idx is None: + return t + return tuple_insert(t, idx, val) + +@hps.composite +def vmappable_index_params(draw, *, op_type: str): + axis_size = draw(hps.integers(min_value=1, max_value=7), label='axis_size') + index_param: IndexParam = draw(index_params()) + non_slice_idx_bdims = tuple( + draw(hps.one_of( + hps.none(), + hps.integers(min_value=0, max_value=len(index_param.idx_shape)))) + for b in index_param.indexed_dims if b) + bat_non_slice_idx_shapes = tuple( + maybe_tuple_insert(index_param.idx_shape, idx_bdim, axis_size) + for idx_bdim in non_slice_idx_bdims) + if op_type == "swap": + # In a swap, the ref *must* be batched + ref_bdim = draw(hps.integers(min_value=0, + max_value=len(index_param.ref_shape))) + if any(idx_bdim is not None for idx_bdim in non_slice_idx_bdims): + # If it's a swap, if indices are batched, val must be batched. slice_bdim = draw(hps.integers( min_value=0, max_value=len(index_param.slice_shape))) + else: + slice_bdim = draw(hps.one_of(hps.none(), hps.integers( + min_value=0, max_value=len(index_param.slice_shape)))) + elif op_type == "get": + # In a get, the indices must be batched or ref is batched + if all(idx_bdim is None for idx_bdim in non_slice_idx_bdims): + ref_bdim = draw(hps.integers(min_value=0, + max_value=len(index_param.ref_shape))) + else: + ref_bdim = draw(hps.one_of(hps.none(), + hps.integers(min_value=0, max_value=len(index_param.ref_shape)))) + slice_bdim = draw(hps.integers( + min_value=0, max_value=len(index_param.slice_shape))) + + bat_ref_shape = maybe_tuple_insert(index_param.ref_shape, ref_bdim, axis_size) + bat_ref_aval = shaped_array_ref(bat_ref_shape, np.float32) + bat_non_slice_idx_avals = tuple( + core.ShapedArray(shape, np.int32) for shape in bat_non_slice_idx_shapes) + bat_slice_shape = maybe_tuple_insert(index_param.slice_shape, slice_bdim, axis_size) + bat_slice_aval = core.ShapedArray(bat_slice_shape, np.float32) + return VmappableIndexParam(index_param, ref_bdim, non_slice_idx_bdims, + slice_bdim, bat_ref_aval, bat_ref_shape, + bat_non_slice_idx_avals, bat_non_slice_idx_shapes, + bat_slice_aval, bat_slice_shape) + +class GetVmapParams(NamedTuple): + vmap_index_param: VmappableIndexParam + bat_ref: np.ndarray + bat_idxs: tuple[np.ndarray, ...] + +@hps.composite +def get_vmap_params(draw): + vmap_index_param: VmappableIndexParam = draw( + vmappable_index_params(op_type="get")) + bat_ref = draw(hnp.arrays(np.float32, vmap_index_param.bat_ref_shape)) + bat_idx_shapes_ = iter(vmap_index_param.bat_non_slice_idx_shapes) + bat_idxs = tuple( + draw(index_arrays(size, next(bat_idx_shapes_))) + for size, indexed in zip( + vmap_index_param.index_param.ref_shape, + vmap_index_param.index_param.indexed_dims) + if indexed) + assert next(bat_idx_shapes_, None) is None + return GetVmapParams(vmap_index_param, bat_ref, bat_idxs) + +class SetVmapParams(NamedTuple): + vmap_index_param: VmappableIndexParam + bat_ref: np.ndarray + bat_val: np.ndarray + bat_idxs: tuple[np.ndarray, ...] + +@hps.composite +def set_vmap_params(draw): + vmap_index_param: VmappableIndexParam = draw(vmappable_index_params( + op_type="swap")) + bat_ref = draw(hnp.arrays(np.float32, vmap_index_param.bat_ref_shape)) + bat_idx_shapes_ = iter(vmap_index_param.bat_non_slice_idx_shapes) + bat_idxs = tuple( + draw(index_arrays(size, next(bat_idx_shapes_))) + for size, indexed in zip( + vmap_index_param.index_param.ref_shape, + vmap_index_param.index_param.indexed_dims) + if indexed) + assert next(bat_idx_shapes_, None) is None + bat_val = draw(hnp.arrays(np.float32, vmap_index_param.bat_slice_shape)) + return SetVmapParams(vmap_index_param, bat_ref, bat_val, bat_idxs) + +Indexer = tuple[Union[int, slice, np.ndarray]] + +def _unpack_idx(idx: Indexer + ) -> tuple[Sequence[int | np.ndarray], Sequence[bool]]: + indexed_dims = [type(i) != slice for i in idx] + non_slice_idx = [i for i, b in zip(idx, indexed_dims) if b] + return non_slice_idx, indexed_dims + +def _pack_idx(non_slice_idx: Sequence[int | np.ndarray], + indexed_dims: Sequence[bool]) -> Indexer: + idx_ = iter(non_slice_idx) + idx = tuple(next(idx_) if b else slice(None) for b in indexed_dims) + assert next(idx_, None) is None + return idx + +@jtu.thread_unsafe_test_class() # hypothesis isn't thread-safe +class StateHypothesisTest(jtu.JaxTestCase): + + @hp.given(get_vmap_params()) + @hp.settings(deadline=None, print_blob=True, + max_examples=jtu.NUM_GENERATED_CASES.value) + def test_get_vmap(self, get_vmap_param: GetVmapParams): + + indexed_dims = get_vmap_param.vmap_index_param.index_param.indexed_dims + + def f(ref, *non_slice_idx): + idx = _pack_idx(non_slice_idx, indexed_dims) + return [ref_get(ref, idx)] + ref_aval = get_vmap_param.vmap_index_param.index_param.ref_aval + bat_ref_aval = get_vmap_param.vmap_index_param.bat_ref_aval + bat_non_slice_idx_avals = get_vmap_param.vmap_index_param.bat_non_slice_idx_avals + ref_bdim = get_vmap_param.vmap_index_param.ref_bdim + idx_bdims = get_vmap_param.vmap_index_param.non_slice_idx_bdims + out_bdim = get_vmap_param.vmap_index_param.slice_bdim + non_slice_idx = get_vmap_param.bat_idxs + idx_avals = get_vmap_param.vmap_index_param.index_param.idx_avals + ref = get_vmap_param.bat_ref + + f_batched = jax.vmap(f, in_axes=(ref_bdim, *idx_bdims), out_axes=[out_bdim]) + stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( + wrap_init(f_batched, 1 + len(bat_non_slice_idx_avals)), + [bat_ref_aval, *bat_non_slice_idx_avals]) + jaxpr, consts = discharge_state(stateful_jaxpr, stateful_consts) + discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, ref, *non_slice_idx) + + # vmap-of-discharge + stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( + wrap_init(f, 1 + len(idx_avals)), [ref_aval, *idx_avals]) + jaxpr_, consts_ = discharge_state(stateful_jaxpr, stateful_consts) + f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_), + in_axes=(ref_bdim, *idx_bdims), + out_axes=[out_bdim, ref_bdim]) + vmap_of_discharge_ans = f_batched(ref, *non_slice_idx) - bat_ref_shape = maybe_tuple_insert(index_param.ref_shape, ref_bdim, axis_size) - bat_ref_aval = shaped_array_ref(bat_ref_shape, np.float32) - bat_non_slice_idx_avals = tuple( - core.ShapedArray(shape, np.int32) for shape in bat_non_slice_idx_shapes) - bat_slice_shape = maybe_tuple_insert(index_param.slice_shape, slice_bdim, axis_size) - bat_slice_aval = core.ShapedArray(bat_slice_shape, np.float32) - return VmappableIndexParam(index_param, ref_bdim, non_slice_idx_bdims, - slice_bdim, bat_ref_aval, bat_ref_shape, - bat_non_slice_idx_avals, bat_non_slice_idx_shapes, - bat_slice_aval, bat_slice_shape) - - class GetVmapParams(NamedTuple): - vmap_index_param: VmappableIndexParam - bat_ref: np.ndarray - bat_idxs: tuple[np.ndarray, ...] - - @hps.composite - def get_vmap_params(draw): - vmap_index_param: VmappableIndexParam = draw( - vmappable_index_params(op_type="get")) - bat_ref = draw(hnp.arrays(np.float32, vmap_index_param.bat_ref_shape)) - bat_idx_shapes_ = iter(vmap_index_param.bat_non_slice_idx_shapes) - bat_idxs = tuple( - draw(index_arrays(size, next(bat_idx_shapes_))) - for size, indexed in zip( - vmap_index_param.index_param.ref_shape, - vmap_index_param.index_param.indexed_dims) - if indexed) - assert next(bat_idx_shapes_, None) is None - return GetVmapParams(vmap_index_param, bat_ref, bat_idxs) - - class SetVmapParams(NamedTuple): - vmap_index_param: VmappableIndexParam - bat_ref: np.ndarray - bat_val: np.ndarray - bat_idxs: tuple[np.ndarray, ...] - - @hps.composite - def set_vmap_params(draw): - vmap_index_param: VmappableIndexParam = draw(vmappable_index_params( - op_type="swap")) - bat_ref = draw(hnp.arrays(np.float32, vmap_index_param.bat_ref_shape)) - bat_idx_shapes_ = iter(vmap_index_param.bat_non_slice_idx_shapes) - bat_idxs = tuple( - draw(index_arrays(size, next(bat_idx_shapes_))) - for size, indexed in zip( - vmap_index_param.index_param.ref_shape, - vmap_index_param.index_param.indexed_dims) - if indexed) - assert next(bat_idx_shapes_, None) is None - bat_val = draw(hnp.arrays(np.float32, vmap_index_param.bat_slice_shape)) - return SetVmapParams(vmap_index_param, bat_ref, bat_val, bat_idxs) - - Indexer = tuple[Union[int, slice, np.ndarray]] - - def _unpack_idx(idx: Indexer - ) -> tuple[Sequence[int | np.ndarray], Sequence[bool]]: - indexed_dims = [type(i) != slice for i in idx] - non_slice_idx = [i for i, b in zip(idx, indexed_dims) if b] - return non_slice_idx, indexed_dims - - def _pack_idx(non_slice_idx: Sequence[int | np.ndarray], - indexed_dims: Sequence[bool]) -> Indexer: - idx_ = iter(non_slice_idx) - idx = tuple(next(idx_) if b else slice(None) for b in indexed_dims) - assert next(idx_, None) is None - return idx - - @jtu.thread_unsafe_test_class() # hypothesis isn't thread-safe - class StateHypothesisTest(jtu.JaxTestCase): - - @hp.given(get_vmap_params()) - @hp.settings(deadline=None, print_blob=True, - max_examples=jtu.NUM_GENERATED_CASES.value) - def test_get_vmap(self, get_vmap_param: GetVmapParams): - - indexed_dims = get_vmap_param.vmap_index_param.index_param.indexed_dims - - def f(ref, *non_slice_idx): - idx = _pack_idx(non_slice_idx, indexed_dims) - return [ref_get(ref, idx)] - ref_aval = get_vmap_param.vmap_index_param.index_param.ref_aval - bat_ref_aval = get_vmap_param.vmap_index_param.bat_ref_aval - bat_non_slice_idx_avals = get_vmap_param.vmap_index_param.bat_non_slice_idx_avals - ref_bdim = get_vmap_param.vmap_index_param.ref_bdim - idx_bdims = get_vmap_param.vmap_index_param.non_slice_idx_bdims - out_bdim = get_vmap_param.vmap_index_param.slice_bdim - non_slice_idx = get_vmap_param.bat_idxs - idx_avals = get_vmap_param.vmap_index_param.index_param.idx_avals - ref = get_vmap_param.bat_ref - - f_batched = jax.vmap(f, in_axes=(ref_bdim, *idx_bdims), out_axes=[out_bdim]) - stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( - wrap_init(f_batched, 1 + len(bat_non_slice_idx_avals)), - [bat_ref_aval, *bat_non_slice_idx_avals]) - jaxpr, consts = discharge_state(stateful_jaxpr, stateful_consts) - discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, ref, *non_slice_idx) - - # vmap-of-discharge - stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( - wrap_init(f, 1 + len(idx_avals)), [ref_aval, *idx_avals]) - jaxpr_, consts_ = discharge_state(stateful_jaxpr, stateful_consts) - f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_), - in_axes=(ref_bdim, *idx_bdims), - out_axes=[out_bdim, ref_bdim]) - vmap_of_discharge_ans = f_batched(ref, *non_slice_idx) - - self.assertAllClose(discharge_of_vmap_ans, vmap_of_discharge_ans, - check_dtypes=False) - - - @hp.given(set_vmap_params()) - @hp.settings(deadline=None, print_blob=True, - max_examples=jtu.NUM_GENERATED_CASES.value) - def test_set_vmap(self, set_vmap_param: SetVmapParams): - if jtu.test_device_matches(["gpu"]): - self.skipTest("Scatter is nondeterministic on GPU") - indexed_dims = set_vmap_param.vmap_index_param.index_param.indexed_dims - - def f(ref, val, *non_slice_idx): - idx = _pack_idx(non_slice_idx, indexed_dims) - ref_set(ref, idx, val) - return [] - ref_aval = set_vmap_param.vmap_index_param.index_param.ref_aval - bat_ref_aval = set_vmap_param.vmap_index_param.bat_ref_aval - bat_non_slice_idx_avals = set_vmap_param.vmap_index_param.bat_non_slice_idx_avals - ref_bdim = set_vmap_param.vmap_index_param.ref_bdim - idx_bdims = set_vmap_param.vmap_index_param.non_slice_idx_bdims - non_slice_idx = set_vmap_param.bat_idxs - idx_avals = set_vmap_param.vmap_index_param.index_param.idx_avals - ref = set_vmap_param.bat_ref - val = set_vmap_param.bat_val - bat_val_aval = set_vmap_param.vmap_index_param.bat_slice_aval - val_aval = set_vmap_param.vmap_index_param.index_param.slice_aval - val_bdim = set_vmap_param.vmap_index_param.slice_bdim - - f_batched = jax.vmap(f, in_axes=(ref_bdim, val_bdim, *idx_bdims), - out_axes=[]) - stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( - wrap_init(f_batched, 2 + len(bat_non_slice_idx_avals)), - [bat_ref_aval, bat_val_aval, *bat_non_slice_idx_avals]) - jaxpr, consts = discharge_state(stateful_jaxpr, stateful_consts) - discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, ref, val, *non_slice_idx) - - # vmap-of-discharge - stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( - wrap_init(f, 2 + len(idx_avals)), [ref_aval, val_aval, *idx_avals]) - jaxpr_, consts_ = discharge_state(stateful_jaxpr, stateful_consts) - f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_), - in_axes=(ref_bdim, val_bdim, *idx_bdims), - out_axes=[ref_bdim]) - vmap_of_discharge_ans = f_batched(ref, val, *non_slice_idx) - - self.assertAllClose(discharge_of_vmap_ans, vmap_of_discharge_ans, - check_dtypes=False) - - - @hp.given(set_vmap_params()) - @hp.settings(deadline=None, print_blob=True, - max_examples=jtu.NUM_GENERATED_CASES.value) - def test_addupdate_vmap(self, set_vmap_param: SetVmapParams): - - indexed_dims = set_vmap_param.vmap_index_param.index_param.indexed_dims - - def f(ref, val, *non_slice_idx): - idx = _pack_idx(non_slice_idx, indexed_dims) - ref_addupdate(ref, idx, val) - return [] - ref_aval = set_vmap_param.vmap_index_param.index_param.ref_aval - bat_ref_aval = set_vmap_param.vmap_index_param.bat_ref_aval - bat_non_slice_idx_avals = set_vmap_param.vmap_index_param.bat_non_slice_idx_avals - ref_bdim = set_vmap_param.vmap_index_param.ref_bdim - idx_bdims = set_vmap_param.vmap_index_param.non_slice_idx_bdims - non_slice_idx = set_vmap_param.bat_idxs - idx_avals = set_vmap_param.vmap_index_param.index_param.idx_avals - ref = set_vmap_param.bat_ref - val = set_vmap_param.bat_val - bat_val_aval = set_vmap_param.vmap_index_param.bat_slice_aval - val_aval = set_vmap_param.vmap_index_param.index_param.slice_aval - val_bdim = set_vmap_param.vmap_index_param.slice_bdim - - f_batched = jax.vmap(f, in_axes=(ref_bdim, val_bdim, *idx_bdims), - out_axes=[]) - stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( - wrap_init(f_batched, 2 + len(bat_non_slice_idx_avals)), - [bat_ref_aval, bat_val_aval, *bat_non_slice_idx_avals]) - jaxpr, consts = discharge_state(stateful_jaxpr, stateful_consts) - discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, ref, val, *non_slice_idx) - - # vmap-of-discharge - stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( - wrap_init(f, 2 + len(idx_avals)), [ref_aval, val_aval, *idx_avals]) - jaxpr_, consts_ = discharge_state(stateful_jaxpr, stateful_consts) - f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_), - in_axes=(ref_bdim, val_bdim, *idx_bdims), - out_axes=[ref_bdim]) - vmap_of_discharge_ans = f_batched(ref, val, *non_slice_idx) - - self.assertAllClose(discharge_of_vmap_ans, vmap_of_discharge_ans, - check_dtypes=False) + self.assertAllClose(discharge_of_vmap_ans, vmap_of_discharge_ans, + check_dtypes=False) + + + @hp.given(set_vmap_params()) + @hp.settings(deadline=None, print_blob=True, + max_examples=jtu.NUM_GENERATED_CASES.value) + def test_set_vmap(self, set_vmap_param: SetVmapParams): + if jtu.test_device_matches(["gpu"]): + self.skipTest("Scatter is nondeterministic on GPU") + indexed_dims = set_vmap_param.vmap_index_param.index_param.indexed_dims + + def f(ref, val, *non_slice_idx): + idx = _pack_idx(non_slice_idx, indexed_dims) + ref_set(ref, idx, val) + return [] + ref_aval = set_vmap_param.vmap_index_param.index_param.ref_aval + bat_ref_aval = set_vmap_param.vmap_index_param.bat_ref_aval + bat_non_slice_idx_avals = set_vmap_param.vmap_index_param.bat_non_slice_idx_avals + ref_bdim = set_vmap_param.vmap_index_param.ref_bdim + idx_bdims = set_vmap_param.vmap_index_param.non_slice_idx_bdims + non_slice_idx = set_vmap_param.bat_idxs + idx_avals = set_vmap_param.vmap_index_param.index_param.idx_avals + ref = set_vmap_param.bat_ref + val = set_vmap_param.bat_val + bat_val_aval = set_vmap_param.vmap_index_param.bat_slice_aval + val_aval = set_vmap_param.vmap_index_param.index_param.slice_aval + val_bdim = set_vmap_param.vmap_index_param.slice_bdim + + f_batched = jax.vmap(f, in_axes=(ref_bdim, val_bdim, *idx_bdims), + out_axes=[]) + stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( + wrap_init(f_batched, 2 + len(bat_non_slice_idx_avals)), + [bat_ref_aval, bat_val_aval, *bat_non_slice_idx_avals]) + jaxpr, consts = discharge_state(stateful_jaxpr, stateful_consts) + discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, ref, val, *non_slice_idx) + + # vmap-of-discharge + stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( + wrap_init(f, 2 + len(idx_avals)), [ref_aval, val_aval, *idx_avals]) + jaxpr_, consts_ = discharge_state(stateful_jaxpr, stateful_consts) + f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_), + in_axes=(ref_bdim, val_bdim, *idx_bdims), + out_axes=[ref_bdim]) + vmap_of_discharge_ans = f_batched(ref, val, *non_slice_idx) + + self.assertAllClose(discharge_of_vmap_ans, vmap_of_discharge_ans, + check_dtypes=False) + + + @hp.given(set_vmap_params()) + @hp.settings(deadline=None, print_blob=True, + max_examples=jtu.NUM_GENERATED_CASES.value) + def test_addupdate_vmap(self, set_vmap_param: SetVmapParams): + + indexed_dims = set_vmap_param.vmap_index_param.index_param.indexed_dims + + def f(ref, val, *non_slice_idx): + idx = _pack_idx(non_slice_idx, indexed_dims) + ref_addupdate(ref, idx, val) + return [] + ref_aval = set_vmap_param.vmap_index_param.index_param.ref_aval + bat_ref_aval = set_vmap_param.vmap_index_param.bat_ref_aval + bat_non_slice_idx_avals = set_vmap_param.vmap_index_param.bat_non_slice_idx_avals + ref_bdim = set_vmap_param.vmap_index_param.ref_bdim + idx_bdims = set_vmap_param.vmap_index_param.non_slice_idx_bdims + non_slice_idx = set_vmap_param.bat_idxs + idx_avals = set_vmap_param.vmap_index_param.index_param.idx_avals + ref = set_vmap_param.bat_ref + val = set_vmap_param.bat_val + bat_val_aval = set_vmap_param.vmap_index_param.bat_slice_aval + val_aval = set_vmap_param.vmap_index_param.index_param.slice_aval + val_bdim = set_vmap_param.vmap_index_param.slice_bdim + + f_batched = jax.vmap(f, in_axes=(ref_bdim, val_bdim, *idx_bdims), + out_axes=[]) + stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( + wrap_init(f_batched, 2 + len(bat_non_slice_idx_avals)), + [bat_ref_aval, bat_val_aval, *bat_non_slice_idx_avals]) + jaxpr, consts = discharge_state(stateful_jaxpr, stateful_consts) + discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, ref, val, *non_slice_idx) + + # vmap-of-discharge + stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( + wrap_init(f, 2 + len(idx_avals)), [ref_aval, val_aval, *idx_avals]) + jaxpr_, consts_ = discharge_state(stateful_jaxpr, stateful_consts) + f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_), + in_axes=(ref_bdim, val_bdim, *idx_bdims), + out_axes=[ref_bdim]) + vmap_of_discharge_ans = f_batched(ref, val, *non_slice_idx) + + self.assertAllClose(discharge_of_vmap_ans, vmap_of_discharge_ans, + check_dtypes=False) class StateControlFlowTest(jtu.JaxTestCase): @@ -1139,7 +1136,7 @@ def false_fun(): y_ref[...] = 2. lax.cond(pred, true_fun, false_fun) return x_ref[...], y_ref[...] - ref = lambda x: AbstractRef(core.raise_to_shaped(core.get_aval(x))) + ref = lambda x: AbstractRef(core.get_aval(x)) f_jaxpr = jax.make_jaxpr(f0)(False, ref(3.), ref(4.)) jaxpr, _ = discharge_state(f_jaxpr.jaxpr, (), should_discharge=[False, False, True]) # Effects on y_ref were discharged away but not the effects on x_ref @@ -1631,216 +1628,218 @@ def _body(ref): jtu.check_grads(f, (0.5,), order=3) -if CAN_USE_HYPOTHESIS: - - class FuncSpec(NamedTuple): - fun: Callable[..., Any] - name: str - min_rank: int = 0 - max_rank: int = 4 - min_dim: int = 0 - max_dim: int = 4 - - def call(self, *args): - return run_state(self.fun)(*args) - - def ref(self, *args): - return run_state_reference(self.fun)(*args) - - def sin_stateful(refs): - x_ref, y_ref = refs - y_ref[...] = jnp.sin(x_ref[...]) - - sin_spec = FuncSpec(sin_stateful, "sin") - - def cos_stateful(refs): - x_ref, y_ref = refs - y_ref[...] = jnp.cos(x_ref[...]) - - cos_spec = FuncSpec(cos_stateful, "cos") - - def mul2_stateful(refs): - x_ref, y_ref = refs - y_ref[...] = x_ref[...] - y_ref[...] = y_ref[...] + x_ref[...] - - mul2_spec = FuncSpec(mul2_stateful, "mul2") - - def mul2_stateful_with_constant(refs): +class FuncSpec(NamedTuple): + fun: Callable[..., Any] + name: str + min_rank: int = 0 + max_rank: int = 4 + min_dim: int = 0 + max_dim: int = 4 + + def call(self, *args): + return run_state(self.fun)(*args) + + def ref(self, *args): + return run_state_reference(self.fun)(*args) + +def sin_stateful(refs): + x_ref, y_ref = refs + y_ref[...] = jnp.sin(x_ref[...]) + +sin_spec = FuncSpec(sin_stateful, "sin") + +def cos_stateful(refs): + x_ref, y_ref = refs + y_ref[...] = jnp.cos(x_ref[...]) + +cos_spec = FuncSpec(cos_stateful, "cos") + +def mul2_stateful(refs): + x_ref, y_ref = refs + y_ref[...] = x_ref[...] + y_ref[...] = y_ref[...] + x_ref[...] + +mul2_spec = FuncSpec(mul2_stateful, "mul2") + +def mul2_stateful_with_constant(refs): + x_ref, y_ref = refs + y_ref[...] = (2. * np.ones(x_ref.shape, x_ref.dtype)) * x_ref[...] + +mul2_constant_spec = FuncSpec(mul2_stateful_with_constant, "mul2_c") + +def crazy_identity_stateful(refs): + x_ref, y_ref = refs + x = x_ref[...] + x_ref[...] = (x + x) / 2 + y_ref[...] = x_ref[...] + y = y_ref[...] + y_ref[...] = (y + y) / 2 + +crazy_identity_spec = FuncSpec(crazy_identity_stateful, "id") + +def func_spec(depth: int = 4): + raw_specs = hps.sampled_from([sin_spec, cos_spec, mul2_spec, + mul2_constant_spec, crazy_identity_spec]) + if depth > 0: + return hps.one_of([raw_specs, nest_spec(depth - 1), add_spec(depth - 1), + compose_spec(depth - 1)]) + return raw_specs + +@hps.composite +def compose_spec(draw, depth): + f1 = draw(func_spec(depth)) + f2 = draw(func_spec(depth)) + def wrapped_impl(*args): + f1.fun(*args) + f2.fun(*args) + return FuncSpec(wrapped_impl, + f"({f2.name} . {f1.name})", + min_rank=max(f1.min_rank, f2.min_rank), + max_rank=min(f1.max_rank, f2.max_rank), + min_dim=max(f1.min_dim, f2.min_dim), + max_dim=min(f1.max_dim, f2.max_dim)) + +@hps.composite +def nest_spec(draw, depth): + f = draw(func_spec(depth)) + def wrapped_impl(refs): x_ref, y_ref = refs - y_ref[...] = (2. * np.ones(x_ref.shape, x_ref.dtype)) * x_ref[...] - - mul2_constant_spec = FuncSpec(mul2_stateful_with_constant, "mul2_c") - - def crazy_identity_stateful(refs): + x, y = x_ref[...], y_ref[...] + x, y = run_state(f.fun)((x, y)) + x_ref[...], y_ref[...] = x, y + return FuncSpec(wrapped_impl, + f"nest({f.name})", + min_rank=f.min_rank, + max_rank=f.max_rank, + min_dim=f.min_dim, + max_dim=f.max_dim) + + +@hps.composite +def add_spec(draw, depth): + f1 = draw(func_spec(depth)) + f2 = draw(func_spec(depth)) + def wrapped_impl(refs): x_ref, y_ref = refs - x = x_ref[...] - x_ref[...] = (x + x) / 2 - y_ref[...] = x_ref[...] - y = y_ref[...] - y_ref[...] = (y + y) / 2 - - crazy_identity_spec = FuncSpec(crazy_identity_stateful, "id") - - def func_spec(depth: int = 4): - raw_specs = hps.sampled_from([sin_spec, cos_spec, mul2_spec, - mul2_constant_spec, crazy_identity_spec]) - if depth > 0: - return hps.one_of([raw_specs, nest_spec(depth - 1), add_spec(depth - 1), - compose_spec(depth - 1)]) - return raw_specs - - @hps.composite - def compose_spec(draw, depth): - f1 = draw(func_spec(depth)) - f2 = draw(func_spec(depth)) - def wrapped_impl(*args): - f1.fun(*args) - f2.fun(*args) - return FuncSpec(wrapped_impl, - f"({f2.name} . {f1.name})", - min_rank=max(f1.min_rank, f2.min_rank), - max_rank=min(f1.max_rank, f2.max_rank), - min_dim=max(f1.min_dim, f2.min_dim), - max_dim=min(f1.max_dim, f2.max_dim)) - - @hps.composite - def nest_spec(draw, depth): - f = draw(func_spec(depth)) - def wrapped_impl(refs): - x_ref, y_ref = refs - x, y = x_ref[...], y_ref[...] - x, y = run_state(f.fun)((x, y)) - x_ref[...], y_ref[...] = x, y - return FuncSpec(wrapped_impl, - f"nest({f.name})", - min_rank=f.min_rank, - max_rank=f.max_rank, - min_dim=f.min_dim, - max_dim=f.max_dim) - - - @hps.composite - def add_spec(draw, depth): - f1 = draw(func_spec(depth)) - f2 = draw(func_spec(depth)) - def wrapped_impl(refs): - x_ref, y_ref = refs - x, y = x_ref[...], y_ref[...] - x1, y1 = run_state(f1.fun)((x, y)) - x2, y2 = run_state(f2.fun)((x, y)) - x_ref[...], y_ref[...] = x1 + x2, y1 + y2 - return FuncSpec(wrapped_impl, - f"({f2.name} + {f1.name})", - min_rank=max(f1.min_rank, f2.min_rank), - max_rank=min(f1.max_rank, f2.max_rank), - min_dim=max(f1.min_dim, f2.min_dim), - max_dim=min(f1.max_dim, f2.max_dim)) - - @jtu.thread_unsafe_test_class() # because of hypothesis - class RunStateHypothesisTest(jtu.JaxTestCase): - - @jax.legacy_prng_key('allow') - @hp.given(hps.data()) - @hp.settings(deadline=None, print_blob=True, - max_examples=jtu.NUM_GENERATED_CASES.value) - def test_jvp(self, data): - - spec = data.draw(func_spec()) - - def impl(x): - return spec.call((x, jnp.zeros_like(x)))[1] - - def ref(x): - return spec.ref((x, jnp.zeros_like(x)))[1] - - k1, k2 = random.split(random.PRNGKey(0)) - shape = data.draw(hnp.array_shapes(min_dims=spec.min_rank, - max_dims=spec.max_rank, min_side=spec.min_dim, - max_side=spec.max_dim)) - x = random.normal(k1, shape) - t = random.normal(k2, x.shape) - y, y_t = jax.jvp(impl, (x,), (t,)) - y_ref, y_ref_t = jax.jvp(ref, (x,), (t,)) - self.assertAllClose(y, y_ref) - self.assertAllClose(y_t, y_ref_t) - - @jax.legacy_prng_key('allow') - @hp.given(hps.data()) - @hp.settings(deadline=None, print_blob=True, - max_examples=jtu.NUM_GENERATED_CASES.value) - def test_linearize(self, data): - - spec = data.draw(func_spec()) - - def impl(x): - return spec.call((x, jnp.zeros_like(x)))[1] - - def ref(x): - return spec.ref((x, jnp.zeros_like(x)))[1] - - - k1, k2 = random.split(random.PRNGKey(0)) - shape = data.draw(hnp.array_shapes(min_dims=spec.min_rank, - max_dims=spec.max_rank, min_side=spec.min_dim, - max_side=spec.max_dim)) - x = random.normal(k1, shape) - y, impl_lin = jax.linearize(impl, x) - y_ref, ref_lin = jax.linearize(ref, x) - self.assertAllClose(y, y_ref, atol=1e-2, rtol=1e-2) - t = random.normal(k2, x.shape) - self.assertAllClose(impl_lin(t), ref_lin(t), atol=1e-2, rtol=1e-2) - - @jax.legacy_prng_key('allow') - @hp.given(hps.data()) - @hp.settings(deadline=None, print_blob=True, - max_examples=jtu.NUM_GENERATED_CASES.value) - def test_vjp(self, data): - - spec = data.draw(func_spec()) - - def impl(x): - return spec.call((x, jnp.zeros_like(x)))[1] - - def ref(x): - return spec.ref((x, jnp.zeros_like(x)))[1] - - - key, k1, k2 = random.split(random.PRNGKey(0), 3) - shape = data.draw(hnp.array_shapes(min_dims=spec.min_rank, - max_dims=spec.max_rank, min_side=spec.min_dim, - max_side=spec.max_dim)) - x = random.normal(k1, shape) - - # First order - y, impl_lin = jax.linearize(impl, x) - y_ref, ref_lin = jax.linearize(ref, x) - self.assertAllClose(y, y_ref) - t = random.normal(k2, x.shape) - self.assertAllClose(impl_lin(t), ref_lin(t)) - - y, impl_vjp = jax.vjp(impl, x) - y_ref, ref_vjp = jax.vjp(ref, x) - self.assertAllClose(y, y_ref) - t = random.normal(jax.random.clone(k2), x.shape) - y2 = random.normal(jax.random.clone(k1), y.shape) - self.assertAllClose(impl_vjp(t), ref_vjp(t)) - - # Second order - key, k1, k2 = random.split(key, 3) - t2 = random.normal(k2, t.shape) - - (x,), impl_lin2 = jax.linearize(impl_vjp, t2) - (x_ref,), ref_lin2 = jax.linearize(ref_vjp, t2) - self.assertAllClose(x, x_ref) - y2 = random.normal(k1, y.shape) - self.assertAllClose(impl_lin2(y2), ref_lin2(y2)) - - (x,), impl_vjp2 = jax.vjp(impl_vjp, t2) - (x_ref,), ref_vjp2 = jax.vjp(ref_vjp, t2) - self.assertAllClose(x, x_ref) - y2 = random.normal(jax.random.clone(k1), y.shape) - self.assertAllClose(impl_vjp2((y2,)), ref_vjp2((y2,))) + x, y = x_ref[...], y_ref[...] + x1, y1 = run_state(f1.fun)((x, y)) + x2, y2 = run_state(f2.fun)((x, y)) + x_ref[...], y_ref[...] = x1 + x2, y1 + y2 + return FuncSpec(wrapped_impl, + f"({f2.name} + {f1.name})", + min_rank=max(f1.min_rank, f2.min_rank), + max_rank=min(f1.max_rank, f2.max_rank), + min_dim=max(f1.min_dim, f2.min_dim), + max_dim=min(f1.max_dim, f2.max_dim)) + +@jtu.thread_unsafe_test_class() # because of hypothesis +class RunStateHypothesisTest(jtu.JaxTestCase): + + @jax.legacy_prng_key('allow') + @hp.given(hps.data()) + @hp.settings(deadline=None, print_blob=True, + max_examples=jtu.NUM_GENERATED_CASES.value) + def test_jvp(self, data): + + spec = data.draw(func_spec()) + + def impl(x): + return spec.call((x, jnp.zeros_like(x)))[1] + + def ref(x): + return spec.ref((x, jnp.zeros_like(x)))[1] + + k1, k2 = random.split(random.PRNGKey(0)) + shape = data.draw(hnp.array_shapes(min_dims=spec.min_rank, + max_dims=spec.max_rank, min_side=spec.min_dim, + max_side=spec.max_dim)) + x = random.normal(k1, shape) + t = random.normal(k2, x.shape) + y, y_t = jax.jvp(impl, (x,), (t,)) + y_ref, y_ref_t = jax.jvp(ref, (x,), (t,)) + self.assertAllClose(y, y_ref) + self.assertAllClose(y_t, y_ref_t) + + @jax.legacy_prng_key('allow') + @hp.given(hps.data()) + @hp.settings(deadline=None, print_blob=True, + max_examples=jtu.NUM_GENERATED_CASES.value) + def test_linearize(self, data): + + spec = data.draw(func_spec()) + + def impl(x): + return spec.call((x, jnp.zeros_like(x)))[1] + + def ref(x): + return spec.ref((x, jnp.zeros_like(x)))[1] + + + k1, k2 = random.split(random.PRNGKey(0)) + shape = data.draw(hnp.array_shapes(min_dims=spec.min_rank, + max_dims=spec.max_rank, min_side=spec.min_dim, + max_side=spec.max_dim)) + x = random.normal(k1, shape) + y, impl_lin = jax.linearize(impl, x) + y_ref, ref_lin = jax.linearize(ref, x) + self.assertAllClose(y, y_ref, atol=1e-2, rtol=1e-2) + t = random.normal(k2, x.shape) + self.assertAllClose(impl_lin(t), ref_lin(t), atol=1e-2, rtol=1e-2) + + @jax.legacy_prng_key('allow') + @hp.given(hps.data()) + @hp.settings(deadline=None, print_blob=True, + max_examples=jtu.NUM_GENERATED_CASES.value) + def test_vjp(self, data): + + spec = data.draw(func_spec()) + + def impl(x): + return spec.call((x, jnp.zeros_like(x)))[1] + + def ref(x): + return spec.ref((x, jnp.zeros_like(x)))[1] + + + key, k1, k2 = random.split(random.PRNGKey(0), 3) + shape = data.draw(hnp.array_shapes(min_dims=spec.min_rank, + max_dims=spec.max_rank, min_side=spec.min_dim, + max_side=spec.max_dim)) + x = random.normal(k1, shape) + + # First order + y, impl_lin = jax.linearize(impl, x) + y_ref, ref_lin = jax.linearize(ref, x) + self.assertAllClose(y, y_ref) + t = random.normal(k2, x.shape) + self.assertAllClose(impl_lin(t), ref_lin(t)) + + y, impl_vjp = jax.vjp(impl, x) + y_ref, ref_vjp = jax.vjp(ref, x) + self.assertAllClose(y, y_ref) + t = random.normal(jax.random.clone(k2), x.shape) + y2 = random.normal(jax.random.clone(k1), y.shape) + self.assertAllClose(impl_vjp(t), ref_vjp(t)) + + if jtu.SKIP_SLOW_TESTS.value: + # Skip second order tests if JAX_SKIP_SLOW_TESTS=true + return + + # Second order + key, k1, k2 = random.split(key, 3) + t2 = random.normal(k2, t.shape) + + (x,), impl_lin2 = jax.linearize(impl_vjp, t2) + (x_ref,), ref_lin2 = jax.linearize(ref_vjp, t2) + self.assertAllClose(x, x_ref) + y2 = random.normal(k1, y.shape) + self.assertAllClose(impl_lin2(y2), ref_lin2(y2)) + + (x,), impl_vjp2 = jax.vjp(impl_vjp, t2) + (x_ref,), ref_vjp2 = jax.vjp(ref_vjp, t2) + self.assertAllClose(x, x_ref) + y2 = random.normal(jax.random.clone(k1), y.shape) + self.assertAllClose(impl_vjp2((y2,)), ref_vjp2((y2,))) if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/unary_ops_accuracy_test.py b/tests/unary_ops_accuracy_test.py new file mode 100644 index 000000000000..6f51651af687 --- /dev/null +++ b/tests/unary_ops_accuracy_test.py @@ -0,0 +1,400 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit test for result accuracy for unary ops.""" + +from typing import Any, Callable, NamedTuple, Union + +from absl.testing import absltest +from absl.testing import parameterized +import jax +from jax._src import config +from jax._src import test_util as jtu +from jax._src.lax import lax +from jax._src.lib import xla_extension +from jax._src.lib.mlir import ir +from jax._src.lib.mlir.dialects import hlo +import jax.numpy as jnp +import numpy as np + + +config.parse_flags_with_absl() + + +class TolerancePair(NamedTuple): + high: Union[lax.Tolerance, lax.AccuracyMode] = lax.AccuracyMode.DEFAULT + low: Union[lax.Tolerance, lax.AccuracyMode] = lax.AccuracyMode.DEFAULT + + +def make_unary_test_cases( + testcase_name: str, + op: Callable[..., Any], + x: np.ndarray, + tp: TolerancePair = None, + min_error_val: float = 0.0, +): + """Creates a single test case.""" + return [{ + "testcase_name": testcase_name, + "op": op, + "x": x, + "tp": tp, + "min_error_val": min_error_val, + }] + + +UNARY_OPS = { + "exp": make_unary_test_cases( + "exp", + lax.exp, + np.arange(84.0, 88.0, dtype=np.float32), + TolerancePair( + high=lax.Tolerance(atol=2**-5, rtol=2**-5, ulps=2), + low=lax.Tolerance(atol=1.5 * 2**-8, rtol=2**-18, ulps=2), + ), + ), + "exp2": make_unary_test_cases( + "exp2", + lax.exp2, + np.arange(84.0, 88.0, dtype=np.float32), + TolerancePair( + high=lax.Tolerance(atol=2**-5, rtol=2**-5, ulps=2), + low=lax.Tolerance(atol=1.5 * 2**-8, rtol=2**-18, ulps=2), + ), + ), + "expm1": make_unary_test_cases( + "expm1", + lax.expm1, + np.arange(84.0, 88.0, dtype=np.float32), + TolerancePair( + high=lax.Tolerance(atol=2**-5, rtol=2**-5, ulps=2), + low=lax.Tolerance(atol=1.5 * 2**-8, rtol=2**-18, ulps=2), + ), + ), + "log": make_unary_test_cases( + "log", + lax.log, + np.linspace(1e28, 2e28, 10, dtype=np.float32), + TolerancePair( + high=lax.Tolerance(atol=0, rtol=2**-10, ulps=0), + low=lax.Tolerance(atol=2**-16, rtol=2**-20, ulps=0), + ), + 1.0, + ), + "log1p": make_unary_test_cases( + "log1p", + lax.log1p, + np.linspace(-9e-8, -8e-8, 10, dtype=np.float32), + TolerancePair( + high=lax.Tolerance(atol=0, rtol=2**-11, ulps=0), + low=lax.Tolerance(atol=0, rtol=2**-14, ulps=0), + ), + 1.0, + ), + "tanh": make_unary_test_cases( + "tanh", + lax.tanh, + np.linspace(5.83, 5.86, 10, dtype=np.float32), + TolerancePair( + high=lax.Tolerance(atol=2**-12, rtol=0, ulps=0), + low=lax.Tolerance(atol=2**-16, rtol=0, ulps=0), + ), + ), + "cos": make_unary_test_cases( + "cos", + lax.cos, + np.linspace(9.7e22, 9.8e22, 10, dtype=np.float32), + TolerancePair( + high=lax.Tolerance(atol=0, rtol=2**-10, ulps=0), + low=lax.Tolerance(atol=0, rtol=2**-30, ulps=0), + ), + ), + "sin": make_unary_test_cases( + "sin", + lax.sin, + np.linspace(9.7e22, 9.8e22, 10, dtype=np.float32), + TolerancePair( + high=lax.Tolerance(atol=0, rtol=2**-10, ulps=0), + low=lax.Tolerance(atol=0, rtol=2**-30, ulps=0), + ), + ), + "tan": make_unary_test_cases( + "tan", + lax.tan, + np.linspace(250.0, 252.0, 10, dtype=np.float32), + TolerancePair( + high=lax.Tolerance(atol=0, rtol=2**-10, ulps=0), + low=lax.Tolerance(atol=0, rtol=2**-30, ulps=0), + ), + ), + "sqrt": make_unary_test_cases( + "sqrt", + lax.sqrt, + np.linspace(250.0, 252.0, 10, dtype=np.float32), + TolerancePair( + high=lax.Tolerance(atol=0, rtol=2**-10, ulps=0), + low=lax.Tolerance(atol=0, rtol=2**-30, ulps=0), + ), + ), + "rsqrt": make_unary_test_cases( + "rsqrt", + lax.rsqrt, + np.linspace(250.0, 252.0, 10, dtype=np.float32), + TolerancePair( + high=lax.Tolerance(atol=0, rtol=2**-10, ulps=0), + low=lax.Tolerance(atol=0, rtol=2**-30, ulps=0), + ), + ), +} + + +def generate_test_cases(op_names): + test_cases = [] + for op in op_names: + op_group = UNARY_OPS[op] + if op_group is None: + raise ValueError(f"No test cases found for op: {op}") + test_cases.extend(op_group) + return test_cases + + +class UnaryOpsAccuracyTest(jtu.JaxTestCase): + + def setUp(self): + if not jtu.stablehlo_version_at_least("1.10.0"): + self.skipTest("Test requires StableHLO v1.10.0 or higher.") + if not jtu.is_device_tpu(): + self.skipTest("Skipping test on non TPU devices.") + super().setUp() + + def test_result_accuracy_mode_attr(self): + with ir.Context() as context: + hlo.register_dialect(context) + attr = hlo.ResultAccuracyModeAttr.get("DEFAULT") + assert attr is not None + assert attr.value == "DEFAULT" + + def test_result_accuracy_attr(self): + with ir.Context() as context: + hlo.register_dialect(context) + attr = hlo.ResultAccuracyAttr.get( + atol=1e-5, rtol=0.0, ulps=1, mode="TOLERANCE" + ) + assert attr is not None + assert attr.mode == "TOLERANCE" + assert attr.atol == 1e-5 + assert attr.rtol == 0.0 + assert attr.ulps == 1 + + @parameterized.named_parameters( + *generate_test_cases(["exp", "expm1", "exp2", "log", "log1p", "tanh"]) + ) + def test_unary_ops_choose_impl(self, op, x, tp, **kwargs): + @jax.jit + def f_default(x): + y = op(x, accuracy=tp.high) + return y + + @jax.jit + def f_accurate(x): + y = op(x, accuracy=tp.low) + return y + + # Input values that would cause large differences between the two + # implementations. + diff = abs(f_default(x) - f_accurate(x)) + if jtu.get_tpu_version() >= 5 and op in [ + lax.tanh, + jnp.tanh, + lax.log, + jnp.log, + ]: + # From tpu version 5 and onwards, even with tighter tolerance, the high performant + # implementation for tanh is chosen because the chip implementation has improved accuracy. + self.assertTrue(jnp.all(diff == 0)) + else: + self.assertTrue(jnp.any(diff > 0)) + + @parameterized.named_parameters( + *generate_test_cases(["exp", "expm1", "exp2", "log", "log1p", "tanh"]) + ) + def test_unary_vmap(self, op, x, tp, min_error_val): + @jax.jit + def f(x, y): + diff = lambda val: abs( + op(val, accuracy=tp.high) - op(val, accuracy=tp.low) + ) + return diff(x), diff(y) + + diff_x, diff_y = jax.vmap(f, in_axes=(None, 0), out_axes=0)( + min_error_val, x + ) + # diff(min_error_val) should be 0 + self.assertTrue(jnp.all(diff_x == 0)) + # diff(x) should be > 0 + if jtu.get_tpu_version() >= 5 and op in [ + lax.tanh, + jnp.tanh, + lax.log, + jnp.log, + ]: + # From tpu version 5 and onwards, even with tighter tolerance, the high performant + # implementation for tanh and log is chosen because the chip implementation has improved accuracy. + self.assertTrue(jnp.all(diff_y == 0)) + else: + self.assertTrue(jnp.any(diff_y > 0)) + + @parameterized.named_parameters( + *generate_test_cases(["exp", "expm1", "exp2"]) + ) + def test_diff_grad(self, op, x, tp, **kwargs): + @jax.jit + def f_default(x): + default_op = op(x, accuracy=tp.low) + return jnp.sum(default_op) + + f_default_grad = jax.grad(f_default) + + @jax.jit + def f_accurate(x): + high_op = op(x, accuracy=tp.high) + return jnp.sum(high_op) + + f_accurate_grad = jax.grad(f_accurate) + # Accuracy should be carried through to the gradient causing + # a large diff. + diff = abs(f_default_grad(x) - f_accurate_grad(x)) + self.assertTrue(jnp.any(diff > 0)) + + @parameterized.named_parameters( + *generate_test_cases(["log", "log1p", "tanh"]) + ) + def test_grad_unchanged(self, op, x, tp, **kwargs): + @jax.jit + def f(x): + return jnp.sum(op(x)) + + f_grad = jax.grad(f) + + @jax.jit + def f_default(x): + default_op = op(x, accuracy=tp.low) + return jnp.sum(default_op) + + f_default_grad = jax.grad(f_default) + + @jax.jit + def f_accurate(x): + high_op = op(x, accuracy=tp.high) + return jnp.sum(high_op) + + f_accurate_grad = jax.grad(f_accurate) + # Accuracy should be carried through to the gradient causing a large diff. + # Diff between f_default and f_accurate should follow diff(f_grad,f_default_grad). + expected_diff = abs(f_grad(x) - f_default_grad(x)) + if jnp.all(expected_diff > 0): + # Don't expect f_accurate_grad and f_default_grad to be equal. + self.assertFalse( + jnp.all(abs(f_default_grad(x) - f_accurate_grad(x)) == 0) + ) + elif jnp.all(expected_diff == 0): + # f_accurate_grad and f_default_grad should be equal. + diff = abs(f_default_grad(x) - f_accurate_grad(x)) + self.assertTrue(jnp.all(diff == 0)) + else: + raise ValueError("Unexpected diff: ", expected_diff) + + @parameterized.named_parameters( + *generate_test_cases(["cos", "sin", "tan", "sqrt", "rsqrt"]) + ) + def test_single_impl(self, op, x, tp, **kwargs): + @jax.jit + def f_tol(x): + return op(x, accuracy=tp.high) + + @jax.jit + def f(x): + return op(x) + + diff = abs(f_tol(x) - f(x)) + self.assertTrue(jnp.all(diff == 0)) + + @parameterized.named_parameters( + *generate_test_cases(["cos", "sin", "tan", "sqrt", "rsqrt"]) + ) + def test_default_grad(self, op, x, tp, **kwargs): + @jax.jit + def f_tol(x): + return jnp.sum(op(x, accuracy=tp.high)) + + @jax.jit + def f(x): + return jnp.sum(op(x)) + + self.assertTrue(jnp.all(abs(jax.grad(f_tol)(x) - jax.grad(f)(x)) == 0)) + + def test_invalid_accuracy(self): + with self.assertRaisesRegex( + ValueError, "At least one of atol, rtol, or ulps must be set." + ): + lax.exp(1.0, accuracy=lax.Tolerance(atol=0.0, rtol=0.0, ulps=0)) + with self.assertRaisesRegex(ValueError, "Tolerances must be non-negative."): + lax.exp(1.0, accuracy=lax.Tolerance(atol=-4e-10, rtol=0.0, ulps=0)) + + @parameterized.named_parameters( + *generate_test_cases([ + "exp", + "expm1", + "exp2", + "log", + "log1p", + "tanh", + "cos", + "sin", + "tan", + "sqrt", + "rsqrt", + ]) + ) + def test_low_tol(self, op, x, **kwargs): + with self.assertRaisesRegex( + xla_extension.XlaRuntimeError, "impl_type.ok()" + ): + op(x, accuracy=lax.Tolerance(atol=1e-60, rtol=1e-60, ulps=0)) + + def test_accuracy_jaxpr(self): + # Since accuracy is not set, the jaxpr should not contain "accuracy". + self.assertNotIn( + "accuracy", + str( + jax.make_jaxpr(lambda x: lax.exp(x, accuracy=None))( + np.arange(4.0, dtype=np.float32) + ) + ), + ) + # Set accuracy. + self.assertIn( + "accuracy", + str( + jax.make_jaxpr( + lambda x: lax.exp( + x, accuracy=lax.Tolerance(atol=1e-60, rtol=1e-60, ulps=0) + ) + )(np.arange(4.0, dtype=np.float32)) + ), + ) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/version_test.py b/tests/version_test.py index b78e61ae024c..14da82df2e3e 100644 --- a/tests/version_test.py +++ b/tests/version_test.py @@ -143,6 +143,7 @@ def testBuildVersionFromEnvironment(self): JAX_NIGHTLY=None, JAXLIB_NIGHTLY=None): with assert_no_subprocess_call(): version = jax.version._get_version_for_build() + self.assertFalse(jax.version._is_prerelease()) self.assertEqual(version, base_version) self.assertValidVersion(version) @@ -150,6 +151,7 @@ def testBuildVersionFromEnvironment(self): JAX_NIGHTLY=None, JAXLIB_NIGHTLY=None): with assert_no_subprocess_call(): version = jax.version._get_version_for_build() + self.assertFalse(jax.version._is_prerelease()) self.assertEqual(version, base_version) self.assertValidVersion(version) @@ -183,6 +185,20 @@ def testBuildVersionFromEnvironment(self): ): with assert_no_subprocess_call(): version = jax.version._get_version_for_build() + self.assertTrue(jax.version._is_prerelease()) + self.assertEqual(version, f"{base_version}rc0") + self.assertValidVersion(version) + + with jtu.set_env( + JAX_RELEASE=None, + JAXLIB_RELEASE="1", + JAX_NIGHTLY=None, + JAXLIB_NIGHTLY=None, + WHEEL_VERSION_SUFFIX="rc0", + ): + with assert_no_subprocess_call(): + version = jax.version._get_version_for_build() + self.assertTrue(jax.version._is_prerelease()) self.assertEqual(version, f"{base_version}rc0") self.assertValidVersion(version) diff --git a/tests/xla_metadata_test.py b/tests/xla_metadata_test.py index d141bc15c249..ba2120fcb7b9 100644 --- a/tests/xla_metadata_test.py +++ b/tests/xla_metadata_test.py @@ -190,6 +190,39 @@ def while_fn(a): if "stablehlo.add" in line: self.assertIn('mhlo.frontend_attributes = {a = "c"}', line) + def test_cond_annotates_branches(self): + sin = jnp.sin + cos = jnp.cos + + @jax.jit + def f(x): + with set_xla_metadata(a="b"): + return jax.lax.cond(x < 0., sin, cos, x) + + hlo_lines = f.lower(1.).as_text().split("\n") + sin_hlo, = [line for line in hlo_lines if "stablehlo.sine" in line] + cos_hlo, = [line for line in hlo_lines if "stablehlo.cosine" in line] + self.assertIn('mhlo.frontend_attributes = {a = "b"}', sin_hlo) + self.assertIn('mhlo.frontend_attributes = {a = "b"}', cos_hlo) + + def test_cond_annotates_branches_and_none_unsets(self): + sin = jnp.sin + + def cos(x): + with set_xla_metadata(a=None): + return jnp.cos(x) + + @jax.jit + def f(x): + with set_xla_metadata(a="b"): + return jax.lax.cond(x < 0., sin, cos, x) + + hlo_lines = f.lower(1.).as_text().split("\n") + sin_hlo, = [line for line in hlo_lines if "stablehlo.sine" in line] + cos_hlo, = [line for line in hlo_lines if "stablehlo.cosine" in line] + self.assertIn( 'mhlo.frontend_attributes = {a = "b"}', sin_hlo) + self.assertNotIn('mhlo.frontend_attributes = {a = "b"}', cos_hlo) + def test_nested_jit(self): @jax.jit def f(x, y): @@ -255,10 +288,10 @@ def f2(x, y): with set_xla_metadata(a="b"): return (x + y, y * 2.0) - f_vmap_jaxpr = jax.make_jaxpr(jax.vmap(f2, in_axes=(0, None))) + f2_vmap = jax.vmap(f2, in_axes=(0, None)) self.assertIn( 'mhlo.frontend_attributes = {a = "b"}', - f_vmap_jaxpr.lower(jnp.arange(5.0), 1.0).as_text(), + jax.jit(f2_vmap).lower(jnp.arange(5.0), 1.0).as_text(), ) def test_multiple_instructions(self): diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 73bf2eb3850d..0ac09a4a6594 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "df971129bd82e381954da0185b534220e21798a4" -XLA_SHA256 = "11e9a568320cf7e7d61819620fd369927527ecefb68d5d1154b1521456bbdb72" +XLA_COMMIT = "0d1b60216ea13b0d261d59552a0f7ef20c4f76c5" +XLA_SHA256 = "357b37cc7c439580344ce0305bad88ef841f29743a99ea8e2253e64a32e139c6" def repo(): tf_http_archive(