diff --git a/.devops/nix/nixpkgs-instances.nix b/.devops/nix/nixpkgs-instances.nix index 90d683a713a..40cf58f196c 100644 --- a/.devops/nix/nixpkgs-instances.nix +++ b/.devops/nix/nixpkgs-instances.nix @@ -4,7 +4,7 @@ # the module `{ pkgs ... }: { /* config */ }` implicitly uses # `_module.args.pkgs` (defined in this case by flake-parts). perSystem = - { system, ... }: + { lib, system, ... }: { _module.args = { # Note: bringing up https://zimbatm.com/notes/1000-instances-of-nixpkgs @@ -33,7 +33,7 @@ "CUDA EULA" "cuDNN EULA" ] - ) (p.meta.licenses or [ p.meta.license ]); + ) (p.meta.licenses or (lib.toList p.meta.license)); }; # Ensure dependencies use ROCm consistently pkgsRocm = import inputs.nixpkgs { diff --git a/.devops/nix/package-gguf-py.nix b/.devops/nix/package-gguf-py.nix index cca2f36a5bd..de3ac841fb4 100644 --- a/.devops/nix/package-gguf-py.nix +++ b/.devops/nix/package-gguf-py.nix @@ -3,6 +3,7 @@ llamaVersion, numpy, tqdm, + requests, sentencepiece, pyyaml, poetry-core, @@ -20,6 +21,7 @@ buildPythonPackage { tqdm sentencepiece pyyaml + requests ]; src = lib.cleanSource ../../gguf-py; pythonImportsCheck = [ diff --git a/.devops/nix/scope.nix b/.devops/nix/scope.nix index 478e8c4228a..b4328a771e1 100644 --- a/.devops/nix/scope.nix +++ b/.devops/nix/scope.nix @@ -7,13 +7,6 @@ let pythonPackages = python3.pkgs; - buildPythonPackage = pythonPackages.buildPythonPackage; - numpy = pythonPackages.numpy; - tqdm = pythonPackages.tqdm; - sentencepiece = pythonPackages.sentencepiece; - pyyaml = pythonPackages.pyyaml; - poetry-core = pythonPackages.poetry-core; - pytestCheckHook = pythonPackages.pytestCheckHook; in # We're using `makeScope` instead of just writing out an attrset @@ -23,17 +16,18 @@ in lib.makeScope newScope (self: { inherit llamaVersion; gguf-py = self.callPackage ./package-gguf-py.nix { - inherit - buildPythonPackage + inherit (pythonPackages) numpy tqdm sentencepiece - poetry-core pyyaml pytestCheckHook + requests + buildPythonPackage + poetry-core ; }; - python-scripts = self.callPackage ./python-scripts.nix { inherit buildPythonPackage poetry-core; }; + python-scripts = self.callPackage ./python-scripts.nix { inherit (pythonPackages) buildPythonPackage poetry-core; }; llama-cpp = self.callPackage ./package.nix { }; docker = self.callPackage ./docker.nix { }; docker-min = self.callPackage ./docker.nix { interactive = false; }; diff --git a/.devops/vulkan.Dockerfile b/.devops/vulkan.Dockerfile index 9797c5e0f31..5d6c87ed6b9 100644 --- a/.devops/vulkan.Dockerfile +++ b/.devops/vulkan.Dockerfile @@ -54,6 +54,7 @@ RUN apt-get update \ build-essential \ git \ python3 \ + python3-dev \ python3-pip \ python3-wheel \ && pip install --break-system-packages --upgrade setuptools \ diff --git a/.github/workflows/build-cache.yml b/.github/workflows/build-cache.yml index 6a22e41c3b5..3de0be9fad5 100644 --- a/.github/workflows/build-cache.yml +++ b/.github/workflows/build-cache.yml @@ -16,7 +16,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Get latest Vulkan SDK version id: vulkan_sdk_version @@ -24,7 +24,7 @@ jobs: echo "VULKAN_SDK_VERSION=$(curl https://vulkan.lunarg.com/sdk/latest/linux.txt)" >> "$GITHUB_ENV" - name: Setup Cache - uses: actions/cache@v4 + uses: actions/cache@v5 id: cache-sdk with: path: ./vulkan_sdk @@ -47,10 +47,10 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Setup Cache - uses: actions/cache@v4 + uses: actions/cache@v5 id: cache-toolchain with: path: ./spacemit_toolchain @@ -73,10 +73,10 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Setup Cache - uses: actions/cache@v4 + uses: actions/cache@v5 id: cache-rocm with: path: C:\Program Files\AMD\ROCm diff --git a/.github/workflows/build-cmake-pkg.yml b/.github/workflows/build-cmake-pkg.yml index 510352a5ccf..259efa43c8f 100644 --- a/.github/workflows/build-cmake-pkg.yml +++ b/.github/workflows/build-cmake-pkg.yml @@ -7,7 +7,7 @@ jobs: linux: runs-on: ubuntu-24.04 steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 with: fetch-depth: 0 diff --git a/.github/workflows/build-linux-cross.yml b/.github/workflows/build-linux-cross.yml index 4d3b687a516..8b6ebaf4a37 100644 --- a/.github/workflows/build-linux-cross.yml +++ b/.github/workflows/build-linux-cross.yml @@ -8,7 +8,7 @@ jobs: # runs-on: ubuntu-24.04 # steps: - # - uses: actions/checkout@v4 + # - uses: actions/checkout@v6 # - name: Setup Riscv # run: | # sudo dpkg --add-architecture riscv64 @@ -52,7 +52,7 @@ jobs: # runs-on: ubuntu-24.04 # steps: - # - uses: actions/checkout@v4 + # - uses: actions/checkout@v6 # - name: Setup Riscv # run: | # sudo dpkg --add-architecture riscv64 @@ -99,7 +99,7 @@ jobs: # runs-on: ubuntu-24.04 # steps: - # - uses: actions/checkout@v4 + # - uses: actions/checkout@v6 # - name: Setup Arm64 # run: | # sudo dpkg --add-architecture arm64 @@ -146,7 +146,7 @@ jobs: container: debian@sha256:653dfb9f86c3782e8369d5f7d29bb8faba1f4bff9025db46e807fa4c22903671 steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Setup LoongArch run: | rm -f /etc/apt/sources.list.d/* @@ -201,7 +201,7 @@ jobs: container: debian@sha256:653dfb9f86c3782e8369d5f7d29bb8faba1f4bff9025db46e807fa4c22903671 steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Setup LoongArch run: | rm -f /etc/apt/sources.list.d/* @@ -262,10 +262,10 @@ jobs: SPACEMIT_IME_TOOLCHAIN_VERSION: "1.1.2" steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Use SpacemiT Toolchain Cache - uses: actions/cache@v4 + uses: actions/cache@v5 id: cache-toolchain with: path: ./spacemit_toolchain diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 187c8614371..6c7ab711431 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -21,7 +21,8 @@ on: '**/*.m', '**/*.metal', '**/*.comp', - '**/*.glsl' + '**/*.glsl', + '**/*.wgsl' ] pull_request: @@ -42,7 +43,8 @@ on: '**/*.m', '**/*.metal', '**/*.comp', - '**/*.glsl' + '**/*.glsl', + '**/*.wgsl' ] concurrency: @@ -63,7 +65,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: ccache uses: ggml-org/ccache-action@v1.2.16 @@ -99,7 +101,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: ccache uses: ggml-org/ccache-action@v1.2.16 @@ -135,7 +137,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: ccache uses: ggml-org/ccache-action@v1.2.16 @@ -189,7 +191,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: ccache uses: ggml-org/ccache-action@v1.2.16 @@ -269,7 +271,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: ccache uses: ggml-org/ccache-action@v1.2.16 @@ -291,7 +293,9 @@ jobs: cmake -B build \ -DLLAMA_FATAL_WARNINGS=ON \ -DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \ + -DGGML_SANITIZE_${{ matrix.sanitizer }}=ON \ -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} + cmake --build build --config ${{ matrix.build_type }} -j $(nproc) - name: Build (no OpenMP) @@ -301,8 +305,10 @@ jobs: cmake -B build \ -DLLAMA_FATAL_WARNINGS=ON \ -DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \ + -DGGML_SANITIZE_${{ matrix.sanitizer }}=ON \ -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \ -DGGML_OPENMP=OFF + cmake --build build --config ${{ matrix.build_type }} -j $(nproc) - name: Test @@ -317,7 +323,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Dependencies id: depends @@ -347,7 +353,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 # - name: ccache # uses: ggml-org/ccache-action@v1.2.16 @@ -380,7 +386,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: ccache uses: ggml-org/ccache-action@v1.2.16 @@ -414,7 +420,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: ccache uses: ggml-org/ccache-action@v1.2.16 @@ -436,7 +442,7 @@ jobs: echo "VULKAN_SDK_VERSION=$(curl https://vulkan.lunarg.com/sdk/latest/linux.txt)" >> "$GITHUB_ENV" - name: Use Vulkan SDK Cache - uses: actions/cache@v4 + uses: actions/cache@v5 id: cache-sdk with: path: ./vulkan_sdk @@ -464,7 +470,7 @@ jobs: export GGML_VK_VISIBLE_DEVICES=0 export GGML_VK_DISABLE_F16=1 # This is using llvmpipe and runs slower than other backends - ctest -L main --verbose --timeout 4200 + ctest -L main --verbose --timeout 4800 ubuntu-24-cmake-webgpu: runs-on: ubuntu-24.04 @@ -472,7 +478,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: ccache uses: ggml-org/ccache-action@v1.2.16 @@ -494,7 +500,7 @@ jobs: echo "VULKAN_SDK_VERSION=$(curl https://vulkan.lunarg.com/sdk/latest/linux.txt)" >> "$GITHUB_ENV" - name: Use Vulkan SDK Cache - uses: actions/cache@v4 + uses: actions/cache@v5 id: cache-sdk with: path: ./vulkan_sdk @@ -543,7 +549,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: ccache uses: ggml-org/ccache-action@v1.2.16 @@ -585,7 +591,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Dependencies id: depends @@ -616,7 +622,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Dependencies id: depends @@ -644,7 +650,7 @@ jobs: continue-on-error: true steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: add oneAPI to apt shell: bash @@ -668,7 +674,7 @@ jobs: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: ccache uses: ggml-org/ccache-action@v1.2.16 @@ -693,7 +699,7 @@ jobs: continue-on-error: true steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: add oneAPI to apt shell: bash @@ -717,7 +723,7 @@ jobs: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: ccache uses: ggml-org/ccache-action@v1.2.16 @@ -749,7 +755,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: ccache uses: ggml-org/ccache-action@v1.2.16 @@ -781,7 +787,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: ccache uses: ggml-org/ccache-action@v1.2.16 @@ -813,7 +819,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Build id: cmake_build @@ -843,7 +849,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: ccache uses: ggml-org/ccache-action@v1.2.16 @@ -853,7 +859,7 @@ jobs: save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} - name: Download xcframework artifact - uses: actions/download-artifact@v4 + uses: actions/download-artifact@v7 with: name: llama-xcframework path: build-apple/llama.xcframework/ @@ -885,7 +891,7 @@ jobs: steps: - name: Clone - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: ccache uses: ggml-org/ccache-action@v1.2.16 @@ -954,7 +960,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: ccache uses: ggml-org/ccache-action@v1.2.16 @@ -1053,7 +1059,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Install dependencies env: @@ -1092,7 +1098,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Install ccache uses: ggml-org/ccache-action@v1.2.16 @@ -1145,7 +1151,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: ccache uses: ggml-org/ccache-action@v1.2.16 @@ -1177,7 +1183,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Grab rocWMMA package id: grab_rocwmma @@ -1187,7 +1193,7 @@ jobs: 7z x data.tar - name: Use ROCm Installation Cache - uses: actions/cache@v4 + uses: actions/cache@v5 id: cache-rocm with: path: C:\Program Files\AMD\ROCm @@ -1239,7 +1245,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Setup Xcode uses: maxim-lobanov/setup-xcode@v1 @@ -1269,7 +1275,7 @@ jobs: ./build-xcframework.sh - name: Upload xcframework artifact - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: llama-xcframework path: build-apple/llama.xcframework/ @@ -1285,7 +1291,7 @@ jobs: steps: - name: Clone - uses: actions/checkout@v4 + uses: actions/checkout@v6 # Disabled due to size (400MB) and always 0 cache hits # - name: ccache @@ -1295,7 +1301,7 @@ jobs: # evict-old-files: 1d - name: Set up JDK - uses: actions/setup-java@v3 + uses: actions/setup-java@v5 with: java-version: 17 distribution: zulu @@ -1327,7 +1333,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Install OpenCL Headers and Libs id: install_opencl @@ -1371,7 +1377,7 @@ jobs: id: update_presets if: ${{ matrix.build == 'arm64-snapdragon' }} run: | - cp docs/backend/hexagon/CMakeUserPresets.json . + cp docs/backend/snapdragon/CMakeUserPresets.json . - name: Build id: ndk_build @@ -1402,7 +1408,7 @@ jobs: runs-on: ${{ matrix.arch == 'aarch64' && 'ubuntu-24.04-arm' || 'ubuntu-24.04' }} steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 0 @@ -1460,7 +1466,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: ccache uses: ggml-org/ccache-action@v1.2.16 @@ -1486,7 +1492,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: ccache uses: ggml-org/ccache-action@v1.2.16 @@ -1512,7 +1518,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: ccache uses: ggml-org/ccache-action@v1.2.16 @@ -1530,7 +1536,7 @@ jobs: - name: Test id: ggml-ci run: | - LLAMA_ARG_THREADS=$(nproc) bash ./ci/run.sh ./tmp/results ./tmp/mnt + LLAMA_ARG_THREADS=$(nproc) GG_BUILD_HIGH_PERF=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt ggml-ci-arm64-cpu-high-perf: runs-on: ubuntu-22.04-arm @@ -1538,7 +1544,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: ccache uses: ggml-org/ccache-action@v1.2.16 @@ -1556,7 +1562,7 @@ jobs: - name: Test id: ggml-ci run: | - LLAMA_ARG_THREADS=$(nproc) GG_BUILD_NO_SVE=1 GG_BUILD_NO_BF16=1 GG_BUILD_EXTRA_TESTS_0=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt + LLAMA_ARG_THREADS=$(nproc) GG_BUILD_HIGH_PERF=1 GG_BUILD_NO_SVE=1 GG_BUILD_NO_BF16=1 GG_BUILD_EXTRA_TESTS_0=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt ggml-ci-arm64-cpu-high-perf-sve: runs-on: ubuntu-22.04-arm @@ -1564,7 +1570,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: ccache uses: ggml-org/ccache-action@v1.2.16 @@ -1590,7 +1596,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Test id: ggml-ci @@ -1604,7 +1610,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Test id: ggml-ci @@ -1618,7 +1624,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Test id: ggml-ci @@ -1632,7 +1638,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Test id: ggml-ci @@ -1645,7 +1651,7 @@ jobs: # steps: # - name: Clone # id: checkout - # uses: actions/checkout@v4 + # uses: actions/checkout@v6 # - name: Test # id: ggml-ci @@ -1659,7 +1665,7 @@ jobs: # steps: # - name: Clone # id: checkout - # uses: actions/checkout@v4 + # uses: actions/checkout@v6 # - name: Test # id: ggml-ci @@ -1673,7 +1679,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Test id: ggml-ci @@ -1686,7 +1692,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Dawn Dependency id: dawn-depends @@ -1714,7 +1720,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Test id: ggml-ci @@ -1728,7 +1734,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: ccache uses: ggml-org/ccache-action@v1.2.16 @@ -1773,7 +1779,7 @@ jobs: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Check environment run: | @@ -1875,7 +1881,7 @@ jobs: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Setup ccache run: | @@ -1969,7 +1975,7 @@ jobs: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Setup ccache run: | @@ -2043,7 +2049,7 @@ jobs: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Setup ccache run: | @@ -2089,7 +2095,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Dependencies id: depends diff --git a/.github/workflows/check-vendor.yml b/.github/workflows/check-vendor.yml index 7b3016079cc..1671ed7b8bd 100644 --- a/.github/workflows/check-vendor.yml +++ b/.github/workflows/check-vendor.yml @@ -19,16 +19,16 @@ on: jobs: check-vendor: - runs-on: ubuntu-latest + runs-on: ubuntu-slim steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 0 - name: Setup Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v6 with: python-version: '3.x' diff --git a/.github/workflows/close-issue.yml b/.github/workflows/close-issue.yml index cbfc4990dbc..ec3df08b2d6 100644 --- a/.github/workflows/close-issue.yml +++ b/.github/workflows/close-issue.yml @@ -10,12 +10,12 @@ permissions: jobs: close-issues: - runs-on: ubuntu-latest + runs-on: ubuntu-slim permissions: issues: write pull-requests: write steps: - - uses: actions/stale@v5 + - uses: actions/stale@v10 with: exempt-issue-labels: "refactoring,help wanted,good first issue,research 🔬,bug,roadmap" days-before-issue-stale: 30 diff --git a/.github/workflows/copilot-setup-steps.yml b/.github/workflows/copilot-setup-steps.yml index 5f733e684e5..fc3cec5ea19 100644 --- a/.github/workflows/copilot-setup-steps.yml +++ b/.github/workflows/copilot-setup-steps.yml @@ -26,7 +26,7 @@ jobs: # If you do not check out your code, Copilot will do this for you. steps: - name: Checkout code - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: ccache uses: ggml-org/ccache-action@v1.2.16 @@ -45,7 +45,7 @@ jobs: sudo chmod +x /usr/local/bin/git-clang-format - name: Set up Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: '3.11' diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index d9fe0686d35..8062177ba5a 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -49,7 +49,7 @@ jobs: - { tag: "rocm", dockerfile: ".devops/rocm.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-22.04" } steps: - name: Check out the repo - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 0 # preserve git history, so we can determine the build number @@ -63,7 +63,7 @@ jobs: uses: docker/setup-buildx-action@v3 - name: Log in to Docker Hub - uses: docker/login-action@v2 + uses: docker/login-action@v3 with: registry: ghcr.io username: ${{ github.repository_owner }} @@ -208,7 +208,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 0 diff --git a/.github/workflows/editorconfig.yml b/.github/workflows/editorconfig.yml index f02b7c2194b..702dc89f5b1 100644 --- a/.github/workflows/editorconfig.yml +++ b/.github/workflows/editorconfig.yml @@ -20,9 +20,9 @@ concurrency: jobs: editorconfig: - runs-on: ubuntu-latest + runs-on: ubuntu-slim steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - uses: editorconfig-checker/action-editorconfig-checker@v2 with: version: v3.0.3 diff --git a/.github/workflows/gguf-publish.yml b/.github/workflows/gguf-publish.yml index 3ca4d305810..0e957664592 100644 --- a/.github/workflows/gguf-publish.yml +++ b/.github/workflows/gguf-publish.yml @@ -21,12 +21,12 @@ on: jobs: deploy: - runs-on: ubuntu-latest + runs-on: ubuntu-slim steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Set up Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: '3.9.x' - name: Install dependencies diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml index 0b0f300aa40..eab20c68811 100644 --- a/.github/workflows/labeler.yml +++ b/.github/workflows/labeler.yml @@ -7,11 +7,11 @@ jobs: permissions: contents: read pull-requests: write - runs-on: ubuntu-latest + runs-on: ubuntu-slim steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 with: repository: "ggml-org/llama.cpp" - - uses: actions/labeler@v5 + - uses: actions/labeler@v6 with: configuration-path: '.github/labeler.yml' diff --git a/.github/workflows/pre-tokenizer-hashes.yml b/.github/workflows/pre-tokenizer-hashes.yml index dff998e2393..7126b62b690 100644 --- a/.github/workflows/pre-tokenizer-hashes.yml +++ b/.github/workflows/pre-tokenizer-hashes.yml @@ -12,14 +12,14 @@ on: jobs: pre-tokenizer-hashes: - runs-on: ubuntu-latest + runs-on: ubuntu-slim steps: - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Set up Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: '3.11' diff --git a/.github/workflows/python-check-requirements.yml b/.github/workflows/python-check-requirements.yml index 46e80aecd0a..1219b874592 100644 --- a/.github/workflows/python-check-requirements.yml +++ b/.github/workflows/python-check-requirements.yml @@ -20,13 +20,13 @@ concurrency: jobs: python-check-requirements: - runs-on: ubuntu-latest + runs-on: ubuntu-slim name: check-requirements steps: - name: Check out source repository - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Set up Python environment - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: "3.11" - name: Run check-requirements.sh script diff --git a/.github/workflows/python-lint.yml b/.github/workflows/python-lint.yml index ddfdf73b8fc..8d1dd7a7d5c 100644 --- a/.github/workflows/python-lint.yml +++ b/.github/workflows/python-lint.yml @@ -15,13 +15,13 @@ concurrency: jobs: flake8-lint: - runs-on: ubuntu-latest + runs-on: ubuntu-slim name: Lint steps: - name: Check out source repository - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Set up Python environment - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: "3.11" - name: flake8 Lint diff --git a/.github/workflows/python-type-check.yml b/.github/workflows/python-type-check.yml index 373bb601020..e801a9f42e6 100644 --- a/.github/workflows/python-type-check.yml +++ b/.github/workflows/python-type-check.yml @@ -24,14 +24,12 @@ jobs: name: pyright type-check steps: - name: Check out source repository - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Set up Python environment - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: "3.11" - - name: Install Python dependencies - # TODO: use a venv - run: pip install -r requirements/requirements-all.txt + pip-install: -r requirements/requirements-all.txt - name: Type-check with Pyright uses: jakebailey/pyright-action@v2 with: diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index d8b3b95df0d..1914c084895 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -27,7 +27,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 0 @@ -63,7 +63,7 @@ jobs: tar -czvf llama-${{ steps.tag.outputs.name }}-bin-macos-arm64.tar.gz -s ",./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin . - name: Upload artifacts - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: path: llama-${{ steps.tag.outputs.name }}-bin-macos-arm64.tar.gz name: llama-bin-macos-arm64.tar.gz @@ -74,7 +74,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 0 @@ -111,7 +111,7 @@ jobs: tar -czvf llama-${{ steps.tag.outputs.name }}-bin-macos-x64.tar.gz -s ",./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin . - name: Upload artifacts - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: path: llama-${{ steps.tag.outputs.name }}-bin-macos-x64.tar.gz name: llama-bin-macos-x64.tar.gz @@ -133,7 +133,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 0 @@ -173,7 +173,7 @@ jobs: tar -czvf llama-${{ steps.tag.outputs.name }}-bin-ubuntu-${{ matrix.build }}.tar.gz --transform "s,./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin . - name: Upload artifacts - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-${{ matrix.build }}.tar.gz name: llama-bin-ubuntu-${{ matrix.build }}.tar.gz @@ -184,7 +184,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 0 @@ -226,7 +226,7 @@ jobs: tar -czvf llama-${{ steps.tag.outputs.name }}-bin-ubuntu-vulkan-x64.tar.gz --transform "s,./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin . - name: Upload artifacts - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-vulkan-x64.tar.gz name: llama-bin-ubuntu-vulkan-x64.tar.gz @@ -242,7 +242,7 @@ jobs: steps: - name: Clone - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 0 @@ -278,7 +278,7 @@ jobs: 7z a -snl llama-bin-win-cpu-${{ matrix.arch }}.zip .\build\bin\Release\* - name: Upload artifacts - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: path: llama-bin-win-cpu-${{ matrix.arch }}.zip name: llama-bin-win-cpu-${{ matrix.arch }}.zip @@ -305,7 +305,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: ccache uses: ggml-org/ccache-action@v1.2.16 @@ -360,7 +360,7 @@ jobs: 7z a -snl llama-bin-win-${{ matrix.backend }}-${{ matrix.arch }}.zip .\build\bin\Release\${{ matrix.target }}.dll - name: Upload artifacts - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: path: llama-bin-win-${{ matrix.backend }}-${{ matrix.arch }}.zip name: llama-bin-win-${{ matrix.backend }}-${{ matrix.arch }}.zip @@ -375,7 +375,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Install ccache uses: ggml-org/ccache-action@v1.2.16 @@ -416,7 +416,7 @@ jobs: 7z a -snl llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip .\build\bin\Release\ggml-cuda.dll - name: Upload artifacts - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: path: llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip name: llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip @@ -431,7 +431,7 @@ jobs: 7z a cudart-llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip $dst\* - name: Upload Cuda runtime - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: path: cudart-llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip name: cudart-llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip @@ -451,7 +451,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: ccache uses: ggml-org/ccache-action@v1.2.16 @@ -511,7 +511,7 @@ jobs: 7z a -snl llama-bin-win-sycl-x64.zip ./build/bin/* - name: Upload the release package - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: path: llama-bin-win-sycl-x64.zip name: llama-bin-win-sycl-x64.zip @@ -531,7 +531,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Grab rocWMMA package id: grab_rocwmma @@ -542,7 +542,7 @@ jobs: - name: Cache ROCm Installation id: cache-rocm - uses: actions/cache@v4 + uses: actions/cache@v5 with: path: C:\Program Files\AMD\ROCm key: rocm-${{ env.HIPSDK_INSTALLER_VERSION }}-${{ runner.os }} @@ -617,7 +617,7 @@ jobs: 7z a -snl llama-bin-win-hip-${{ matrix.name }}-x64.zip .\build\bin\* - name: Upload artifacts - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: path: llama-bin-win-hip-${{ matrix.name }}-x64.zip name: llama-bin-win-hip-${{ matrix.name }}-x64.zip @@ -627,7 +627,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 0 @@ -672,7 +672,7 @@ jobs: zip -r -y llama-${{ steps.tag.outputs.name }}-xcframework.zip build-apple/llama.xcframework - name: Upload artifacts - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: path: llama-${{ steps.tag.outputs.name }}-xcframework.zip name: llama-${{ steps.tag.outputs.name }}-xcframework.zip @@ -703,7 +703,7 @@ jobs: runs-on: ${{ matrix.arch == 'aarch64' && 'ubuntu-24.04-arm' || 'ubuntu-24.04' }} steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 0 @@ -763,7 +763,7 @@ jobs: tar -czvf llama-${{ steps.tag.outputs.name }}-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}${{ matrix.use_acl_graph == 'on' && '-aclgraph' || '' }}.tar.gz --transform "s,./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin . - name: Upload artifacts - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: path: llama-${{ steps.tag.outputs.name }}-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}${{ matrix.use_acl_graph == 'on' && '-aclgraph' || '' }}.tar.gz name: llama-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}${{ matrix.use_acl_graph == 'on' && '-aclgraph' || '' }}.tar.gz @@ -794,7 +794,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 0 @@ -804,7 +804,7 @@ jobs: - name: Download artifacts id: download-artifact - uses: actions/download-artifact@v4 + uses: actions/download-artifact@v7 with: path: ./artifact merge-multiple: true @@ -887,7 +887,7 @@ jobs: - name: Upload release id: upload_release - uses: actions/github-script@v3 + uses: actions/github-script@v8 with: github-token: ${{secrets.GITHUB_TOKEN}} script: | @@ -897,7 +897,7 @@ jobs: for (let file of await fs.readdirSync('./release')) { if (path.extname(file) === '.zip' || file.endsWith('.tar.gz')) { console.log('uploadReleaseAsset', file); - await github.repos.uploadReleaseAsset({ + await github.rest.repos.uploadReleaseAsset({ owner: context.repo.owner, repo: context.repo.repo, release_id: release_id, diff --git a/.github/workflows/server-webui.yml b/.github/workflows/server-webui.yml index 318003c5ccc..94899c93761 100644 --- a/.github/workflows/server-webui.yml +++ b/.github/workflows/server-webui.yml @@ -8,10 +8,6 @@ on: description: 'Commit SHA1 to build' required: false type: string - slow_tests: - description: 'Run slow tests' - required: true - type: boolean push: branches: - master @@ -37,14 +33,14 @@ jobs: continue-on-error: true steps: - name: Checkout code - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 0 ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }} - name: Setup Node.js id: node - uses: actions/setup-node@v4 + uses: actions/setup-node@v6 with: node-version: "22" cache: "npm" @@ -101,119 +97,3 @@ jobs: if: ${{ always() && steps.playwright.conclusion == 'success' }} run: npm run test:e2e working-directory: tools/server/webui - - server-build: - runs-on: ubuntu-latest - - strategy: - matrix: - sanitizer: [ADDRESS, UNDEFINED] # THREAD is broken - build_type: [RelWithDebInfo] - include: - - build_type: Release - sanitizer: "" - fail-fast: false # While -DLLAMA_SANITIZE_THREAD=ON is broken - - steps: - - name: Dependencies - id: depends - run: | - sudo apt-get update - sudo apt-get -y install \ - build-essential \ - xxd \ - git \ - cmake \ - curl \ - wget \ - language-pack-en \ - libssl-dev - - - name: Clone - id: checkout - uses: actions/checkout@v4 - with: - fetch-depth: 0 - ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }} - - - name: Python setup - id: setup_python - uses: actions/setup-python@v5 - with: - python-version: '3.11' - - - name: Tests dependencies - id: test_dependencies - run: | - pip install -r tools/server/tests/requirements.txt - - - name: Setup Node.js for WebUI - uses: actions/setup-node@v4 - with: - node-version: "22" - cache: "npm" - cache-dependency-path: "tools/server/webui/package-lock.json" - - - name: Install WebUI dependencies - run: npm ci - working-directory: tools/server/webui - - - name: Build WebUI - run: npm run build - working-directory: tools/server/webui - - - name: Build (no OpenMP) - id: cmake_build_no_openmp - if: ${{ matrix.sanitizer == 'THREAD' }} - run: | - cmake -B build \ - -DGGML_NATIVE=OFF \ - -DLLAMA_BUILD_SERVER=ON \ - -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \ - -DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \ - -DGGML_OPENMP=OFF ; - cmake --build build --config ${{ matrix.build_type }} -j $(nproc) --target llama-server - - - name: Build (sanitizers) - id: cmake_build_sanitizers - if: ${{ matrix.sanitizer != '' && matrix.sanitizer != 'THREAD' }} - run: | - cmake -B build \ - -DGGML_NATIVE=OFF \ - -DLLAMA_BUILD_SERVER=ON \ - -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \ - -DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON ; - cmake --build build --config ${{ matrix.build_type }} -j $(nproc) --target llama-server - - - name: Build (sanitizers) - id: cmake_build - if: ${{ matrix.sanitizer == '' }} - run: | - cmake -B build \ - -DGGML_NATIVE=OFF \ - -DLLAMA_BUILD_SERVER=ON \ - -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} ; - cmake --build build --config ${{ matrix.build_type }} -j $(nproc) --target llama-server - - - name: Tests - id: server_integration_tests - if: ${{ matrix.sanitizer == '' }} - env: - GITHUB_ACTIONS: "true" - run: | - cd tools/server/tests - ./tests.sh - - - name: Tests (sanitizers) - id: server_integration_tests_sanitizers - if: ${{ matrix.sanitizer != '' }} - run: | - cd tools/server/tests - LLAMA_SANITIZE=1 ./tests.sh - - - name: Slow tests - id: server_integration_tests_slow - if: ${{ (github.event.schedule || github.event.inputs.slow_tests == 'true') && matrix.build_type == 'Release' }} - run: | - cd tools/server/tests - SLOW_TESTS=1 ./tests.sh diff --git a/.github/workflows/server.yml b/.github/workflows/server.yml index ab7c520e115..99d05226ba5 100644 --- a/.github/workflows/server.yml +++ b/.github/workflows/server.yml @@ -36,7 +36,7 @@ jobs: strategy: matrix: - sanitizer: [ADDRESS, UNDEFINED] # THREAD is broken + sanitizer: [ADDRESS, UNDEFINED] # THREAD is very slow build_type: [RelWithDebInfo] include: - build_type: Release @@ -45,7 +45,7 @@ jobs: - build_type: Release sanitizer: "" extra_args: "LLAMA_ARG_BACKEND_SAMPLING=1" - fail-fast: false # While -DLLAMA_SANITIZE_THREAD=ON is broken + fail-fast: false steps: - name: Dependencies @@ -64,7 +64,7 @@ jobs: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 0 ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }} @@ -72,35 +72,47 @@ jobs: - name: Build id: cmake_build run: | - cmake -B build -DLLAMA_BUILD_BORINGSSL=ON - cmake --build build --config ${{ matrix.build_type }} -j ${env:NUMBER_OF_PROCESSORS} --target llama-server + cmake -B build \ + -DLLAMA_BUILD_BORINGSSL=ON \ + -DGGML_SCHED_NO_REALLOC=ON \ + -DGGML_SANITIZE_ADDRESS=${{ matrix.sanitizer == 'ADDRESS' }} \ + -DGGML_SANITIZE_THREAD=${{ matrix.sanitizer == 'THREAD' }} \ + -DGGML_SANITIZE_UNDEFINED=${{ matrix.sanitizer == 'UNDEFINED' }} \ + -DLLAMA_SANITIZE_ADDRESS=${{ matrix.sanitizer == 'ADDRESS' }} \ + -DLLAMA_SANITIZE_THREAD=${{ matrix.sanitizer == 'THREAD' }} \ + -DLLAMA_SANITIZE_UNDEFINED=${{ matrix.sanitizer == 'UNDEFINED' }} + cmake --build build --config ${{ matrix.build_type }} -j $(nproc) --target llama-server - name: Python setup id: setup_python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: '3.11' - - - name: Tests dependencies - id: test_dependencies - run: | - pip install -r tools/server/tests/requirements.txt + pip-install: -r tools/server/tests/requirements.txt - name: Tests id: server_integration_tests - if: ${{ (!matrix.disabled_on_pr || !github.event.pull_request) && matrix.build_type == 'Release' }} + if: ${{ (!matrix.disabled_on_pr || !github.event.pull_request) }} run: | cd tools/server/tests export ${{ matrix.extra_args }} pytest -v -x -m "not slow" + - name: Slow tests + id: server_integration_tests_slow + if: ${{ (github.event.schedule || github.event.inputs.slow_tests == 'true') && matrix.build_type == 'Release' }} + run: | + cd tools/server/tests + export ${{ matrix.extra_args }} + SLOW_TESTS=1 pytest -v -x + server-windows: runs-on: windows-2022 steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 0 ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }} @@ -108,19 +120,15 @@ jobs: - name: Build id: cmake_build run: | - cmake -B build -DLLAMA_BUILD_BORINGSSL=ON + cmake -B build -DLLAMA_BUILD_BORINGSSL=ON -DGGML_SCHED_NO_REALLOC=ON cmake --build build --config Release -j ${env:NUMBER_OF_PROCESSORS} --target llama-server - name: Python setup id: setup_python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: '3.11' - - - name: Tests dependencies - id: test_dependencies - run: | - pip install -r tools/server/tests/requirements.txt + pip-install: -r tools/server/tests/requirements.txt - name: Tests id: server_integration_tests diff --git a/.github/workflows/update-ops-docs.yml b/.github/workflows/update-ops-docs.yml index d5e264b34f4..2ab06eb9811 100644 --- a/.github/workflows/update-ops-docs.yml +++ b/.github/workflows/update-ops-docs.yml @@ -14,14 +14,14 @@ on: jobs: update-ops-docs: - runs-on: ubuntu-latest + runs-on: ubuntu-slim steps: - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Set up Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: '3.x' diff --git a/.github/workflows/winget.yml b/.github/workflows/winget.yml index d3d9be23ce5..2047c276f8d 100644 --- a/.github/workflows/winget.yml +++ b/.github/workflows/winget.yml @@ -21,23 +21,24 @@ jobs: - name: Find latest release id: find_latest_release - uses: actions/github-script@v6 + uses: actions/github-script@v8 with: script: | const { data: releases } = await github.rest.repos.listReleases({ owner: context.repo.owner, repo: context.repo.repo, }); - console.log("Latest release:", releases[0].tag_name); - return releases[0].tag_name; + const { tag_name: version, assets: assets } = releases.find(({assets}) => assets.find(asset => asset.name.includes('win-vulkan'))); + const { browser_download_url: asset_url } = assets.find(asset => asset.name.includes('win-vulkan')); + console.log("Latest release:", version); + core.setOutput('VERSION', version); + core.setOutput('ASSETURL', asset_url); - name: Update manifest - env: - VERSION: ${{ steps.find_latest_release.outputs.result }} run: | echo "Updating manifest..." - komac update --version ${{ env.VERSION }} \ - --urls "https://github.com/ggml-org/llama.cpp/releases/download/${{ env.VERSION }}/llama-${{ env.VERSION }}-bin-win-vulkan-x64.zip" \ + komac update --version ${{ steps.find_latest_release.outputs.VERSION }} \ + --urls "${{ steps.find_latest_release.outputs.ASSETURL }}" \ --token ${{ secrets.WINGET_GITHUB_TOKEN }} \ --submit \ ggml.llamacpp diff --git a/AUTHORS b/AUTHORS index 0af9f44ad4a..c297f3c2178 100644 --- a/AUTHORS +++ b/AUTHORS @@ -1,127 +1,228 @@ -# date: Sat Mar 8 18:23:52 EET 2025 +# date: Mon Feb 2 08:45:04 EET 2026 # this file is auto-generated by scripts/gen-authors.sh +Нияз Гарифзянов <112617865+garrnizon@users.noreply.github.com> +杨朱 · Kiki +エシュナヴァリシア <148695646+eternaphia@users.noreply.github.com> +吴小白 <296015668@qq.com> +源文雨 <41315874+fumiama@users.noreply.github.com> +蕭澧邦 <45505768+shou692199@users.noreply.github.com> +도로로도로또 <60079918+dororodoroddo@users.noreply.github.com> +손희준 +谢乃闻 0cc4m +0Marble <85058989+0Marble@users.noreply.github.com> 0xspringtime <110655352+0xspringtime@users.noreply.github.com> 20kdc +2114L3 <2114L3@users.noreply.github.com> 2f38b454 3ooabkhxtn <31479382+3ooabkhxtn@users.noreply.github.com> 44670 <44670@users.noreply.github.com> +4onen <11580688+4onen@users.noreply.github.com> 65a <10104049+65a@users.noreply.github.com> 708-145 <40387547+708-145@users.noreply.github.com> -AN Long -AT +a-n-n-a-l-e-e <150648636+a-n-n-a-l-e-e@users.noreply.github.com> +a3sh <38979186+A3shTnT@users.noreply.github.com> +aa956 +Aadeshveer Singh <24b0926@iitb.ac.in> +Aadeshveer Singh Aarni Koskela Aaron Miller Aaron Teo <57927438+taronaeo@users.noreply.github.com> +Aaron Teo Aaryaman Vasishta Abheek Gulati Abhilash Majumder <30946547+abhilash1910@users.noreply.github.com> Abhishek Gopinath K <31348521+overtunned@users.noreply.github.com> +Acly +Adam +adel boussaken Adithya Balaji AdithyanI Adrian Adrian Hesketh Adrian Kretz +Adrian Lundberg <47256989+alundb@users.noreply.github.com> Adrien Gallouët Adrien Gallouët +afrideva <95653597+afrideva@users.noreply.github.com> +ag2s20150909 <19373730+ag2s20150909@users.noreply.github.com> +agray3 Ahmad Tameem <113388789+Tameem-10xE@users.noreply.github.com> Ahmet Zeer +ai-fonsi +Aidan <99101158+gSUz92nc@users.noreply.github.com> AidanBeltonS <87009434+AidanBeltonS@users.noreply.github.com> AidanBeltonS Aisuko Akarshan Biswas Akarshan Biswas Akarshan Biswas +akawrykow <142945436+akawrykow@users.noreply.github.com> Al Mochkin <14274697+amochkin@users.noreply.github.com> +Alan Gray +Alawode Oluwandabira Albert Jin Alberto <57916483+albbus-stack@users.noreply.github.com> +Alberto Cabrera Pérez <1478977+Alcpz@users.noreply.github.com> Alberto Cabrera Pérez Alberto Cabrera Pérez +Alberto Cabrera Pérez +Aldehir Rojas +alek3y <44779186+alek3y@users.noreply.github.com> +Aleksander Grygier Aleksei Nikiforov <103434461+AlekseiNikiforovIBM@users.noreply.github.com> +Alessandro98-git <61804547+Alessandro98-git@users.noreply.github.com> Alex Alex Azarov Alex Azarov Alex Brooks +Alex Fanthome Alex Klinkhamer Alex Klinkhamer Alex Nguyen Alex O'Connell <35843486+acon96@users.noreply.github.com> Alex Petenchea Alex Renda +Alex Trotta <44127594+Ahajha@users.noreply.github.com> Alex Tuddenham <61622354+AlexsCode@users.noreply.github.com> Alex von Gluck IV +Alex Wu +alex-spacemit Alexey Parfenov +Alexis Williams +alexpinel <93524949+alexpinel@users.noreply.github.com> +Alfred Ali Chraghi <63465728+alichraghi@users.noreply.github.com> Ali Nehzat Ali Tariq +Ali Tariq Alon +alonfaraj AlpinDale <52078762+AlpinDale@users.noreply.github.com> +alwqx +Aman +Aman Gupta +amd-dwang +amd-lalithnc Amir +amirai21 <89905406+amirai21@users.noreply.github.com> AmirAli Mirian <37371367+amiralimi@users.noreply.github.com> +amritahs-ibm +AN Long Ananta Bastola Anas Ahouzi <112881240+aahouzi@users.noreply.github.com> +Anav Prasad +anavp-nvidia +Andika Wasisto András Salamon Andreas (Andi) Kunar Andreas Kieslinger <47689530+aendk@users.noreply.github.com> Andrei +Andrew Aladjev Andrew Canis Andrew Downing Andrew Duffy Andrew Godfrey +Andrew Marshall Andrew Minh Nguyen <40281306+amqdn@users.noreply.github.com> +andrijdavid Andy Salerno Andy Tai +Ankur Verma <31362771+ankurvdev@users.noreply.github.com> +anon998 <131767832+anon998@users.noreply.github.com> +Anri Lombard +Anthony Umfer Anthony Van de Gejuchte +antichristHater <142441588+antichristHater@users.noreply.github.com> Antoine Viallon +Anton Mitkov +Anton Mitkov Antonis Makropoulos +Anudit Nagar +anzz1 +apaz +apcameron <37645737+apcameron@users.noreply.github.com> +arch-btw <57669023+arch-btw@users.noreply.github.com> +arcrank +ardfork <134447697+ardfork@users.noreply.github.com> Arik Poznanski +arlo-phoenix <140345165+arlo-phoenix@users.noreply.github.com> Armen Kaleshian Artem Artem Zinnatullin Artyom Lebedev +aryantandon01 <80969509+aryantandon01@users.noreply.github.com> Asbjørn Olling Ásgeir Bjarni Ingvarsson Asghar Ghorbani Ashish <1856117+ashishdatta@users.noreply.github.com> Ashok Gelal <401055+ashokgelal@users.noreply.github.com> Ashraful Islam +AT +at8u <129688334+at8u@users.noreply.github.com> +Atharva Dubey Atsushi Tatsuma +aubreyli Austin <77757836+teleprint-me@users.noreply.github.com> AustinMroz -BADR -BB-fat <45072480+BB-fat@users.noreply.github.com> +automaticcat +awatuna <23447591+awatuna@users.noreply.github.com> +b4b4o Bach Le +BADR +bagheera <59658056+bghira@users.noreply.github.com> Bailey Chittle <39804642+bachittle@users.noreply.github.com> +bandoti <141645996+bandoti@users.noreply.github.com> BarfingLemurs <128182951+BarfingLemurs@users.noreply.github.com> +Bart Louwers +Bartowski <3266127+bartowski1182@users.noreply.github.com> Bartowski +Bas Nijholt +bashayer hijji +BB-fat <45072480+BB-fat@users.noreply.github.com> Behnam M <58621210+ibehnam@users.noreply.github.com> +beiller +Beinsezii <39478211+Beinsezii@users.noreply.github.com> Ben Ashbaugh +Ben Chen Ben Garney Ben Siraphob Ben Williams Benjamin Findley <39356821+Kartoffelsaft@users.noreply.github.com> Benjamin Lecaillon <84293038+blecaillon@users.noreply.github.com> +Benni <73313922+BenjaminBruenau@users.noreply.github.com> Benson Wong Bernat Vadell Bernhard M. Wiedemann Bert Wagner +bhubbb <79117352+bhubbb@users.noreply.github.com> Billel Mokeddem Bingan <70050083+binganao@users.noreply.github.com> +Bizhao Shi <37729561+shibizhao@users.noreply.github.com> Bjarke Viksøe <164612031+bviksoe@users.noreply.github.com> +Björn Ganster +bmwl +Bo Zheng <368586905@qq.com> +bobqianic <129547291+bobqianic@users.noreply.github.com> Bodhi <3882561+BodhiHu@users.noreply.github.com> Bodo Graumann +Boian Berberov <7432115+bberberov@users.noreply.github.com> Bono Lv Borislav Stanimirov Borislav Stanimirov +Bowen Han Branden Butler Brandon Squizzato <35474886+bsquizz@users.noreply.github.com> Brian Brian Cunnie Bruce MacDonald +brucepro Bryan Honof -CJ Pais -CRD716 +bryanSwk <93190252+bryanSwk@users.noreply.github.com> +bsilvereagle +bssrdf +byte-6174 <88070277+byte-6174@users.noreply.github.com> Calvin Laurenson Cameron Cameron Kaiser @@ -132,20 +233,33 @@ CarterLi999 <664681047@qq.com> Casey Primozic Casey Primozic CausalLM <148736309+CausalLM@users.noreply.github.com> +ccbinn +cduk <19917266+cduk@users.noreply.github.com> +cebtenzzre Cebtenzzre CentricStorm Chad Brewbaker +Chad Voegele +chaihahaha Changyeon Kim +chansikpark Chao Jiang +characharm <123120856+characharm@users.noreply.github.com> Charles Duffy Charles Xu <63788048+chaxu01@users.noreply.github.com> Charles Xu +chen fan <350211548@qq.com> Chen Xi Chen Xi Cheng Shao +Chenguang Li <757486878@qq.com> Chenguang Li <87689256+noemotiovon@users.noreply.github.com> +chiranko <96988916+chiranko@users.noreply.github.com> Chris Elrod Chris Kuehl +Chris Peterson +Chris Rohlf +Chris Thompson Christian Demsar Christian Demsar Christian Falch <875252+chrfalch@users.noreply.github.com> @@ -155,260 +269,466 @@ Christian Kögler Christian Köhnenkamp Christian Zhou-Zheng <59622928+christianazinn@users.noreply.github.com> Christopher Nielsen <62156882+mascguy@users.noreply.github.com> +City <125218114+city96@users.noreply.github.com> +CJ Pais Clark Saben <76020733+csaben@users.noreply.github.com> Clauszy +clibdev <52199778+clibdev@users.noreply.github.com> Clint Herron +clyang +cmdr2 +cmdr2 +cocktailpeanut <121128867+cocktailpeanut@users.noreply.github.com> +codezjx +coezbek +comex +compilade <113953597+compilade@users.noreply.github.com> +compilade +Congcong Cai Conrad Kramer +Copilot <198982749+Copilot@users.noreply.github.com> Corentin REGAL +cpumaxx <163466046+cpumaxx@users.noreply.github.com> +crasm +crasm +crat0z <11581854+crat0z@users.noreply.github.com> +CRD716 CrispStrobe <154636388+CrispStrobe@users.noreply.github.com> Csaba Kecskemeti Cuong Trinh Manh -DAN™ +daboe01 +daghanerdonmez <44506702+daghanerdonmez@users.noreply.github.com> Damian Stewart +daminho <37615795+daminho@users.noreply.github.com> +DAN™ Dan Johansson <164997844+eddnjjn@users.noreply.github.com> Dan Johansson Dane Madsen DaniAndTheWeb <57776841+DaniAndTheWeb@users.noreply.github.com> +Daniel Benjaminsson Daniel Bevenius Daniel Drake +Daniel Han Daniel Hiltgen Daniel Illescas Romero Daniel Kleine <53251018+d-kleine@users.noreply.github.com> +Daniel Tang Daniele <57776841+daniandtheweb@users.noreply.github.com> +Daniele +Daniele Pinna <72076821+pestopoppa@users.noreply.github.com> Danny Milosavljevic DannyDaemonic +Darius Lukas Dat Quoc Nguyen <2412555+datquocnguyen@users.noreply.github.com> Dave Dave Airlie Dave Airlie Dave Della Costa +David Chiu David Friehs David Huang <1969802+hjc4869@users.noreply.github.com> David Kennedy +David Lima David Pflug +david raistrick David Renshaw +David Ribeiro Alves David Sommers <12738+databyte@users.noreply.github.com> David Yang +David Zhao <90013954+Your-Cheese@users.noreply.github.com> +davidef DavidKorczynski Dawid Potocki Dawid Wysocki <62249621+TortillaZHawaii@users.noreply.github.com> +ddh0 +ddh0 +ddpasa <112642920+ddpasa@users.noreply.github.com> +DDXDB <38449595+DDXDB@users.noreply.github.com> Dean +deepdiffuser <112834445+deepdiffuser@users.noreply.github.com> +deepsek <166548550+deepsek@users.noreply.github.com> Deins Denis Spasyuk <34203011+dspasyuk@users.noreply.github.com> Derrick T. Woolworth Deven Mistry <31466137+deven367@users.noreply.github.com> +devojony <61173062+devojony@users.noreply.github.com> +diannao <55k@outlook.com> Dibakar Gope Didzis Gosko Diego Devesa +Diner Burger +Đinh Trọng Huy <77562200+huydt84@users.noreply.github.com> Diogo Teles Sant'Anna +ditsuke +divinity76 Djip007 <3705339+Djip007@users.noreply.github.com> Djip007 +dm4 +dm4 +Dmytro Minochkin +Dobri Danchev <12420863+danchev@users.noreply.github.com> +DocShotgun <126566557+DocShotgun@users.noreply.github.com> +Doctor Shotgun <126566557+DocShotgun@users.noreply.github.com> Don Mahurin -DooWoong Lee (David) +Dong Won Kim <63934649+ddwkim@users.noreply.github.com> +Donghyeon Jeong <54725479+djeong20@users.noreply.github.com> +Dongliang Wei <121270393+wdl339@users.noreply.github.com> Doomsdayrs <38189170+Doomsdayrs@users.noreply.github.com> +DooWoong Lee (David) +Dorin-Andrei Geman +dotpy314 <33351922+dotpy314@users.noreply.github.com> Dou Xinpeng <15529241576@163.com> Dou Xinpeng <81913537+Dou-Git@users.noreply.github.com> Douglas Hanley +Dowon Dr. Tom Murphy VII Ph.D <499244+tom7@users.noreply.github.com> +drbh +ds5t5 <145942675+ds5t5@users.noreply.github.com> +duduta +dylan +eastriver Ebey Abraham +ebraminio +ebraminio Echo Nolan +Ed Addario <29247825+EAddario@users.noreply.github.com> Ed Lee Ed Lepedus Eddie-Wang Edward Taylor +eiery <19350831+eiery@users.noreply.github.com> Elaine Elbios <141279586+Elbios@users.noreply.github.com> Elton Kola +Emmanuel Ferdman Emreerdog <34742675+Emreerdog@users.noreply.github.com> Engininja2 <139037756+Engininja2@users.noreply.github.com> Equim Eric Curtin +Eric Curtin Eric Curtin Eric Sommerlade Eric Zhang <34133756+EZForever@users.noreply.github.com> +eric8607242 Erik Garrison Erik Scholz +Ervin Áron Tasnádi Esko Toivonen Ettore Di Giacinto +EugeoSynthesisThirtyTwo Evan Jones Evan Miller Eve <139727413+netrunnereve@users.noreply.github.com> Evgeny Kurnevsky +Ewan Crawford +Ewan Crawford Ewout ter Hoeven ExtReMLapin <3909752+ExtReMLapin@users.noreply.github.com> -FK Fabian Fabio R. Sluzala Faez Shakil +fairydreaming <166155368+fairydreaming@users.noreply.github.com> Faisal Zaghloul Faisal Zaghloul Fan Shupei FantasyGmm <16450052+FantasyGmm@users.noreply.github.com> +fanyang Farbod Bijary <110523279+farbodbj@users.noreply.github.com> Fattire <528174+fat-tire@users.noreply.github.com> Felix +fengerhu1 <2748250768@qq.com> +fidoriel <49869342+fidoriel@users.noreply.github.com> Finn Voorhees Firat FirstTimeEZ <179362031+FirstTimeEZ@users.noreply.github.com> +fj-y-saito <85871716+fj-y-saito@users.noreply.github.com> +FK Florent BENOIT +Florian Badie Folko-Ven <71110216+Folko-Ven@users.noreply.github.com> Foul-Tarnished <107711110+Foul-Tarnished@users.noreply.github.com> +Francisco Herrera Francisco Melo <43780565+francis2tm@users.noreply.github.com> Frank Mai FrankHB Frankie Robertson +fraxy-v <65565042+fraxy-v@users.noreply.github.com> Fred Douglas <43351173+fredlas@users.noreply.github.com> Frederik Vogel +Fredrik Hultin +frob +fxzjshm <11426482+fxzjshm@users.noreply.github.com> +g2mt <166577174+g2mt@users.noreply.github.com> Gabe Goodhart Gabe Goodhart +Gabriel Larson <55459720+gabriellarson@users.noreply.github.com> +Gadflyii <34758915+Gadflyii@users.noreply.github.com> Gaetan Bisson GainLee Galunid Gary Linscott Gary Mulder +gatbontonpc +Gaurav Garg <52341457+gaugarg-nv@users.noreply.github.com> +Gaurav Garg Gavin Zhao Genkagaku.GPT Georgi Gerganov Gian-Carlo Pascutto +GideonSerf Gilad S Gilad S. <7817232+giladgd@users.noreply.github.com> +github-actions[bot] +GittyBurstein Giuseppe Scrivano +Giuseppe Scrivano GiviMAD +gliptic +gn64 +goerch Govlzkoy +grahameth <96447521+grahameth@users.noreply.github.com> +Gregor Jasny +Grzegorz Grasza +gtygo Guillaume "Vermeille" Sanchez Guillaume Wenzek Guoliang Hua <32868157+nbcsm@users.noreply.github.com> Guoteng <32697156+SolenoidWGT@users.noreply.github.com> Guspan Tanadi <36249910+guspan-tanadi@users.noreply.github.com> Gustavo Rocha Dias <91472747+gustrd@users.noreply.github.com> +Guus Waals <_@guusw.nl> +Guy Goldenberg +gwjr <502526+gwjr@users.noreply.github.com> +h-h-h-h <13482553+h-h-h-h@users.noreply.github.com> Haggai Nuchi +Haiyue Wang Halalaluyafail3 <55773281+Halalaluyafail3@users.noreply.github.com> Hale Chan Hamdoud Hakem <90524568+hamdoudhakem@users.noreply.github.com> +Han Qingzhe <95479277+hNSBQZ@users.noreply.github.com> Han Yin HanishKVC +hankcs Haohui Mai +haopeng <657407891@qq.com> +Haowei Wu Haoxiang Fei Harald Fernengel Hatsune Miku <129688334+at8u@users.noreply.github.com> HatsuneMikuUwU33 <173229399+HatsuneMikuUwU33@users.noreply.github.com> Haus1 +Héctor Estrada Moreno +HelloKS +Helton Reis <47722840+HRKings@users.noreply.github.com> +Hendrik Erz Henk Poley Henri Vasserman Henrik Forstén Henry Linjamäki +Henry Linjamäki +Henry147147 <44851451+Henry147147@users.noreply.github.com> +Herman Semenoff Herman Semenov Hesen Peng +HighDoping HimariO +hipudding +hksdpc255 <43977088+hksdpc255@users.noreply.github.com> Hoang Nguyen +hoangmit +HonestQiao Hong Bo PENG +hongbo.mo <352280764@qq.com> Hongyu Ouyang <96765450+casavaca@users.noreply.github.com> +hopkins385 <98618192+hopkins385@users.noreply.github.com> Howard Su +howlger +howlger Hua Jiang Huang Qi Huawei Lin Hugo Roussel Huifeng Ou <79071290+ho2103@users.noreply.github.com> +hutli <6594598+hutli@users.noreply.github.com> +hutli +hutli +hxer7963 +hydai +iacore <74560659+iacore@users.noreply.github.com> Ian Bull Ian Bull Ian Scrivener +ibrahim khadraoui <132432132+ibrahimkhadraoui@users.noreply.github.com> Icecream95 +Icenowy Zheng +icppWorld <124377669+icppWorld@users.noreply.github.com> Ido S +igardev <49397134+igardev@users.noreply.github.com> +igarnier IgnacioFDM Igor Okulist +Igor Smirnov +Ihar Hrachyshka Ihar Hrachyshka Ikko Eltociear Ashimine +Ilia Ilmer Ilya Kurdyukov <59548320+ilyakurdyukov@users.noreply.github.com> +Imad Saddik <79410781+ImadSaddik@users.noreply.github.com> +intelmatt <61025942+intelmatt@users.noreply.github.com> +iohub Ionoclast Laboratories +iron Isaac McFadyen IsaacDynamo <61521674+IsaacDynamo@users.noreply.github.com> +Ishaan Gandhi +iSma +issixx <46835150+issixx@users.noreply.github.com> Ivan Ivan Filipov <159561759+vanaka11@users.noreply.github.com> Ivan Komarov Ivan Stepanov -JC <43374599+MrSMlT@users.noreply.github.com> -JFLFY2255 -JH23X <165871467+JH23X@users.noreply.github.com> +Ivy233 <952254420@qq.com> +ixgbe <1113177880@qq.com> +j-k +jacekpoplawski <67507230+jacekpoplawski@users.noreply.github.com> Jack Mousseau Jack Mousseau JackJollimore <130917767+JackJollimore@users.noreply.github.com> +jacobi petrucciani <8117202+jpetrucciani@users.noreply.github.com> Jaeden Amero Jaemin Son Jafar Uruç Jag Chadha +jaime-m-p <167997752+jaime-m-p@users.noreply.github.com> +Jake Karnes +Jakkala Mahesh <155058658+MaheshJakkala@users.noreply.github.com> Jakub N James A Capozzoli <157492257+jac-jim@users.noreply.github.com> James Reynolds +jameswu2014 <545426914@qq.com> Jan Boon Jan Boon Jan Ploski Jannis Schönleber +Jared Tweed Jared Van Bortel Jared Van Bortel +Jaromír Hradílek Jason C.H Jason McCartney +Jason Ni Jason Stillerman +jason_w +Jay +Jay Zenith <162098309+JayZenith@users.noreply.github.com> +JC <43374599+MrSMlT@users.noreply.github.com> +jdomke <28772296+jdomke@users.noreply.github.com> Jean-Christophe Hoelt Jean-Michaël Celerier Jed Fox Jeff Bolz Jeffrey Morgan Jeffrey Quesnelle +Jeremy Demeule +Jeremy Rand <244188+JeremyRand@users.noreply.github.com> Jeroen Mostert +Jesse +Jesse Gross +Jesse Ikonen Jesse Jojo Johnson Jett Janiak Jeximo +JFLFY2255 +JH23X <165871467+JH23X@users.noreply.github.com> Jhen-Jie Hong +Jiacheng (Jason) Chen <76919340+jiachengjason@users.noreply.github.com> Jiahao Li +jiahao su Jian Liao JidongZhang-THU <1119708529@qq.com> +Jie Fu (傅杰) +Jie Fu (傅杰) +jiez <373447296@qq.com> Jinwoo Jeong <33892306+williamjeong2@users.noreply.github.com> Jinyang He Jiří Podivín <66251151+jpodivin@users.noreply.github.com> Jiří Sejkora +JJJYmmm <92386084+JJJYmmm@users.noreply.github.com> +jklincn <985765408@qq.com> +jklincn +jneem Joan Fontanals Joan Fontanals João Dinis Ferreira Joe Eli McIlvain Joe Todd +joecryptotoo <80373433+joecryptotoo@users.noreply.github.com> Johan Johannes Gäßler Johannes Rudolph John <78893154+cmp-nct@users.noreply.github.com> John Balis +John Bean <113509988+johnbean393@users.noreply.github.com> John Smith <67539080+kingsidelee@users.noreply.github.com> JohnnyB +johnson442 <56517414+johnson442@users.noreply.github.com> +jojorne +jon-chuang <9093549+jon-chuang@users.noreply.github.com> Jonas Wunderlich <32615971+jonas-w@users.noreply.github.com> +Jonathan Graehl <99024+graehl@users.noreply.github.com> Jorge A <161275481+jorgealias@users.noreply.github.com> Jose Maldonado <63384398+yukiteruamano@users.noreply.github.com> Joseph Stahl <1269177+josephst@users.noreply.github.com> Josh Ramer +Joshua Cogliati Joyce +jp-x-g Juan Calderon-Perez <835733+gaby@users.noreply.github.com> +Judd <4046440+foldl@users.noreply.github.com> Judd Juk Armstrong <69222624+jukofyork@users.noreply.github.com> +jukofyork <69222624+jukofyork@users.noreply.github.com> +Julien Denize <40604584+juliendenize@users.noreply.github.com> Julius Arkenberg +Julius Tischbein +Julius Tischbein Jun Hee Yoo Jun Jie <71215065+junnjiee16@users.noreply.github.com> +junchao-loongson <68935141+junchao-loongson@users.noreply.github.com> +junchao-zhao <68935141+junchao-loongson@users.noreply.github.com> Junil Kim +Junwon Hwang Junyang Lin Juraj Bednar Justin Parker +Justin Santa Barbara Justin Suess Justina Cho Justine Tunney Justine Tunney Juuso Alasuutari -KASR +Juyoung Suk +jwj7140 <32943891+jwj7140@users.noreply.github.com> +k.h.lai +Kai Pastor +kaizau +kallewoof +kallewoof +kalomaze <66376113+kalomaze@users.noreply.github.com> Kamil Tomšík +kang Kante Yin Karol Kontny <82021046+kkontny@users.noreply.github.com> Karsten Weiss Karthick Karthik Kumar Viswanathan <195178+guilt@users.noreply.github.com> Karthik Sethuraman +KASR Kasumi <90275229+kasumi-1@users.noreply.github.com> +katsu560 <118887472+katsu560@users.noreply.github.com> Kawrakow <48489457+ikawrakow@users.noreply.github.com> +kchro3 <62481661+kchro3@users.noreply.github.com> Keiichi Tabata Keke Han Kenvix ⭐ @@ -417,48 +737,109 @@ Kevin Gibbons Kevin Ji <1146876+kevinji@users.noreply.github.com> Kevin Kwok Kevin Lo +Kevin Pouget Kevin Wang +khimaros +kiltyj +Kim S. +kimminsu <80271594+kimminsu38oo@users.noreply.github.com> +kiwi <122582483+kiwi142857@users.noreply.github.com> +klosax <131523366+klosax@users.noreply.github.com> Kolen Cheung Konstantin Herud Konstantin Zhuravlyov +krystiancha +kunal-vaishnavi <115581922+kunal-vaishnavi@users.noreply.github.com> +kunnis Kunshang Ji +kuronekosaiko +kustaaya <58045274+kustaaya@users.noreply.github.com> +kuvaus <22169537+kuvaus@users.noreply.github.com> +kwin1412 <42286931+kwin1412@users.noreply.github.com> Kyle Bruene Kyle Liang Kyle Mistele Kylin <56434533+KyL0N@users.noreply.github.com> +l-austenfeld <53152202+l-austenfeld@users.noreply.github.com> +l3utterfly +LaffeyNyaa <112215776+LaffeyNyaa@users.noreply.github.com> +laik Lars Grammel +Lars Sonchocky-Helldorf Laura +Law Po Ying <30721578+yingying0906@users.noreply.github.com> +lcy +ldwang +le.chang Lee <44310445+lx200916@users.noreply.github.com> Lee Drake +leejet Leng Yue +Lennart Austenfeld <53152202+l-austenfeld@users.noreply.github.com> +leo-pony Leon Knauer -LeonEricsson <70749762+LeonEricsson@users.noreply.github.com> +Leonard Mosescu Leonardo Neumann +LeonEricsson <70749762+LeonEricsson@users.noreply.github.com> +levkropp +lexasub +lgai-exaone +lhez +lhez +Li Pengzhan <151381994+Lpzhan931@users.noreply.github.com> Li Tan +limitedAtonement Linwei Wang Liu Jia <109258120+Septa2112@users.noreply.github.com> Liu Jia +liuwei-git <14815172+liuwei-git@users.noreply.github.com> +lixing-star <104126818+lixing-star@users.noreply.github.com> +lksj92hs <134250687+lksj92hs@users.noreply.github.com> LoganDark Loïc Carrère +lon <114724657+longregen@users.noreply.github.com> +loonerin <132926317+loonerin@users.noreply.github.com> LostRuins <39025047+LostRuins@users.noreply.github.com> LostRuins Concedo <39025047+LostRuins@users.noreply.github.com> +lovedheart <6277001+lovedheart@users.noreply.github.com> +ltoniazzi <61414566+ltoniazzi@users.noreply.github.com> +Luca Stefani Lucas Moura Belo Luciano +Lukas Straub +Łukasz Ślusarczyk <112692748+lslusarczyk@users.noreply.github.com> Luo Tian +luoyu-intel +luyhcsu <110711054+luyhcsu@users.noreply.github.com> Lyle Dean M-A +M. Mediouni M. Yusuf Sarıgöz +m3ndax Ma Mingfei Maarten ter Huurne Mack Straight +maddes8cht <55592906+maddes8cht@users.noreply.github.com> Maël Kerbiriou MaggotHATE +magicse +Mahekk Shaikh <118063190+Mahekk357@users.noreply.github.com> Mahesh Madhav <67384846+heshpdx@users.noreply.github.com> +mahorozte <41834471+mahorozte@users.noreply.github.com> +makomk +manikbhandari Manuel <44313466+makuche@users.noreply.github.com> +maor-ps <154728172+maor-ps@users.noreply.github.com> Marc Köhlbrugge +Marcello Seri Marco Matthies <71844+marcom@users.noreply.github.com> +Marcos Del Sol Vives +marcoStocchi Marcus Dunn <51931484+MarcusDunn@users.noreply.github.com> +Marek Hradil jr. Marian Cepok +Marius Gerdes <141485318+mglambda@users.noreply.github.com> +Mariusz Woloszyn Mark Fairbairn Mark Zhuang Marko Tasic @@ -467,7 +848,11 @@ Martin Delille Martin Krasser Martin Schwaighofer Marvin Gießing +Masashi Yoshimura +Masato Nakasaka +Masato Nakasaka Masaya, Kato <62578291+msy-kato@users.noreply.github.com> +mashdragon <122402293+mashdragon@users.noreply.github.com> MasterYi1024 <39848311+MasterYi1024@users.noreply.github.com> Mateusz Charytoniuk Matheus C. França @@ -475,57 +860,89 @@ Matheus Gabriel Alves Silva Mathieu Baudier Mathieu Geli Mathieu Nayrolles -Mathijs Henquet Mathijs de Bruin +Mathijs Henquet +matiaslin <45382001+matiaslin@users.noreply.github.com> Matt Clayton <156335168+mattjcly@users.noreply.github.com> Matt Pulver Matt Stephenson +matt23654 <193348153+matt23654@users.noreply.github.com> +matt23654 +matteo +matteo Matteo Boschini <12133566+mbosc@users.noreply.github.com> Matteo Mortari Mattheus Chediak +Matthew Michel Matthew Tejo +Matthieu Coudron <886074+teto@users.noreply.github.com> +Mattt Matvey Soloviev Max Krasnyansky +Max Krasnyansky Max Krasnyansky Maxim Evtush <154841002+maximevtush@users.noreply.github.com> Maxime <672982+maximegmd@users.noreply.github.com> Maximilian Winter +mdrokz +MeeMin <74113151+Meet91721@users.noreply.github.com> Meng Zhang Meng, Hengyu Mengqing Cao Merrick Christensen +mgroeber9110 <45620825+mgroeber9110@users.noreply.github.com> +Miaoqian Lin Michael Coppola +Michael de Gans +Michaël de Vries Michael Engel Michael Francis +Michael Giba Michael Hueschen Michael Kesper Michael Klimenko Michael Podvitskiy Michael Potter -Michael de Gans -Michaël de Vries +Michael Wand Michał Moskal Michał Tuszyński Michelle Tan <41475767+MichelleTanPY@users.noreply.github.com> +midnight Mihai Mike +Mike Abbott +Mike Abbott Mikko Juola +Min-Hua <136287195+Min-Hua@users.noreply.github.com> +minarchist Minsoo Cheong <54794500+mscheong01@users.noreply.github.com> Minsoo Cheong Mirko185 Mirror Azure <54669636+MirrorAzure@users.noreply.github.com> MistApproach <98988043+MistApproach@users.noreply.github.com> Miwa / Ensan <63481257+ensan-hcl@users.noreply.github.com> +mj-shifu <77107165+mj-shifu@users.noreply.github.com> +mmyjona +mnehete32 <33429707+mnehete32@users.noreply.github.com> Mohammadreza Hendiani Mohammadreza Hendiani Molly Sophia +momonga <115213907+mmnga@users.noreply.github.com> +momonga <146910567+mmngays@users.noreply.github.com> MoonRide303 <130458190+MoonRide303@users.noreply.github.com> MorganRO8 <47795945+MorganRO8@users.noreply.github.com> +moritzbrantner <31051084+moritzbrantner@users.noreply.github.com> +muggle-stack Murilo Santana Musab Gultekin +musoles <135031143+musoles@users.noreply.github.com> +mzcu +Naco Siren Nam D. Tran <42194884+namtranase@users.noreply.github.com> +nanahi <130121847+na-na-hi@users.noreply.github.com> Nathan Epstein Natsu +Nauful Shaikh NawafAlansari <72708095+NawafAlansari@users.noreply.github.com> Nebula Neo Zhang <14088817+arthw@users.noreply.github.com> @@ -533,73 +950,157 @@ Neo Zhang Neo Zhang Jianyu Neuman Vong NeverLucky <92274250+nvrxq@users.noreply.github.com> +Nexes the Elder <124105151+Nexesenex@users.noreply.github.com> Nexes the Old <124105151+Nexesenex@users.noreply.github.com> Nexesenex <124105151+Nexesenex@users.noreply.github.com> +ngc92 <7938269+ngc92@users.noreply.github.com> +nhamanasu <45545786+nhamanasu@users.noreply.github.com> Niall Coates <1349685+Niall-@users.noreply.github.com> +niansa/tuxifan +niansa/tuxifan Nicholai Tukanov +Nick <0x0b4ac@gmail.com> +nick huang +nickp27 Nico Bosshard Nicolai Weitkemper +Nicolas B. Pierron Nicolás Pérez Nicolò Scipione Nigel Bosch +Nikhil Jain Nikita Sarychev <42014488+sARY77@users.noreply.github.com> Niklas Korz NikolaiLyssogor <59844691+NikolaiLyssogor@users.noreply.github.com> Nikolaos Pothitos Nikolas <127742645+nneubacher@users.noreply.github.com> +Nikolay Popov <131475237+npopov-vst@users.noreply.github.com> Nindaleth +ningshanwutuobang +Noah <99681487+NoahOksuz@users.noreply.github.com> +nold +nopperl <54780682+nopperl@users.noreply.github.com> +nullname Nuno -OSecret <135510162+OLSecret@users.noreply.github.com> +nusu-github <29514220+nusu-github@users.noreply.github.com> +nwyin +o7si <32285332+o7si@users.noreply.github.com> Oleksandr Kuvshynov <661042+okuvshynov@users.noreply.github.com> Oleksandr Nikitin Oleksii Maryshchenko +Olexandr88 +olexiyb +Oliver Simons +Oliver Simons +Oliver Walsh Olivier Chafik +Olivier Chafik +omahs <73983677+omahs@users.noreply.github.com> Ondřej Čertík +oobabooga <112222186+oobabooga@users.noreply.github.com> +oobabooga +opparco +Oscar Barenys +OSecret <135510162+OLSecret@users.noreply.github.com> +ostix360 <55257054+ostix360@users.noreply.github.com> Ouadie EL FAROUKI PAB Pablo Duboue +Pádraic Slattery +Pascal Pascal Patry +pascal-lc <49066376+pascal-lc@users.noreply.github.com> Patrice Ferlet Patrick Peng +Patryk Kaminski Paul Tsochantaris Pavel Zloi +Pavels Zaicenkovs Pavol Rusnak Paweł Wodnicki <151604+32bitmicro@users.noreply.github.com> +pculliton Pedro Cuenca +peidaqi +Penglin Cai <1402538448@qq.com> +pengxin99 +Pepijn de Vos +Percy Piper +Perry Naseck <4472083+DaAwesomeP@users.noreply.github.com> +perserk Peter Peter Sugihara +Peter0x44 +petterreinholdtsen Phil H <5756783+phiharri@users.noreply.github.com> Philip Taron +philip-essential <169196560+philip-essential@users.noreply.github.com> Phillip Kravtsov +Phylliida Dev +piDack <104877312+piDack@users.noreply.github.com> Pierre Alexandre SCHEMBRI Pierrick Hymbert Pieter Ouwerkerk +Piotr +Piotr Jasiukajtis +Piotr Kubaj +Piotr Wilkin (ilintar) +pl752 Plamen Minev +pmysl +pockers21 <134406831+pockers21@users.noreply.github.com> +postmasters +Pouya +pqnet <119850+pqnet@users.noreply.github.com> +Prabod +Prajwal B Mehendarkar Prashant Vithule <119530321+Vithulep@users.noreply.github.com> Przemysław Pawełczyk +psocolovsky <50770545+psocolovsky@users.noreply.github.com> +pudepiedj PureJourney +QDelta <60222316+QDelta@users.noreply.github.com> +Qeeweew <68716978+Qeeweew@users.noreply.github.com> Qin Yue Chen <71813199+chenqiny@users.noreply.github.com> +qingfengfenga <41416092+qingfengfenga@users.noreply.github.com> +qingy1337 Qingyou Meng +qouoq Qu Zongfu <43257352+yancaoweidaode@users.noreply.github.com> +Quentin Bramas +qunash +R +R R0CKSTAR R0CKSTAR -RJ Adriaansen +rabidcopy +RachelMantel Radoslav Gerganov Radosław Gryta +Rafal Lewczuk +Rahul Sathe <150351592+rrsathe@users.noreply.github.com> Rahul Vivek Nair <68507071+RahulVivekNair@users.noreply.github.com> +rainred <107027757+gryffindor-rr@users.noreply.github.com> Raj Hammeer Singh Hada Ralph Soika Rand Xie Randall Fitzgerald Random Fly +rankaiyx +Raul Torres <138264735+rauletorresc@users.noreply.github.com> +redbeard +Reese Levine Reinforce-II Rémy O Rémy Oudompheng Ren Xuancheng +Renat Rene Leonhardt <65483435+reneleonhardt@users.noreply.github.com> Reza Kakhki Reza Rahemtola <49811529+RezaRahemtola@users.noreply.github.com> RhinoDevel +rhjdvsgsgks <26178113+rhjdvsgsgks@users.noreply.github.com> +rhuddleston +Rhys-T <108157737+Rhys-T@users.noreply.github.com> Riccardo Orlando Riceball LEE Rich Dougherty @@ -611,14 +1112,22 @@ Rickard Edén Rickard Hallerbäck Rickey Bowers Jr Riley Stewart +rimoliga <53384203+rimoliga@users.noreply.github.com> Rinne Rinne +RJ Adriaansen +rmatif <66360289+rmatif@users.noreply.github.com> +rmatif +rmatif Robert Brisita <986796+rbrisita@users.noreply.github.com> Robert Collins Robert Ormandi <52251610+ormandi@users.noreply.github.com> Robert Sung-wook Shin Robey Holderith +Robin Davidsson <40024429+R-Dson@users.noreply.github.com> Robyn +Rőczey Barnabás <31726601+An0nie@users.noreply.github.com> +RodriMora Roger Meier Rohanjames1997 Roland <14355895+rbur0425@users.noreply.github.com> @@ -629,68 +1138,133 @@ Roman Parykin Ron Evans Ron Jailall Roni +Ronny Brendel Ronny Brendel Ronsor +Rotem Dan Rowan Hart +rspOverflow <217881046+rspOverflow@users.noreply.github.com> +rtaluyev Ruan <47767371+ruanych@users.noreply.github.com> +Ruben Ortlam +Ruben Ortlam Ruchira Hasaranga Rudi Servo +Ruikai Peng Ruixin Huang <18860020911@163.com> Rune <43761327+Rune-AI@users.noreply.github.com> +runfuture RunningLeon RunningLeon +Russyyds <161207317+Russyyds@users.noreply.github.com> Ryan Landay +Ryan Mangeno <160974989+ryan-mangeno@users.noreply.github.com> Ryder Wishart Ryuei -Rőczey Barnabás <31726601+An0nie@users.noreply.github.com> -SAMI -SRHMorris <69468379+SRHMorris@users.noreply.github.com> -SXX +s-goto-11 <206795233+s-goto-11@users.noreply.github.com> +s8322 +Saba Fallah <10401143+sfallah@users.noreply.github.com> +Sachin Desai +safranowith SakuraUmi Salvador E. Tropea Salvatore Mesoraca +Sam +Sam Malayek <12037535+SamMalayek@users.noreply.github.com> Sam Spilsbury +Sam/Samuel <57896620+cern1710@users.noreply.github.com> +SAMI Sami Farin <3876865+Safari77@users.noreply.github.com> Samuel Maynard +Sandro Hanea <40202887+sandrohanea@users.noreply.github.com> +sandyiscool Sang-Kil Park +Sascha Rogmann <59577610+srogmann@users.noreply.github.com> +sasha0552 +SavicStefan <50296686+SavicStefan@users.noreply.github.com> +Scott Fudally Seb C <47074056+Sebby37@users.noreply.github.com> Sebastián A SebastianApel <13675545+SebastianApel@users.noreply.github.com> +semidark Senemu <10880819+Senemu@users.noreply.github.com> +senhtry +Sergei Vorobyov +Sergey Alirzaev Sergey Alirzaev +Sergey Fedorov Sergio López Sergio López +serhii-nakon <57632032+serhii-nakon@users.noreply.github.com> Sertaç Özercan <852750+sozercan@users.noreply.github.com> SeungWon Jeong <65549245+redlion0929@users.noreply.github.com> ShadovvBeast +Shagun Bera <141054835+notV3NOM@users.noreply.github.com> Shakhar Dasgupta +Shakil Ahmed <44522075+ahmedshakill@users.noreply.github.com> +shalinib-ibm Shane A Shangning Xu <32517059+xushangning@users.noreply.github.com> +shani-f Shankar Shanshan Shen <467638484@qq.com> +shaofeiqi <109865877+shaofeiqi@users.noreply.github.com> +shaofeiqi +sharpHL <132747147+sharpHL@users.noreply.github.com> +Shawn Gu +Shawn yang <137684499+Yangxiaoz@users.noreply.github.com> Shelby Jenkins <47464908+ShelbyJenkins@users.noreply.github.com> Sheldon Robinson +shibe2 Shijie <821898965@qq.com> +Shin-myoung-serp Shintarou Okada +Shouyu <65317431+joeldushouyu@users.noreply.github.com> Shouzheng Liu <61452103+lshzh-ww@users.noreply.github.com> Shouzheng Liu +SHUAI YANG Shuichi Tsutsumi +shun095 <8069181+shun095@users.noreply.github.com> +Shunta Saito Shupei Fan +Si1w <139008732+Si1w@users.noreply.github.com> Sigbjørn Skjæret +simevo +Simon Redman Simon Willison +simon886212 <37953122+simon886212@users.noreply.github.com> +Simranjeet Singh <105192966+simrnsingh@users.noreply.github.com> +singularity <12184989+singularity-s0@users.noreply.github.com> +sirus20x6 Siwen Yu +sjinzh +sjxx <63994076+ylsdamxssjxxdd@users.noreply.github.com> +Sky Sky Yan +slaren <2141330+slaren@users.noreply.github.com> Slaren <2141330+slaren@users.noreply.github.com> +slaren Slava Primenko +Slobodan Josic <127323561+slojosic-amd@users.noreply.github.com> Small Grass Forest +SmartestWashingMachine +SnA1lGo <44647694+skrandy@users.noreply.github.com> +snadampal <87143774+snadampal@users.noreply.github.com> SoftwareRenderer <138734813+SoftwareRenderer@users.noreply.github.com> Someone Someone Serge +someone13574 <81528246+someone13574@users.noreply.github.com> Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Spencer Sutton +SRHMorris <69468379+SRHMorris@users.noreply.github.com> Srihari-mcw <96763064+Srihari-mcw@users.noreply.github.com> Srinivas Billa +ssweens <1149151+ssweens@users.noreply.github.com> +standby24x7 +staviq +stduhpf Stefan Sydow +Ștefan-Gabriel Muscalu Steffen Röcker Stephan Walter Stephen Nichols @@ -698,46 +1272,100 @@ Steve Bonds Steve Grubb Steven Prichard Steven Roussey +stevenkuang Steward Garcia <57494570+FSSRepo@users.noreply.github.com> StrangeBytesDev <141275258+StrangeBytesDev@users.noreply.github.com> +strawberrymelonpanda <152940198+strawberrymelonpanda@users.noreply.github.com> Suaj Carrot <72162667+SuajCarrot@users.noreply.github.com> +sudhiarm Sukriti Sharma SuperUserNameMan Sutou Kouhei +Svetlozar Georgiev <55534064+sgeor255@users.noreply.github.com> +swittk +SXX Tai Duc Nguyen Taikono-Himazin +Taimur Ahmad +Tak-RS +takasurazeem +takov751 <40316768+takov751@users.noreply.github.com> +takuya kodama +takuya kodama +tamarPal Tameem <113388789+AhmadTameem@users.noreply.github.com> Tamotsu Takahashi +tarcey +Tarek Dakhran +Tarek Dakhran +tastelikefeet <58414341+tastelikefeet@users.noreply.github.com> +Tatsuya Tanaka +Taylor +tc-mb <157115220+tc-mb@users.noreply.github.com> +TecJesh Tei Home +tempstudio <49735574+tempstudio@users.noreply.github.com> +teo +texmex76 <40733439+texmex76@users.noreply.github.com> Thái Hoàng Tâm <75922889+RoyalHeart@users.noreply.github.com> +Thammachart Chinvarapon <1731496+Thammachart@users.noreply.github.com> Thatcher Chamberlin Theia Vogel +thement <40525767+thement@users.noreply.github.com> +theo77186 +theraininsky <76763719+theraininsky@users.noreply.github.com> Thérence <13496987+Royalphax@users.noreply.github.com> +thewh1teagle <61390950+thewh1teagle@users.noreply.github.com> Thibault Terrasson +thom-dev-fr <161708450+thom-dev-fr@users.noreply.github.com> +Thomas Germer <99991@users.noreply.github.com> +Thomas Jarosch Thomas Klausner +Thore Koritzius Thorsten Sommer +TianHao324 <854531745@qq.com> +TianHao324 +Tianyue-Zhao Tim Miller +Tim Neumann Tim Wang Timmy Knight Timothy Cronin <40186632+4imothy@users.noreply.github.com> Ting Lou Ting Lou Ting Sun +tjohnman Tobias Lütke +Todor Boinovski Tom C Tom Jobbins <784313+TheBloke@users.noreply.github.com> Tomas Tomáš Pazdiora Tony Wasserka <4840017+neobrain@users.noreply.github.com> +toyer <2042519524@qq.com> +TrevorS +triplenom <79777178+triplenom@users.noreply.github.com> Tristan Druyen Tristan Ross Trivikram Kamat <16024985+trivikr@users.noreply.github.com> +tslmy +tt <291400568@qq.com> Tungsten842 <886724vf@anonaddy.me> Tungsten842 Tushar +tv1wnd <55383215+tv1wnd@users.noreply.github.com> +ubergarm +ubik2 UEXTM.com <84163508+uextm@users.noreply.github.com> +Uilian Ries +uint256_t +uint256_t Ujjawal Panchal <31011628+Ujjawal-K-Panchal@users.noreply.github.com> Ulrich Drepper +unbounded +uvos +uvos +uvos Uzo Nweke Vaibhav Srivastav Val Kharitonov @@ -745,10 +1373,22 @@ Valentin Konovalov Valentin Mamedov <45292985+Inf1delis@users.noreply.github.com> Valentyn Bezshapkin <61702053+valentynbez@users.noreply.github.com> Vali Malinoiu <0x4139@gmail.com> +valiray <133289098+valiray@users.noreply.github.com> +vb +Vedran Miletić +Victor <194116445+dodekapod@users.noreply.github.com> Victor Nogueira Victor Z. Peng Viet-Anh NGUYEN (Andrew) +vik +Ville Vesilehto +Vineel Abhinav <131174187+vineelabhinav@users.noreply.github.com> Vinesh Janarthanan <36610342+VJHack@users.noreply.github.com> +Vinkal +virajwad <84867530+virajwad@users.noreply.github.com> +viric +Vishal Agarwal +Vishal Singh Vitali Lovich Vivian Vlad @@ -756,351 +1396,124 @@ Vladimir Vladimir Malyutin Vladimir Vuksanovic <109677816+vvuksanovic@users.noreply.github.com> Vladimir Zorin +Vladislav Sayapin <70110788+v-sayapin@users.noreply.github.com> +vmobilis <75476228+vmobilis@users.noreply.github.com> +vodkaslime <646329483@qq.com> VoidIsVoid <343750470@qq.com> Volodymyr Vitvitskyi <72226+signalpillar@users.noreply.github.com> +vvhg1 <94630311+vvhg1@users.noreply.github.com> +vxiiduu <73044267+vxiiduu@users.noreply.github.com> Wagner Bruna Wang Qin <37098874+wangqin0@users.noreply.github.com> Wang Ran (汪然) +Wang Weixuan WangHaoranRobin <56047610+WangHaoranRobin@users.noreply.github.com> +wangshuai09 <391746016@qq.com> +wbpxre150 <100937007+wbpxre150@users.noreply.github.com> +wbtek <171302111+wbtek@users.noreply.github.com> Weird Constructor Weizhao Ouyang +Weizhao Ouyang Welby Seely +welix Wentai Zhang +whoreson <139810751+whoreson@users.noreply.github.com> Wilken Gottwalt <12194808+wgottwalt@users.noreply.github.com> WillCorticesAI <150854901+WillCorticesAI@users.noreply.github.com> +william pan <61359596+wp4032@users.noreply.github.com> William Tambellini William Tambellini Willy Tarreau +woachk <24752637+woachk@users.noreply.github.com> +wonjun Jang +woodx <124784234+woodx9@users.noreply.github.com> Woof Dog <197125663+woof-dog@users.noreply.github.com> +wooksong Wouter <9594229+DifferentialityDevelopment@users.noreply.github.com> +Wroclaw +wsbagnsv1 Wu Jian Ping Wu Jian Ping +wwoodsTM <104587230+wwoodsTM@users.noreply.github.com> +wzy <32936898+Freed-Wu@users.noreply.github.com> +xaedes +xaedes +xctan +xctan Xiake Sun Xiang (Kevin) Li +Xiangyan Sun Xiao-Yong Jin +xiaobing318 <71554036+xiaobing318@users.noreply.github.com> +xiaofei XiaotaoChen Xiaoyi Chen Xie Yanbo Xingchen Song(宋星辰) +Xinpeng Dou <15529241576@163.com> Xinpeng Dou <81913537+Dou-Git@users.noreply.github.com> +xloem <0xloem@gmail.com> Xuan Son Nguyen +Xuan-Son Nguyen Xuan-Son Nguyen +yael-works <106673277+yael-works@users.noreply.github.com> +YaelGitAccount <38328157276@mby.co.il> +YaelLogic Yaiko +YangLe +yangli2 Yann Follet <131855179+YannFollet@users.noreply.github.com> Yaroslav +Yavor Ivanov Yazan Agha-Schrader +Ycros <18012+ycros@users.noreply.github.com> +YehuditE +Yibo Cai +Yibo Cai +yifant-code Yiming Cui Yishuo Wang +ymcki <84055651+ymcki@users.noreply.github.com> Yoshi Suhara Yoshi Suhara +Yoshi_likes_e4 <104140648+pt13762104@users.noreply.github.com> Younes Belkada <49240599+younesbelkada@users.noreply.github.com> +Yuanhao Ji +Yuannan Yueh-Po Peng <94939112+y10ab1@users.noreply.github.com> Yüg Yui +Yuichiro Utsumi <81412151+utsumi-fj@users.noreply.github.com> +yuiseki +yulo <77381088+zhang-hui-yulo@users.noreply.github.com> +yumeyao +yummy <57988893+jk3456a@users.noreply.github.com> Yun Dou Yuri Khrustalev +yuri@FreeBSD Yusuf Kağan Hanoğlu Yuval Peled <31162840+Yuval-Peled@users.noreply.github.com> -ZHAOKAI WANG +Yuxuan Zhang <2448370773@qq.com> +Z +Zagaj +zakkor Zane Shannon Zay <95888118+isaiahbjork@users.noreply.github.com> Zenix Zhang Peiyuan +zhangkaihuo +ZHAOKAI WANG Zheng.Deng <32841220+dengzheng-cloud@users.noreply.github.com> +zhentaoyu Zhenwei Jin <109658203+kylo5aby@users.noreply.github.com> +Zheyuan Chen +Zhiyong Wang <85110830+ravenouse@users.noreply.github.com> Zhiyuan Li Zhiyuan Li +zhouwg <6889919+zhouwg@users.noreply.github.com> +zhouwg ZhouYuChen Ziad Ben Hadj-Alouane Ziang Wu <97337387+ZiangWu-77@users.noreply.github.com> -Zsapi -a-n-n-a-l-e-e <150648636+a-n-n-a-l-e-e@users.noreply.github.com> -a3sh <38979186+A3shTnT@users.noreply.github.com> -adel boussaken -afrideva <95653597+afrideva@users.noreply.github.com> -ag2s20150909 <19373730+ag2s20150909@users.noreply.github.com> -agray3 -akawrykow <142945436+akawrykow@users.noreply.github.com> -alek3y <44779186+alek3y@users.noreply.github.com> -alexpinel <93524949+alexpinel@users.noreply.github.com> -alonfaraj -alwqx -amd-dwang -amd-lalithnc -amritahs-ibm -andrijdavid -anon998 <131767832+anon998@users.noreply.github.com> -anzz1 -apaz -apcameron <37645737+apcameron@users.noreply.github.com> -arch-btw <57669023+arch-btw@users.noreply.github.com> -arcrank -ardfork <134447697+ardfork@users.noreply.github.com> -arlo-phoenix <140345165+arlo-phoenix@users.noreply.github.com> -aryantandon01 <80969509+aryantandon01@users.noreply.github.com> -at8u <129688334+at8u@users.noreply.github.com> -automaticcat -awatuna <23447591+awatuna@users.noreply.github.com> -b4b4o -bandoti <141645996+bandoti@users.noreply.github.com> -beiller -bhubbb <79117352+bhubbb@users.noreply.github.com> -bmwl -bobqianic <129547291+bobqianic@users.noreply.github.com> -brucepro -bryanSwk <93190252+bryanSwk@users.noreply.github.com> -bsilvereagle -bssrdf -byte-6174 <88070277+byte-6174@users.noreply.github.com> -cduk <19917266+cduk@users.noreply.github.com> -cebtenzzre -chaihahaha -chiranko <96988916+chiranko@users.noreply.github.com> -clibdev <52199778+clibdev@users.noreply.github.com> -clyang -cmdr2 -cmdr2 -cocktailpeanut <121128867+cocktailpeanut@users.noreply.github.com> -codezjx -coezbek -comex -compilade <113953597+compilade@users.noreply.github.com> -compilade -cpumaxx <163466046+cpumaxx@users.noreply.github.com> -crasm -crasm -daboe01 -daghanerdonmez <44506702+daghanerdonmez@users.noreply.github.com> -daminho <37615795+daminho@users.noreply.github.com> -david raistrick -ddh0 -ddpasa <112642920+ddpasa@users.noreply.github.com> -deepdiffuser <112834445+deepdiffuser@users.noreply.github.com> -devojony <61173062+devojony@users.noreply.github.com> -ditsuke -divinity76 -dm4 -dm4 -dotpy314 <33351922+dotpy314@users.noreply.github.com> -drbh -ds5t5 <145942675+ds5t5@users.noreply.github.com> -dylan -eastriver -ebraminio -ebraminio -eiery <19350831+eiery@users.noreply.github.com> -eric8607242 -fairydreaming <166155368+fairydreaming@users.noreply.github.com> -fengerhu1 <2748250768@qq.com> -fj-y-saito <85871716+fj-y-saito@users.noreply.github.com> -fraxy-v <65565042+fraxy-v@users.noreply.github.com> -fxzjshm <11426482+fxzjshm@users.noreply.github.com> -github-actions[bot] -gliptic -gn64 -goerch -grahameth <96447521+grahameth@users.noreply.github.com> -gtygo -gwjr <502526+gwjr@users.noreply.github.com> -h-h-h-h <13482553+h-h-h-h@users.noreply.github.com> -hankcs -haopeng <657407891@qq.com> -hipudding -hoangmit -hongbo.mo <352280764@qq.com> -hopkins385 <98618192+hopkins385@users.noreply.github.com> -howlger -howlger -hutli <6594598+hutli@users.noreply.github.com> -hutli -hutli -hxer7963 -hydai -iSma -iacore <74560659+iacore@users.noreply.github.com> -icppWorld <124377669+icppWorld@users.noreply.github.com> -igardev <49397134+igardev@users.noreply.github.com> -igarnier -intelmatt <61025942+intelmatt@users.noreply.github.com> -iohub -issixx <46835150+issixx@users.noreply.github.com> -jacobi petrucciani <8117202+jpetrucciani@users.noreply.github.com> -jaime-m-p <167997752+jaime-m-p@users.noreply.github.com> -jameswu2014 <545426914@qq.com> -jason_w -jdomke <28772296+jdomke@users.noreply.github.com> -jiahao su -jiez <373447296@qq.com> -jneem -joecryptotoo <80373433+joecryptotoo@users.noreply.github.com> -johnson442 <56517414+johnson442@users.noreply.github.com> -jojorne -jon-chuang <9093549+jon-chuang@users.noreply.github.com> -jp-x-g -jukofyork <69222624+jukofyork@users.noreply.github.com> -junchao-loongson <68935141+junchao-loongson@users.noreply.github.com> -junchao-zhao <68935141+junchao-loongson@users.noreply.github.com> -jwj7140 <32943891+jwj7140@users.noreply.github.com> -k.h.lai -kaizau -kallewoof -kalomaze <66376113+kalomaze@users.noreply.github.com> -kang -katsu560 <118887472+katsu560@users.noreply.github.com> -kchro3 <62481661+kchro3@users.noreply.github.com> -khimaros -kiltyj -klosax <131523366+klosax@users.noreply.github.com> -krystiancha -kunal-vaishnavi <115581922+kunal-vaishnavi@users.noreply.github.com> -kunnis -kuronekosaiko -kustaaya <58045274+kustaaya@users.noreply.github.com> -kuvaus <22169537+kuvaus@users.noreply.github.com> -kwin1412 <42286931+kwin1412@users.noreply.github.com> -l3utterfly -laik -ldwang -le.chang -leejet -leo-pony -lexasub -lhez -limitedAtonement -liuwei-git <14815172+liuwei-git@users.noreply.github.com> -lon <114724657+longregen@users.noreply.github.com> -loonerin <132926317+loonerin@users.noreply.github.com> -ltoniazzi <61414566+ltoniazzi@users.noreply.github.com> -luoyu-intel -m3ndax -maddes8cht <55592906+maddes8cht@users.noreply.github.com> -magicse -mahorozte <41834471+mahorozte@users.noreply.github.com> -makomk -manikbhandari -maor-ps <154728172+maor-ps@users.noreply.github.com> -mashdragon <122402293+mashdragon@users.noreply.github.com> -matiaslin <45382001+matiaslin@users.noreply.github.com> -matt23654 -matteo -mdrokz -mgroeber9110 <45620825+mgroeber9110@users.noreply.github.com> -midnight -minarchist -mj-shifu <77107165+mj-shifu@users.noreply.github.com> -mmyjona -momonga <115213907+mmnga@users.noreply.github.com> -momonga <146910567+mmngays@users.noreply.github.com> -moritzbrantner <31051084+moritzbrantner@users.noreply.github.com> -musoles <135031143+musoles@users.noreply.github.com> -mzcu -nanahi <130121847+na-na-hi@users.noreply.github.com> -ngc92 <7938269+ngc92@users.noreply.github.com> -nhamanasu <45545786+nhamanasu@users.noreply.github.com> -niansa/tuxifan -niansa/tuxifan -nickp27 -ningshanwutuobang -nold -nopperl <54780682+nopperl@users.noreply.github.com> -nusu-github <29514220+nusu-github@users.noreply.github.com> -olexiyb -omahs <73983677+omahs@users.noreply.github.com> -oobabooga <112222186+oobabooga@users.noreply.github.com> -opparco -ostix360 <55257054+ostix360@users.noreply.github.com> -pascal-lc <49066376+pascal-lc@users.noreply.github.com> -pculliton -peidaqi -pengxin99 -perserk -petterreinholdtsen -piDack <104877312+piDack@users.noreply.github.com> -pmysl -postmasters -pudepiedj -qingfengfenga <41416092+qingfengfenga@users.noreply.github.com> -qingy1337 -qouoq -qunash -rabidcopy -rankaiyx -redbeard -rhjdvsgsgks <26178113+rhjdvsgsgks@users.noreply.github.com> -rhuddleston -rimoliga <53384203+rimoliga@users.noreply.github.com> -runfuture -sandyiscool -sasha0552 -semidark -serhii-nakon <57632032+serhii-nakon@users.noreply.github.com> -sharpHL <132747147+sharpHL@users.noreply.github.com> -shibe2 -simon886212 <37953122+simon886212@users.noreply.github.com> -singularity <12184989+singularity-s0@users.noreply.github.com> -sjinzh -sjxx <63994076+ylsdamxssjxxdd@users.noreply.github.com> -slaren <2141330+slaren@users.noreply.github.com> -slaren -snadampal <87143774+snadampal@users.noreply.github.com> -someone13574 <81528246+someone13574@users.noreply.github.com> -standby24x7 -staviq -stduhpf -strawberrymelonpanda <152940198+strawberrymelonpanda@users.noreply.github.com> -swittk -takov751 <40316768+takov751@users.noreply.github.com> -tarcey -tc-mb <157115220+tc-mb@users.noreply.github.com> -texmex76 <40733439+texmex76@users.noreply.github.com> -thement <40525767+thement@users.noreply.github.com> -theraininsky <76763719+theraininsky@users.noreply.github.com> -thewh1teagle <61390950+thewh1teagle@users.noreply.github.com> -tjohnman -toyer <2042519524@qq.com> -tslmy -tv1wnd <55383215+tv1wnd@users.noreply.github.com> -ubik2 -uint256_t -uint256_t -unbounded -uvos -uvos -valiray <133289098+valiray@users.noreply.github.com> -vb -vik -viric -vmobilis <75476228+vmobilis@users.noreply.github.com> -vodkaslime <646329483@qq.com> -vvhg1 <94630311+vvhg1@users.noreply.github.com> -vxiiduu <73044267+vxiiduu@users.noreply.github.com> -wangshuai09 <391746016@qq.com> -wbpxre150 <100937007+wbpxre150@users.noreply.github.com> -whoreson <139810751+whoreson@users.noreply.github.com> -woachk <24752637+woachk@users.noreply.github.com> -wonjun Jang -woodx <124784234+woodx9@users.noreply.github.com> -wwoodsTM <104587230+wwoodsTM@users.noreply.github.com> -wzy <32936898+Freed-Wu@users.noreply.github.com> -xaedes -xaedes -xctan -xiaobing318 <71554036+xiaobing318@users.noreply.github.com> -xiaofei -xloem <0xloem@gmail.com> -yangli2 -ymcki <84055651+ymcki@users.noreply.github.com> -yuiseki -yuri@FreeBSD -zakkor -zhangkaihuo -zhentaoyu -zhouwg <6889919+zhouwg@users.noreply.github.com> -zhouwg zrm -Ștefan-Gabriel Muscalu -杨朱 · Kiki -源文雨 <41315874+fumiama@users.noreply.github.com> -蕭澧邦 <45505768+shou692199@users.noreply.github.com> -谢乃闻 -Нияз Гарифзянов <112617865+garrnizon@users.noreply.github.com> +Zsapi diff --git a/CMakeLists.txt b/CMakeLists.txt index d24fa080ae2..6d4ed67020d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -164,29 +164,6 @@ llama_option_depr(WARNING LLAMA_SYCL GGML_SYCL) llama_option_depr(WARNING LLAMA_SYCL_F16 GGML_SYCL_F16) llama_option_depr(WARNING LLAMA_CANN GGML_CANN) -if (NOT MSVC) - if (LLAMA_SANITIZE_THREAD) - message(STATUS "Using -fsanitize=thread") - - add_compile_options(-fsanitize=thread) - link_libraries (-fsanitize=thread) - endif() - - if (LLAMA_SANITIZE_ADDRESS) - message(STATUS "Using -fsanitize=address") - - add_compile_options(-fsanitize=address -fno-omit-frame-pointer) - link_libraries (-fsanitize=address) - endif() - - if (LLAMA_SANITIZE_UNDEFINED) - message(STATUS "Using -fsanitize=undefined") - - add_compile_options(-fsanitize=undefined) - link_libraries (-fsanitize=undefined) - endif() -endif() - include("cmake/license.cmake") license_add_file("llama.cpp" "LICENSE") diff --git a/CODEOWNERS b/CODEOWNERS index 55f5011dfa2..9d252c9b8dc 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -18,6 +18,7 @@ /common/jinja/ @ngxson @CISC @aldehir /common/llguidance.* @ggerganov /common/log.* @ggerganov +/common/ngram-map.* @srogmann /common/peg-parser.* @aldehir /common/sampling.* @ggerganov /common/speculative.* @ggerganov @@ -26,6 +27,7 @@ /examples/batched.swift/ @ggerganov /examples/batched/ @ggerganov /examples/convert-llama2c-to-ggml/ @ggerganov +/examples/debug/ @danbev @pwilkin /examples/deprecation-warning/ @ggerganov /examples/diffusion/ @am17an /examples/embedding/ @ggerganov @@ -67,6 +69,7 @@ /ggml/src/ggml-rpc/ @rgerganov /ggml/src/ggml-threading.* @ggerganov /ggml/src/ggml-vulkan/ @0cc4m +/ggml/src/ggml-virtgpu/ @kpouget /ggml/src/ggml-webgpu/ @reeselevine /ggml/src/ggml-zdnn/ @taronaeo @Andreas-Krebbel @AlekseiNikiforovIBM /ggml/src/ggml.c @ggerganov diff --git a/LICENSE b/LICENSE index acb96ce78e0..e7dca554bcb 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2023-2024 The ggml authors +Copyright (c) 2023-2026 The ggml authors Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/README.md b/README.md index 91a8f25d1c9..dac020ad377 100644 --- a/README.md +++ b/README.md @@ -132,6 +132,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo - [x] [FalconMamba Models](https://huggingface.co/collections/tiiuae/falconmamba-7b-66b9a580324dd1598b0f6d4a) - [x] [Jais](https://huggingface.co/inceptionai/jais-13b-chat) - [x] [Bielik-11B-v2.3](https://huggingface.co/collections/speakleash/bielik-11b-v23-66ee813238d9b526a072408a) +- [x] [RWKV-7](https://huggingface.co/collections/shoumenchougou/rwkv7-gxx-gguf) - [x] [RWKV-6](https://github.com/BlinkDL/RWKV-LM) - [x] [QRWKV-6](https://huggingface.co/recursal/QRWKV6-32B-Instruct-Preview-v0.1) - [x] [GigaChat-20B-A3B](https://huggingface.co/ai-sage/GigaChat-20B-A3B-instruct) @@ -212,6 +213,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo - [llama.vim](https://github.com/ggml-org/llama.vim) (MIT) - [LARS](https://github.com/abgulati/LARS) (AGPL) - [Llama Assistant](https://github.com/vietanhdev/llama-assistant) (GPL) +- [LlamaLib](https://github.com/undreamai/LlamaLib) (Apache-2.0) - [LLMFarm](https://github.com/guinmoon/LLMFarm?tab=readme-ov-file) (MIT) - [LLMUnity](https://github.com/undreamai/LLMUnity) (MIT) - [LMStudio](https://lmstudio.ai/) (proprietary) diff --git a/benches/dgx-spark/dgx-spark.md b/benches/dgx-spark/dgx-spark.md index ec6c20d8a05..fd5c4e2c788 100644 --- a/benches/dgx-spark/dgx-spark.md +++ b/benches/dgx-spark/dgx-spark.md @@ -8,7 +8,7 @@ g++ --version g++ (Ubuntu 13.3.0-6ubuntu2~24.04) 13.3.0 nvidia-smi -Sun Nov 2 10:43:25 2025 +Thu Feb 5 13:49:40 2026 +-----------------------------------------------------------------------------------------+ | NVIDIA-SMI 580.95.05 Driver Version: 580.95.05 CUDA Version: 13.0 | +-----------------------------------------+------------------------+----------------------+ @@ -17,7 +17,7 @@ Sun Nov 2 10:43:25 2025 | | | MIG M. | |=========================================+========================+======================| | 0 NVIDIA GB10 On | 0000000F:01:00.0 Off | N/A | -| N/A 35C P8 4W / N/A | Not Supported | 0% Default | +| N/A 47C P0 13W / N/A | Not Supported | 0% Default | | | | N/A | +-----------------------------------------+------------------------+----------------------+ ``` @@ -29,46 +29,46 @@ Model: https://huggingface.co/ggml-org/gpt-oss-20b-GGUF - `llama-batched-bench` -main: n_kv_max = 270336, n_batch = 2048, n_ubatch = 2048, flash_attn = 1, is_pp_shared = 0, n_gpu_layers = -1, n_threads = 20, n_threads_batch = 20 +main: n_kv_max = 270336, n_batch = 2048, n_ubatch = 2048, flash_attn = 1, is_pp_shared = 0, is_tg_separate = 0, n_gpu_layers = -1, n_threads = 20, n_threads_batch = 20 | PP | TG | B | N_KV | T_PP s | S_PP t/s | T_TG s | S_TG t/s | T s | S t/s | |-------|--------|------|--------|----------|----------|----------|----------|----------|----------| -| 512 | 32 | 1 | 544 | 0.374 | 1369.01 | 0.383 | 83.64 | 0.757 | 719.01 | -| 512 | 32 | 2 | 1088 | 0.274 | 3741.35 | 0.659 | 97.14 | 0.933 | 1166.66 | -| 512 | 32 | 4 | 2176 | 0.526 | 3896.47 | 0.817 | 156.73 | 1.342 | 1621.08 | -| 512 | 32 | 8 | 4352 | 1.044 | 3925.10 | 0.987 | 259.44 | 2.030 | 2143.56 | -| 512 | 32 | 16 | 8704 | 2.076 | 3945.84 | 1.248 | 410.32 | 3.324 | 2618.60 | -| 512 | 32 | 32 | 17408 | 4.170 | 3929.28 | 1.630 | 628.40 | 5.799 | 3001.76 | -| 4096 | 32 | 1 | 4128 | 1.083 | 3782.66 | 0.394 | 81.21 | 1.477 | 2795.13 | -| 4096 | 32 | 2 | 8256 | 2.166 | 3782.72 | 0.725 | 88.28 | 2.891 | 2856.14 | -| 4096 | 32 | 4 | 16512 | 4.333 | 3780.88 | 0.896 | 142.82 | 5.230 | 3157.38 | -| 4096 | 32 | 8 | 33024 | 8.618 | 3802.14 | 1.155 | 221.69 | 9.773 | 3379.08 | -| 4096 | 32 | 16 | 66048 | 17.330 | 3781.73 | 1.598 | 320.34 | 18.928 | 3489.45 | -| 4096 | 32 | 32 | 132096 | 34.671 | 3780.48 | 2.336 | 438.35 | 37.007 | 3569.51 | -| 8192 | 32 | 1 | 8224 | 2.233 | 3668.56 | 0.438 | 72.98 | 2.671 | 3078.44 | -| 8192 | 32 | 2 | 16448 | 4.425 | 3702.95 | 0.756 | 84.66 | 5.181 | 3174.95 | -| 8192 | 32 | 4 | 32896 | 8.859 | 3698.64 | 0.967 | 132.38 | 9.826 | 3347.72 | -| 8192 | 32 | 8 | 65792 | 17.714 | 3699.57 | 1.277 | 200.52 | 18.991 | 3464.35 | -| 8192 | 32 | 16 | 131584 | 35.494 | 3692.84 | 1.841 | 278.12 | 37.335 | 3524.46 | -| 8192 | 32 | 32 | 263168 | 70.949 | 3694.82 | 2.798 | 365.99 | 73.747 | 3568.53 | +| 512 | 32 | 1 | 544 | 0.270 | 1895.57 | 0.399 | 80.13 | 0.669 | 812.60 | +| 512 | 32 | 2 | 1088 | 0.230 | 4451.23 | 0.583 | 109.71 | 0.813 | 1337.56 | +| 512 | 32 | 4 | 2176 | 0.437 | 4688.87 | 0.820 | 156.03 | 1.257 | 1730.91 | +| 512 | 32 | 8 | 4352 | 0.863 | 4744.23 | 0.942 | 271.79 | 1.805 | 2410.73 | +| 512 | 32 | 16 | 8704 | 1.725 | 4748.19 | 1.173 | 436.38 | 2.899 | 3002.85 | +| 512 | 32 | 32 | 17408 | 3.437 | 4767.38 | 1.503 | 681.49 | 4.939 | 3524.40 | +| 4096 | 32 | 1 | 4128 | 0.907 | 4513.91 | 0.407 | 78.54 | 1.315 | 3139.56 | +| 4096 | 32 | 2 | 8256 | 1.796 | 4560.42 | 0.625 | 102.37 | 2.422 | 3409.45 | +| 4096 | 32 | 4 | 16512 | 3.596 | 4555.66 | 0.888 | 144.11 | 4.485 | 3681.93 | +| 4096 | 32 | 8 | 33024 | 7.184 | 4561.44 | 1.098 | 233.11 | 8.282 | 3987.51 | +| 4096 | 32 | 16 | 66048 | 14.369 | 4560.82 | 1.503 | 340.74 | 15.872 | 4161.30 | +| 4096 | 32 | 32 | 132096 | 28.760 | 4557.52 | 2.162 | 473.59 | 30.922 | 4271.95 | +| 8192 | 32 | 1 | 8224 | 1.859 | 4405.59 | 0.430 | 74.36 | 2.290 | 3591.61 | +| 8192 | 32 | 2 | 16448 | 3.698 | 4430.02 | 0.656 | 97.59 | 4.354 | 3777.47 | +| 8192 | 32 | 4 | 32896 | 7.403 | 4426.10 | 0.957 | 133.82 | 8.360 | 3934.97 | +| 8192 | 32 | 8 | 65792 | 14.802 | 4427.63 | 1.222 | 209.44 | 16.024 | 4105.87 | +| 8192 | 32 | 16 | 131584 | 29.596 | 4428.67 | 1.741 | 294.13 | 31.337 | 4199.00 | +| 8192 | 32 | 32 | 263168 | 59.169 | 4430.42 | 2.619 | 390.92 | 61.789 | 4259.17 | - `llama-bench` -| model | size | params | backend | ngl | n_ubatch | fa | mmap | test | t/s | -| ------------------------------ | ---------: | ---------: | ---------- | --: | -------: | -: | ---: | --------------: | -------------------: | -| gpt-oss 20B MXFP4 MoE | 11.27 GiB | 20.91 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 | 3714.25 ± 20.36 | -| gpt-oss 20B MXFP4 MoE | 11.27 GiB | 20.91 B | CUDA | 99 | 2048 | 1 | 0 | tg32 | 86.58 ± 0.43 | -| gpt-oss 20B MXFP4 MoE | 11.27 GiB | 20.91 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d4096 | 3445.17 ± 17.85 | -| gpt-oss 20B MXFP4 MoE | 11.27 GiB | 20.91 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d4096 | 81.72 ± 0.53 | -| gpt-oss 20B MXFP4 MoE | 11.27 GiB | 20.91 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d8192 | 3218.78 ± 11.34 | -| gpt-oss 20B MXFP4 MoE | 11.27 GiB | 20.91 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d8192 | 74.86 ± 0.64 | -| gpt-oss 20B MXFP4 MoE | 11.27 GiB | 20.91 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d16384 | 2732.83 ± 7.17 | -| gpt-oss 20B MXFP4 MoE | 11.27 GiB | 20.91 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d16384 | 71.57 ± 0.51 | -| gpt-oss 20B MXFP4 MoE | 11.27 GiB | 20.91 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d32768 | 2119.75 ± 12.81 | -| gpt-oss 20B MXFP4 MoE | 11.27 GiB | 20.91 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d32768 | 62.33 ± 0.24 | - -build: eeee367de (6989) +| model | size | params | backend | ngl | n_ubatch | fa | mmap | dio | test | t/s | +| ------------------------------ | ---------: | ---------: | ---------- | --: | -------: | -: | ---: | --: | --------------: | -------------------: | +| gpt-oss 20B MXFP4 MoE | 11.27 GiB | 20.91 B | CUDA | 99 | 2048 | 1 | 0 | 1 | pp2048 | 4505.82 ± 12.90 | +| gpt-oss 20B MXFP4 MoE | 11.27 GiB | 20.91 B | CUDA | 99 | 2048 | 1 | 0 | 1 | tg32 | 83.43 ± 0.59 | +| gpt-oss 20B MXFP4 MoE | 11.27 GiB | 20.91 B | CUDA | 99 | 2048 | 1 | 0 | 1 | pp2048 @ d4096 | 4158.34 ± 18.84 | +| gpt-oss 20B MXFP4 MoE | 11.27 GiB | 20.91 B | CUDA | 99 | 2048 | 1 | 0 | 1 | tg32 @ d4096 | 79.22 ± 0.60 | +| gpt-oss 20B MXFP4 MoE | 11.27 GiB | 20.91 B | CUDA | 99 | 2048 | 1 | 0 | 1 | pp2048 @ d8192 | 3993.81 ± 17.55 | +| gpt-oss 20B MXFP4 MoE | 11.27 GiB | 20.91 B | CUDA | 99 | 2048 | 1 | 0 | 1 | tg32 @ d8192 | 75.22 ± 1.05 | +| gpt-oss 20B MXFP4 MoE | 11.27 GiB | 20.91 B | CUDA | 99 | 2048 | 1 | 0 | 1 | pp2048 @ d16384 | 3449.98 ± 12.13 | +| gpt-oss 20B MXFP4 MoE | 11.27 GiB | 20.91 B | CUDA | 99 | 2048 | 1 | 0 | 1 | tg32 @ d16384 | 70.36 ± 0.37 | +| gpt-oss 20B MXFP4 MoE | 11.27 GiB | 20.91 B | CUDA | 99 | 2048 | 1 | 0 | 1 | pp2048 @ d32768 | 2689.42 ± 18.89 | +| gpt-oss 20B MXFP4 MoE | 11.27 GiB | 20.91 B | CUDA | 99 | 2048 | 1 | 0 | 1 | tg32 @ d32768 | 61.65 ± 0.30 | + +build: 11fb327bf (7941) ## ggml-org/gpt-oss-120b-GGUF @@ -77,46 +77,46 @@ Model: https://huggingface.co/ggml-org/gpt-oss-120b-GGUF - `llama-batched-bench` -main: n_kv_max = 270336, n_batch = 2048, n_ubatch = 2048, flash_attn = 1, is_pp_shared = 0, n_gpu_layers = -1, n_threads = 20, n_threads_batch = 20 +main: n_kv_max = 270336, n_batch = 2048, n_ubatch = 2048, flash_attn = 1, is_pp_shared = 0, is_tg_separate = 0, n_gpu_layers = -1, n_threads = 20, n_threads_batch = 20 | PP | TG | B | N_KV | T_PP s | S_PP t/s | T_TG s | S_TG t/s | T s | S t/s | |-------|--------|------|--------|----------|----------|----------|----------|----------|----------| -| 512 | 32 | 1 | 544 | 0.571 | 897.18 | 0.543 | 58.96 | 1.113 | 488.60 | -| 512 | 32 | 2 | 1088 | 0.593 | 1725.37 | 1.041 | 61.45 | 1.635 | 665.48 | -| 512 | 32 | 4 | 2176 | 1.043 | 1963.15 | 1.334 | 95.95 | 2.377 | 915.36 | -| 512 | 32 | 8 | 4352 | 2.099 | 1951.63 | 1.717 | 149.07 | 3.816 | 1140.45 | -| 512 | 32 | 16 | 8704 | 4.207 | 1947.12 | 2.311 | 221.56 | 6.518 | 1335.35 | -| 512 | 32 | 32 | 17408 | 8.422 | 1945.36 | 3.298 | 310.46 | 11.720 | 1485.27 | -| 4096 | 32 | 1 | 4128 | 2.138 | 1915.88 | 0.571 | 56.09 | 2.708 | 1524.12 | -| 4096 | 32 | 2 | 8256 | 4.266 | 1920.25 | 1.137 | 56.27 | 5.404 | 1527.90 | -| 4096 | 32 | 4 | 16512 | 8.564 | 1913.02 | 1.471 | 86.99 | 10.036 | 1645.29 | -| 4096 | 32 | 8 | 33024 | 17.092 | 1917.19 | 1.979 | 129.33 | 19.071 | 1731.63 | -| 4096 | 32 | 16 | 66048 | 34.211 | 1915.65 | 2.850 | 179.66 | 37.061 | 1782.15 | -| 4096 | 32 | 32 | 132096 | 68.394 | 1916.44 | 4.381 | 233.72 | 72.775 | 1815.13 | -| 8192 | 32 | 1 | 8224 | 4.349 | 1883.45 | 0.620 | 51.65 | 4.969 | 1655.04 | -| 8192 | 32 | 2 | 16448 | 8.674 | 1888.83 | 1.178 | 54.33 | 9.852 | 1669.48 | -| 8192 | 32 | 4 | 32896 | 17.351 | 1888.55 | 1.580 | 81.01 | 18.931 | 1737.68 | -| 8192 | 32 | 8 | 65792 | 34.743 | 1886.31 | 2.173 | 117.80 | 36.916 | 1782.20 | -| 8192 | 32 | 16 | 131584 | 69.413 | 1888.29 | 3.297 | 155.28 | 72.710 | 1809.70 | -| 8192 | 32 | 32 | 263168 | 138.903 | 1887.24 | 5.004 | 204.63 | 143.907 | 1828.73 | +| 512 | 32 | 1 | 544 | 0.445 | 1151.80 | 0.560 | 57.14 | 1.005 | 541.53 | +| 512 | 32 | 2 | 1088 | 0.472 | 2169.85 | 0.874 | 73.27 | 1.345 | 808.65 | +| 512 | 32 | 4 | 2176 | 0.826 | 2480.33 | 1.299 | 98.51 | 2.125 | 1023.94 | +| 512 | 32 | 8 | 4352 | 1.644 | 2491.67 | 1.608 | 159.18 | 3.252 | 1338.20 | +| 512 | 32 | 16 | 8704 | 3.292 | 2488.35 | 2.117 | 241.85 | 5.409 | 1609.13 | +| 512 | 32 | 32 | 17408 | 6.604 | 2481.07 | 2.898 | 353.31 | 9.502 | 1832.04 | +| 4096 | 32 | 1 | 4128 | 1.698 | 2412.65 | 0.580 | 55.21 | 2.277 | 1812.66 | +| 4096 | 32 | 2 | 8256 | 3.399 | 2409.88 | 0.934 | 68.53 | 4.333 | 1905.27 | +| 4096 | 32 | 4 | 16512 | 6.823 | 2401.21 | 1.411 | 90.72 | 8.234 | 2005.30 | +| 4096 | 32 | 8 | 33024 | 13.574 | 2413.97 | 1.841 | 139.07 | 15.415 | 2142.31 | +| 4096 | 32 | 16 | 66048 | 27.176 | 2411.52 | 2.609 | 196.26 | 29.785 | 2217.49 | +| 4096 | 32 | 32 | 132096 | 54.359 | 2411.23 | 3.905 | 262.20 | 58.264 | 2267.19 | +| 8192 | 32 | 1 | 8224 | 3.491 | 2346.81 | 0.613 | 52.23 | 4.103 | 2004.21 | +| 8192 | 32 | 2 | 16448 | 6.939 | 2361.03 | 0.981 | 65.21 | 7.921 | 2076.56 | +| 8192 | 32 | 4 | 32896 | 13.888 | 2359.40 | 1.511 | 84.71 | 15.399 | 2136.21 | +| 8192 | 32 | 8 | 65792 | 27.756 | 2361.18 | 2.034 | 125.86 | 29.790 | 2208.56 | +| 8192 | 32 | 16 | 131584 | 55.554 | 2359.34 | 3.021 | 169.49 | 58.575 | 2246.41 | +| 8192 | 32 | 32 | 263168 | 111.036 | 2360.89 | 4.537 | 225.72 | 115.573 | 2277.08 | - `llama-bench` -| model | size | params | backend | ngl | n_ubatch | fa | mmap | test | t/s | -| ------------------------------ | ---------: | ---------: | ---------- | --: | -------: | -: | ---: | --------------: | -------------------: | -| gpt-oss 120B MXFP4 MoE | 59.02 GiB | 116.83 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 | 1919.36 ± 5.01 | -| gpt-oss 120B MXFP4 MoE | 59.02 GiB | 116.83 B | CUDA | 99 | 2048 | 1 | 0 | tg32 | 60.40 ± 0.30 | -| gpt-oss 120B MXFP4 MoE | 59.02 GiB | 116.83 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d4096 | 1825.30 ± 6.37 | -| gpt-oss 120B MXFP4 MoE | 59.02 GiB | 116.83 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d4096 | 56.94 ± 0.29 | -| gpt-oss 120B MXFP4 MoE | 59.02 GiB | 116.83 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d8192 | 1739.19 ± 6.00 | -| gpt-oss 120B MXFP4 MoE | 59.02 GiB | 116.83 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d8192 | 52.51 ± 0.42 | -| gpt-oss 120B MXFP4 MoE | 59.02 GiB | 116.83 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d16384 | 1536.75 ± 4.27 | -| gpt-oss 120B MXFP4 MoE | 59.02 GiB | 116.83 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d16384 | 49.33 ± 0.27 | -| gpt-oss 120B MXFP4 MoE | 59.02 GiB | 116.83 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d32768 | 1255.85 ± 3.26 | -| gpt-oss 120B MXFP4 MoE | 59.02 GiB | 116.83 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d32768 | 42.99 ± 0.18 | - -build: eeee367de (6989) +| model | size | params | backend | ngl | n_ubatch | fa | mmap | dio | test | t/s | +| ------------------------------ | ---------: | ---------: | ---------- | --: | -------: | -: | ---: | --: | --------------: | -------------------: | +| gpt-oss 120B MXFP4 MoE | 59.02 GiB | 116.83 B | CUDA | 99 | 2048 | 1 | 0 | 1 | pp2048 | 2443.91 ± 7.47 | +| gpt-oss 120B MXFP4 MoE | 59.02 GiB | 116.83 B | CUDA | 99 | 2048 | 1 | 0 | 1 | tg32 | 58.72 ± 0.20 | +| gpt-oss 120B MXFP4 MoE | 59.02 GiB | 116.83 B | CUDA | 99 | 2048 | 1 | 0 | 1 | pp2048 @ d4096 | 2309.84 ± 3.63 | +| gpt-oss 120B MXFP4 MoE | 59.02 GiB | 116.83 B | CUDA | 99 | 2048 | 1 | 0 | 1 | tg32 @ d4096 | 55.67 ± 0.35 | +| gpt-oss 120B MXFP4 MoE | 59.02 GiB | 116.83 B | CUDA | 99 | 2048 | 1 | 0 | 1 | pp2048 @ d8192 | 2216.68 ± 10.16 | +| gpt-oss 120B MXFP4 MoE | 59.02 GiB | 116.83 B | CUDA | 99 | 2048 | 1 | 0 | 1 | tg32 @ d8192 | 52.87 ± 0.43 | +| gpt-oss 120B MXFP4 MoE | 59.02 GiB | 116.83 B | CUDA | 99 | 2048 | 1 | 0 | 1 | pp2048 @ d16384 | 1956.31 ± 6.39 | +| gpt-oss 120B MXFP4 MoE | 59.02 GiB | 116.83 B | CUDA | 99 | 2048 | 1 | 0 | 1 | tg32 @ d16384 | 49.45 ± 0.20 | +| gpt-oss 120B MXFP4 MoE | 59.02 GiB | 116.83 B | CUDA | 99 | 2048 | 1 | 0 | 1 | pp2048 @ d32768 | 1567.08 ± 11.79 | +| gpt-oss 120B MXFP4 MoE | 59.02 GiB | 116.83 B | CUDA | 99 | 2048 | 1 | 0 | 1 | tg32 @ d32768 | 42.76 ± 0.14 | + +build: 11fb327bf (7941) ## ggml-org/Qwen3-Coder-30B-A3B-Instruct-Q8_0-GGUF @@ -125,46 +125,46 @@ Model: https://huggingface.co/ggml-org/Qwen3-Coder-30B-A3B-Instruct-Q8_0-GGUF - `llama-batched-bench` -main: n_kv_max = 270336, n_batch = 2048, n_ubatch = 2048, flash_attn = 1, is_pp_shared = 0, n_gpu_layers = -1, n_threads = 20, n_threads_batch = 20 +main: n_kv_max = 270336, n_batch = 2048, n_ubatch = 2048, flash_attn = 1, is_pp_shared = 0, is_tg_separate = 0, n_gpu_layers = -1, n_threads = 20, n_threads_batch = 20 | PP | TG | B | N_KV | T_PP s | S_PP t/s | T_TG s | S_TG t/s | T s | S t/s | |-------|--------|------|--------|----------|----------|----------|----------|----------|----------| -| 512 | 32 | 1 | 544 | 0.398 | 1285.90 | 0.530 | 60.41 | 0.928 | 586.27 | -| 512 | 32 | 2 | 1088 | 0.386 | 2651.65 | 0.948 | 67.50 | 1.334 | 815.38 | -| 512 | 32 | 4 | 2176 | 0.666 | 3076.37 | 1.209 | 105.87 | 1.875 | 1160.71 | -| 512 | 32 | 8 | 4352 | 1.325 | 3091.39 | 1.610 | 158.98 | 2.935 | 1482.65 | -| 512 | 32 | 16 | 8704 | 2.664 | 3075.58 | 2.150 | 238.19 | 4.813 | 1808.39 | -| 512 | 32 | 32 | 17408 | 5.336 | 3070.31 | 2.904 | 352.59 | 8.240 | 2112.50 | -| 4096 | 32 | 1 | 4128 | 1.444 | 2836.81 | 0.581 | 55.09 | 2.025 | 2038.81 | -| 4096 | 32 | 2 | 8256 | 2.872 | 2852.14 | 1.084 | 59.06 | 3.956 | 2086.99 | -| 4096 | 32 | 4 | 16512 | 5.744 | 2852.32 | 1.440 | 88.90 | 7.184 | 2298.47 | -| 4096 | 32 | 8 | 33024 | 11.463 | 2858.68 | 2.068 | 123.78 | 13.531 | 2440.65 | -| 4096 | 32 | 16 | 66048 | 22.915 | 2859.95 | 3.018 | 169.67 | 25.933 | 2546.90 | -| 4096 | 32 | 32 | 132096 | 45.956 | 2852.10 | 4.609 | 222.18 | 50.565 | 2612.39 | -| 8192 | 32 | 1 | 8224 | 3.063 | 2674.72 | 0.693 | 46.20 | 3.755 | 2189.92 | -| 8192 | 32 | 2 | 16448 | 6.109 | 2681.87 | 1.214 | 52.71 | 7.323 | 2245.98 | -| 8192 | 32 | 4 | 32896 | 12.197 | 2686.63 | 1.682 | 76.11 | 13.878 | 2370.30 | -| 8192 | 32 | 8 | 65792 | 24.409 | 2684.94 | 2.556 | 100.17 | 26.965 | 2439.95 | -| 8192 | 32 | 16 | 131584 | 48.753 | 2688.50 | 3.994 | 128.20 | 52.747 | 2494.64 | -| 8192 | 32 | 32 | 263168 | 97.508 | 2688.42 | 6.528 | 156.86 | 104.037 | 2529.57 | +| 512 | 32 | 1 | 544 | 0.393 | 1303.73 | 0.548 | 58.36 | 0.941 | 578.10 | +| 512 | 32 | 2 | 1088 | 0.387 | 2648.68 | 0.910 | 70.35 | 1.296 | 839.27 | +| 512 | 32 | 4 | 2176 | 0.659 | 3107.63 | 1.302 | 98.33 | 1.961 | 1109.77 | +| 512 | 32 | 8 | 4352 | 1.322 | 3099.35 | 1.669 | 153.42 | 2.990 | 1455.43 | +| 512 | 32 | 16 | 8704 | 2.639 | 3104.63 | 2.212 | 231.44 | 4.851 | 1794.32 | +| 512 | 32 | 32 | 17408 | 5.284 | 3100.80 | 2.955 | 346.53 | 8.239 | 2112.93 | +| 4096 | 32 | 1 | 4128 | 1.417 | 2890.36 | 0.598 | 53.51 | 2.015 | 2048.45 | +| 4096 | 32 | 2 | 8256 | 2.829 | 2895.62 | 1.019 | 62.82 | 3.848 | 2145.60 | +| 4096 | 32 | 4 | 16512 | 5.656 | 2896.96 | 1.528 | 83.79 | 7.183 | 2298.71 | +| 4096 | 32 | 8 | 33024 | 11.338 | 2890.02 | 2.127 | 120.36 | 13.465 | 2452.53 | +| 4096 | 32 | 16 | 66048 | 22.709 | 2885.96 | 3.104 | 164.97 | 25.812 | 2558.79 | +| 4096 | 32 | 32 | 132096 | 45.301 | 2893.35 | 4.723 | 216.80 | 50.024 | 2640.63 | +| 8192 | 32 | 1 | 8224 | 3.022 | 2711.09 | 0.678 | 47.20 | 3.700 | 2222.89 | +| 8192 | 32 | 2 | 16448 | 6.039 | 2713.01 | 1.149 | 55.70 | 7.188 | 2288.21 | +| 8192 | 32 | 4 | 32896 | 12.050 | 2719.35 | 1.785 | 71.69 | 13.835 | 2377.67 | +| 8192 | 32 | 8 | 65792 | 24.113 | 2717.90 | 2.629 | 97.39 | 26.741 | 2460.31 | +| 8192 | 32 | 16 | 131584 | 48.178 | 2720.58 | 4.099 | 124.91 | 52.277 | 2517.06 | +| 8192 | 32 | 32 | 263168 | 96.401 | 2719.31 | 6.696 | 152.93 | 103.097 | 2552.63 | - `llama-bench` -| model | size | params | backend | ngl | n_ubatch | fa | mmap | test | t/s | -| ------------------------------ | ---------: | ---------: | ---------- | --: | -------: | -: | ---: | --------------: | -------------------: | -| qwen3moe 30B.A3B Q8_0 | 30.25 GiB | 30.53 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 | 2925.55 ± 4.25 | -| qwen3moe 30B.A3B Q8_0 | 30.25 GiB | 30.53 B | CUDA | 99 | 2048 | 1 | 0 | tg32 | 62.80 ± 0.27 | -| qwen3moe 30B.A3B Q8_0 | 30.25 GiB | 30.53 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d4096 | 2531.01 ± 6.79 | -| qwen3moe 30B.A3B Q8_0 | 30.25 GiB | 30.53 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d4096 | 55.86 ± 0.33 | -| qwen3moe 30B.A3B Q8_0 | 30.25 GiB | 30.53 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d8192 | 2244.39 ± 5.33 | -| qwen3moe 30B.A3B Q8_0 | 30.25 GiB | 30.53 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d8192 | 45.95 ± 0.33 | -| qwen3moe 30B.A3B Q8_0 | 30.25 GiB | 30.53 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d16384 | 1783.17 ± 3.68 | -| qwen3moe 30B.A3B Q8_0 | 30.25 GiB | 30.53 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d16384 | 39.07 ± 0.10 | -| qwen3moe 30B.A3B Q8_0 | 30.25 GiB | 30.53 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d32768 | 1241.90 ± 3.13 | -| qwen3moe 30B.A3B Q8_0 | 30.25 GiB | 30.53 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d32768 | 29.92 ± 0.06 | - -build: eeee367de (6989) +| model | size | params | backend | ngl | n_ubatch | fa | mmap | dio | test | t/s | +| ------------------------------ | ---------: | ---------: | ---------- | --: | -------: | -: | ---: | --: | --------------: | -------------------: | +| qwen3moe 30B.A3B Q8_0 | 30.25 GiB | 30.53 B | CUDA | 99 | 2048 | 1 | 0 | 1 | pp2048 | 2986.97 ± 18.87 | +| qwen3moe 30B.A3B Q8_0 | 30.25 GiB | 30.53 B | CUDA | 99 | 2048 | 1 | 0 | 1 | tg32 | 61.06 ± 0.23 | +| qwen3moe 30B.A3B Q8_0 | 30.25 GiB | 30.53 B | CUDA | 99 | 2048 | 1 | 0 | 1 | pp2048 @ d4096 | 2633.45 ± 6.26 | +| qwen3moe 30B.A3B Q8_0 | 30.25 GiB | 30.53 B | CUDA | 99 | 2048 | 1 | 0 | 1 | tg32 @ d4096 | 54.77 ± 0.28 | +| qwen3moe 30B.A3B Q8_0 | 30.25 GiB | 30.53 B | CUDA | 99 | 2048 | 1 | 0 | 1 | pp2048 @ d8192 | 2354.14 ± 3.84 | +| qwen3moe 30B.A3B Q8_0 | 30.25 GiB | 30.53 B | CUDA | 99 | 2048 | 1 | 0 | 1 | tg32 @ d8192 | 48.02 ± 0.40 | +| qwen3moe 30B.A3B Q8_0 | 30.25 GiB | 30.53 B | CUDA | 99 | 2048 | 1 | 0 | 1 | pp2048 @ d16384 | 1908.86 ± 4.25 | +| qwen3moe 30B.A3B Q8_0 | 30.25 GiB | 30.53 B | CUDA | 99 | 2048 | 1 | 0 | 1 | tg32 @ d16384 | 40.23 ± 0.10 | +| qwen3moe 30B.A3B Q8_0 | 30.25 GiB | 30.53 B | CUDA | 99 | 2048 | 1 | 0 | 1 | pp2048 @ d32768 | 1348.17 ± 2.00 | +| qwen3moe 30B.A3B Q8_0 | 30.25 GiB | 30.53 B | CUDA | 99 | 2048 | 1 | 0 | 1 | tg32 @ d32768 | 30.21 ± 0.04 | + +build: 11fb327bf (7941) ## ggml-org/Qwen2.5-Coder-7B-Q8_0-GGUF @@ -173,46 +173,46 @@ Model: https://huggingface.co/ggml-org/Qwen2.5-Coder-7B-Q8_0-GGUF - `llama-batched-bench` -main: n_kv_max = 270336, n_batch = 2048, n_ubatch = 2048, flash_attn = 1, is_pp_shared = 0, n_gpu_layers = -1, n_threads = 20, n_threads_batch = 20 +main: n_kv_max = 270336, n_batch = 2048, n_ubatch = 2048, flash_attn = 1, is_pp_shared = 0, is_tg_separate = 0, n_gpu_layers = -1, n_threads = 20, n_threads_batch = 20 | PP | TG | B | N_KV | T_PP s | S_PP t/s | T_TG s | S_TG t/s | T s | S t/s | |-------|--------|------|--------|----------|----------|----------|----------|----------|----------| -| 512 | 32 | 1 | 544 | 0.211 | 2421.57 | 1.055 | 30.33 | 1.266 | 429.57 | -| 512 | 32 | 2 | 1088 | 0.419 | 2441.34 | 1.130 | 56.65 | 1.549 | 702.32 | -| 512 | 32 | 4 | 2176 | 0.873 | 2345.54 | 1.174 | 108.99 | 2.048 | 1062.74 | -| 512 | 32 | 8 | 4352 | 1.727 | 2371.85 | 1.254 | 204.22 | 2.980 | 1460.19 | -| 512 | 32 | 16 | 8704 | 3.452 | 2373.22 | 1.492 | 343.16 | 4.944 | 1760.56 | -| 512 | 32 | 32 | 17408 | 6.916 | 2368.93 | 1.675 | 611.51 | 8.591 | 2026.36 | -| 4096 | 32 | 1 | 4128 | 1.799 | 2277.26 | 1.084 | 29.51 | 2.883 | 1431.91 | -| 4096 | 32 | 2 | 8256 | 3.577 | 2290.01 | 1.196 | 53.50 | 4.774 | 1729.51 | -| 4096 | 32 | 4 | 16512 | 7.172 | 2284.36 | 1.313 | 97.50 | 8.485 | 1946.00 | -| 4096 | 32 | 8 | 33024 | 14.341 | 2284.96 | 1.520 | 168.46 | 15.860 | 2082.18 | -| 4096 | 32 | 16 | 66048 | 28.675 | 2285.44 | 1.983 | 258.21 | 30.658 | 2154.33 | -| 4096 | 32 | 32 | 132096 | 57.354 | 2285.32 | 2.640 | 387.87 | 59.994 | 2201.82 | -| 8192 | 32 | 1 | 8224 | 3.701 | 2213.75 | 1.119 | 28.59 | 4.820 | 1706.34 | -| 8192 | 32 | 2 | 16448 | 7.410 | 2211.19 | 1.272 | 50.31 | 8.682 | 1894.56 | -| 8192 | 32 | 4 | 32896 | 14.802 | 2213.83 | 1.460 | 87.68 | 16.261 | 2022.96 | -| 8192 | 32 | 8 | 65792 | 29.609 | 2213.35 | 1.781 | 143.74 | 31.390 | 2095.93 | -| 8192 | 32 | 16 | 131584 | 59.229 | 2212.96 | 2.495 | 205.17 | 61.725 | 2131.79 | -| 8192 | 32 | 32 | 263168 | 118.449 | 2213.15 | 3.714 | 275.75 | 122.162 | 2154.25 | +| 512 | 32 | 1 | 544 | 0.212 | 2420.12 | 1.100 | 29.10 | 1.311 | 414.85 | +| 512 | 32 | 2 | 1088 | 0.428 | 2393.89 | 1.185 | 54.00 | 1.613 | 674.56 | +| 512 | 32 | 4 | 2176 | 0.894 | 2290.41 | 1.229 | 104.17 | 2.123 | 1025.02 | +| 512 | 32 | 8 | 4352 | 1.758 | 2330.36 | 1.319 | 194.15 | 3.076 | 1414.70 | +| 512 | 32 | 16 | 8704 | 3.508 | 2335.21 | 1.543 | 331.90 | 5.051 | 1723.33 | +| 512 | 32 | 32 | 17408 | 7.035 | 2328.93 | 1.738 | 589.21 | 8.773 | 1984.29 | +| 4096 | 32 | 1 | 4128 | 1.831 | 2237.25 | 1.125 | 28.44 | 2.956 | 1396.42 | +| 4096 | 32 | 2 | 8256 | 3.642 | 2249.48 | 1.253 | 51.07 | 4.895 | 1686.64 | +| 4096 | 32 | 4 | 16512 | 7.274 | 2252.26 | 1.380 | 92.72 | 8.655 | 1907.81 | +| 4096 | 32 | 8 | 33024 | 14.576 | 2248.09 | 1.617 | 158.29 | 16.193 | 2039.37 | +| 4096 | 32 | 16 | 66048 | 29.138 | 2249.17 | 2.081 | 246.01 | 31.219 | 2115.63 | +| 4096 | 32 | 32 | 132096 | 58.275 | 2249.19 | 2.814 | 363.87 | 61.089 | 2162.34 | +| 8192 | 32 | 1 | 8224 | 3.757 | 2180.26 | 1.184 | 27.03 | 4.941 | 1664.37 | +| 8192 | 32 | 2 | 16448 | 7.522 | 2178.05 | 1.341 | 47.73 | 8.863 | 1855.77 | +| 8192 | 32 | 4 | 32896 | 15.043 | 2178.25 | 1.548 | 82.69 | 16.591 | 1982.74 | +| 8192 | 32 | 8 | 65792 | 30.111 | 2176.49 | 1.937 | 132.13 | 32.048 | 2052.90 | +| 8192 | 32 | 16 | 131584 | 60.405 | 2169.90 | 2.706 | 189.21 | 63.111 | 2084.97 | +| 8192 | 32 | 32 | 263168 | 120.439 | 2176.58 | 3.993 | 256.46 | 124.432 | 2114.96 | - `llama-bench` -| model | size | params | backend | ngl | n_ubatch | fa | mmap | test | t/s | -| ------------------------------ | ---------: | ---------: | ---------- | --: | -------: | -: | ---: | --------------: | -------------------: | -| qwen2 7B Q8_0 | 7.54 GiB | 7.62 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 | 2272.74 ± 4.68 | -| qwen2 7B Q8_0 | 7.54 GiB | 7.62 B | CUDA | 99 | 2048 | 1 | 0 | tg32 | 30.66 ± 0.02 | -| qwen2 7B Q8_0 | 7.54 GiB | 7.62 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d4096 | 2107.80 ± 9.55 | -| qwen2 7B Q8_0 | 7.54 GiB | 7.62 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d4096 | 29.71 ± 0.05 | -| qwen2 7B Q8_0 | 7.54 GiB | 7.62 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d8192 | 1937.80 ± 6.75 | -| qwen2 7B Q8_0 | 7.54 GiB | 7.62 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d8192 | 28.86 ± 0.04 | -| qwen2 7B Q8_0 | 7.54 GiB | 7.62 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d16384 | 1641.12 ± 1.78 | -| qwen2 7B Q8_0 | 7.54 GiB | 7.62 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d16384 | 27.24 ± 0.04 | -| qwen2 7B Q8_0 | 7.54 GiB | 7.62 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d32768 | 1296.02 ± 2.67 | -| qwen2 7B Q8_0 | 7.54 GiB | 7.62 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d32768 | 23.78 ± 0.03 | - -build: eeee367de (6989) +| model | size | params | backend | ngl | n_ubatch | fa | mmap | dio | test | t/s | +| ------------------------------ | ---------: | ---------: | ---------- | --: | -------: | -: | ---: | --: | --------------: | -------------------: | +| qwen2 7B Q8_0 | 7.54 GiB | 7.62 B | CUDA | 99 | 2048 | 1 | 0 | 1 | pp2048 | 2250.28 ± 6.41 | +| qwen2 7B Q8_0 | 7.54 GiB | 7.62 B | CUDA | 99 | 2048 | 1 | 0 | 1 | tg32 | 29.43 ± 0.02 | +| qwen2 7B Q8_0 | 7.54 GiB | 7.62 B | CUDA | 99 | 2048 | 1 | 0 | 1 | pp2048 @ d4096 | 2100.19 ± 8.96 | +| qwen2 7B Q8_0 | 7.54 GiB | 7.62 B | CUDA | 99 | 2048 | 1 | 0 | 1 | tg32 @ d4096 | 28.61 ± 0.02 | +| qwen2 7B Q8_0 | 7.54 GiB | 7.62 B | CUDA | 99 | 2048 | 1 | 0 | 1 | pp2048 @ d8192 | 2007.56 ± 4.16 | +| qwen2 7B Q8_0 | 7.54 GiB | 7.62 B | CUDA | 99 | 2048 | 1 | 0 | 1 | tg32 @ d8192 | 27.38 ± 0.09 | +| qwen2 7B Q8_0 | 7.54 GiB | 7.62 B | CUDA | 99 | 2048 | 1 | 0 | 1 | pp2048 @ d16384 | 1779.11 ± 6.42 | +| qwen2 7B Q8_0 | 7.54 GiB | 7.62 B | CUDA | 99 | 2048 | 1 | 0 | 1 | tg32 @ d16384 | 25.72 ± 0.03 | +| qwen2 7B Q8_0 | 7.54 GiB | 7.62 B | CUDA | 99 | 2048 | 1 | 0 | 1 | pp2048 @ d32768 | 1471.23 ± 1.71 | +| qwen2 7B Q8_0 | 7.54 GiB | 7.62 B | CUDA | 99 | 2048 | 1 | 0 | 1 | tg32 @ d32768 | 22.51 ± 0.02 | + +build: 11fb327bf (7941) ## ggml-org/gemma-3-4b-it-qat-GGUF @@ -221,44 +221,91 @@ Model: https://huggingface.co/ggml-org/gemma-3-4b-it-qat-GGUF - `llama-batched-bench` -main: n_kv_max = 270336, n_batch = 2048, n_ubatch = 2048, flash_attn = 1, is_pp_shared = 0, n_gpu_layers = -1, n_threads = 20, n_threads_batch = 20 +main: n_kv_max = 270336, n_batch = 2048, n_ubatch = 2048, flash_attn = 1, is_pp_shared = 0, is_tg_separate = 0, n_gpu_layers = -1, n_threads = 20, n_threads_batch = 20 | PP | TG | B | N_KV | T_PP s | S_PP t/s | T_TG s | S_TG t/s | T s | S t/s | |-------|--------|------|--------|----------|----------|----------|----------|----------|----------| -| 512 | 32 | 1 | 544 | 0.094 | 5434.73 | 0.394 | 81.21 | 0.488 | 1114.15 | -| 512 | 32 | 2 | 1088 | 0.168 | 6091.68 | 0.498 | 128.52 | 0.666 | 1633.41 | -| 512 | 32 | 4 | 2176 | 0.341 | 6010.68 | 0.542 | 236.37 | 0.882 | 2466.43 | -| 512 | 32 | 8 | 4352 | 0.665 | 6161.46 | 0.678 | 377.74 | 1.342 | 3241.72 | -| 512 | 32 | 16 | 8704 | 1.323 | 6193.19 | 0.902 | 567.41 | 2.225 | 3911.74 | -| 512 | 32 | 32 | 17408 | 2.642 | 6202.03 | 1.231 | 832.03 | 3.872 | 4495.36 | -| 4096 | 32 | 1 | 4128 | 0.701 | 5840.49 | 0.439 | 72.95 | 1.140 | 3621.23 | -| 4096 | 32 | 2 | 8256 | 1.387 | 5906.82 | 0.574 | 111.48 | 1.961 | 4210.12 | -| 4096 | 32 | 4 | 16512 | 2.758 | 5940.33 | 0.651 | 196.58 | 3.409 | 4843.33 | -| 4096 | 32 | 8 | 33024 | 5.491 | 5967.56 | 0.876 | 292.40 | 6.367 | 5187.12 | -| 4096 | 32 | 16 | 66048 | 10.978 | 5969.58 | 1.275 | 401.69 | 12.253 | 5390.38 | -| 4096 | 32 | 32 | 132096 | 21.944 | 5972.93 | 1.992 | 514.16 | 23.936 | 5518.73 | -| 8192 | 32 | 1 | 8224 | 1.402 | 5841.91 | 0.452 | 70.73 | 1.855 | 4434.12 | -| 8192 | 32 | 2 | 16448 | 2.793 | 5865.34 | 0.637 | 100.55 | 3.430 | 4795.51 | -| 8192 | 32 | 4 | 32896 | 5.564 | 5889.64 | 0.770 | 166.26 | 6.334 | 5193.95 | -| 8192 | 32 | 8 | 65792 | 11.114 | 5896.44 | 1.122 | 228.07 | 12.237 | 5376.51 | -| 8192 | 32 | 16 | 131584 | 22.210 | 5901.38 | 1.789 | 286.15 | 24.000 | 5482.74 | -| 8192 | 32 | 32 | 263168 | 44.382 | 5906.56 | 3.044 | 336.38 | 47.426 | 5549.02 | +| 512 | 32 | 1 | 544 | 0.092 | 5566.97 | 0.412 | 77.63 | 0.504 | 1078.95 | +| 512 | 32 | 2 | 1088 | 0.161 | 6345.67 | 0.522 | 122.70 | 0.683 | 1593.06 | +| 512 | 32 | 4 | 2176 | 0.325 | 6309.87 | 0.562 | 227.68 | 0.887 | 2453.87 | +| 512 | 32 | 8 | 4352 | 0.643 | 6374.42 | 0.685 | 373.67 | 1.328 | 3277.94 | +| 512 | 32 | 16 | 8704 | 1.277 | 6413.64 | 0.915 | 559.47 | 2.192 | 3970.01 | +| 512 | 32 | 32 | 17408 | 2.518 | 6506.57 | 1.249 | 819.61 | 3.767 | 4620.64 | +| 4096 | 32 | 1 | 4128 | 0.674 | 6079.68 | 0.453 | 70.60 | 1.127 | 3662.88 | +| 4096 | 32 | 2 | 8256 | 1.335 | 6137.82 | 0.627 | 102.03 | 1.962 | 4208.11 | +| 4096 | 32 | 4 | 16512 | 2.657 | 6167.35 | 0.749 | 170.92 | 3.405 | 4848.71 | +| 4096 | 32 | 8 | 33024 | 5.307 | 6173.91 | 0.974 | 262.89 | 6.281 | 5257.53 | +| 4096 | 32 | 16 | 66048 | 10.610 | 6176.96 | 1.379 | 371.42 | 11.988 | 5509.40 | +| 4096 | 32 | 32 | 132096 | 21.213 | 6178.89 | 2.122 | 482.50 | 23.335 | 5660.82 | +| 8192 | 32 | 1 | 8224 | 1.359 | 6027.34 | 0.467 | 68.52 | 1.826 | 4503.48 | +| 8192 | 32 | 2 | 16448 | 2.699 | 6069.68 | 0.653 | 98.03 | 3.352 | 4906.68 | +| 8192 | 32 | 4 | 32896 | 5.366 | 6106.74 | 0.818 | 156.55 | 6.184 | 5319.96 | +| 8192 | 32 | 8 | 65792 | 10.755 | 6093.50 | 1.174 | 218.04 | 11.929 | 5515.22 | +| 8192 | 32 | 16 | 131584 | 21.484 | 6100.82 | 1.829 | 279.90 | 23.314 | 5644.11 | +| 8192 | 32 | 32 | 263168 | 42.950 | 6103.40 | 3.058 | 334.91 | 46.008 | 5720.05 | - `llama-bench` -| model | size | params | backend | ngl | n_ubatch | fa | mmap | test | t/s | -| ------------------------------ | ---------: | ---------: | ---------- | --: | -------: | -: | ---: | --------------: | -------------------: | -| gemma3 4B Q4_0 | 2.35 GiB | 3.88 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 | 5810.04 ± 21.71 | -| gemma3 4B Q4_0 | 2.35 GiB | 3.88 B | CUDA | 99 | 2048 | 1 | 0 | tg32 | 84.54 ± 0.18 | -| gemma3 4B Q4_0 | 2.35 GiB | 3.88 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d4096 | 5288.04 ± 3.54 | -| gemma3 4B Q4_0 | 2.35 GiB | 3.88 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d4096 | 78.82 ± 1.37 | -| gemma3 4B Q4_0 | 2.35 GiB | 3.88 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d8192 | 4960.43 ± 16.64 | -| gemma3 4B Q4_0 | 2.35 GiB | 3.88 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d8192 | 74.13 ± 0.30 | -| gemma3 4B Q4_0 | 2.35 GiB | 3.88 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d16384 | 4495.92 ± 31.11 | -| gemma3 4B Q4_0 | 2.35 GiB | 3.88 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d16384 | 72.37 ± 0.29 | -| gemma3 4B Q4_0 | 2.35 GiB | 3.88 B | CUDA | 99 | 2048 | 1 | 0 | pp2048 @ d32768 | 3746.90 ± 40.01 | -| gemma3 4B Q4_0 | 2.35 GiB | 3.88 B | CUDA | 99 | 2048 | 1 | 0 | tg32 @ d32768 | 63.02 ± 0.20 | - -build: eeee367de (6989) +| model | size | params | backend | ngl | n_ubatch | fa | mmap | dio | test | t/s | +| ------------------------------ | ---------: | ---------: | ---------- | --: | -------: | -: | ---: | --: | --------------: | -------------------: | +| gemma3 4B Q4_0 | 2.35 GiB | 3.88 B | CUDA | 99 | 2048 | 1 | 0 | 1 | pp2048 | 5948.74 ± 10.61 | +| gemma3 4B Q4_0 | 2.35 GiB | 3.88 B | CUDA | 99 | 2048 | 1 | 0 | 1 | tg32 | 81.05 ± 0.20 | +| gemma3 4B Q4_0 | 2.35 GiB | 3.88 B | CUDA | 99 | 2048 | 1 | 0 | 1 | pp2048 @ d4096 | 5652.69 ± 34.29 | +| gemma3 4B Q4_0 | 2.35 GiB | 3.88 B | CUDA | 99 | 2048 | 1 | 0 | 1 | tg32 @ d4096 | 76.37 ± 0.58 | +| gemma3 4B Q4_0 | 2.35 GiB | 3.88 B | CUDA | 99 | 2048 | 1 | 0 | 1 | pp2048 @ d8192 | 5509.57 ± 40.69 | +| gemma3 4B Q4_0 | 2.35 GiB | 3.88 B | CUDA | 99 | 2048 | 1 | 0 | 1 | tg32 @ d8192 | 71.61 ± 0.80 | +| gemma3 4B Q4_0 | 2.35 GiB | 3.88 B | CUDA | 99 | 2048 | 1 | 0 | 1 | pp2048 @ d16384 | 5340.86 ± 36.92 | +| gemma3 4B Q4_0 | 2.35 GiB | 3.88 B | CUDA | 99 | 2048 | 1 | 0 | 1 | tg32 @ d16384 | 70.89 ± 0.34 | +| gemma3 4B Q4_0 | 2.35 GiB | 3.88 B | CUDA | 99 | 2048 | 1 | 0 | 1 | pp2048 @ d32768 | 5023.30 ± 13.52 | +| gemma3 4B Q4_0 | 2.35 GiB | 3.88 B | CUDA | 99 | 2048 | 1 | 0 | 1 | tg32 @ d32768 | 62.28 ± 0.30 | +build: 11fb327bf (7941) + +## ggml-org/GLM-4.7-Flash-GGUF + +Model: https://huggingface.co/ggml-org/GLM-4.7-Flash-GGUF + +- `llama-batched-bench` + + +main: n_kv_max = 270336, n_batch = 2048, n_ubatch = 2048, flash_attn = 1, is_pp_shared = 0, is_tg_separate = 0, n_gpu_layers = -1, n_threads = 20, n_threads_batch = 20 + +| PP | TG | B | N_KV | T_PP s | S_PP t/s | T_TG s | S_TG t/s | T s | S t/s | +|-------|--------|------|--------|----------|----------|----------|----------|----------|----------| +| 512 | 32 | 1 | 544 | 0.433 | 1181.83 | 0.693 | 46.16 | 1.126 | 482.94 | +| 512 | 32 | 2 | 1088 | 0.439 | 2334.46 | 1.034 | 61.89 | 1.473 | 738.75 | +| 512 | 32 | 4 | 2176 | 0.772 | 2654.46 | 1.459 | 87.76 | 2.230 | 975.77 | +| 512 | 32 | 8 | 4352 | 1.541 | 2658.78 | 2.043 | 125.31 | 3.583 | 1214.47 | +| 512 | 32 | 16 | 8704 | 3.083 | 2656.91 | 2.675 | 191.42 | 5.758 | 1511.62 | +| 512 | 32 | 32 | 17408 | 6.159 | 2660.12 | 3.615 | 283.24 | 9.774 | 1780.98 | +| 4096 | 32 | 1 | 4128 | 1.915 | 2139.30 | 0.725 | 44.14 | 2.640 | 1563.83 | +| 4096 | 32 | 2 | 8256 | 3.834 | 2136.40 | 1.119 | 57.21 | 4.953 | 1666.81 | +| 4096 | 32 | 4 | 16512 | 7.636 | 2145.72 | 1.631 | 78.49 | 9.266 | 1781.93 | +| 4096 | 32 | 8 | 33024 | 15.295 | 2142.40 | 2.344 | 109.21 | 17.639 | 1872.20 | +| 4096 | 32 | 16 | 66048 | 30.573 | 2143.62 | 3.773 | 135.70 | 34.346 | 1923.04 | +| 4096 | 32 | 32 | 132096 | 61.282 | 2138.82 | 5.795 | 176.71 | 67.077 | 1969.31 | +| 8192 | 32 | 1 | 8224 | 4.510 | 1816.24 | 0.760 | 42.11 | 5.270 | 1560.44 | +| 8192 | 32 | 2 | 16448 | 9.036 | 1813.19 | 1.206 | 53.06 | 10.242 | 1605.91 | +| 8192 | 32 | 4 | 32896 | 18.070 | 1813.43 | 1.783 | 71.80 | 19.852 | 1657.03 | +| 8192 | 32 | 8 | 65792 | 36.125 | 1814.15 | 2.635 | 97.14 | 38.760 | 1697.41 | +| 8192 | 32 | 16 | 131584 | 72.367 | 1811.20 | 4.954 | 103.34 | 77.322 | 1701.77 | +| 8192 | 32 | 32 | 263168 | 144.501 | 1814.13 | 8.103 | 126.37 | 152.604 | 1724.51 | + + +- `llama-bench` + +| model | size | params | backend | ngl | n_ubatch | fa | dio | test | t/s | +| ------------------------------ | ---------: | ---------: | ---------- | --: | -------: | -: | --: | --------------: | -------------------: | +| deepseek2 30B.A3B Q8_0 | 29.65 GiB | 29.94 B | CUDA | 99 | 2048 | 1 | 1 | pp2048 | 2364.18 ± 11.43 | +| deepseek2 30B.A3B Q8_0 | 29.65 GiB | 29.94 B | CUDA | 99 | 2048 | 1 | 1 | tg32 | 48.68 ± 0.12 | +| deepseek2 30B.A3B Q8_0 | 29.65 GiB | 29.94 B | CUDA | 99 | 2048 | 1 | 1 | pp2048 @ d4096 | 1684.13 ± 1.24 | +| deepseek2 30B.A3B Q8_0 | 29.65 GiB | 29.94 B | CUDA | 99 | 2048 | 1 | 1 | tg32 @ d4096 | 44.62 ± 0.22 | +| deepseek2 30B.A3B Q8_0 | 29.65 GiB | 29.94 B | CUDA | 99 | 2048 | 1 | 1 | pp2048 @ d8192 | 1314.68 ± 1.41 | +| deepseek2 30B.A3B Q8_0 | 29.65 GiB | 29.94 B | CUDA | 99 | 2048 | 1 | 1 | tg32 @ d8192 | 42.59 ± 0.11 | +| deepseek2 30B.A3B Q8_0 | 29.65 GiB | 29.94 B | CUDA | 99 | 2048 | 1 | 1 | pp2048 @ d16384 | 914.05 ± 3.32 | +| deepseek2 30B.A3B Q8_0 | 29.65 GiB | 29.94 B | CUDA | 99 | 2048 | 1 | 1 | tg32 @ d16384 | 38.72 ± 0.13 | +| deepseek2 30B.A3B Q8_0 | 29.65 GiB | 29.94 B | CUDA | 99 | 2048 | 1 | 1 | pp2048 @ d32768 | 567.20 ± 0.90 | +| deepseek2 30B.A3B Q8_0 | 29.65 GiB | 29.94 B | CUDA | 99 | 2048 | 1 | 1 | tg32 @ d32768 | 32.65 ± 0.09 | + +build: 11fb327bf (7941) diff --git a/benches/mac-m2-ultra/mac-m2-ultra.md b/benches/mac-m2-ultra/mac-m2-ultra.md new file mode 100644 index 00000000000..cf8a953388d --- /dev/null +++ b/benches/mac-m2-ultra/mac-m2-ultra.md @@ -0,0 +1,298 @@ +## System info + +```bash +uname -a +Darwin gg-studio 25.2.0 Darwin Kernel Version 25.2.0: Tue Nov 18 21:07:05 PST 2025; root:xnu-12377.61.12~1/RELEASE_ARM64_T6020 arm64 + +g++ --version +Apple clang version 17.0.0 (clang-1700.3.19.1) +Target: arm64-apple-darwin25.2.0 +``` + +## ggml-org/gpt-oss-20b-GGUF + +Model: https://huggingface.co/ggml-org/gpt-oss-20b-GGUF + +- `llama-batched-bench` + + +main: n_kv_max = 270336, n_batch = 2048, n_ubatch = 2048, flash_attn = 1, is_pp_shared = 0, is_tg_separate = 0, n_gpu_layers = -1, n_threads = 16, n_threads_batch = 16 + +| PP | TG | B | N_KV | T_PP s | S_PP t/s | T_TG s | S_TG t/s | T s | S t/s | +|-------|--------|------|--------|----------|----------|----------|----------|----------|----------| +| 512 | 32 | 1 | 544 | 0.215 | 2381.35 | 0.245 | 130.45 | 0.460 | 1181.81 | +| 512 | 32 | 2 | 1088 | 0.379 | 2701.43 | 0.382 | 167.56 | 0.761 | 1429.67 | +| 512 | 32 | 4 | 2176 | 0.721 | 2839.27 | 0.604 | 211.76 | 1.326 | 1641.32 | +| 512 | 32 | 8 | 4352 | 1.433 | 2858.30 | 1.033 | 247.75 | 2.466 | 1764.57 | +| 512 | 32 | 16 | 8704 | 2.853 | 2871.12 | 1.570 | 326.11 | 4.423 | 1967.77 | +| 512 | 32 | 32 | 17408 | 5.699 | 2874.95 | 1.910 | 536.15 | 7.609 | 2287.88 | +| 4096 | 32 | 1 | 4128 | 1.552 | 2638.56 | 0.334 | 95.72 | 1.887 | 2188.00 | +| 4096 | 32 | 2 | 8256 | 3.084 | 2655.88 | 0.404 | 158.54 | 3.488 | 2366.86 | +| 4096 | 32 | 4 | 16512 | 6.151 | 2663.78 | 0.652 | 196.39 | 6.802 | 2427.37 | +| 4096 | 32 | 8 | 33024 | 12.288 | 2666.77 | 1.135 | 225.47 | 13.423 | 2460.27 | +| 4096 | 32 | 16 | 66048 | 24.563 | 2668.12 | 1.762 | 290.55 | 26.325 | 2508.97 | +| 4096 | 32 | 32 | 132096 | 49.114 | 2668.73 | 2.398 | 426.94 | 51.512 | 2564.35 | +| 8192 | 32 | 1 | 8224 | 3.345 | 2448.78 | 0.275 | 116.46 | 3.620 | 2271.76 | +| 8192 | 32 | 2 | 16448 | 6.665 | 2458.11 | 0.425 | 150.71 | 7.090 | 2319.91 | +| 8192 | 32 | 4 | 32896 | 13.315 | 2460.92 | 0.691 | 185.21 | 14.006 | 2348.63 | +| 8192 | 32 | 8 | 65792 | 26.611 | 2462.73 | 1.212 | 211.16 | 27.823 | 2364.62 | +| 8192 | 32 | 16 | 131584 | 53.232 | 2462.27 | 1.919 | 266.83 | 55.151 | 2385.88 | +| 8192 | 32 | 32 | 263168 | 110.455 | 2373.30 | 2.752 | 372.03 | 113.208 | 2324.64 | + + +- `llama-bench` + +| model | size | params | backend | threads | n_ubatch | fa | test | t/s | +| ------------------------------ | ---------: | ---------: | ---------- | ------: | -------: | -: | --------------: | -------------------: | +| gpt-oss 20B MXFP4 MoE | 11.27 GiB | 20.91 B | MTL,BLAS | 16 | 2048 | 1 | pp2048 | 2713.40 ± 3.56 | +| gpt-oss 20B MXFP4 MoE | 11.27 GiB | 20.91 B | MTL,BLAS | 16 | 2048 | 1 | tg32 | 129.97 ± 3.90 | +| gpt-oss 20B MXFP4 MoE | 11.27 GiB | 20.91 B | MTL,BLAS | 16 | 2048 | 1 | pp2048 @ d4096 | 2324.59 ± 3.01 | +| gpt-oss 20B MXFP4 MoE | 11.27 GiB | 20.91 B | MTL,BLAS | 16 | 2048 | 1 | tg32 @ d4096 | 123.38 ± 0.17 | +| gpt-oss 20B MXFP4 MoE | 11.27 GiB | 20.91 B | MTL,BLAS | 16 | 2048 | 1 | pp2048 @ d8192 | 1989.82 ± 30.11 | +| gpt-oss 20B MXFP4 MoE | 11.27 GiB | 20.91 B | MTL,BLAS | 16 | 2048 | 1 | tg32 @ d8192 | 117.39 ± 0.33 | +| gpt-oss 20B MXFP4 MoE | 11.27 GiB | 20.91 B | MTL,BLAS | 16 | 2048 | 1 | pp2048 @ d16384 | 1556.54 ± 6.22 | +| gpt-oss 20B MXFP4 MoE | 11.27 GiB | 20.91 B | MTL,BLAS | 16 | 2048 | 1 | tg32 @ d16384 | 109.75 ± 0.42 | +| gpt-oss 20B MXFP4 MoE | 11.27 GiB | 20.91 B | MTL,BLAS | 16 | 2048 | 1 | pp2048 @ d32768 | 1122.63 ± 1.45 | +| gpt-oss 20B MXFP4 MoE | 11.27 GiB | 20.91 B | MTL,BLAS | 16 | 2048 | 1 | tg32 @ d32768 | 98.25 ± 0.08 | + +build: b828e18c7 (7948) + +## ggml-org/gpt-oss-120b-GGUF + +Model: https://huggingface.co/ggml-org/gpt-oss-120b-GGUF + +- `llama-batched-bench` + + +main: n_kv_max = 270336, n_batch = 2048, n_ubatch = 2048, flash_attn = 1, is_pp_shared = 0, is_tg_separate = 0, n_gpu_layers = -1, n_threads = 16, n_threads_batch = 16 + +| PP | TG | B | N_KV | T_PP s | S_PP t/s | T_TG s | S_TG t/s | T s | S t/s | +|-------|--------|------|--------|----------|----------|----------|----------|----------|----------| +| 512 | 32 | 1 | 544 | 0.426 | 1200.92 | 0.361 | 88.56 | 0.788 | 690.64 | +| 512 | 32 | 2 | 1088 | 0.683 | 1500.14 | 0.545 | 117.35 | 1.228 | 886.02 | +| 512 | 32 | 4 | 2176 | 1.204 | 1701.56 | 0.847 | 151.19 | 2.050 | 1061.34 | +| 512 | 32 | 8 | 4352 | 2.402 | 1705.20 | 1.455 | 176.00 | 3.857 | 1128.45 | +| 512 | 32 | 16 | 8704 | 4.802 | 1705.90 | 2.349 | 217.93 | 7.152 | 1217.08 | +| 512 | 32 | 32 | 17408 | 9.593 | 1707.85 | 3.665 | 279.42 | 13.258 | 1313.01 | +| 4096 | 32 | 1 | 4128 | 2.581 | 1587.08 | 0.390 | 82.12 | 2.970 | 1389.67 | +| 4096 | 32 | 2 | 8256 | 5.124 | 1598.79 | 0.589 | 108.62 | 5.713 | 1445.10 | +| 4096 | 32 | 4 | 16512 | 10.231 | 1601.47 | 0.928 | 137.98 | 11.158 | 1479.80 | +| 4096 | 32 | 8 | 33024 | 20.468 | 1600.94 | 1.606 | 159.38 | 22.074 | 1496.04 | +| 4096 | 32 | 16 | 66048 | 40.924 | 1601.42 | 2.639 | 193.99 | 43.563 | 1516.15 | +| 4096 | 32 | 32 | 132096 | 81.819 | 1601.98 | 4.466 | 229.29 | 86.284 | 1530.94 | +| 8192 | 32 | 1 | 8224 | 5.517 | 1484.74 | 0.409 | 78.16 | 5.927 | 1387.58 | +| 8192 | 32 | 2 | 16448 | 11.008 | 1488.43 | 0.622 | 102.92 | 11.629 | 1414.34 | +| 8192 | 32 | 4 | 32896 | 22.002 | 1489.29 | 0.987 | 129.66 | 22.990 | 1430.90 | +| 8192 | 32 | 8 | 65792 | 46.051 | 1423.11 | 1.858 | 137.79 | 47.909 | 1373.27 | +| 8192 | 32 | 16 | 131584 | 97.680 | 1341.85 | 2.872 | 178.28 | 100.552 | 1308.62 | +| 8192 | 32 | 32 | 263168 | 176.407 | 1486.02 | 5.048 | 202.85 | 181.455 | 1450.32 | + + +- `llama-bench` + +| model | size | params | backend | threads | n_ubatch | fa | test | t/s | +| ------------------------------ | ---------: | ---------: | ---------- | ------: | -------: | -: | --------------: | -------------------: | +| gpt-oss 120B MXFP4 MoE | 59.02 GiB | 116.83 B | MTL,BLAS | 16 | 2048 | 1 | pp2048 | 1648.69 ± 1.80 | +| gpt-oss 120B MXFP4 MoE | 59.02 GiB | 116.83 B | MTL,BLAS | 16 | 2048 | 1 | tg32 | 85.60 ± 0.52 | +| gpt-oss 120B MXFP4 MoE | 59.02 GiB | 116.83 B | MTL,BLAS | 16 | 2048 | 1 | pp2048 @ d4096 | 1429.86 ± 1.01 | +| gpt-oss 120B MXFP4 MoE | 59.02 GiB | 116.83 B | MTL,BLAS | 16 | 2048 | 1 | tg32 @ d4096 | 82.03 ± 0.12 | +| gpt-oss 120B MXFP4 MoE | 59.02 GiB | 116.83 B | MTL,BLAS | 16 | 2048 | 1 | pp2048 @ d8192 | 1257.90 ± 1.81 | +| gpt-oss 120B MXFP4 MoE | 59.02 GiB | 116.83 B | MTL,BLAS | 16 | 2048 | 1 | tg32 @ d8192 | 78.23 ± 0.33 | +| gpt-oss 120B MXFP4 MoE | 59.02 GiB | 116.83 B | MTL,BLAS | 16 | 2048 | 1 | pp2048 @ d16384 | 1013.49 ± 0.70 | +| gpt-oss 120B MXFP4 MoE | 59.02 GiB | 116.83 B | MTL,BLAS | 16 | 2048 | 1 | tg32 @ d16384 | 73.20 ± 0.28 | +| gpt-oss 120B MXFP4 MoE | 59.02 GiB | 116.83 B | MTL,BLAS | 16 | 2048 | 1 | pp2048 @ d32768 | 721.11 ± 0.58 | +| gpt-oss 120B MXFP4 MoE | 59.02 GiB | 116.83 B | MTL,BLAS | 16 | 2048 | 1 | tg32 @ d32768 | 65.52 ± 0.10 | + +build: b828e18c7 (7948) + +## ggml-org/Qwen3-Coder-30B-A3B-Instruct-Q8_0-GGUF + +Model: https://huggingface.co/ggml-org/Qwen3-Coder-30B-A3B-Instruct-Q8_0-GGUF + +- `llama-batched-bench` + + +main: n_kv_max = 270336, n_batch = 2048, n_ubatch = 2048, flash_attn = 1, is_pp_shared = 0, is_tg_separate = 0, n_gpu_layers = -1, n_threads = 16, n_threads_batch = 16 + +| PP | TG | B | N_KV | T_PP s | S_PP t/s | T_TG s | S_TG t/s | T s | S t/s | +|-------|--------|------|--------|----------|----------|----------|----------|----------|----------| +| 512 | 32 | 1 | 544 | 0.243 | 2109.23 | 0.419 | 76.34 | 0.662 | 821.84 | +| 512 | 32 | 2 | 1088 | 0.406 | 2521.40 | 0.575 | 111.36 | 0.981 | 1109.27 | +| 512 | 32 | 4 | 2176 | 0.744 | 2751.65 | 0.841 | 152.22 | 1.585 | 1372.71 | +| 512 | 32 | 8 | 4352 | 1.479 | 2770.20 | 1.330 | 192.48 | 2.809 | 1549.53 | +| 512 | 32 | 16 | 8704 | 2.951 | 2776.20 | 2.572 | 199.05 | 5.523 | 1575.93 | +| 512 | 32 | 32 | 17408 | 5.899 | 2777.64 | 2.603 | 393.34 | 8.502 | 2047.54 | +| 4096 | 32 | 1 | 4128 | 1.901 | 2154.15 | 0.474 | 67.58 | 2.375 | 1738.14 | +| 4096 | 32 | 2 | 8256 | 3.788 | 2162.89 | 0.652 | 98.17 | 4.439 | 1859.69 | +| 4096 | 32 | 4 | 16512 | 7.564 | 2166.18 | 0.990 | 129.24 | 8.554 | 1930.34 | +| 4096 | 32 | 8 | 33024 | 15.121 | 2166.98 | 1.632 | 156.82 | 16.754 | 1971.12 | +| 4096 | 32 | 16 | 66048 | 30.241 | 2167.09 | 3.166 | 161.72 | 33.407 | 1977.04 | +| 4096 | 32 | 32 | 132096 | 60.474 | 2167.42 | 3.780 | 270.93 | 64.254 | 2055.86 | +| 8192 | 32 | 1 | 8224 | 4.733 | 1730.92 | 0.483 | 66.29 | 5.215 | 1576.85 | +| 8192 | 32 | 2 | 16448 | 9.459 | 1732.09 | 0.722 | 88.58 | 10.182 | 1615.46 | +| 8192 | 32 | 4 | 32896 | 18.912 | 1732.65 | 1.120 | 114.26 | 20.032 | 1642.14 | +| 8192 | 32 | 8 | 65792 | 37.797 | 1733.91 | 1.873 | 136.67 | 39.670 | 1658.49 | +| 8192 | 32 | 16 | 131584 | 84.133 | 1557.92 | 3.718 | 137.72 | 87.850 | 1497.82 | +| 8192 | 32 | 32 | 263168 | 157.550 | 1663.88 | 4.854 | 210.98 | 162.403 | 1620.46 | + + +- `llama-bench` + +| model | size | params | backend | threads | n_ubatch | fa | test | t/s | +| ------------------------------ | ---------: | ---------: | ---------- | ------: | -------: | -: | --------------: | -------------------: | +| qwen3moe 30B.A3B Q8_0 | 30.25 GiB | 30.53 B | MTL,BLAS | 16 | 2048 | 1 | pp2048 | 2453.11 ± 1.70 | +| qwen3moe 30B.A3B Q8_0 | 30.25 GiB | 30.53 B | MTL,BLAS | 16 | 2048 | 1 | tg32 | 78.97 ± 0.46 | +| qwen3moe 30B.A3B Q8_0 | 30.25 GiB | 30.53 B | MTL,BLAS | 16 | 2048 | 1 | pp2048 @ d4096 | 1569.46 ± 1.97 | +| qwen3moe 30B.A3B Q8_0 | 30.25 GiB | 30.53 B | MTL,BLAS | 16 | 2048 | 1 | tg32 @ d4096 | 71.18 ± 0.37 | +| qwen3moe 30B.A3B Q8_0 | 30.25 GiB | 30.53 B | MTL,BLAS | 16 | 2048 | 1 | pp2048 @ d8192 | 1145.51 ± 1.16 | +| qwen3moe 30B.A3B Q8_0 | 30.25 GiB | 30.53 B | MTL,BLAS | 16 | 2048 | 1 | tg32 @ d8192 | 65.11 ± 0.36 | +| qwen3moe 30B.A3B Q8_0 | 30.25 GiB | 30.53 B | MTL,BLAS | 16 | 2048 | 1 | pp2048 @ d16384 | 741.04 ± 0.74 | +| qwen3moe 30B.A3B Q8_0 | 30.25 GiB | 30.53 B | MTL,BLAS | 16 | 2048 | 1 | tg32 @ d16384 | 56.87 ± 0.14 | +| qwen3moe 30B.A3B Q8_0 | 30.25 GiB | 30.53 B | MTL,BLAS | 16 | 2048 | 1 | pp2048 @ d32768 | 431.31 ± 0.31 | +| qwen3moe 30B.A3B Q8_0 | 30.25 GiB | 30.53 B | MTL,BLAS | 16 | 2048 | 1 | tg32 @ d32768 | 45.26 ± 0.11 | + +build: b828e18c7 (7948) + +## ggml-org/Qwen2.5-Coder-7B-Q8_0-GGUF + +Model: https://huggingface.co/ggml-org/Qwen2.5-Coder-7B-Q8_0-GGUF + +- `llama-batched-bench` + + +main: n_kv_max = 270336, n_batch = 2048, n_ubatch = 2048, flash_attn = 1, is_pp_shared = 0, is_tg_separate = 0, n_gpu_layers = -1, n_threads = 16, n_threads_batch = 16 + +| PP | TG | B | N_KV | T_PP s | S_PP t/s | T_TG s | S_TG t/s | T s | S t/s | +|-------|--------|------|--------|----------|----------|----------|----------|----------|----------| +| 512 | 32 | 1 | 544 | 0.339 | 1509.22 | 0.409 | 78.17 | 0.749 | 726.67 | +| 512 | 32 | 2 | 1088 | 0.646 | 1584.93 | 0.483 | 132.45 | 1.129 | 963.45 | +| 512 | 32 | 4 | 2176 | 1.258 | 1627.50 | 0.585 | 218.67 | 1.844 | 1180.21 | +| 512 | 32 | 8 | 4352 | 2.506 | 1634.41 | 1.005 | 254.83 | 3.511 | 1239.64 | +| 512 | 32 | 16 | 8704 | 5.007 | 1635.99 | 1.595 | 321.07 | 6.602 | 1318.38 | +| 512 | 32 | 32 | 17408 | 10.007 | 1637.19 | 1.676 | 611.12 | 11.683 | 1490.03 | +| 4096 | 32 | 1 | 4128 | 2.730 | 1500.46 | 0.431 | 74.31 | 3.160 | 1306.12 | +| 4096 | 32 | 2 | 8256 | 5.446 | 1504.33 | 0.524 | 122.04 | 5.970 | 1382.91 | +| 4096 | 32 | 4 | 16512 | 10.875 | 1506.59 | 0.662 | 193.45 | 11.537 | 1431.28 | +| 4096 | 32 | 8 | 33024 | 21.749 | 1506.61 | 1.158 | 221.11 | 22.907 | 1441.64 | +| 4096 | 32 | 16 | 66048 | 43.477 | 1507.36 | 1.901 | 269.32 | 45.378 | 1455.49 | +| 4096 | 32 | 32 | 132096 | 86.954 | 1507.37 | 2.325 | 440.42 | 89.279 | 1479.59 | +| 8192 | 32 | 1 | 8224 | 5.940 | 1379.21 | 0.449 | 71.20 | 6.389 | 1287.20 | +| 8192 | 32 | 2 | 16448 | 11.865 | 1380.84 | 0.559 | 114.59 | 12.424 | 1323.92 | +| 8192 | 32 | 4 | 32896 | 23.723 | 1381.25 | 0.728 | 175.80 | 24.452 | 1345.35 | +| 8192 | 32 | 8 | 65792 | 47.434 | 1381.63 | 1.279 | 200.09 | 48.713 | 1350.60 | +| 8192 | 32 | 16 | 131584 | 94.864 | 1381.69 | 2.198 | 232.97 | 97.061 | 1355.68 | +| 8192 | 32 | 32 | 263168 | 189.743 | 1381.57 | 3.052 | 335.50 | 192.795 | 1365.01 | + + +- `llama-bench` + +| model | size | params | backend | threads | n_ubatch | fa | test | t/s | +| ------------------------------ | ---------: | ---------: | ---------- | ------: | -------: | -: | --------------: | -------------------: | +| qwen2 7B Q8_0 | 7.54 GiB | 7.62 B | MTL,BLAS | 16 | 2048 | 1 | pp2048 | 1565.91 ± 0.86 | +| qwen2 7B Q8_0 | 7.54 GiB | 7.62 B | MTL,BLAS | 16 | 2048 | 1 | tg32 | 79.68 ± 0.39 | +| qwen2 7B Q8_0 | 7.54 GiB | 7.62 B | MTL,BLAS | 16 | 2048 | 1 | pp2048 @ d4096 | 1317.41 ± 1.02 | +| qwen2 7B Q8_0 | 7.54 GiB | 7.62 B | MTL,BLAS | 16 | 2048 | 1 | tg32 @ d4096 | 74.70 ± 0.04 | +| qwen2 7B Q8_0 | 7.54 GiB | 7.62 B | MTL,BLAS | 16 | 2048 | 1 | pp2048 @ d8192 | 1134.65 ± 0.76 | +| qwen2 7B Q8_0 | 7.54 GiB | 7.62 B | MTL,BLAS | 16 | 2048 | 1 | tg32 @ d8192 | 71.31 ± 0.12 | +| qwen2 7B Q8_0 | 7.54 GiB | 7.62 B | MTL,BLAS | 16 | 2048 | 1 | pp2048 @ d16384 | 886.46 ± 0.78 | +| qwen2 7B Q8_0 | 7.54 GiB | 7.62 B | MTL,BLAS | 16 | 2048 | 1 | tg32 @ d16384 | 65.93 ± 0.06 | +| qwen2 7B Q8_0 | 7.54 GiB | 7.62 B | MTL,BLAS | 16 | 2048 | 1 | pp2048 @ d32768 | 612.21 ± 0.30 | +| qwen2 7B Q8_0 | 7.54 GiB | 7.62 B | MTL,BLAS | 16 | 2048 | 1 | tg32 @ d32768 | 56.83 ± 0.02 | + +build: b828e18c7 (7948) + +## ggml-org/gemma-3-4b-it-qat-GGUF + +Model: https://huggingface.co/ggml-org/gemma-3-4b-it-qat-GGUF + +- `llama-batched-bench` + + +main: n_kv_max = 270336, n_batch = 2048, n_ubatch = 2048, flash_attn = 1, is_pp_shared = 0, is_tg_separate = 0, n_gpu_layers = -1, n_threads = 16, n_threads_batch = 16 + +| PP | TG | B | N_KV | T_PP s | S_PP t/s | T_TG s | S_TG t/s | T s | S t/s | +|-------|--------|------|--------|----------|----------|----------|----------|----------|----------| +| 512 | 32 | 1 | 544 | 0.186 | 2748.06 | 0.235 | 136.28 | 0.421 | 1291.78 | +| 512 | 32 | 2 | 1088 | 0.342 | 2990.95 | 0.312 | 204.99 | 0.655 | 1662.15 | +| 512 | 32 | 4 | 2176 | 0.662 | 3092.69 | 0.404 | 316.97 | 1.066 | 2041.21 | +| 512 | 32 | 8 | 4352 | 1.317 | 3110.41 | 0.579 | 441.80 | 1.896 | 2294.97 | +| 512 | 32 | 16 | 8704 | 2.625 | 3120.23 | 1.207 | 424.08 | 3.833 | 2270.93 | +| 512 | 32 | 32 | 17408 | 5.242 | 3125.34 | 1.299 | 788.23 | 6.541 | 2661.19 | +| 4096 | 32 | 1 | 4128 | 1.408 | 2909.90 | 0.296 | 108.07 | 1.704 | 2422.95 | +| 4096 | 32 | 2 | 8256 | 2.793 | 2933.40 | 0.325 | 197.00 | 3.118 | 2648.25 | +| 4096 | 32 | 4 | 16512 | 5.567 | 2943.22 | 0.440 | 291.07 | 6.006 | 2749.05 | +| 4096 | 32 | 8 | 33024 | 11.114 | 2948.23 | 0.640 | 400.26 | 11.754 | 2809.59 | +| 4096 | 32 | 16 | 66048 | 22.217 | 2949.76 | 1.327 | 385.83 | 23.544 | 2805.26 | +| 4096 | 32 | 32 | 132096 | 44.420 | 2950.77 | 1.553 | 659.30 | 45.973 | 2873.36 | +| 8192 | 32 | 1 | 8224 | 2.860 | 2864.58 | 0.250 | 127.90 | 3.110 | 2644.42 | +| 8192 | 32 | 2 | 16448 | 5.702 | 2873.63 | 0.335 | 191.07 | 6.036 | 2724.77 | +| 8192 | 32 | 4 | 32896 | 11.383 | 2878.69 | 0.456 | 280.72 | 11.839 | 2778.63 | +| 8192 | 32 | 8 | 65792 | 22.750 | 2880.75 | 0.671 | 381.48 | 23.421 | 2809.14 | +| 8192 | 32 | 16 | 131584 | 45.484 | 2881.74 | 1.406 | 364.04 | 46.890 | 2806.22 | +| 8192 | 32 | 32 | 263168 | 90.956 | 2882.10 | 1.793 | 570.98 | 92.749 | 2837.41 | + + +- `llama-bench` + +| model | size | params | backend | threads | n_ubatch | fa | test | t/s | +| ------------------------------ | ---------: | ---------: | ---------- | ------: | -------: | -: | --------------: | -------------------: | +| gemma3 4B Q4_0 | 2.35 GiB | 3.88 B | MTL,BLAS | 16 | 2048 | 1 | pp2048 | 2923.59 ± 3.10 | +| gemma3 4B Q4_0 | 2.35 GiB | 3.88 B | MTL,BLAS | 16 | 2048 | 1 | tg32 | 134.28 ± 1.29 | +| gemma3 4B Q4_0 | 2.35 GiB | 3.88 B | MTL,BLAS | 16 | 2048 | 1 | pp2048 @ d4096 | 2748.21 ± 3.05 | +| gemma3 4B Q4_0 | 2.35 GiB | 3.88 B | MTL,BLAS | 16 | 2048 | 1 | tg32 @ d4096 | 133.11 ± 0.08 | +| gemma3 4B Q4_0 | 2.35 GiB | 3.88 B | MTL,BLAS | 16 | 2048 | 1 | pp2048 @ d8192 | 2641.45 ± 2.31 | +| gemma3 4B Q4_0 | 2.35 GiB | 3.88 B | MTL,BLAS | 16 | 2048 | 1 | tg32 @ d8192 | 125.85 ± 0.35 | +| gemma3 4B Q4_0 | 2.35 GiB | 3.88 B | MTL,BLAS | 16 | 2048 | 1 | pp2048 @ d16384 | 2446.20 ± 2.94 | +| gemma3 4B Q4_0 | 2.35 GiB | 3.88 B | MTL,BLAS | 16 | 2048 | 1 | tg32 @ d16384 | 125.00 ± 0.12 | +| gemma3 4B Q4_0 | 2.35 GiB | 3.88 B | MTL,BLAS | 16 | 2048 | 1 | pp2048 @ d32768 | 2129.18 ± 7.43 | +| gemma3 4B Q4_0 | 2.35 GiB | 3.88 B | MTL,BLAS | 16 | 2048 | 1 | tg32 @ d32768 | 113.14 ± 0.10 | + +build: b828e18c7 (7948) + +## ggml-org/GLM-4.7-Flash-GGUF + +Model: https://huggingface.co/ggml-org/GLM-4.7-Flash-GGUF + +- `llama-batched-bench` + + +main: n_kv_max = 270336, n_batch = 2048, n_ubatch = 2048, flash_attn = 1, is_pp_shared = 0, is_tg_separate = 0, n_gpu_layers = -1, n_threads = 16, n_threads_batch = 16 + +| PP | TG | B | N_KV | T_PP s | S_PP t/s | T_TG s | S_TG t/s | T s | S t/s | +|-------|--------|------|--------|----------|----------|----------|----------|----------|----------| +| 512 | 32 | 1 | 544 | 0.326 | 1568.69 | 0.522 | 61.28 | 0.849 | 641.09 | +| 512 | 32 | 2 | 1088 | 0.528 | 1939.42 | 0.744 | 86.07 | 1.272 | 855.63 | +| 512 | 32 | 4 | 2176 | 0.968 | 2114.85 | 1.105 | 115.85 | 2.073 | 1049.56 | +| 512 | 32 | 8 | 4352 | 1.928 | 2124.62 | 1.684 | 151.99 | 3.612 | 1204.82 | +| 512 | 32 | 16 | 8704 | 3.844 | 2131.34 | 3.141 | 162.99 | 6.985 | 1246.11 | +| 512 | 32 | 32 | 17408 | 7.683 | 2132.38 | 3.924 | 260.95 | 11.608 | 1499.71 | +| 4096 | 32 | 1 | 4128 | 3.280 | 1248.75 | 0.723 | 44.29 | 4.003 | 1031.33 | +| 4096 | 32 | 2 | 8256 | 6.545 | 1251.63 | 0.930 | 68.85 | 7.475 | 1104.53 | +| 4096 | 32 | 4 | 16512 | 13.080 | 1252.64 | 1.454 | 88.03 | 14.534 | 1136.12 | +| 4096 | 32 | 8 | 33024 | 26.154 | 1252.90 | 2.388 | 107.20 | 28.542 | 1157.04 | +| 4096 | 32 | 16 | 66048 | 52.297 | 1253.14 | 4.724 | 108.37 | 57.022 | 1158.30 | +| 4096 | 32 | 32 | 132096 | 104.578 | 1253.34 | 7.266 | 140.93 | 111.844 | 1181.08 | +| 8192 | 32 | 1 | 8224 | 9.623 | 851.31 | 0.767 | 41.72 | 10.390 | 791.54 | +| 8192 | 32 | 2 | 16448 | 20.916 | 783.32 | 1.148 | 55.74 | 22.064 | 745.45 | +| 8192 | 32 | 4 | 32896 | 43.509 | 753.14 | 1.833 | 69.82 | 45.342 | 725.51 | +| 8192 | 32 | 8 | 65792 | 79.621 | 823.10 | 3.180 | 80.50 | 82.801 | 794.58 | +| 8192 | 32 | 16 | 131584 | 153.770 | 852.39 | 6.502 | 78.74 | 160.272 | 821.00 | +| 8192 | 32 | 32 | 263168 | 307.539 | 852.39 | 10.839 | 94.48 | 318.378 | 826.59 | + + +- `llama-bench` + +| model | size | params | backend | threads | n_ubatch | fa | test | t/s | +| ------------------------------ | ---------: | ---------: | ---------- | ------: | -------: | -: | --------------: | -------------------: | +| deepseek2 30B.A3B Q8_0 | 29.65 GiB | 29.94 B | MTL,BLAS | 16 | 2048 | 1 | pp2048 | 1629.33 ± 0.27 | +| deepseek2 30B.A3B Q8_0 | 29.65 GiB | 29.94 B | MTL,BLAS | 16 | 2048 | 1 | tg32 | 59.58 ± 0.13 | +| deepseek2 30B.A3B Q8_0 | 29.65 GiB | 29.94 B | MTL,BLAS | 16 | 2048 | 1 | pp2048 @ d4096 | 732.67 ± 0.42 | +| deepseek2 30B.A3B Q8_0 | 29.65 GiB | 29.94 B | MTL,BLAS | 16 | 2048 | 1 | tg32 @ d4096 | 47.44 ± 0.15 | +| deepseek2 30B.A3B Q8_0 | 29.65 GiB | 29.94 B | MTL,BLAS | 16 | 2048 | 1 | pp2048 @ d8192 | 474.33 ± 0.33 | +| deepseek2 30B.A3B Q8_0 | 29.65 GiB | 29.94 B | MTL,BLAS | 16 | 2048 | 1 | tg32 @ d8192 | 40.20 ± 0.20 | +| deepseek2 30B.A3B Q8_0 | 29.65 GiB | 29.94 B | MTL,BLAS | 16 | 2048 | 1 | pp2048 @ d16384 | 277.46 ± 0.09 | +| deepseek2 30B.A3B Q8_0 | 29.65 GiB | 29.94 B | MTL,BLAS | 16 | 2048 | 1 | tg32 @ d16384 | 31.50 ± 0.93 | +| deepseek2 30B.A3B Q8_0 | 29.65 GiB | 29.94 B | MTL,BLAS | 16 | 2048 | 1 | pp2048 @ d32768 | 151.44 ± 0.05 | +| deepseek2 30B.A3B Q8_0 | 29.65 GiB | 29.94 B | MTL,BLAS | 16 | 2048 | 1 | tg32 @ d32768 | 21.81 ± 0.01 | + +build: b828e18c7 (7948) diff --git a/ci/run.sh b/ci/run.sh index dfcf9596618..96755ea13e3 100755 --- a/ci/run.sh +++ b/ci/run.sh @@ -635,6 +635,29 @@ function gg_check_build_requirements { fi } +function gg_run_test_backend_ops_cpu { + cd ${SRC} + + cd build-ci-release + + set -e + + (time ./bin/test-backend-ops -b CPU ) 2>&1 | tee -a $OUT/${ci}-test-backend-ops-cpu.log + + set +e +} + +function gg_sum_test_backend_ops_cpu { + gg_printf '### %s\n\n' "${ci}" + + gg_printf 'Runs test-backend-ops for CPU backend\n' + gg_printf '- status: %s\n' "$(cat $OUT/${ci}.exit)" + gg_printf '```\n' + gg_printf '%s\n' "$(cat $OUT/${ci}-test-backend-ops-cpu.log)" + gg_printf '```\n' + gg_printf '\n' +} + ## main export LLAMA_LOG_PREFIX=1 @@ -663,6 +686,10 @@ ret=0 test $ret -eq 0 && gg_run ctest_debug test $ret -eq 0 && gg_run ctest_release +if [ ! -z ${GG_BUILD_HIGH_PERF} ]; then + test $ret -eq 0 && gg_run test_backend_ops_cpu +fi + if [ -z ${GG_BUILD_LOW_PERF} ]; then test $ret -eq 0 && gg_run embd_bge_small test $ret -eq 0 && gg_run rerank_tiny diff --git a/cmake/common.cmake b/cmake/common.cmake index a5bb787f151..bcf403e0ee3 100644 --- a/cmake/common.cmake +++ b/cmake/common.cmake @@ -32,4 +32,27 @@ function(llama_add_compile_flags) set(CXX_FLAGS "" PARENT_SCOPE) endif() endif() + + if (NOT MSVC) + if (LLAMA_SANITIZE_THREAD) + message(STATUS "Using -fsanitize=thread") + + add_compile_options(-fsanitize=thread) + link_libraries (-fsanitize=thread) + endif() + + if (LLAMA_SANITIZE_ADDRESS) + message(STATUS "Using -fsanitize=address") + + add_compile_options(-fsanitize=address -fno-omit-frame-pointer) + link_libraries (-fsanitize=address) + endif() + + if (LLAMA_SANITIZE_UNDEFINED) + message(STATUS "Using -fsanitize=undefined") + + add_compile_options(-fsanitize=undefined) + link_libraries (-fsanitize=undefined) + endif() + endif() endfunction() diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index ae02c0bd77f..295ae9ea254 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -73,6 +73,10 @@ add_library(${TARGET} STATIC log.h ngram-cache.cpp ngram-cache.h + ngram-map.cpp + ngram-map.h + ngram-mod.cpp + ngram-mod.h peg-parser.cpp peg-parser.h preset.cpp diff --git a/common/arg.cpp b/common/arg.cpp index 163c9b71b0e..9d7ad30bf4a 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -6,6 +6,7 @@ #include "json-schema-to-grammar.h" #include "log.h" #include "sampling.h" +#include "speculative.h" #include "preset.h" // fix problem with std::min and std::max @@ -579,14 +580,14 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context params.mmproj = res.mmproj; } // only download mmproj if the current example is using it - for (auto & ex : mmproj_examples) { + for (const auto & ex : mmproj_examples) { if (ctx_arg.ex == ex) { common_params_handle_model(params.mmproj, params.hf_token, params.offline); break; } } - common_params_handle_model(params.speculative.model, params.hf_token, params.offline); - common_params_handle_model(params.vocoder.model, params.hf_token, params.offline); + common_params_handle_model(params.speculative.mparams_dft, params.hf_token, params.offline); + common_params_handle_model(params.vocoder.model, params.hf_token, params.offline); } // model is required (except for server) @@ -1216,21 +1217,25 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"-lcs", "--lookup-cache-static"}, "FNAME", "path to static lookup cache to use for lookup decoding (not updated by generation)", [](common_params & params, const std::string & value) { - params.lookup_cache_static = value; + params.speculative.lookup_cache_static = value; } - ).set_examples({LLAMA_EXAMPLE_LOOKUP})); + ).set_examples({LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_SERVER})); add_opt(common_arg( {"-lcd", "--lookup-cache-dynamic"}, "FNAME", "path to dynamic lookup cache to use for lookup decoding (updated by generation)", [](common_params & params, const std::string & value) { - params.lookup_cache_dynamic = value; + params.speculative.lookup_cache_dynamic = value; } - ).set_examples({LLAMA_EXAMPLE_LOOKUP})); + ).set_examples({LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_SERVER})); add_opt(common_arg( {"-c", "--ctx-size"}, "N", string_format("size of the prompt context (default: %d, 0 = loaded from model)", params.n_ctx), [](common_params & params, int value) { params.n_ctx = value; + if (value == 0) { + // disable context reduction in llama_params_fit if the user explicitly requests the full context size: + params.fit_params_min_ctx = UINT32_MAX; + } } ).set_env("LLAMA_ARG_CTX_SIZE")); add_opt(common_arg( @@ -1291,11 +1296,12 @@ common_params_context common_params_parser_init(common_params & params, llama_ex ).set_env("LLAMA_ARG_CACHE_RAM").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI})); add_opt(common_arg( {"-kvu", "--kv-unified"}, + {"-no-kvu", "--no-kv-unified"}, "use single unified KV buffer shared across all sequences (default: enabled if number of slots is auto)", - [](common_params & params) { - params.kv_unified = true; + [](common_params & params, bool value) { + params.kv_unified = value; } - ).set_env("LLAMA_ARG_KV_UNIFIED").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_BATCHED})); + ).set_env("LLAMA_ARG_KV_UNIFIED").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_BATCHED, LLAMA_EXAMPLE_BENCH})); add_opt(common_arg( {"--context-shift"}, {"--no-context-shift"}, @@ -1573,7 +1579,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex ).set_sparam()); add_opt(common_arg( {"--temp"}, "N", - string_format("temperature (default: %.1f)", (double)params.sampling.temp), + string_format("temperature (default: %.2f)", (double)params.sampling.temp), [](common_params & params, const std::string & value) { params.sampling.temp = std::stof(value); params.sampling.temp = std::max(params.sampling.temp, 0.0f); @@ -1590,7 +1596,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex ).set_sparam().set_env("LLAMA_ARG_TOP_K")); add_opt(common_arg( {"--top-p"}, "N", - string_format("top-p sampling (default: %.1f, 1.0 = disabled)", (double)params.sampling.top_p), + string_format("top-p sampling (default: %.2f, 1.0 = disabled)", (double)params.sampling.top_p), [](common_params & params, const std::string & value) { params.sampling.top_p = std::stof(value); params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_P; @@ -1598,7 +1604,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex ).set_sparam()); add_opt(common_arg( {"--min-p"}, "N", - string_format("min-p sampling (default: %.1f, 0.0 = disabled)", (double)params.sampling.min_p), + string_format("min-p sampling (default: %.2f, 0.0 = disabled)", (double)params.sampling.min_p), [](common_params & params, const std::string & value) { params.sampling.min_p = std::stof(value); params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIN_P; @@ -1606,14 +1612,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex ).set_sparam()); add_opt(common_arg( {"--top-nsigma"}, "N", - string_format("top-n-sigma sampling (default: %.1f, -1.0 = disabled)", params.sampling.top_n_sigma), + string_format("top-n-sigma sampling (default: %.2f, -1.0 = disabled)", params.sampling.top_n_sigma), [](common_params & params, const std::string & value) { params.sampling.top_n_sigma = std::stof(value); } ).set_sparam()); add_opt(common_arg( {"--xtc-probability"}, "N", - string_format("xtc probability (default: %.1f, 0.0 = disabled)", (double)params.sampling.xtc_probability), + string_format("xtc probability (default: %.2f, 0.0 = disabled)", (double)params.sampling.xtc_probability), [](common_params & params, const std::string & value) { params.sampling.xtc_probability = std::stof(value); params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_PROBABILITY; @@ -1621,7 +1627,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex ).set_sparam()); add_opt(common_arg( {"--xtc-threshold"}, "N", - string_format("xtc threshold (default: %.1f, 1.0 = disabled)", (double)params.sampling.xtc_threshold), + string_format("xtc threshold (default: %.2f, 1.0 = disabled)", (double)params.sampling.xtc_threshold), [](common_params & params, const std::string & value) { params.sampling.xtc_threshold = std::stof(value); params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_THRESHOLD; @@ -1629,7 +1635,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex ).set_sparam()); add_opt(common_arg( {"--typical"}, "N", - string_format("locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)", (double)params.sampling.typ_p), + string_format("locally typical sampling, parameter p (default: %.2f, 1.0 = disabled)", (double)params.sampling.typ_p), [](common_params & params, const std::string & value) { params.sampling.typ_p = std::stof(value); } @@ -1648,7 +1654,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex ).set_sparam()); add_opt(common_arg( {"--repeat-penalty"}, "N", - string_format("penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)", (double)params.sampling.penalty_repeat), + string_format("penalize repeat sequence of tokens (default: %.2f, 1.0 = disabled)", (double)params.sampling.penalty_repeat), [](common_params & params, const std::string & value) { params.sampling.penalty_repeat = std::stof(value); params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_REPEAT; @@ -1656,21 +1662,21 @@ common_params_context common_params_parser_init(common_params & params, llama_ex ).set_sparam()); add_opt(common_arg( {"--presence-penalty"}, "N", - string_format("repeat alpha presence penalty (default: %.1f, 0.0 = disabled)", (double)params.sampling.penalty_present), + string_format("repeat alpha presence penalty (default: %.2f, 0.0 = disabled)", (double)params.sampling.penalty_present), [](common_params & params, const std::string & value) { params.sampling.penalty_present = std::stof(value); } ).set_sparam()); add_opt(common_arg( {"--frequency-penalty"}, "N", - string_format("repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)", (double)params.sampling.penalty_freq), + string_format("repeat alpha frequency penalty (default: %.2f, 0.0 = disabled)", (double)params.sampling.penalty_freq), [](common_params & params, const std::string & value) { params.sampling.penalty_freq = std::stof(value); } ).set_sparam()); add_opt(common_arg( {"--dry-multiplier"}, "N", - string_format("set DRY sampling multiplier (default: %.1f, 0.0 = disabled)", (double)params.sampling.dry_multiplier), + string_format("set DRY sampling multiplier (default: %.2f, 0.0 = disabled)", (double)params.sampling.dry_multiplier), [](common_params & params, const std::string & value) { params.sampling.dry_multiplier = std::stof(value); } @@ -1751,14 +1757,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex ).set_sparam()); add_opt(common_arg( {"--dynatemp-range"}, "N", - string_format("dynamic temperature range (default: %.1f, 0.0 = disabled)", (double)params.sampling.dynatemp_range), + string_format("dynamic temperature range (default: %.2f, 0.0 = disabled)", (double)params.sampling.dynatemp_range), [](common_params & params, const std::string & value) { params.sampling.dynatemp_range = std::stof(value); } ).set_sparam()); add_opt(common_arg( {"--dynatemp-exp"}, "N", - string_format("dynamic temperature exponent (default: %.1f)", (double)params.sampling.dynatemp_exponent), + string_format("dynamic temperature exponent (default: %.2f)", (double)params.sampling.dynatemp_exponent), [](common_params & params, const std::string & value) { params.sampling.dynatemp_exponent = std::stof(value); } @@ -1774,7 +1780,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex ).set_sparam()); add_opt(common_arg( {"--mirostat-lr"}, "N", - string_format("Mirostat learning rate, parameter eta (default: %.1f)", (double)params.sampling.mirostat_eta), + string_format("Mirostat learning rate, parameter eta (default: %.2f)", (double)params.sampling.mirostat_eta), [](common_params & params, const std::string & value) { params.sampling.mirostat_eta = std::stof(value); params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA; @@ -1782,7 +1788,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex ).set_sparam()); add_opt(common_arg( {"--mirostat-ent"}, "N", - string_format("Mirostat target entropy, parameter tau (default: %.1f)", (double)params.sampling.mirostat_tau), + string_format("Mirostat target entropy, parameter tau (default: %.2f)", (double)params.sampling.mirostat_tau), [](common_params & params, const std::string & value) { params.sampling.mirostat_tau = std::stof(value); params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_TAU; @@ -1916,28 +1922,28 @@ common_params_context common_params_parser_init(common_params & params, llama_ex ).set_env("LLAMA_ARG_YARN_ORIG_CTX")); add_opt(common_arg( {"--yarn-ext-factor"}, "N", - string_format("YaRN: extrapolation mix factor (default: %.1f, 0.0 = full interpolation)", (double)params.yarn_ext_factor), + string_format("YaRN: extrapolation mix factor (default: %.2f, 0.0 = full interpolation)", (double)params.yarn_ext_factor), [](common_params & params, const std::string & value) { params.yarn_ext_factor = std::stof(value); } ).set_env("LLAMA_ARG_YARN_EXT_FACTOR")); add_opt(common_arg( {"--yarn-attn-factor"}, "N", - string_format("YaRN: scale sqrt(t) or attention magnitude (default: %.1f)", (double)params.yarn_attn_factor), + string_format("YaRN: scale sqrt(t) or attention magnitude (default: %.2f)", (double)params.yarn_attn_factor), [](common_params & params, const std::string & value) { params.yarn_attn_factor = std::stof(value); } ).set_env("LLAMA_ARG_YARN_ATTN_FACTOR")); add_opt(common_arg( {"--yarn-beta-slow"}, "N", - string_format("YaRN: high correction dim or alpha (default: %.1f)", (double)params.yarn_beta_slow), + string_format("YaRN: high correction dim or alpha (default: %.2f)", (double)params.yarn_beta_slow), [](common_params & params, const std::string & value) { params.yarn_beta_slow = std::stof(value); } ).set_env("LLAMA_ARG_YARN_BETA_SLOW")); add_opt(common_arg( {"--yarn-beta-fast"}, "N", - string_format("YaRN: low correction dim or beta (default: %.1f)", (double)params.yarn_beta_fast), + string_format("YaRN: low correction dim or beta (default: %.2f)", (double)params.yarn_beta_fast), [](common_params & params, const std::string & value) { params.yarn_beta_fast = std::stof(value); } @@ -2194,18 +2200,15 @@ common_params_context common_params_parser_init(common_params & params, llama_ex add_opt(common_arg( {"--mmap"}, {"--no-mmap"}, - string_format("whether to memory-map model. Explicitly enabling mmap disables direct-io. (if mmap disabled, slower load but may reduce pageouts if not using mlock) (default: %s)", params.use_mmap ? "enabled" : "disabled"), + string_format("whether to memory-map model. (if mmap disabled, slower load but may reduce pageouts if not using mlock) (default: %s)", params.use_mmap ? "enabled" : "disabled"), [](common_params & params, bool value) { params.use_mmap = value; - if (value) { - params.use_direct_io = false; // disable direct io when mmap is explicitly enabled - } } ).set_env("LLAMA_ARG_MMAP")); add_opt(common_arg( {"-dio", "--direct-io"}, {"-ndio", "--no-direct-io"}, - string_format("use DirectIO if available. Takes precedence over --mmap (default: %s)", params.use_direct_io ? "enabled" : "disabled"), + string_format("use DirectIO if available. (default: %s)", params.use_direct_io ? "enabled" : "disabled"), [](common_params & params, bool value) { params.use_direct_io = value; } @@ -2561,7 +2564,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"-hfd", "-hfrd", "--hf-repo-draft"}, "/[:quant]", "Same as --hf-repo, but for the draft model (default: unused)", [](common_params & params, const std::string & value) { - params.speculative.model.hf_repo = value; + params.speculative.mparams_dft.hf_repo = value; } ).set_env("LLAMA_ARG_HFD_REPO")); add_opt(common_arg( @@ -2743,14 +2746,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params, int value) { params.embd_normalize = value; } - ).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_DEBUG})); + ).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_MTMD, LLAMA_EXAMPLE_DEBUG})); add_opt(common_arg( {"--embd-output-format"}, "FORMAT", "empty = default, \"array\" = [[],[]...], \"json\" = openai style, \"json+\" = same \"json\" + cosine similarity matrix, \"raw\" = plain whitespace-delimited output (one embedding per line)", [](common_params & params, const std::string & value) { params.embd_out = value; } - ).set_examples({LLAMA_EXAMPLE_EMBEDDING})); + ).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_MTMD})); add_opt(common_arg( {"--embd-separator"}, "STRING", "separator of embeddings (default \\n) for example \"<#sep#>\"", @@ -3331,14 +3334,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_DRAFT_MIN")); add_opt(common_arg( {"--draft-p-split"}, "P", - string_format("speculative decoding split probability (default: %.1f)", (double)params.speculative.p_split), + string_format("speculative decoding split probability (default: %.2f)", (double)params.speculative.p_split), [](common_params & params, const std::string & value) { params.speculative.p_split = std::stof(value); } ).set_examples({LLAMA_EXAMPLE_SPECULATIVE}).set_env("LLAMA_ARG_DRAFT_P_SPLIT")); add_opt(common_arg( {"--draft-p-min"}, "P", - string_format("minimum speculative decoding probability (greedy) (default: %.1f)", (double)params.speculative.p_min), + string_format("minimum speculative decoding probability (greedy) (default: %.2f)", (double)params.speculative.p_min), [](common_params & params, const std::string & value) { params.speculative.p_min = std::stof(value); } @@ -3382,7 +3385,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"-md", "--model-draft"}, "FNAME", "draft model for speculative decoding (default: unused)", [](common_params & params, const std::string & value) { - params.speculative.model.path = value; + params.speculative.mparams_dft.path = value; } ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_MODEL_DRAFT")); add_opt(common_arg( @@ -3392,6 +3395,68 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.speculative.replacements.push_back({ tgt, dft }); } ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI})); + add_opt(common_arg( + {"--spec-type"}, "[none|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]", + string_format("type of speculative decoding to use when no draft model is provided (default: %s)\n", + common_speculative_type_to_str(params.speculative.type).c_str()), + [](common_params & params, const std::string & value) { + if (value == "none") { + params.speculative.type = COMMON_SPECULATIVE_TYPE_NONE; + } else if (value == "ngram-cache") { + params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_CACHE; + } else if (value == "ngram-simple") { + params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE; + } else if (value == "ngram-map-k") { + params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K; + } else if (value == "ngram-map-k4v") { + params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V; + } else if (value == "ngram-mod") { + params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MOD; + } else { + throw std::invalid_argument("unknown speculative decoding type without draft model"); + } + } + ).set_examples({LLAMA_EXAMPLE_SERVER})); + add_opt(common_arg( + {"--spec-ngram-size-n"}, "N", + string_format("ngram size N for ngram-simple/ngram-map speculative decoding, length of lookup n-gram (default: %d)", params.speculative.ngram_size_n), + [](common_params & params, int value) { + if (value < 1 || value > 1024) { + throw std::invalid_argument("ngram size N must be between 1 and 1024 inclusive"); + } + params.speculative.ngram_size_n = value; + } + ).set_examples({LLAMA_EXAMPLE_SERVER})); + add_opt(common_arg( + {"--spec-ngram-size-m"}, "N", + string_format("ngram size M for ngram-simple/ngram-map speculative decoding, length of draft m-gram (default: %d)", params.speculative.ngram_size_m), + [](common_params & params, int value) { + if (value < 1 || value > 1024) { + throw std::invalid_argument("ngram size M must be between 1 and 1024 inclusive"); + } + params.speculative.ngram_size_m = value; + } + ).set_examples({LLAMA_EXAMPLE_SERVER})); + add_opt(common_arg( + {"--spec-ngram-check-rate"}, "N", + string_format("ngram check rate for ngram-simple/ngram-map speculative decoding (default: %d)", params.speculative.ngram_check_rate), + [](common_params & params, int value) { + if (value < 1) { + throw std::invalid_argument("ngram check rate must be at least 1"); + } + params.speculative.ngram_check_rate = value; + } + ).set_examples({LLAMA_EXAMPLE_SERVER})); + add_opt(common_arg( + {"--spec-ngram-min-hits"}, "N", + string_format("minimum hits for ngram-map speculative decoding (default: %d)", params.speculative.ngram_min_hits), + [](common_params & params, int value) { + if (value < 1) { + throw std::invalid_argument("ngram min hits must be at least 1"); + } + params.speculative.ngram_min_hits = value; + } + ).set_examples({LLAMA_EXAMPLE_SERVER})); add_opt(common_arg( {"-ctkd", "--cache-type-k-draft"}, "TYPE", string_format( @@ -3618,8 +3683,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params) { params.model.hf_repo = "ggml-org/Qwen2.5-Coder-7B-Q8_0-GGUF"; params.model.hf_file = "qwen2.5-coder-7b-q8_0.gguf"; - params.speculative.model.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF"; - params.speculative.model.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf"; + params.speculative.mparams_dft.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF"; + params.speculative.mparams_dft.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf"; params.port = 8012; params.n_ubatch = 1024; params.n_batch = 1024; @@ -3634,8 +3699,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params) { params.model.hf_repo = "ggml-org/Qwen2.5-Coder-14B-Q8_0-GGUF"; params.model.hf_file = "qwen2.5-coder-14b-q8_0.gguf"; - params.speculative.model.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF"; - params.speculative.model.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf"; + params.speculative.mparams_dft.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF"; + params.speculative.mparams_dft.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf"; params.port = 8012; params.n_ubatch = 1024; params.n_batch = 1024; diff --git a/common/chat-parser.cpp b/common/chat-parser.cpp index c2d1e30f35e..29819e48d3b 100644 --- a/common/chat-parser.cpp +++ b/common/chat-parser.cpp @@ -1630,7 +1630,7 @@ common_chat_msg common_chat_parse(const std::string & input, bool is_partial, co } auto msg = builder.result(); if (!is_partial) { - LOG_DBG("Parsed message: %s\n", common_chat_msgs_to_json_oaicompat({msg}).at(0).dump().c_str()); + LOG_DBG("Parsed message: %s\n", common_chat_msgs_to_json_oaicompat({msg}).at(0).dump().c_str()); } return msg; } @@ -1663,7 +1663,7 @@ common_chat_msg common_chat_peg_parse(const common_peg_arena & parser, const std mapper.from_ast(ctx.ast, result); } if (!is_partial) { - LOG_DBG("Parsed message: %s\n", common_chat_msgs_to_json_oaicompat({msg}).at(0).dump().c_str()); + LOG_DBG("Parsed message: %s\n", common_chat_msgs_to_json_oaicompat({msg}).at(0).dump().c_str()); } return msg; } diff --git a/common/chat.cpp b/common/chat.cpp index b29544dac01..2bf46326694 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -7,9 +7,6 @@ #include "log.h" #include "regex-partial.h" -// #include -// #include - #include "jinja/parser.h" #include "jinja/value.h" #include "jinja/runtime.h" @@ -56,39 +53,73 @@ static bool has_content_or_tool_calls(const common_chat_msg & msg) { return !msg.content.empty() || !msg.tool_calls.empty(); } -template <> -json common_chat_msg::to_json_oaicompat() const -{ - json message { - {"role", "assistant"}, +json common_chat_msg::to_json_oaicompat(bool concat_typed_text) const { + if (!content.empty() && !content_parts.empty()) { + throw std::runtime_error("Cannot specify both content and content_parts"); + } + json jmsg { + {"role", role}, }; + if (!content.empty()) { + jmsg["content"] = content; + } else if (!content_parts.empty()) { + if (concat_typed_text) { + std::string text; + for (const auto & part : content_parts) { + if (part.type != "text") { + LOG_WRN("Ignoring content part type: %s\n", part.type.c_str()); + continue; + } + if (!text.empty()) { + text += '\n'; + } + text += part.text; + } + jmsg["content"] = text; + } else { + auto & parts = jmsg["content"] = json::array(); + for (const auto & part : content_parts) { + parts.push_back({ + {"type", part.type}, + {"text", part.text}, + }); + } + } + } else { + jmsg["content"] = ""; + } if (!reasoning_content.empty()) { - message["reasoning_content"] = reasoning_content; + jmsg["reasoning_content"] = reasoning_content; } - if (content.empty() && !tool_calls.empty()) { - message["content"] = json(); - } else { - message["content"] = content; + if (!tool_name.empty()) { + jmsg["name"] = tool_name; + } + if (!tool_call_id.empty()) { + jmsg["tool_call_id"] = tool_call_id; } if (!tool_calls.empty()) { - auto arr = json::array(); - for (const auto & tc : tool_calls) { - arr.push_back({ + jmsg["tool_calls"] = json::array(); + auto & jtool_calls = jmsg["tool_calls"]; + for (const auto & tool_call : tool_calls) { + json tc { {"type", "function"}, {"function", { - {"name", tc.name}, - {"arguments", tc.arguments}, + {"name", tool_call.name}, + {"arguments", tool_call.arguments}, }}, - {"id", tc.id}, - // // Some templates generate and require an id (sometimes in a very specific format, e.g. Mistral Nemo). - // // We only generate a random id for the ones that don't generate one by themselves - // // (they also won't get to see it as their template likely doesn't use it, so it's all for the client) - // {"id", tc.id.empty() ? gen_tool_call_id() : tc.id}, - }); + }; + if (!tool_call.id.empty()) { + tc["id"] = tool_call.id; + } + // Some templates generate and require an id (sometimes in a very specific format, e.g. Mistral Nemo). + // We only generate a random id for the ones that don't generate one by themselves + // (they also won't get to see it as their template likely doesn't use it, so it's all for the client) + // {"id", tc.id.empty() ? gen_tool_call_id() : tc.id}, + jtool_calls.push_back(tc); } - message["tool_calls"] = arr; } - return message; + + return jmsg; } std::vector common_chat_msg_diff::compute_diffs(const common_chat_msg & msg_prv, const common_chat_msg & msg_new) { @@ -256,7 +287,6 @@ bool common_chat_templates_support_enable_thinking(const common_chat_templates * return rendered_no_thinking.prompt != rendered_with_thinking.prompt; } -template <> std::vector common_chat_msgs_parse_oaicompat(const json & messages) { std::vector msgs; @@ -350,80 +380,15 @@ std::vector common_chat_msgs_parse_oaicompat(const json & messa return msgs; } -template <> json common_chat_msgs_to_json_oaicompat(const std::vector & msgs, bool concat_typed_text) { json messages = json::array(); for (const auto & msg : msgs) { - if (!msg.content.empty() && !msg.content_parts.empty()) { - throw std::runtime_error("Cannot specify both content and content_parts"); - } - json jmsg { - {"role", msg.role}, - }; - if (!msg.content.empty()) { - jmsg["content"] = msg.content; - } else if (!msg.content_parts.empty()) { - if (concat_typed_text) { - std::string text; - for (const auto & part : msg.content_parts) { - if (part.type != "text") { - LOG_WRN("Ignoring content part type: %s\n", part.type.c_str()); - continue; - } - if (!text.empty()) { - text += '\n'; - } - text += part.text; - } - jmsg["content"] = text; - } else { - auto & parts = jmsg["content"] = json::array(); - for (const auto & part : msg.content_parts) { - parts.push_back({ - {"type", part.type}, - {"text", part.text}, - }); - } - } - } else { - jmsg["content"] = ""; - } - if (!msg.reasoning_content.empty()) { - jmsg["reasoning_content"] = msg.reasoning_content; - } - if (!msg.tool_name.empty()) { - jmsg["name"] = msg.tool_name; - } - if (!msg.tool_call_id.empty()) { - jmsg["tool_call_id"] = msg.tool_call_id; - } - if (!msg.tool_calls.empty()) { - auto & tool_calls = jmsg["tool_calls"] = json::array(); - for (const auto & tool_call : msg.tool_calls) { - json tc { - {"type", "function"}, - {"function", { - {"name", tool_call.name}, - {"arguments", tool_call.arguments}, - }}, - }; - if (!tool_call.id.empty()) { - tc["id"] = tool_call.id; - } - tool_calls.push_back(tc); - } - } + json jmsg = msg.to_json_oaicompat(concat_typed_text); messages.push_back(jmsg); } return messages; } -template <> -std::vector common_chat_msgs_parse_oaicompat(const std::string & messages) { - return common_chat_msgs_parse_oaicompat(json::parse(messages)); -} - -template <> std::vector common_chat_tools_parse_oaicompat(const json & tools) { std::vector result; @@ -459,12 +424,6 @@ std::vector common_chat_tools_parse_oaicompat(const json & too return result; } -template <> -std::vector common_chat_tools_parse_oaicompat(const std::string & tools) { - return common_chat_tools_parse_oaicompat(json::parse(tools)); -} - -template <> json common_chat_tools_to_json_oaicompat(const std::vector & tools) { if (tools.empty()) { return json(); @@ -484,7 +443,7 @@ json common_chat_tools_to_json_oaicompat(const std::vector & t return result; } -template <> json common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff) { +json common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff) { json delta = json::object(); if (!diff.reasoning_content_delta.empty()) { delta["reasoning_content"] = diff.reasoning_content_delta; @@ -812,10 +771,12 @@ static std::string apply( nlohmann::ordered_json inp = nlohmann::ordered_json{ {"messages", messages_override.has_value() ? *messages_override : inputs.messages}, - {"tools", tools_override.has_value() ? *tools_override : inputs.tools}, {"bos_token", tmpl.bos_token()}, {"eos_token", tmpl.eos_token()}, }; + if (tools_override.has_value() || !inputs.tools.empty()) { + inp["tools"] = tools_override.has_value() ? *tools_override : inputs.tools; + } if (inputs.extra_context.is_object()) { // TODO: do we need to merge, or replacing is fine? for (const auto & [k, v] : inputs.extra_context.items()) { @@ -831,9 +792,6 @@ static std::string apply( if (inputs.add_generation_prompt) { inp["add_generation_prompt"] = true; } - if (inp["tools"].is_null()) { - inp["tools"] = json::array(); - } jinja::global_from_json(ctx, inp, inputs.mark_input); @@ -2260,12 +2218,11 @@ static common_chat_params common_chat_params_init_glm_4_5(const common_chat_temp static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) { LOG_DBG("%s\n", __func__); common_chat_params data; - const std::optional tools_override = json(); const std::optional additional_context = json { {"datetime", format_time(inputs.now, "%b %d %Y %H:%M:%S GMT")}, {"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))}, }; - data.prompt = apply(tmpl, inputs, /* messages_override =*/ std::nullopt, tools_override, additional_context); + data.prompt = apply(tmpl, inputs, /* messages_override =*/ std::nullopt, /* tools_override =*/ std::nullopt, additional_context); if (inputs.tools.is_array() && !inputs.tools.empty()) { data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; data.grammar = build_grammar([&](const common_grammar_builder & builder) { @@ -2614,20 +2571,165 @@ static common_chat_params common_chat_params_init_granite(const common_chat_temp static common_chat_params common_chat_params_init_solar_open(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; - // TODO: Reasoning effort - json additional_context = {}; + // Copy `reasoning_content` to `reasoning` + auto adjusted_messages = json::array(); + for (const auto & msg : inputs.messages) { + if (msg.contains("reasoning_content") && msg.at("reasoning_content").is_string()) { + auto adjusted_message = msg; + adjusted_message["reasoning"] = msg.at("reasoning_content"); + adjusted_message.erase("reasoning_content"); + adjusted_messages.push_back(adjusted_message); + } else { + adjusted_messages.push_back(msg); + } + } - data.prompt = apply(tmpl, inputs, std::nullopt, std::nullopt, additional_context); - data.format = COMMON_CHAT_FORMAT_SOLAR_OPEN; + auto has_tools = inputs.tools.is_array() && !inputs.tools.empty(); + auto include_grammar = true; + + auto prompt = apply(tmpl, inputs, /* messages_override= */ adjusted_messages); + // Check if we need to replace the flush token with end token during inference and without generation prompt. + if (inputs.is_inference && !inputs.add_generation_prompt) { + static constexpr std::string_view return_token = "<|flush|>"; + static constexpr std::string_view end_token = "<|end|>"; + if (size_t pos = prompt.rfind(return_token); pos != std::string::npos) { + prompt.replace(pos, return_token.length(), end_token); + } + } + + data.prompt = prompt; + data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; data.preserved_tokens = { "<|think|>", "<|content|>", "<|begin|>", "<|end|>", + "<|tool_calls|>", + "<|tool_call:begin|>", + "<|tool_call:end|>", + "<|tool_call:name|>", + "<|tool_call:args|>", }; - // TODO: Tool calling + auto parser = build_chat_peg_native_parser([&](common_chat_peg_native_builder & p) { + auto lit_think = p.atomic(p.literal("<|think|>")); + auto lit_assistant_begin = p.atomic(p.literal("<|begin|>assistant")); + auto lit_content = p.atomic(p.literal("<|content|>")); + auto lit_end = p.atomic(p.literal("<|end|>")); + auto parser_until_end = p.until("<|end|>"); + + // reasoning <- "<|think|>" (!"<|end|>" .)* + auto parser_reasoning = p.rule("reasoning", lit_think + p.reasoning(parser_until_end)); + + // content <- "<|content|>" (!"<|end|>" .)* + auto parser_content = p.rule("content", lit_content + p.content(parser_until_end)); + + // wrap_choice(items) <- item-choice wrapped* + // item-choice <- items[0] / ... / items[n] + // wrapped <- "<|end|><|begin|>assistant" item-choice + auto wrap_choice = [&](const std::vector & items) { + auto choice = p.choice(items); + return choice + p.zero_or_more(lit_end + lit_assistant_begin + choice); + }; + + // wrap_seq(items) <- item[0] "<|end|><|begin|>assistant" item[1] ... + auto wrap_seq = [&](const std::vector & items) { + auto seq = p.sequence(); + for (auto i = 0u; i < items.size(); i++) { + if (i == 0) { + seq += items[i]; + continue; + } + seq += lit_end + lit_assistant_begin + items[i]; + } + return seq; + }; + + // Response format parser + if (inputs.json_schema.is_object() && !inputs.json_schema.empty()) { + auto parser_response_format = lit_content + p.content(p.schema(p.json(), "response-format", inputs.json_schema)); + return p.choice({ + wrap_seq({parser_reasoning, parser_response_format}), + wrap_seq({parser_response_format}) + }); + } + + auto lit_tool_call_begin = p.literal("<|tool_call:begin|>"); + auto lit_tool_call_name = p.literal("<|tool_call:name|>"); + auto lit_tool_call_args = p.literal("<|tool_call:args|>"); + auto lit_tool_call_end = p.literal("<|tool_call:end|>"); + + // Tool call parser + if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) { + auto parser_tool_call = p.choice(); + foreach_function(inputs.tools, [&](const json & tool) { + const auto & function = tool.at("function"); + std::string name = function.at("name"); + const auto & schema = function.at("parameters"); + + // tool(name, schema) <- name "<|tool_call:args|>" schema + parser_tool_call |= p.rule("tool-" + name, + p.atomic(p.tool_name(p.literal(name)) + lit_tool_call_args) + + p.tool_args(p.schema(p.json(), "tool-" + name + "-schema", schema))); + }); + + auto min_calls = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED ? 1 : 0; + auto max_calls = inputs.parallel_tool_calls ? -1 : 1; + + // tool-calls <- "<|tool_calls|>" tool-call+ + // tool-call <- "<|tool_call:begin|> call-id "<|tool_call:name|>" &([^<]+ "<|tool_call:args|>") tool-choice "<|tool_call:end|>" + // call-id <- [a-zA-Z0-9_-]+ + // tool-choice <- tool(t[0].name, t[0].schema) / ... / tool(t[n].name, t[n].schema) + auto parser_tool_calls = p.trigger_rule("tool-calls", + p.atomic(p.literal("<|tool_calls|>")) + + p.repeat( + p.tool_open( + lit_tool_call_begin + + p.tool_id(p.chars("[a-zA-Z0-9_-]", 1, -1)) + + lit_tool_call_name + + p.peek(p.chars("[^<]", 1, -1) + lit_tool_call_args)) + + parser_tool_call + + p.tool_close(lit_tool_call_end), + /* min = */ 1, + /* max = */ max_calls)); + + if (min_calls == 1) { + // If required, then try any combination of the reasoning, content, and tool call + return p.choice({ + wrap_seq({parser_reasoning, parser_content, parser_tool_calls}), + wrap_seq({parser_reasoning, parser_tool_calls}), + wrap_seq({parser_content, parser_tool_calls}), + wrap_seq({parser_tool_calls}) + }); + } + + return wrap_choice({parser_reasoning, parser_content, parser_tool_calls}); + } + + // Content only parser + include_grammar = false; + return wrap_choice({parser_reasoning, parser_content}); + }); + + data.parser = parser.save(); + + if (include_grammar) { + data.grammar_lazy = has_tools && inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_AUTO; + + data.grammar = build_grammar([&](const common_grammar_builder & builder) { + foreach_function(inputs.tools, [&](const json & tool) { + const auto & function = tool.at("function"); + auto schema = function.at("parameters"); + builder.resolve_refs(schema); + }); + parser.build_grammar(builder, data.grammar_lazy); + }); + + data.grammar_triggers = { + {COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|tool_calls|>"} + }; + } return data; } @@ -2691,6 +2793,51 @@ static common_chat_params common_chat_params_init_exaone_moe(const common_chat_t return data; } +static common_chat_params common_chat_params_init_translate_gemma(const common_chat_template & tmpl, const struct templates_params & inputs) { + common_chat_params data; + + // This template does not support tools or reasoning + // we just need to transform the messages into the correct schema + + templates_params inputs_new = inputs; + json & messages = inputs_new.messages; + + // default to chat_template_kwargs, or en-GB if not specified + std::string default_src_lang = inputs.extra_context.value("source_lang_code", "en-GB"); + std::string default_tgt_lang = inputs.extra_context.value("target_lang_code", "en-GB"); + + GGML_ASSERT(messages.is_array()); + for (auto & message : messages) { + if (message.contains("role") && message["role"].get() != "user") { + continue; + } + if (!message.contains("content")) { + message["content"] = json::array(); + } + if (message.contains("content") && !message["content"].is_array()) { + auto content_str = message["content"].get(); + // default to en-GB if not specified (to make common_chat_format_example works) + auto src_lang = message.contains("source_lang_code") + ? message["source_lang_code"].get() : default_src_lang; + auto tgt_lang = message.contains("target_lang_code") + ? message["target_lang_code"].get() : default_tgt_lang; + message["content"] = json::array({ + json{ + {"type", "text"}, + {"text", content_str}, + {"source_lang_code", src_lang}, + {"target_lang_code", tgt_lang}, + } + }); + } + } + + data.prompt = apply(tmpl, inputs_new, std::nullopt, std::nullopt); + data.format = COMMON_CHAT_FORMAT_GENERIC; + + return data; +} + static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; data.prompt = apply(tmpl, inputs); @@ -2867,13 +3014,13 @@ static common_chat_params common_chat_templates_apply_jinja( const struct common_chat_templates_inputs & inputs) { templates_params params; - params.tools = common_chat_tools_to_json_oaicompat(inputs.tools); + params.tools = common_chat_tools_to_json_oaicompat(inputs.tools); const auto & tmpl = params.tools.is_array() && tmpls->template_tool_use ? *tmpls->template_tool_use : *tmpls->template_default; const auto & src = tmpl.source(); const auto & caps = tmpl.original_caps(); - params.messages = common_chat_msgs_to_json_oaicompat(inputs.messages, /* concat_text= */ !tmpl.original_caps().requires_typed_content); + params.messages = common_chat_msgs_to_json_oaicompat(inputs.messages, /* concat_text= */ !tmpl.original_caps().requires_typed_content); params.add_generation_prompt = inputs.add_generation_prompt; params.tool_choice = inputs.tool_choice; params.reasoning_format = inputs.reasoning_format; @@ -2943,6 +3090,10 @@ static common_chat_params common_chat_templates_apply_jinja( src.find("") != std::string::npos && params.json_schema.is_null()) { workaround::func_args_not_string(params.messages); + if (!params.extra_context.contains("clear_thinking")) { + // by default, do not clear reasoning_content (added since GLM-4.7) + params.extra_context["clear_thinking"] = false; + } return common_chat_params_init_glm_4_5(tmpl, params); } @@ -3035,6 +3186,13 @@ static common_chat_params common_chat_templates_apply_jinja( return common_chat_params_init_apriel_1_5(tmpl, params); } + // Solar Open + if (src.find("<|tool_response:begin|>") != std::string::npos && + src.find("<|tool_response:name|>") != std::string::npos && + src.find("<|tool_response:result|>") != std::string::npos) { + return common_chat_params_init_solar_open(tmpl, params); + } + // Use generic handler when mixing tools + JSON schema. // TODO: support that mix in handlers below. if ((params.tools.is_array() && params.json_schema.is_object())) { @@ -3082,6 +3240,12 @@ static common_chat_params common_chat_templates_apply_jinja( return common_chat_params_init_solar_open(tmpl, params); } + // TranslateGemma + if (src.find("[source_lang_code]") != std::string::npos && + src.find("[target_lang_code]") != std::string::npos) { + return common_chat_params_init_translate_gemma(tmpl, params); + } + // Plain handler (no tools) if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) { return common_chat_params_init_without_tools(tmpl, params); @@ -3174,3 +3338,9 @@ common_chat_params common_chat_templates_apply( ? common_chat_templates_apply_jinja(tmpls, inputs) : common_chat_templates_apply_legacy(tmpls, inputs); } + +std::map common_chat_templates_get_caps(const common_chat_templates * chat_templates) { + GGML_ASSERT(chat_templates != nullptr); + GGML_ASSERT(chat_templates->template_default != nullptr); + return chat_templates->template_default->caps.to_map(); +} diff --git a/common/chat.h b/common/chat.h index ac19348ece7..24aa4aab5cd 100644 --- a/common/chat.h +++ b/common/chat.h @@ -10,6 +10,8 @@ #include #include +#include + struct common_chat_templates; struct common_chat_tool_call { @@ -26,6 +28,11 @@ struct common_chat_msg_content_part { std::string type; std::string text; + // TODO @ngxson : no known chat templates support reasoning_content in content parts yet + // this can be useful for models with interleaved thinking (like Kimi-K2) + // if you see any templates explicitly support this, please ping me + // std::string reasoning_content; + bool operator==(const common_chat_msg_content_part & other) const { return type == other.type && text == other.text; } @@ -40,7 +47,7 @@ struct common_chat_msg { std::string tool_name; std::string tool_call_id; - template T to_json_oaicompat() const; + nlohmann::ordered_json to_json_oaicompat(bool concat_typed_text = false) const; bool empty() const { return content.empty() && content_parts.empty() && tool_calls.empty() && reasoning_content.empty() && tool_name.empty() && tool_call_id.empty(); @@ -232,13 +239,13 @@ common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::strin bool common_chat_templates_support_enable_thinking(const common_chat_templates * chat_templates); // Parses a JSON array of messages in OpenAI's chat completion API format. -// T can be std::string containing JSON or nlohmann::ordered_json -template std::vector common_chat_msgs_parse_oaicompat(const T & messages); -template T common_chat_msgs_to_json_oaicompat(const std::vector & msgs, bool concat_typed_text = false); +std::vector common_chat_msgs_parse_oaicompat(const nlohmann::ordered_json & messages); +nlohmann::ordered_json common_chat_msgs_to_json_oaicompat(const std::vector & msgs, bool concat_typed_text = false); + +std::vector common_chat_tools_parse_oaicompat(const nlohmann::ordered_json & tools); +nlohmann::ordered_json common_chat_tools_to_json_oaicompat(const std::vector & tools); -// Parses a JSON array of tools in OpenAI's chat completion tool call API format. -// T can be std::string containing JSON or nlohmann::ordered_json -template std::vector common_chat_tools_parse_oaicompat(const T & tools); -template T common_chat_tools_to_json_oaicompat(const std::vector & tools); +nlohmann::ordered_json common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff); -template T common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff); +// get template caps, useful for reporting to server /props endpoint +std::map common_chat_templates_get_caps(const common_chat_templates * chat_templates); diff --git a/common/common.cpp b/common/common.cpp index 26250abb6c8..3aa396127ce 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1097,7 +1097,10 @@ common_init_result::common_init_result(common_params & params) : if (params.fit_params) { LOG_INF("%s: fitting params to device memory, for bugs during this step try to reproduce them with -fit off, or provide --verbose logs if the bug only occurs with -fit on\n", __func__); llama_params_fit(params.model.path.c_str(), &mparams, &cparams, - params.tensor_split, params.tensor_buft_overrides.data(), params.fit_params_target.data(), params.fit_params_min_ctx, + params.tensor_split, + params.tensor_buft_overrides.data(), + params.fit_params_target.data(), + params.fit_params_min_ctx, params.verbosity >= 4 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_ERROR); } @@ -1208,10 +1211,6 @@ std::vector & common_init_result::lora() { return pimpl->lora; } -void common_init_result::free_context() { - pimpl->context.reset(); -} - common_init_result_ptr common_init_from_params(common_params & params) { common_init_result_ptr res(new common_init_result(params)); diff --git a/common/common.h b/common/common.h index 96c990c05d8..398ebb09601 100644 --- a/common/common.h +++ b/common/common.h @@ -164,6 +164,17 @@ enum common_params_sampling_config : uint64_t { COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA = 1 << 11, }; +enum common_speculative_type { + COMMON_SPECULATIVE_TYPE_NONE, // no speculative decoding + COMMON_SPECULATIVE_TYPE_DRAFT, // draft model + COMMON_SPECULATIVE_TYPE_EAGLE3, // eagle draft model + COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, // simple self-speculative decoding + COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, // self-speculative decoding with n-gram keys only + COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, // self-speculative decoding with n-gram keys and 4 m-gram values + COMMON_SPECULATIVE_TYPE_NGRAM_MOD, + COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, // self-speculative decoding with 3-level n-gram cache + COMMON_SPECULATIVE_TYPE_COUNT // number of types, unknown type +}; // sampling parameters struct common_params_sampling { @@ -242,17 +253,40 @@ struct common_params_model { std::string name = ""; // in format /[:] (tag is optional) // NOLINT }; +struct common_ngram_mod; + struct common_params_speculative { - std::vector devices; // devices to use for offloading + common_speculative_type type = COMMON_SPECULATIVE_TYPE_NONE; // type of speculative decoding - int32_t n_ctx = 0; // draft context size - int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding - int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding - int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default) - float p_split = 0.1f; // speculative decoding split probability - float p_min = 0.75f; // minimum speculative decoding probability (greedy) - std::vector> replacements; // main to speculative model replacements - std::vector tensor_buft_overrides; + // general-purpose speculative decoding parameters + + int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding + int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding + float p_split = 0.1f; // speculative decoding split probability + float p_min = 0.75f; // minimum speculative decoding probability (greedy) + + // ngram-based speculative decoding + + uint16_t ngram_size_n = 12; // ngram size for lookup + uint16_t ngram_size_m = 48; // mgram size for speculative tokens + uint16_t ngram_check_rate = 1; // check rate for ngram lookup + uint16_t ngram_min_hits = 1; // minimum hits at ngram/mgram lookup for mgram to be proposed + + std::shared_ptr ngram_mod; + + std::string lookup_cache_static; // path of static ngram cache file for lookup decoding // NOLINT + std::string lookup_cache_dynamic; // path of dynamic ngram cache file for lookup decoding // NOLINT + + // draft-model speculative decoding + + struct common_params_model mparams_dft; + + llama_model * model_dft = nullptr; // a llama_model that can be shared by multiple speculative contexts + + llama_context_params cparams_dft; // these are the parameters for the draft llama_context + + int32_t n_ctx = 0; // draft context size + int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default) ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V @@ -260,7 +294,14 @@ struct common_params_speculative { struct cpu_params cpuparams; struct cpu_params cpuparams_batch; - struct common_params_model model; + std::vector devices; // devices to use for offloading + + std::vector> replacements; // main to speculative model replacements + std::vector tensor_buft_overrides; + + bool has_dft() const { + return !mparams_dft.path.empty() || !mparams_dft.hf_repo.empty(); + } }; struct common_params_vocoder { @@ -378,8 +419,6 @@ struct common_params { std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state // NOLINT std::string input_prefix = ""; // string to prefix user inputs with // NOLINT std::string input_suffix = ""; // string to suffix user inputs with // NOLINT - std::string lookup_cache_static = ""; // path of static ngram cache file for lookup decoding // NOLINT - std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding // NOLINT std::string logits_file = ""; // file for saving *all* logits // NOLINT // llama-debug specific options @@ -438,7 +477,7 @@ struct common_params { bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix bool use_mmap = true; // enable mmap to use filesystem cache - bool use_direct_io = true; // read from disk without buffering for faster model loading + bool use_direct_io = false; // read from disk without buffering bool use_mlock = false; // use mlock to keep model in memory bool verbose_prompt = false; // print prompt tokens before generation bool display_prompt = true; // print prompt before generation @@ -575,10 +614,6 @@ struct common_params { // return false from callback to abort model loading or true to continue llama_progress_callback load_progress_callback = NULL; void * load_progress_callback_user_data = NULL; - - bool has_speculative() const { - return !speculative.model.path.empty() || !speculative.model.hf_repo.empty(); - } }; // call once at the start of a program if it uses libcommon @@ -714,8 +749,6 @@ struct common_init_result { std::vector & lora(); - void free_context(); - private: struct impl; std::unique_ptr pimpl; diff --git a/common/debug.cpp b/common/debug.cpp index fdaddb14436..0df409a79db 100644 --- a/common/debug.cpp +++ b/common/debug.cpp @@ -45,6 +45,8 @@ static float common_ggml_get_float_value(const uint8_t * data, return v; } +#define INDENT " " + template void common_debug_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne, const size_t * nb, int64_t n) { GGML_ASSERT(n > 0); @@ -60,41 +62,41 @@ void common_debug_print_tensor(uint8_t * data, ggml_type type, const int64_t * n } } for (int64_t i3 = 0; i3 < ne[3]; i3++) { - LOG_ERR(" [\n"); + LOG(INDENT "[\n"); for (int64_t i2 = 0; i2 < ne[2]; i2++) { if (i2 == n && ne[2] > 2 * n) { - LOG_ERR(" ..., \n"); + LOG(INDENT INDENT "..., \n"); i2 = ne[2] - n; } - LOG_ERR(" [\n"); + LOG(INDENT INDENT "[\n"); for (int64_t i1 = 0; i1 < ne[1]; i1++) { if (i1 == n && ne[1] > 2 * n) { - LOG_ERR(" ..., \n"); + LOG(INDENT INDENT INDENT "..., \n"); i1 = ne[1] - n; } - LOG_ERR(" ["); + LOG(INDENT INDENT INDENT "["); for (int64_t i0 = 0; i0 < ne[0]; i0++) { if (i0 == n && ne[0] > 2 * n) { - LOG_ERR("..., "); + LOG(" ..., "); i0 = ne[0] - n; } const float v = common_ggml_get_float_value(data, type, nb, i0, i1, i2, i3); - LOG_ERR("%12.4f", v); + LOG("%12.4f", v); if (i0 < ne[0] - 1) { - LOG_ERR(", "); + LOG(", "); } } - LOG_ERR("],\n"); + LOG(" ],\n"); } - LOG_ERR(" ],\n"); + LOG(INDENT INDENT "],\n"); } - LOG_ERR(" ]\n"); - LOG_ERR(" sum = %f\n", sum); + LOG(INDENT "]\n"); + LOG(INDENT "sum = %f\n", sum); } if constexpr (abort) { if (std::isnan(sum)) { - LOG_ERR("encountered NaN - aborting\n"); + LOG("encountered NaN - aborting\n"); exit(0); } } @@ -137,9 +139,9 @@ template bool common_debug_cb_eval(struct ggml_tensor * t, b } if (matches_filter) { - LOG_ERR("%s: %24s = (%s) %10s(%s{%s}, %s}) = {%s}\n", __func__, t->name, ggml_type_name(t->type), - ggml_op_desc(t), src0->name, common_ggml_ne_string(src0).c_str(), src1 ? src1_str : "", - common_ggml_ne_string(t).c_str()); + LOG("%s: %24s = (%s) %10s(%s{%s}, %s}) = {%s}\n", __func__, t->name, ggml_type_name(t->type), + ggml_op_desc(t), src0->name, common_ggml_ne_string(src0).c_str(), src1 ? src1_str : "", + common_ggml_ne_string(t).c_str()); } const bool is_host = ggml_backend_buffer_is_host(t->buffer); diff --git a/common/http.h b/common/http.h index 8e29787dcc6..e8ed56f952b 100644 --- a/common/http.h +++ b/common/http.h @@ -57,6 +57,17 @@ static std::pair common_http_client(const std: throw std::runtime_error("error: invalid URL format"); } +#ifndef CPPHTTPLIB_OPENSSL_SUPPORT + if (parts.scheme == "https") { + throw std::runtime_error( + "HTTPS is not supported. Please rebuild with one of:\n" + " -DLLAMA_BUILD_BORINGSSL=ON\n" + " -DLLAMA_BUILD_LIBRESSL=ON\n" + " -DLLAMA_OPENSSL=ON (default, requires OpenSSL dev files installed)" + ); + } +#endif + httplib::Client cli(parts.scheme + "://" + parts.host); if (!parts.user.empty()) { diff --git a/common/jinja/caps.cpp b/common/jinja/caps.cpp index 61deccd1f5e..f27490f1fb7 100644 --- a/common/jinja/caps.cpp +++ b/common/jinja/caps.cpp @@ -61,14 +61,23 @@ static void caps_print_stats(value & v, const std::string & path) { ops.c_str()); } +std::map caps::to_map() const { + return { + {"requires_typed_content", requires_typed_content}, + {"supports_tools", supports_tools}, + {"supports_tool_calls", supports_tool_calls}, + {"supports_parallel_tool_calls", supports_parallel_tool_calls}, + {"supports_system_role", supports_system_role}, + {"supports_preserve_reasoning", supports_preserve_reasoning}, + }; +} + std::string caps::to_string() const { std::ostringstream ss; ss << "Caps(\n"; - ss << " requires_typed_content=" << requires_typed_content << "\n"; - ss << " supports_tools=" << supports_tools << "\n"; - ss << " supports_tool_calls=" << supports_tool_calls << "\n"; - ss << " supports_parallel_tool_calls=" << supports_parallel_tool_calls << "\n"; - ss << " supports_system_role=" << supports_system_role << "\n"; + for (const auto & [key, value] : to_map()) { + ss << " " << key << "=" << (value ? "true" : "false") << "\n"; + } ss << ")"; return ss.str(); } @@ -229,6 +238,40 @@ caps caps_get(jinja::program & prog) { } ); + // case: preserve reasoning content in chat history + caps_try_execute( + prog, + [&]() { + // messages + return json::array({ + { + {"role", "user"}, + {"content", "User message"} + }, + { + {"role", "assistant"}, + {"content", "Assistant message"}, + {"reasoning_content", "Reasoning content"} + }, + { + {"role", "user"}, + {"content", "User message"} + }, + }); + }, + [&]() { + // tools + return json::array(); + }, + [&](bool, value & messages, value &) { + auto & content = messages->at(1)->at("reasoning_content"); + caps_print_stats(content, "messages[1].reasoning_content"); + if (content->stats.used) { + result.supports_preserve_reasoning = true; + } + } + ); + JJ_DEBUG("%s\n", result.to_string().c_str()); return result; diff --git a/common/jinja/caps.h b/common/jinja/caps.h index deb2df180f0..77df117baa1 100644 --- a/common/jinja/caps.h +++ b/common/jinja/caps.h @@ -3,6 +3,7 @@ #include "runtime.h" #include +#include namespace jinja { @@ -11,14 +12,17 @@ struct caps { bool supports_tool_calls = true; bool supports_system_role = true; bool supports_parallel_tool_calls = true; + bool supports_preserve_reasoning = false; // support assistant message with reasoning_content bool requires_typed_content = false; // default: use string content + // for reporting on server + std::map to_map() const; + // for debugging std::string to_string() const; }; caps caps_get(jinja::program & prog); -void debug_print_caps(const caps & c); } // namespace jinja diff --git a/common/jinja/runtime.cpp b/common/jinja/runtime.cpp index e3e4ebf1ec2..4453d86e6d7 100644 --- a/common/jinja/runtime.cpp +++ b/common/jinja/runtime.cpp @@ -44,6 +44,12 @@ static std::string get_line_col(const std::string & source, size_t pos) { return "line " + std::to_string(line) + ", column " + std::to_string(col); } +static void ensure_key_type_allowed(const value & val) { + if (!val->is_hashable()) { + throw std::runtime_error("Type: " + val->type() + " is not allowed as object key"); + } +} + // execute with error handling value statement::execute(context & ctx) { try { @@ -95,20 +101,10 @@ value identifier::execute_impl(context & ctx) { value object_literal::execute_impl(context & ctx) { auto obj = mk_val(); for (const auto & pair : val) { - value key_val = pair.first->execute(ctx); - if (!is_val(key_val) && !is_val(key_val)) { - throw std::runtime_error("Object literal: keys must be string or int values, got " + key_val->type()); - } - std::string key = key_val->as_string().str(); + value key = pair.first->execute(ctx); value val = pair.second->execute(ctx); - JJ_DEBUG("Object literal: setting key '%s' with value type %s", key.c_str(), val->type().c_str()); + JJ_DEBUG("Object literal: setting key '%s' with value type %s", key->as_string().str().c_str(), val->type().c_str()); obj->insert(key, val); - - if (is_val(key_val)) { - obj->val_obj.is_key_numeric = true; - } else if (obj->val_obj.is_key_numeric) { - throw std::runtime_error("Object literal: cannot mix numeric and non-numeric keys"); - } } return obj; } @@ -127,9 +123,9 @@ value binary_expression::execute_impl(context & ctx) { value right_val = right->execute(ctx); JJ_DEBUG("Executing binary expression %s '%s' %s", left_val->type().c_str(), op.value.c_str(), right_val->type().c_str()); if (op.value == "==") { - return mk_val(value_compare(left_val, right_val, value_compare_op::eq)); + return mk_val(*left_val == *right_val); } else if (op.value == "!=") { - return mk_val(!value_compare(left_val, right_val, value_compare_op::eq)); + return mk_val(!(*left_val == *right_val)); } auto workaround_concat_null_with_str = [&](value & res) -> bool { @@ -148,6 +144,13 @@ value binary_expression::execute_impl(context & ctx) { return false; }; + auto test_is_in = [&]() -> bool { + func_args args(ctx); + args.push_back(left_val); + args.push_back(right_val); + return global_builtins().at("test_is_in")(args)->as_bool(); + }; + // Handle undefined and null values if (is_val(left_val) || is_val(right_val)) { if (is_val(right_val) && (op.value == "in" || op.value == "not in")) { @@ -227,19 +230,11 @@ value binary_expression::execute_impl(context & ctx) { return result; } } else if (is_val(right_val)) { - auto & arr = right_val->as_array(); - bool member = false; - for (const auto & item : arr) { - if (value_compare(left_val, item, value_compare_op::eq)) { - member = true; - break; - } - } + // case: 1 in [0, 1, 2] + bool member = test_is_in(); if (op.value == "in") { - JJ_DEBUG("Checking membership: %s in Array is %d", left_val->type().c_str(), member); return mk_val(member); } else if (op.value == "not in") { - JJ_DEBUG("Checking non-membership: %s not in Array is %d", left_val->type().c_str(), !member); return mk_val(!member); } } @@ -256,23 +251,23 @@ value binary_expression::execute_impl(context & ctx) { // String membership if (is_val(left_val) && is_val(right_val)) { - auto left_str = left_val->as_string().str(); - auto right_str = right_val->as_string().str(); + // case: "a" in "abc" + bool member = test_is_in(); if (op.value == "in") { - return mk_val(right_str.find(left_str) != std::string::npos); + return mk_val(member); } else if (op.value == "not in") { - return mk_val(right_str.find(left_str) == std::string::npos); + return mk_val(!member); } } - // String in object - if (is_val(left_val) && is_val(right_val)) { - auto key = left_val->as_string().str(); - bool has_key = right_val->has_key(key); + // Value key in object + if (is_val(right_val)) { + // case: key in {key: value} + bool member = test_is_in(); if (op.value == "in") { - return mk_val(has_key); + return mk_val(member); } else if (op.value == "not in") { - return mk_val(!has_key); + return mk_val(!member); } } @@ -465,14 +460,8 @@ value for_statement::execute_impl(context & ctx) { JJ_DEBUG("%s", "For loop over object keys"); auto & obj = iterable_val->as_ordered_object(); for (auto & p : obj) { - auto tuple = mk_val(); - if (iterable_val->val_obj.is_key_numeric) { - tuple->push_back(mk_val(std::stoll(p.first))); - } else { - tuple->push_back(mk_val(p.first)); - } - tuple->push_back(p.second); - items.push_back(tuple); + auto tuple = mk_val(p); + items.push_back(std::move(tuple)); } if (ctx.is_get_stats) { iterable_val->stats.used = true; @@ -602,11 +591,13 @@ value set_statement::execute_impl(context & ctx) { auto rhs = val ? val->execute(ctx) : exec_statements(body, ctx); if (is_stmt(assignee)) { + // case: {% set my_var = value %} auto var_name = cast_stmt(assignee)->val; JJ_DEBUG("Setting global variable '%s' with value type %s", var_name.c_str(), rhs->type().c_str()); ctx.set_val(var_name, rhs); } else if (is_stmt(assignee)) { + // case: {% set a, b = value %} auto tuple = cast_stmt(assignee); if (!is_val(rhs)) { throw std::runtime_error("Cannot unpack non-iterable type in set: " + rhs->type()); @@ -625,6 +616,7 @@ value set_statement::execute_impl(context & ctx) { } } else if (is_stmt(assignee)) { + // case: {% set ns.my_var = value %} auto member = cast_stmt(assignee); if (member->computed) { throw std::runtime_error("Cannot assign to computed member"); @@ -767,22 +759,22 @@ value member_expression::execute_impl(context & ctx) { } JJ_DEBUG("Member expression on object type %s, property type %s", object->type().c_str(), property->type().c_str()); + ensure_key_type_allowed(property); value val = mk_val("object_property"); if (is_val(object)) { JJ_DEBUG("%s", "Accessing property on undefined object, returning undefined"); return val; + } else if (is_val(object)) { - if (!is_val(property)) { - throw std::runtime_error("Cannot access object with non-string: got " + property->type()); - } auto key = property->as_string().str(); - val = object->at(key, val); + val = object->at(property, val); if (is_val(val)) { val = try_builtin_func(ctx, key, object, true); } JJ_DEBUG("Accessed property '%s' value, got type: %s", key.c_str(), val->type().c_str()); + } else if (is_val(object) || is_val(object)) { if (is_val(property)) { int64_t index = property->as_int(); @@ -806,6 +798,7 @@ value member_expression::execute_impl(context & ctx) { auto key = property->as_string().str(); JJ_DEBUG("Accessing %s built-in '%s'", is_val(object) ? "array" : "string", key.c_str()); val = try_builtin_func(ctx, key, object, true); + } else { throw std::runtime_error("Cannot access property with non-string/non-number: got " + property->type()); } diff --git a/common/jinja/runtime.h b/common/jinja/runtime.h index dc7f4e471c1..17a6dff5aa2 100644 --- a/common/jinja/runtime.h +++ b/common/jinja/runtime.h @@ -79,18 +79,18 @@ struct context { } value get_val(const std::string & name) { - auto it = env->val_obj.unordered.find(name); - if (it != env->val_obj.unordered.end()) { - return it->second; - } else { - return mk_val(name); - } + value default_val = mk_val(name); + return env->at(name, default_val); } void set_val(const std::string & name, const value & val) { env->insert(name, val); } + void set_val(const value & name, const value & val) { + env->insert(name, val); + } + void print_vars() const { printf("Context Variables:\n%s\n", value_to_json(env, 2).c_str()); } @@ -344,9 +344,19 @@ struct array_literal : public expression { } }; -struct tuple_literal : public array_literal { - explicit tuple_literal(statements && val) : array_literal(std::move(val)) {} +struct tuple_literal : public expression { + statements val; + explicit tuple_literal(statements && val) : val(std::move(val)) { + for (const auto& item : this->val) chk_type(item); + } std::string type() const override { return "TupleLiteral"; } + value execute_impl(context & ctx) override { + auto arr = mk_val(); + for (const auto & item_stmt : val) { + arr->push_back(item_stmt->execute(ctx)); + } + return mk_val(std::move(arr->as_array())); + } }; struct object_literal : public expression { diff --git a/common/jinja/string.cpp b/common/jinja/string.cpp index 21ebde39e3e..8087e15b350 100644 --- a/common/jinja/string.cpp +++ b/common/jinja/string.cpp @@ -61,6 +61,12 @@ size_t string::length() const { return len; } +void string::hash_update(hasher & hash) const noexcept { + for (const auto & part : parts) { + hash.update(part.val.data(), part.val.length()); + } +} + bool string::all_parts_are_input() const { for (const auto & part : parts) { if (!part.is_input) { diff --git a/common/jinja/string.h b/common/jinja/string.h index 78457f9e413..c4963000adb 100644 --- a/common/jinja/string.h +++ b/common/jinja/string.h @@ -4,6 +4,8 @@ #include #include +#include "utils.h" + namespace jinja { // allow differentiate between user input strings and template strings @@ -37,6 +39,7 @@ struct string { std::string str() const; size_t length() const; + void hash_update(hasher & hash) const noexcept; bool all_parts_are_input() const; bool is_uppercase() const; bool is_lowercase() const; diff --git a/common/jinja/utils.h b/common/jinja/utils.h index 1e9f2a12a1a..de6947fc28f 100644 --- a/common/jinja/utils.h +++ b/common/jinja/utils.h @@ -3,6 +3,8 @@ #include #include #include +#include +#include namespace jinja { @@ -46,4 +48,102 @@ static std::string fmt_error_with_source(const std::string & tag, const std::str return oss.str(); } +// Note: this is a simple hasher, not cryptographically secure, just for hash table usage +struct hasher { + static constexpr auto size_t_digits = sizeof(size_t) * 8; + static constexpr size_t prime = size_t_digits == 64 ? 0x100000001b3 : 0x01000193; + static constexpr size_t seed = size_t_digits == 64 ? 0xcbf29ce484222325 : 0x811c9dc5; + static constexpr auto block_size = sizeof(size_t); // in bytes; allowing the compiler to vectorize the computation + + static_assert(size_t_digits == 64 || size_t_digits == 32); + static_assert(block_size == 8 || block_size == 4); + + uint8_t buffer[block_size]; + size_t idx = 0; // current index in buffer + size_t state = seed; + + hasher() = default; + hasher(const std::type_info & type_inf) noexcept { + const auto type_hash = type_inf.hash_code(); + update(&type_hash, sizeof(type_hash)); + } + + // Properties: + // - update is not associative: update(a).update(b) != update(b).update(a) + // - update(a ~ b) == update(a).update(b) with ~ as concatenation operator --> useful for streaming + // - update("", 0) --> state unchanged with empty input + hasher& update(void const * bytes, size_t len) noexcept { + const uint8_t * c = static_cast(bytes); + if (len == 0) { + return *this; + } + size_t processed = 0; + + // first, fill the existing buffer if it's partial + if (idx > 0) { + size_t to_fill = block_size - idx; + if (to_fill > len) { + to_fill = len; + } + std::memcpy(buffer + idx, c, to_fill); + idx += to_fill; + processed += to_fill; + if (idx == block_size) { + update_block(buffer); + idx = 0; + } + } + + // process full blocks from the remaining input + for (; processed + block_size <= len; processed += block_size) { + update_block(c + processed); + } + + // buffer any remaining bytes + size_t remaining = len - processed; + if (remaining > 0) { + std::memcpy(buffer, c + processed, remaining); + idx = remaining; + } + return *this; + } + + // convenience function for testing only + hasher& update(const std::string & s) noexcept { + return update(s.data(), s.size()); + } + + // finalize and get the hash value + // note: after calling digest, the hasher state is modified, do not call update() again + size_t digest() noexcept { + // if there are remaining bytes in buffer, fill the rest with zeros and process + if (idx > 0) { + for (size_t i = idx; i < block_size; ++i) { + buffer[i] = 0; + } + update_block(buffer); + idx = 0; + } + + return state; + } + +private: + // IMPORTANT: block must have at least block_size bytes + void update_block(const uint8_t * block) noexcept { + size_t blk = static_cast(block[0]) + | (static_cast(block[1]) << 8) + | (static_cast(block[2]) << 16) + | (static_cast(block[3]) << 24); + if constexpr (block_size == 8) { + blk = blk | (static_cast(block[4]) << 32) + | (static_cast(block[5]) << 40) + | (static_cast(block[6]) << 48) + | (static_cast(block[7]) << 56); + } + state ^= blk; + state *= prime; + } +}; + } // namespace jinja diff --git a/common/jinja/value.cpp b/common/jinja/value.cpp index e414aad444c..2aa156b1778 100644 --- a/common/jinja/value.cpp +++ b/common/jinja/value.cpp @@ -114,6 +114,18 @@ static T slice(const T & array, int64_t start, int64_t stop, int64_t step = 1) { return result; } +template +static value empty_value_fn(const func_args &) { + if constexpr (std::is_same_v) { + return mk_val(0); + } else if constexpr (std::is_same_v) { + return mk_val(0.0); + } else if constexpr (std::is_same_v) { + return mk_val(false); + } else { + return mk_val(); + } +} template static value test_type_fn(const func_args & args) { args.ensure_count(1); @@ -128,6 +140,13 @@ static value test_type_fn(const func_args & args) { JJ_DEBUG("test_type_fn: type=%s or %s result=%d", typeid(T).name(), typeid(U).name(), is_type ? 1 : 0); return mk_val(is_type); } +template +static value test_type_fn(const func_args & args) { + args.ensure_count(1); + bool is_type = is_val(args.get_pos(0)) || is_val(args.get_pos(0)) || is_val(args.get_pos(0)); + JJ_DEBUG("test_type_fn: type=%s, %s or %s result=%d", typeid(T).name(), typeid(U).name(), typeid(V).name(), is_type ? 1 : 0); + return mk_val(is_type); +} template static value test_compare_fn(const func_args & args) { args.ensure_count(2, 2); @@ -163,7 +182,7 @@ static value selectattr(const func_args & args) { args.ensure_vals(true, true, false, false); auto arr = args.get_pos(0)->as_array(); - auto attr_name = args.get_pos(1)->as_string().str(); + auto attribute = args.get_pos(1); auto out = mk_val(); value val_default = mk_val(); @@ -173,7 +192,7 @@ static value selectattr(const func_args & args) { if (!is_val(item)) { throw raised_exception("selectattr: item is not an object"); } - value attr_val = item->at(attr_name, val_default); + value attr_val = item->at(attribute, val_default); bool is_selected = attr_val->as_bool(); if constexpr (is_reject) is_selected = !is_selected; if (is_selected) out->push_back(item); @@ -217,7 +236,7 @@ static value selectattr(const func_args & args) { if (!is_val(item)) { throw raised_exception("selectattr: item is not an object"); } - value attr_val = item->at(attr_name, val_default); + value attr_val = item->at(attribute, val_default); func_args test_args(args.ctx); test_args.push_back(attr_val); // attribute value test_args.push_back(extra_arg); // extra argument @@ -347,8 +366,8 @@ const func_builtins & global_builtins() { {"test_is_integer", test_type_fn}, {"test_is_float", test_type_fn}, {"test_is_number", test_type_fn}, - {"test_is_iterable", test_type_fn}, - {"test_is_sequence", test_type_fn}, + {"test_is_iterable", test_type_fn}, + {"test_is_sequence", test_type_fn}, {"test_is_mapping", test_type_fn}, {"test_is_lower", [](const func_args & args) -> value { args.ensure_vals(); @@ -374,6 +393,33 @@ const func_builtins & global_builtins() { {"test_is_lt", test_compare_fn}, {"test_is_lessthan", test_compare_fn}, {"test_is_ne", test_compare_fn}, + {"test_is_in", [](const func_args & args) -> value { + args.ensure_count(2); + auto needle = args.get_pos(0); + auto haystack = args.get_pos(1); + if (is_val(haystack)) { + return mk_val(false); + } + if (is_val(haystack)) { + for (const auto & item : haystack->as_array()) { + if (*needle == *item) { + return mk_val(true); + } + } + return mk_val(false); + } + if (is_val(haystack)) { + if (!is_val(needle)) { + throw raised_exception("'in' test expects args[1] as string when args[0] is string, got args[1] as " + needle->type()); + } + return mk_val( + haystack->as_string().str().find(needle->as_string().str()) != std::string::npos); + } + if (is_val(haystack)) { + return mk_val(haystack->has_key(needle)); + } + throw raised_exception("'in' test expects iterable as first argument, got " + haystack->type()); + }}, {"test_is_test", [](const func_args & args) -> value { args.ensure_vals(); auto & builtins = global_builtins(); @@ -741,6 +787,7 @@ const func_builtins & value_array_t::get_builtins() const { args.ensure_count(1, 4); args.ensure_vals(true, true, false, false); + auto val = args.get_pos(0); auto arg0 = args.get_pos(1); auto arg1 = args.get_pos(2, mk_val()); auto arg2 = args.get_pos(3, mk_val()); @@ -762,10 +809,8 @@ const func_builtins & value_array_t::get_builtins() const { if (step == 0) { throw raised_exception("slice step cannot be zero"); } - auto arr = slice(args.get_pos(0)->as_array(), start, stop, step); - auto res = mk_val(); - res->val_arr = std::move(arr); - return res; + auto arr = slice(val->as_array(), start, stop, step); + return is_val(val) ? mk_val(std::move(arr)) : mk_val(std::move(arr)); }}, {"selectattr", selectattr}, {"select", selectattr}, @@ -785,15 +830,14 @@ const func_builtins & value_array_t::get_builtins() const { } const int64_t attr_int = attr_is_int ? attribute->as_int() : 0; const std::string delim = val_delim->is_undefined() ? "" : val_delim->as_string().str(); - const std::string attr_name = attribute->is_undefined() ? "" : attribute->as_string().str(); std::string result; for (size_t i = 0; i < arr.size(); ++i) { value val_arr = arr[i]; if (!attribute->is_undefined()) { if (attr_is_int && is_val(val_arr)) { val_arr = val_arr->at(attr_int); - } else if (!attr_is_int && !attr_name.empty() && is_val(val_arr)) { - val_arr = val_arr->at(attr_name); + } else if (!attr_is_int && is_val(val_arr)) { + val_arr = val_arr->at(attribute); } } if (!is_val(val_arr) && !is_val(val_arr) && !is_val(val_arr)) { @@ -808,9 +852,7 @@ const func_builtins & value_array_t::get_builtins() const { }}, {"string", [](const func_args & args) -> value { args.ensure_vals(); - auto str = mk_val(); - gather_string_parts_recursive(args.get_pos(0), str); - return str; + return mk_val(args.get_pos(0)->as_string()); }}, {"tojson", tojson}, {"map", [](const func_args & args) -> value { @@ -821,26 +863,26 @@ const func_builtins & value_array_t::get_builtins() const { if (!is_val(args.get_args().at(1))) { throw not_implemented_exception("map: filter-mapping not implemented"); } + value val = args.get_pos(0); value attribute = args.get_kwarg_or_pos("attribute", 1); const bool attr_is_int = is_val(attribute); if (!is_val(attribute) && !attr_is_int) { throw raised_exception("map: attribute must be string or integer"); } const int64_t attr_int = attr_is_int ? attribute->as_int() : 0; - const std::string attr_name = attribute->as_string().str(); value default_val = args.get_kwarg("default", mk_val()); auto out = mk_val(); - auto arr = args.get_pos(0)->as_array(); + auto arr = val->as_array(); for (const auto & item : arr) { value attr_val; if (attr_is_int) { attr_val = is_val(item) ? item->at(attr_int, default_val) : default_val; } else { - attr_val = is_val(item) ? item->at(attr_name, default_val) : default_val; + attr_val = is_val(item) ? item->at(attribute, default_val) : default_val; } out->push_back(attr_val); } - return out; + return is_val(val) ? mk_val(std::move(out->as_array())) : out; }}, {"append", [](const func_args & args) -> value { args.ensure_count(2); @@ -867,6 +909,7 @@ const func_builtins & value_array_t::get_builtins() const { if (!is_val(args.get_pos(0))) { throw raised_exception("sort: first argument must be an array"); } + value val = args.get_pos(0); value val_reverse = args.get_kwarg_or_pos("reverse", 1); value val_case = args.get_kwarg_or_pos("case_sensitive", 2); value attribute = args.get_kwarg_or_pos("attribute", 3); @@ -875,8 +918,7 @@ const func_builtins & value_array_t::get_builtins() const { const bool reverse = val_reverse->as_bool(); // undefined == false const bool attr_is_int = is_val(attribute); const int64_t attr_int = attr_is_int ? attribute->as_int() : 0; - const std::string attr_name = attribute->is_undefined() ? "" : attribute->as_string().str(); - std::vector arr = cast_val(args.get_pos(0))->as_array(); // copy + std::vector arr = val->as_array(); // copy std::sort(arr.begin(), arr.end(),[&](const value & a, const value & b) { value val_a = a; value val_b = b; @@ -884,22 +926,23 @@ const func_builtins & value_array_t::get_builtins() const { if (attr_is_int && is_val(a) && is_val(b)) { val_a = a->at(attr_int); val_b = b->at(attr_int); - } else if (!attr_is_int && !attr_name.empty() && is_val(a) && is_val(b)) { - val_a = a->at(attr_name); - val_b = b->at(attr_name); + } else if (!attr_is_int && is_val(a) && is_val(b)) { + val_a = a->at(attribute); + val_b = b->at(attribute); } else { - throw raised_exception("sort: unsupported object attribute comparison"); + throw raised_exception("sort: unsupported object attribute comparison between " + a->type() + " and " + b->type()); } } return value_compare(val_a, val_b, reverse ? value_compare_op::gt : value_compare_op::lt); }); - return mk_val(arr); + return is_val(val) ? mk_val(std::move(arr)) : mk_val(std::move(arr)); }}, {"reverse", [](const func_args & args) -> value { args.ensure_vals(); - std::vector arr = cast_val(args.get_pos(0))->as_array(); // copy + value val = args.get_pos(0); + std::vector arr = val->as_array(); // copy std::reverse(arr.begin(), arr.end()); - return mk_val(arr); + return is_val(val) ? mk_val(std::move(arr)) : mk_val(std::move(arr)); }}, {"unique", [](const func_args &) -> value { throw not_implemented_exception("Array unique builtin not implemented"); @@ -930,7 +973,7 @@ const func_builtins & value_object_t::get_builtins() const { default_val = args.get_pos(2); } const value obj = args.get_pos(0); - std::string key = args.get_pos(1)->as_string().str(); + const value key = args.get_pos(1); return obj->at(key, default_val); }}, {"keys", [](const func_args & args) -> value { @@ -938,7 +981,7 @@ const func_builtins & value_object_t::get_builtins() const { const auto & obj = args.get_pos(0)->as_ordered_object(); auto result = mk_val(); for (const auto & pair : obj) { - result->push_back(mk_val(pair.first)); + result->push_back(pair.first); } return result; }}, @@ -956,15 +999,16 @@ const func_builtins & value_object_t::get_builtins() const { const auto & obj = args.get_pos(0)->as_ordered_object(); auto result = mk_val(); for (const auto & pair : obj) { - auto item = mk_val(); - item->push_back(mk_val(pair.first)); - item->push_back(pair.second); + auto item = mk_val(pair); result->push_back(std::move(item)); } return result; }}, {"tojson", tojson}, - {"string", tojson}, + {"string", [](const func_args & args) -> value { + args.ensure_vals(); + return mk_val(args.get_pos(0)->as_string()); + }}, {"length", [](const func_args & args) -> value { args.ensure_vals(); const auto & obj = args.get_pos(0)->as_ordered_object(); @@ -985,11 +1029,11 @@ const func_builtins & value_object_t::get_builtins() const { const bool reverse = val_reverse->as_bool(); // undefined == false const bool by_value = is_val(val_by) && val_by->as_string().str() == "value" ? true : false; auto result = mk_val(val_input); // copy - std::sort(result->val_obj.ordered.begin(), result->val_obj.ordered.end(), [&](const auto & a, const auto & b) { + std::sort(result->val_obj.begin(), result->val_obj.end(), [&](const auto & a, const auto & b) { if (by_value) { return value_compare(a.second, b.second, reverse ? value_compare_op::gt : value_compare_op::lt); } else { - return reverse ? a.first > b.first : a.first < b.first; + return value_compare(a.first, b.first, reverse ? value_compare_op::gt : value_compare_op::lt); } }); return result; @@ -1005,6 +1049,22 @@ const func_builtins & value_none_t::get_builtins() const { static const func_builtins builtins = { {"default", default_value}, {"tojson", tojson}, + {"string", [](const func_args &) -> value { + return mk_val("None"); + }}, + {"safe", [](const func_args &) -> value { + return mk_val("None"); + }}, + {"strip", [](const func_args &) -> value { + return mk_val("None"); + }}, + {"items", empty_value_fn}, + {"map", empty_value_fn}, + {"reject", empty_value_fn}, + {"rejectattr", empty_value_fn}, + {"select", empty_value_fn}, + {"selectattr", empty_value_fn}, + {"unique", empty_value_fn}, }; return builtins; } @@ -1013,10 +1073,33 @@ const func_builtins & value_none_t::get_builtins() const { const func_builtins & value_undefined_t::get_builtins() const { static const func_builtins builtins = { {"default", default_value}, - {"tojson", [](const func_args & args) -> value { - args.ensure_vals(); - return mk_val("null"); - }}, + {"capitalize", empty_value_fn}, + {"first", empty_value_fn}, + {"items", empty_value_fn}, + {"join", empty_value_fn}, + {"last", empty_value_fn}, + {"length", empty_value_fn}, + {"list", empty_value_fn}, + {"lower", empty_value_fn}, + {"map", empty_value_fn}, + {"max", empty_value_fn}, + {"min", empty_value_fn}, + {"reject", empty_value_fn}, + {"rejectattr", empty_value_fn}, + {"replace", empty_value_fn}, + {"reverse", empty_value_fn}, + {"safe", empty_value_fn}, + {"select", empty_value_fn}, + {"selectattr", empty_value_fn}, + {"sort", empty_value_fn}, + {"string", empty_value_fn}, + {"strip", empty_value_fn}, + {"sum", empty_value_fn}, + {"title", empty_value_fn}, + {"truncate", empty_value_fn}, + {"unique", empty_value_fn}, + {"upper", empty_value_fn}, + {"wordcount", empty_value_fn}, }; return builtins; } @@ -1133,6 +1216,8 @@ void global_from_json(context & ctx, const nlohmann::ordered_json & json_obj, bo } } +// recursively convert value to JSON string +// TODO: avoid circular references static void value_to_json_internal(std::ostringstream & oss, const value & val, int curr_lvl, int indent, const std::string_view item_sep, const std::string_view key_sep) { auto indent_str = [indent, curr_lvl]() -> std::string { return (indent > 0) ? std::string(curr_lvl * indent, ' ') : ""; @@ -1195,7 +1280,8 @@ static void value_to_json_internal(std::ostringstream & oss, const value & val, size_t i = 0; for (const auto & pair : obj) { oss << indent_str() << (indent > 0 ? std::string(indent, ' ') : ""); - oss << "\"" << pair.first << "\"" << key_sep; + value_to_json_internal(oss, mk_val(pair.first->as_string().str()), curr_lvl + 1, indent, item_sep, key_sep); + oss << key_sep; value_to_json_internal(oss, pair.second, curr_lvl + 1, indent, item_sep, key_sep); if (i < obj.size() - 1) { oss << item_sep; @@ -1218,4 +1304,19 @@ std::string value_to_json(const value & val, int indent, const std::string_view return oss.str(); } +// TODO: avoid circular references +std::string value_to_string_repr(const value & val) { + if (is_val(val)) { + const std::string val_str = val->as_string().str(); + + if (val_str.find('\'') != std::string::npos) { + return value_to_json(val); + } else { + return "'" + val_str + "'"; + } + } else { + return val->as_repr(); + } +} + } // namespace jinja diff --git a/common/jinja/value.h b/common/jinja/value.h index 7bd0202cea8..1c04760a08c 100644 --- a/common/jinja/value.h +++ b/common/jinja/value.h @@ -1,8 +1,10 @@ #pragma once #include "string.h" +#include "utils.h" #include +#include #include #include #include @@ -10,6 +12,7 @@ #include #include #include +#include #include namespace jinja { @@ -93,7 +96,8 @@ void global_from_json(context & ctx, const T_JSON & json_obj, bool mark_input); struct func_args; // function argument values -using func_handler = std::function; +using func_hptr = value(const func_args &); +using func_handler = std::function; using func_builtins = std::map; enum value_compare_op { eq, ge, gt, lt, ne }; @@ -103,28 +107,9 @@ struct value_t { int64_t val_int; double val_flt; string val_str; - bool val_bool; std::vector val_arr; - - struct map { - // once set to true, all keys must be numeric - // caveat: we only allow either all numeric keys or all non-numeric keys - // for now, this only applied to for_statement in case of iterating over object keys/items - bool is_key_numeric = false; - std::map unordered; - std::vector> ordered; - void insert(const std::string & key, const value & val) { - if (unordered.find(key) != unordered.end()) { - // if key exists, remove from ordered list - ordered.erase(std::remove_if(ordered.begin(), ordered.end(), - [&](const std::pair & p) { return p.first == key; }), - ordered.end()); - } - unordered[key] = val; - ordered.push_back({key, val}); - } - } val_obj; + std::vector> val_obj; func_handler val_func; @@ -139,6 +124,7 @@ struct value_t { value_t(const value_t &) = default; virtual ~value_t() = default; + // Note: only for debugging and error reporting purposes virtual std::string type() const { return ""; } virtual int64_t as_int() const { throw std::runtime_error(type() + " is not an int value"); } @@ -146,7 +132,7 @@ struct value_t { virtual string as_string() const { throw std::runtime_error(type() + " is not a string value"); } virtual bool as_bool() const { throw std::runtime_error(type() + " is not a bool value"); } virtual const std::vector & as_array() const { throw std::runtime_error(type() + " is not an array value"); } - virtual const std::vector> & as_ordered_object() const { throw std::runtime_error(type() + " is not an object value"); } + virtual const std::vector> & as_ordered_object() const { throw std::runtime_error(type() + " is not an object value"); } virtual value invoke(const func_args &) const { throw std::runtime_error(type() + " is not a function value"); } virtual bool is_none() const { return false; } virtual bool is_undefined() const { return false; } @@ -154,43 +140,66 @@ struct value_t { throw std::runtime_error("No builtins available for type " + type()); } - virtual bool has_key(const std::string & key) { - return val_obj.unordered.find(key) != val_obj.unordered.end(); - } - virtual value & at(const std::string & key, value & default_val) { - auto it = val_obj.unordered.find(key); - if (it == val_obj.unordered.end()) { - return default_val; - } - return val_obj.unordered.at(key); - } - virtual value & at(const std::string & key) { - auto it = val_obj.unordered.find(key); - if (it == val_obj.unordered.end()) { - throw std::runtime_error("Key '" + key + "' not found in value of type " + type()); - } - return val_obj.unordered.at(key); + virtual bool has_key(const value &) { throw std::runtime_error(type() + " is not an object value"); } + virtual void insert(const value & /* key */, const value & /* val */) { throw std::runtime_error(type() + " is not an object value"); } + virtual value & at(const value & /* key */, value & /* default_val */) { throw std::runtime_error(type() + " is not an object value"); } + virtual value & at(const value & /* key */) { throw std::runtime_error(type() + " is not an object value"); } + virtual value & at(const std::string & /* key */, value & /* default_val */) { throw std::runtime_error(type() + " is not an object value"); } + virtual value & at(const std::string & /* key */) { throw std::runtime_error(type() + " is not an object value"); } + virtual value & at(int64_t /* idx */, value & /* default_val */) { throw std::runtime_error(type() + " is not an array value"); } + virtual value & at(int64_t /* idx */) { throw std::runtime_error(type() + " is not an array value"); } + + virtual bool is_numeric() const { return false; } + virtual bool is_hashable() const { return false; } + virtual bool is_immutable() const { return true; } + virtual hasher unique_hash() const noexcept = 0; + // TODO: C++20 <=> operator + // NOTE: We are treating == as equivalent (for normal comparisons) and != as strict nonequal (for strict (is) comparisons) + virtual bool operator==(const value_t & other) const { return equivalent(other); } + virtual bool operator!=(const value_t & other) const { return nonequal(other); } + + // Note: only for debugging purposes + virtual std::string as_repr() const { return as_string().str(); } + +protected: + virtual bool equivalent(const value_t &) const = 0; + virtual bool nonequal(const value_t & other) const { return !equivalent(other); } +}; + +// +// utils +// + +const func_builtins & global_builtins(); + +std::string value_to_json(const value & val, int indent = -1, const std::string_view item_sep = ", ", const std::string_view key_sep = ": "); + +// Note: only used for debugging purposes +std::string value_to_string_repr(const value & val); + +struct not_implemented_exception : public std::runtime_error { + not_implemented_exception(const std::string & msg) : std::runtime_error("NotImplemented: " + msg) {} +}; + +struct value_hasher { + size_t operator()(const value & val) const noexcept { + return val->unique_hash().digest(); } - virtual value & at(int64_t index, value & default_val) { - if (index < 0) { - index += val_arr.size(); - } - if (index < 0 || static_cast(index) >= val_arr.size()) { - return default_val; - } - return val_arr[index]; +}; + +struct value_equivalence { + bool operator()(const value & lhs, const value & rhs) const { + return *lhs == *rhs; } - virtual value & at(int64_t index) { - if (index < 0) { - index += val_arr.size(); - } - if (index < 0 || static_cast(index) >= val_arr.size()) { - throw std::runtime_error("Index " + std::to_string(index) + " out of bounds for array of size " + std::to_string(val_arr.size())); - } - return val_arr[index]; + bool operator()(const std::pair & lhs, const std::pair & rhs) const { + return *(lhs.first) == *(rhs.first) && *(lhs.second) == *(rhs.second); } +}; - virtual std::string as_repr() const { return as_string().str(); } +struct value_equality { + bool operator()(const value & lhs, const value & rhs) const { + return !(*lhs != *rhs); + } }; // @@ -198,24 +207,49 @@ struct value_t { // struct value_int_t : public value_t { - value_int_t(int64_t v) { val_int = v; } + value_int_t(int64_t v) { + val_int = v; + val_flt = static_cast(v); + if (static_cast(val_flt) != v) { + val_flt = v < 0 ? -INFINITY : INFINITY; + } + } virtual std::string type() const override { return "Integer"; } virtual int64_t as_int() const override { return val_int; } - virtual double as_float() const override { return static_cast(val_int); } + virtual double as_float() const override { return val_flt; } virtual string as_string() const override { return std::to_string(val_int); } virtual bool as_bool() const override { return val_int != 0; } virtual const func_builtins & get_builtins() const override; + virtual bool is_numeric() const override { return true; } + virtual bool is_hashable() const override { return true; } + virtual hasher unique_hash() const noexcept override { + return hasher(typeid(*this)) + .update(&val_int, sizeof(val_int)) + .update(&val_flt, sizeof(val_flt)); + } +protected: + virtual bool equivalent(const value_t & other) const override { + return other.is_numeric() && val_int == other.val_int && val_flt == other.val_flt; + } + virtual bool nonequal(const value_t & other) const override { + return !(typeid(*this) == typeid(other) && val_int == other.val_int); + } }; using value_int = std::shared_ptr; struct value_float_t : public value_t { - value_float_t(double v) { val_flt = v; } + value val; + value_float_t(double v) { + val_flt = v; + val_int = std::isfinite(v) ? static_cast(v) : 0; + val = mk_val(val_int); + } virtual std::string type() const override { return "Float"; } virtual double as_float() const override { return val_flt; } - virtual int64_t as_int() const override { return static_cast(val_flt); } + virtual int64_t as_int() const override { return val_int; } virtual string as_string() const override { std::string out = std::to_string(val_flt); out.erase(out.find_last_not_of('0') + 1, std::string::npos); // remove trailing zeros @@ -226,6 +260,24 @@ struct value_float_t : public value_t { return val_flt != 0.0; } virtual const func_builtins & get_builtins() const override; + virtual bool is_numeric() const override { return true; } + virtual bool is_hashable() const override { return true; } + virtual hasher unique_hash() const noexcept override { + if (static_cast(val_int) == val_flt) { + return val->unique_hash(); + } else { + return hasher(typeid(*this)) + .update(&val_int, sizeof(val_int)) + .update(&val_flt, sizeof(val_flt)); + } + } +protected: + virtual bool equivalent(const value_t & other) const override { + return other.is_numeric() && val_int == other.val_int && val_flt == other.val_flt; + } + virtual bool nonequal(const value_t & other) const override { + return !(typeid(*this) == typeid(other) && val_flt == other.val_flt); + } }; using value_float = std::shared_ptr; @@ -247,19 +299,49 @@ struct value_string_t : public value_t { return val_str.length() > 0; } virtual const func_builtins & get_builtins() const override; + virtual bool is_hashable() const override { return true; } + virtual hasher unique_hash() const noexcept override { + const auto type_hash = typeid(*this).hash_code(); + auto hash = hasher(); + hash.update(&type_hash, sizeof(type_hash)); + val_str.hash_update(hash); + return hash; + } void mark_input() { val_str.mark_input(); } +protected: + virtual bool equivalent(const value_t & other) const override { + return typeid(*this) == typeid(other) && val_str.str() == other.val_str.str(); + } }; using value_string = std::shared_ptr; struct value_bool_t : public value_t { - value_bool_t(bool v) { val_bool = v; } + value val; + value_bool_t(bool v) { + val_int = static_cast(v); + val_flt = static_cast(v); + val = mk_val(val_int); + } virtual std::string type() const override { return "Boolean"; } - virtual bool as_bool() const override { return val_bool; } - virtual string as_string() const override { return std::string(val_bool ? "True" : "False"); } + virtual int64_t as_int() const override { return val_int; } + virtual bool as_bool() const override { return val_int; } + virtual string as_string() const override { return std::string(val_int ? "True" : "False"); } virtual const func_builtins & get_builtins() const override; + virtual bool is_numeric() const override { return true; } + virtual bool is_hashable() const override { return true; } + virtual hasher unique_hash() const noexcept override { + return val->unique_hash(); + } +protected: + virtual bool equivalent(const value_t & other) const override { + return other.is_numeric() && val_int == other.val_int && val_flt == other.val_flt; + } + virtual bool nonequal(const value_t & other) const override { + return !(typeid(*this) == typeid(other) && val_int == other.val_int); + } }; using value_bool = std::shared_ptr; @@ -269,13 +351,34 @@ struct value_array_t : public value_t { value_array_t(value & v) { val_arr = v->val_arr; } + value_array_t(std::vector && arr) { + val_arr = arr; + } value_array_t(const std::vector & arr) { val_arr = arr; } - void reverse() { std::reverse(val_arr.begin(), val_arr.end()); } - void push_back(const value & val) { val_arr.push_back(val); } - void push_back(value && val) { val_arr.push_back(std::move(val)); } + void reverse() { + if (is_immutable()) { + throw std::runtime_error("Attempting to modify immutable type"); + } + std::reverse(val_arr.begin(), val_arr.end()); + } + void push_back(const value & val) { + if (is_immutable()) { + throw std::runtime_error("Attempting to modify immutable type"); + } + val_arr.push_back(val); + } + void push_back(value && val) { + if (is_immutable()) { + throw std::runtime_error("Attempting to modify immutable type"); + } + val_arr.push_back(std::move(val)); + } value pop_at(int64_t index) { + if (is_immutable()) { + throw std::runtime_error("Attempting to modify immutable type"); + } if (index < 0) { index = static_cast(val_arr.size()) + index; } @@ -287,67 +390,228 @@ struct value_array_t : public value_t { return val; } virtual std::string type() const override { return "Array"; } + virtual bool is_immutable() const override { return false; } virtual const std::vector & as_array() const override { return val_arr; } virtual string as_string() const override { + const bool immutable = is_immutable(); std::ostringstream ss; - ss << "["; + ss << (immutable ? "(" : "["); for (size_t i = 0; i < val_arr.size(); i++) { if (i > 0) ss << ", "; - ss << val_arr.at(i)->as_repr(); + value val = val_arr.at(i); + ss << value_to_string_repr(val); } - ss << "]"; + if (immutable && val_arr.size() == 1) { + ss << ","; + } + ss << (immutable ? ")" : "]"); return ss.str(); } virtual bool as_bool() const override { return !val_arr.empty(); } + virtual value & at(int64_t index, value & default_val) override { + if (index < 0) { + index += val_arr.size(); + } + if (index < 0 || static_cast(index) >= val_arr.size()) { + return default_val; + } + return val_arr[index]; + } + virtual value & at(int64_t index) override { + if (index < 0) { + index += val_arr.size(); + } + if (index < 0 || static_cast(index) >= val_arr.size()) { + throw std::runtime_error("Index " + std::to_string(index) + " out of bounds for array of size " + std::to_string(val_arr.size())); + } + return val_arr[index]; + } virtual const func_builtins & get_builtins() const override; + virtual bool is_hashable() const override { + if (std::all_of(val_arr.begin(), val_arr.end(), [&](auto & val) -> bool { + return val->is_immutable() && val->is_hashable(); + })) { + return true; + } + return false; + } + virtual hasher unique_hash() const noexcept override { + auto hash = hasher(typeid(*this)); + for (const auto & val : val_arr) { + // must use digest to prevent problems from "concatenation" property of hasher + // for ex. hash of [ "ab", "c" ] should be different from [ "a", "bc" ] + const size_t val_hash = val->unique_hash().digest(); + hash.update(&val_hash, sizeof(size_t)); + } + return hash; + } +protected: + virtual bool equivalent(const value_t & other) const override { + return typeid(*this) == typeid(other) && is_hashable() && other.is_hashable() && std::equal(val_arr.begin(), val_arr.end(), other.val_arr.begin(), value_equivalence()); + } }; using value_array = std::shared_ptr; +struct value_tuple_t : public value_array_t { + value_tuple_t(value & v) { + val_arr = v->val_arr; + } + value_tuple_t(std::vector && arr) { + val_arr = arr; + } + value_tuple_t(const std::vector & arr) { + val_arr = arr; + } + value_tuple_t(const std::pair & pair) { + val_arr.push_back(pair.first); + val_arr.push_back(pair.second); + } + virtual std::string type() const override { return "Tuple"; } + virtual bool is_immutable() const override { return true; } +}; +using value_tuple = std::shared_ptr; + + struct value_object_t : public value_t { + std::unordered_map unordered; bool has_builtins = true; // context and loop objects do not have builtins value_object_t() = default; value_object_t(value & v) { val_obj = v->val_obj; + for (const auto & pair : val_obj) { + unordered[pair.first] = pair.second; + } } - value_object_t(const std::map & obj) { + value_object_t(const std::map & obj) { for (const auto & pair : obj) { - val_obj.insert(pair.first, pair.second); + insert(pair.first, pair.second); } } - value_object_t(const std::vector> & obj) { + value_object_t(const std::vector> & obj) { for (const auto & pair : obj) { - val_obj.insert(pair.first, pair.second); + insert(pair.first, pair.second); } } void insert(const std::string & key, const value & val) { - val_obj.insert(key, val); + insert(mk_val(key), val); } virtual std::string type() const override { return "Object"; } - virtual const std::vector> & as_ordered_object() const override { return val_obj.ordered; } + virtual bool is_immutable() const override { return false; } + virtual const std::vector> & as_ordered_object() const override { return val_obj; } + virtual string as_string() const override { + std::ostringstream ss; + ss << "{"; + for (size_t i = 0; i < val_obj.size(); i++) { + if (i > 0) ss << ", "; + auto & [key, val] = val_obj.at(i); + ss << value_to_string_repr(key) << ": " << value_to_string_repr(val); + } + ss << "}"; + return ss.str(); + } virtual bool as_bool() const override { - return !val_obj.unordered.empty(); + return !unordered.empty(); + } + virtual bool has_key(const value & key) override { + if (!key->is_immutable() || !key->is_hashable()) { + throw std::runtime_error("Object key of unhashable type: " + key->type()); + } + return unordered.find(key) != unordered.end(); + } + virtual void insert(const value & key, const value & val) override { + bool replaced = false; + if (is_immutable()) { + throw std::runtime_error("Attempting to modify immutable type"); + } + if (has_key(key)) { + // if key exists, replace value in ordered list instead of appending + for (auto & pair : val_obj) { + if (*(pair.first) == *key) { + pair.second = val; + replaced = true; + break; + } + } + } + unordered[key] = val; + if (!replaced) { + val_obj.push_back({key, val}); + } + } + virtual value & at(const value & key, value & default_val) override { + if (!has_key(key)) { + return default_val; + } + return unordered.at(key); + } + virtual value & at(const value & key) override { + if (!has_key(key)) { + throw std::runtime_error("Key '" + key->as_string().str() + "' not found in value of type " + type()); + } + return unordered.at(key); + } + virtual value & at(const std::string & key, value & default_val) override { + value key_val = mk_val(key); + return at(key_val, default_val); + } + virtual value & at(const std::string & key) override { + value key_val = mk_val(key); + return at(key_val); } virtual const func_builtins & get_builtins() const override; + virtual bool is_hashable() const override { + if (std::all_of(val_obj.begin(), val_obj.end(), [&](auto & pair) -> bool { + const auto & val = pair.second; + return val->is_immutable() && val->is_hashable(); + })) { + return true; + } + return false; + } + virtual hasher unique_hash() const noexcept override { + auto hash = hasher(typeid(*this)); + for (const auto & [key, val] : val_obj) { + // must use digest to prevent problems from "concatenation" property of hasher + // for ex. hash of key="ab", value="c" should be different from key="a", value="bc" + const size_t key_hash = key->unique_hash().digest(); + const size_t val_hash = val->unique_hash().digest(); + hash.update(&key_hash, sizeof(key_hash)); + hash.update(&val_hash, sizeof(val_hash)); + } + return hash; + } +protected: + virtual bool equivalent(const value_t & other) const override { + return typeid(*this) == typeid(other) && is_hashable() && other.is_hashable() && std::equal(val_obj.begin(), val_obj.end(), other.val_obj.begin(), value_equivalence()); + } }; using value_object = std::shared_ptr; // -// null and undefined types +// none and undefined types // struct value_none_t : public value_t { virtual std::string type() const override { return "None"; } virtual bool is_none() const override { return true; } virtual bool as_bool() const override { return false; } + virtual string as_string() const override { return string(type()); } virtual std::string as_repr() const override { return type(); } virtual const func_builtins & get_builtins() const override; + virtual bool is_hashable() const override { return true; } + virtual hasher unique_hash() const noexcept override { + return hasher(typeid(*this)); + } +protected: + virtual bool equivalent(const value_t & other) const override { + return typeid(*this) == typeid(other); + } }; using value_none = std::shared_ptr; - struct value_undefined_t : public value_t { std::string hint; // for debugging, to indicate where undefined came from value_undefined_t(const std::string & h = "") : hint(h) {} @@ -356,6 +620,13 @@ struct value_undefined_t : public value_t { virtual bool as_bool() const override { return false; } virtual std::string as_repr() const override { return type(); } virtual const func_builtins & get_builtins() const override; + virtual hasher unique_hash() const noexcept override { + return hasher(typeid(*this)); + } +protected: + virtual bool equivalent(const value_t & other) const override { + return is_undefined() == other.is_undefined(); + } }; using value_undefined = std::shared_ptr; @@ -436,7 +707,23 @@ struct value_func_t : public value_t { return val_func(new_args); } virtual std::string type() const override { return "Function"; } - virtual std::string as_repr() const override { return type(); } + virtual std::string as_repr() const override { return type() + "<" + name + ">(" + (arg0 ? arg0->as_repr() : "") + ")"; } + virtual bool is_hashable() const override { return false; } + virtual hasher unique_hash() const noexcept override { + // Note: this is unused for now, we don't support function as object keys + // use function pointer as unique identifier + const auto target = val_func.target(); + return hasher(typeid(*this)).update(&target, sizeof(target)); + } +protected: + virtual bool equivalent(const value_t & other) const override { + // Note: this is unused for now, we don't support function as object keys + // compare function pointers + // (val_func == other.val_func does not work as std::function::operator== is only used for nullptr check) + const auto target_this = this->val_func.target(); + const auto target_other = other.val_func.target(); + return typeid(*this) == typeid(other) && target_this == target_other; + } }; using value_func = std::shared_ptr; @@ -447,18 +734,21 @@ struct value_kwarg_t : public value_t { value_kwarg_t(const std::string & k, const value & v) : key(k), val(v) {} virtual std::string type() const override { return "KwArg"; } virtual std::string as_repr() const override { return type(); } + virtual bool is_hashable() const override { return true; } + virtual hasher unique_hash() const noexcept override { + const auto type_hash = typeid(*this).hash_code(); + auto hash = val->unique_hash(); + hash.update(&type_hash, sizeof(type_hash)) + .update(key.data(), key.size()); + return hash; + } +protected: + virtual bool equivalent(const value_t & other) const override { + const value_kwarg_t & other_val = static_cast(other); + return typeid(*this) == typeid(other) && key == other_val.key && val == other_val.val; + } }; using value_kwarg = std::shared_ptr; -// utils - -const func_builtins & global_builtins(); -std::string value_to_json(const value & val, int indent = -1, const std::string_view item_sep = ", ", const std::string_view key_sep = ": "); - -struct not_implemented_exception : public std::runtime_error { - not_implemented_exception(const std::string & msg) : std::runtime_error("NotImplemented: " + msg) {} -}; - - } // namespace jinja diff --git a/common/ngram-cache.cpp b/common/ngram-cache.cpp index d1a4d84c40f..dce54b36474 100644 --- a/common/ngram-cache.cpp +++ b/common/ngram-cache.cpp @@ -192,12 +192,12 @@ void common_ngram_cache_draft( break; } - LOG(" - draft candidate: token=%d\n", drafted_token); + LOG_DBG(" - draft candidate: token=%d\n", drafted_token); draft.push_back(drafted_token); } } -void common_ngram_cache_save(common_ngram_cache & ngram_cache, std::string & filename) { +void common_ngram_cache_save(common_ngram_cache & ngram_cache, const std::string & filename) { std::ofstream file_out(filename, std::ios::binary); for (std::pair item : ngram_cache) { const common_ngram ngram = item.first; @@ -217,10 +217,9 @@ void common_ngram_cache_save(common_ngram_cache & ngram_cache, std::string & fil file_out.write(reinterpret_cast(&count), sizeof(int32_t)); } } - } -common_ngram_cache common_ngram_cache_load(std::string & filename) { +common_ngram_cache common_ngram_cache_load(const std::string & filename) { std::ifstream hashmap_file(filename, std::ios::binary); if (!hashmap_file) { throw std::ifstream::failure("Unable to open file " + filename); diff --git a/common/ngram-cache.h b/common/ngram-cache.h index dfe012abe49..6e7cfea966d 100644 --- a/common/ngram-cache.h +++ b/common/ngram-cache.h @@ -88,12 +88,12 @@ void common_ngram_cache_draft( // Save an ngram cache to a file. // ngram_cache: the ngram cache to save. // filename: the path under which to save the ngram cache. -void common_ngram_cache_save(common_ngram_cache & ngram_cache, std::string & filename); +void common_ngram_cache_save(common_ngram_cache & ngram_cache, const std::string & filename); // Load an ngram cache saved with common_ngram_cache_save. // filename: the path from which to load the ngram cache. // returns: an ngram cache containing the information saved to filename. -common_ngram_cache common_ngram_cache_load(std::string & filename); +common_ngram_cache common_ngram_cache_load(const std::string & filename); // Merge two ngram caches. // ngram_cache_target: the ngram cache to which to add the information from ngram_cache_add. diff --git a/common/ngram-map.cpp b/common/ngram-map.cpp new file mode 100644 index 00000000000..c5b8fc75ed8 --- /dev/null +++ b/common/ngram-map.cpp @@ -0,0 +1,531 @@ +#include "common.h" +#include "log.h" +#include "ngram-map.h" + +#include +#include +#include +#include + +// prime number used for LCG hash function (32 bit), it is near (sqrt(5) - 1)/2 * 2^32. +#define LCG_FACTOR 2654435761UL + +// Compute the LCG hash of a n-gram of size len at offset start. +static uint32_t common_ngram_map_hash(const llama_tokens & tokens, size_t start, size_t len) { + uint32_t hash = 0; + for (size_t i = 0; i < len; ++i) { + hash = hash * LCG_FACTOR + tokens[start + i]; + } + return hash; +} + +// Print the values of a sublist of `llama_tokens & inp` to a string in the form [v0, v1, v2, ...]. +static std::string common_tokens_to_str(const llama_tokens & inp, size_t start, size_t length) { + std::ostringstream oss; + oss << '['; + for (size_t i = 0; i < length; ++i) { + if (i > 0) { + oss << ", "; + } + oss << inp[start + i]; + } + oss << ']'; + return oss.str(); +} + + +// n-gram simple +// + +/** + * Perform speculative generation using the model's own token history. + * Searches for a matching pattern in the token history and returns draft tokens. + * + * @param state Current state of this implementation + * @param tokens Token history to search in + * @param sampled Last sampled token + * @return Vector of draft tokens, empty if no matching pattern is found + */ +llama_tokens common_ngram_simple_draft( + const common_ngram_simple_config & config, + const llama_tokens & tokens, llama_token sampled) { + + // Simple implementation of self-speculative decoding without a draft model. + // + const size_t cur_len = tokens.size(); + + const size_t n_draft_min = config.size_ngram; // size of n-gram to lookup in token history + const size_t n_draft_max = config.size_mgram; // the m-gram following the found n-gram is used for draft + + // vector for tokens we want to verify. + // return empty vector if there is no match. + llama_tokens draft_tokens; + + // We need at least n_draft_min + n_draft_max + 1 tokens. + if (cur_len <= static_cast(n_draft_min + n_draft_max + 1)) { + return draft_tokens; + } + + // pattern search + llama_tokens pattern; + pattern.reserve(n_draft_min); + for (size_t j = cur_len - n_draft_min + 1; j < cur_len; ++j) { + pattern.push_back(tokens[j]); + } + pattern.push_back(sampled); // add the last token to the pattern + + size_t match_pos = 0; // we ignore position 0, position 0 == no match + // search backwards, but skip the current match (we are currently there) + for (size_t j = cur_len - n_draft_min - 1; j > 0; --j) { + bool match = true; + for (size_t k = 0; k < pattern.size(); ++k) { + if (tokens[j + k] != pattern[k]) { + match = false; + break; + } + } + if (match) { + match_pos = j; + break; + } + } + if (match_pos == 0) { + return draft_tokens; + } + + const size_t copy_max = std::min( + n_draft_max, + cur_len - (match_pos + n_draft_min) + ); + if (copy_max < n_draft_min) { + return draft_tokens; + } + LOG_DBG("%s: #tokens = %zu: found matching pattern at pos %zu, length %zu, draft length %zu\n", + __func__, cur_len, + match_pos, pattern.size(), copy_max); + + draft_tokens.reserve(copy_max); + for (size_t j = 0; j < copy_max; ++j) { + draft_tokens.push_back(tokens[match_pos + n_draft_min + j]); + } + return draft_tokens; +} + + +// n-gram map +// + +// maximum number of counted values of a ngram map value. +#define COMMON_NGRAM_MAX_VALUE_COUNT 16380 + +void common_ngram_map_begin( + common_ngram_map & map, const llama_tokens & tokens) { + size_t size_begin = tokens.size(); + + LOG_DBG("%s: begin, idx_last_draft=%zu, new begin=%zu, #keys=%zu\n", __func__, + map.idx_last_check, size_begin, map.keys.size()); + + size_t count_map_entries_upd = 0; + if (!map.key_map.empty() && size_begin < map.idx_last_check) { + if (map.show_key_map_stats) { + // Print statistics of hash map map_key. + size_t count_nonzero = 0; + uint32_t min_idx = UINT32_MAX; + uint32_t max_idx = 0; + for (size_t i = 0; i < map.key_map.size(); ++i) { + uint32_t key_idx = map.key_map[i]; + if (key_idx != 0) { + ++count_nonzero; + if (key_idx < min_idx) min_idx = key_idx; + if (key_idx > max_idx) max_idx = key_idx; + } + } + if (count_nonzero == 0) { + min_idx = 0; + } + LOG_INF("%s: key_map stats: entries=%zu, min_idx=%u, max_idx=%u, key_map_last_idx=%u\n", + __func__, count_nonzero, min_idx, max_idx, map.key_map_last_idx); + } + + // Update the map from hash to key index (clear outdated entries). + for (size_t i = 0; i < map.key_map.size(); ++i) { + uint32_t key_idx = map.key_map[i]; + if (key_idx >= map.size_last_begin) { + map.key_map[i] = 0; + count_map_entries_upd++; + } + } + map.key_map_last_idx = (map.size_last_begin > 0) ? map.size_last_begin - 1 : 0; + } + + if (size_begin < map.idx_last_check && !map.keys.empty()) { + // The next token generation will start at index size_begin. + // The tokens between map.size_last_begin and size_begin are no longer valid. + // + // Refresh map: Remove all entries with index >= map.size_last_begin. + size_t count_keys = map.keys.size(); + size_t count_keys_del = 0; + size_t count_values_del = 0; + for (int32_t i = map.keys.size() - 1; i >= 0; --i) { + common_ngram_map_key & key = map.keys[i]; + if (key.key_idx >= map.size_last_begin) { + // Delete the key. + LOG_DBG("%s: delete key %d at index %zu (>= size_last_begin=%zu)\n", __func__, i, key.key_idx, map.size_last_begin); + map.keys.erase(map.keys.begin() + i); + count_keys_del++; + continue; + } + if (map.key_only) { + continue; + } + + // Check the indices of the values. + for (int16_t j = COMMON_NGRAM_MAX_VALUES - 1; j >= 0; --j) { + common_ngram_map_value & value = key.values[j]; + if (value.value_idx >= map.size_last_begin) { + // Delete the value. + count_values_del++; + + // Move all values after this value to the left. + for (uint16_t k = j; k < COMMON_NGRAM_MAX_VALUES - 1; ++k) { + key.values[k] = key.values[k + 1]; + } + // Clear the last value. + key.values[COMMON_NGRAM_MAX_VALUES - 1].value_idx = 0; + key.values[COMMON_NGRAM_MAX_VALUES - 1].value_num = 0; + } + } + if (key.values[0].value_idx == 0) { + // No values left, delete the key. + LOG_DBG("%s: delete key %d at index %zu (no values left)\n", __func__, i, key.key_idx); + map.keys.erase(map.keys.begin() + i); + count_keys_del++; + } + } + + LOG_INF("%s: refresh map: idx_last_draft=%zu, new begin=%zu, #keys_checked=%zu, #keys_del=%zu, #values_del=%zu, #hashes_upd=%zu\n", __func__, + map.idx_last_check, size_begin, + count_keys, count_keys_del, count_values_del, count_map_entries_upd); + } + + map.idx_last_check = (map.size_last_begin > 0) ? map.size_last_begin - 1 : 0; + map.size_last_begin = size_begin; +} + +void common_ngram_map_draft(common_ngram_map & map, + const llama_tokens & inp, llama_token sampled, + llama_tokens & draft) { + // reset last key and value. + map.last_draft_created = false; + map.last_draft_key_idx = 0; + map.last_draft_value_idx = 0; + + const size_t cur_len = inp.size(); + const uint16_t n = map.size_key; + const uint16_t m = map.size_value; + if (cur_len < static_cast(2 * n + m)) { + return; + } + if (cur_len >= static_cast(UINT32_MAX)) { + // key_map uses uint32_t instead of size_t. + GGML_ABORT("%s: cur_len exceeds UINT32_MAX: %zu", __func__, cur_len); + } + + // Only check every check_rate tokens to save compute + // i.e., perform check if (cur_len - idx_last_check) >= check_rate + if (map.idx_last_check + map.check_rate > cur_len) { + return; + } + map.idx_last_check = cur_len; + + // search pattern, the key n-gram + std::vector key_tokens; + key_tokens.reserve(n); + for (size_t j = cur_len - n + 1; j < cur_len; ++j) { + key_tokens.push_back(inp[j]); + } + key_tokens.push_back(sampled); + + // search for the key in the map + size_t match_pos = 0; + if (map.size_last_begin > cur_len) { + GGML_ABORT("%s: map.size_last_begin > cur_len: %zu > %zu", __func__, map.size_last_begin, cur_len); + } + if (!map.key_map.empty()) { + // Search for the key in the map key_map from hash of ngrams to index of ngram. + uint32_t idx_hash = (common_ngram_map_hash(key_tokens, 0, n) % map.key_map.size()); + uint32_t idx_key = map.key_map[idx_hash]; + if (idx_key != 0 && idx_key < cur_len - n - m - 1) { + // Check if the key matches the key at idx_key (because of possible collisions). + bool match = true; + for (size_t k = 0; k < n; ++k) { + if (inp[idx_key + k] != key_tokens[k]) { + match = false; + break; + } + } + LOG_DBG("%s: key hash %x -> idx_key %d: match %d\n", __func__, idx_hash, idx_key, match ? 1 : 0); + if (match) { + match_pos = idx_key; + } + } + } + if (match_pos == 0 && map.size_last_begin > (size_t) (n + m + 1)) { + // Search for the key in [1, map.size_last_begin - n - m -1], descending. + for (size_t j = map.size_last_begin - n - m - 1; j > map.key_map_last_idx; --j) { + // Check if the key matches the key. + bool match = true; + for (size_t k = 0; k < n; ++k) { + if (inp[j + k] != key_tokens[k]) { + match = false; + break; + } + } + if (match) { + match_pos = j; + break; + } + } + } + if (match_pos == 0) { + // In case of a reasoning chat, the part after size_last_begin may be deleted/reordered later. + // + // Search in [size_last_begin, cur_len - n - m - 1], descending. + for (size_t j = cur_len - n - m - 1; j > map.size_last_begin && j > map.key_map_last_idx; --j) { + bool match = true; + for (size_t k = 0; k < n; ++k) { + if (inp[j + k] != key_tokens[k]) { + match = false; + break; + } + } + if (match) { + match_pos = j; + break; + } + } + } + if (match_pos > 0) { + LOG_DBG("%s: cur_len = %zu, n = %d, m = %d, sz_tkns = %zu, sampled = %d, match_pos = %zu\n", __func__, + cur_len, n, m, key_tokens.size(), sampled, match_pos); + } + + if (!map.key_map.empty()) { + // Add hashes of new ngrams in key_map. + // + // Use the same order as above. + if (map.size_last_begin > (size_t) (n + m + 1)) { + for (size_t j = map.size_last_begin - n - m - 1; j > map.key_map_last_idx; --j) { + // compute hash and store index of ngram at idx j in the map. + uint32_t idx_hash = (common_ngram_map_hash(inp, j, n) % map.key_map.size()); + if (map.key_map[idx_hash] == 0) { + map.key_map[idx_hash] = j; // collisions may occur + } + } + } + + for (size_t j = cur_len - n - m - 1; j > map.size_last_begin && j > map.key_map_last_idx; --j) { + // compute hash and store index of ngram at idx j in the map. + uint32_t idx_hash = (common_ngram_map_hash(inp, j, n) % map.key_map.size()); + if (map.key_map[idx_hash] == 0) { + map.key_map[idx_hash] = j; + } + } + map.key_map_last_idx = std::max(static_cast(cur_len - n - m - 1), map.key_map_last_idx); + } + + if (match_pos == 0) { + return; + } + + // We have a match, now we look for the statistics of the key. + size_t key_offset = map.keys.size(); // offset in the map + // We iterate through the std::vector map->keys. + for (size_t i = 0; i < map.keys.size(); ++i) { + bool match = true; + for (size_t j = 0; j < n; ++j) { + if (inp[map.keys[i].key_idx + j] != key_tokens[j]) { + match = false; + break; + } + } + if (match) { + key_offset = i; + break; + } + } + if (key_offset == map.keys.size()) { + // We create a new key-entry, it will get offset key_offset. + common_ngram_map_key new_key; + new_key.key_idx = match_pos; + new_key.stat_idx = 0; + new_key.key_num = 0; + for (int i = 0; i < COMMON_NGRAM_MAX_VALUES; ++i) { + new_key.values[i].value_num = 0; + new_key.values[i].n_accepted = m; + } + map.keys.push_back(new_key); + } + + // our key n-gram: + common_ngram_map_key & curr_key = map.keys[key_offset]; + + // update number of key hits + curr_key.key_num = (uint16_t) std::min((int) map.keys[key_offset].key_num + 1, + (int) COMMON_NGRAM_MAX_VALUE_COUNT); + + if (map.key_only) { + // simple mode: + // Fill in the draft with the m tokens following the key. + // We work with value values[0] only. + int n_draft_tokens = std::min((int) m, (int) curr_key.values[0].n_accepted); + + for (int i = 0; i < n_draft_tokens; ++i) { + draft.push_back(inp[match_pos + n + i]); + } + + LOG_DBG("%s: key_idx = %zu, key_offset = %zu, key_num = %d, draft.size = %zu\n", __func__, + curr_key.key_idx, key_offset, curr_key.key_num, draft.size()); + + map.last_draft_created = false; + map.last_draft_key_idx = key_offset; + map.last_draft_value_idx = 0; // value 0 is used for simple mode + return; + } + + if (curr_key.key_num < map.min_hits) { + // not enough hits to consider this a good draft + LOG_DBG("%s: key_offset = %zu, key_num = %d, min_hits = %d, no draft\n", __func__, + key_offset, curr_key.key_num, map.min_hits); + return; + } + + // complex mode: examine the different m-grams after this key n-gram. + // + + // determine all (max COMMON_NGRAM_MAX_VALUES) m-grams after the key n-gram. + for (size_t i = curr_key.stat_idx; i <= match_pos; ++i) { + // begins the key n-gram at index i? + bool match_key = true; + for (size_t k = 0; k < n; ++k) { + if (inp[i + k] != key_tokens[k]) { + match_key = false; + break; + } + } + if (!match_key) { + continue; + } + + // Do we haven a existing value m-gram or a new one after the key at index i? + size_t idx_begin_value_key = i + n; + int idx_value = -1; + for (int v = 0; v < COMMON_NGRAM_MAX_VALUES; ++v) { + size_t idx_begin_value_v = curr_key.values[v].value_idx; + if (idx_begin_value_v == 0) { + // We found an empty value slot => we found a new value m-gram after the key n-gram. + curr_key.values[v].value_idx = idx_begin_value_key; + curr_key.values[v].value_num = 0; + curr_key.values[v].n_accepted = m; + idx_value = v; + break; + } + bool match = true; + for (size_t j = 0; j < m; ++j) { + if (inp[idx_begin_value_key + j] != inp[idx_begin_value_v + j]) { + match = false; + break; + } + } + if (match) { + // We found an existing value m-gram after the key n-gram. + idx_value = v; + break; + } + } + if (idx_value >= 0) { + // We found a value m-gram of the key n-gram. + curr_key.values[idx_value].value_num = (uint16_t) std::min((int) curr_key.values[idx_value].value_num + 1, + (int) COMMON_NGRAM_MAX_VALUE_COUNT); + } + } + // the statistics are updated up to match_pos. + curr_key.stat_idx = match_pos; + + // Do we have a value we could use for the draft? + uint16_t max_occur = 0; + int slot_max = 0; + for (int v = 0; v < COMMON_NGRAM_MAX_VALUES; ++v) { + uint16_t curr_occur = curr_key.values[v].value_num; + if (curr_occur > max_occur) { + max_occur = curr_occur; + slot_max = v; + } + } + // What is sum of the other occurences? + uint32_t sum_occur = 0; + for (int v = 0; v < COMMON_NGRAM_MAX_VALUES; ++v) { + if (v == slot_max) { + continue; + } + uint16_t curr_occur = curr_key.values[v].value_num; + sum_occur += curr_occur; + } + + LOG_INF("%s: key_offset = %zu, max_occur = %d, sum_occur = %d, slot_max = %d [%zu/%d, %zu/%d, %zu/%d, %zu/%d]\n", __func__, + key_offset, + max_occur, sum_occur, slot_max, + curr_key.values[0].value_idx, curr_key.values[0].value_num, + curr_key.values[1].value_idx, curr_key.values[1].value_num, + curr_key.values[2].value_idx, curr_key.values[2].value_num, + curr_key.values[3].value_idx, curr_key.values[3].value_num + ); + // Print the tokens of the four values (if idx != 0), use LOG_INF + for (int v = 0; v < COMMON_NGRAM_MAX_VALUES; ++v) { + if (curr_key.values[v].value_idx != 0) { + LOG_INF("%s: value[%d] = %s\n", __func__, v, common_tokens_to_str(inp, curr_key.values[v].value_idx, m).c_str()); + } + } + + if (sum_occur > 0 && max_occur < 2 * sum_occur) { + // The most frequent value is not much more frequent than the other values. + // We do not use the draft. + return; + } + + // We use the most frequent value values[slot_max] for the draft. + // Fill in the draft with the m tokens following the key. + int n_draft_tokens = std::min((int) m, (int) curr_key.values[slot_max].n_accepted); + + for (int i = 0; i < n_draft_tokens; ++i) { + draft.push_back(inp[match_pos + n + i]); + } + + LOG_INF("%s: key_offset = %zu, slot_max = %d, key_num = %d, draft.size = %zu\n", __func__, + key_offset, slot_max, + curr_key.key_num, draft.size()); + + map.last_draft_created = true; + map.last_draft_key_idx = key_offset; + map.last_draft_value_idx = slot_max; // value used for draft generation. +} + +void common_ngram_map_accept(common_ngram_map & map, uint16_t n_accepted) { + if (!map.last_draft_created) { + return; + } + + // find the key and its chosen value. + const size_t key_idx = map.last_draft_key_idx; + const size_t val_idx = map.last_draft_value_idx; + + // find key corresponding to key_idx. + common_ngram_map_key & curr_key = map.keys[key_idx]; + // find value corresponding to val_idx. + struct common_ngram_map_value & curr_value = curr_key.values[val_idx]; // value used for draft generation. + + // update the value statistics + LOG_INF("common_ngram_map_send_accepted: n_accepted = %d, prev value_num = %d\n", + n_accepted, curr_value.n_accepted); + curr_value.n_accepted = n_accepted; +} diff --git a/common/ngram-map.h b/common/ngram-map.h new file mode 100644 index 00000000000..9668bd5a7c5 --- /dev/null +++ b/common/ngram-map.h @@ -0,0 +1,117 @@ +#pragma once +// +// common/ngram-map.h: structures used to manage a map from n-grams to a list of m-grams +// +// These structures are used to do a lookup of n-grams followed by m-grams in token history. +// +// There are two algorithms implemented: +// 1. ngram_simple: lookup of n-grams followed by m-grams in token history. +// 2. ngram_map: lookup of n-grams followed by m-grams in token history using a map. +// The map is a vector of key n-grams, and for each key n-gram there is a list of value m-grams. +// +// ref: https://github.com/ggml-org/llama.cpp/pull/18471 +// + +#include "llama.h" +#include "common.h" + +#include + +// n-gram simple +// + +// config of n-gram simple. +struct common_ngram_simple_config { + uint16_t size_ngram; // size of n-grams to lookup in self-mode + uint16_t size_mgram; // size of m-grams to draft in self-mode + uint16_t check_rate; // check for speculative decoding without draft model for each check_rate token +}; + +// Searches for a n-gram in the history and checks whether a draft sequence should be generated. +llama_tokens common_ngram_simple_draft( + const common_ngram_simple_config & config, + const llama_tokens & tokens, llama_token sampled); + + +// n-gram map +// + +// maximum number of m-gram values stored for each key n-gram. +#define COMMON_NGRAM_MAX_VALUES 4 + +// number of entries in the (optional, size 0 to disable) map from ngram-hash to ngram-index. +#define COMMON_NGRAM_HASH_MAP_SIZE 262144 + +// statistics of a m-gram after a known n-gram +struct common_ngram_map_value { + size_t value_idx = 0; // index of value m-gram in token-history (0 if unused) + uint16_t value_num = 0; // number of occurences of this value m-gram after the key n-gram (0 in an unused values-slot) + int16_t n_accepted = -1; // number of accepted tokens at last draft (-1 if unused) +}; + +// statistics of a n-gram +struct common_ngram_map_key { + size_t key_idx; // index of key n-gram in token-history + size_t stat_idx; // index of last token of stastistics computation (key_num, values) + + uint16_t key_num; // number of occurences of this key n-gram in token-history + common_ngram_map_value values[COMMON_NGRAM_MAX_VALUES]; // some known values after the key +}; + +// map from n-grams to following m-grams in token-history +struct common_ngram_map { + uint16_t size_key; // size of key n-grams + uint16_t size_value; // size of value m-grams + + bool key_only; // true if only key n-grams are used, no values. + + std::vector keys; // key n-grams which occur several times in token-history + uint16_t check_rate; // check for speculative decoding without draft model for each check_rate token + uint16_t min_hits; // minimum number of key hits to consider a draft + + bool show_key_map_stats = false; // true, if statitics of the key_map should be printed. + + common_ngram_map(uint16_t sz_key, uint16_t sz_value, bool only_keys, + uint16_t check_rate, uint16_t min_hits) + : size_key(sz_key), size_value(sz_value), key_only(only_keys), + check_rate(check_rate), min_hits(min_hits) { + key_map.resize(COMMON_NGRAM_HASH_MAP_SIZE); // 2^18 hash entries, 0 entries if key_map shouldn't be used + } + + // In reasoning chats the previous reasoning block will be removed from context history. + // A rebuild of the ngram map is needed after that. + + size_t size_last_begin = 0; // number of tokens at previous start of generation + + bool last_draft_created = false; // true if a draft was created at last call. + size_t last_draft_key_idx = 0; // index of last key used for draft generation (0 = no draft) + uint16_t last_draft_value_idx = 0; // index of last value used for draft generation. + + size_t idx_last_check = 0; // index of last check in context history + + // optional map "hash to ngram-index" for faster lookup of n-grams. map is empty if unused. + // + // uint32_t instead of size_t (size of current histories is << UINT32_MAX) + std::vector key_map; // key_map[hash] = index of ngram in context window + uint32_t key_map_last_idx = 0; // index of the last ngram added to key_map +}; + +// Initialize the n-gram map with the given token history. +// map: the ngram map to initialize. +// tokens: the token history to base the map on. +void common_ngram_map_begin( + common_ngram_map & map, + const llama_tokens & tokens); + +// Searches for the n-gram in the history and checks whether a draft sequence should be generated. +// map: the ngram map to search in. +// inp: the tokens generated so far. +// sampled: the token that was just sampled. +// draft: vector to store the draft tokens, initially empty. +void common_ngram_map_draft( + common_ngram_map & map, + const llama_tokens & inp, llama_token sampled, + llama_tokens & draft); + +// Update the statistics of a value after a draft was processed. +void common_ngram_map_accept(common_ngram_map & map, uint16_t n_accepted); diff --git a/common/ngram-mod.cpp b/common/ngram-mod.cpp new file mode 100644 index 00000000000..76f7257f611 --- /dev/null +++ b/common/ngram-mod.cpp @@ -0,0 +1,60 @@ +#include "ngram-mod.h" + +// +// common_ngram_mod +// + +common_ngram_mod::common_ngram_mod(uint16_t n, size_t size) : n(n), used(0) { + entries.resize(size); + + reset(); +} + +size_t common_ngram_mod::idx(const entry_t * tokens) const { + size_t res = 0; + + for (size_t i = 0; i < n; ++i) { + res = res*6364136223846793005ULL + tokens[i]; + } + + res = res % entries.size(); + + return res; +} + +void common_ngram_mod::add(const entry_t * tokens) { + const size_t i = idx(tokens); + + if (entries[i] == EMPTY) { + used++; + } + + entries[i] = tokens[n]; +} + +common_ngram_mod::entry_t common_ngram_mod::get(const entry_t * tokens) const { + const size_t i = idx(tokens); + + return entries[i]; +} + +void common_ngram_mod::reset() { + std::fill(entries.begin(), entries.end(), EMPTY); + used = 0; +} + +size_t common_ngram_mod::get_n() const { + return n; +} + +size_t common_ngram_mod::get_used() const { + return used; +} + +size_t common_ngram_mod::size() const { + return entries.size(); +} + +size_t common_ngram_mod::size_bytes() const { + return entries.size() * sizeof(entries[0]); +} diff --git a/common/ngram-mod.h b/common/ngram-mod.h new file mode 100644 index 00000000000..7af92e9dde4 --- /dev/null +++ b/common/ngram-mod.h @@ -0,0 +1,38 @@ +#pragma once + +#include +#include +#include + +// +// common_ngram_mod +// ref: https://github.com/ggml-org/llama.cpp/pull/19164 +// + +// basic n-gram hasher +struct common_ngram_mod { + using entry_t = int32_t; + + static constexpr entry_t EMPTY = -1; + + common_ngram_mod(uint16_t n, size_t size); + + size_t idx(const entry_t * tokens) const; + void add(const entry_t * tokens); + entry_t get(const entry_t * tokens) const; // return -1 if not found + + void reset(); + + size_t get_n() const; + size_t get_used() const; + + size_t size() const; + size_t size_bytes() const; + +private: + size_t n; // ngram size to hash + + size_t used; + + std::vector entries; +}; diff --git a/common/speculative.cpp b/common/speculative.cpp index 3e83b0964c8..84d2556ceba 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -1,99 +1,57 @@ #include "speculative.h" +#include "common.h" #include "ggml.h" #include "llama.h" #include "log.h" -#include "common.h" +#include "ngram-cache.h" +#include "ngram-map.h" +#include "ngram-mod.h" #include "sampling.h" -#include #include +#include +#include #include #define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128 #define SPEC_VOCAB_CHECK_START_TOKEN_ID 5 -struct common_speculative { - struct llama_context * ctx_tgt; // only used for retokenizing from ctx_dft - struct llama_context * ctx_dft; - struct common_sampler * smpl; - - llama_batch batch; - llama_tokens prompt_dft; - bool vocab_dft_compatible = true; // whether retokenization is needed - std::map tgt_dft_replacements = {}; +const std::vector common_speculative_types = { + COMMON_SPECULATIVE_TYPE_NONE, + COMMON_SPECULATIVE_TYPE_DRAFT, + COMMON_SPECULATIVE_TYPE_EAGLE3, + COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, + COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, + COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, + COMMON_SPECULATIVE_TYPE_NGRAM_MOD, + COMMON_SPECULATIVE_TYPE_NGRAM_CACHE }; -struct common_speculative * common_speculative_init( - struct llama_context * ctx_tgt, - struct llama_context * ctx_dft) { - auto * result = new common_speculative { - /* .ctx_tgt = */ ctx_tgt, - /* .ctx_dft = */ ctx_dft, - /* .smpl = */ nullptr, - /* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1), - /* .prompt_dft = */ {}, - /* .vocab_dft_compatible = */ false, - }; - - // TODO: optimize or pass from outside? -#if 0 - { - common_params_sampling params; - params.no_perf = false; - - params.top_k = 40; - params.top_p = 0.9; - - params.samplers = { - COMMON_SAMPLER_TYPE_TOP_K, - COMMON_SAMPLER_TYPE_TOP_P, - COMMON_SAMPLER_TYPE_INFILL, - }; - - result->smpl = common_sampler_init(llama_get_model(ctx_dft), params); - } -#else - { - common_params_sampling params; - params.no_perf = false; - - params.top_k = 10; - - params.samplers = { - COMMON_SAMPLER_TYPE_TOP_K, - }; - - result->smpl = common_sampler_init(llama_get_model(ctx_dft), params); - } -#endif - - result->vocab_dft_compatible = common_speculative_are_compatible(ctx_tgt, ctx_dft); - LOG_DBG("vocab_dft_compatible = %d\n", result->vocab_dft_compatible); - - return result; -} - -void common_speculative_free(struct common_speculative * spec) { - if (spec == nullptr) { - return; - } - - common_sampler_free(spec->smpl); - - llama_batch_free(spec->batch); +const std::map common_speculative_type_from_name_map = { + {"none", COMMON_SPECULATIVE_TYPE_NONE}, + {"draft", COMMON_SPECULATIVE_TYPE_DRAFT}, + {"eagle3", COMMON_SPECULATIVE_TYPE_EAGLE3}, + {"ngram_simple", COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE}, + {"ngram_map_k", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K}, + {"ngram_map_k4v", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V}, + {"ngram_mod", COMMON_SPECULATIVE_TYPE_NGRAM_MOD}, + {"ngram_cache", COMMON_SPECULATIVE_TYPE_NGRAM_CACHE} +}; - delete spec; -} +struct common_speculative_config { + common_speculative_type type; + common_params_speculative params; -bool common_speculative_are_compatible( - const struct llama_context * ctx_tgt, - const struct llama_context * ctx_dft) { - const struct llama_model * model_tgt = llama_get_model(ctx_tgt); - const struct llama_model * model_dft = llama_get_model(ctx_dft); + common_speculative_config(common_speculative_type t, + const common_params_speculative & p = common_params_speculative{}) : type(t), params(p) {} +}; - const struct llama_vocab * vocab_tgt = llama_model_get_vocab(model_tgt); - const struct llama_vocab * vocab_dft = llama_model_get_vocab(model_dft); +static bool common_speculative_are_compatible( + const llama_model * model_tgt, + const llama_model * model_dft) { + const llama_vocab * vocab_tgt = llama_model_get_vocab(model_tgt); + const llama_vocab * vocab_dft = llama_model_get_vocab(model_dft); const bool vocab_type_tgt = llama_vocab_type(vocab_tgt); LOG_DBG("%s: vocab_type tgt: %d\n", __func__, vocab_type_tgt); @@ -134,11 +92,12 @@ bool common_speculative_are_compatible( for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) { const char * token_text_tgt = llama_vocab_get_text(vocab_tgt, i); const char * token_text_dft = llama_vocab_get_text(vocab_dft, i); + if (std::strcmp(token_text_tgt, token_text_dft) != 0) { LOG_DBG("%s: draft model vocab must match target model to use speculation but ", __func__); LOG_DBG("token %d content differs - target '%s', draft '%s'\n", i, - common_token_to_piece(ctx_tgt, i).c_str(), - common_token_to_piece(ctx_dft, i).c_str()); + common_token_to_piece(vocab_tgt, i).c_str(), + common_token_to_piece(vocab_dft, i).c_str()); return false; } } @@ -147,215 +106,976 @@ bool common_speculative_are_compatible( return true; } -void common_speculative_add_replacement_tgt_dft( - struct common_speculative * spec, - const char *source, const char *dest) { - spec->tgt_dft_replacements[source] = dest; -} +// state of an implementation of speculative decoding +// +// each implementation has a unique type and a state that is implementation-specific +// in a subclass of common_speculative_state +struct common_speculative_state { + const enum common_speculative_type type; + + // TODO: rename to n_call_draft, n_gen_drafts, n_acc_drafts, n_gen_tokens, n_acc_tokens + // TODO: add n_call_begin, n_call_accept + size_t drafts_call_count = 0; // number of times this implementation was called. + size_t drafts_generated_count = 0; // number of times a draft or part was generated by this implementation. + size_t drafts_accepted_count = 0; // number of times a draft or part was accepted by the target model. + size_t drafts_generated_tokens = 0; // number of tokens generated by this implementation. + size_t drafts_accepted_tokens = 0; // number of tokens accepted by the target model. + + // TODO: track performance of most recent calls + const bool gen_perf = true; // whether to generate performance stats. + + int64_t t_begin_us = 0; // total time spent in refresh of this implementation in microseconds. + int64_t t_draft_us = 0; // total time spent in generating drafts in this implementation in microseconds. + int64_t t_accept_us = 0; // total time spent in accumulation of this implementation in microseconds. -static std::string replace_to_dft( - struct common_speculative * spec, - const std::string& input) { - std::string result = input; - for (const auto & pair : spec->tgt_dft_replacements) { - size_t pos = result.find(pair.first); - while (pos != std::string::npos) { - result.replace(pos, pair.first.length(), pair.second); - pos = result.find(pair.first, pos + pair.second.length()); + common_speculative_state(enum common_speculative_type type) : type(type) {} + + virtual ~common_speculative_state() = default; + + virtual void begin(const llama_tokens & prompt) = 0; + + virtual void draft( + const common_params_speculative & params, + const llama_tokens & prompt_tgt, + llama_token id_last, + llama_tokens & result) = 0; + + virtual void accept(uint16_t n_accepted) = 0; +}; + +struct common_speculative_state_draft : public common_speculative_state { + llama_context * ctx_tgt; // only used for retokenizing from ctx_dft + llama_context * ctx_dft; + + common_sampler * smpl; + + llama_batch batch; + llama_tokens prompt_dft; + + bool vocab_cmpt = true; // whether retokenization is needed + std::unordered_map vocab_map; + + common_speculative_state_draft( + enum common_speculative_type type, + llama_context * ctx_tgt, + llama_context * ctx_dft, + const std::vector> & replacements) + : common_speculative_state(type) + , ctx_tgt(ctx_tgt) + , ctx_dft(ctx_dft) + { + batch = llama_batch_init(llama_n_batch(ctx_dft), 0, 1); + smpl = nullptr; + + // TODO: optimize or pass from outside? + // { + // common_params_sampling params; + // params.no_perf = false; + // + // params.top_k = 40; + // params.top_p = 0.9; + // + // params.samplers = { + // COMMON_SAMPLER_TYPE_TOP_K, + // COMMON_SAMPLER_TYPE_TOP_P, + // COMMON_SAMPLER_TYPE_INFILL, + // }; + // + // result->smpl = common_sampler_init(llama_get_model(ctx_dft), params); + // } + { + common_params_sampling params; + params.no_perf = false; + params.top_k = 10; + params.samplers = { + COMMON_SAMPLER_TYPE_TOP_K, + }; + + smpl = common_sampler_init(llama_get_model(ctx_dft), params); + } + + vocab_cmpt = common_speculative_are_compatible(llama_get_model(ctx_tgt), llama_get_model(ctx_dft)); + LOG_DBG("vocab_cmpt = %d\n", vocab_cmpt); + + if (!vocab_cmpt) { + LOG_WRN("the target and draft vocabs are not compatible - tokens will be translated between the two\n"); + + for (const auto & pair : replacements) { + vocab_map[pair.first] = pair.second; + } } } - return result; -} -static std::string replace_to_tgt( - struct common_speculative * spec, - const std::string& input) { - std::string result = input; - for (const auto& pair : spec->tgt_dft_replacements) { - size_t pos = result.find(pair.second); - while (pos != std::string::npos) { - result.replace(pos, pair.second.length(), pair.first); - pos = result.find(pair.second, pos + pair.first.length()); + ~common_speculative_state_draft() override { + llama_perf_context_print(ctx_dft); + + llama_free(ctx_dft); + + common_sampler_free(smpl); + + llama_batch_free(batch); + } + + void begin(const llama_tokens & prompt) override { + GGML_UNUSED(prompt); + } + + void draft( + const common_params_speculative & params, + const llama_tokens & prompt_tgt, + llama_token id_last, + llama_tokens & result) override { + auto * spec = this; + + auto & batch = spec->batch; + auto & ctx_tgt = spec->ctx_tgt; + auto & ctx_dft = spec->ctx_dft; + auto & smpl = spec->smpl; + auto & prompt_dft = spec->prompt_dft; + + auto * mem_dft = llama_get_memory(ctx_dft); + + int reuse_i = 0; + int reuse_n = 0; + + const int n_ctx = llama_n_ctx(ctx_dft) - params.n_max; + + llama_tokens prompt_cnv; + if (!spec->vocab_cmpt) { + std::string text; + + text = common_detokenize(ctx_tgt, prompt_tgt, true); + text = replace_to_dft(text); + + LOG_DBG("%s: main->draft detokenized string: '%s'\n", __func__, text.c_str()); + + prompt_cnv = common_tokenize(ctx_dft, text, false, true); + + // convert id_last to draft vocab. llama_detokenize is called directly to avoid an allocation + const auto * model_tgt = llama_get_model(ctx_tgt); + const auto * vocab_tgt = llama_model_get_vocab(model_tgt); + + int32_t n_chars = llama_detokenize(vocab_tgt, &id_last, 1, nullptr, 0, false, false); + GGML_ASSERT(n_chars < 0 && "failed to detokenize id_last"); + + text.resize(-n_chars); + llama_detokenize(vocab_tgt, &id_last, 1, text.data(), text.size(), false, false); + text = replace_to_dft(text); + + LOG_DBG("main->draft detokenized id_last(%d): '%s'\n", id_last, text.c_str()); + id_last = common_tokenize(ctx_dft, text, false, true)[0]; + } + + const llama_tokens & prompt_cur = spec->vocab_cmpt ? prompt_tgt : prompt_cnv; + + const int i_start = std::max(0, (int) prompt_cur.size() - n_ctx); + + // reuse as much as possible from the old draft context + // ideally, the draft context should be as big as the target context and we will always reuse the entire prompt + for (int i = 0; i < (int) prompt_dft.size(); ++i) { + int cur = 0; + while (i_start + cur < (int) prompt_cur.size() && + i + cur < (int) prompt_dft.size() && + prompt_cur[i_start + cur] == prompt_dft[i + cur]) { + cur++; + } + + if ((cur >= 256 || n_ctx >= (int) prompt_cur.size()) && cur > reuse_n) { + reuse_i = i; + reuse_n = cur; + } + } + + LOG_DBG("%s: reuse_i = %d, reuse_n = %d, prompt = %d\n", __func__, reuse_i, reuse_n, (int) prompt_dft.size()); + + result.clear(); + result.reserve(params.n_max); + + if (reuse_n == 0) { + llama_memory_clear(mem_dft, false); + prompt_dft.clear(); + } else { + // this happens when a previous draft has been discarded (for example, due to being too small), but the + // target model agreed with it. in this case, we simply pass back the previous results to save compute + if (reuse_i + reuse_n < (int) prompt_dft.size() && prompt_dft[reuse_i + reuse_n] == id_last) { + for (int i = reuse_i + reuse_n + 1; i < (int) prompt_dft.size(); ++i) { + result.push_back(prompt_dft[i]); + + if (params.n_max <= (int) result.size()) { + break; + } + } + + return; + } + + if (reuse_i > 0) { + llama_memory_seq_rm (mem_dft, 0, 0, reuse_i); + llama_memory_seq_add(mem_dft, 0, reuse_i, -1, -reuse_i); + + prompt_dft.erase(prompt_dft.begin(), prompt_dft.begin() + reuse_i); + } + + if (reuse_n < (int) prompt_dft.size()) { + llama_memory_seq_rm (mem_dft, 0, reuse_n, -1); + prompt_dft.erase(prompt_dft.begin() + reuse_n, prompt_dft.end()); + } + } + + // prepare a batch to evaluate any new tokens in the prompt + common_batch_clear(batch); + + for (size_t i = i_start + reuse_n; i < prompt_cur.size(); ++i) { + //LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_cur[i]); + common_batch_add(batch, prompt_cur[i], i - i_start, { 0 }, false); + + prompt_dft.push_back(prompt_cur[i]); + } + + // we should rarely end-up here during normal decoding + if (batch.n_tokens > 0) { + //LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str()); + + llama_decode(ctx_dft, batch); + } + + const llama_pos n_past = prompt_dft.size(); + + LOG_DBG("%s: n_past = %d\n", __func__, n_past); + + common_batch_clear(batch); + common_batch_add (batch, id_last, n_past, { 0 }, true); + + prompt_dft.push_back(id_last); + + LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx_dft, prompt_dft).c_str()); + + llama_decode(ctx_dft, batch); + + common_sampler_reset(smpl); + + // sample n_draft tokens from the draft model + for (int i = 0; i < params.n_max; ++i) { + common_batch_clear(batch); + + common_sampler_sample(smpl, ctx_dft, 0, true); + + const auto * cur_p = common_sampler_get_candidates(smpl, true); + + for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) { + LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", + k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str()); + } + + // add drafted token for each sequence + const llama_token id = cur_p->data[0].id; + + common_sampler_accept(smpl, id, true); + + result.push_back(id); + + if (params.n_max <= (int) result.size()) { + break; + } + + // only collect very high-confidence draft tokens + if (cur_p->data[0].p < params.p_min) { + break; + } + + common_batch_add(batch, id, n_past + i + 1, { 0 }, true); + + // evaluate the drafted tokens on the draft model + llama_decode(ctx_dft, batch); + + prompt_dft.push_back(id); + } + + if (!spec->vocab_cmpt) { + std::string detokenized = common_detokenize(ctx_dft, result, true); + detokenized = replace_to_tgt(detokenized); + LOG_DBG("draft->main detokenized string: '%s'\n", detokenized.c_str()); + result = common_tokenize(ctx_tgt, detokenized, false, true); + if (result.size() > (size_t)params.n_max) { + result.resize(params.n_max); + } } } - return result; -} + void accept(uint16_t n_accepted) override { + // noop + GGML_UNUSED(n_accepted); + } -llama_tokens common_speculative_gen_draft( - struct common_speculative * spec, - struct common_speculative_params params, - const llama_tokens & prompt_tgt_main_model, // specified in target model vocab - llama_token id_last) { - auto & batch = spec->batch; - auto & ctx_tgt = spec->ctx_tgt; - auto & ctx_dft = spec->ctx_dft; - auto & smpl = spec->smpl; - auto & prompt_dft = spec->prompt_dft; + std::string replace_to_dft(const std::string & input) const { + std::string result = input; + + for (const auto & pair : this->vocab_map) { + size_t pos = result.find(pair.first); + while (pos != std::string::npos) { + result.replace(pos, pair.first.length(), pair.second); + pos = result.find(pair.first, pos + pair.second.length()); + } + } + + return result; + } + + std::string replace_to_tgt(const std::string & input) const { + std::string result = input; + + for (const auto & pair : this->vocab_map) { + size_t pos = result.find(pair.second); + while (pos != std::string::npos) { + result.replace(pos, pair.second.length(), pair.first); + pos = result.find(pair.second, pos + pair.first.length()); + } + } + + return result; + } +}; + +struct common_speculative_state_eagle3 : public common_speculative_state { + common_speculative_state_eagle3(enum common_speculative_type type) : common_speculative_state(type) {} + + void begin(const llama_tokens & prompt) override { + GGML_UNUSED(prompt); + } + + void draft( + const common_params_speculative & params, + const llama_tokens & prompt_tgt, + llama_token id_last, + llama_tokens & draft_tokens) override { + // TODO: implement + GGML_UNUSED(params); + GGML_UNUSED(prompt_tgt); + GGML_UNUSED(id_last); + GGML_UNUSED(draft_tokens); + } + + void accept(uint16_t n_accepted) override { + // noop + GGML_UNUSED(n_accepted); + } +}; + +// state of self-speculation (simple implementation, not ngram-map) +struct common_speculative_state_ngram_simple : public common_speculative_state { + common_ngram_simple_config config; + + uint16_t check_id = 0; // used to control the frequency of generating drafts + + common_speculative_state_ngram_simple( + enum common_speculative_type type, + common_ngram_simple_config config) + : common_speculative_state(type), config(config) {} + + void begin(const llama_tokens & prompt) override { + GGML_UNUSED(prompt); + } - auto * mem_dft = llama_get_memory(ctx_dft); + void draft( + const common_params_speculative & params, + const llama_tokens & prompt_tgt, + llama_token id_last, + llama_tokens & result) override { + ++check_id; + if (check_id < config.check_rate) { + return; + } + check_id = 0; - int reuse_i = 0; - int reuse_n = 0; + result = common_ngram_simple_draft(config, prompt_tgt, id_last); + GGML_UNUSED(params); + } + + void accept(uint16_t n_accepted) override { + // noop + GGML_UNUSED(n_accepted); + } +}; - const int n_ctx = llama_n_ctx(ctx_dft) - params.n_draft; +struct common_speculative_state_ngram_map_k : public common_speculative_state { + // draft ngram map for speculative decoding without draft model + common_ngram_map map; - llama_tokens prompt_tgt_draft_model; - if (!spec->vocab_dft_compatible) { - std::string text; - text = common_detokenize(ctx_tgt, prompt_tgt_main_model, true); - text = replace_to_dft(spec, text); - LOG_DBG("%s: main->draft detokenized string: '%s'\n", __func__, text.c_str()); - prompt_tgt_draft_model = common_tokenize(ctx_dft, text, false, true); + common_speculative_state_ngram_map_k( + enum common_speculative_type type, + common_ngram_map map) + : common_speculative_state(type), map(std::move(map)) {} - // convert id_last to draft vocab. llama_detokenize is called directly to avoid an allocation - const auto * model_tgt = llama_get_model(ctx_tgt); - const auto * vocab_tgt = llama_model_get_vocab(model_tgt); + void begin(const llama_tokens & prompt) override { + common_ngram_map_begin(map, prompt); + } - int32_t n_chars = llama_detokenize(vocab_tgt, &id_last, 1, nullptr, 0, false, false); - GGML_ASSERT(n_chars < 0 && "failed to detokenize id_last"); - text.resize(-n_chars); - llama_detokenize(vocab_tgt, &id_last, 1, text.data(), text.size(), false, false); - text = replace_to_dft(spec, text); + void draft( + const common_params_speculative & params, + const llama_tokens & prompt_tgt, + llama_token id_last, + llama_tokens & result) override { + common_ngram_map_draft(map, prompt_tgt, id_last, result); + GGML_UNUSED(params); + } - LOG_DBG("main->draft detokenized id_last(%d): '%s'\n", id_last, text.c_str()); - id_last = common_tokenize(ctx_dft, text, false, true)[0]; + void accept(uint16_t n_accepted) override { + common_ngram_map_accept(map, n_accepted); } - // prompt_tgt's tokens will always be compatible with ctx_dft - const llama_tokens &prompt_tgt = - spec->vocab_dft_compatible ? prompt_tgt_main_model : prompt_tgt_draft_model; +}; + +struct common_speculative_state_ngram_mod : public common_speculative_state { + common_ngram_mod & mod; - const int i_start = std::max(0, (int) prompt_tgt.size() - n_ctx); + // the last position in the prompt that was added to the ngram container + size_t i_last = 0; + + // length of the last drafted n‑gram (number of tokens returned by draft) + size_t n_draft_last = 0; + + // consecutive accept rounds with low acceptance fraction (< 0.5) + int n_low = 0; + + // enable trace logging if LLAMA_TRACE is set + const bool verbose; + + common_speculative_state_ngram_mod(enum common_speculative_type type, common_ngram_mod & mod) + : common_speculative_state(type), mod(mod), verbose(std::getenv("LLAMA_TRACE") != nullptr) { + static_assert(sizeof(llama_token) == sizeof(common_ngram_mod::entry_t)); + } - // reuse as much as possible from the old draft context - // ideally, the draft context should be as big as the target context and we will always reuse the entire prompt - for (int i = 0; i < (int) prompt_dft.size(); ++i) { - int cur = 0; - while (i_start + cur < (int) prompt_tgt.size() && - i + cur < (int) prompt_dft.size() && - prompt_tgt[i_start + cur] == prompt_dft[i + cur]) { - cur++; + void begin(const llama_tokens & prompt) override { + i_last = 0; + + n_draft_last = 0; + + const size_t n = mod.get_n(); + + if (prompt.size() < n) { + return; + } + + for (size_t i = 0; i < prompt.size() - n; ++i) { + mod.add(prompt.data() + i); } - if ((cur >= params.n_reuse || n_ctx >= (int) prompt_tgt.size()) && cur > reuse_n) { - reuse_i = i; - reuse_n = cur; + i_last = prompt.size() - n; + + const double f = (double)mod.get_used() / (double)mod.size(); + LOG_INF("%s: ngram_mod occupancy = %zu/%zu (%.2f)\n", __func__, mod.get_used(), mod.size(), f); + + constexpr double f_thold = 0.25; + if (f > f_thold) { + LOG_WRN("%s: ngram_mod occupancy %.2f exceeds threshold (%.2f) - resetting\n", __func__, f, f_thold); + + mod.reset(); } } - LOG_DBG("%s: reuse_i = %d, reuse_n = %d, prompt = %d\n", __func__, reuse_i, reuse_n, (int) prompt_dft.size()); + void draft( + const common_params_speculative & params, + const llama_tokens & prompt_tgt, + llama_token id_last, + llama_tokens & result) override { + GGML_UNUSED(params); - llama_tokens result; - result.reserve(params.n_draft); - - if (reuse_n == 0) { - llama_memory_clear(mem_dft, false); - prompt_dft.clear(); - } else { - // this happens when a previous draft has been discarded (for example, due to being too small), but the - // target model agreed with it. in this case, we simply pass back the previous results to save compute - if (reuse_i + reuse_n < (int) prompt_dft.size() && prompt_dft[reuse_i + reuse_n] == id_last) { - for (int i = reuse_i + reuse_n + 1; i < (int) prompt_dft.size(); ++i) { - result.push_back(prompt_dft[i]); - - if (params.n_draft <= (int) result.size()) { - break; + n_draft_last = 0; + + const size_t cur_len = prompt_tgt.size(); + if (cur_len < mod.get_n()) { + return; + } + + const size_t n = mod.get_n(); + + // add new ngrams in chunks + if (i_last + 32 < cur_len) { + for (size_t i = i_last; i < cur_len - n; ++i) { + mod.add(prompt_tgt.data() + i); + } + + i_last = cur_len - n; + } + + result.resize(n + params.n_max); + for (size_t i = 0; i < n - 1; ++i) { + result[i] = prompt_tgt[cur_len - n + 1 + i]; + } + result[n - 1] = id_last; + + for (int i = 0; i < params.n_max; ++i) { + const llama_token token = mod.get(result.data() + i); + if (token == common_ngram_mod::EMPTY) { + if (i < params.n_min) { + result.clear(); + return; } + + result.resize(n + i); + break; } + result[n + i] = token; + } - return result; + // only return the m tokens that were drafted + for (size_t i = 0; n + i < result.size(); ++i) { + result[i] = result[n + i]; } + result.resize(result.size() - n); - if (reuse_i > 0) { - llama_memory_seq_rm (mem_dft, 0, 0, reuse_i); - llama_memory_seq_add(mem_dft, 0, reuse_i, -1, -reuse_i); + // store length of drafted n‑gram for later acceptance analysis + n_draft_last = result.size(); + } - prompt_dft.erase(prompt_dft.begin(), prompt_dft.begin() + reuse_i); + void accept(uint16_t n_accepted) override { + if (verbose) { + LOG_INF("%s: accepted %d tokens from %zu drafted tokens\n", __func__, n_accepted, n_draft_last); } - if (reuse_n < (int) prompt_dft.size()) { - llama_memory_seq_rm (mem_dft, 0, reuse_n, -1); - prompt_dft.erase(prompt_dft.begin() + reuse_n, prompt_dft.end()); + // compute acceptance fraction if we have a recorded draft length + if (n_draft_last > 0) { + const double f_acc = (double)n_accepted / (double)n_draft_last; + if (f_acc < 0.5) { + n_low++; + if (n_low >= 3) { + LOG_WRN("%s: low acceptance streak (%d) – resetting ngram_mod\n", __func__, n_low); + + mod.reset(); + n_low = 0; + } + } else { + n_low = 0; + } } } +}; - // prepare a batch to evaluate any new tokens in the prompt - common_batch_clear(batch); +struct common_speculative_state_ngram_cache : public common_speculative_state { + uint16_t n_draft; + bool save_dynamic; + bool save_static; + + common_ngram_cache ngram_cache_context; + common_ngram_cache ngram_cache_dynamic; + common_ngram_cache ngram_cache_static; + + size_t cache_size = 0; // number of tokens in n-gram cache + + common_speculative_state_ngram_cache( + const enum common_speculative_type type, + const std::string & path_static, + const std::string & path_dynamic, + uint16_t n_draft, + bool save_dynamic, + bool save_static) + : common_speculative_state(type) + , n_draft(n_draft) + , save_dynamic(save_dynamic) + , save_static(save_static) + { + if (!path_static.empty()) { + try { + ngram_cache_static = common_ngram_cache_load(path_static); + } catch (...) { + LOG_ERR("failed to open static lookup cache: %s", path_static.c_str()); + GGML_ABORT("Couldn't read static lookup cache"); + } + } - for (size_t i = i_start + reuse_n; i < prompt_tgt.size(); ++i) { - //LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt[i]); - common_batch_add(batch, prompt_tgt[i], i - i_start, { 0 }, false); + if (!path_dynamic.empty()) { + try { + ngram_cache_dynamic = common_ngram_cache_load(path_dynamic); + } catch (...) { + LOG_ERR("failed to open dynamic lookup cache: %s", path_dynamic.c_str()); + GGML_ABORT("Couldn't read dynamic lookup cache"); + } + } + } - prompt_dft.push_back(prompt_tgt[i]); + void begin(const llama_tokens & prompt) override { + GGML_UNUSED(prompt); } - // we should rarely end-up here during normal decoding - if (batch.n_tokens > 0) { - //LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str()); + void draft( + const common_params_speculative & params, + const llama_tokens & prompt_tgt, + llama_token id_last, + llama_tokens & result) override { + GGML_UNUSED(params); + + if (cache_size < prompt_tgt.size() + 1) { + llama_tokens tokens_new; + tokens_new.reserve(prompt_tgt.size() + 1 - cache_size); + for (size_t j = cache_size; j < prompt_tgt.size(); ++j) { + tokens_new.push_back(prompt_tgt[j]); + } + tokens_new.push_back(id_last); // add the last token - llama_decode(ctx_dft, batch); + // Update context ngram cache with new prompt_tgt: + common_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, + tokens_new, tokens_new.size(), false); + cache_size = prompt_tgt.size() + 1; + } + + llama_tokens inp; + inp.reserve(prompt_tgt.size() + 1); + for (size_t j = 0; j < prompt_tgt.size(); ++j) { + inp.push_back(prompt_tgt[j]); + } + inp.push_back(id_last); + + result.push_back(id_last); + + common_ngram_cache_draft(inp, result, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, + ngram_cache_context, + ngram_cache_dynamic, + ngram_cache_static); + + if (result.size() > 0) { + // delete first token in result (which is the id_last token) + result.erase(result.begin()); + } } - const llama_pos n_past = prompt_dft.size(); + void accept(uint16_t n_accepted) override { + // TODO: noop + GGML_UNUSED(n_accepted); + } +}; - LOG_DBG("%s: n_past = %d\n", __func__, n_past); +struct common_speculative { + std::vector> impls; // list of implementations to use and their states + common_speculative_state * curr_impl = nullptr; // current implementation in use (for stats) +}; - common_batch_clear(batch); - common_batch_add (batch, id_last, n_past, { 0 }, true); +static common_ngram_map get_common_ngram_map(const common_speculative_config & config) { + uint16_t size_key = config.params.ngram_size_n; + uint16_t size_value = config.params.ngram_size_m; + bool key_only = (config.type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K); + uint16_t check_rate = config.params.ngram_check_rate; + uint16_t min_hits = config.params.ngram_min_hits; - prompt_dft.push_back(id_last); + return common_ngram_map(size_key, size_value, key_only, check_rate, min_hits); +} - LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx_dft, prompt_dft).c_str()); +static common_speculative_state_ngram_cache create_state_ngram_cache( + const std::string & path_static, const std::string & path_dynamic, + const common_speculative_config & config) { + uint16_t n_draft = 8; // TODO get from config? - llama_decode(ctx_dft, batch); + // TODO bool param in common/common.h to set save_static/save_dynamic? + bool save_static = false; + bool save_dynamic = false; - common_sampler_reset(smpl); + common_speculative_state_ngram_cache state(config.type, path_static, path_dynamic, n_draft, save_static, save_dynamic); - // sample n_draft tokens from the draft model - for (int i = 0; i < params.n_draft; ++i) { - common_batch_clear(batch); + return state; +} - common_sampler_sample(smpl, ctx_dft, 0, true); +std::string common_speculative_type_name_str() { + std::string result; + for (size_t i = 0; i < common_speculative_types.size(); i++) { + if (i > 0) { + result += ", "; + } + result += common_speculative_type_to_str(common_speculative_types[i]); + } + return result; +} - const auto * cur_p = common_sampler_get_candidates(smpl, true); +std::string common_speculative_type_to_str(enum common_speculative_type type) { + switch (type) { + case COMMON_SPECULATIVE_TYPE_NONE: return "none"; + case COMMON_SPECULATIVE_TYPE_DRAFT: return "draft"; + case COMMON_SPECULATIVE_TYPE_EAGLE3: return "eagle3"; + case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: return "ngram_simple"; + case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: return "ngram_map_k"; + case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: return "ngram_map_k4v"; + case COMMON_SPECULATIVE_TYPE_NGRAM_MOD: return "ngram_mod"; + case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE: return "ngram_cache"; + default: return "unknown"; + } +} - for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) { - LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", - k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str()); +enum common_speculative_type common_speculative_type_from_name(const std::string & name) { + const auto it = common_speculative_type_from_name_map.find(name); + if (it == common_speculative_type_from_name_map.end()) { + return COMMON_SPECULATIVE_TYPE_COUNT; + } + return it->second; +} + +bool common_speculative_is_compat(llama_context * ctx_tgt) { + auto * mem = llama_get_memory(ctx_tgt); + if (mem == nullptr) { + return false; + } + + bool res = true; + + llama_memory_clear(mem, true); + + // eval 2 tokens to check if the context is compatible + std::vector tmp; + tmp.push_back(0); + tmp.push_back(0); + + int ret = llama_decode(ctx_tgt, llama_batch_get_one(tmp.data(), tmp.size())); + if (ret != 0) { + LOG_ERR("%s: llama_decode() failed: %d\n", __func__, ret); + res = false; + goto done; + } + + // try to remove the last tokens + if (!llama_memory_seq_rm(mem, 0, 1, -1)) { + LOG_WRN("%s: the target context does not support partial sequence removal\n", __func__); + res = false; + goto done; + } + +done: + llama_memory_clear(mem, true); + llama_synchronize(ctx_tgt); + + return res; +} + +// initialization of the speculative decoding system +// +common_speculative * common_speculative_init( + common_params_speculative & params, + llama_context * ctx_tgt) { + llama_context * ctx_dft = nullptr; + if (params.model_dft) { + ctx_dft = llama_init_from_model(params.model_dft, params.cparams_dft); + if (ctx_dft == nullptr) { + LOG_ERR("%s", "failed to create draft context\n"); + return nullptr; } + } - // add drafted token for each sequence - const llama_token id = cur_p->data[0].id; + // Compute the implementations to use based on the config and their order of preference + std::vector configs = {}; // list of speculative configs to try + { + bool has_draft = !params.mparams_dft.path.empty(); + bool has_draft_eagle3 = false; // TODO PR-18039: if params.speculative.eagle3 + + bool has_ngram_cache = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_CACHE); + bool has_ngram_simple = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE); + bool has_ngram_map_k = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K); + bool has_ngram_map_k4v = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V); + bool has_ngram_mod = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_MOD); + + // In a more complex implementation we could use the same implementation but with different parameters. + // This was initially used in PR-18471 but removed to simplify the code. + if (has_ngram_simple) { + // This implementation can guess a lot of tokens without any draft model. + configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, params)); + } + if (has_ngram_map_k) { + configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, params)); + } + if (has_ngram_map_k4v) { + // This implementation can guess tokens with high acceptance rate but is more expensive. + configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, params)); + } + if (has_ngram_mod) { + // shared instance for all speculative decoding contexts + if (!params.ngram_mod) { + params.ngram_mod = std::make_shared(params.ngram_size_n, 4*1024*1024); - common_sampler_accept(smpl, id, true); + LOG_INF("%s: initialized ngram_mod with n=%d, size=%zu (%.3f MB)\n", __func__, + params.ngram_size_n, params.ngram_mod->size(), + (float)(params.ngram_mod->size_bytes())/1024/1024); - result.push_back(id); + if (params.ngram_size_n < 16) { + LOG_WRN("%s: ngram_mod n=%d is too small - poor quality is possible, see: https://github.com/ggml-org/llama.cpp/pull/19164\n", __func__, params.ngram_size_n); + } + } - if (params.n_draft <= (int) result.size()) { - break; + configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_MOD, params)); + } + if (has_ngram_cache) { + configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, params)); + } + if (has_draft) { + configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT, params)); } + if (has_draft_eagle3) { + configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_EAGLE3, params)); + } + } - // only collect very high-confidence draft tokens - if (cur_p->data[0].p < params.p_min) { - break; + std::vector> impls = {}; + + for (const common_speculative_config & config : configs) { + LOG_DBG("%s: adding implementation %s\n", __func__, common_speculative_type_to_str(config.type).c_str()); + switch (config.type) { + case COMMON_SPECULATIVE_TYPE_NONE: + break; + case COMMON_SPECULATIVE_TYPE_DRAFT: { + impls.push_back(std::make_unique(config.type, + /* .ctx_tgt = */ ctx_tgt, + /* .ctx_dft = */ ctx_dft, + /* .replacements = */ params.replacements + )); + break; + } + case COMMON_SPECULATIVE_TYPE_EAGLE3: { + impls.push_back(std::make_unique(config.type)); + break; + } + case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: { + common_ngram_map ngram_map = get_common_ngram_map(config); + + uint16_t ngram_size_key = ngram_map.size_key; + uint16_t mgram_size_value = ngram_map.size_value; + uint16_t check_rate = ngram_map.check_rate; + + auto config_simple = common_ngram_simple_config { + /* .size_ngram = */ ngram_size_key, + /* .size_mgram = */ mgram_size_value, + /* .check_rate = */ check_rate + }; + auto state = std::make_unique( + /* .type = */ config.type, + /* .state = */ config_simple + ); + impls.push_back(std::move(state)); + break; + } + case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: + case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: { + impls.push_back(std::make_unique( + (config.type), + get_common_ngram_map(config) + )); + break; + } + case COMMON_SPECULATIVE_TYPE_NGRAM_MOD: { + GGML_ASSERT(config.params.ngram_mod); + impls.push_back(std::make_unique(config.type, *config.params.ngram_mod)); + break; + } + case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE: { + auto state = create_state_ngram_cache( + params.lookup_cache_static, params.lookup_cache_dynamic, config); + impls.push_back(std::make_unique(state)); + break; + } + default: + break; } + } - common_batch_add(batch, id, n_past + i + 1, { 0 }, true); + if (impls.empty()) { + LOG_WRN("%s", "no implementations specified for speculative decoding\n"); + return nullptr; + } - // evaluate the drafted tokens on the draft model - llama_decode(ctx_dft, batch); + auto * result = new common_speculative { + /* .impls = */ std::move(impls) + }; - prompt_dft.push_back(id); + return result; +} + +void common_speculative_free(common_speculative * spec) { + if (spec == nullptr) { + return; } - if (!spec->vocab_dft_compatible) { - std::string detokenized = common_detokenize(ctx_dft, result, true); - detokenized = replace_to_tgt(spec, detokenized); - LOG_DBG("draft->main detokenized string: '%s'\n", detokenized.c_str()); - result = common_tokenize(ctx_tgt, detokenized, false, true); - if (result.size() > (size_t)params.n_draft) { - result.resize(params.n_draft); + delete spec; +} + +void common_speculative_begin(common_speculative * spec, const llama_tokens & prompt) { + if (spec == nullptr) { + return; + } + + for (auto & impl : spec->impls) { + common_time_meas tm(impl->t_begin_us, !impl->gen_perf); + impl->begin(prompt); + } +} + +llama_tokens common_speculative_draft( + common_speculative * spec, + const common_params_speculative & params, + const llama_tokens & prompt_tgt, // specified in target model vocab + llama_token id_last) { + llama_tokens result; + + spec->curr_impl = nullptr; // reset current implementation + + for (auto & impl : spec->impls) { + { + common_time_meas tm(impl->t_draft_us, !impl->gen_perf); + impl->draft(params, prompt_tgt, id_last, result); + impl->drafts_call_count++; + } + + if (!result.empty()) { + LOG_DBG("%s: called impl %s, hist size = %zu, call_count = %zu, gen = %zu\n", __func__, + common_speculative_type_to_str(impl.get()->type).c_str(), prompt_tgt.size(), + impl.get()->drafts_call_count, result.size()); + + spec->curr_impl = impl.get(); // set current implementation for stats + impl->drafts_generated_count++; + impl->drafts_generated_tokens += result.size(); + + break; // We have a draft, so break out of the loop and return it. } } + return result; } + +void common_speculative_accept(common_speculative * spec, uint16_t n_accepted) { + if (n_accepted == 0) { + return; + } + + common_speculative_state * impl = spec->curr_impl; + + GGML_ASSERT(impl); + + { + common_time_meas tm(impl->t_accept_us, !impl->gen_perf); + if (n_accepted > 0) { + impl->drafts_accepted_count++; + impl->drafts_accepted_tokens += n_accepted; + } + + impl->accept(n_accepted); + } +} + +void common_speculative_print_stats(const common_speculative * spec) { + if (spec == nullptr) { + return; + } + + for (const auto & impl : spec->impls) { + std::string str_perf; + if (impl->gen_perf) { + std::ostringstream oss; + oss << std::fixed << std::setprecision(3) << impl->t_begin_us / 1000.0 << ", "; + oss << std::fixed << std::setprecision(3) << impl->t_draft_us / 1000.0 << ", "; + oss << std::fixed << std::setprecision(3) << impl->t_accept_us / 1000.0; + str_perf = ", dur(b,g,a) = " + oss.str() + " ms"; + } else { + str_perf = ""; + } + + LOG_INF("statistics %s: #calls = %zu, #gen drafts = %zu, #acc drafts = %zu, #gen tokens = %zu, #acc tokens = %zu%s\n", + common_speculative_type_to_str(impl->type).c_str(), + impl->drafts_call_count, + impl->drafts_generated_count, + impl->drafts_accepted_count, + impl->drafts_generated_tokens, + impl->drafts_accepted_tokens, + str_perf.c_str()); + } +} diff --git a/common/speculative.h b/common/speculative.h index e69d7aaa1eb..876cde3d180 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -5,31 +5,37 @@ struct common_speculative; -struct common_speculative_params { - int n_draft = 16; // max drafted tokens - int n_reuse = 256; +// comma separated list of all types +std::string common_speculative_type_name_str(); - float p_min = 0.75f; // min probability required to accept a token in the draft -}; +// convert string to type +enum common_speculative_type common_speculative_type_from_name(const std::string & name); -struct common_speculative * common_speculative_init( - struct llama_context * ctx_tgt, - struct llama_context * ctx_dft -); +// convert type to string +std::string common_speculative_type_to_str(enum common_speculative_type type); -void common_speculative_free(struct common_speculative * spec); +// check if the llama_context is compatible for speculative decoding +// note: clears the memory of the context +bool common_speculative_is_compat(llama_context * ctx_tgt); -bool common_speculative_are_compatible( - const struct llama_context * ctx_tgt, - const struct llama_context * ctx_dft); +common_speculative * common_speculative_init( + common_params_speculative & params, + llama_context * ctx_tgt); -void common_speculative_add_replacement_tgt_dft( - struct common_speculative * spec, - const char *source, const char *dest); +void common_speculative_free(common_speculative * spec); + +// optionally call once at the beginning of a new generation +void common_speculative_begin(common_speculative * spec, const llama_tokens & prompt); // sample up to n_draft tokens and add them to the batch using the draft model -llama_tokens common_speculative_gen_draft( - struct common_speculative * spec, - struct common_speculative_params params, - const llama_tokens & prompt, - llama_token id_last); +llama_tokens common_speculative_draft( + common_speculative * spec, + const common_params_speculative & params, + const llama_tokens & prompt, + llama_token id_last); + +// informs the speculative decoder that n_accepted tokens were accepted by the target model +void common_speculative_accept(common_speculative * spec, uint16_t n_accepted); + +// print statistics about the speculative decoding +void common_speculative_print_stats(const common_speculative * spec); diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index ab015dd2c3a..9be5bcb9ece 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -514,8 +514,7 @@ def set_gguf_parameters(self): raise NotImplementedError("set_gguf_parameters() must be implemented in subclasses") def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - del bid # unused - + del bid # unused return [(self.map_tensor_name(name), data_torch)] def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool: @@ -587,6 +586,10 @@ def prepare_tensors(self): gguf.MODEL_TENSOR.A_ENC_EMBD_POS, gguf.MODEL_TENSOR.ALTUP_CORRECT_COEF, gguf.MODEL_TENSOR.ALTUP_PREDICT_COEF, + # Kimi KDA conv weights should be F32 + gguf.MODEL_TENSOR.SSM_CONV1D_Q, + gguf.MODEL_TENSOR.SSM_CONV1D_K, + gguf.MODEL_TENSOR.SSM_CONV1D_V, ) ) or new_name[-7:] not in (".weight", ".lora_a", ".lora_b") @@ -904,10 +907,10 @@ def set_gguf_parameters(self): if (f_norm_eps := self.find_hparam(["layer_norm_eps", "layer_norm_epsilon", "norm_epsilon"], optional=True)) is not None: self.gguf_writer.add_layer_norm_eps(f_norm_eps) logger.info(f"gguf: layer norm epsilon = {f_norm_eps}") - if (n_experts := self.hparams.get("num_local_experts")) is not None: + if (n_experts := self.find_hparam(["num_local_experts", "num_experts"], optional=True)) is not None: self.gguf_writer.add_expert_count(n_experts) logger.info(f"gguf: expert count = {n_experts}") - if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None: + if (n_experts_used := self.find_hparam(["num_experts_per_tok", "num_experts_per_token"], optional=True)) is not None: self.gguf_writer.add_expert_used_count(n_experts_used) logger.info(f"gguf: experts used count = {n_experts_used}") if (n_expert_groups := self.hparams.get("n_group")) is not None: @@ -917,7 +920,7 @@ def set_gguf_parameters(self): self.gguf_writer.add_expert_group_used_count(n_group_used) logger.info(f"gguf: expert groups used count = {n_group_used}") - if (score_func := self.find_hparam(["score_function", "scoring_func", "score_func"], optional=True)) is not None: + if (score_func := self.find_hparam(["score_function", "scoring_func", "score_func", "moe_router_activation", "moe_router_activation_func"], optional=True)) is not None: if score_func == "sigmoid": self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID) elif score_func == "softmax": @@ -1809,7 +1812,7 @@ class MmprojModel(ModelBase): preprocessor_config: dict[str, Any] global_config: dict[str, Any] - n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth", "encoder_layers"] + n_block_keys = ["layers", "n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth", "encoder_layers"] has_vision_encoder: bool = True # by default has_audio_encoder: bool = False @@ -1831,7 +1834,13 @@ def __init__(self, *args, **kwargs): if "audio_config" not in self.hparams: self.hparams["audio_config"] = {} text_config = {**self.hparams, **self.hparams["text_config"]} - self.n_embd_text = text_config.get("hidden_size", text_config.get("n_embd", 0)) + n_embd_text = ( + text_config.get("hidden_size") + or text_config.get("n_embd") + or text_config.get("embed_dim") + or 0 + ) + self.n_embd_text = int(n_embd_text) if n_embd_text else 0 else: text_config = { k: v for k, v in self.hparams.items() if k not in ["vision_encoder", "audio_encoder"] @@ -1981,13 +1990,9 @@ def set_gguf_parameters(self): self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"]) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - del bid # unused - n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads")) n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed")) - tensors: list[tuple[str, Tensor]] = [] - if re.match(r"gpt_neox\.layers\.\d+\.attention\.query_key_value\.weight", name): # Map bloom-style qkv_linear to gpt-style qkv_linear # bloom: https://github.com/huggingface/transformers/blob/main/src/transformers/models/bloom/modeling_bloom.py#L238-L252 # noqa @@ -2014,9 +2019,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter ) logger.info("re-format attention.linear_qkv.bias") - tensors.append((self.map_tensor_name(name), data_torch)) - - return tensors + yield from super().modify_tensors(data_torch, name, bid) @ModelBase.register("BloomForCausalLM", "BloomModel") @@ -2036,15 +2039,11 @@ def set_gguf_parameters(self): self.gguf_writer.add_file_type(self.ftype) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - del bid # unused - n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads")) n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed")) name = re.sub(r'transformer\.', '', name) - tensors: list[tuple[str, Tensor]] = [] - if re.match(r"h\.\d+\.self_attention\.query_key_value\.weight", name): # Map bloom-style qkv_linear to gpt-style qkv_linear # bloom: https://github.com/huggingface/transformers/blob/main/src/transformers/models/bloom/modeling_bloom.py#L238-L252 # noqa @@ -2071,9 +2070,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter ) logger.info("re-format attention.linear_qkv.bias") - tensors.append((self.map_tensor_name(name), data_torch)) - - return tensors + yield from super().modify_tensors(data_torch, name, bid) @ModelBase.register("MPTForCausalLM") @@ -2108,15 +2105,13 @@ def set_gguf_parameters(self): self.gguf_writer.add_max_alibi_bias(0.0) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - del bid # unused - if "scales" in name: new_name = self.map_tensor_name(name, try_suffixes=(".weight", ".bias", ".scales")) new_name = new_name.replace("scales", "act.scales") else: new_name = self.map_tensor_name(name, try_suffixes=(".weight", ".bias")) - return [(new_name, data_torch)] + yield from super().modify_tensors(data_torch, new_name, bid) @ModelBase.register("OrionForCausalLM") @@ -2170,11 +2165,9 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter head_count = self.hparams["num_attention_heads"] head_count_kv = self.hparams.get("num_key_value_heads", head_count) - tensors: list[tuple[str, Tensor]] = [] - if bid is not None and name == f"model.layers.{bid}.self_attn.W_pack.weight": logger.info(f"Unpacking and permuting layer {bid}") - tensors = [ + yield from [ (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid), self._reverse_hf_permute_part(data_torch, 0, head_count, head_count)), (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid), @@ -2183,9 +2176,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter self._reverse_hf_part(data_torch, 2)), ] else: - tensors = [(self.map_tensor_name(name), data_torch)] - - return tensors + yield from self.modify_tensors(data_torch, self.map_tensor_name(name), bid) def _reverse_hf_permute(self, weights: Tensor, n_head: int, n_kv_head: int | None = None) -> Tensor: if n_kv_head is not None and n_head != n_kv_head: @@ -2266,8 +2257,6 @@ def set_gguf_parameters(self): self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"]) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - del bid # unused - head_count = self.hparams["num_attention_heads"] head_count_kv = self.hparams.get("num_key_value_heads", head_count) @@ -2277,7 +2266,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter if name.endswith("k_proj.weight"): data_torch = self._reverse_hf_permute(data_torch, head_count, head_count_kv) - return [(self.map_tensor_name(name), data_torch)] + yield from super().modify_tensors(data_torch, name, bid) def _reverse_hf_permute(self, weights: Tensor, n_head: int, n_kv_head: int | None = None) -> Tensor: if n_kv_head is not None and n_head != n_kv_head: @@ -2314,8 +2303,6 @@ def set_gguf_parameters(self): self.gguf_writer.add_file_type(self.ftype) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - del bid # unused - # QKV tensor transform # The original query_key_value tensor contains n_head_kv "kv groups", # each consisting of n_head/n_head_kv query weights followed by one key @@ -2337,7 +2324,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter v = qkv[:, [-1]].reshape(n_head_kv * head_dim, head_dim * n_head) data_torch = torch.cat((q, k, v)).reshape_as(data_torch) - return [(self.map_tensor_name(name), data_torch)] + yield from super().modify_tensors(data_torch, name, bid) @ModelBase.register("GPTBigCodeForCausalLM") @@ -2399,22 +2386,20 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter n_head_kv = 1 head_dim = self.hparams["n_embd"] // n_head - tensors: list[tuple[str, Tensor]] = [] - if bid is not None: if name == f"transformer.h.{bid}.attn.kv.weight": - tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid), data_torch[:n_head_kv * head_dim])) - tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_V, bid), data_torch[n_head_kv * head_dim:])) - elif name == f"transformer.h.{bid}.attn.q.weight": - tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid), data_torch)) - elif name == f"transformer.h.{bid}.mlp.gate_up_proj.weight": - tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE, bid), data_torch[:ff_dim])) - tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP, bid), data_torch[ff_dim:])) - - if len(tensors) == 0: - tensors.append((self.map_tensor_name(name), data_torch)) + yield from super().modify_tensors(data_torch[:n_head_kv * head_dim], self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid), bid) + yield from super().modify_tensors(data_torch[n_head_kv * head_dim:], self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_V, bid), bid) + return + if name == f"transformer.h.{bid}.attn.q.weight": + yield from super().modify_tensors(data_torch, self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid), bid) + return + if name == f"transformer.h.{bid}.mlp.gate_up_proj.weight": + yield from super().modify_tensors(data_torch[:ff_dim], self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE, bid), bid) + yield from super().modify_tensors(data_torch[ff_dim:], self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP, bid), bid) + return - return tensors + yield from super().modify_tensors(data_torch, name, bid) @ModelBase.register("StableLmForCausalLM", "StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM") @@ -2461,7 +2446,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter if len(self._q_norms[bid]) >= n_head: return self._stack_qk_norm(bid, n_head, self._q_norms[bid], "q_layernorm") else: - return [] + return if name.find("k_layernorm.norms") != -1: assert bid is not None @@ -2474,9 +2459,9 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter if len(self._k_norms[bid]) >= n_kv_head: return self._stack_qk_norm(bid, n_kv_head, self._k_norms[bid], "k_layernorm") else: - return [] + return - return [(self.map_tensor_name(name), data_torch)] + yield from super().modify_tensors(data_torch, name, bid) def _stack_qk_norm(self, bid: int, n_head: int, norms: dict[str, Tensor], layer_name: str = "q_layernorm"): datas: list[Tensor] = [] @@ -2488,9 +2473,8 @@ def _stack_qk_norm(self, bid: int, n_head: int, norms: dict[str, Tensor], layer_ data_torch = torch.stack(datas, dim=0) merged_name = f"model.layers.{bid}.self_attn.{layer_name}.weight" - new_name = self.map_tensor_name(merged_name) - return [(new_name, data_torch)] + yield from super().modify_tensors(data_torch, merged_name, bid) def prepare_tensors(self): super().prepare_tensors() @@ -2616,7 +2600,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter ) if is_multimodal_tensor: - return [] # skip vision tensors + return # skip vision tensors elif self.hf_arch == "LlamaModel": name = "model." + name elif name.startswith("model.text_model"): @@ -2642,8 +2626,6 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter self._experts[bid][name] = data_torch if len(self._experts[bid]) >= n_experts * 3: - tensors: list[tuple[str, Tensor]] = [] - # merge the experts into a single 3d tensor for wid in ["w1", "w2", "w3"]: datas: list[Tensor] = [] @@ -2657,14 +2639,12 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter merged_name = f"layers.{bid}.feed_forward.experts.{wid}.weight" - new_name = self.map_tensor_name(merged_name) - - tensors.append((new_name, data_torch)) - return tensors + yield from super().modify_tensors(data_torch, merged_name, bid) + return else: - return [] + return - return [(self.map_tensor_name(name), data_torch)] + yield from super().modify_tensors(data_torch, name, bid) def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: if rope_params := self.rope_parameters.get("full_attention", self.rope_parameters): @@ -2755,8 +2735,6 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter self._experts[bid][name] = data_torch if len(self._experts[bid]) >= n_experts * 3: - tensors: list[tuple[str, Tensor]] = [] - # merge the experts into a single 3d tensor for w_name in ["gate_proj", "up_proj", "down_proj"]: datas: list[Tensor] = [] @@ -2768,17 +2746,16 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter data_torch = torch.stack(datas, dim=0) merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight" - new_name = self.map_tensor_name(merged_name) - tensors.append((new_name, data_torch)) + yield from ModelBase.modify_tensors(self, data_torch, merged_name, bid) - return tensors + return else: - return [] + return if name.endswith(".expert_bias"): name = name.replace(".expert_bias", ".expert_bias.bias") - return [(self.map_tensor_name(name), data_torch)] + yield from ModelBase.modify_tensors(self, data_torch, name, bid) @ModelBase.register( @@ -2835,7 +2812,6 @@ def set_gguf_parameters(self): self.gguf_writer.add_vision_spatial_merge_size(self.global_config["spatial_merge_size"]) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - del bid # unused n_head = ( self.hparams["num_attention_heads"] if not self.is_mistral_format else self.find_vparam(["num_attention_heads"]) ) @@ -2856,7 +2832,8 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter data_torch = LlamaModel.permute(data_torch, n_head, n_head) if name.endswith(("k_proj.weight", "k_proj.bias")) and not self.is_mistral_format: data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head) - return [(self.map_tensor_name(name), data_torch)] + yield from super().modify_tensors(data_torch, name, bid) + return embed_key = "embed_tokens.weight" if not self.is_mistral_format else "tok_embeddings.weight" if self.img_break_tok_id > 0 and embed_key in name: @@ -2864,9 +2841,9 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter # for pixtral model, we need to extract the [IMG_BREAK] token embedding img_break_embd = data_torch[self.img_break_tok_id] name = gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK] - return [(self.map_tensor_name(name), img_break_embd)] + yield from super().modify_tensors(img_break_embd, name, bid) - return [] # skip other tensors + return # skip other tensors @ModelBase.register("Idefics3ForConditionalGeneration", "SmolVLMForConditionalGeneration") @@ -2897,13 +2874,12 @@ def tensor_force_quant(self, name, new_name, bid, n_dims): return super().tensor_force_quant(name, new_name, bid, n_dims) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - del bid # unused is_vision_tensor = "vision_tower" in name or "vision_model" in name or "model.connector" in name if is_vision_tensor: - return [(self.map_tensor_name(name), data_torch)] + yield from super().modify_tensors(data_torch, name, bid) - return [] # skip other tensors + return # skip other tensors @ModelBase.register( @@ -2942,18 +2918,17 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None): name_gate = name.replace("gate_up_proj", "gate_proj.weight") dim_half = data_torch.shape[-1] // 2 gate_proj_weight, up_proj_weight = data_torch.transpose(-1, -2).split(dim_half, dim=-2) - return [ - (self.map_tensor_name(name_gate), gate_proj_weight), - (self.map_tensor_name(name_up), up_proj_weight) - ] + yield from super().modify_tensors(gate_proj_weight, name_gate, bid) + yield from super().modify_tensors(up_proj_weight, name_up, bid) + return if name.endswith("down_proj"): name += ".weight" data_torch = data_torch.transpose(-1, -2) if "multi_modal_projector" in name or "vision_model" in name: - return [] - return super().modify_tensors(data_torch, name, bid) + return + yield from super().modify_tensors(data_torch, name, bid) @ModelBase.register("Llama4ForConditionalGeneration") @@ -2967,19 +2942,21 @@ def set_gguf_parameters(self): self.gguf_writer.add_vision_use_gelu(True) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - del bid # unused if "multi_modal_projector" in name or "vision_model" in name: # process vision tensors if "positional_embedding_vlm" in name and ".weight" not in name: name += ".weight" if "multi_modal_projector.linear_1" in name: # despite the name with number postfix, this is a single fully connected layer - return [(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_MMPROJ_FC] + '.weight', data_torch)] - return [(self.map_tensor_name(name), data_torch)] - return [] + yield (gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_MMPROJ_FC] + '.weight', data_torch) + else: + yield from super().modify_tensors(data_torch, name, bid) -@ModelBase.register("Mistral3ForConditionalGeneration") +@ModelBase.register( + "Mistral3ForConditionalGeneration", + "Ministral3ForCausalLM", +) class Mistral3Model(LlamaModel): model_arch = gguf.MODEL_ARCH.MISTRAL3 @@ -3005,9 +2982,9 @@ def set_gguf_parameters(self): def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None): name = name.replace("language_model.", "") if "multi_modal_projector" in name or "vision_tower" in name: - return [] + return - return super().modify_tensors(data_torch, name, bid) + yield from super().modify_tensors(data_torch, name, bid) @ModelBase.register("DeciLMForCausalLM") @@ -3146,7 +3123,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter data_torch = DeciModel.permute(data_torch, n_head, n_head) if name.endswith(("k_proj.weight", "k_proj.bias")): data_torch = DeciModel.permute(data_torch, n_head, n_kv_head) - return [(self.map_tensor_name(name), data_torch)] + yield from super().modify_tensors(data_torch, name, bid) def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: if rope_params := self.rope_parameters.get("full_attention", self.rope_parameters): @@ -3220,7 +3197,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter # transform weight into 1/0/-1 (in fp32) data_torch = self.weight_quant(data_torch) - yield (new_name, data_torch) + yield from super().modify_tensors(data_torch, name, bid) @ModelBase.register("GrokForCausalLM", "Grok1ForCausalLM") @@ -3276,11 +3253,11 @@ def set_gguf_parameters(self): _cur_expert = "" def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - tensors: list[tuple[str, Tensor]] = [] + deferred: list[tuple[Tensor, str, int | None]] = [] is_expert = ".moe." in name or ".block_sparse_moe.experts." in name if not is_expert: - tensors.append((self.map_tensor_name(name), data_torch)) + deferred.append((data_torch, name, bid)) # process the experts separately if is_expert or self._cur_expert: @@ -3295,11 +3272,11 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter if name in self._experts[bid]: self._cur_expert = name self._experts[bid][name].append(data_torch) - return [] + return elif is_expert: self._cur_expert = name self._experts[bid][name] = [data_torch] - return [] + return else: self._cur_expert = "" @@ -3321,11 +3298,10 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter merged_name = f"transformer.decoder_layer.{bid}.moe.{wid[0]}.weight" - new_name = self.map_tensor_name(merged_name) + yield from super().modify_tensors(data_torch, merged_name, bid) - yield (new_name, data_torch) - - yield from tensors + for t in deferred: + yield from super().modify_tensors(*t) @ModelBase.register("DbrxForCausalLM") @@ -3357,8 +3333,6 @@ def set_gguf_parameters(self): logger.info(f"gguf: file type = {self.ftype}") def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - del bid # unused - n_expert = self.hparams["ffn_config"]["moe_num_experts"] n_ff = self.hparams["ffn_config"]["ffn_hidden_size"] n_embd = self.hparams["d_model"] @@ -3389,7 +3363,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter # https://huggingface.co/databricks/dbrx-instruct/blob/main/model.safetensors.index.json#L15 new_name = self.map_tensor_name(name if not experts else name + ".weight", try_suffixes=(".weight",)) - return [(new_name, data_torch)] + yield from super().modify_tensors(data_torch, new_name, bid) def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool: del name, new_name, bid # unused @@ -3434,8 +3408,6 @@ def set_vocab(self): self._set_vocab_sentencepiece() def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - del bid # unused - n_head = self.hparams["num_attention_heads"] n_kv_head = self.hparams.get("num_key_value_heads") @@ -3445,7 +3417,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter if name.endswith(("k_proj.weight")): data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head) - return [(self.map_tensor_name(name), data_torch)] + yield from super().modify_tensors(data_torch, name, bid) @ModelBase.register("MiniCPM3ForCausalLM") @@ -3555,7 +3527,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter or name.startswith("vision_model") or name.startswith("audio_tower") \ or name.startswith("model.vision_tower") or name.startswith("model.multi_modal_projector"): # skip vision and audio tensors - return [] + return yield from super().modify_tensors(data_torch, name, bid) @@ -3752,23 +3724,20 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter total_k_dim = num_kv_heads * head_dim total_v_dim = num_kv_heads * head_dim q_proj_weight, k_proj_weight, v_proj_weight = data_torch.split([total_q_dim, total_k_dim, total_v_dim], dim=0) - return [ - (self.map_tensor_name(name_q), q_proj_weight), - (self.map_tensor_name(name_k), k_proj_weight), - (self.map_tensor_name(name_v), v_proj_weight) - ] + yield from super().modify_tensors(q_proj_weight, name_q, bid) + yield from super().modify_tensors(k_proj_weight, name_k, bid) + yield from super().modify_tensors(v_proj_weight, name_v, bid) # split the up_gate_proj into gate and up # up_gate_proj shape: [2 * intermediate_size, hidden_size] - if "up_gate_proj" in name: + elif "up_gate_proj" in name: name_up = name.replace("up_gate_proj.weight", "up_proj.weight") name_gate = name.replace("up_gate_proj.weight", "gate_proj.weight") dim_half = data_torch.shape[0] // 2 gate_proj_weight, up_proj_weight = data_torch.split(dim_half, dim=0) - return [ - (self.map_tensor_name(name_gate), gate_proj_weight), - (self.map_tensor_name(name_up), up_proj_weight) - ] - return [(self.map_tensor_name(name), data_torch)] + yield from super().modify_tensors(gate_proj_weight, name_gate, bid) + yield from super().modify_tensors(up_proj_weight, name_up, bid) + else: + yield from super().modify_tensors(data_torch, name, bid) @ModelBase.register("Ernie4_5_MoeForCausalLM") @@ -3801,20 +3770,20 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter # skip Multi-Token Prediction (MTP) layers (again, same as DeepseekV2) match = re.match(r"model.mtp_block.(\d+)", name) if match: - return [] + return # skip all other MTP tensors for now match = re.match(r"model.mtp_emb_norm.(\d+)", name) if match: - return [] + return match = re.match(r"model.mtp_hidden_norm.(\d+)", name) if match: - return [] + return match = re.match(r"model.mtp_linear_proj.(\d+)", name) if match: - return [] + return # process the experts separately if name.find("mlp.experts") != -1: @@ -3827,8 +3796,6 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter self._experts[bid][name] = data_torch if len(self._experts[bid]) >= n_experts * 3: - tensors: list[tuple[str, Tensor]] = [] - # merge the experts into a single 3d tensor for w_name in ["gate_proj", "up_proj", "down_proj"]: datas: list[Tensor] = [] @@ -3840,13 +3807,9 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter data_torch = torch.stack(datas, dim=0) merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight" - new_name = self.map_tensor_name(merged_name) - tensors.append((new_name, data_torch)) - - return tensors - else: - return [] - return [(self.map_tensor_name(name), data_torch)] + yield from super().modify_tensors(data_torch, merged_name, bid) + else: + yield from ModelBase.modify_tensors(self, data_torch, name, bid) def prepare_tensors(self): super().prepare_tensors() @@ -3877,14 +3840,13 @@ def set_vocab(self): self._set_vocab_gpt2() def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - del bid # unused if name.startswith("thinker."): name = name.replace("thinker.", "") if name.startswith("visual") or name.startswith("audio") or \ name.startswith("talker") or name.startswith("token2wav"): # skip multimodal tensors - return [] - return [(self.map_tensor_name(name), data_torch)] + return + yield from super().modify_tensors(data_torch, name, bid) @ModelBase.register("Qwen2VLModel", "Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration") @@ -3933,7 +3895,6 @@ def tensor_force_quant(self, name, new_name, bid, n_dims): return super().tensor_force_quant(name, new_name, bid, n_dims) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - del bid # unused if name.startswith("visual."): # process visual tensors # split QKV tensors if needed @@ -3947,23 +3908,18 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter wq = data_torch[:c] wk = data_torch[c: c * 2] wv = data_torch[c * 2:] - return [ - (self.map_tensor_name(name.replace("qkv", "q")), wq), - (self.map_tensor_name(name.replace("qkv", "k")), wk), - (self.map_tensor_name(name.replace("qkv", "v")), wv), - ] + yield from super().modify_tensors(wq, name.replace("qkv", "q"), bid) + yield from super().modify_tensors(wk, name.replace("qkv", "k"), bid) + yield from super().modify_tensors(wv, name.replace("qkv", "v"), bid) elif 'patch_embed.proj.weight' in name: # split Conv3D into Conv2Ds c1, c2, kt, kh, kw = data_torch.shape del c1, c2, kh, kw # unused assert kt == 2, "Current implmentation only support temporal_patch_size of 2" - return [ - (gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".weight" , data_torch[:, :, 0, ...]), - (gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".weight.1", data_torch[:, :, 1, ...]), - ] + yield (gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".weight" , data_torch[:, :, 0, ...]) + yield (gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".weight.1", data_torch[:, :, 1, ...]) else: - return [(self.map_tensor_name(name), data_torch)] - return [] # skip other tensors + yield from super().modify_tensors(data_torch, name, bid) @ModelBase.register("Qwen2_5OmniModel") @@ -4019,10 +3975,8 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter if "audio_bos_eos_token" in name: # this tensor is left unused in transformers code # https://github.com/huggingface/transformers/blob/6e3063422c4b1c014aa60c32b9254fd2902f0f28/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py#L1809 - return [] - return [(self.map_tensor_name(name), data_torch)] - - return super().modify_tensors(data_torch, name, bid) + return + yield from super().modify_tensors(data_torch, name, bid) @ModelBase.register("InternVisionModel") @@ -4069,7 +4023,6 @@ def _mapping_interns1_name(self, name): return name def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - del bid # unused vision_prefix = ['vision_model', 'mlp', 'model.vision_tower', 'model.multi_modal_projector'] # deal with intern-s1 special case name = self._mapping_interns1_name(name) @@ -4091,13 +4044,11 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter wq = data_torch[:c] wk = data_torch[c: c * 2] wv = data_torch[c * 2:] - return [ - (self.map_tensor_name(name.replace("attn.qkv", "self_attn.q_proj")), wq), - (self.map_tensor_name(name.replace("attn.qkv", "self_attn.k_proj")), wk), - (self.map_tensor_name(name.replace("attn.qkv", "self_attn.v_proj")), wv), - ] - return [(self.map_tensor_name(name), data_torch)] - return [] # skip other tensors + yield from super().modify_tensors(wq, name.replace("attn.qkv", "self_attn.q_proj"), bid) + yield from super().modify_tensors(wk, name.replace("attn.qkv", "self_attn.k_proj"), bid) + yield from super().modify_tensors(wv, name.replace("attn.qkv", "self_attn.v_proj"), bid) + else: + yield from super().modify_tensors(data_torch, name, bid) @ModelBase.register("WavTokenizerDec") @@ -4105,18 +4056,16 @@ class WavTokenizerDecModel(TextModel): model_arch = gguf.MODEL_ARCH.WAVTOKENIZER_DEC def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - del bid # unused - if \ name.endswith("codebook.cluster_size") or \ name.endswith("codebook.embed_avg") or \ name.endswith("codebook.inited"): logger.debug(f"Skipping {name!r}") - return [] + return logger.info(f"{self.map_tensor_name(name)} -> {data_torch.shape}") - return [(self.map_tensor_name(name), data_torch)] + yield from super().modify_tensors(data_torch, name, bid) def set_vocab(self): self._set_vocab_none() @@ -4159,44 +4108,32 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter # process the experts separately name = name.replace("language_model.", "") # InternVL - # handle aggregated expert tensors - # GGUF stores dimensions reversed from PyTorch, so: - # PyTorch (A,B,C) -> GGUF writes [C,B,A] -> GGML reads ne={C,B,A} - # Input shapes from HF: (n_expert, n_ff_exp, n_embd) or (n_expert, n_embd, n_ff_exp) - # Expected GGML ne: {n_embd, n_ff_exp, n_expert} for gate/up, {n_ff_exp, n_embd, n_expert} for down + # handle pre-packed expert tensors (e.g. Qwen3.5 MoE, Qwen3Next) + # HF stores these using nn.Linear convention: [n_expert, out_features, in_features] + # This matches the individual expert stacking path below (which stacks + # per-expert [out, in] weights into [n_expert, out, in]), so no permute is needed. if name.endswith("mlp.experts.down_proj") or name.endswith("mlp.experts.down_proj.weight"): mapped = f"{name}.weight" if not name.endswith(".weight") else name - # Input: (n_expert=128, n_ff_exp=768, n_embd=2048) - # Want GGML ne: {n_ff_exp, n_embd, n_expert} = {768, 2048, 128} - # Need PyTorch: (128, 2048, 768) [reversed of GGML] - # So: permute(0, 2, 1): (128, 768, 2048) -> (128, 2048, 768) - permuted = data_torch.permute(0, 2, 1).contiguous() - return [(self.map_tensor_name(mapped), permuted)] + # HF: [n_expert, n_embd, n_ff] → GGML: {n_ff, n_embd, n_expert} ✓ + yield from super().modify_tensors(data_torch, mapped, bid) + return if name.endswith("mlp.experts.gate_up_proj") or name.endswith("mlp.experts.gate_up_proj.weight"): - if data_torch.ndim < 3 or data_torch.shape[-1] % 2 != 0: - raise ValueError(f"Unexpected gate_up_proj shape for {name}: {tuple(data_torch.shape)}") - split_dim = data_torch.shape[-1] // 2 - gate = data_torch[..., :split_dim].contiguous() - up = data_torch[..., split_dim:].contiguous() - # Input gate/up: (n_expert=128, n_embd=2048, n_ff_exp=768) - # Want GGML ne: {n_embd, n_ff_exp, n_expert} = {2048, 768, 128} - # Need PyTorch: (128, 768, 2048) [reversed of GGML] - # So: permute(0, 2, 1): (128, 2048, 768) -> (128, 768, 2048) - base_name = name.removesuffix(".weight") - base = base_name.rsplit('.', 1)[0] - mapped_gate = f"{base}.gate_proj.weight" - mapped_up = f"{base}.up_proj.weight" - perm_gate = gate.permute(0, 2, 1).contiguous() - perm_up = up.permute(0, 2, 1).contiguous() - return [ - (self.map_tensor_name(mapped_gate), perm_gate), - (self.map_tensor_name(mapped_up), perm_up), - ] + # HF: [n_expert, 2*n_ff, n_embd] → split on dim=1 + n_ff = data_torch.shape[1] // 2 + gate = data_torch[:, :n_ff, :].contiguous() + up = data_torch[:, n_ff:, :].contiguous() + # gate/up: [n_expert, n_ff, n_embd] → GGML: {n_embd, n_ff, n_expert} ✓ + base_name = name.removesuffix(".weight").removesuffix(".gate_up_proj") + mapped_gate = f"{base_name}.gate_proj.weight" + mapped_up = f"{base_name}.up_proj.weight" + yield from super().modify_tensors(gate, mapped_gate, bid) + yield from super().modify_tensors(up, mapped_up, bid) + return if name.startswith("mlp") or name.startswith("vision_model") or name.startswith("model.vision_tower") or name.startswith("model.multi_modal_projector") or name.startswith("model.visual"): # skip visual tensors - return [] + return if name.find("experts") != -1: n_experts = self.hparams["num_experts"] assert bid is not None @@ -4207,8 +4144,6 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter self._experts[bid][name] = data_torch if len(self._experts[bid]) >= n_experts * 3: - tensors: list[tuple[str, Tensor]] = [] - # merge the experts into a single 3d tensor for w_name in ["down_proj", "gate_proj", "up_proj"]: datas: list[Tensor] = [] @@ -4222,14 +4157,12 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight" - new_name = self.map_tensor_name(merged_name) - - tensors.append((new_name, data_torch)) - return tensors + yield from super().modify_tensors(data_torch, merged_name, bid) + return else: - return [] + return - return [(self.map_tensor_name(name), data_torch)] + yield from super().modify_tensors(data_torch, name, bid) def prepare_tensors(self): super().prepare_tensors() @@ -4309,7 +4242,7 @@ def _get_cls_out_tensor(self, data_torch: Tensor) -> Tensor: def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: if "model.vision_" in name: # skip multimodal tensors - return [] + return if self.is_rerank: is_tied_head = self.is_tied_embeddings and "embed_tokens" in name @@ -4319,13 +4252,12 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.CLS_OUT] + ".weight", self._get_cls_out_tensor(data_torch), ) + yield cls_out_head if is_tied_head: - embed = (self.map_tensor_name(name), data_torch) - return [cls_out_head, embed] - if is_real_head: - return [cls_out_head] + yield from super().modify_tensors(data_torch, name, bid) + return - return super().modify_tensors(data_torch, name, bid) + yield from super().modify_tensors(data_torch, name, bid) @ModelBase.register("Qwen3MoeForCausalLM") @@ -4363,7 +4295,7 @@ def set_gguf_parameters(self): def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: if name.startswith("mtp"): - return [] # ignore MTP layers for now + return # ignore MTP layers for now if name.endswith(".A_log"): data_torch = -torch.exp(data_torch) elif name.endswith(".dt_bias"): @@ -4406,6 +4338,40 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter yield from super().modify_tensors(data_torch, name, bid) +@ModelBase.register("Qwen3_5ForCausalLM", "Qwen3_5TextForCausalLM") +class Qwen3_5Model(Qwen3NextModel): + model_arch = gguf.MODEL_ARCH.QWEN3_5 + + # Stores whichever of in_proj_a/in_proj_b is seen first, keyed by layer + _pending_ba: dict[int | None, tuple[str, Tensor]] = {} + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # Handle split in_proj_b + in_proj_a → concatenated SSM_BETA_ALPHA + # safetensors sorts alphabetically so in_proj_a arrives before in_proj_b + if "in_proj_a.weight" in name or "in_proj_b.weight" in name: + which = "a" if "in_proj_a" in name else "b" + if bid not in self._pending_ba: + self._pending_ba[bid] = (which, data_torch) + return + prev_which, prev_tensor = self._pending_ba.pop(bid) + assert prev_which != which, f"duplicate in_proj_{which} for layer {bid}" + b_tensor = prev_tensor if prev_which == "b" else data_torch + a_tensor = prev_tensor if prev_which == "a" else data_torch + ba_combined = torch.cat([b_tensor, a_tensor], dim=0) + yield (self.format_tensor_name(gguf.MODEL_TENSOR.SSM_BETA_ALPHA, bid, ".weight"), ba_combined) + return + else: + # Qwen3Next uses .qkvz tensor, so we use the super to get the other functionalities + # (norm correction, A_log to A etc.) for free + # Qwen2Moe already does the gate_up conversion properly, just use that + yield from super().modify_tensors(data_torch, name, bid) + + +@ModelBase.register("Qwen3_5MoeForCausalLM", "Qwen3_5MoeTextForCausalLM") +class Qwen3_5MoeModel(Qwen3_5Model): + model_arch = gguf.MODEL_ARCH.QWEN3_5_MOE + + @ModelBase.register("RND1") class RND1Model(Qwen2MoeModel): model_arch = gguf.MODEL_ARCH.RND1 @@ -4465,7 +4431,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter assert self.hparams_vision is not None # Skip text model tensors - they go in the text model file if name.startswith("model.language_model.") or name.startswith("lm_head."): - return [] + return if name.startswith("model.visual."): name = name.replace("model.visual.", "visual.", 1) @@ -4490,7 +4456,8 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter raise ValueError(f"Unexpected deepstack tensor: {name}") new_name = self.format_tensor_name(tensor_type, idx, suffix=f".{suffix}") - return [(new_name, data_torch)] + yield from super().modify_tensors(data_torch, new_name, bid) + return if name.startswith("visual.merger."): suffix = name.split(".", 2)[2] @@ -4510,7 +4477,8 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter new_name = self.format_tensor_name(gguf.MODEL_TENSOR.V_POST_NORM, suffix=f".{suffix.split('.', 1)[1]}") else: raise ValueError(f"Unexpected merger tensor: {name}") - return [(new_name, data_torch)] + yield (new_name, data_torch) + return if name == "visual.patch_embed.proj.weight": # split Conv3D into Conv2Ds along temporal dimension @@ -4518,20 +4486,21 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter del c1, c2 if kt != 2: raise ValueError("Current implementation only supports temporal_patch_size of 2") - return [ - (gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".weight", data_torch[:, :, 0, ...]), - (gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".weight.1", data_torch[:, :, 1, ...]), - ] + yield (gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".weight", data_torch[:, :, 0, ...]) + yield (gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".weight.1", data_torch[:, :, 1, ...]) + return if name == "visual.patch_embed.proj.bias": # Include the bias - it's used by the C++ code - return [(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".bias", data_torch)] + yield (gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".bias", data_torch) + return if name.startswith("visual."): - return [(self.map_tensor_name(name), data_torch)] + yield from super().modify_tensors(data_torch, name, bid) + return # Fall back to parent class for other tensors - return super().modify_tensors(data_torch, name, bid) + yield from super().modify_tensors(data_torch, name, bid) @ModelBase.register("Glm4vForConditionalGeneration", "Glm4vMoeForConditionalGeneration") @@ -4554,8 +4523,9 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter if name.startswith("model.visual."): name = name.replace("model.visual.", "visual.") if name.startswith("visual.merger."): - return [(self.map_tensor_name(name), data_torch)] - return super().modify_tensors(data_torch, name, bid) + yield from ModelBase.modify_tensors(self, data_torch, name, bid) + return + yield from super().modify_tensors(data_torch, name, bid) @ModelBase.register("Qwen3VLForConditionalGeneration") @@ -4573,9 +4543,9 @@ def set_gguf_parameters(self): def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: # Skip vision tensors - they go in the mmproj file if name.startswith("model.visual."): - return [] + return - return super().modify_tensors(data_torch, name, bid) + yield from super().modify_tensors(data_torch, name, bid) @ModelBase.register("Qwen3VLMoeForConditionalGeneration") @@ -4591,9 +4561,9 @@ def set_gguf_parameters(self): def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: # Skip vision tensors - they go in the mmproj file if name.startswith("model.visual."): - return [] + return - return super().modify_tensors(data_torch, name, bid) + yield from super().modify_tensors(data_torch, name, bid) @ModelBase.register("GPT2LMHeadModel") @@ -4610,22 +4580,17 @@ def set_gguf_parameters(self): self.gguf_writer.add_file_type(self.ftype) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - del bid # unused - - tensors: list[tuple[str, Tensor]] = [] - # we don't need these if name.endswith((".attn.bias", ".attn.masked_bias")): - return tensors + yield from super().modify_tensors(data_torch, name, bid) + return if name.endswith((".c_attn.weight", ".c_proj.weight", ".c_fc.weight", ".c_proj.weight")): data_torch = data_torch.transpose(1, 0) new_name = self.map_tensor_name(name) - tensors.append((new_name, data_torch)) - - return tensors + yield from super().modify_tensors(data_torch, new_name, bid) @ModelBase.register("PhiForCausalLM") @@ -4849,8 +4814,6 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter self._experts[bid][name] = data_torch if len(self._experts[bid]) >= n_experts * 3: - tensors: list[tuple[str, Tensor]] = [] - # merge the experts into a single 3d tensor for w_name in ["w1", "w2", "w3"]: datas: list[Tensor] = [] @@ -4864,14 +4827,12 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter merged_name = f"model.layers.{bid}.block_sparse_moe.experts.{w_name}.weight" - new_name = self.map_tensor_name(merged_name) - - tensors.append((new_name, data_torch)) - return tensors + yield from super().modify_tensors(data_torch, merged_name, bid) + return else: - return [] + return - return [(self.map_tensor_name(name), data_torch)] + yield from super().modify_tensors(data_torch, name, bid) def prepare_tensors(self): super().prepare_tensors() @@ -4917,8 +4878,6 @@ def shuffle_attn_output_weight(self, data_torch): return data_torch def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - del bid # unused - new_name = self.map_tensor_name(name) # shuffle for broadcasting of gqa in ggml_mul_mat @@ -4927,7 +4886,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter elif new_name.endswith("attn_output.weight"): data_torch = self.shuffle_attn_output_weight(data_torch) - return [(new_name, data_torch)] + yield from super().modify_tensors(data_torch, name, bid) @ModelBase.register("Plamo2ForCausalLM", "PLaMo2ForCausalLM") @@ -4988,8 +4947,6 @@ def set_gguf_parameters(self): self.gguf_writer.add_file_type(self.ftype) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - del bid # unused - if name.endswith(".A_log"): data_torch = -torch.exp(data_torch) elif name.endswith(".dt_bias"): @@ -5018,9 +4975,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter elif name.endswith(".norm.weight"): data_torch += 1.0 - new_name = self.map_tensor_name(name) - - return [(new_name, data_torch)] + yield from super().modify_tensors(data_torch, name, bid) @ModelBase.register("Plamo3ForCausalLM", "PLaMo3ForCausalLM") @@ -5069,7 +5024,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter elif name.endswith(".norm.weight"): data_torch = data_torch + 1.0 - return [(self.map_tensor_name(name), data_torch)] + yield from super().modify_tensors(data_torch, name, bid) @ModelBase.register("CodeShellForCausalLM") @@ -5090,40 +5045,255 @@ def set_gguf_parameters(self): self.gguf_writer.add_rope_scaling_factor(1.0) -@ModelBase.register("InternLM2ForCausalLM") -class InternLM2Model(TextModel): - model_arch = gguf.MODEL_ARCH.INTERNLM2 +@ModelBase.register("KimiLinearModel", "KimiLinearForCausalLM") +class KimiLinearModel(TextModel): + """Kimi-Linear model with hybrid MLA+KDA architecture""" + model_arch = gguf.MODEL_ARCH.KIMI_LINEAR + + _experts: list[dict[str, Tensor]] | None = None def set_vocab(self): - # (TODO): Is there a better way? - # Copy from _set_vocab_sentencepiece, The only difference is that we will treat the character - # \x00 specially and convert it into an emoji character to prevent it from being mistakenly - # recognized as an empty string in C++. - from sentencepiece import SentencePieceProcessor - from sentencepiece import sentencepiece_model_pb2 as model + try: + self._set_vocab_gpt2() + return + except Exception: + pass - tokenizer_path = self.dir_model / 'tokenizer.model' + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True) + tokpre = self.get_vocab_base_pre(tokenizer) - tokens: list[bytes] = [] - scores: list[float] = [] - toktypes: list[int] = [] + if tokpre == "kimi-k2": + # Build merges list using the approach similar to HunYuanMoE + merges = [] + vocab = {} + mergeable_ranks = tokenizer.model._mergeable_ranks + for token, rank in mergeable_ranks.items(): + vocab[QwenModel.token_bytes_to_string(token)] = rank + if len(token) == 1: + continue + merged = QwenModel.bpe(mergeable_ranks, token, max_rank=rank) + if len(merged) == 2: + merges.append(' '.join(map(QwenModel.token_bytes_to_string, merged))) + # Build token list + vocab_size = self.hparams["vocab_size"] + special_tokens = tokenizer.special_tokens + reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in {**vocab, **special_tokens}.items()} + tokens: list[str] = [] + toktypes: list[int] = [] - if not tokenizer_path.is_file(): - logger.error(f'Error: Missing {tokenizer_path}') - sys.exit(1) + for i in range(vocab_size): + if i not in reverse_vocab: + tokens.append(f"[PAD{i}]") + toktypes.append(gguf.TokenType.UNUSED) + else: + token = reverse_vocab[i] + tokens.append(token) + if i in special_tokens.values(): + toktypes.append(gguf.TokenType.CONTROL) + else: + toktypes.append(gguf.TokenType.NORMAL) - sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue] - sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read()) - add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix + self.gguf_writer.add_tokenizer_model("gpt2") + self.gguf_writer.add_tokenizer_pre(tokpre) + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_types(toktypes) + self.gguf_writer.add_token_merges(merges) - tokenizer = SentencePieceProcessor() - tokenizer.LoadFromFile(str(tokenizer_path)) + special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False) + special_vocab.add_to_gguf(self.gguf_writer) + # override eos id in config.json with tiktoken eos id + self.gguf_writer.add_eos_token_id(tokenizer.eos_id) + else: + raise NotImplementedError(f"Deepseek pre-tokenizer {tokpre!r} is not supported yet!") - vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size()) + def set_gguf_parameters(self): + # note: To enable MLA KV cache, attention needs to be converted into MQA (ie: GQA with 1 group) + self.hparams["num_key_value_heads"] = 1 - for token_id in range(vocab_size): - piece = tokenizer.IdToPiece(token_id) - text = piece.encode("utf-8") + super().set_gguf_parameters() + self.gguf_writer.add_vocab_size(self.hparams["vocab_size"]) + + # KDA & MLA params + # Get ssm_d_conv from linear_attn_config.short_conv_kernel_size or ssm_d_conv + linear_attn_config = self.hparams["linear_attn_config"] + # n_head == 0 for KDA layers, n_head > 0 for MLA layers + # full_attention_layers list will be used to distingush layer type + _num_kv_heads = list() + _full_attn_layers = linear_attn_config["full_attn_layers"] + for il in range(self.hparams["num_hidden_layers"]): + if il + 1 in _full_attn_layers: + _num_kv_heads.append(self.hparams["num_key_value_heads"]) + else: + _num_kv_heads.append(0) + assert len(_num_kv_heads) == self.hparams["num_hidden_layers"] + self.gguf_writer.add_head_count_kv(_num_kv_heads) + + if (ssm_d_conv := linear_attn_config.get("short_conv_kernel_size")) is not None: + self.gguf_writer.add_ssm_conv_kernel(ssm_d_conv) + if (kda_head_dim := linear_attn_config.get("head_dim")) is not None: + self.gguf_writer.add_kda_head_dim(kda_head_dim) + + # MLA params - use add_* methods that handle arch substitution + # Support both HuggingFace naming (q_lora_rank, kv_lora_rank) and internal naming (n_lora_q, n_lora_kv) + if (q_lora_rank := self.find_hparam(["q_lora_rank", "n_lora_q"], optional=True)) is not None: + self.gguf_writer.add_q_lora_rank(q_lora_rank) + # To enable MLA KV cache, MLA needs to be converted into MQA with larger heads, then decompresses to MHA + kv_lora_rank = self.find_hparam(["kv_lora_rank", "n_lora_kv"], optional=False) + self.gguf_writer.add_kv_lora_rank(kv_lora_rank) + + # MLA head dimensions + # Support HuggingFace naming: qk_nope_head_dim, qk_rope_head_dim, v_head_dim + qk_nope_head_dim = self.hparams.get("qk_nope_head_dim") + # Rotation - use qk_rope_head_dim for Kimi + qk_rope_head_dim = self.find_hparam(["qk_rope_head_dim", "n_rot"], optional=False) + self.gguf_writer.add_rope_dimension_count(qk_rope_head_dim) + self.gguf_writer.add_key_length(kv_lora_rank + qk_rope_head_dim) + v_head_dim = self.hparams.get("v_head_dim") + + # Calculate n_embd_head_k_mla = qk_nope_head_dim + qk_rope_head_dim + if (n_embd_head_k_mla := self.find_hparam(["n_embd_head_k_mla"], optional=True)) is not None: + self.gguf_writer.add_key_length_mla(n_embd_head_k_mla) + elif qk_nope_head_dim is not None: + n_embd_head_k_mla = qk_nope_head_dim + qk_rope_head_dim + self.gguf_writer.add_key_length_mla(n_embd_head_k_mla) + + # n_embd_head_v_mla = v_head_dim + if (n_embd_head_v_mla := self.hparams.get("n_embd_head_v_mla")) is not None: + self.gguf_writer.add_value_length_mla(n_embd_head_v_mla) + elif v_head_dim is not None: + self.gguf_writer.add_value_length_mla(v_head_dim) + + # moe_intermediate_size (1024 for Kimi) + self.gguf_writer.add_expert_feed_forward_length(self.hparams["moe_intermediate_size"]) + # num_shared_experts (1 for Kimi) + self.gguf_writer.add_expert_shared_count(self.hparams["num_shared_experts"]) + # first_k_dense_replace (1 for Kimi - first layer uses dense MLP) + self.gguf_writer.add_leading_dense_block_count(self.hparams["first_k_dense_replace"]) + # Routed scaling factor (expert_weights_scale = 2.446 for Kimi) + self.gguf_writer.add_expert_weights_scale(self.hparams["routed_scaling_factor"]) + + def prepare_tensors(self): + super().prepare_tensors() + if self._experts is not None: + experts = [k for d in self._experts for k in d.keys()] + if len(experts) > 0: + raise ValueError(f"Unprocessed experts: {experts}") + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + logger.info(f"Processing {name}: shape before = {tuple(data_torch.shape)}") + + # Handle KDA conv1d weights + # HuggingFace/vLLM stores as [d_inner, d_conv] (2D), memory layout: conv_step changes fastest + # llama.cpp expects ggml ne = [d_conv, 1, d_inner, 1], memory layout: ne[0]=d_conv changes fastest + # GGUF reverses numpy shape when writing, so numpy (1, d_inner, 1, d_conv) -> ggml ne = [d_conv, 1, d_inner, 1] + # Memory layouts match: both have conv_step (d_conv) changing fastest + if name.endswith((".q_conv1d.weight", ".k_conv1d.weight", ".v_conv1d.weight")): + # HF shape: [d_inner, d_conv] e.g. [4096, 4] + # Target numpy shape: (1, d_inner, 1, d_conv) -> ggml ne = [d_conv, 1, d_inner, 1] + if data_torch.ndim == 2: + d_inner, d_conv = data_torch.shape + # Reshape to (1, d_inner, 1, d_conv) - memory layout preserved (d_conv fastest) + data_torch = data_torch.reshape(1, d_inner, 1, d_conv) + logger.info(f"Reshaped conv1d weight {name}: [d_inner={d_inner}, d_conv={d_conv}] -> numpy {tuple(data_torch.shape)} -> ggml ne=[{d_conv}, 1, {d_inner}, 1]") + elif data_torch.ndim == 3: + # Already 3D [d_inner, 1, d_conv] from unsqueeze + d_inner, _, d_conv = data_torch.shape + data_torch = data_torch.reshape(1, d_inner, 1, d_conv) + logger.info(f"Reshaped conv1d weight {name}: [d_inner={d_inner}, 1, d_conv={d_conv}] -> numpy {tuple(data_torch.shape)} -> ggml ne=[{d_conv}, 1, {d_inner}, 1]") + + # Kimi specific bias + if name.endswith("e_score_correction_bias"): + name = name.replace("e_score_correction_bias", "e_score_correction.bias") + + # Handle A_log: iHF stores as [1, 1, num_heads, 1] + # llama.cpp expects ggml ne = [1, num_heads, 1, 1] + # GGUF reverses numpy shape: numpy (1, 1, num_heads, 1) -> ggml ne = [1, num_heads, 1, 1] + if name.endswith(".A_log"): + data_torch = -torch.exp(data_torch) + if name.endswith(".dt_bias"): + name = name.rpartition(".dt_bias")[0] + ".dt_proj.bias" + logger.info("Changed dt_bias to dt_proj.bias") + + # process the experts separately + if name.find("block_sparse_moe.experts") != -1: + n_experts = self.find_hparam(["num_local_experts", "num_experts"], optional=False) + assert bid is not None + + if self._experts is None: + self._experts = [{} for _ in range(self.block_count)] + + self._experts[bid][name] = data_torch + + if len(self._experts[bid]) >= n_experts * 3: + # merge the experts into a single 3d tensor + # w1: gate, w2: down, w3: up + for wid, tname in [("w1", gguf.MODEL_TENSOR.FFN_GATE_EXP), + ("w2", gguf.MODEL_TENSOR.FFN_DOWN_EXP), + ("w3", gguf.MODEL_TENSOR.FFN_UP_EXP)]: + datas: list[Tensor] = [] + for xid in range(n_experts): + ename = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.{wid}.weight" + datas.append(self._experts[bid][ename]) + del self._experts[bid][ename] + data_torch = torch.stack(datas, dim=0) + new_name = self.format_tensor_name(tname, bid) + yield from super().modify_tensors(data_torch, new_name, bid) + return + + # note: MLA with the absorption optimization, needs these two split and k_b_proj transposed + if name.endswith("kv_b_proj.weight"): + name_kb = name.replace("kv_b_proj", "k_b_proj") + name_vb = name.replace("kv_b_proj", "v_b_proj") + n_head_kv = self.hparams["num_key_value_heads"] + v_head_dim = self.find_hparam(["n_embd_head_v_mla", "v_head_dim"], optional=False) + qk_nope_head_dim = self.hparams["qk_nope_head_dim"] + logger.info("Split kv_b n_head_kv %d\n" % n_head_kv) + assert data_torch.shape[0] == n_head_kv * (v_head_dim + qk_nope_head_dim) + kv_b = data_torch.view(n_head_kv, v_head_dim + qk_nope_head_dim, data_torch.shape[-1]) + k_b, v_b = torch.split(kv_b, [qk_nope_head_dim, v_head_dim], dim=1) + k_b = k_b.transpose(1, 2) + yield from super().modify_tensors(k_b, name_kb, bid) + yield from super().modify_tensors(v_b, name_vb, bid) + return + + yield from super().modify_tensors(data_torch, name, bid) + + +@ModelBase.register("InternLM2ForCausalLM") +class InternLM2Model(TextModel): + model_arch = gguf.MODEL_ARCH.INTERNLM2 + + def set_vocab(self): + # (TODO): Is there a better way? + # Copy from _set_vocab_sentencepiece, The only difference is that we will treat the character + # \x00 specially and convert it into an emoji character to prevent it from being mistakenly + # recognized as an empty string in C++. + from sentencepiece import SentencePieceProcessor + from sentencepiece import sentencepiece_model_pb2 as model + + tokenizer_path = self.dir_model / 'tokenizer.model' + + tokens: list[bytes] = [] + scores: list[float] = [] + toktypes: list[int] = [] + + if not tokenizer_path.is_file(): + logger.error(f'Error: Missing {tokenizer_path}') + sys.exit(1) + + sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue] + sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read()) + add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix + + tokenizer = SentencePieceProcessor() + tokenizer.LoadFromFile(str(tokenizer_path)) + + vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size()) + + for token_id in range(vocab_size): + piece = tokenizer.IdToPiece(token_id) + text = piece.encode("utf-8") score = tokenizer.GetScore(token_id) if text == b"\x00": # (TODO): fixme @@ -5231,7 +5401,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter name = name.replace("language_model.", "") # InternVL if name.startswith("mlp") or name.startswith("vision_model"): # skip visual tensors - return [] + return if bid is not None and f"model.layers.{bid}.attention.wqkv" in name: qkv = data_torch @@ -5244,13 +5414,11 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter k = LlamaModel.permute(k.reshape((-1, k.shape[-1])), num_heads, num_kv_heads) v = v.reshape((-1, v.shape[-1])) - return [ - (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid), q), - (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid), k), - (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_V, bid), v), - ] + yield from super().modify_tensors(q, self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid), bid) + yield from super().modify_tensors(k, self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid), bid) + yield from super().modify_tensors(v, self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_V, bid), bid) else: - return [(self.map_tensor_name(name), data_torch)] + yield from super().modify_tensors(data_torch, name, bid) @ModelBase.register("InternLM3ForCausalLM") @@ -5302,12 +5470,12 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter name = name.replace("language_model.", "") # InternVL if name.startswith("mlp") or name.startswith("vision_model"): # skip visual tensors - return [] + return if name.endswith(("q_proj.weight", "q_proj.bias")): data_torch = LlamaModel.permute(data_torch, n_head, n_head) if name.endswith(("k_proj.weight", "k_proj.bias")): data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head) - return [(self.map_tensor_name(name), data_torch)] + yield from super().modify_tensors(data_torch, name, bid) @ModelBase.register("BertModel", "BertForMaskedLM", "CamembertModel", "BertForSequenceClassification") @@ -5362,8 +5530,6 @@ def phantom(tok, toktype): special_vocab.add_to_gguf(self.gguf_writer) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - del bid # unused - if name.startswith("bert."): name = name[5:] @@ -5375,13 +5541,13 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter # we are only using BERT for embeddings so we don't need the pooling layer if name in ("embeddings.position_ids", "pooler.dense.weight", "pooler.dense.bias"): - return [] # we don't need these + return # we don't need these if name.startswith("cls.predictions"): - return [] + return if name.startswith("cls.seq_relationship"): - return [] + return if self.cls_out_labels: # For BertForSequenceClassification (direct projection layer) @@ -5391,7 +5557,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter if name == "classifier.bias": name = "classifier.out_proj.bias" - return [(self.map_tensor_name(name), data_torch)] + yield from super().modify_tensors(data_torch, name, bid) def _xlmroberta_tokenizer_init(self) -> None: # we need the pad_token_id to know how to chop down position_embd matrix @@ -5546,9 +5712,9 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter # These layers act as MLM head, so we don't need them if name.startswith("vocab_"): - return [] + return - return super().modify_tensors(data_torch, name, bid) + yield from super().modify_tensors(data_torch, name, bid) @ModelBase.register("RobertaModel", "RobertaForSequenceClassification") @@ -5591,7 +5757,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter if self._position_offset is not None: data_torch = data_torch[self._position_offset:,:] - return super().modify_tensors(data_torch, name, bid) + yield from super().modify_tensors(data_torch, name, bid) @ModelBase.register("NomicBertModel") @@ -5644,7 +5810,7 @@ def set_vocab(self) -> None: def modify_tensors(self, data_torch: torch.Tensor, name: str, bid: int | None) -> Iterable[tuple[str, torch.Tensor]]: # If the tensor is an experts bias tensor, skip it by returning an empty list. if "mlp.experts.bias" in name: - return [] # Explicitly return an empty list. + return # Explicitly return. if "mlp.experts.mlp.w1" in name: data_torch = data_torch.view(self.hparams["num_experts"], self.hparams["n_inner"], self.hparams["n_embd"]) @@ -5655,7 +5821,7 @@ def modify_tensors(self, data_torch: torch.Tensor, name: str, bid: int | None) - data_torch = data_torch.transpose(1, 2) name += ".weight" - return [(self.map_tensor_name(name), data_torch)] + yield from super().modify_tensors(data_torch, name, bid) def set_gguf_parameters(self): super().set_gguf_parameters() @@ -5695,12 +5861,12 @@ def set_gguf_parameters(self): def modify_tensors(self, data_torch, name, bid): if name.startswith("decoder."): - return [] + return if name.startswith("model."): name = name[6:] - return super().modify_tensors(data_torch, name, bid) + yield from super().modify_tensors(data_torch, name, bid) @ModelBase.register("XLMRobertaModel", "XLMRobertaForSequenceClassification") @@ -5757,7 +5923,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter if name.endswith(".0.lora_A") or name.endswith(".0.lora_B"): if name.startswith("pooler.dense"): - return [] + return num_loras = data_torch.size(0) assert num_loras == len(self._lora_names) @@ -5773,9 +5939,9 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter new_name = new_name[:-1] + ("a" if new_name[-1:] == "b" else "b") lora_writer.add_tensor(new_name, data.float().numpy(), raw_dtype=gguf.GGMLQuantizationType.F32) - return [] + return - return super().modify_tensors(data_torch, name, bid) + yield from super().modify_tensors(data_torch, name, bid) def set_gguf_parameters(self): super().set_gguf_parameters() @@ -5834,19 +6000,17 @@ def set_gguf_parameters(self): self.gguf_writer.add_file_type(self.ftype) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - del bid # unused - # lm_head is not used in llama.cpp, while autoawq will include this tensor in model # To prevent errors, skip loading lm_head.weight. if name == "lm_head.weight": logger.debug(f"Skipping get tensor {name!r} in safetensors so that convert can end normally.") - return [] + return # ref: https://github.com/huggingface/transformers/blob/fc37f38915372c15992b540dfcbbe00a916d4fc6/src/transformers/models/gemma/modeling_gemma.py#L89 if name.endswith("norm.weight"): data_torch = data_torch + 1 - return [(self.map_tensor_name(name), data_torch)] + yield from super().modify_tensors(data_torch, name, bid) @ModelBase.register("Gemma2ForCausalLM") @@ -5880,19 +6044,17 @@ def set_gguf_parameters(self): self.gguf_writer.add_sliding_window(self.hparams["sliding_window"]) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - del bid # unused - # lm_head is not used in llama.cpp, while autoawq will include this tensor in model # To prevent errors, skip loading lm_head.weight. if name == "lm_head.weight": logger.debug(f"Skipping get tensor {name!r} in safetensors so that convert can end normally.") - return [] + return # ref: https://github.com/huggingface/transformers/blob/fc37f38915372c15992b540dfcbbe00a916d4fc6/src/transformers/models/gemma/modeling_gemma.py#L89 if name.endswith("norm.weight"): data_torch = data_torch + 1 - return [(self.map_tensor_name(name), data_torch)] + yield from super().modify_tensors(data_torch, name, bid) @ModelBase.register("Gemma3ForCausalLM", "Gemma3ForConditionalGeneration") @@ -5927,14 +6089,12 @@ def set_gguf_parameters(self): self.gguf_writer.add_head_count_kv(hparams.get("num_key_value_heads", 4)) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - del bid # unused - if "language_model." in name: name = name.replace("language_model.", "") elif name.startswith("multi_modal_projector.") or name.startswith("vision_tower.") \ or name.startswith("multimodal_projector.") or name.startswith("vision_model."): - return [] # skip vision tensors + return # skip vision tensors # remove OOV (out-of-vocabulary) rows in token_embd if "embed_tokens.weight" in name: @@ -5950,7 +6110,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter if name.endswith("norm.weight"): data_torch = data_torch + self.norm_shift - return [(self.map_tensor_name(name), data_torch)] + yield from super().modify_tensors(data_torch, name, bid) @ModelBase.register("Gemma3TextModel") @@ -6056,10 +6216,8 @@ def tensor_force_quant(self, name, new_name, bid, n_dims): return super().tensor_force_quant(name, new_name, bid, n_dims) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - del bid # unused - if "vision_model.head." in name: - return [] # skip redundant tensors for tinygemma3 + return # skip redundant tensors for tinygemma3 if name.startswith("multi_modal_projector.") or name.startswith("vision_tower.") \ or name.startswith("multimodal_projector.") or name.startswith("vision_model."): @@ -6073,9 +6231,9 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter logger.info(f"Correcting norm value for '{name}'") data_torch = data_torch + 1 - return [(self.map_tensor_name(name), data_torch)] + yield from super().modify_tensors(data_torch, name, bid) - return [] # skip other tensors + return # skip other tensors class ConformerAudioModel(MmprojModel): @@ -6100,7 +6258,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter self._batch_norm_tensors[bid][name] = data_torch if len(self._batch_norm_tensors[bid]) < 5: - return [] + return weight = self._batch_norm_tensors[bid][f"conformer.layers.{bid}.conv.batch_norm.weight"] bias = self._batch_norm_tensors[bid][f"conformer.layers.{bid}.conv.batch_norm.bias"] @@ -6110,10 +6268,9 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter a = weight / torch.sqrt(running_var + eps) b = bias - running_mean * a - return [ - (self.map_tensor_name(f"conformer.layers.{bid}.conv.batch_norm.weight"), a), - (self.map_tensor_name(f"conformer.layers.{bid}.conv.batch_norm.bias"), b), - ] + yield from super().modify_tensors(a, f"conformer.layers.{bid}.conv.batch_norm.weight", bid) + yield from super().modify_tensors(b, f"conformer.layers.{bid}.conv.batch_norm.bias", bid) + return # reshape conv weights if name.startswith("conformer.pre_encode.conv.") and name.endswith(".bias"): @@ -6125,7 +6282,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter assert data_torch.shape[2] == 1 data_torch = data_torch.reshape(data_torch.shape[0], data_torch.shape[1]) - return [(self.map_tensor_name(name), data_torch)] + yield from super().modify_tensors(data_torch, name, bid) @ModelBase.register("Gemma3nForConditionalGeneration") @@ -6224,18 +6381,19 @@ def custom_map(self, name: str) -> str: def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: if (ConformerAudioModel.is_audio_tensor(name)): name = name.replace("model.audio_tower.conformer.", "conformer.layers.") - return super().modify_tensors(data_torch, name, bid) + yield from super().modify_tensors(data_torch, name, bid) # Gemma3n uses # - model.embed_vision.* for projection layers # - model.vision_tower.* for vision encoder # Skip non-vision tensors if not (name.startswith("model.embed_vision.") or name.startswith("model.vision_tower.")): - return [] + return if name.startswith("model.vision_tower.timm_model.blocks."): # Double-indexed block tensors through custom logic - new_name = self.custom_map(name) + yield (self.custom_map(name), data_torch) + return else: # Route non-repeating (conv_stem, msfa, embedding, etc.) and un-catched through tensor_mapping.py new_name = self.map_tensor_name(name) @@ -6243,7 +6401,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter if new_name.endswith("conv_stem.conv.bias") or new_name.endswith("layer_scale.gamma"): data_torch = data_torch.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) # [1, C, 1, 1] - return [(new_name, data_torch)] + yield from ModelBase.modify_tensors(self, data_torch, new_name, bid) @ModelBase.register("Gemma3nForCausalLM", "Gemma3nForConditionalGeneration") @@ -6321,7 +6479,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter # TODO: implement self.prediction_coefs.weight.clamp_(...) if "language_model." not in name: - return [] # skip non-language model tensors + return # skip non-language model tensors # Pad token embeddings for vision/audio special tokens (262144-262399) if "embed_tokens.weight" in name or "embed_tokens_per_layer" in name: @@ -6343,7 +6501,8 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter # Continue with normal processing name = name.replace("language_model.", "") - return [(self.map_tensor_name(name), data_torch)] + yield from ModelBase.modify_tensors(self, data_torch, name, bid) + return if "altup_unembed_projections" in name: data_torch = data_torch.to(device="cpu") @@ -6359,9 +6518,10 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter raise ValueError(f"Unknown name: {name}") out = self._stack_matrices(self._altup_unembd) if out is not None: - return [(self.map_tensor_name("model.altup_unembed_projections.weight"), out)] + yield from ModelBase.modify_tensors(self, out, "model.altup_unembed_projections.weight", bid) + return else: - return [] + return if "altup_projections" in name: data_torch = data_torch.to(device="cpu") @@ -6375,11 +6535,12 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter raise ValueError(f"Unknown name: {name}") out = self._stack_matrices(self._altup_proj) if out is not None: - return [(self.map_tensor_name("model.altup_projections.weight"), out)] + yield from ModelBase.modify_tensors(self, out, "model.altup_projections.weight", bid) + return else: - return [] + return - return super().modify_tensors(data_torch, name, bid) + yield from super().modify_tensors(data_torch, name, bid) @ModelBase.register("Starcoder2ForCausalLM") @@ -6762,11 +6923,11 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter if self._tok_embd is not None and new_name == output_name: if torch.equal(self._tok_embd, data_torch): logger.debug(f"{output_name} is equivalent to {tok_embd_name}, omitting") - return [] + return elif new_name == tok_embd_name: self._tok_embd = data_torch - return [(new_name, data_torch)] + yield from super().modify_tensors(data_torch, new_name, bid) @ModelBase.register("Mamba2ForCausalLM") @@ -7022,8 +7183,6 @@ def set_gguf_parameters(self): # Same as super class, but permuting q_proj, k_proj # Copied from: LlamaModel def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - del bid # unused - n_head = self.hparams["num_attention_heads"] n_kv_head = self.hparams.get("num_key_value_heads") @@ -7032,7 +7191,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter if name.endswith("k_proj.weight"): data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head) - return [(self.map_tensor_name(name), data_torch)] + yield from super().modify_tensors(data_torch, name, bid) @ModelBase.register("SeedOssForCausalLM") @@ -7088,8 +7247,6 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter self._experts[bid][name] = data_torch if len(self._experts[bid]) >= n_experts * 3: - tensors: list[tuple[str, Tensor]] = [] - # merge the experts into a single 3d tensor for w_name in ["down_proj", "gate_proj", "up_proj"]: datas: list[Tensor] = [] @@ -7103,14 +7260,12 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight" - new_name = self.map_tensor_name(merged_name) - - tensors.append((new_name, data_torch)) - return tensors + yield from super().modify_tensors(data_torch, merged_name, bid) + return else: - return [] + return - return [(self.map_tensor_name(name), data_torch)] + yield from super().modify_tensors(data_torch, name, bid) # Copied from: Qwen2MoeModel def prepare_tensors(self): @@ -7141,6 +7296,130 @@ def set_vocab(self): raise NotImplementedError(f'Tokenizer {tokenizer_class} is not supported for JinaBertModel') +@ModelBase.register("JinaCLIPModel") +class JinaCLIPTextModel(XLMRobertaModel): + model_arch = gguf.MODEL_ARCH.BERT + _text_prefix = "text_model.transformer." + + @staticmethod + def _load_json_file(path: Path) -> dict[str, Any]: + with open(path, "r", encoding="utf-8") as f: + return json.load(f) + + @staticmethod + def _load_hf_config_json(hf_name_or_path: str) -> dict[str, Any]: + p = Path(hf_name_or_path) + if p.is_dir(): + cfg_path = p / "config.json" + if cfg_path.is_file(): + return JinaCLIPTextModel._load_json_file(cfg_path) + + try: + from huggingface_hub import hf_hub_download + except Exception: + raise ImportError( + "huggingface_hub is required to fetch the text tower config.json for JinaClip; " + "install this package or provide a local path in text_config.hf_model_name_or_path." + ) + + try: + cfg_path = Path(hf_hub_download(repo_id=hf_name_or_path, filename="config.json", local_files_only=True)) + except Exception: + cfg_path = Path(hf_hub_download(repo_id=hf_name_or_path, filename="config.json", local_files_only=False)) + return JinaCLIPTextModel._load_json_file(cfg_path) + + def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, **kwargs: Any): + jinaclip_hparams = ModelBase.load_hparams(dir_model, False) + text_cfg = jinaclip_hparams.get("text_config") or {} + hf_name = text_cfg.get("hf_model_name_or_path") + if not hf_name: + raise KeyError("JinaCLIPTextModel: missing text_config.hf_model_name_or_path in config.json") + + base_cfg = self._load_hf_config_json(str(hf_name)) + + overrides = text_cfg.get("hf_model_config_kwargs") or {} + if not isinstance(overrides, dict): + raise TypeError("JinaCLIPTextModel: text_config.hf_model_config_kwargs must be a dict") + + merged_hparams = {**base_cfg, **overrides} + + kwargs["hparams"] = merged_hparams + + super().__init__(dir_model, ftype, fname_out, **kwargs) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + if not name.startswith(self._text_prefix): + return [] + + name = name[len(self._text_prefix):] + return super().modify_tensors(data_torch, name, bid) + + +@ModelBase.register("JinaCLIPModel") +class JinaCLIPVisionModel(MmprojModel): + + def set_gguf_parameters(self): + cfg = self.hparams + + width = int(self.find_hparam(["width"])) + head_width = int(self.find_hparam(["head_width"])) + layers = int(self.find_hparam(["layers"])) + image_size = int(self.find_hparam(["image_size"])) + patch_size = int(self.find_hparam(["patch_size"])) + + if width % head_width != 0: + raise ValueError( + f"JinaCLIPVisionModel: width ({width}) not divisible by head_width ({head_width})" + ) + n_head = width // head_width + + if "mlp_ratio" in cfg: + n_ff = int(width * float(cfg["mlp_ratio"])) + elif bool(cfg.get("naive_swiglu", False)): + n_ff = int((width * 8) // 3) + else: + raise ValueError("JinaCLIPVisionModel: unable to infer FFN size; please provide 'mlp_ratio' or set 'naive_swiglu' in config.json") + + self.gguf_writer.add_file_type(self.ftype) + self.gguf_writer.add_clip_has_vision_encoder(True) + proj_dim = int(self.global_config.get("projection_dim") or cfg.get("embed_dim") or width) + self.gguf_writer.add_vision_projection_dim(proj_dim) + + self.gguf_writer.add_vision_image_size(image_size) + self.gguf_writer.add_vision_patch_size(patch_size) + self.gguf_writer.add_vision_embedding_length(width) + self.gguf_writer.add_vision_feed_forward_length(n_ff) + self.gguf_writer.add_vision_block_count(layers) + self.gguf_writer.add_vision_head_count(n_head) + + self.gguf_writer.add_vision_attention_layernorm_eps(float(cfg.get("layer_norm_eps", 1e-6))) + + # JinaClip v2 uses mean/std in preprocessor_config.json + mean = self.preprocessor_config["mean"] + std = self.preprocessor_config["std"] + self.gguf_writer.add_vision_image_mean(mean) + self.gguf_writer.add_vision_image_std(std) + + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.JINACLIP2) + self.gguf_writer.add_vision_use_silu(True) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + if name.startswith("vision_model."): + name = name[len("vision_model."):] + elif not (name.startswith("v.") or name.startswith("mm.")): + return [] + + if name == "pos_embed": + pos_name = gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_POS] + ".weight" + return [(pos_name, data_torch)] + + try: + return [(self.map_tensor_name(name), data_torch)] + except Exception: + logger.debug("mmproj(jinaclip): skip unmapped tensor %s", name) + return [] + + @ModelBase.register("OpenELMForCausalLM") class OpenELMModel(TextModel): model_arch = gguf.MODEL_ARCH.OPENELM @@ -7333,8 +7612,6 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter self._experts[bid][name] = data_torch if len(self._experts[bid]) >= n_experts * 3: - tensors: list[tuple[str, Tensor]] = [] - # merge the experts into a single 3d tensor for wid in ["w1", "w2", "w3"]: datas: list[Tensor] = [] @@ -7348,14 +7625,12 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter merged_name = f"layers.{bid}.feed_forward.experts.{wid}.weight" - new_name = self.map_tensor_name(merged_name) - - tensors.append((new_name, data_torch)) - return tensors + yield from super().modify_tensors(data_torch, merged_name, bid) + return else: - return [] + return - return [(self.map_tensor_name(name), data_torch)] + yield from super().modify_tensors(data_torch, name, bid) def prepare_tensors(self): super().prepare_tensors() @@ -7422,8 +7697,6 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter self._experts[bid][name] = data_torch if len(self._experts[bid]) >= n_experts * 3: - tensors: list[tuple[str, Tensor]] = [] - # merge the experts into a single 3d tensor for w_name in ["down_proj", "gate_proj", "up_proj"]: datas: list[Tensor] = [] @@ -7437,14 +7710,12 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight" - new_name = self.map_tensor_name(merged_name) - - tensors.append((new_name, data_torch)) - return tensors + yield from super().modify_tensors(data_torch, merged_name, bid) + return else: - return [] + return - return [(self.map_tensor_name(name), data_torch)] + yield from super().modify_tensors(data_torch, name, bid) def prepare_tensors(self): super().prepare_tensors() @@ -7580,9 +7851,9 @@ def set_gguf_parameters(self): def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: # skip vision tensors and remove "language_model." for Kimi-VL if "vision_tower" in name or "multi_modal_projector" in name: - return [] + return if name.startswith("siglip2.") or name.startswith("merger."): - return [] + return if name.startswith("language_model."): name = name.replace("language_model.", "") @@ -7590,7 +7861,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter if self.hparams.get("tie_word_embeddings", False): if name == "lm_head.weight" or name == "model.lm_head.weight": logger.info("Skipping tied output layer 'lm_head.weight' (will use token_embd.weight)") - return [] + return # rename e_score_correction_bias tensors if name.endswith("e_score_correction_bias"): @@ -7600,7 +7871,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter block_count = self.hparams["num_hidden_layers"] match = re.match(r"model.layers.(\d+)", name) if match and int(match.group(1)) >= block_count: - return [] + return # process the experts separately if name.find("mlp.experts") != -1: @@ -7613,8 +7884,6 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter self._experts[bid][name] = data_torch if len(self._experts[bid]) >= n_experts * 3: - tensors: list[tuple[str, Tensor]] = [] - # merge the experts into a single 3d tensor for w_name in ["down_proj", "gate_proj", "up_proj"]: datas: list[Tensor] = [] @@ -7628,12 +7897,10 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight" - new_name = self.map_tensor_name(merged_name) - - tensors.append((new_name, data_torch)) - return tensors + yield from super().modify_tensors(data_torch, merged_name, bid) + return else: - return [] + return # note: MLA with the absorption optimization, needs these two split and k_b_proj transposed if name.endswith("kv_b_proj.weight"): @@ -7650,12 +7917,11 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter k_b, v_b = torch.split(kv_b, [qk_nope_head_dim, v_head_dim], dim=1) k_b = k_b.transpose(1, 2) - return [ - (self.map_tensor_name(name_kb), k_b), - (self.map_tensor_name(name_vb), v_b) - ] + yield from super().modify_tensors(k_b, name_kb, bid) + yield from super().modify_tensors(v_b, name_vb, bid) + return - return [(self.map_tensor_name(name), data_torch)] + yield from super().modify_tensors(data_torch, name, bid) def prepare_tensors(self): super().prepare_tensors() @@ -7697,9 +7963,8 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None): # not enough expert weights to merge if len(expert_cache) < n_experts * len(expert_weights): - return [] + return - tensors: list[tuple[str, Tensor]] = [] for w_name in expert_weights: datas: list[Tensor] = [] @@ -7711,12 +7976,12 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None): data_torch = torch.stack(datas, dim=0) merged_name = f"model.layers.{bid}.block_sparse_moe.experts.{w_name}.weight" new_name = self.map_tensor_name(merged_name) - tensors.append((new_name, data_torch)) + yield from super().modify_tensors(data_torch, new_name, bid) del self._experts_cache[bid] - return tensors + return - return super().modify_tensors(data_torch, name, bid) + yield from super().modify_tensors(data_torch, name, bid) @ModelBase.register("MiMoV2FlashForCausalLM") @@ -7758,7 +8023,7 @@ def modify_tensors(self, data_torch, name, bid): # TODO: mimo v2 does not indicate the number of next-token-prediction layers, therefore we cannot do the same way as GLM4_MOE if "model.mtp." in name: - return [] + return # process the experts separately if name.find("mlp.experts") != -1: @@ -7771,8 +8036,6 @@ def modify_tensors(self, data_torch, name, bid): self._experts[bid][name] = data_torch if len(self._experts[bid]) >= n_experts * 3: - tensors: list[tuple[str, Tensor]] = [] - # merge the experts into a single 3d tensor for w_name in ["gate_proj", "up_proj", "down_proj"]: datas: list[Tensor] = [] @@ -7784,13 +8047,12 @@ def modify_tensors(self, data_torch, name, bid): data_torch = torch.stack(datas, dim=0) merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight" - new_name = self.map_tensor_name(merged_name) - tensors.append((new_name, data_torch)) - return tensors + yield from super().modify_tensors(data_torch, merged_name, bid) + return else: - return [] - return [(self.map_tensor_name(name), data_torch)] + return + yield from super().modify_tensors(data_torch, name, bid) def prepare_tensors(self): super().prepare_tensors() @@ -7802,6 +8064,135 @@ def prepare_tensors(self): raise ValueError(f"Unprocessed experts: {experts}") +@ModelBase.register("Step3p5ForCausalLM") +class Step35Model(TextModel): + model_arch = gguf.MODEL_ARCH.STEP35 + + def set_gguf_parameters(self): + rope_theta = self.hparams.get("rope_theta") + if isinstance(rope_theta, list): + self.hparams["rope_theta"] = float(rope_theta[0]) + self.hparams["local_rope_theta"] = float(rope_theta[1]) + self.rope_parameters["rope_theta"] = self.hparams["rope_theta"] + self.rope_parameters["sliding_attention"] = {"rope_theta": self.hparams["local_rope_theta"]} + + super().set_gguf_parameters() + + layer_types = self.hparams.get("layer_types") or [] + partial_rotary_factors = self.hparams.get("partial_rotary_factors") or [] + attn_other = self.hparams.get("attention_other_setting") or {} + + n_head_base = self.hparams["num_attention_heads"] + n_kv_base = self.hparams["num_attention_groups"] + + n_head_swa = attn_other.get("num_attention_heads", n_head_base) + n_kv_swa = attn_other.get("num_attention_groups", n_kv_base) + + layer_types = layer_types[: self.block_count] + partial_rotary_factors = partial_rotary_factors[: self.block_count] + assert [1.0 if lt == "sliding_attention" else 0.5 for lt in layer_types] == partial_rotary_factors + head_arr = [n_head_swa if lt == "sliding_attention" else n_head_base for lt in layer_types] + kv_arr = [n_kv_swa if lt == "sliding_attention" else n_kv_base for lt in layer_types] + swa_pat = [lt == "sliding_attention" for lt in layer_types] + + self.gguf_writer.add_head_count(head_arr) + self.gguf_writer.add_head_count_kv(kv_arr) + + self.gguf_writer.add_sliding_window(self.hparams["sliding_window"]) + self.gguf_writer.add_sliding_window_pattern(swa_pat) + + self.gguf_writer.add_value_length(self.hparams["head_dim"]) + + # MoE params + self.gguf_writer.add_expert_count(self.hparams["moe_num_experts"]) + self.gguf_writer.add_expert_used_count(self.hparams["moe_top_k"]) + self.gguf_writer.add_expert_feed_forward_length(self.hparams["moe_intermediate_size"]) + self.gguf_writer.add_expert_shared_feed_forward_length(self.hparams["share_expert_dim"]) + + if (moe_router_scaling_factor := self.hparams.get("moe_router_scaling_factor")) is not None: + self.gguf_writer.add_expert_weights_scale(moe_router_scaling_factor) + if (norm_expert_weight := self.hparams.get("norm_expert_weight")) is not None: + self.gguf_writer.add_expert_weights_norm(norm_expert_weight) + + # leading dense blocks + leading_dense = 0 + moe_layers_enum = self.hparams.get("moe_layers_enum") + if isinstance(moe_layers_enum, str) and moe_layers_enum.strip(): + moe_layers = sorted(int(i) for i in moe_layers_enum.strip().split(",")) + if moe_layers: + leading_dense = max(0, moe_layers[0]) + self.gguf_writer.add_leading_dense_block_count(leading_dense) + self.gguf_writer.add_moe_every_n_layers(int(self.hparams.get("moe_every_n_layer", 1))) + + self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("rms_norm_eps", 1e-5)) + + # Optional per-layer SwiGLU clamps. + if (limits := self.hparams.get("swiglu_limits")) is not None: + limits_f = [0.0 if v is None else float(v) for v in limits[: self.block_count]] + self.gguf_writer.add_swiglu_clamp_exp(limits_f) + if (limits_shared := self.hparams.get("swiglu_limits_shared")) is not None: + limits_shared_f = [0.0 if v is None else float(v) for v in limits_shared[: self.block_count]] + self.gguf_writer.add_swiglu_clamp_shexp(limits_shared_f) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None): + # remove mtp layers + if (m := re.match(r"model\.layers\.(\d+)\.", name)) is not None: + il = int(m.group(1)) + n_main = int(self.hparams.get("num_hidden_layers", self.block_count)) + if il >= n_main: + return + if name.endswith("norm.weight"): + data_torch += 1.0 + # Map router bias (expert selection bias) to a GGUF bias tensor + if name.endswith(".moe.router_bias"): + name += ".bias" + + if name.endswith((".self_attn.g_proj.weight", ".moe.gate.weight", ".moe.up_proj.weight", ".moe.gate_proj.weight", ".moe.down_proj.weight")): + data_torch = data_torch.squeeze().contiguous() + + yield from super().modify_tensors(data_torch, name, bid) + + def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: + # Step35 can optionally use Llama-3 style RoPE scaling (HF: rope_scaling.rope_type == "llama3"). + # llama.cpp represents this via a single extra tensor: "rope_freqs.weight" (aka MODEL_TENSOR.ROPE_FREQS). + rope_params = self.rope_parameters.get("full_attention", self.rope_parameters) + rope_type = rope_params.get("rope_type") or "" + if rope_type.lower() != "llama3": + return + + # Step35 configs can carry per-layer rope_theta as a list; for llama3 rope factors we use the base value. + rope_theta = self.hparams.get("rope_theta", 10000.0) + if isinstance(rope_theta, list): + rope_theta = rope_theta[0] + base = float(rope_theta) + if (dim := self.hparams.get("head_dim")) is None: + dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"] + dim = int(dim) + + freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + + factor = float(rope_params.get("factor", 8.0)) + low_freq_factor = float(rope_params.get("low_freq_factor", 1.0)) + high_freq_factor = float(rope_params.get("high_freq_factor", 4.0)) + old_context_len = int(rope_params.get("original_max_position_embeddings", self.hparams.get("original_max_position_embeddings", 8192))) + + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + + rope_factors: list[float] = [] + for freq in freqs: + wavelen = 2 * math.pi / float(freq) + if wavelen < high_freq_wavelen: + rope_factors.append(1.0) + elif wavelen > low_freq_wavelen: + rope_factors.append(factor) + else: + smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) + rope_factors.append(1.0 / ((1.0 - smooth) / factor + smooth)) + + yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), torch.tensor(rope_factors, dtype=torch.float32)) + + @ModelBase.register("PanguEmbeddedForCausalLM") class PanguEmbeddedModel(TextModel): model_arch = gguf.MODEL_ARCH.PANGU_EMBED @@ -7834,8 +8225,8 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter if name == "lm_head.weight": if self.hparams.get("tie_word_embeddings", False): logger.info("Skipping tied output layer 'lm_head.weight'") - return [] - return [(self.map_tensor_name(name), data_torch)] + return + yield from super().modify_tensors(data_torch, name, bid) @ModelBase.register("Dots1ForCausalLM") @@ -7857,8 +8248,9 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None): if name.endswith("e_score_correction_bias"): name = name.replace("e_score_correction_bias", "e_score_correction.bias") if "shared_experts" in name: - return [(self.map_tensor_name(name), data_torch)] - return super().modify_tensors(data_torch, name, bid) + yield from ModelBase.modify_tensors(self, data_torch, name, bid) + else: + yield from super().modify_tensors(data_torch, name, bid) @ModelBase.register("PLMForCausalLM") @@ -7877,9 +8269,6 @@ def set_gguf_parameters(self): self.gguf_writer.add_value_length(hparams["v_head_dim"]) self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"]) - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - return [(self.map_tensor_name(name), data_torch)] - def prepare_tensors(self): super().prepare_tensors() @@ -8010,8 +8399,6 @@ def set_gguf_parameters(self): self.gguf_writer.add_file_type(self.ftype) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - del bid # unused - # T5 based models contain shared token embeddings tensors saved randomly as either "encoder.embed_tokens.weight", # "decoder.embed_tokens.weight" or "shared.weight" tensor. In some models there are even multiple of them stored # in the safetensors files. We use the first tensor from these three as the token embeddings for both encoder @@ -8022,9 +8409,9 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter self.shared_token_embeddings_found = True else: logger.debug(f"Skipping shared tensor {name!r} in safetensors so that convert can end normally.") - return [] + return - return [(self.map_tensor_name(name), data_torch)] + yield from super().modify_tensors(data_torch, name, bid) @ModelBase.register("T5EncoderModel") @@ -8146,8 +8533,6 @@ def set_gguf_parameters(self): self.gguf_writer.add_file_type(self.ftype) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - del bid # unused - # T5 based models contain shared token embeddings tensors saved randomly as either "encoder.embed_tokens.weight", # "decoder.embed_tokens.weight" or "shared.weight" tensor. In some models there are even multiple of them stored # in the safetensors files. We use the first tensor from these three as the token embeddings for both encoder @@ -8158,9 +8543,9 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter self.shared_token_embeddings_found = True else: logger.debug(f"Skipping shared tensor {name!r} in safetensors so that convert can end normally.") - return [] + return - return [(self.map_tensor_name(name), data_torch)] + yield from super().modify_tensors(data_torch, name, bid) @ModelBase.register("JAISLMHeadModel") @@ -8208,13 +8593,9 @@ def set_gguf_parameters(self): self.gguf_writer.add_file_type(self.ftype) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - del bid # unused - - tensors: list[tuple[str, Tensor]] = [] - # we don't need these if name.endswith((".attn.bias")): - return tensors + return if name.endswith(("relative_pe.slopes")): # Calculate max ALiBi bias (this is the inverse of the ALiBi calculation) @@ -8225,7 +8606,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter first_val = float(data_torch[0].item()) self.max_alibi_bias = -round(math.log2(first_val) * n_head_closest_log2) - return tensors + return if name.endswith((".c_attn.weight", ".c_proj.weight", ".c_fc.weight", ".c_fc2.weight")): data_torch = data_torch.transpose(1, 0) @@ -8233,13 +8614,11 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter new_name = self.map_tensor_name(name) if new_name == self.format_tensor_name(gguf.MODEL_TENSOR.TOKEN_EMBD): - tensors.append((new_name, data_torch * self.embeddings_scale)) + yield from super().modify_tensors(data_torch * self.embeddings_scale, new_name, bid) elif new_name == self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT): - tensors.append((new_name, data_torch * self.width_scale)) + yield from super().modify_tensors(data_torch * self.width_scale, new_name, bid) else: - tensors.append((new_name, data_torch)) - - return tensors + yield from super().modify_tensors(data_torch, new_name, bid) def prepare_tensors(self): super().prepare_tensors() @@ -8305,7 +8684,7 @@ def normal_to_neox(weights: Tensor, n_head: int, n_head_kv: int, head_dim: int, def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: if name.startswith("model.visual."): # ignore visual part of Glm4v - return [] + return elif name.startswith("model.language_model."): name = name.replace("language_model.", "") # for Glm4v if self.use_mrope: @@ -8318,7 +8697,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter data_torch = Glm4Model.normal_to_neox(data_torch, n_head, n_head, head_dim, self.partial_rotary_factor) if name.endswith(("k_proj.weight", "k_proj.bias")): data_torch = Glm4Model.normal_to_neox(data_torch, n_head, n_kv_head, head_dim, self.partial_rotary_factor) - return super().modify_tensors(data_torch, name, bid) + yield from super().modify_tensors(data_torch, name, bid) @ModelBase.register("Glm4MoeForCausalLM", "Glm4vMoeForConditionalGeneration") @@ -8393,13 +8772,14 @@ def modify_tensors( self, data_torch: Tensor, name: str, bid: int | None ) -> Iterable[tuple[str, Tensor]]: if name.startswith("model.visual."): # ignore visual part - return [] + return elif name.startswith("model.language_model."): name = name.replace("language_model.", "") # for multimodal variants # Handle main token embedding (but not layer-specific NextN embeddings) if name == "model.embed_tokens.weight" and ".layers." not in name: - return [(self.map_tensor_name("token_embd.weight"), data_torch)] + yield from super().modify_tensors(data_torch, "token_embd.weight", bid) + return # Handle routed experts if name.find("mlp.experts") != -1: @@ -8412,8 +8792,6 @@ def modify_tensors( self._experts[bid][name] = data_torch if len(self._experts[bid]) >= n_experts * 3: - tensors: list[tuple[str, Tensor]] = [] - # merge the experts into a single 3d tensor for w_name in ["down_proj", "gate_proj", "up_proj"]: datas: list[Tensor] = [] @@ -8427,18 +8805,15 @@ def modify_tensors( merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight" - new_name = self.map_tensor_name(merged_name) - tensors.append((new_name, data_torch)) - return tensors + yield from super().modify_tensors(data_torch, merged_name, bid) + return else: - return [] + return if name.endswith("e_score_correction_bias"): name = name.replace("e_score_correction_bias", "e_score_correction.bias") - new_name = self.map_tensor_name(name) - - return [(new_name, data_torch)] + yield from super().modify_tensors(data_torch, name, bid) def prepare_tensors(self): super().prepare_tensors() @@ -8621,13 +8996,11 @@ def set_gguf_parameters(self): self.gguf_writer.add_rope_freq_base(rope_freq) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - del bid # unused - if name.endswith(".rotary_pos_emb.inv_freq") or name.startswith("model.vision."): - return [] + return name = name.removeprefix("transformer.") - return [(self.map_tensor_name(name), data_torch)] + yield from super().modify_tensors(data_torch, name, bid) @ModelBase.register("NemotronForCausalLM") @@ -8668,7 +9041,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter if name.endswith("norm.weight"): data_torch = data_torch + 1 - return [(self.map_tensor_name(name), data_torch)] + yield from super().modify_tensors(data_torch, name, bid) @ModelBase.register("ExaoneForCausalLM") @@ -8824,11 +9197,9 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter new_name = remapper[_n.stem] + _n.suffix # set shared weights for all NextN/MTP layers - tensors = [] for bid in range(self.hparams['num_hidden_layers'], self.block_count): - new_name = new_name.format(bid=bid) - tensors.append((self.map_tensor_name(new_name), data_torch)) - return tensors + yield from super().modify_tensors(data_torch, new_name.format(bid=bid), bid) + return if name.endswith("e_score_correction_bias"): name = name.replace("e_score_correction_bias", "e_score_correction.bias") @@ -8843,8 +9214,6 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter self._experts[bid][name] = data_torch if len(self._experts[bid]) >= n_experts * 3: - tensors: list[tuple[str, Tensor]] = [] - # merge the experts into a single 3d tensor for w_name in ["down_proj", "gate_proj", "up_proj"]: datas: list[Tensor] = [] @@ -8860,12 +9229,12 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter new_name = self.map_tensor_name(merged_name) - tensors.append((new_name, data_torch)) - return tensors + yield from super().modify_tensors(data_torch, new_name, bid) + return else: - return [] + return - return [(self.map_tensor_name(name), data_torch)] + yield from super().modify_tensors(data_torch, name, bid) def prepare_tensors(self): super().prepare_tensors() @@ -8935,10 +9304,9 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter ffn_dim = self.hparams["intermediate_size"] assert data_torch.shape[-2] == 2 * ffn_dim, "Merged FFN tensor size must be 2 * intermediate_size" gate, up = data_torch.split(ffn_dim, dim=-2) - return [ - (self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_EXP, bid), gate), - (self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_EXP, bid), up), - ] + yield from ModelBase.modify_tensors(self, gate, self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_EXP, bid), bid) + yield from ModelBase.modify_tensors(self, up, self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_EXP, bid), bid) + return has_experts = bool(self.hparams.get('num_local_experts')) @@ -8947,21 +9315,18 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter assert data_torch.shape[-2] == 2 * ffn_dim, "Merged FFN tensor size must be 2 * shared_intermediate_size" gate, up = data_torch.split(ffn_dim, dim=-2) if has_experts: - return [ - (self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_SHEXP, bid), gate), - (self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_SHEXP, bid), up), - ] - return [ - (self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE, bid), gate), - (self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP, bid), up), - ] + yield from ModelBase.modify_tensors(self, gate,self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_SHEXP, bid), bid) + yield from ModelBase.modify_tensors(self, up, self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_SHEXP, bid), bid) + return + yield from ModelBase.modify_tensors(self, gate, self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE, bid), bid) + yield from ModelBase.modify_tensors(self, up, self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP, bid), bid) + return if not has_experts and name.endswith("shared_mlp.output_linear.weight"): - return [ - (self.format_tensor_name(gguf.MODEL_TENSOR.FFN_DOWN, bid), data_torch) - ] + yield from ModelBase.modify_tensors(self, data_torch, self.format_tensor_name(gguf.MODEL_TENSOR.FFN_DOWN, bid), bid) + return - return super().modify_tensors(data_torch, name, bid) + yield from super().modify_tensors(data_torch, name, bid) @ModelBase.register("GraniteMoeHybridForCausalLM", "BambaForCausalLM") @@ -9048,14 +9413,17 @@ def modify_tensors( name.endswith("block_sparse_moe.input_linear.weight") or "shared_mlp" in name ): - return GraniteMoeModel.modify_tensors(self, data_torch, name, bid) + yield from GraniteMoeModel.modify_tensors(self, data_torch, name, bid) + return # Determine whether this is a mamba layer or an attention layer if bid in self._ssm_layers: - return Mamba2Model.modify_tensors(self, data_torch, name, bid) + yield from Mamba2Model.modify_tensors(self, data_torch, name, bid) + return elif bid in self._attn_layers: - return GraniteMoeModel.modify_tensors(self, data_torch, name, bid) - return [(self.map_tensor_name(name), data_torch)] + yield from GraniteMoeModel.modify_tensors(self, data_torch, name, bid) + return + yield from ModelBase.modify_tensors(self, data_torch, name, bid) def set_gguf_parameters(self): """This method merges params from both parents and some that are @@ -9187,34 +9555,34 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter if self.is_moe and bid is not None: if name.endswith("mixer.gate.e_score_correction_bias"): new_name = name.replace("e_score_correction_bias", "e_score_correction.bias") - mapped_name = self.map_tensor_name(new_name) - return [(mapped_name, data_torch)] + yield from ModelBase.modify_tensors(self, data_torch, new_name, bid) + return if name.endswith("mixer.dt_bias"): new_name = name.replace("dt_bias", "dt.bias") - mapped_name = self.map_tensor_name(new_name) - return [(mapped_name, data_torch)] + yield from ModelBase.modify_tensors(self, data_torch, new_name, bid) + return if name.endswith("mixer.conv1d.weight"): squeezed_data = data_torch.squeeze() - mapped_name = self.map_tensor_name(name) - return [(mapped_name, squeezed_data)] + yield from ModelBase.modify_tensors(self, squeezed_data, name, bid) + return if name.endswith("mixer.A_log"): transformed_data = -torch.exp(data_torch) reshaped_data = transformed_data.squeeze().reshape(-1, 1) - mapped_name = self.map_tensor_name(name) - return [(mapped_name, reshaped_data)] + yield from ModelBase.modify_tensors(self, reshaped_data, name, bid) + return if name.endswith("mixer.D"): reshaped_data = data_torch.squeeze().reshape(-1, 1) - mapped_name = self.map_tensor_name(name) - return [(mapped_name, reshaped_data)] + yield from ModelBase.modify_tensors(self, reshaped_data, name, bid) + return if name.endswith("mixer.norm.weight"): reshaped_data = data_torch.reshape(self.n_group, -1) - mapped_name = self.map_tensor_name(name) - return [(mapped_name, reshaped_data)] + yield from ModelBase.modify_tensors(self, reshaped_data, name, bid) + return if name.find("mixer.experts") != -1: n_experts = self.hparams["n_routed_experts"] @@ -9227,7 +9595,6 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter if len(self._experts[bid]) >= n_experts * 2: # merge the experts into a single tensor - tensors: list[tuple[str, Tensor]] = [] for w_name in ["down_proj", "up_proj"]: datas: list[Tensor] = [] @@ -9238,14 +9605,13 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter data_torch = torch.stack(datas, dim=0) merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight" - new_name = self.map_tensor_name(merged_name) - tensors.append((new_name, data_torch)) - return tensors + yield from ModelBase.modify_tensors(self, data_torch, merged_name, bid) + return else: - return [] + return - return super().modify_tensors(data_torch, name, bid) + yield from super().modify_tensors(data_torch, name, bid) def prepare_tensors(self): super().prepare_tensors() @@ -9304,21 +9670,19 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter output_name = self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT) if name.endswith("attention.dense.weight"): - return [(self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_OUT, bid), data_torch)] + yield from super().modify_tensors(data_torch, self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_OUT, bid), bid) + return elif name.endswith("query_key_value.weight"): q, k, v = data_torch.split([n_head * head_dim, n_kv_head * head_dim, n_kv_head * head_dim], dim=-2) - return [ - (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid), BailingMoeModel.permute(q, n_head, n_head)), - (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid), BailingMoeModel.permute(k, n_head, n_kv_head)), - (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_V, bid), v) - ] + yield from super().modify_tensors(BailingMoeModel.permute(q, n_head, n_head), self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid), bid) + yield from super().modify_tensors(BailingMoeModel.permute(k, n_head, n_kv_head), self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid), bid) + yield from super().modify_tensors(v,self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_V, bid), bid) + return elif name.find("mlp.experts") != -1: n_experts = self.hparams["num_experts"] assert bid is not None - tensors: list[tuple[str, Tensor]] = [] - if self._experts is None: self._experts = [{} for _ in range(self.block_count)] @@ -9340,9 +9704,9 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter new_name = self.map_tensor_name(merged_name) - tensors.append((new_name, data_torch)) + yield from super().modify_tensors(data_torch, new_name, bid) - return tensors + return new_name = self.map_tensor_name(name) @@ -9350,7 +9714,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter data_torch = data_torch.float() data_torch /= torch.norm(data_torch, p=2, dim=0, keepdim=True) + 1e-7 - return [(new_name, data_torch)] + yield from super().modify_tensors(data_torch, new_name, bid) def prepare_tensors(self): super().prepare_tensors() @@ -9401,8 +9765,6 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter n_experts = self.hparams["num_experts"] assert bid is not None - tensors: list[tuple[str, Tensor]] = [] - if self._experts is None: self._experts = [{} for _ in range(self.block_count)] @@ -9422,16 +9784,13 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight" - new_name = self.map_tensor_name(merged_name) - - tensors.append((new_name, data_torch)) - - return tensors + yield from super().modify_tensors(data_torch, merged_name, bid) + return if name.endswith(".expert_bias"): name = name.replace(".expert_bias", ".expert_bias.bias") - return [(self.map_tensor_name(name), data_torch)] + yield from super().modify_tensors(data_torch, name, bid) def prepare_tensors(self): super().prepare_tensors() @@ -9467,7 +9826,7 @@ def set_gguf_parameters(self): def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: if name.endswith(".expert_bias"): # FIXME?: Unused https://huggingface.co/inclusionAI/GroveMoE-Inst/blob/c4c69e5970d18907b5e6ddccdfd55176fe292df1/modeling_grove_moe.py#L303 - return [] + return # process the experts separately if name.find("chunk_experts") != -1: @@ -9480,8 +9839,6 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter self._chunk_experts[bid][name] = data_torch if len(self._chunk_experts[bid]) >= n_experts * 3: - tensors: list[tuple[str, Tensor]] = [] - # merge the experts into a single 3d tensor for w_name in ["down_proj", "gate_proj", "up_proj"]: datas: list[Tensor] = [] @@ -9495,12 +9852,10 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter merged_name = f"model.layers.{bid}.mlp.chunk_experts.{w_name}.weight" - new_name = self.map_tensor_name(merged_name) - - tensors.append((new_name, data_torch)) - return tensors + yield from super().modify_tensors(data_torch, merged_name, bid) + return else: - return [] + return elif name.find("experts") != -1: n_experts = self.hparams["num_experts"] assert bid is not None @@ -9511,8 +9866,6 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter self._experts[bid][name] = data_torch if len(self._experts[bid]) >= n_experts * 3: - tensors: list[tuple[str, Tensor]] = [] - # merge the experts into a single 3d tensor for w_name in ["down_proj", "gate_proj", "up_proj"]: datas: list[Tensor] = [] @@ -9526,14 +9879,12 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight" - new_name = self.map_tensor_name(merged_name) - - tensors.append((new_name, data_torch)) - return tensors + yield from super().modify_tensors(data_torch, merged_name, bid) + return else: - return [] + return - return [(self.map_tensor_name(name), data_torch)] + yield from super().modify_tensors(data_torch, name, bid) def prepare_tensors(self): super().prepare_tensors() @@ -9567,7 +9918,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter # ignore image tokenizer for now # TODO: remove this once image support is implemented for Chameleon if name.startswith("model.vqmodel"): - return [] + return n_head = self.hparams["num_attention_heads"] n_kv_head = self.hparams.get("num_key_value_heads") @@ -9582,7 +9933,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter if name.endswith(("k_norm.weight", "k_norm.bias")): data_torch = ChameleonModel._reverse_hf_permute(data_torch, n_kv_head, hidden_dim) - return [(self.map_tensor_name(name), data_torch)] + yield from super().modify_tensors(data_torch, name, bid) # see: https://github.com/huggingface/transformers/blob/72fb02c47dbbe1999ae105319f24631cad6e2e00/src/transformers/models/chameleon/convert_chameleon_weights_to_hf.py#L176-L203 @staticmethod @@ -9627,11 +9978,9 @@ def tensor_force_quant(self, name, new_name, bid, n_dims): return super().tensor_force_quant(name, new_name, bid, n_dims) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - del bid # unused - if name.startswith("model.") or name.startswith("lm_head."): # skip language model tensors - return [] + return if name.startswith("audio_encoder.whisper."): name = name.replace("audio_encoder.whisper.","audio_tower.") @@ -9639,7 +9988,9 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter name = name.replace("audio_encoder.", "audio_encoder.adapting.") if name.startswith("audio_encoder.audio_bos_eos_token."): - return [(self.map_tensor_name("model.vision.boi"), data_torch[0]), (self.map_tensor_name("model.vision.eoi"), data_torch[1])] + yield from super().modify_tensors(data_torch[0], "model.vision.boi", bid) + yield from super().modify_tensors(data_torch[1], "model.vision.eoi", bid) + return if name.startswith("audio_encoder.adapting."): name = name.replace("audio_encoder.adapting.","audio.multi_modal_projector.") @@ -9650,13 +10001,13 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter if ".2." in name: name = name.replace(".2.", ".linear_2.") if ".proj." in name: - return [] + return if "conv1.bias" in name or "conv2.bias" in name: # transpose conv1 and conv2 bias data_torch = data_torch.unsqueeze(-1) - return [(self.map_tensor_name(name), data_torch)] + yield from super().modify_tensors(data_torch, name, bid) @ModelBase.register("Qwen2AudioForConditionalGeneration") @@ -9683,11 +10034,9 @@ def tensor_force_quant(self, name, new_name, bid, n_dims): return super().tensor_force_quant(name, new_name, bid, n_dims) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - del bid # unused - if name.startswith("language_model."): # skip language model tensors - return [] + return # prevent clash naming with vision tensors if name.startswith("multi_modal_projector"): @@ -9697,7 +10046,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter # transpose conv1 and conv2 bias data_torch = data_torch.unsqueeze(-1) - return [(self.map_tensor_name(name), data_torch)] + yield from super().modify_tensors(data_torch, name, bid) @ModelBase.register("UltravoxModel") @@ -9941,7 +10290,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter if name == "lm_head.weight": if self.hparams.get("tie_word_embeddings", False): logger.info("Skipping tied output layer 'lm_head.weight'") - return [] + return if name.find("mlp.experts") != -1: n_experts = self.hparams["num_experts"] @@ -9954,7 +10303,6 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter if len(self._experts[bid]) >= n_experts * 3: # merge the experts into a single 3d tensor - tensors: list[tuple[str, Tensor]] = [] for w_name in ["down_proj", "gate_proj", "up_proj"]: datas: list[Tensor] = [] @@ -9965,14 +10313,13 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter data_torch = torch.stack(datas, dim=0) merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight" - new_name = self.map_tensor_name(merged_name) - tensors.append((new_name, data_torch)) - return tensors + yield from super().modify_tensors(data_torch, merged_name, bid) + return else: - return [] + return - return [(self.map_tensor_name(name), data_torch)] + yield from super().modify_tensors(data_torch, name, bid) def prepare_tensors(self): super().prepare_tensors() @@ -10017,8 +10364,6 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter self._experts[bid][name] = data_torch if len(self._experts[bid]) >= n_experts * 3: - tensors: list[tuple[str, Tensor]] = [] - # merge the experts into a single 3d tensor for w_name in ["down_proj", "gate_proj", "up_proj"]: datas: list[Tensor] = [] @@ -10032,14 +10377,12 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight" - new_name = self.map_tensor_name(merged_name) - - tensors.append((new_name, data_torch)) - return tensors + yield from super().modify_tensors(data_torch, merged_name, bid) + return else: - return [] + return - return [(self.map_tensor_name(name), data_torch)] + yield from super().modify_tensors(data_torch, name, bid) # Copied from: Qwen2MoeModel def prepare_tensors(self): @@ -10138,9 +10481,9 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter if name == "lm_head.weight": if self.hparams.get("tie_word_embeddings", False): logger.info("Skipping tied output layer 'lm_head.weight'") - return [] + return - return [(self.map_tensor_name(name), data_torch)] + yield from super().modify_tensors(data_torch, name, bid) @ModelBase.register("SmolLM3ForCausalLM") @@ -10220,8 +10563,6 @@ def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: return [] def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - del bid # unused - if "sinks" in name: name += ".weight" @@ -10235,7 +10576,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter data_torch = data_torch.transpose(-1, -2) else: # otherwise, it should already be repacked to ggml MXFP4 format - return [] + return # split the gate_up into gate and up if "gate_up_proj" in name: @@ -10243,25 +10584,18 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter name_up = name.replace("gate_up_proj_bias", "up_proj.bias") name_gate = name.replace("gate_up_proj_bias", "gate_proj.bias") gate_proj_bias, up_proj_bias = data_torch[..., ::2], data_torch[..., 1::2] - return [ - (self.map_tensor_name(name_gate), gate_proj_bias), - (self.map_tensor_name(name_up), up_proj_bias) - ] + yield from super().modify_tensors(gate_proj_bias, name_gate, bid) + yield from super().modify_tensors(up_proj_bias, name_up, bid) elif "_blocks" not in name and "_scales" not in name: logger.warning(f"{name} is not in MXFP4, performance may be degraded") name_up = name.replace("gate_up_proj", "up_proj.weight") name_gate = name.replace("gate_up_proj", "gate_proj.weight") data_torch = data_torch.transpose(-1, -2) gate_proj_weight, up_proj_weight = data_torch[:, ::2, :], data_torch[:, 1::2, :] - return [ - (self.map_tensor_name(name_gate), gate_proj_weight), - (self.map_tensor_name(name_up), up_proj_weight) - ] - else: - # otherwise, it should already be repacked to ggml MXFP4 format - return [] - - return [(self.map_tensor_name(name), data_torch)] + yield from super().modify_tensors(gate_proj_weight, name_gate, bid) + yield from super().modify_tensors(up_proj_weight, name_up, bid) + else: + yield from super().modify_tensors(data_torch, name, bid) def set_vocab(self): self._set_vocab_gpt2() @@ -10309,7 +10643,7 @@ def set_gguf_parameters(self): def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: if self._is_vision_tensor(name) or ConformerAudioModel.is_audio_tensor(name): # skip multimodal tensors - return [] + return name = name.replace("language_model.", "") # vision name = name.replace("lfm.", "model.") # audio @@ -10318,7 +10652,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter if 'conv.conv' in name: data_torch = data_torch.squeeze(1) - return [(self.map_tensor_name(name), data_torch)] + yield from super().modify_tensors(data_torch, name, bid) def _is_vision_tensor(self, name: str) -> bool: return "vision_tower" in name or "multi_modal_projector" in name @@ -10333,7 +10667,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter if not name.startswith(self.dense_tensor_name): name = "model." + name - return super().modify_tensors(data_torch, name, bid) + yield from super().modify_tensors(data_torch, name, bid) def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: # dense tensor is stored in a separate safetensors file @@ -10388,9 +10722,8 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter # not enough expert weights to merge if len(expert_cache) < n_experts * len(expert_weights): - return [] + return - tensors: list[tuple[str, Tensor]] = [] for w_name in expert_weights: datas: list[Tensor] = [] @@ -10401,13 +10734,13 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter data_torch = torch.stack(datas, dim=0) merged_name = f"layers.{bid}.feed_forward.experts.{w_name}.weight" - new_name = self.map_tensor_name(merged_name) - tensors.append((new_name, data_torch)) + + yield from super().modify_tensors(data_torch, merged_name, bid) del self._experts_cache[bid] - return tensors + return - return [(self.map_tensor_name(name), data_torch)] + yield from super().modify_tensors(data_torch, name, bid) def prepare_tensors(self): super().prepare_tensors() @@ -10433,7 +10766,6 @@ def set_gguf_parameters(self): self.gguf_writer.add_vision_block_count(self.find_vparam(self.n_block_keys) - vision_feature_layers_to_drop) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - del bid # unused is_vision_tensor = "vision_tower" in name or "multi_modal_projector" in name if is_vision_tensor: @@ -10444,9 +10776,10 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter if "patch_embedding.weight" in name: data_torch = data_torch.view(data_torch.shape[0], 16, 16, 3).permute(0, 3, 1, 2) - return [(self.map_tensor_name(name), data_torch)] + yield from super().modify_tensors(data_torch, name, bid) + return - return [] # skip other tensors + return # skip other tensors @ModelBase.register("Lfm2AudioForConditionalGeneration") @@ -10471,17 +10804,17 @@ def set_gguf_parameters(self): def modify_tensors(self, data_torch, name, bid): # skip language model tensors if name.startswith("lfm."): - return [] + return # for training only if any(p in name for p in ["audio_loss_weight"]): - return [] + return # for audio output if any(p in name for p in ["codebook_offsets", "depth_embeddings", "depth_linear", "depthformer"]): - return [] + return - return super().modify_tensors(data_torch, name, bid) + yield from super().modify_tensors(data_torch, name, bid) @ModelBase.register("SmallThinkerForCausalLM") @@ -10526,8 +10859,6 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter self._experts[bid][name] = data_torch if len(self._experts[bid]) >= n_experts * 3: - tensors: list[tuple[str, Tensor]] = [] - # merge the experts into a single 3d tensor for w_name in ["down", "gate", "up"]: datas: list[Tensor] = [] @@ -10541,14 +10872,12 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter merged_name = f"model.layers.{bid}.block_sparse_moe.experts.{w_name}.weight" - new_name = self.map_tensor_name(merged_name) - - tensors.append((new_name, data_torch)) - return tensors + yield from super().modify_tensors(data_torch, merged_name, bid) + return else: - return [] + return - return [(self.map_tensor_name(name), data_torch)] + yield from super().modify_tensors(data_torch, name, bid) def prepare_tensors(self): super().prepare_tensors() @@ -10581,12 +10910,12 @@ def set_gguf_parameters(self): def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: # these layers act as MLM head, so we don't need them if name.startswith("decoder."): - return [] + return if name.startswith("model."): name = name[6:] - return super().modify_tensors(data_torch, name, bid) + yield from super().modify_tensors(data_torch, name, bid) @ModelBase.register("ApertusForCausalLM") @@ -10606,24 +10935,24 @@ def modify_tensors(self, data_torch, name, bid): self._alpha_n[bid] = data_torch.to("cpu").float().item() if (len(self._alpha_n) == n_layers): self.gguf_writer.add_xielu_alpha_n([self._alpha_n[k] for k in sorted(self._alpha_n)]) - return [] + return if name.endswith(".act_fn.alpha_p"): self._alpha_p[bid] = data_torch.to("cpu").float().item() if (len(self._alpha_p) == n_layers): self.gguf_writer.add_xielu_alpha_p([self._alpha_p[k] for k in sorted(self._alpha_p)]) - return [] + return if name.endswith(".act_fn.beta"): self._beta[bid] = data_torch.to("cpu").float().item() if (len(self._beta) == n_layers): self.gguf_writer.add_xielu_beta([self._beta[k] for k in sorted(self._beta)]) - return [] + return if name.endswith(".act_fn.eps"): self._eps[bid] = data_torch.to("cpu").float().item() if (len(self._eps) == n_layers): self.gguf_writer.add_xielu_eps([self._eps[k] for k in sorted(self._eps)]) - return [] + return - return super().modify_tensors(data_torch, name, bid) + yield from super().modify_tensors(data_torch, name, bid) class MistralModel(LlamaModel): @@ -10786,7 +11115,7 @@ def set_gguf_parameters(self): def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None): if name.startswith("vision_") or name.startswith("patch_merger.") or "mm_projector" in name: - return [] + return # rename certain tensors so that we can reuse DeepseekV2Model modify_tensors logic if name.endswith(".qscale_act"): @@ -10802,7 +11131,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None): name = name.replace(".w3.", ".up_proj.") name = "model." + name - return super().modify_tensors(data_torch, name, bid) + yield from super().modify_tensors(data_torch, name, bid) class PixtralModel(LlavaVisionModel): @@ -10847,7 +11176,7 @@ def set_gguf_parameters(self): def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None): name = name.replace("model.vision_encoder.", "vision_tower.") name = name.replace("model.vision_projection.", "multi_modal_projector.") - return super().modify_tensors(data_torch, name, bid) + yield from super().modify_tensors(data_torch, name, bid) @ModelBase.register("KimiVLForConditionalGeneration") @@ -10867,24 +11196,20 @@ def set_gguf_parameters(self): self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams_vision.get("layer_norm_eps", 1e-5)) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - del bid # unused is_vision_tensor = "vision_tower" in name or "multi_modal_projector" in name if is_vision_tensor: if "pos_emb.weight" in name: data_torch = data_torch.view(data_torch.shape[0] * data_torch.shape[1], data_torch.shape[2]) - elif "wqkv" in name: + + if "wqkv" in name: split_dim = 0 if "weight" in name else -1 wq, wk, wv = data_torch.chunk(3, dim=split_dim) - return [ - (self.map_tensor_name(name.replace("wqkv", "wq")), wq), - (self.map_tensor_name(name.replace("wqkv", "wk")), wk), - (self.map_tensor_name(name.replace("wqkv", "wv")), wv) - ] - - return [(self.map_tensor_name(name), data_torch)] - - return [] # skip other tensors + yield from super().modify_tensors(wq, name.replace("wqkv", "wq"), bid) + yield from super().modify_tensors(wk, name.replace("wqkv", "wk"), bid) + yield from super().modify_tensors(wv, name.replace("wqkv", "wv"), bid) + else: + yield from super().modify_tensors(data_torch, name, bid) @ModelBase.register("CogVLMForCausalLM") @@ -10896,12 +11221,10 @@ def set_gguf_parameters(self): self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.COGVLM) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - del bid # unused - if not name.startswith("model.vision."): - return [] + return - return [(self.map_tensor_name(name), data_torch)] + yield from super().modify_tensors(data_torch, name, bid) @ModelBase.register("CogVLMForCausalLM") @@ -10909,13 +11232,11 @@ class CogVLMModel(LlamaModel): model_arch = gguf.MODEL_ARCH.COGVLM def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - del bid # unused - # block vision tensors if name.startswith("model.vision."): - return [] + return - return [(self.map_tensor_name(name), data_torch)] + yield from ModelBase.modify_tensors(self, data_torch, name, bid) @ModelBase.register("JanusForConditionalGeneration") @@ -10933,14 +11254,14 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter 'model.generation_head.', ) if name.startswith(skip_prefixes): - return [] + return if name.startswith('model.language_model.'): name = name.replace('model.language_model.', 'model.') elif name.startswith('language_model.'): name = name.replace('language_model.', '') - return super().modify_tensors(data_torch, name, bid) + yield from super().modify_tensors(data_torch, name, bid) @ModelBase.register("JanusForConditionalGeneration") @@ -10993,11 +11314,9 @@ def _map_aligner_tensor(self, data_torch: Tensor, name: str) -> Iterable[tuple[s return [(tensor_name, data_torch)] def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - del bid # unused - # Skip language model tensors as they will be handled by `JanusProModel` if name.startswith(('model.language_model.', 'language_model.')): - return [] + return # Skip generation-related components skip_generation_prefixes = ( @@ -11011,17 +11330,19 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter 'generation_head.', ) if name.startswith(skip_generation_prefixes): - return [] + return # Handle aligner tensors if name.startswith(('model.aligner.', 'aligner.')): - return list(self._map_aligner_tensor(data_torch, name)) + yield from self._map_aligner_tensor(data_torch, name) + return # Handle vision tensors if name.startswith(('model.vision_model.', 'vision_model.')): - return [(self.map_tensor_name(name), data_torch)] + yield from super().modify_tensors(data_torch, name, bid) + return - return [] + return @ModelBase.register("YoutuVLForConditionalGeneration") @@ -11060,21 +11381,18 @@ def set_gguf_parameters(self): self.gguf_writer.add_vision_wa_layer_indexes(layers=fullatt_block_indexes) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - del bid # unused - # Skip language model tensors skip_prefixes = ('lm_head.', 'model.layers.', 'model.embed_tokens.', 'model.norm.') if name.startswith(skip_prefixes): - return [] + return # Try to map the tensor using TensorNameMap (handles vision encoder and projector) try: - new_name = self.map_tensor_name(name) - return [(new_name, data_torch)] + yield from super().modify_tensors(data_torch, name, bid) except ValueError: # If mapping fails, log warning and skip logger.warning(f"Cannot map tensor: {name}") - return [] + return @ModelBase.register("SolarOpenForCausalLM") diff --git a/docs/backend/SYCL.md b/docs/backend/SYCL.md index bcb3ce6743d..b3cff96604e 100644 --- a/docs/backend/SYCL.md +++ b/docs/backend/SYCL.md @@ -22,12 +22,11 @@ - **DPCPP** *(Data Parallel C++)*: The primary oneAPI SYCL implementation, which includes the icpx/icx Compilers. - **oneAPI Libraries**: A set of highly optimized libraries targeting multiple domains *(e.g. Intel oneMKL, oneMath and oneDNN)*. - **oneAPI LevelZero**: A high performance low level interface for fine-grained control over Intel iGPUs and dGPUs. -- **Nvidia & AMD Plugins**: These are plugins extending oneAPI's DPCPP support to SYCL on Nvidia and AMD GPU targets. ### Llama.cpp + SYCL The llama.cpp SYCL backend is primarily designed for **Intel GPUs**. -SYCL cross-platform capabilities enable support for Nvidia GPUs as well, with limited support for AMD. +SYCL cross-platform capabilities enable support for other vendor GPUs as well. ## Recommended Release @@ -35,13 +34,16 @@ The following releases are verified and recommended: |Commit ID|Tag|Release|Verified Platform| Update date| |-|-|-|-|-| -|24e86cae7219b0f3ede1d5abdf5bf3ad515cccb8|b5377 |[llama-b5377-bin-win-sycl-x64.zip](https://github.com/ggml-org/llama.cpp/releases/download/b5377/llama-b5377-bin-win-sycl-x64.zip) |ArcB580/Linux/oneAPI 2025.1
LNL Arc GPU/Windows 11/oneAPI 2025.1.1|2025-05-15| -|3bcd40b3c593d14261fb2abfabad3c0fb5b9e318|b4040 |[llama-b4040-bin-win-sycl-x64.zip](https://github.com/ggml-org/llama.cpp/releases/download/b4040/llama-b4040-bin-win-sycl-x64.zip) |Arc770/Linux/oneAPI 2024.1
MTL Arc GPU/Windows 11/oneAPI 2024.1| 2024-11-19| -|fb76ec31a9914b7761c1727303ab30380fd4f05c|b3038 |[llama-b3038-bin-win-sycl-x64.zip](https://github.com/ggml-org/llama.cpp/releases/download/b3038/llama-b3038-bin-win-sycl-x64.zip) |Arc770/Linux/oneAPI 2024.1
MTL Arc GPU/Windows 11/oneAPI 2024.1|| +|24e86cae7219b0f3ede1d5abdf5bf3ad515cccb8|b5377 |[llama-b5377-bin-win-sycl-x64.zip](https://github.com/ggml-org/llama.cpp/releases/download/b5377/llama-b5377-bin-win-sycl-x64.zip) |Arc B580/Linux/oneAPI 2025.1
LNL Arc GPU/Windows 11/oneAPI 2025.1.1|2025-05-15| +|3bcd40b3c593d14261fb2abfabad3c0fb5b9e318|b4040 |[llama-b4040-bin-win-sycl-x64.zip](https://github.com/ggml-org/llama.cpp/releases/download/b4040/llama-b4040-bin-win-sycl-x64.zip) |Arc A770/Linux/oneAPI 2024.1
MTL Arc GPU/Windows 11/oneAPI 2024.1| 2024-11-19| +|fb76ec31a9914b7761c1727303ab30380fd4f05c|b3038 |[llama-b3038-bin-win-sycl-x64.zip](https://github.com/ggml-org/llama.cpp/releases/download/b3038/llama-b3038-bin-win-sycl-x64.zip) |Arc A770/Linux/oneAPI 2024.1
MTL Arc GPU/Windows 11/oneAPI 2024.1|| ## News +- 2026.02 + - Remove support for Nvidia & AMD GPU, because the oneAPI plugin for Nvidia & AMD GPU is unavailable: download/installation channels are out of work. User can't build up the software for Nvidia & AMD GPU. + - 2025.11 - Support malloc memory on device more than 4GB. @@ -51,7 +53,7 @@ The following releases are verified and recommended: |-|-|-|-| |PVC 1550|39|73|+87%| |Flex 170|39|50|+28%| - |Arc770|42|55|+30%| + |Arc A770|42|55|+30%| |MTL|13|16|+23%| |ARL-H|14|17|+21%| @@ -62,7 +64,7 @@ The following releases are verified and recommended: - Use oneDNN as the default GEMM library, improve the compatibility for new Intel GPUs. - 2024.5 - - Performance is increased: 34 -> 37 tokens/s of llama-2-7b.Q4_0 on Arc770. + - Performance is increased: 34 -> 37 tokens/s of llama-2-7b.Q4_0 on Arc A770. - Arch Linux is verified successfully. - 2024.4 @@ -111,14 +113,15 @@ On older Intel GPUs, you may try [OpenCL](/docs/backend/OPENCL.md) although the |-------------------------------|---------|---------------------------------------| | Intel Data Center Max Series | Support | Max 1550, 1100 | | Intel Data Center Flex Series | Support | Flex 170 | -| Intel Arc Series | Support | Arc 770, 730M, Arc A750, B580 | +| Intel Arc A-Series | Support | Arc A770, Arc A730M, Arc A750 | +| Intel Arc B-Series | Support | Arc B580 | | Intel built-in Arc GPU | Support | built-in Arc GPU in Meteor Lake, Arrow Lake, Lunar Lake | | Intel iGPU | Support | iGPU in 13700k, 13400, i5-1250P, i7-1260P, i7-1165G7 | *Notes:* - **Memory** - - The device memory is a limitation when running a large model. The loaded model size, *`llm_load_tensors: buffer_size`*, is displayed in the log when running `./bin/llama-cli`. + - The device memory is a limitation when running a large model. The loaded model size, *`llm_load_tensors: buffer_size`*, is displayed in the log when running `./bin/llama-completion`. - Please make sure the GPU shared memory from the host is large enough to account for the model's size. For e.g. the *llama-2-7b.Q4_0* requires at least 8.0GB for integrated GPU and 4.0GB for discrete GPU. - **Execution Unit (EU)** @@ -126,20 +129,7 @@ On older Intel GPUs, you may try [OpenCL](/docs/backend/OPENCL.md) although the ### Other Vendor GPU -**Verified devices** - -| Nvidia GPU | Status | Verified Model | -|--------------------------|-----------|----------------| -| Ampere Series | Supported | A100, A4000 | -| Ampere Series *(Mobile)* | Supported | RTX 40 Series | - -| AMD GPU | Status | Verified Model | -|--------------------------|--------------|----------------| -| Radeon Pro | Experimental | W6800 | -| Radeon RX | Experimental | 6700 XT | - -Note: AMD GPU support is highly experimental and is incompatible with F16. -Additionally, it only supports GPUs with a sub_group_size (warp size) of 32. +NA ## Docker @@ -148,11 +138,11 @@ The docker build option is currently limited to *Intel GPU* targets. ### Build image ```sh -# Using FP16 -docker build -t llama-cpp-sycl --build-arg="GGML_SYCL_F16=ON" --target light -f .devops/intel.Dockerfile . - # Using FP32 docker build -t llama-cpp-sycl --build-arg="GGML_SYCL_F16=OFF" --target light -f .devops/intel.Dockerfile . + +# Using FP16 +docker build -t llama-cpp-sycl --build-arg="GGML_SYCL_F16=ON" --target light -f .devops/intel.Dockerfile . ``` *Notes*: @@ -211,14 +201,6 @@ Platform #0: Intel(R) OpenCL HD Graphics `-- Device #0: Intel(R) Iris(R) Xe Graphics [0x9a49] ``` -- **Nvidia GPU** - -In order to target Nvidia GPUs through SYCL, please make sure the CUDA/CUBLAS native requirements *-found [here](README.md#cuda)-* are installed. - -- **AMD GPU** - -To target AMD GPUs with SYCL, the ROCm stack must be installed first. - 2. **Install Intel® oneAPI Base toolkit** SYCL backend depends on: @@ -247,23 +229,6 @@ Upon a successful installation, SYCL is enabled for the available intel devices, |2025.1| |2024.1| -- **Adding support to Nvidia GPUs** - -**oneAPI Plugin**: In order to enable SYCL support on Nvidia GPUs, please install the [Codeplay oneAPI Plugin for Nvidia GPUs](https://developer.codeplay.com/products/oneapi/nvidia/download). User should also make sure the plugin version matches the installed base toolkit one *(previous step)* for a seamless "oneAPI on Nvidia GPU" setup. - -**oneDNN**: The current oneDNN releases *(shipped with the oneAPI base-toolkit)* do not include the NVIDIA backend. Therefore, oneDNN must be compiled from source to enable the NVIDIA target: - -```sh -git clone https://github.com/oneapi-src/oneDNN.git -cd oneDNN -cmake -GNinja -Bbuild-nvidia -DDNNL_CPU_RUNTIME=DPCPP -DDNNL_GPU_RUNTIME=DPCPP -DDNNL_GPU_VENDOR=NVIDIA -DONEDNN_BUILD_GRAPH=OFF -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -cmake --build build-nvidia --config Release -``` - -- **Adding support to AMD GPUs** - -**oneAPI Plugin**: In order to enable SYCL support on AMD GPUs, please install the [Codeplay oneAPI Plugin for AMD GPUs](https://developer.codeplay.com/products/oneapi/amd/download). As with Nvidia GPUs, the user should also make sure the plugin version matches the installed base toolkit. - 3. **Verify installation and environment** In order to check the available SYCL devices on the machine, please use the `sycl-ls` command. @@ -284,25 +249,6 @@ When targeting an intel GPU, the user should expect one or more devices among th [opencl:gpu][opencl:2] Intel(R) OpenCL Graphics, Intel(R) UHD Graphics 730 OpenCL 3.0 NEO [24.39.31294] ``` -- **Nvidia GPU** - -Similarly, user targeting Nvidia GPUs should expect at least one SYCL-CUDA device [`cuda:gpu`] as below: - -``` -[opencl:acc][opencl:0] Intel(R) FPGA Emulation Platform for OpenCL(TM), Intel(R) FPGA Emulation Device OpenCL 1.2 [2023.16.12.0.12_195853.xmain-hotfix] -[opencl:cpu][opencl:1] Intel(R) OpenCL, Intel(R) Xeon(R) Gold 6326 CPU @ 2.90GHz OpenCL 3.0 (Build 0) [2023.16.12.0.12_195853.xmain-hotfix] -[cuda:gpu][cuda:0] NVIDIA CUDA BACKEND, NVIDIA A100-PCIE-40GB 8.0 [CUDA 12.5] -``` - -- **AMD GPU** - -For AMD GPUs we should expect at least one SYCL-HIP device [`hip:gpu`]: - -``` -[opencl:cpu][opencl:0] Intel(R) OpenCL, 12th Gen Intel(R) Core(TM) i9-12900K OpenCL 3.0 (Build 0) [2024.18.6.0.02_160000] -[hip:gpu][hip:0] AMD HIP BACKEND, AMD Radeon PRO W6800 gfx1030 [HIP 60140.9] -``` - ### II. Build llama.cpp #### Intel GPU @@ -331,47 +277,6 @@ It is possible to come across some precision issues when running tests that stem instructions, which can be circumvented by setting the environment variable `SYCL_PROGRAM_COMPILE_OPTIONS` as `-cl-fp32-correctly-rounded-divide-sqrt` -#### Nvidia GPU - -The SYCL backend depends on [oneMath](https://github.com/uxlfoundation/oneMath) for Nvidia and AMD devices. -By default it is automatically built along with the project. A specific build can be provided by setting the CMake flag `-DoneMath_DIR=/path/to/oneMath/install/lib/cmake/oneMath`. - -```sh -# Build LLAMA with Nvidia BLAS acceleration through SYCL -# Setting GGML_SYCL_DEVICE_ARCH is optional but can improve performance -GGML_SYCL_DEVICE_ARCH=sm_80 # Example architecture - -# Option 1: Use FP32 (recommended for better performance in most cases) -cmake -B build -DGGML_SYCL=ON -DGGML_SYCL_TARGET=NVIDIA -DGGML_SYCL_DEVICE_ARCH=${GGML_SYCL_DEVICE_ARCH} -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DDNNL_DIR=/path/to/oneDNN/build-nvidia/install/lib/cmake/dnnl - -# Option 2: Use FP16 -cmake -B build -DGGML_SYCL=ON -DGGML_SYCL_TARGET=NVIDIA -DGGML_SYCL_DEVICE_ARCH=${GGML_SYCL_DEVICE_ARCH} -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DGGML_SYCL_F16=ON -DDNNL_DIR=/path/to/oneDNN/build-nvidia/install/lib/cmake/dnnl - -# build all binary -cmake --build build --config Release -j -v -``` - -It is possible to come across some precision issues when running tests that stem from using faster -instructions, which can be circumvented by passing the `-fno-fast-math` flag to the compiler. - -#### AMD GPU - -The SYCL backend depends on [oneMath](https://github.com/uxlfoundation/oneMath) for Nvidia and AMD devices. -By default it is automatically built along with the project. A specific build can be provided by setting the CMake flag `-DoneMath_DIR=/path/to/oneMath/install/lib/cmake/oneMath`. - -```sh -# Build LLAMA with rocBLAS acceleration through SYCL - -## AMD -# Use FP32, FP16 is not supported -# Find your GGML_SYCL_DEVICE_ARCH with rocminfo, under the key 'Name:' -GGML_SYCL_DEVICE_ARCH=gfx90a # Example architecture -cmake -B build -DGGML_SYCL=ON -DGGML_SYCL_TARGET=AMD -DGGML_SYCL_DEVICE_ARCH=${GGML_SYCL_DEVICE_ARCH} -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx - -# build all binary -cmake --build build --config Release -j -v -``` - ### III. Run the inference #### Retrieve and prepare model @@ -422,16 +327,12 @@ Choose one of following methods to run. - Use device 0: ```sh -./examples/sycl/run-llama2.sh 0 -# OR -./examples/sycl/run-llama3.sh 0 +./examples/sycl/test.sh -mg 0 ``` - Use multiple devices: ```sh -./examples/sycl/run-llama2.sh -# OR -./examples/sycl/run-llama3.sh +./examples/sycl/test.sh ``` 2. Command line @@ -454,13 +355,13 @@ Examples: - Use device 0: ```sh -ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -no-cnv -m models/llama-2-7b.Q4_0.gguf -p "Building a website can be done in 10 simple steps:" -n 400 -e -ngl 99 -sm none -mg 0 +ZES_ENABLE_SYSMAN=1 ./build/bin/llama-completion -no-cnv -m models/llama-2-7b.Q4_0.gguf -p "Building a website can be done in 10 simple steps:" -n 400 -e -ngl 99 -sm none -mg 0 --mmap ``` - Use multiple devices: ```sh -ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -no-cnv -m models/llama-2-7b.Q4_0.gguf -p "Building a website can be done in 10 simple steps:" -n 400 -e -ngl 99 -sm layer +ZES_ENABLE_SYSMAN=1 ./build/bin/llama-completion -no-cnv -m models/llama-2-7b.Q4_0.gguf -p "Building a website can be done in 10 simple steps:" -n 400 -e -ngl 99 -sm layer --mmap ``` *Notes:* @@ -576,13 +477,13 @@ Or, use CMake presets to build: ```sh cmake --preset x64-windows-sycl-release -cmake --build build-x64-windows-sycl-release -j --target llama-cli +cmake --build build-x64-windows-sycl-release -j --target llama-completion cmake -DGGML_SYCL_F16=ON --preset x64-windows-sycl-release -cmake --build build-x64-windows-sycl-release -j --target llama-cli +cmake --build build-x64-windows-sycl-release -j --target llama-completion cmake --preset x64-windows-sycl-debug -cmake --build build-x64-windows-sycl-debug -j --target llama-cli +cmake --build build-x64-windows-sycl-debug -j --target llama-completion ``` #### 3. Visual Studio @@ -607,7 +508,7 @@ You can use Visual Studio to open the `llama.cpp` folder directly as a CMake pro - For a minimal experimental setup, you can build only the inference executable using: ```Powershell - cmake --build build --config Release -j --target llama-cli + cmake --build build --config Release -j --target llama-completion ``` ##### - Generating a Visual Studio Solution @@ -713,13 +614,7 @@ Choose one of following methods to run. 1. Script ``` -examples\sycl\win-run-llama-2.bat -``` - -or - -``` -examples\sycl\win-run-llama-3.bat +examples\sycl\win-test.bat ``` 2. Command line @@ -743,13 +638,13 @@ Examples: - Use device 0: ``` -build\bin\llama-cli.exe -no-cnv -m models\llama-2-7b.Q4_0.gguf -p "Building a website can be done in 10 simple steps:\nStep 1:" -n 400 -e -ngl 99 -sm none -mg 0 +build\bin\llama-completion.exe -no-cnv -m models\llama-2-7b.Q4_0.gguf -p "Building a website can be done in 10 simple steps:\nStep 1:" -n 400 -e -ngl 99 -sm none -mg 0 --mmap ``` - Use multiple devices: ``` -build\bin\llama-cli.exe -no-cnv -m models\llama-2-7b.Q4_0.gguf -p "Building a website can be done in 10 simple steps:\nStep 1:" -n 400 -e -ngl 99 -sm layer +build\bin\llama-completion.exe -no-cnv -m models\llama-2-7b.Q4_0.gguf -p "Building a website can be done in 10 simple steps:\nStep 1:" -n 400 -e -ngl 99 -sm layer --mmap ``` @@ -775,15 +670,15 @@ use 1 SYCL GPUs: [0] with Max compute units:512 | Name | Value | Function | |--------------------|---------------------------------------|---------------------------------------------| | GGML_SYCL | ON (mandatory) | Enable build with SYCL code path. | -| GGML_SYCL_TARGET | INTEL *(default)* \| NVIDIA \| AMD | Set the SYCL target device type. | -| GGML_SYCL_DEVICE_ARCH | Optional (except for AMD) | Set the SYCL device architecture, optional except for AMD. Setting the device architecture can improve the performance. See the table [--offload-arch](https://github.com/intel/llvm/blob/sycl/sycl/doc/design/OffloadDesign.md#--offload-arch) for a list of valid architectures. | +| GGML_SYCL_TARGET | INTEL *(default)* | Set the SYCL target device type. | +| GGML_SYCL_DEVICE_ARCH | Optional | Set the SYCL device architecture. Setting the device architecture can improve the performance. See the table [--offload-arch](https://github.com/intel/llvm/blob/sycl/sycl/doc/design/OffloadDesign.md#--offload-arch) for a list of valid architectures. | | GGML_SYCL_F16 | OFF *(default)* \|ON *(optional)* | Enable FP16 build with SYCL code path. (1.) | -| GGML_SYCL_GRAPH | ON *(default)* \|OFF *(Optional)* | Enable build with [SYCL Graph extension](https://github.com/intel/llvm/blob/sycl/sycl/doc/extensions/experimental/sycl_ext_oneapi_graph.asciidoc). | +| GGML_SYCL_GRAPH | OFF *(default)* \|ON *(Optional)* | Enable build with [SYCL Graph extension](https://github.com/intel/llvm/blob/sycl/sycl/doc/extensions/experimental/sycl_ext_oneapi_graph.asciidoc). | | GGML_SYCL_DNN | ON *(default)* \|OFF *(Optional)* | Enable build with oneDNN. | | CMAKE_C_COMPILER | `icx` *(Linux)*, `icx/cl` *(Windows)* | Set `icx` compiler for SYCL code path. | | CMAKE_CXX_COMPILER | `icpx` *(Linux)*, `icx` *(Windows)* | Set `icpx/icx` compiler for SYCL code path. | -1. FP16 is recommended for better prompt processing performance on quantized models. Performance is equivalent in text generation but set `GGML_SYCL_F16=OFF` if you are experiencing issues with FP16 builds. +1. FP32 or FP16 have different performance impact to LLM. Recommended to test them for better prompt processing performance on your models. You need to rebuild the code after change `GGML_SYCL_F16=OFF/ON`. #### Runtime @@ -791,7 +686,7 @@ use 1 SYCL GPUs: [0] with Max compute units:512 |-------------------|------------------|---------------------------------------------------------------------------------------------------------------------------| | GGML_SYCL_DEBUG | 0 (default) or 1 | Enable log function by macro: GGML_SYCL_DEBUG | | GGML_SYCL_DISABLE_OPT | 0 (default) or 1 | Disable optimize features for Intel GPUs. (Recommended to 1 for intel devices older than Gen 10) | -| GGML_SYCL_DISABLE_GRAPH | 0 or 1 (default) | Disable running computations through SYCL Graphs feature. Disabled by default because graph performance isn't yet better than non-graph performance. | +| GGML_SYCL_DISABLE_GRAPH | 0 or 1 (default) | Disable running computations through SYCL Graphs feature. Disabled by default because SYCL Graph is still on development, no better performance. | | GGML_SYCL_DISABLE_DNN | 0 (default) or 1 | Disable running computations through oneDNN and always use oneMKL. | | ZES_ENABLE_SYSMAN | 0 (default) or 1 | Support to get free memory of GPU by sycl::aspect::ext_intel_free_memory.
Recommended to use when --split-mode = layer | | UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS | 0 (default) or 1 | Support malloc device memory more than 4GB.| diff --git a/docs/backend/hexagon/CMakeUserPresets.json b/docs/backend/snapdragon/CMakeUserPresets.json similarity index 72% rename from docs/backend/hexagon/CMakeUserPresets.json rename to docs/backend/snapdragon/CMakeUserPresets.json index 1f2676c0bcd..1faae2f3db7 100644 --- a/docs/backend/hexagon/CMakeUserPresets.json +++ b/docs/backend/snapdragon/CMakeUserPresets.json @@ -1,5 +1,5 @@ { - "version": 4, + "version": 5, "configurePresets": [ { "name": "arm64-android-snapdragon", @@ -16,7 +16,9 @@ "CMAKE_CXX_FLAGS_RELEASE": "-O3 -DNDEBUG", "CMAKE_C_FLAGS_RELWITHDEBINFO": "-O3 -DNDEBUG -g", "CMAKE_CXX_FLAGS_RELWITHDEBINFO": "-O3 -DNDEBUG -g", - "HEXAGON_SDK_ROOT": "$env{HEXAGON_SDK_ROOT}", + "CMAKE_PREFIX_PATH": "$env{OPENCL_SDK_ROOT}", + "HEXAGON_SDK_ROOT": "$env{HEXAGON_SDK_ROOT}", + "HEXAGON_TOOLS_ROOT": "$env{HEXAGON_TOOLS_ROOT}", "PREBUILT_LIB_DIR": "android_aarch64", "GGML_OPENMP": "OFF", "GGML_LLAMAFILE": "OFF", @@ -31,7 +33,15 @@ "name": "arm64-windows-snapdragon", "inherits": [ "base", "arm64-windows-llvm" ], "cacheVariables": { - "HEXAGON_SDK_ROOT": "$env{HEXAGON_SDK_ROOT}", + "CMAKE_C_FLAGS": "-march=armv8.7a+fp16 -fvectorize -ffp-model=fast -flto -D_GNU_SOURCE", + "CMAKE_CXX_FLAGS": "-march=armv8.7a+fp16 -fvectorize -ffp-model=fast -flto -D_GNU_SOURCE", + "CMAKE_C_FLAGS_RELEASE": "-O3 -DNDEBUG", + "CMAKE_CXX_FLAGS_RELEASE": "-O3 -DNDEBUG", + "CMAKE_C_FLAGS_RELWITHDEBINFO": "-O3 -DNDEBUG -g", + "CMAKE_CXX_FLAGS_RELWITHDEBINFO": "-O3 -DNDEBUG -g", + "CMAKE_PREFIX_PATH": "$env{OPENCL_SDK_ROOT}", + "HEXAGON_SDK_ROOT": "$env{HEXAGON_SDK_ROOT}", + "HEXAGON_TOOLS_ROOT": "$env{HEXAGON_TOOLS_ROOT}", "PREBUILT_LIB_DIR": "windows_aarch64", "GGML_OPENMP": "OFF", "GGML_LLAMAFILE": "OFF", diff --git a/docs/backend/hexagon/README.md b/docs/backend/snapdragon/README.md similarity index 84% rename from docs/backend/hexagon/README.md rename to docs/backend/snapdragon/README.md index 3befdf72258..8e1f37b2062 100644 --- a/docs/backend/hexagon/README.md +++ b/docs/backend/snapdragon/README.md @@ -1,6 +1,8 @@ -# Snapdragon-based Android devices +# Snapdragon-based devices -## How to Build +## Setup + +### Android The easiest way to build llama.cpp for a Snapdragon-based Android device is using the toolchain Docker image (see github.com/snapdragon-toolchain). This image includes Android NDK, OpenCL SDK, Hexagon SDK, CMake, etc. @@ -12,7 +14,24 @@ This method works on Linux, macOS, and Windows. macOS and Windows users should i [d]/> cd /workspace ``` -The rest of the Android build process assumes that you're running inside the toolchain container. +Note: The rest of the **Android** build process assumes that you're running inside the toolchain container. + +### Windows On Snapdragon + +Native Windows 11 arm64 builds has the following tools dependencies: +- MS Visual Studio 2026 (Community Edition or Pro) + - MSVC arm64 standard and runtime libraries + - UCRT and Driver Kit +- LLVM core libraries and Clang compiler (winget) +- CMake, Git, Python (winget) +- Hexagon SDK Community Edition 6.4 or later (see windows.md) +- OpenCL SDK 2.3 or later (see windows.md) + +Note: The rest of the **Windows** build process assumes that you're running natively in Powershell. +Adapt below build commands accordingly. + +## How to Build + Let's build llama.cpp with CPU, OpenCL, and Hexagon backends via CMake presets: ``` @@ -49,24 +68,26 @@ Preset CMake variables: To generate an installable "package" simply use cmake --install: ``` -[d]/workspace> cmake --install build-snapdragon --prefix pkg-adb/llama.cpp +[d]/workspace> cmake --install build-snapdragon --prefix pkg-snapdragon/llama.cpp -- Install configuration: "Release" --- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml-cpu.so --- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml-opencl.so --- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml-hexagon.so --- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml-htp-v73.so --- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml-htp-v75.so --- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml-htp-v79.so --- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml-htp-v81.so --- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml.so +-- Installing: /workspace/pkg-snapdragon/llama.cpp/lib/libggml-cpu.so +-- Installing: /workspace/pkg-snapdragon/llama.cpp/lib/libggml-opencl.so +-- Installing: /workspace/pkg-snapdragon/llama.cpp/lib/libggml-hexagon.so +-- Installing: /workspace/pkg-snapdragon/llama.cpp/lib/libggml-htp-v73.so +-- Installing: /workspace/pkg-snapdragon/llama.cpp/lib/libggml-htp-v75.so +-- Installing: /workspace/pkg-snapdragon/llama.cpp/lib/libggml-htp-v79.so +-- Installing: /workspace/pkg-snapdragon/llama.cpp/lib/libggml-htp-v81.so +-- Installing: /workspace/pkg-snapdragon/llama.cpp/lib/libggml.so ... --- Installing: /workspace/pkg-adb/llama.cpp/bin/llama-bench --- Installing: /workspace/pkg-adb/llama.cpp/bin/llama-cli +-- Installing: /workspace/pkg-snapdragon/llama.cpp/bin/llama-bench +-- Installing: /workspace/pkg-snapdragon/llama.cpp/bin/llama-cli ... ``` ## How to Install +### Android + For this step, your device needs to be configured for on-device development. Please see https://developer.android.com/studio/debug/dev-options for details. @@ -74,10 +95,10 @@ Once ADB is enabled, use `adb push` to install `pkg-snapdragon` on the device. **Note that the toolchain Docker image doesn't have ADB and doesn't set up the ADB bridge. Please use native ADB on the host.** ``` -~/src/llama.cpp$ adb push pkg-adb/llama.cpp /data/local/tmp/ -pkg-adb/llama.cpp/bin/: 67 files pushed, 0 skipped. 190.2 MB/s (919095042 bytes in 4.607s) -pkg-adb/llama.cpp/include/: 19 files pushed, 0 skipped. 20.5 MB/s (255173 bytes in 0.012s) -pkg-adb/llama.cpp/lib/: 16 files pushed, 0 skipped. 144.4 MB/s (43801382 bytes in 0.289s) +~/src/llama.cpp$ adb push pkg-snapdragon/llama.cpp /data/local/tmp/ +pkg-snapdragon/llama.cpp/bin/: 67 files pushed, 0 skipped. 190.2 MB/s (919095042 bytes in 4.607s) +pkg-snapdragon/llama.cpp/include/: 19 files pushed, 0 skipped. 20.5 MB/s (255173 bytes in 0.012s) +pkg-snapdragon/llama.cpp/lib/: 16 files pushed, 0 skipped. 144.4 MB/s (43801382 bytes in 0.289s) 102 files pushed, 0 skipped. 186.9 MB/s (963151597 bytes in 4.914s) ``` @@ -92,6 +113,11 @@ At this point, you should also install some models: Llama-3.2-1B-Instruct-Q4_0.gguf: 1 file pushed, 0 skipped. 38.3 MB/s (773025920 bytes in 19.250s) ``` +### Windows + +All artifacts are already installed in the `pkg-snapdragon` folder. +To run, adapt below instructions to use Powershell scrits in `scripts/snapdragon/windows`. + ## How to Run The easiest way to run llama.cpp cli tools is using provided wrapper scripts that properly set up all required environment variables. diff --git a/docs/backend/hexagon/developer.md b/docs/backend/snapdragon/developer.md similarity index 100% rename from docs/backend/hexagon/developer.md rename to docs/backend/snapdragon/developer.md diff --git a/docs/backend/snapdragon/windows.md b/docs/backend/snapdragon/windows.md new file mode 100644 index 00000000000..e9346ccadf1 --- /dev/null +++ b/docs/backend/snapdragon/windows.md @@ -0,0 +1,161 @@ +## Overview + +The document covers procedures for installing the latest GPU and NPU drivers, and OpenCL and Hexagon SDKs. + + +In order to use Hexagon NPU on Snapdragon Windows devices the underlying HTP Ops libraries (e.g libggml-htp-v73.so) +must be included in the .cat file digitally signed with a trusted certificate. + +This document covers details on how to generate personal certificate files (.pfx) and how to configure the system +to allow for test signatures (aka test-signing). + +## Install the latest Adreno OpenCL SDK + +Either use the trimmed down version (optimized for CI) from + + https://github.com/snapdragon-toolchain/opencl-sdk/releases/download/v2.3.2/adreno-opencl-sdk-v2.3.2-arm64-wos.tar.xz + +Or download the complete official version from + + https://softwarecenter.qualcomm.com/catalog/item/Adreno_OpenCL_SDK?version=2.3.2 + +Unzip/untar the archive into +``` +c:\Qualcomm\OpenCL_SDK\2.3.2 +``` + +## Install the latest Hexagon SDK Community Edition + +Either use the trimmed down version (optimized for CI) from + + https://github.com/snapdragon-toolchain/hexagon-sdk/releases/download/v6.4.0.2/hexagon-sdk-v6.4.0.2-arm64-wos.tar.xz + +Or download the complete official version from + + https://softwarecenter.qualcomm.com/catalog/item/Hexagon_SDK?version=6.4.0.2 + +Unzip/untar the archive into +``` +c:\Qualcomm\Hexagon_SDK\6.4.0.2 +``` + +## Install the latest Adreno GPU driver + +Download the driver from + + https://softwarecenter.qualcomm.com/catalog/item/Windows_Graphics_Driver + +After the automated installation and reboot please make sure that the GPU device shows up in the `Device Manager` (under 'Display Adapters`) + +## Install the latest Qualcomm NPU driver + +Download the driver from + + https://softwarecenter.qualcomm.com/catalog/item/Qualcomm_HND + +After the automated installation and reboot please make sure that the Hexagon NPU device shows up in the `Device Manager` (under `Neural Processors`). + +If the device is not available you can try installing all components (`qcnspmcdm8380`, `qcnspmcdm8380_ext`) manually. +The components are extracted into +``` +c:\QCDrivers\qcnspmcdm... +``` + +## Enable NPU driver test signatures + +Please note that the following steps are required only for the Hexagon NPU. +Adreno GPU backend does not require test signatures. + +### Enable testsigning + +Use `bcdedit` to enable test-signing +``` +> bcdedit /set TESTSIGNING ON +``` +(Secure Boot may need to be disabled for this to work) + +Make sure test-signing is enabled after reboot +``` +> bcdedit /enum +... +testsigning Yes +... +``` +For additional details see Microsoft guide at + + https://learn.microsoft.com/en-us/windows-hardware/drivers/install/the-testsigning-boot-configuration-option + +### Create personal certificate + +The tools required for this procedure are available as part of Windows SDK and Windows Driver Kit which should be +installed as part of the MS Visual Studio. +They are typically located at +``` +c:\Program Files (x86)\Windows Kits\10\bin\10.0.26100.0 +``` +(replace 10.0.26100.0 with correct version). + +To create personal self-signed certificate run the following commands (either from cmd or power-shell): +``` +> cd c:\Users\MyUser +> mkdir Certs +> cd Certs +> makecert -r -pe -ss PrivateCertStore -n CN=GGML.HTP.v1 -eku 1.3.6.1.5.5.7.3.3 -sv ggml-htp-v1.pvk ggml-htp-v1.cer +> pvk2pfx.exe -pvk ggml-htp-v1.pvk -spc ggml-htp-v1.cer -pfx ggml-htp-v1.pfx +``` +(replace `MyUser` with your username). + +Add this certificate to `Trusted Root Certification Authorities` and `Trusted Publishers` stores. +This can be done using `certlm` Certificate Manager tool. +Right click on the certificate store, select `All Tasks -> Import` and follow the prompts to import the certificate from the +PFX file you created above. + +For additional details see Microsoft guide at + + https://learn.microsoft.com/en-us/windows-hardware/drivers/install/introduction-to-test-signing + +Make sure to save the PFX file, you will need it for the build procedures. +Please note that the same certificate can be used for signing any number of builds. + +## Build Hexagon backend with signed HTP ops libraries + +The overall Hexagon backend build procedure for Windows on Snapdragon is the same as for other platforms. +However, additional settings are required for generating and signing HTP Ops libraries. +``` +> $env:OPENCL_SDK_ROOT="C:\Qualcomm\OpenCL_SDK\2.3.2" +> $env:HEXAGON_SDK_ROOT="C:\Qualcomm\Hexagon_SDK\6.4.0.2" +> $env:HEXAGON_TOOLS_ROOT="C:\Qualcomm\Hexagon_SDK\6.4.0.2\tools\HEXAGON_Tools\19.0.04" +> $env:HEXAGON_HTP_CERT="c:\Users\MyUsers\Certs\ggml-htp-v1.pfx" +> $env:WINDOWS_SDK_BIN="C:\Program Files (x86)\Windows Kits\10\bin\10.0.26100.0\arm64" + +> cmake --preset arm64-windows-snapdragon-release -B build-wos +... +> cmake --install build-wos --prefix pkg-snapdragon +``` + +Once the build is complete HTP ops libraries will be installed like this +``` +> dir pkg-snapdragon/lib +... +-a---- 1/22/2026 6:01 PM 187656 libggml-htp-v73.so +-a---- 1/22/2026 6:01 PM 191752 libggml-htp-v75.so +-a---- 1/22/2026 6:01 PM 187656 libggml-htp-v79.so +-a---- 1/22/2026 6:01 PM 187656 libggml-htp-v81.so +-a---- 1/22/2026 6:01 PM 4139 libggml-htp.cat +``` + +The .cat file, the signature and proper certicate installation can be verified with + +``` +> signtool.exe verify /v /pa .\pkg-snapdragon\lib\libggml-htp.cat +Verifying: .\pkg-snapdragon\lib\libggml-htp.cat + +Signature Index: 0 (Primary Signature) +Hash of file (sha256): 9820C664DA59D5EAE31DBB664127FCDAEF59CDC31502496BC567544EC2F401CF + +Signing Certificate Chain: + Issued to: GGML.HTP.v1 +... +Successfully verified: .\pkg-snapdragon\lib\libggml-htp.cat +... +``` diff --git a/docs/build.md b/docs/build.md index fce9361b2d6..fd447424c78 100644 --- a/docs/build.md +++ b/docs/build.md @@ -144,7 +144,7 @@ We also have a [guide](./backend/CUDA-FEDORA.md) for setting up CUDA toolkit in - ***Necessary*** for users of [Atomic Desktops for Fedora](https://fedoraproject.org/atomic-desktops/); such as: [Silverblue](https://fedoraproject.org/atomic-desktops/silverblue/) and [Kinoite](https://fedoraproject.org/atomic-desktops/kinoite/). - (there are no supported CUDA packages for these systems) - ***Necessary*** for users that have a host that is not a: [Supported Nvidia CUDA Release Platform](https://developer.nvidia.com/cuda-downloads). - - (for example, you may have [Fedora 42 Beta](https://fedoramagazine.org/announcing-fedora-linux-42-beta/) as your your host operating system) + - (for example, you may have [Fedora 42 Beta](https://fedoramagazine.org/announcing-fedora-linux-42-beta/) as your host operating system) - ***Convenient*** For those running [Fedora Workstation](https://fedoraproject.org/workstation/) or [Fedora KDE Plasma Desktop](https://fedoraproject.org/spins/kde), and want to keep their host system clean. - *Optionally* toolbox packages are available: [Arch Linux](https://archlinux.org/), [Red Hat Enterprise Linux >= 8.5](https://www.redhat.com/en/technologies/linux-platforms/enterprise-linux), or [Ubuntu](https://ubuntu.com/download) @@ -248,6 +248,12 @@ You may set the [cuda environmental variables](https://docs.nvidia.com/cuda/cuda CUDA_VISIBLE_DEVICES="-0" ./build/bin/llama-server --model /srv/models/llama.gguf ``` +#### CUDA_SCALE_LAUNCH_QUEUES + +The environment variable [`CUDA_SCALE_LAUNCH_QUEUES`](https://docs.nvidia.com/cuda/cuda-programming-guide/05-appendices/environment-variables.html#cuda-scale-launch-queues) controls the size of CUDA's command buffer, which determines how many GPU operations can be queued before the CPU must wait for the GPU to catch up. A larger buffer reduces CPU-side stalls and allows more work to be queued on a GPU. + +Consider setting `CUDA_SCALE_LAUNCH_QUEUES=4x`, which increases the CUDA command buffer to 4 times its default size. This optimization is particularly beneficial for **Multi-GPU setups with pipeline parallelism**, where it significantly improves prompt processing throughput by allowing more operations to be enqueued across GPUs. + ### Unified Memory The environment variable `GGML_CUDA_ENABLE_UNIFIED_MEMORY=1` can be used to enable unified memory in Linux. This allows swapping to system RAM instead of crashing when the GPU VRAM is exhausted. In Windows this setting is available in the NVIDIA control panel as `System Memory Fallback`. @@ -487,6 +493,37 @@ Finally, after finishing your build, you should be able to do something like thi # ggml_vulkan: Using Intel(R) Graphics (ADL GT2) | uma: 1 | fp16: 1 | warp size: 32 ``` +### For Mac users: + +Generally, follow LunarG's [Getting Started with the MacOS Vulkan SDK](https://vulkan.lunarg.com/doc/sdk/latest/mac/getting_started.html) guide for installation and setup of the Vulkan SDK. There are two options of Vulkan drivers on macOS, both of which implement translation layers to map Vulkan to Metal. They can be hot-swapped by setting the `VK_ICD_FILENAMES` environment variable to point to the respective ICD JSON file. + +Check the box for "KosmicKrisp" during the LunarG Vulkan SDK installation. + +Set environment variable for the LunarG Vulkan SDK after installation (and optionally add to your shell profile for persistence): +```bash +source /path/to/vulkan-sdk/setup-env.sh +``` + +#### Using MoltenVK + +MoltenVK is the default Vulkan driver installed with the LunarG Vulkan SDK on macOS, so you can use the above environment variable settings as is. + +#### Using KosmicKrisp + +Override the environment variable for KosmicKrisp: +```bash +export VK_ICD_FILENAMES=$VULKAN_SDK/share/vulkan/icd.d/libkosmickrisp_icd.json +export VK_DRIVER_FILES=$VULKAN_SDK/share/vulkan/icd.d/libkosmickrisp_icd.json +``` + +#### Build + +This is the only step different from [above](#common-steps) instructions. +```bash +cmake -B build -DGGML_VULKAN=1 -DGGML_METAL=OFF +cmake --build build --config Release +``` + ## CANN This provides NPU acceleration using the AI cores of your Ascend NPU. And [CANN](https://www.hiascend.com/en/software/cann) is a hierarchical APIs to help you to quickly build AI applications and service based on Ascend NPU. diff --git a/docs/multimodal/minicpmo2.6.md b/docs/multimodal/minicpmo2.6.md index 5e74058e5d5..ce003b2ebc0 100644 --- a/docs/multimodal/minicpmo2.6.md +++ b/docs/multimodal/minicpmo2.6.md @@ -9,7 +9,7 @@ Download [MiniCPM-o-2_6](https://huggingface.co/openbmb/MiniCPM-o-2_6) PyTorch m ### Build llama.cpp Readme modification time: 20250206 -If there are differences in usage, please refer to the official build [documentation](https://github.com/ggerganov/llama.cpp/blob/master/docs/build.md) +If there are differences in usage, please refer to the official build [documentation](https://github.com/ggml-org/llama.cpp/blob/master/docs/build.md) Clone llama.cpp: ```bash diff --git a/docs/multimodal/minicpmo4.0.md b/docs/multimodal/minicpmo4.0.md index 49125ea05e0..a5281779c2b 100644 --- a/docs/multimodal/minicpmo4.0.md +++ b/docs/multimodal/minicpmo4.0.md @@ -8,11 +8,11 @@ Download [MiniCPM-o-4](https://huggingface.co/openbmb/MiniCPM-o-4) PyTorch model ### Build llama.cpp Readme modification time: 20250206 -If there are differences in usage, please refer to the official build [documentation](https://github.com/ggerganov/llama.cpp/blob/master/docs/build.md) +If there are differences in usage, please refer to the official build [documentation](https://github.com/ggml-org/llama.cpp/blob/master/docs/build.md) Clone llama.cpp: ```bash -git clone https://github.com/ggerganov/llama.cpp +git clone https://github.com/ggml-org/llama.cpp cd llama.cpp ``` diff --git a/docs/multimodal/minicpmv2.5.md b/docs/multimodal/minicpmv2.5.md index 5eb87bc9693..096f070a1c9 100644 --- a/docs/multimodal/minicpmv2.5.md +++ b/docs/multimodal/minicpmv2.5.md @@ -8,7 +8,7 @@ Download [MiniCPM-Llama3-V-2_5](https://huggingface.co/openbmb/MiniCPM-Llama3-V- ### Build llama.cpp Readme modification time: 20250206 -If there are differences in usage, please refer to the official build [documentation](https://github.com/ggerganov/llama.cpp/blob/master/docs/build.md) +If there are differences in usage, please refer to the official build [documentation](https://github.com/ggml-org/llama.cpp/blob/master/docs/build.md) Clone llama.cpp: ```bash diff --git a/docs/multimodal/minicpmv2.6.md b/docs/multimodal/minicpmv2.6.md index bc874bbd8cd..a7db9c58db0 100644 --- a/docs/multimodal/minicpmv2.6.md +++ b/docs/multimodal/minicpmv2.6.md @@ -8,7 +8,7 @@ Download [MiniCPM-V-2_6](https://huggingface.co/openbmb/MiniCPM-V-2_6) PyTorch m ### Build llama.cpp Readme modification time: 20250206 -If there are differences in usage, please refer to the official build [documentation](https://github.com/ggerganov/llama.cpp/blob/master/docs/build.md) +If there are differences in usage, please refer to the official build [documentation](https://github.com/ggml-org/llama.cpp/blob/master/docs/build.md) Clone llama.cpp: ```bash diff --git a/docs/multimodal/minicpmv4.0.md b/docs/multimodal/minicpmv4.0.md index d04cb338cec..1d21b8cfdf9 100644 --- a/docs/multimodal/minicpmv4.0.md +++ b/docs/multimodal/minicpmv4.0.md @@ -8,11 +8,11 @@ Download [MiniCPM-V-4](https://huggingface.co/openbmb/MiniCPM-V-4) PyTorch model ### Build llama.cpp Readme modification time: 20250731 -If there are differences in usage, please refer to the official build [documentation](https://github.com/ggerganov/llama.cpp/blob/master/docs/build.md) +If there are differences in usage, please refer to the official build [documentation](https://github.com/ggml-org/llama.cpp/blob/master/docs/build.md) Clone llama.cpp: ```bash -git clone https://github.com/ggerganov/llama.cpp +git clone https://github.com/ggml-org/llama.cpp cd llama.cpp ``` diff --git a/docs/multimodal/minicpmv4.5.md b/docs/multimodal/minicpmv4.5.md index 8fea5e611da..a102c0fa510 100644 --- a/docs/multimodal/minicpmv4.5.md +++ b/docs/multimodal/minicpmv4.5.md @@ -8,11 +8,11 @@ Download [MiniCPM-V-4_5](https://huggingface.co/openbmb/MiniCPM-V-4_5) PyTorch m ### Build llama.cpp Readme modification time: 20250826 -If there are differences in usage, please refer to the official build [documentation](https://github.com/ggerganov/llama.cpp/blob/master/docs/build.md) +If there are differences in usage, please refer to the official build [documentation](https://github.com/ggml-org/llama.cpp/blob/master/docs/build.md) Clone llama.cpp: ```bash -git clone https://github.com/ggerganov/llama.cpp +git clone https://github.com/ggml-org/llama.cpp cd llama.cpp ``` diff --git a/docs/ops.md b/docs/ops.md index c066ab5a858..5754b0a96cd 100644 --- a/docs/ops.md +++ b/docs/ops.md @@ -22,7 +22,7 @@ Legend: | ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | | ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | | ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ | -| CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ | +| CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ | | CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ | | CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ | | CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ | ❌ | @@ -97,7 +97,7 @@ Legend: | SILU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ | | SILU_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | | SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | -| SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ | +| SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ | | SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | | SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ | ❌ | | SOLVE_TRI | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ | @@ -113,8 +113,8 @@ Legend: | SWIGLU_OAI | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ | | TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ | | TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | -| TOP_K | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ | -| TRI | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | +| TOP_K | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ | +| TRI | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | | TRUNC | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ | | UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ | ❌ | | XIELU | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | diff --git a/docs/ops/SYCL.csv b/docs/ops/SYCL.csv index 91b442bde8e..c1622cc6f0e 100644 --- a/docs/ops/SYCL.csv +++ b/docs/ops/SYCL.csv @@ -29,8 +29,8 @@ "SYCL0","EXP","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL" "SYCL0","EXPM1","type=f16,ne_a=[128,2,2,2],v=0","support","0","no","SYCL" "SYCL0","EXPM1","type=f16,ne_a=[5,7,11,13],v=0","support","0","no","SYCL" -"SYCL0","SOFTPLUS","type=f16,ne_a=[128,2,2,2],v=0","support","0","no","SYCL" -"SYCL0","SOFTPLUS","type=f16,ne_a=[5,7,11,13],v=0","support","0","no","SYCL" +"SYCL0","SOFTPLUS","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL" +"SYCL0","SOFTPLUS","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL" "SYCL0","GELU_ERF","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL" "SYCL0","GELU_ERF","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL" "SYCL0","FLOOR","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL" @@ -71,14 +71,14 @@ "SYCL0","EXP","type=f16,ne_a=[5,7,11,13],v=1","support","1","yes","SYCL" "SYCL0","EXPM1","type=f16,ne_a=[128,2,2,2],v=1","support","0","no","SYCL" "SYCL0","EXPM1","type=f16,ne_a=[5,7,11,13],v=1","support","0","no","SYCL" -"SYCL0","SOFTPLUS","type=f16,ne_a=[128,2,2,2],v=1","support","0","no","SYCL" -"SYCL0","SOFTPLUS","type=f16,ne_a=[5,7,11,13],v=1","support","0","no","SYCL" +"SYCL0","SOFTPLUS","type=f16,ne_a=[128,2,2,2],v=1","support","1","yes","SYCL" +"SYCL0","SOFTPLUS","type=f16,ne_a=[5,7,11,13],v=1","support","1","yes","SYCL" "SYCL0","GELU_ERF","type=f16,ne_a=[128,2,2,2],v=1","support","1","yes","SYCL" "SYCL0","GELU_ERF","type=f16,ne_a=[5,7,11,13],v=1","support","1","yes","SYCL" "SYCL0","FLOOR","type=f16,ne_a=[128,2,2,2],v=1","support","0","no","SYCL" "SYCL0","FLOOR","type=f16,ne_a=[5,7,11,13],v=1","support","0","no","SYCL" -"SYCL0","CEIL","type=f16,ne_a=[128,2,2,2],v=1","support","0","no","SYCL" -"SYCL0","CEIL","type=f16,ne_a=[5,7,11,13],v=1","support","0","no","SYCL" +"SYCL0","CEIL","type=f16,ne_a=[128,2,2,2],v=1","support","1","yes","SYCL" +"SYCL0","CEIL","type=f16,ne_a=[5,7,11,13],v=1","support","1","yes","SYCL" "SYCL0","ROUND","type=f16,ne_a=[128,2,2,2],v=1","support","0","no","SYCL" "SYCL0","ROUND","type=f16,ne_a=[5,7,11,13],v=1","support","0","no","SYCL" "SYCL0","TRUNC","type=f16,ne_a=[128,2,2,2],v=1","support","0","no","SYCL" @@ -113,8 +113,8 @@ "SYCL0","EXP","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL" "SYCL0","EXPM1","type=f32,ne_a=[128,2,2,2],v=0","support","0","no","SYCL" "SYCL0","EXPM1","type=f32,ne_a=[5,7,11,13],v=0","support","0","no","SYCL" -"SYCL0","SOFTPLUS","type=f32,ne_a=[128,2,2,2],v=0","support","0","no","SYCL" -"SYCL0","SOFTPLUS","type=f32,ne_a=[5,7,11,13],v=0","support","0","no","SYCL" +"SYCL0","SOFTPLUS","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL" +"SYCL0","SOFTPLUS","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL" "SYCL0","GELU_ERF","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL" "SYCL0","GELU_ERF","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL" "SYCL0","FLOOR","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL" @@ -155,14 +155,14 @@ "SYCL0","EXP","type=f32,ne_a=[5,7,11,13],v=1","support","1","yes","SYCL" "SYCL0","EXPM1","type=f32,ne_a=[128,2,2,2],v=1","support","0","no","SYCL" "SYCL0","EXPM1","type=f32,ne_a=[5,7,11,13],v=1","support","0","no","SYCL" -"SYCL0","SOFTPLUS","type=f32,ne_a=[128,2,2,2],v=1","support","0","no","SYCL" -"SYCL0","SOFTPLUS","type=f32,ne_a=[5,7,11,13],v=1","support","0","no","SYCL" +"SYCL0","SOFTPLUS","type=f32,ne_a=[128,2,2,2],v=1","support","1","yes","SYCL" +"SYCL0","SOFTPLUS","type=f32,ne_a=[5,7,11,13],v=1","support","1","yes","SYCL" "SYCL0","GELU_ERF","type=f32,ne_a=[128,2,2,2],v=1","support","1","yes","SYCL" "SYCL0","GELU_ERF","type=f32,ne_a=[5,7,11,13],v=1","support","1","yes","SYCL" "SYCL0","FLOOR","type=f32,ne_a=[128,2,2,2],v=1","support","0","no","SYCL" "SYCL0","FLOOR","type=f32,ne_a=[5,7,11,13],v=1","support","0","no","SYCL" -"SYCL0","CEIL","type=f32,ne_a=[128,2,2,2],v=1","support","0","no","SYCL" -"SYCL0","CEIL","type=f32,ne_a=[5,7,11,13],v=1","support","0","no","SYCL" +"SYCL0","CEIL","type=f32,ne_a=[128,2,2,2],v=1","support","1","yes","SYCL" +"SYCL0","CEIL","type=f32,ne_a=[5,7,11,13],v=1","support","1","yes","SYCL" "SYCL0","ROUND","type=f32,ne_a=[128,2,2,2],v=1","support","0","no","SYCL" "SYCL0","ROUND","type=f32,ne_a=[5,7,11,13],v=1","support","0","no","SYCL" "SYCL0","TRUNC","type=f32,ne_a=[128,2,2,2],v=1","support","0","no","SYCL" @@ -878,6 +878,54 @@ "SYCL0","POOL_2D","pool_type=max,type_input=f32,ne_input=[10,10,3,1],k0=3,k1=3,s0=2,s1=2,p0=0,p1=1","support","1","yes","SYCL" "SYCL0","POOL_2D","pool_type=max,type_input=f32,ne_input=[10,10,3,1],k0=3,k1=3,s0=2,s1=2,p0=1,p1=0","support","1","yes","SYCL" "SYCL0","POOL_2D","pool_type=max,type_input=f32,ne_input=[10,10,3,1],k0=3,k1=3,s0=2,s1=2,p0=1,p1=1","support","1","yes","SYCL" +"SYCL0","POOL_1D","pool_type=avg,type_input=f32,ne_input=[10,3,2,1],k0=1,s0=1,p0=0","support","0","no","SYCL" +"SYCL0","POOL_1D","pool_type=avg,type_input=f32,ne_input=[11,1,3,2],k0=1,s0=1,p0=0","support","0","no","SYCL" +"SYCL0","POOL_1D","pool_type=avg,type_input=f32,ne_input=[128,2,1,3],k0=1,s0=1,p0=0","support","0","no","SYCL" +"SYCL0","POOL_1D","pool_type=avg,type_input=f32,ne_input=[10,3,2,1],k0=1,s0=1,p0=1","support","0","no","SYCL" +"SYCL0","POOL_1D","pool_type=avg,type_input=f32,ne_input=[11,1,3,2],k0=1,s0=1,p0=1","support","0","no","SYCL" +"SYCL0","POOL_1D","pool_type=avg,type_input=f32,ne_input=[128,2,1,3],k0=1,s0=1,p0=1","support","0","no","SYCL" +"SYCL0","POOL_1D","pool_type=avg,type_input=f32,ne_input=[10,3,2,1],k0=1,s0=2,p0=0","support","0","no","SYCL" +"SYCL0","POOL_1D","pool_type=avg,type_input=f32,ne_input=[11,1,3,2],k0=1,s0=2,p0=0","support","0","no","SYCL" +"SYCL0","POOL_1D","pool_type=avg,type_input=f32,ne_input=[128,2,1,3],k0=1,s0=2,p0=0","support","0","no","SYCL" +"SYCL0","POOL_1D","pool_type=avg,type_input=f32,ne_input=[10,3,2,1],k0=1,s0=2,p0=1","support","0","no","SYCL" +"SYCL0","POOL_1D","pool_type=avg,type_input=f32,ne_input=[11,1,3,2],k0=1,s0=2,p0=1","support","0","no","SYCL" +"SYCL0","POOL_1D","pool_type=avg,type_input=f32,ne_input=[128,2,1,3],k0=1,s0=2,p0=1","support","0","no","SYCL" +"SYCL0","POOL_1D","pool_type=avg,type_input=f32,ne_input=[10,3,2,1],k0=3,s0=1,p0=0","support","0","no","SYCL" +"SYCL0","POOL_1D","pool_type=avg,type_input=f32,ne_input=[11,1,3,2],k0=3,s0=1,p0=0","support","0","no","SYCL" +"SYCL0","POOL_1D","pool_type=avg,type_input=f32,ne_input=[128,2,1,3],k0=3,s0=1,p0=0","support","0","no","SYCL" +"SYCL0","POOL_1D","pool_type=avg,type_input=f32,ne_input=[10,3,2,1],k0=3,s0=1,p0=1","support","0","no","SYCL" +"SYCL0","POOL_1D","pool_type=avg,type_input=f32,ne_input=[11,1,3,2],k0=3,s0=1,p0=1","support","0","no","SYCL" +"SYCL0","POOL_1D","pool_type=avg,type_input=f32,ne_input=[128,2,1,3],k0=3,s0=1,p0=1","support","0","no","SYCL" +"SYCL0","POOL_1D","pool_type=avg,type_input=f32,ne_input=[10,3,2,1],k0=3,s0=2,p0=0","support","0","no","SYCL" +"SYCL0","POOL_1D","pool_type=avg,type_input=f32,ne_input=[11,1,3,2],k0=3,s0=2,p0=0","support","0","no","SYCL" +"SYCL0","POOL_1D","pool_type=avg,type_input=f32,ne_input=[128,2,1,3],k0=3,s0=2,p0=0","support","0","no","SYCL" +"SYCL0","POOL_1D","pool_type=avg,type_input=f32,ne_input=[10,3,2,1],k0=3,s0=2,p0=1","support","0","no","SYCL" +"SYCL0","POOL_1D","pool_type=avg,type_input=f32,ne_input=[11,1,3,2],k0=3,s0=2,p0=1","support","0","no","SYCL" +"SYCL0","POOL_1D","pool_type=avg,type_input=f32,ne_input=[128,2,1,3],k0=3,s0=2,p0=1","support","0","no","SYCL" +"SYCL0","POOL_1D","pool_type=max,type_input=f32,ne_input=[10,3,2,1],k0=1,s0=1,p0=0","support","0","no","SYCL" +"SYCL0","POOL_1D","pool_type=max,type_input=f32,ne_input=[11,1,3,2],k0=1,s0=1,p0=0","support","0","no","SYCL" +"SYCL0","POOL_1D","pool_type=max,type_input=f32,ne_input=[128,2,1,3],k0=1,s0=1,p0=0","support","0","no","SYCL" +"SYCL0","POOL_1D","pool_type=max,type_input=f32,ne_input=[10,3,2,1],k0=1,s0=1,p0=1","support","0","no","SYCL" +"SYCL0","POOL_1D","pool_type=max,type_input=f32,ne_input=[11,1,3,2],k0=1,s0=1,p0=1","support","0","no","SYCL" +"SYCL0","POOL_1D","pool_type=max,type_input=f32,ne_input=[128,2,1,3],k0=1,s0=1,p0=1","support","0","no","SYCL" +"SYCL0","POOL_1D","pool_type=max,type_input=f32,ne_input=[10,3,2,1],k0=1,s0=2,p0=0","support","0","no","SYCL" +"SYCL0","POOL_1D","pool_type=max,type_input=f32,ne_input=[11,1,3,2],k0=1,s0=2,p0=0","support","0","no","SYCL" +"SYCL0","POOL_1D","pool_type=max,type_input=f32,ne_input=[128,2,1,3],k0=1,s0=2,p0=0","support","0","no","SYCL" +"SYCL0","POOL_1D","pool_type=max,type_input=f32,ne_input=[10,3,2,1],k0=1,s0=2,p0=1","support","0","no","SYCL" +"SYCL0","POOL_1D","pool_type=max,type_input=f32,ne_input=[11,1,3,2],k0=1,s0=2,p0=1","support","0","no","SYCL" +"SYCL0","POOL_1D","pool_type=max,type_input=f32,ne_input=[128,2,1,3],k0=1,s0=2,p0=1","support","0","no","SYCL" +"SYCL0","POOL_1D","pool_type=max,type_input=f32,ne_input=[10,3,2,1],k0=3,s0=1,p0=0","support","0","no","SYCL" +"SYCL0","POOL_1D","pool_type=max,type_input=f32,ne_input=[11,1,3,2],k0=3,s0=1,p0=0","support","0","no","SYCL" +"SYCL0","POOL_1D","pool_type=max,type_input=f32,ne_input=[128,2,1,3],k0=3,s0=1,p0=0","support","0","no","SYCL" +"SYCL0","POOL_1D","pool_type=max,type_input=f32,ne_input=[10,3,2,1],k0=3,s0=1,p0=1","support","0","no","SYCL" +"SYCL0","POOL_1D","pool_type=max,type_input=f32,ne_input=[11,1,3,2],k0=3,s0=1,p0=1","support","0","no","SYCL" +"SYCL0","POOL_1D","pool_type=max,type_input=f32,ne_input=[128,2,1,3],k0=3,s0=1,p0=1","support","0","no","SYCL" +"SYCL0","POOL_1D","pool_type=max,type_input=f32,ne_input=[10,3,2,1],k0=3,s0=2,p0=0","support","0","no","SYCL" +"SYCL0","POOL_1D","pool_type=max,type_input=f32,ne_input=[11,1,3,2],k0=3,s0=2,p0=0","support","0","no","SYCL" +"SYCL0","POOL_1D","pool_type=max,type_input=f32,ne_input=[128,2,1,3],k0=3,s0=2,p0=0","support","0","no","SYCL" +"SYCL0","POOL_1D","pool_type=max,type_input=f32,ne_input=[10,3,2,1],k0=3,s0=2,p0=1","support","0","no","SYCL" +"SYCL0","POOL_1D","pool_type=max,type_input=f32,ne_input=[11,1,3,2],k0=3,s0=2,p0=1","support","0","no","SYCL" +"SYCL0","POOL_1D","pool_type=max,type_input=f32,ne_input=[128,2,1,3],k0=3,s0=2,p0=1","support","0","no","SYCL" "SYCL0","IM2COL","type_input=f32,type_kernel=f32,dst_type=f32,ne_input=[3000,128,1,1],ne_kernel=[3,128,1280,1],s0=1,s1=0,p0=1,p1=0,d0=1,d1=0,is_2D=0","support","1","yes","SYCL" "SYCL0","IM2COL","type_input=f32,type_kernel=f16,dst_type=f32,ne_input=[3000,128,1,1],ne_kernel=[3,128,1280,1],s0=1,s1=0,p0=1,p1=0,d0=1,d1=0,is_2D=0","support","1","yes","SYCL" "SYCL0","IM2COL","type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[3000,128,1,1],ne_kernel=[3,128,1280,1],s0=1,s1=0,p0=1,p1=0,d0=1,d1=0,is_2D=0","support","1","yes","SYCL" @@ -965,6 +1013,7 @@ "SYCL0","IM2COL","type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,2560],ne_kernel=[3,3,1,2560],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1","support","1","yes","SYCL" "SYCL0","IM2COL","type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,2560],ne_kernel=[3,3,2,2560],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1","support","1","yes","SYCL" "SYCL0","IM2COL","type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[5,5,1,32],ne_kernel=[3,4,1,32],s0=1,s1=1,p0=0,p1=0,d0=1,d1=1,is_2D=1","support","1","yes","SYCL" +"SYCL0","IM2COL","type_input=f32,type_kernel=f32,dst_type=f32,ne_input=[2,2,1536,729],ne_kernel=[2,2,1536,4096],s0=1,s1=1,p0=0,p1=0,d0=1,d1=1,is_2D=1","support","1","yes","SYCL" "SYCL0","IM2COL_3D","type_input=f32,type_kernel=f32,dst_type=f32,ne_input=[10,10,10,9],ne_kernel=[3,3,3,1],IC=3,s0=1,s1=1,s2=1,p0=1,p1=1,p2=1,d0=1,d1=1,d2=1,v=0","support","0","no","SYCL" "SYCL0","IM2COL_3D","type_input=f32,type_kernel=f16,dst_type=f32,ne_input=[10,10,10,9],ne_kernel=[3,3,3,1],IC=3,s0=1,s1=1,s2=1,p0=1,p1=1,p2=1,d0=1,d1=1,d2=1,v=0","support","0","no","SYCL" "SYCL0","IM2COL_3D","type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[10,10,10,9],ne_kernel=[3,3,3,1],IC=3,s0=1,s1=1,s2=1,p0=1,p1=1,p2=1,d0=1,d1=1,d2=1,v=0","support","0","no","SYCL" @@ -5696,35 +5745,58 @@ "SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000000,inplace=0","support","1","yes","SYCL" "SYCL0","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.000000","support","1","yes","SYCL" "SYCL0","L2_NORM","type=f32,ne=[64,5,4,3]","support","1","yes","SYCL" +"SYCL0","NORM","type=f32,ne=[1025,5,4,3],v=0,eps=0.000000","support","1","yes","SYCL" +"SYCL0","RMS_NORM","type=f32,ne=[1025,5,4,3],v=0,eps=0.000000,inplace=0","support","1","yes","SYCL" +"SYCL0","NORM","type=f32,ne=[1025,5,4,3],v=1,eps=0.000000","support","1","yes","SYCL" +"SYCL0","RMS_NORM","type=f32,ne=[1025,5,4,3],v=1,eps=0.000000,inplace=0","support","1","yes","SYCL" +"SYCL0","RMS_NORM_BACK","type=f32,ne=[1025,5,4,3],eps=0.000000","support","1","yes","SYCL" +"SYCL0","L2_NORM","type=f32,ne=[1025,5,4,3]","support","1","yes","SYCL" "SYCL0","NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000001","support","1","yes","SYCL" "SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000001,inplace=0","support","1","yes","SYCL" "SYCL0","NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000001","support","1","yes","SYCL" "SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000001,inplace=0","support","1","yes","SYCL" "SYCL0","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.000001","support","1","yes","SYCL" "SYCL0","L2_NORM","type=f32,ne=[64,5,4,3]","support","1","yes","SYCL" +"SYCL0","NORM","type=f32,ne=[1025,5,4,3],v=0,eps=0.000001","support","1","yes","SYCL" +"SYCL0","RMS_NORM","type=f32,ne=[1025,5,4,3],v=0,eps=0.000001,inplace=0","support","1","yes","SYCL" +"SYCL0","NORM","type=f32,ne=[1025,5,4,3],v=1,eps=0.000001","support","1","yes","SYCL" +"SYCL0","RMS_NORM","type=f32,ne=[1025,5,4,3],v=1,eps=0.000001,inplace=0","support","1","yes","SYCL" +"SYCL0","RMS_NORM_BACK","type=f32,ne=[1025,5,4,3],eps=0.000001","support","1","yes","SYCL" +"SYCL0","L2_NORM","type=f32,ne=[1025,5,4,3]","support","1","yes","SYCL" "SYCL0","NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000100","support","1","yes","SYCL" "SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000100,inplace=0","support","1","yes","SYCL" "SYCL0","NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000100","support","1","yes","SYCL" "SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000100,inplace=0","support","1","yes","SYCL" "SYCL0","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.000100","support","1","yes","SYCL" "SYCL0","L2_NORM","type=f32,ne=[64,5,4,3]","support","1","yes","SYCL" +"SYCL0","NORM","type=f32,ne=[1025,5,4,3],v=0,eps=0.000100","support","1","yes","SYCL" +"SYCL0","RMS_NORM","type=f32,ne=[1025,5,4,3],v=0,eps=0.000100,inplace=0","support","1","yes","SYCL" +"SYCL0","NORM","type=f32,ne=[1025,5,4,3],v=1,eps=0.000100","support","1","yes","SYCL" +"SYCL0","RMS_NORM","type=f32,ne=[1025,5,4,3],v=1,eps=0.000100,inplace=0","support","1","yes","SYCL" +"SYCL0","RMS_NORM_BACK","type=f32,ne=[1025,5,4,3],eps=0.000100","support","1","yes","SYCL" +"SYCL0","L2_NORM","type=f32,ne=[1025,5,4,3]","support","1","yes","SYCL" "SYCL0","NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.100000","support","1","yes","SYCL" "SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.100000,inplace=0","support","1","yes","SYCL" "SYCL0","NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.100000","support","1","yes","SYCL" "SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.100000,inplace=0","support","1","yes","SYCL" "SYCL0","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.100000","support","1","yes","SYCL" "SYCL0","L2_NORM","type=f32,ne=[64,5,4,3]","support","1","yes","SYCL" +"SYCL0","NORM","type=f32,ne=[1025,5,4,3],v=0,eps=0.100000","support","1","yes","SYCL" +"SYCL0","RMS_NORM","type=f32,ne=[1025,5,4,3],v=0,eps=0.100000,inplace=0","support","1","yes","SYCL" +"SYCL0","NORM","type=f32,ne=[1025,5,4,3],v=1,eps=0.100000","support","1","yes","SYCL" +"SYCL0","RMS_NORM","type=f32,ne=[1025,5,4,3],v=1,eps=0.100000,inplace=0","support","1","yes","SYCL" +"SYCL0","RMS_NORM_BACK","type=f32,ne=[1025,5,4,3],eps=0.100000","support","1","yes","SYCL" +"SYCL0","L2_NORM","type=f32,ne=[1025,5,4,3]","support","1","yes","SYCL" "SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000001,inplace=1","support","1","yes","SYCL" -"SYCL0","L2_NORM","type=f32,ne=[64,5,4,3]","support","1","yes","SYCL" -"SYCL0","SSM_CONV","type=f32,ne_a=[4,1024,1,1],ne_b=[3,1024,1,1]","support","1","yes","SYCL" -"SYCL0","SSM_CONV","type=f32,ne_a=[8,1024,1,1],ne_b=[3,1024,1,1]","support","1","yes","SYCL" -"SYCL0","SSM_CONV","type=f32,ne_a=[4,1024,4,1],ne_b=[3,1024,1,1]","support","1","yes","SYCL" -"SYCL0","SSM_CONV","type=f32,ne_a=[4,1536,1,1],ne_b=[3,1536,1,1]","support","1","yes","SYCL" -"SYCL0","SSM_CONV","type=f32,ne_a=[8,1536,1,1],ne_b=[3,1536,1,1]","support","1","yes","SYCL" -"SYCL0","SSM_CONV","type=f32,ne_a=[4,1536,4,1],ne_b=[3,1536,1,1]","support","1","yes","SYCL" -"SYCL0","SSM_CONV","type=f32,ne_a=[4,2048,1,1],ne_b=[3,2048,1,1]","support","1","yes","SYCL" -"SYCL0","SSM_CONV","type=f32,ne_a=[8,2048,1,1],ne_b=[3,2048,1,1]","support","1","yes","SYCL" -"SYCL0","SSM_CONV","type=f32,ne_a=[4,2048,4,1],ne_b=[3,2048,1,1]","support","1","yes","SYCL" +"SYCL0","SSM_CONV","type=f32,ne_a=[3,1024,1,1],ne_b=[3,1024,1,1]","support","1","yes","SYCL" +"SYCL0","SSM_CONV","type=f32,ne_a=[6,1024,1,1],ne_b=[3,1024,1,1]","support","1","yes","SYCL" +"SYCL0","SSM_CONV","type=f32,ne_a=[3,1024,4,1],ne_b=[3,1024,1,1]","support","1","yes","SYCL" +"SYCL0","SSM_CONV","type=f32,ne_a=[3,1536,1,1],ne_b=[3,1536,1,1]","support","1","yes","SYCL" +"SYCL0","SSM_CONV","type=f32,ne_a=[6,1536,1,1],ne_b=[3,1536,1,1]","support","1","yes","SYCL" +"SYCL0","SSM_CONV","type=f32,ne_a=[3,1536,4,1],ne_b=[3,1536,1,1]","support","1","yes","SYCL" +"SYCL0","SSM_CONV","type=f32,ne_a=[3,2048,1,1],ne_b=[3,2048,1,1]","support","1","yes","SYCL" +"SYCL0","SSM_CONV","type=f32,ne_a=[6,2048,1,1],ne_b=[3,2048,1,1]","support","1","yes","SYCL" +"SYCL0","SSM_CONV","type=f32,ne_a=[3,2048,4,1],ne_b=[3,2048,1,1]","support","1","yes","SYCL" "SYCL0","SSM_CONV","type=f32,ne_a=[4,1024,1,1],ne_b=[4,1024,1,1]","support","1","yes","SYCL" "SYCL0","SSM_CONV","type=f32,ne_a=[8,1024,1,1],ne_b=[4,1024,1,1]","support","1","yes","SYCL" "SYCL0","SSM_CONV","type=f32,ne_a=[4,1024,4,1],ne_b=[4,1024,1,1]","support","1","yes","SYCL" @@ -5734,6 +5806,15 @@ "SYCL0","SSM_CONV","type=f32,ne_a=[4,2048,1,1],ne_b=[4,2048,1,1]","support","1","yes","SYCL" "SYCL0","SSM_CONV","type=f32,ne_a=[8,2048,1,1],ne_b=[4,2048,1,1]","support","1","yes","SYCL" "SYCL0","SSM_CONV","type=f32,ne_a=[4,2048,4,1],ne_b=[4,2048,1,1]","support","1","yes","SYCL" +"SYCL0","SSM_CONV","type=f32,ne_a=[9,1024,1,1],ne_b=[9,1024,1,1]","support","1","yes","SYCL" +"SYCL0","SSM_CONV","type=f32,ne_a=[18,1024,1,1],ne_b=[9,1024,1,1]","support","1","yes","SYCL" +"SYCL0","SSM_CONV","type=f32,ne_a=[9,1024,4,1],ne_b=[9,1024,1,1]","support","1","yes","SYCL" +"SYCL0","SSM_CONV","type=f32,ne_a=[9,1536,1,1],ne_b=[9,1536,1,1]","support","1","yes","SYCL" +"SYCL0","SSM_CONV","type=f32,ne_a=[18,1536,1,1],ne_b=[9,1536,1,1]","support","1","yes","SYCL" +"SYCL0","SSM_CONV","type=f32,ne_a=[9,1536,4,1],ne_b=[9,1536,1,1]","support","1","yes","SYCL" +"SYCL0","SSM_CONV","type=f32,ne_a=[9,2048,1,1],ne_b=[9,2048,1,1]","support","1","yes","SYCL" +"SYCL0","SSM_CONV","type=f32,ne_a=[18,2048,1,1],ne_b=[9,2048,1,1]","support","1","yes","SYCL" +"SYCL0","SSM_CONV","type=f32,ne_a=[9,2048,4,1],ne_b=[9,2048,1,1]","support","1","yes","SYCL" "SYCL0","SSM_SCAN","type=f32,d_state=16,head_dim=1,n_head=1024,n_group=1,n_seq_tokens=32,n_seqs=4","support","0","no","SYCL" "SYCL0","SSM_SCAN","type=f32,d_state=128,head_dim=64,n_head=16,n_group=2,n_seq_tokens=32,n_seqs=4","support","0","no","SYCL" "SYCL0","SSM_SCAN","type=f32,d_state=256,head_dim=64,n_head=8,n_group=2,n_seq_tokens=32,n_seqs=4","support","0","no","SYCL" @@ -6593,6 +6674,30 @@ "SYCL0","MUL_MAT","type_a=f16,type_b=f32,m=1056,n=1,k=67,bs=[1,1],nr=[4,1],per=[0,2,1,3],k_v=0,o=1","support","1","yes","SYCL" "SYCL0","MUL_MAT","type_a=f32,type_b=f32,m=64,n=77,k=77,bs=[12,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1","support","1","yes","SYCL" "SYCL0","MUL_MAT","type_a=q4_0,type_b=f32,m=576,n=512,k=576,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1","support","1","yes","SYCL" +"SYCL0","MUL_MAT","type_a=q4_0,type_b=f32,m=1,n=2048,k=8192,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1","support","1","yes","SYCL" +"SYCL0","MUL_MAT","type_a=f32,type_b=f32,m=1,n=64,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1","support","1","yes","SYCL" +"SYCL0","MUL_MAT","type_a=f16,type_b=f32,m=1,n=64,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1","support","1","yes","SYCL" +"SYCL0","MUL_MAT","type_a=bf16,type_b=f32,m=1,n=64,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1","support","0","no","SYCL" +"SYCL0","MUL_MAT","type_a=q4_0,type_b=f32,m=1,n=64,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1","support","1","yes","SYCL" +"SYCL0","MUL_MAT","type_a=q4_1,type_b=f32,m=1,n=64,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1","support","1","yes","SYCL" +"SYCL0","MUL_MAT","type_a=q5_0,type_b=f32,m=1,n=64,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1","support","1","yes","SYCL" +"SYCL0","MUL_MAT","type_a=q5_1,type_b=f32,m=1,n=64,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1","support","1","yes","SYCL" +"SYCL0","MUL_MAT","type_a=q8_0,type_b=f32,m=1,n=64,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1","support","1","yes","SYCL" +"SYCL0","MUL_MAT","type_a=mxfp4,type_b=f32,m=1,n=64,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1","support","1","yes","SYCL" +"SYCL0","MUL_MAT","type_a=q2_K,type_b=f32,m=1,n=64,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1","support","1","yes","SYCL" +"SYCL0","MUL_MAT","type_a=q3_K,type_b=f32,m=1,n=64,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1","support","1","yes","SYCL" +"SYCL0","MUL_MAT","type_a=q4_K,type_b=f32,m=1,n=64,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1","support","1","yes","SYCL" +"SYCL0","MUL_MAT","type_a=q5_K,type_b=f32,m=1,n=64,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1","support","1","yes","SYCL" +"SYCL0","MUL_MAT","type_a=q6_K,type_b=f32,m=1,n=64,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1","support","1","yes","SYCL" +"SYCL0","MUL_MAT","type_a=iq2_xxs,type_b=f32,m=1,n=64,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1","support","1","yes","SYCL" +"SYCL0","MUL_MAT","type_a=iq2_xs,type_b=f32,m=1,n=64,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1","support","1","yes","SYCL" +"SYCL0","MUL_MAT","type_a=iq2_s,type_b=f32,m=1,n=64,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1","support","1","yes","SYCL" +"SYCL0","MUL_MAT","type_a=iq3_xxs,type_b=f32,m=1,n=64,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1","support","1","yes","SYCL" +"SYCL0","MUL_MAT","type_a=iq1_s,type_b=f32,m=1,n=64,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1","support","1","yes","SYCL" +"SYCL0","MUL_MAT","type_a=iq1_m,type_b=f32,m=1,n=64,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1","support","1","yes","SYCL" +"SYCL0","MUL_MAT","type_a=iq4_nl,type_b=f32,m=1,n=64,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1","support","1","yes","SYCL" +"SYCL0","MUL_MAT","type_a=iq3_s,type_b=f32,m=1,n=64,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1","support","1","yes","SYCL" +"SYCL0","MUL_MAT","type_a=iq4_xs,type_b=f32,m=1,n=64,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1","support","1","yes","SYCL" "SYCL0","MUL_MAT","type_a=f16,type_b=f32,m=1056,n=1,k=128,bs=[1,1],nr=[1,1],per=[0,2,1,3],k_v=0,o=1","support","1","yes","SYCL" "SYCL0","MUL_MAT","type_a=f16,type_b=f32,m=128,n=1,k=1056,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=2112,o=1","support","1","yes","SYCL" "SYCL0","MUL_MAT","type_a=bf16,type_b=f32,m=1056,n=1,k=128,bs=[1,1],nr=[1,1],per=[0,2,1,3],k_v=0,o=1","support","0","no","SYCL" @@ -8917,6 +9022,11 @@ "SYCL0","SOFT_MAX","type=f32,ne=[32,2,32,1],mask=1,sinks=0,m_prec=f16,nr23=[1,1],scale=0.100000,max_bias=0.000000,inplace=0","support","1","yes","SYCL" "SYCL0","SOFT_MAX","type=f32,ne=[32,2,32,1],mask=1,sinks=1,m_prec=f32,nr23=[1,1],scale=0.100000,max_bias=8.000000,inplace=0","support","1","yes","SYCL" "SYCL0","SOFT_MAX","type=f32,ne=[32,2,32,1],mask=1,sinks=1,m_prec=f16,nr23=[1,1],scale=0.100000,max_bias=8.000000,inplace=0","support","1","yes","SYCL" +"SYCL0","SOFT_MAX","type=f32,ne=[200001,2,3,1],mask=1,sinks=1,m_prec=f32,nr23=[1,1],scale=0.100000,max_bias=8.000000,inplace=0","support","1","yes","SYCL" +"SYCL0","SOFT_MAX","type=f32,ne=[200001,2,3,1],mask=1,sinks=1,m_prec=f16,nr23=[1,1],scale=0.100000,max_bias=8.000000,inplace=0","support","1","yes","SYCL" +"SYCL0","SOFT_MAX","type=f32,ne=[200000,1,1,1],mask=0,sinks=0,m_prec=f32,nr23=[1,1],scale=1.000000,max_bias=0.000000,inplace=0","support","1","yes","SYCL" +"SYCL0","SOFT_MAX","type=f32,ne=[200000,4,1,1],mask=0,sinks=0,m_prec=f32,nr23=[1,1],scale=1.000000,max_bias=0.000000,inplace=0","support","1","yes","SYCL" +"SYCL0","SOFT_MAX","type=f32,ne=[643251,3,1,1],mask=0,sinks=0,m_prec=f32,nr23=[1,1],scale=1.000000,max_bias=0.000000,inplace=0","support","1","yes","SYCL" "SYCL0","SOFT_MAX_BACK","type=f32,ne=[16,16,1,1],scale=1.000000,max_bias=0.000000","support","1","yes","SYCL" "SYCL0","SOFT_MAX_BACK","type=f32,ne=[15,15,1,1],scale=1.000000,max_bias=0.000000","support","1","yes","SYCL" "SYCL0","SOFT_MAX_BACK","type=f32,ne=[16,16,2,3],scale=1.000000,max_bias=0.000000","support","1","yes","SYCL" @@ -8969,6 +9079,7 @@ "SYCL0","ROPE","type=f32,ne_a=[128,40,2,1],n_dims=128,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","1","yes","SYCL" "SYCL0","ROPE","type=f32,ne_a=[128,52,2,1],n_dims=128,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","1","yes","SYCL" "SYCL0","ROPE","type=f32,ne_a=[128,64,2,1],n_dims=128,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","1","yes","SYCL" +"SYCL0","ROPE","type=f32,ne_a=[16,16,8192,1],n_dims=16,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","1","yes","SYCL" "SYCL0","ROPE","type=f32,ne_a=[64,1,2,1],n_dims=64,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","1","yes","SYCL" "SYCL0","ROPE","type=f32,ne_a=[64,71,2,1],n_dims=64,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","1","yes","SYCL" "SYCL0","ROPE","type=f32,ne_a=[64,8,2,1],n_dims=64,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","1","yes","SYCL" @@ -8978,6 +9089,7 @@ "SYCL0","ROPE","type=f32,ne_a=[80,32,2,1],n_dims=20,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","1","yes","SYCL" "SYCL0","ROPE","type=f32,ne_a=[80,32,2,1],n_dims=32,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","1","yes","SYCL" "SYCL0","ROPE","type=f32,ne_a=[80,32,4,1],n_dims=32,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","1","yes","SYCL" +"SYCL0","ROPE","type=f32,ne_a=[16,16,8192,1],n_dims=16,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","1","yes","SYCL" "SYCL0","ROPE","type=f32,ne_a=[128,12,2,1],n_dims=128,mode=8,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","1","yes","SYCL" "SYCL0","ROPE","type=f32,ne_a=[128,28,2,1],n_dims=128,mode=8,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","1","yes","SYCL" "SYCL0","ROPE","type=f32,ne_a=[128,12,2,1],n_dims=20,mode=8,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","1","yes","SYCL" @@ -8988,11 +9100,13 @@ "SYCL0","ROPE","type=f32,ne_a=[128,28,2,1],n_dims=32,mode=40,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","1","yes","SYCL" "SYCL0","ROPE","type=f32,ne_a=[80,16,2,1],n_dims=80,mode=24,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","1","yes","SYCL" "SYCL0","ROPE","type=f32,ne_a=[128,16,2,1],n_dims=128,mode=40,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","1","yes","SYCL" +"SYCL0","ROPE","type=f32,ne_a=[16,16,8192,1],n_dims=16,mode=40,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","1","yes","SYCL" "SYCL0","ROPE","type=f32,ne_a=[64,128,2,1],n_dims=64,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","1","yes","SYCL" "SYCL0","ROPE","type=f32,ne_a=[128,32,2,1],n_dims=128,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","1","yes","SYCL" "SYCL0","ROPE","type=f32,ne_a=[128,40,2,1],n_dims=128,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","1","yes","SYCL" "SYCL0","ROPE","type=f32,ne_a=[128,52,2,1],n_dims=128,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","1","yes","SYCL" "SYCL0","ROPE","type=f32,ne_a=[128,64,2,1],n_dims=128,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","1","yes","SYCL" +"SYCL0","ROPE","type=f32,ne_a=[16,16,8192,1],n_dims=16,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","1","yes","SYCL" "SYCL0","ROPE","type=f32,ne_a=[64,1,2,1],n_dims=64,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","1","yes","SYCL" "SYCL0","ROPE","type=f32,ne_a=[64,71,2,1],n_dims=64,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","1","yes","SYCL" "SYCL0","ROPE","type=f32,ne_a=[64,8,2,1],n_dims=64,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","1","yes","SYCL" @@ -9002,6 +9116,7 @@ "SYCL0","ROPE","type=f32,ne_a=[80,32,2,1],n_dims=20,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","1","yes","SYCL" "SYCL0","ROPE","type=f32,ne_a=[80,32,2,1],n_dims=32,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","1","yes","SYCL" "SYCL0","ROPE","type=f32,ne_a=[80,32,4,1],n_dims=32,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","1","yes","SYCL" +"SYCL0","ROPE","type=f32,ne_a=[16,16,8192,1],n_dims=16,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","1","yes","SYCL" "SYCL0","ROPE","type=f32,ne_a=[128,12,2,1],n_dims=128,mode=8,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","1","yes","SYCL" "SYCL0","ROPE","type=f32,ne_a=[128,28,2,1],n_dims=128,mode=8,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","1","yes","SYCL" "SYCL0","ROPE","type=f32,ne_a=[128,12,2,1],n_dims=20,mode=8,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","1","yes","SYCL" @@ -9012,11 +9127,13 @@ "SYCL0","ROPE","type=f32,ne_a=[128,28,2,1],n_dims=32,mode=40,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","1","yes","SYCL" "SYCL0","ROPE","type=f32,ne_a=[80,16,2,1],n_dims=80,mode=24,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","1","yes","SYCL" "SYCL0","ROPE","type=f32,ne_a=[128,16,2,1],n_dims=128,mode=40,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","1","yes","SYCL" +"SYCL0","ROPE","type=f32,ne_a=[16,16,8192,1],n_dims=16,mode=40,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","1","yes","SYCL" "SYCL0","ROPE","type=f32,ne_a=[64,128,2,1],n_dims=64,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","1","yes","SYCL" "SYCL0","ROPE","type=f32,ne_a=[128,32,2,1],n_dims=128,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","1","yes","SYCL" "SYCL0","ROPE","type=f32,ne_a=[128,40,2,1],n_dims=128,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","1","yes","SYCL" "SYCL0","ROPE","type=f32,ne_a=[128,52,2,1],n_dims=128,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","1","yes","SYCL" "SYCL0","ROPE","type=f32,ne_a=[128,64,2,1],n_dims=128,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","1","yes","SYCL" +"SYCL0","ROPE","type=f32,ne_a=[16,16,8192,1],n_dims=16,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","1","yes","SYCL" "SYCL0","ROPE","type=f32,ne_a=[64,1,2,1],n_dims=64,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","1","yes","SYCL" "SYCL0","ROPE","type=f32,ne_a=[64,71,2,1],n_dims=64,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","1","yes","SYCL" "SYCL0","ROPE","type=f32,ne_a=[64,8,2,1],n_dims=64,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","1","yes","SYCL" @@ -9026,6 +9143,7 @@ "SYCL0","ROPE","type=f32,ne_a=[80,32,2,1],n_dims=20,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","1","yes","SYCL" "SYCL0","ROPE","type=f32,ne_a=[80,32,2,1],n_dims=32,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","1","yes","SYCL" "SYCL0","ROPE","type=f32,ne_a=[80,32,4,1],n_dims=32,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","1","yes","SYCL" +"SYCL0","ROPE","type=f32,ne_a=[16,16,8192,1],n_dims=16,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","1","yes","SYCL" "SYCL0","ROPE","type=f32,ne_a=[128,12,2,1],n_dims=128,mode=8,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","1","yes","SYCL" "SYCL0","ROPE","type=f32,ne_a=[128,28,2,1],n_dims=128,mode=8,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","1","yes","SYCL" "SYCL0","ROPE","type=f32,ne_a=[128,12,2,1],n_dims=20,mode=8,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","1","yes","SYCL" @@ -9036,11 +9154,13 @@ "SYCL0","ROPE","type=f32,ne_a=[128,28,2,1],n_dims=32,mode=40,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","1","yes","SYCL" "SYCL0","ROPE","type=f32,ne_a=[80,16,2,1],n_dims=80,mode=24,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","1","yes","SYCL" "SYCL0","ROPE","type=f32,ne_a=[128,16,2,1],n_dims=128,mode=40,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","1","yes","SYCL" +"SYCL0","ROPE","type=f32,ne_a=[16,16,8192,1],n_dims=16,mode=40,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","1","yes","SYCL" "SYCL0","ROPE","type=f32,ne_a=[64,128,2,1],n_dims=64,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","1","yes","SYCL" "SYCL0","ROPE","type=f32,ne_a=[128,32,2,1],n_dims=128,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","1","yes","SYCL" "SYCL0","ROPE","type=f32,ne_a=[128,40,2,1],n_dims=128,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","1","yes","SYCL" "SYCL0","ROPE","type=f32,ne_a=[128,52,2,1],n_dims=128,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","1","yes","SYCL" "SYCL0","ROPE","type=f32,ne_a=[128,64,2,1],n_dims=128,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","1","yes","SYCL" +"SYCL0","ROPE","type=f32,ne_a=[16,16,8192,1],n_dims=16,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","1","yes","SYCL" "SYCL0","ROPE","type=f32,ne_a=[64,1,2,1],n_dims=64,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","1","yes","SYCL" "SYCL0","ROPE","type=f32,ne_a=[64,71,2,1],n_dims=64,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","1","yes","SYCL" "SYCL0","ROPE","type=f32,ne_a=[64,8,2,1],n_dims=64,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","1","yes","SYCL" @@ -9050,6 +9170,7 @@ "SYCL0","ROPE","type=f32,ne_a=[80,32,2,1],n_dims=20,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","1","yes","SYCL" "SYCL0","ROPE","type=f32,ne_a=[80,32,2,1],n_dims=32,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","1","yes","SYCL" "SYCL0","ROPE","type=f32,ne_a=[80,32,4,1],n_dims=32,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","1","yes","SYCL" +"SYCL0","ROPE","type=f32,ne_a=[16,16,8192,1],n_dims=16,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","1","yes","SYCL" "SYCL0","ROPE","type=f32,ne_a=[128,12,2,1],n_dims=128,mode=8,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","1","yes","SYCL" "SYCL0","ROPE","type=f32,ne_a=[128,28,2,1],n_dims=128,mode=8,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","1","yes","SYCL" "SYCL0","ROPE","type=f32,ne_a=[128,12,2,1],n_dims=20,mode=8,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","1","yes","SYCL" @@ -9060,6 +9181,7 @@ "SYCL0","ROPE","type=f32,ne_a=[128,28,2,1],n_dims=32,mode=40,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","1","yes","SYCL" "SYCL0","ROPE","type=f32,ne_a=[80,16,2,1],n_dims=80,mode=24,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","1","yes","SYCL" "SYCL0","ROPE","type=f32,ne_a=[128,16,2,1],n_dims=128,mode=40,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","1","yes","SYCL" +"SYCL0","ROPE","type=f32,ne_a=[16,16,8192,1],n_dims=16,mode=40,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","1","yes","SYCL" "SYCL0","ROPE","type=f32,ne_a=[64,128,2,1],n_dims=64,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","1","yes","SYCL" "SYCL0","ROPE","type=f16,ne_a=[128,32,2,1],n_dims=128,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","1","yes","SYCL" "SYCL0","ROPE","type=f16,ne_a=[64,128,2,1],n_dims=64,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","1","yes","SYCL" @@ -9185,6 +9307,7 @@ "SYCL0","ROPE_BACK","type=f32,ne_a=[128,40,2,1],n_dims=128,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","0","no","SYCL" "SYCL0","ROPE_BACK","type=f32,ne_a=[128,52,2,1],n_dims=128,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","0","no","SYCL" "SYCL0","ROPE_BACK","type=f32,ne_a=[128,64,2,1],n_dims=128,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","0","no","SYCL" +"SYCL0","ROPE_BACK","type=f32,ne_a=[16,16,8192,1],n_dims=16,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","0","no","SYCL" "SYCL0","ROPE_BACK","type=f32,ne_a=[64,1,2,1],n_dims=64,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","0","no","SYCL" "SYCL0","ROPE_BACK","type=f32,ne_a=[64,71,2,1],n_dims=64,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","0","no","SYCL" "SYCL0","ROPE_BACK","type=f32,ne_a=[64,8,2,1],n_dims=64,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","0","no","SYCL" @@ -9194,6 +9317,7 @@ "SYCL0","ROPE_BACK","type=f32,ne_a=[80,32,2,1],n_dims=20,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","0","no","SYCL" "SYCL0","ROPE_BACK","type=f32,ne_a=[80,32,2,1],n_dims=32,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","0","no","SYCL" "SYCL0","ROPE_BACK","type=f32,ne_a=[80,32,4,1],n_dims=32,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","0","no","SYCL" +"SYCL0","ROPE_BACK","type=f32,ne_a=[16,16,8192,1],n_dims=16,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","0","no","SYCL" "SYCL0","ROPE_BACK","type=f32,ne_a=[128,12,2,1],n_dims=128,mode=8,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","0","no","SYCL" "SYCL0","ROPE_BACK","type=f32,ne_a=[128,28,2,1],n_dims=128,mode=8,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","0","no","SYCL" "SYCL0","ROPE_BACK","type=f32,ne_a=[128,12,2,1],n_dims=20,mode=8,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","0","no","SYCL" @@ -9204,11 +9328,13 @@ "SYCL0","ROPE_BACK","type=f32,ne_a=[128,28,2,1],n_dims=32,mode=40,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","0","no","SYCL" "SYCL0","ROPE_BACK","type=f32,ne_a=[80,16,2,1],n_dims=80,mode=24,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","0","no","SYCL" "SYCL0","ROPE_BACK","type=f32,ne_a=[128,16,2,1],n_dims=128,mode=40,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","0","no","SYCL" +"SYCL0","ROPE_BACK","type=f32,ne_a=[16,16,8192,1],n_dims=16,mode=40,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","0","no","SYCL" "SYCL0","ROPE_BACK","type=f32,ne_a=[64,128,2,1],n_dims=64,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","0","no","SYCL" "SYCL0","ROPE_BACK","type=f32,ne_a=[128,32,2,1],n_dims=128,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","0","no","SYCL" "SYCL0","ROPE_BACK","type=f32,ne_a=[128,40,2,1],n_dims=128,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","0","no","SYCL" "SYCL0","ROPE_BACK","type=f32,ne_a=[128,52,2,1],n_dims=128,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","0","no","SYCL" "SYCL0","ROPE_BACK","type=f32,ne_a=[128,64,2,1],n_dims=128,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","0","no","SYCL" +"SYCL0","ROPE_BACK","type=f32,ne_a=[16,16,8192,1],n_dims=16,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","0","no","SYCL" "SYCL0","ROPE_BACK","type=f32,ne_a=[64,1,2,1],n_dims=64,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","0","no","SYCL" "SYCL0","ROPE_BACK","type=f32,ne_a=[64,71,2,1],n_dims=64,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","0","no","SYCL" "SYCL0","ROPE_BACK","type=f32,ne_a=[64,8,2,1],n_dims=64,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","0","no","SYCL" @@ -9218,6 +9344,7 @@ "SYCL0","ROPE_BACK","type=f32,ne_a=[80,32,2,1],n_dims=20,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","0","no","SYCL" "SYCL0","ROPE_BACK","type=f32,ne_a=[80,32,2,1],n_dims=32,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","0","no","SYCL" "SYCL0","ROPE_BACK","type=f32,ne_a=[80,32,4,1],n_dims=32,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","0","no","SYCL" +"SYCL0","ROPE_BACK","type=f32,ne_a=[16,16,8192,1],n_dims=16,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","0","no","SYCL" "SYCL0","ROPE_BACK","type=f32,ne_a=[128,12,2,1],n_dims=128,mode=8,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","0","no","SYCL" "SYCL0","ROPE_BACK","type=f32,ne_a=[128,28,2,1],n_dims=128,mode=8,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","0","no","SYCL" "SYCL0","ROPE_BACK","type=f32,ne_a=[128,12,2,1],n_dims=20,mode=8,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","0","no","SYCL" @@ -9228,11 +9355,13 @@ "SYCL0","ROPE_BACK","type=f32,ne_a=[128,28,2,1],n_dims=32,mode=40,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","0","no","SYCL" "SYCL0","ROPE_BACK","type=f32,ne_a=[80,16,2,1],n_dims=80,mode=24,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","0","no","SYCL" "SYCL0","ROPE_BACK","type=f32,ne_a=[128,16,2,1],n_dims=128,mode=40,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","0","no","SYCL" +"SYCL0","ROPE_BACK","type=f32,ne_a=[16,16,8192,1],n_dims=16,mode=40,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","0","no","SYCL" "SYCL0","ROPE_BACK","type=f32,ne_a=[64,128,2,1],n_dims=64,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","0","no","SYCL" "SYCL0","ROPE_BACK","type=f32,ne_a=[128,32,2,1],n_dims=128,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","0","no","SYCL" "SYCL0","ROPE_BACK","type=f32,ne_a=[128,40,2,1],n_dims=128,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","0","no","SYCL" "SYCL0","ROPE_BACK","type=f32,ne_a=[128,52,2,1],n_dims=128,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","0","no","SYCL" "SYCL0","ROPE_BACK","type=f32,ne_a=[128,64,2,1],n_dims=128,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","0","no","SYCL" +"SYCL0","ROPE_BACK","type=f32,ne_a=[16,16,8192,1],n_dims=16,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","0","no","SYCL" "SYCL0","ROPE_BACK","type=f32,ne_a=[64,1,2,1],n_dims=64,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","0","no","SYCL" "SYCL0","ROPE_BACK","type=f32,ne_a=[64,71,2,1],n_dims=64,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","0","no","SYCL" "SYCL0","ROPE_BACK","type=f32,ne_a=[64,8,2,1],n_dims=64,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","0","no","SYCL" @@ -9242,6 +9371,7 @@ "SYCL0","ROPE_BACK","type=f32,ne_a=[80,32,2,1],n_dims=20,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","0","no","SYCL" "SYCL0","ROPE_BACK","type=f32,ne_a=[80,32,2,1],n_dims=32,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","0","no","SYCL" "SYCL0","ROPE_BACK","type=f32,ne_a=[80,32,4,1],n_dims=32,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","0","no","SYCL" +"SYCL0","ROPE_BACK","type=f32,ne_a=[16,16,8192,1],n_dims=16,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","0","no","SYCL" "SYCL0","ROPE_BACK","type=f32,ne_a=[128,12,2,1],n_dims=128,mode=8,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","0","no","SYCL" "SYCL0","ROPE_BACK","type=f32,ne_a=[128,28,2,1],n_dims=128,mode=8,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","0","no","SYCL" "SYCL0","ROPE_BACK","type=f32,ne_a=[128,12,2,1],n_dims=20,mode=8,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","0","no","SYCL" @@ -9252,11 +9382,13 @@ "SYCL0","ROPE_BACK","type=f32,ne_a=[128,28,2,1],n_dims=32,mode=40,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","0","no","SYCL" "SYCL0","ROPE_BACK","type=f32,ne_a=[80,16,2,1],n_dims=80,mode=24,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","0","no","SYCL" "SYCL0","ROPE_BACK","type=f32,ne_a=[128,16,2,1],n_dims=128,mode=40,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","0","no","SYCL" +"SYCL0","ROPE_BACK","type=f32,ne_a=[16,16,8192,1],n_dims=16,mode=40,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","0","no","SYCL" "SYCL0","ROPE_BACK","type=f32,ne_a=[64,128,2,1],n_dims=64,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","0","no","SYCL" "SYCL0","ROPE_BACK","type=f32,ne_a=[128,32,2,1],n_dims=128,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","0","no","SYCL" "SYCL0","ROPE_BACK","type=f32,ne_a=[128,40,2,1],n_dims=128,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","0","no","SYCL" "SYCL0","ROPE_BACK","type=f32,ne_a=[128,52,2,1],n_dims=128,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","0","no","SYCL" "SYCL0","ROPE_BACK","type=f32,ne_a=[128,64,2,1],n_dims=128,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","0","no","SYCL" +"SYCL0","ROPE_BACK","type=f32,ne_a=[16,16,8192,1],n_dims=16,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","0","no","SYCL" "SYCL0","ROPE_BACK","type=f32,ne_a=[64,1,2,1],n_dims=64,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","0","no","SYCL" "SYCL0","ROPE_BACK","type=f32,ne_a=[64,71,2,1],n_dims=64,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","0","no","SYCL" "SYCL0","ROPE_BACK","type=f32,ne_a=[64,8,2,1],n_dims=64,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","0","no","SYCL" @@ -9266,6 +9398,7 @@ "SYCL0","ROPE_BACK","type=f32,ne_a=[80,32,2,1],n_dims=20,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","0","no","SYCL" "SYCL0","ROPE_BACK","type=f32,ne_a=[80,32,2,1],n_dims=32,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","0","no","SYCL" "SYCL0","ROPE_BACK","type=f32,ne_a=[80,32,4,1],n_dims=32,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","0","no","SYCL" +"SYCL0","ROPE_BACK","type=f32,ne_a=[16,16,8192,1],n_dims=16,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","0","no","SYCL" "SYCL0","ROPE_BACK","type=f32,ne_a=[128,12,2,1],n_dims=128,mode=8,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","0","no","SYCL" "SYCL0","ROPE_BACK","type=f32,ne_a=[128,28,2,1],n_dims=128,mode=8,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","0","no","SYCL" "SYCL0","ROPE_BACK","type=f32,ne_a=[128,12,2,1],n_dims=20,mode=8,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","0","no","SYCL" @@ -9276,6 +9409,7 @@ "SYCL0","ROPE_BACK","type=f32,ne_a=[128,28,2,1],n_dims=32,mode=40,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","0","no","SYCL" "SYCL0","ROPE_BACK","type=f32,ne_a=[80,16,2,1],n_dims=80,mode=24,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","0","no","SYCL" "SYCL0","ROPE_BACK","type=f32,ne_a=[128,16,2,1],n_dims=128,mode=40,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","0","no","SYCL" +"SYCL0","ROPE_BACK","type=f32,ne_a=[16,16,8192,1],n_dims=16,mode=40,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","0","no","SYCL" "SYCL0","ROPE_BACK","type=f32,ne_a=[64,128,2,1],n_dims=64,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","0","no","SYCL" "SYCL0","ROPE_BACK","type=f16,ne_a=[128,32,2,1],n_dims=128,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","0","no","SYCL" "SYCL0","ROPE_BACK","type=f16,ne_a=[64,128,2,1],n_dims=64,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","0","no","SYCL" @@ -9543,168 +9677,168 @@ "SYCL0","ARGSORT","type=f32,ne=[2048,2,1,3],order=1","support","1","yes","SYCL" "SYCL0","ARGSORT","type=f32,ne=[2049,2,1,3],order=1","support","1","yes","SYCL" "SYCL0","ARGSORT","type=f32,ne=[2,8,8192,1],order=1","support","1","yes","SYCL" -"SYCL0","TOP_K","type=f32,ne=[1,1,1,1],k=1,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[12,1,2,1],k=1,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[2,1,1,1],k=1,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[13,1,2,1],k=1,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[2,1,1,1],k=2,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[13,1,2,1],k=2,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[4,1,1,1],k=1,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[15,1,2,1],k=1,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[4,1,1,1],k=2,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[15,1,2,1],k=2,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[4,1,1,1],k=3,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[15,1,2,1],k=3,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[8,1,1,1],k=1,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[19,1,2,1],k=1,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[8,1,1,1],k=2,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[19,1,2,1],k=2,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[8,1,1,1],k=3,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[19,1,2,1],k=3,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[8,1,1,1],k=7,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[19,1,2,1],k=7,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[16,1,1,1],k=1,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[27,1,2,1],k=1,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[16,1,1,1],k=2,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[27,1,2,1],k=2,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[16,1,1,1],k=3,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[27,1,2,1],k=3,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[16,1,1,1],k=7,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[27,1,2,1],k=7,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[16,1,1,1],k=15,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[27,1,2,1],k=15,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[32,1,1,1],k=1,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[43,1,2,1],k=1,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[32,1,1,1],k=2,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[43,1,2,1],k=2,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[32,1,1,1],k=3,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[43,1,2,1],k=3,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[32,1,1,1],k=7,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[43,1,2,1],k=7,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[32,1,1,1],k=15,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[43,1,2,1],k=15,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[64,1,1,1],k=1,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[75,1,2,1],k=1,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[64,1,1,1],k=2,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[75,1,2,1],k=2,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[64,1,1,1],k=3,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[75,1,2,1],k=3,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[64,1,1,1],k=7,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[75,1,2,1],k=7,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[64,1,1,1],k=15,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[75,1,2,1],k=15,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[128,1,1,1],k=1,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[139,1,2,1],k=1,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[128,1,1,1],k=2,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[139,1,2,1],k=2,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[128,1,1,1],k=3,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[139,1,2,1],k=3,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[128,1,1,1],k=7,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[139,1,2,1],k=7,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[128,1,1,1],k=15,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[139,1,2,1],k=15,ties=0","support","0","no","SYCL" +"SYCL0","TOP_K","type=f32,ne=[1,1,1,1],k=1,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[12,1,2,1],k=1,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[2,1,1,1],k=1,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[13,1,2,1],k=1,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[2,1,1,1],k=2,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[13,1,2,1],k=2,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[4,1,1,1],k=1,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[15,1,2,1],k=1,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[4,1,1,1],k=2,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[15,1,2,1],k=2,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[4,1,1,1],k=3,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[15,1,2,1],k=3,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[8,1,1,1],k=1,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[19,1,2,1],k=1,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[8,1,1,1],k=2,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[19,1,2,1],k=2,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[8,1,1,1],k=3,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[19,1,2,1],k=3,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[8,1,1,1],k=7,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[19,1,2,1],k=7,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[16,1,1,1],k=1,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[27,1,2,1],k=1,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[16,1,1,1],k=2,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[27,1,2,1],k=2,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[16,1,1,1],k=3,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[27,1,2,1],k=3,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[16,1,1,1],k=7,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[27,1,2,1],k=7,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[16,1,1,1],k=15,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[27,1,2,1],k=15,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[32,1,1,1],k=1,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[43,1,2,1],k=1,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[32,1,1,1],k=2,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[43,1,2,1],k=2,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[32,1,1,1],k=3,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[43,1,2,1],k=3,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[32,1,1,1],k=7,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[43,1,2,1],k=7,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[32,1,1,1],k=15,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[43,1,2,1],k=15,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[64,1,1,1],k=1,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[75,1,2,1],k=1,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[64,1,1,1],k=2,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[75,1,2,1],k=2,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[64,1,1,1],k=3,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[75,1,2,1],k=3,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[64,1,1,1],k=7,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[75,1,2,1],k=7,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[64,1,1,1],k=15,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[75,1,2,1],k=15,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[128,1,1,1],k=1,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[139,1,2,1],k=1,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[128,1,1,1],k=2,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[139,1,2,1],k=2,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[128,1,1,1],k=3,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[139,1,2,1],k=3,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[128,1,1,1],k=7,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[139,1,2,1],k=7,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[128,1,1,1],k=15,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[139,1,2,1],k=15,ties=0","support","1","yes","SYCL" "SYCL0","TOP_K","type=f32,ne=[128,1,1,1],k=100,ties=0","support","0","no","SYCL" "SYCL0","TOP_K","type=f32,ne=[139,1,2,1],k=100,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[256,1,1,1],k=1,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[267,1,2,1],k=1,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[256,1,1,1],k=2,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[267,1,2,1],k=2,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[256,1,1,1],k=3,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[267,1,2,1],k=3,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[256,1,1,1],k=7,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[267,1,2,1],k=7,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[256,1,1,1],k=15,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[267,1,2,1],k=15,ties=0","support","0","no","SYCL" +"SYCL0","TOP_K","type=f32,ne=[256,1,1,1],k=1,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[267,1,2,1],k=1,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[256,1,1,1],k=2,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[267,1,2,1],k=2,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[256,1,1,1],k=3,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[267,1,2,1],k=3,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[256,1,1,1],k=7,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[267,1,2,1],k=7,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[256,1,1,1],k=15,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[267,1,2,1],k=15,ties=0","support","1","yes","SYCL" "SYCL0","TOP_K","type=f32,ne=[256,1,1,1],k=100,ties=0","support","0","no","SYCL" "SYCL0","TOP_K","type=f32,ne=[267,1,2,1],k=100,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[512,1,1,1],k=1,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[523,1,2,1],k=1,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[512,1,1,1],k=2,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[523,1,2,1],k=2,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[512,1,1,1],k=3,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[523,1,2,1],k=3,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[512,1,1,1],k=7,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[523,1,2,1],k=7,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[512,1,1,1],k=15,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[523,1,2,1],k=15,ties=0","support","0","no","SYCL" +"SYCL0","TOP_K","type=f32,ne=[512,1,1,1],k=1,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[523,1,2,1],k=1,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[512,1,1,1],k=2,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[523,1,2,1],k=2,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[512,1,1,1],k=3,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[523,1,2,1],k=3,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[512,1,1,1],k=7,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[523,1,2,1],k=7,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[512,1,1,1],k=15,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[523,1,2,1],k=15,ties=0","support","1","yes","SYCL" "SYCL0","TOP_K","type=f32,ne=[512,1,1,1],k=100,ties=0","support","0","no","SYCL" "SYCL0","TOP_K","type=f32,ne=[523,1,2,1],k=100,ties=0","support","0","no","SYCL" "SYCL0","TOP_K","type=f32,ne=[512,1,1,1],k=500,ties=0","support","0","no","SYCL" "SYCL0","TOP_K","type=f32,ne=[523,1,2,1],k=500,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[1024,1,1,1],k=1,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[1035,1,2,1],k=1,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[1024,1,1,1],k=2,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[1035,1,2,1],k=2,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[1024,1,1,1],k=3,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[1035,1,2,1],k=3,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[1024,1,1,1],k=7,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[1035,1,2,1],k=7,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[1024,1,1,1],k=15,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[1035,1,2,1],k=15,ties=0","support","0","no","SYCL" +"SYCL0","TOP_K","type=f32,ne=[1024,1,1,1],k=1,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[1035,1,2,1],k=1,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[1024,1,1,1],k=2,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[1035,1,2,1],k=2,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[1024,1,1,1],k=3,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[1035,1,2,1],k=3,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[1024,1,1,1],k=7,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[1035,1,2,1],k=7,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[1024,1,1,1],k=15,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[1035,1,2,1],k=15,ties=0","support","1","yes","SYCL" "SYCL0","TOP_K","type=f32,ne=[1024,1,1,1],k=100,ties=0","support","0","no","SYCL" "SYCL0","TOP_K","type=f32,ne=[1035,1,2,1],k=100,ties=0","support","0","no","SYCL" "SYCL0","TOP_K","type=f32,ne=[1024,1,1,1],k=500,ties=0","support","0","no","SYCL" "SYCL0","TOP_K","type=f32,ne=[1035,1,2,1],k=500,ties=0","support","0","no","SYCL" "SYCL0","TOP_K","type=f32,ne=[1024,1,1,1],k=1023,ties=0","support","0","no","SYCL" "SYCL0","TOP_K","type=f32,ne=[1035,1,2,1],k=1023,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[2048,1,1,1],k=1,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[2059,1,2,1],k=1,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[2048,1,1,1],k=2,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[2059,1,2,1],k=2,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[2048,1,1,1],k=3,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[2059,1,2,1],k=3,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[2048,1,1,1],k=7,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[2059,1,2,1],k=7,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[2048,1,1,1],k=15,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[2059,1,2,1],k=15,ties=0","support","0","no","SYCL" +"SYCL0","TOP_K","type=f32,ne=[2048,1,1,1],k=1,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[2059,1,2,1],k=1,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[2048,1,1,1],k=2,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[2059,1,2,1],k=2,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[2048,1,1,1],k=3,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[2059,1,2,1],k=3,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[2048,1,1,1],k=7,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[2059,1,2,1],k=7,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[2048,1,1,1],k=15,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[2059,1,2,1],k=15,ties=0","support","1","yes","SYCL" "SYCL0","TOP_K","type=f32,ne=[2048,1,1,1],k=100,ties=0","support","0","no","SYCL" "SYCL0","TOP_K","type=f32,ne=[2059,1,2,1],k=100,ties=0","support","0","no","SYCL" "SYCL0","TOP_K","type=f32,ne=[2048,1,1,1],k=500,ties=0","support","0","no","SYCL" "SYCL0","TOP_K","type=f32,ne=[2059,1,2,1],k=500,ties=0","support","0","no","SYCL" "SYCL0","TOP_K","type=f32,ne=[2048,1,1,1],k=1023,ties=0","support","0","no","SYCL" "SYCL0","TOP_K","type=f32,ne=[2059,1,2,1],k=1023,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[4096,1,1,1],k=1,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[4107,1,2,1],k=1,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[4096,1,1,1],k=2,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[4107,1,2,1],k=2,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[4096,1,1,1],k=3,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[4107,1,2,1],k=3,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[4096,1,1,1],k=7,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[4107,1,2,1],k=7,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[4096,1,1,1],k=15,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[4107,1,2,1],k=15,ties=0","support","0","no","SYCL" +"SYCL0","TOP_K","type=f32,ne=[4096,1,1,1],k=1,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[4107,1,2,1],k=1,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[4096,1,1,1],k=2,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[4107,1,2,1],k=2,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[4096,1,1,1],k=3,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[4107,1,2,1],k=3,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[4096,1,1,1],k=7,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[4107,1,2,1],k=7,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[4096,1,1,1],k=15,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[4107,1,2,1],k=15,ties=0","support","1","yes","SYCL" "SYCL0","TOP_K","type=f32,ne=[4096,1,1,1],k=100,ties=0","support","0","no","SYCL" "SYCL0","TOP_K","type=f32,ne=[4107,1,2,1],k=100,ties=0","support","0","no","SYCL" "SYCL0","TOP_K","type=f32,ne=[4096,1,1,1],k=500,ties=0","support","0","no","SYCL" "SYCL0","TOP_K","type=f32,ne=[4107,1,2,1],k=500,ties=0","support","0","no","SYCL" "SYCL0","TOP_K","type=f32,ne=[4096,1,1,1],k=1023,ties=0","support","0","no","SYCL" "SYCL0","TOP_K","type=f32,ne=[4107,1,2,1],k=1023,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[8192,1,1,1],k=1,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[8203,1,2,1],k=1,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[8192,1,1,1],k=2,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[8203,1,2,1],k=2,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[8192,1,1,1],k=3,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[8203,1,2,1],k=3,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[8192,1,1,1],k=7,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[8203,1,2,1],k=7,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[8192,1,1,1],k=15,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[8203,1,2,1],k=15,ties=0","support","0","no","SYCL" +"SYCL0","TOP_K","type=f32,ne=[8192,1,1,1],k=1,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[8203,1,2,1],k=1,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[8192,1,1,1],k=2,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[8203,1,2,1],k=2,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[8192,1,1,1],k=3,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[8203,1,2,1],k=3,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[8192,1,1,1],k=7,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[8203,1,2,1],k=7,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[8192,1,1,1],k=15,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[8203,1,2,1],k=15,ties=0","support","1","yes","SYCL" "SYCL0","TOP_K","type=f32,ne=[8192,1,1,1],k=100,ties=0","support","0","no","SYCL" "SYCL0","TOP_K","type=f32,ne=[8203,1,2,1],k=100,ties=0","support","0","no","SYCL" "SYCL0","TOP_K","type=f32,ne=[8192,1,1,1],k=500,ties=0","support","0","no","SYCL" "SYCL0","TOP_K","type=f32,ne=[8203,1,2,1],k=500,ties=0","support","0","no","SYCL" "SYCL0","TOP_K","type=f32,ne=[8192,1,1,1],k=1023,ties=0","support","0","no","SYCL" "SYCL0","TOP_K","type=f32,ne=[8203,1,2,1],k=1023,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[16384,1,1,1],k=1,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[16395,1,2,1],k=1,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[16384,1,1,1],k=2,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[16395,1,2,1],k=2,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[16384,1,1,1],k=3,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[16395,1,2,1],k=3,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[16384,1,1,1],k=7,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[16395,1,2,1],k=7,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[16384,1,1,1],k=15,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[16395,1,2,1],k=15,ties=0","support","0","no","SYCL" +"SYCL0","TOP_K","type=f32,ne=[16384,1,1,1],k=1,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[16395,1,2,1],k=1,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[16384,1,1,1],k=2,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[16395,1,2,1],k=2,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[16384,1,1,1],k=3,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[16395,1,2,1],k=3,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[16384,1,1,1],k=7,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[16395,1,2,1],k=7,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[16384,1,1,1],k=15,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[16395,1,2,1],k=15,ties=0","support","1","yes","SYCL" "SYCL0","TOP_K","type=f32,ne=[16384,1,1,1],k=100,ties=0","support","0","no","SYCL" "SYCL0","TOP_K","type=f32,ne=[16395,1,2,1],k=100,ties=0","support","0","no","SYCL" "SYCL0","TOP_K","type=f32,ne=[16384,1,1,1],k=500,ties=0","support","0","no","SYCL" @@ -9713,16 +9847,16 @@ "SYCL0","TOP_K","type=f32,ne=[16395,1,2,1],k=1023,ties=0","support","0","no","SYCL" "SYCL0","TOP_K","type=f32,ne=[16384,1,1,1],k=9999,ties=0","support","0","no","SYCL" "SYCL0","TOP_K","type=f32,ne=[16395,1,2,1],k=9999,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[32768,1,1,1],k=1,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[32779,1,2,1],k=1,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[32768,1,1,1],k=2,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[32779,1,2,1],k=2,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[32768,1,1,1],k=3,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[32779,1,2,1],k=3,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[32768,1,1,1],k=7,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[32779,1,2,1],k=7,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[32768,1,1,1],k=15,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[32779,1,2,1],k=15,ties=0","support","0","no","SYCL" +"SYCL0","TOP_K","type=f32,ne=[32768,1,1,1],k=1,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[32779,1,2,1],k=1,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[32768,1,1,1],k=2,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[32779,1,2,1],k=2,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[32768,1,1,1],k=3,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[32779,1,2,1],k=3,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[32768,1,1,1],k=7,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[32779,1,2,1],k=7,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[32768,1,1,1],k=15,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[32779,1,2,1],k=15,ties=0","support","1","yes","SYCL" "SYCL0","TOP_K","type=f32,ne=[32768,1,1,1],k=100,ties=0","support","0","no","SYCL" "SYCL0","TOP_K","type=f32,ne=[32779,1,2,1],k=100,ties=0","support","0","no","SYCL" "SYCL0","TOP_K","type=f32,ne=[32768,1,1,1],k=500,ties=0","support","0","no","SYCL" @@ -9731,16 +9865,16 @@ "SYCL0","TOP_K","type=f32,ne=[32779,1,2,1],k=1023,ties=0","support","0","no","SYCL" "SYCL0","TOP_K","type=f32,ne=[32768,1,1,1],k=9999,ties=0","support","0","no","SYCL" "SYCL0","TOP_K","type=f32,ne=[32779,1,2,1],k=9999,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[65536,1,1,1],k=1,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[65547,1,2,1],k=1,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[65536,1,1,1],k=2,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[65547,1,2,1],k=2,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[65536,1,1,1],k=3,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[65547,1,2,1],k=3,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[65536,1,1,1],k=7,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[65547,1,2,1],k=7,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[65536,1,1,1],k=15,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[65547,1,2,1],k=15,ties=0","support","0","no","SYCL" +"SYCL0","TOP_K","type=f32,ne=[65536,1,1,1],k=1,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[65547,1,2,1],k=1,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[65536,1,1,1],k=2,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[65547,1,2,1],k=2,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[65536,1,1,1],k=3,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[65547,1,2,1],k=3,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[65536,1,1,1],k=7,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[65547,1,2,1],k=7,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[65536,1,1,1],k=15,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[65547,1,2,1],k=15,ties=0","support","1","yes","SYCL" "SYCL0","TOP_K","type=f32,ne=[65536,1,1,1],k=100,ties=0","support","0","no","SYCL" "SYCL0","TOP_K","type=f32,ne=[65547,1,2,1],k=100,ties=0","support","0","no","SYCL" "SYCL0","TOP_K","type=f32,ne=[65536,1,1,1],k=500,ties=0","support","0","no","SYCL" @@ -9749,16 +9883,16 @@ "SYCL0","TOP_K","type=f32,ne=[65547,1,2,1],k=1023,ties=0","support","0","no","SYCL" "SYCL0","TOP_K","type=f32,ne=[65536,1,1,1],k=9999,ties=0","support","0","no","SYCL" "SYCL0","TOP_K","type=f32,ne=[65547,1,2,1],k=9999,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[131072,1,1,1],k=1,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[131083,1,2,1],k=1,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[131072,1,1,1],k=2,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[131083,1,2,1],k=2,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[131072,1,1,1],k=3,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[131083,1,2,1],k=3,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[131072,1,1,1],k=7,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[131083,1,2,1],k=7,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[131072,1,1,1],k=15,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[131083,1,2,1],k=15,ties=0","support","0","no","SYCL" +"SYCL0","TOP_K","type=f32,ne=[131072,1,1,1],k=1,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[131083,1,2,1],k=1,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[131072,1,1,1],k=2,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[131083,1,2,1],k=2,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[131072,1,1,1],k=3,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[131083,1,2,1],k=3,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[131072,1,1,1],k=7,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[131083,1,2,1],k=7,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[131072,1,1,1],k=15,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[131083,1,2,1],k=15,ties=0","support","1","yes","SYCL" "SYCL0","TOP_K","type=f32,ne=[131072,1,1,1],k=100,ties=0","support","0","no","SYCL" "SYCL0","TOP_K","type=f32,ne=[131083,1,2,1],k=100,ties=0","support","0","no","SYCL" "SYCL0","TOP_K","type=f32,ne=[131072,1,1,1],k=500,ties=0","support","0","no","SYCL" @@ -9767,16 +9901,16 @@ "SYCL0","TOP_K","type=f32,ne=[131083,1,2,1],k=1023,ties=0","support","0","no","SYCL" "SYCL0","TOP_K","type=f32,ne=[131072,1,1,1],k=9999,ties=0","support","0","no","SYCL" "SYCL0","TOP_K","type=f32,ne=[131083,1,2,1],k=9999,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[262144,1,1,1],k=1,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[262155,1,2,1],k=1,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[262144,1,1,1],k=2,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[262155,1,2,1],k=2,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[262144,1,1,1],k=3,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[262155,1,2,1],k=3,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[262144,1,1,1],k=7,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[262155,1,2,1],k=7,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[262144,1,1,1],k=15,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[262155,1,2,1],k=15,ties=0","support","0","no","SYCL" +"SYCL0","TOP_K","type=f32,ne=[262144,1,1,1],k=1,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[262155,1,2,1],k=1,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[262144,1,1,1],k=2,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[262155,1,2,1],k=2,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[262144,1,1,1],k=3,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[262155,1,2,1],k=3,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[262144,1,1,1],k=7,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[262155,1,2,1],k=7,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[262144,1,1,1],k=15,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[262155,1,2,1],k=15,ties=0","support","1","yes","SYCL" "SYCL0","TOP_K","type=f32,ne=[262144,1,1,1],k=100,ties=0","support","0","no","SYCL" "SYCL0","TOP_K","type=f32,ne=[262155,1,2,1],k=100,ties=0","support","0","no","SYCL" "SYCL0","TOP_K","type=f32,ne=[262144,1,1,1],k=500,ties=0","support","0","no","SYCL" @@ -9785,16 +9919,16 @@ "SYCL0","TOP_K","type=f32,ne=[262155,1,2,1],k=1023,ties=0","support","0","no","SYCL" "SYCL0","TOP_K","type=f32,ne=[262144,1,1,1],k=9999,ties=0","support","0","no","SYCL" "SYCL0","TOP_K","type=f32,ne=[262155,1,2,1],k=9999,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[524288,1,1,1],k=1,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[524299,1,2,1],k=1,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[524288,1,1,1],k=2,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[524299,1,2,1],k=2,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[524288,1,1,1],k=3,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[524299,1,2,1],k=3,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[524288,1,1,1],k=7,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[524299,1,2,1],k=7,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[524288,1,1,1],k=15,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[524299,1,2,1],k=15,ties=0","support","0","no","SYCL" +"SYCL0","TOP_K","type=f32,ne=[524288,1,1,1],k=1,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[524299,1,2,1],k=1,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[524288,1,1,1],k=2,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[524299,1,2,1],k=2,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[524288,1,1,1],k=3,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[524299,1,2,1],k=3,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[524288,1,1,1],k=7,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[524299,1,2,1],k=7,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[524288,1,1,1],k=15,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[524299,1,2,1],k=15,ties=0","support","1","yes","SYCL" "SYCL0","TOP_K","type=f32,ne=[524288,1,1,1],k=100,ties=0","support","0","no","SYCL" "SYCL0","TOP_K","type=f32,ne=[524299,1,2,1],k=100,ties=0","support","0","no","SYCL" "SYCL0","TOP_K","type=f32,ne=[524288,1,1,1],k=500,ties=0","support","0","no","SYCL" @@ -9803,73 +9937,73 @@ "SYCL0","TOP_K","type=f32,ne=[524299,1,2,1],k=1023,ties=0","support","0","no","SYCL" "SYCL0","TOP_K","type=f32,ne=[524288,1,1,1],k=9999,ties=0","support","0","no","SYCL" "SYCL0","TOP_K","type=f32,ne=[524299,1,2,1],k=9999,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[16,10,10,10],k=1,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[60,10,10,10],k=1,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[1023,2,1,3],k=1,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[1024,2,1,3],k=1,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[1025,2,1,3],k=1,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[16384,1,1,1],k=1,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[2047,2,1,3],k=1,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[2048,2,1,3],k=1,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[2049,2,1,3],k=1,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[16,10,10,10],k=2,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[60,10,10,10],k=2,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[1023,2,1,3],k=2,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[1024,2,1,3],k=2,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[1025,2,1,3],k=2,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[16384,1,1,1],k=2,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[2047,2,1,3],k=2,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[2048,2,1,3],k=2,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[2049,2,1,3],k=2,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[16,10,10,10],k=3,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[60,10,10,10],k=3,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[1023,2,1,3],k=3,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[1024,2,1,3],k=3,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[1025,2,1,3],k=3,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[16384,1,1,1],k=3,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[2047,2,1,3],k=3,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[2048,2,1,3],k=3,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[2049,2,1,3],k=3,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[16,10,10,10],k=7,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[60,10,10,10],k=7,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[1023,2,1,3],k=7,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[1024,2,1,3],k=7,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[1025,2,1,3],k=7,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[16384,1,1,1],k=7,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[2047,2,1,3],k=7,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[2048,2,1,3],k=7,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[2049,2,1,3],k=7,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[16,10,10,10],k=15,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[60,10,10,10],k=15,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[1023,2,1,3],k=15,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[1024,2,1,3],k=15,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[1025,2,1,3],k=15,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[16384,1,1,1],k=15,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[2047,2,1,3],k=15,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[2048,2,1,3],k=15,ties=0","support","0","no","SYCL" -"SYCL0","TOP_K","type=f32,ne=[2049,2,1,3],k=15,ties=0","support","0","no","SYCL" +"SYCL0","TOP_K","type=f32,ne=[16,10,10,10],k=1,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[60,10,10,10],k=1,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[1023,2,1,3],k=1,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[1024,2,1,3],k=1,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[1025,2,1,3],k=1,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[16384,1,1,1],k=1,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[2047,2,1,3],k=1,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[2048,2,1,3],k=1,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[2049,2,1,3],k=1,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[16,10,10,10],k=2,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[60,10,10,10],k=2,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[1023,2,1,3],k=2,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[1024,2,1,3],k=2,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[1025,2,1,3],k=2,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[16384,1,1,1],k=2,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[2047,2,1,3],k=2,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[2048,2,1,3],k=2,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[2049,2,1,3],k=2,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[16,10,10,10],k=3,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[60,10,10,10],k=3,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[1023,2,1,3],k=3,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[1024,2,1,3],k=3,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[1025,2,1,3],k=3,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[16384,1,1,1],k=3,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[2047,2,1,3],k=3,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[2048,2,1,3],k=3,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[2049,2,1,3],k=3,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[16,10,10,10],k=7,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[60,10,10,10],k=7,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[1023,2,1,3],k=7,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[1024,2,1,3],k=7,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[1025,2,1,3],k=7,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[16384,1,1,1],k=7,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[2047,2,1,3],k=7,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[2048,2,1,3],k=7,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[2049,2,1,3],k=7,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[16,10,10,10],k=15,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[60,10,10,10],k=15,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[1023,2,1,3],k=15,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[1024,2,1,3],k=15,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[1025,2,1,3],k=15,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[16384,1,1,1],k=15,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[2047,2,1,3],k=15,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[2048,2,1,3],k=15,ties=0","support","1","yes","SYCL" +"SYCL0","TOP_K","type=f32,ne=[2049,2,1,3],k=15,ties=0","support","1","yes","SYCL" "SYCL0","UPSCALE","type=f32,ne=[512,512,3,2],scale_factor=2,mode=nearest,transpose=0","support","1","yes","SYCL" "SYCL0","UPSCALE","type=f32,ne=[512,512,3,2],scale_factor=2,mode=nearest,transpose=1","support","1","yes","SYCL" -"SYCL0","UPSCALE","type=f32,ne=[2,5,7,11],ne_tgt=[5,7,11,13],mode=nearest,flags=none","support","1","yes","SYCL" -"SYCL0","UPSCALE","type=f32,ne=[5,7,11,13],ne_tgt=[2,5,7,11],mode=nearest,flags=none","support","1","yes","SYCL" +"SYCL0","UPSCALE","type=f32,ne=[2,5,7,11],ne_tgt=[5,7,11,13],mode=nearest","support","1","yes","SYCL" +"SYCL0","UPSCALE","type=f32,ne=[5,7,11,13],ne_tgt=[2,5,7,11],mode=nearest","support","1","yes","SYCL" "SYCL0","UPSCALE","type=f32,ne=[512,512,3,2],scale_factor=2,mode=bilinear,transpose=0","support","0","no","SYCL" "SYCL0","UPSCALE","type=f32,ne=[512,512,3,2],scale_factor=2,mode=bilinear,transpose=1","support","0","no","SYCL" -"SYCL0","UPSCALE","type=f32,ne=[2,5,7,11],ne_tgt=[5,7,11,13],mode=bilinear,flags=none","support","0","no","SYCL" -"SYCL0","UPSCALE","type=f32,ne=[5,7,11,13],ne_tgt=[2,5,7,11],mode=bilinear,flags=none","support","0","no","SYCL" +"SYCL0","UPSCALE","type=f32,ne=[2,5,7,11],ne_tgt=[5,7,11,13],mode=bilinear","support","0","no","SYCL" +"SYCL0","UPSCALE","type=f32,ne=[5,7,11,13],ne_tgt=[2,5,7,11],mode=bilinear","support","0","no","SYCL" "SYCL0","UPSCALE","type=f32,ne=[512,512,3,2],scale_factor=2,mode=bicubic,transpose=0","support","0","no","SYCL" "SYCL0","UPSCALE","type=f32,ne=[512,512,3,2],scale_factor=2,mode=bicubic,transpose=1","support","0","no","SYCL" -"SYCL0","UPSCALE","type=f32,ne=[2,5,7,11],ne_tgt=[5,7,11,13],mode=bicubic,flags=none","support","0","no","SYCL" -"SYCL0","UPSCALE","type=f32,ne=[5,7,11,13],ne_tgt=[2,5,7,11],mode=bicubic,flags=none","support","0","no","SYCL" -"SYCL0","UPSCALE","type=f32,ne=[512,512,3,2],scale_factor=2,mode=513,transpose=0","support","0","no","SYCL" -"SYCL0","UPSCALE","type=f32,ne=[512,512,3,2],scale_factor=2,mode=513,transpose=1","support","0","no","SYCL" -"SYCL0","UPSCALE","type=f32,ne=[2,5,7,11],ne_tgt=[5,7,11,13],mode=bilinear,flags=none","support","0","no","SYCL" -"SYCL0","UPSCALE","type=f32,ne=[5,7,11,13],ne_tgt=[2,5,7,11],mode=bilinear,flags=none","support","0","no","SYCL" -"SYCL0","UPSCALE","type=f32,ne=[2,5,7,11],ne_tgt=[5,7,11,13],mode=bilinear,flags=align_corners","support","0","no","SYCL" -"SYCL0","UPSCALE","type=f32,ne=[1,4,3,2],ne_tgt=[2,8,3,2],mode=bilinear,flags=align_corners","support","0","no","SYCL" -"SYCL0","UPSCALE","type=f32,ne=[4,1,3,2],ne_tgt=[1,1,3,2],mode=bilinear,flags=align_corners","support","0","no","SYCL" -"SYCL0","UPSCALE","type=f32,ne=[2,5,7,11],ne_tgt=[5,7,11,13],mode=bicubic,flags=align_corners","support","0","no","SYCL" -"SYCL0","UPSCALE","type=f32,ne=[1,4,3,2],ne_tgt=[2,8,3,2],mode=bicubic,flags=align_corners","support","0","no","SYCL" -"SYCL0","UPSCALE","type=f32,ne=[4,1,3,2],ne_tgt=[1,1,3,2],mode=bicubic,flags=align_corners","support","0","no","SYCL" +"SYCL0","UPSCALE","type=f32,ne=[2,5,7,11],ne_tgt=[5,7,11,13],mode=bicubic","support","0","no","SYCL" +"SYCL0","UPSCALE","type=f32,ne=[5,7,11,13],ne_tgt=[2,5,7,11],mode=bicubic","support","0","no","SYCL" +"SYCL0","UPSCALE","type=f32,ne=[512,512,3,2],scale_factor=2,mode=bilinear|antialias,transpose=0","support","0","no","SYCL" +"SYCL0","UPSCALE","type=f32,ne=[512,512,3,2],scale_factor=2,mode=bilinear|antialias,transpose=1","support","0","no","SYCL" +"SYCL0","UPSCALE","type=f32,ne=[2,5,7,11],ne_tgt=[5,7,11,13],mode=bilinear|antialias","support","0","no","SYCL" +"SYCL0","UPSCALE","type=f32,ne=[5,7,11,13],ne_tgt=[2,5,7,11],mode=bilinear|antialias","support","0","no","SYCL" +"SYCL0","UPSCALE","type=f32,ne=[2,5,7,11],ne_tgt=[5,7,11,13],mode=bilinear|align_corners","support","0","no","SYCL" +"SYCL0","UPSCALE","type=f32,ne=[1,4,3,2],ne_tgt=[2,8,3,2],mode=bilinear|align_corners","support","0","no","SYCL" +"SYCL0","UPSCALE","type=f32,ne=[4,1,3,2],ne_tgt=[1,1,3,2],mode=bilinear|align_corners","support","0","no","SYCL" +"SYCL0","UPSCALE","type=f32,ne=[2,5,7,11],ne_tgt=[5,7,11,13],mode=bicubic|align_corners","support","0","no","SYCL" +"SYCL0","UPSCALE","type=f32,ne=[1,4,3,2],ne_tgt=[2,8,3,2],mode=bicubic|align_corners","support","0","no","SYCL" +"SYCL0","UPSCALE","type=f32,ne=[4,1,3,2],ne_tgt=[1,1,3,2],mode=bicubic|align_corners","support","0","no","SYCL" "SYCL0","SUM","type=f32,ne=[10,5,4,3]","support","1","yes","SYCL" "SYCL0","SUM_ROWS","type=f32,ne=[10,5,4,3],permute=0,slice=0","support","1","yes","SYCL" "SYCL0","SUM","type=f32,ne=[11,5,6,3],permute=[0,2,1,3]","support","0","no","SYCL" @@ -9892,8 +10026,9 @@ "SYCL0","GROUP_NORM","type=f32,ne=[64,64,320,1],num_groups=32,eps=0.000001","support","1","yes","SYCL" "SYCL0","GROUP_NORM","type=f32,ne=[9,9,1280,1],num_groups=32,eps=0.000001","support","1","yes","SYCL" "SYCL0","ACC","type=f32,ne_a=[256,17,1,1],ne_b=[256,16,1,1]","support","1","yes","SYCL" -"SYCL0","PAD","type=f32,ne_a=[512,512,1,1],pad_0=1,pad_1=1","support","1","yes","SYCL" -"SYCL0","PAD","type=f32,ne_a=[512,512,3,1],lp0=1,rp0=1,lp1=1,rp1=1,lp2=1,rp2=1,lp3=1,rp3=1,v=0","support","1","yes","SYCL" +"SYCL0","PAD","type=f32,ne_a=[512,512,1,1],pad_0=1,pad_1=1,circular=0","support","1","yes","SYCL" +"SYCL0","PAD","type=f32,ne_a=[33,17,2,1],pad_0=4,pad_1=3,circular=1","support","0","no","SYCL" +"SYCL0","PAD","type=f32,ne_a=[512,512,3,1],lp0=1,rp0=1,lp1=1,rp1=1,lp2=1,rp2=1,lp3=1,rp3=1,v=0,circular=0","support","1","yes","SYCL" "SYCL0","PAD_REFLECT_1D","type=f32,ne_a=[512,34,2,1],pad_0=10,pad_1=9","support","1","yes","SYCL" "SYCL0","PAD_REFLECT_1D","type=f32,ne_a=[3000,384,4,1],pad_0=10,pad_1=9","support","1","yes","SYCL" "SYCL0","ROLL","shift0=3,shift1=-2,shift3=1,shift4=-1","support","1","yes","SYCL" @@ -9915,28 +10050,51 @@ "SYCL0","CUMSUM","type=f32,ne=[2048,5,4,3]","support","0","no","SYCL" "SYCL0","CUMSUM","type=f32,ne=[242004,1,1,1]","support","0","no","SYCL" "SYCL0","CUMSUM","type=f32,ne=[375960,1,1,1]","support","0","no","SYCL" +"SYCL0","CUMSUM","type=f32,ne=[20481,4,1,1]","support","0","no","SYCL" "SYCL0","XIELU","type=f32,ne=[10,5,4,3]","support","0","no","SYCL" -"SYCL0","TRI","type=f32,ne=[10,10,4,3],tri_type=3","support","0","no","SYCL" -"SYCL0","TRI","type=f32,ne=[10,10,4,3],tri_type=2","support","0","no","SYCL" -"SYCL0","TRI","type=f32,ne=[10,10,4,3],tri_type=1","support","0","no","SYCL" -"SYCL0","TRI","type=f32,ne=[10,10,4,3],tri_type=0","support","0","no","SYCL" +"SYCL0","TRI","type=f32,ne=[10,10,4,3],tri_type=3","support","1","yes","SYCL" +"SYCL0","TRI","type=f32,ne=[10,10,4,3],tri_type=2","support","1","yes","SYCL" +"SYCL0","TRI","type=f32,ne=[10,10,4,3],tri_type=1","support","1","yes","SYCL" +"SYCL0","TRI","type=f32,ne=[10,10,4,3],tri_type=0","support","1","yes","SYCL" "SYCL0","FILL","type=f32,ne=[10,10,4,3],c=0.000000","support","0","no","SYCL" "SYCL0","FILL","type=f32,ne=[303,207,11,3],c=2.000000","support","0","no","SYCL" "SYCL0","FILL","type=f32,ne=[800,600,4,4],c=-152.000000","support","0","no","SYCL" "SYCL0","FILL","type=f32,ne=[2048,512,2,2],c=3.500000","support","0","no","SYCL" +"SYCL0","DIAG","type=f32,ne=[10,1,4,3]","support","0","no","SYCL" +"SYCL0","DIAG","type=f32,ne=[79,1,19,13]","support","0","no","SYCL" +"SYCL0","DIAG","type=f32,ne=[256,1,8,16]","support","0","no","SYCL" "SYCL0","SOLVE_TRI","type=f32,ne_lhs=[10,10,4,3],ne_rhs=[3,10,4,3]","support","0","no","SYCL" "SYCL0","SOLVE_TRI","type=f32,ne_lhs=[11,11,1,1],ne_rhs=[5,11,1,1]","support","0","no","SYCL" "SYCL0","SOLVE_TRI","type=f32,ne_lhs=[17,17,2,4],ne_rhs=[9,17,2,4]","support","0","no","SYCL" "SYCL0","SOLVE_TRI","type=f32,ne_lhs=[30,30,7,1],ne_rhs=[8,30,7,1]","support","0","no","SYCL" "SYCL0","SOLVE_TRI","type=f32,ne_lhs=[42,42,5,2],ne_rhs=[10,42,5,2]","support","0","no","SYCL" "SYCL0","SOLVE_TRI","type=f32,ne_lhs=[64,64,2,2],ne_rhs=[10,64,2,2]","support","0","no","SYCL" +"SYCL0","SOLVE_TRI","type=f32,ne_lhs=[64,64,2,2],ne_rhs=[64,64,2,2]","support","0","no","SYCL" +"SYCL0","SOLVE_TRI","type=f32,ne_lhs=[79,79,5,3],ne_rhs=[417,79,5,3]","support","0","no","SYCL" +"SYCL0","SOLVE_TRI","type=f32,ne_lhs=[128,128,4,2],ne_rhs=[32,128,4,2]","support","0","no","SYCL" +"SYCL0","SOLVE_TRI","type=f32,ne_lhs=[80,80,2,8],ne_rhs=[80,80,2,8]","support","0","no","SYCL" +"SYCL0","SOLVE_TRI","type=f32,ne_lhs=[80,80,2,8],ne_rhs=[79,80,2,8]","support","0","no","SYCL" +"SYCL0","SOLVE_TRI","type=f32,ne_lhs=[80,80,2,8],ne_rhs=[81,80,2,8]","support","0","no","SYCL" +"SYCL0","SOLVE_TRI","type=f32,ne_lhs=[80,80,8,8],ne_rhs=[80,80,8,8]","support","0","no","SYCL" +"SYCL0","SOLVE_TRI","type=f32,ne_lhs=[80,80,8,8],ne_rhs=[79,80,8,8]","support","0","no","SYCL" +"SYCL0","SOLVE_TRI","type=f32,ne_lhs=[80,80,8,8],ne_rhs=[81,80,8,8]","support","0","no","SYCL" +"SYCL0","SOLVE_TRI","type=f32,ne_lhs=[84,84,4,4],ne_rhs=[32,84,4,4]","support","0","no","SYCL" +"SYCL0","SOLVE_TRI","type=f32,ne_lhs=[95,95,8,8],ne_rhs=[40,95,8,8]","support","0","no","SYCL" "SYCL0","SOLVE_TRI","type=f32,ne_lhs=[100,100,4,4],ne_rhs=[41,100,4,4]","support","0","no","SYCL" "SYCL0","SOLVE_TRI","type=f32,ne_lhs=[128,128,4,4],ne_rhs=[31,128,4,4]","support","0","no","SYCL" -"SYCL0","SOLVE_TRI","type=f32,ne_lhs=[64,64,4,4],ne_rhs=[300,64,4,4]","support","0","no","SYCL" -"SYCL0","PAD","type=f32,ne_a=[512,512,1,1],lp0=0,rp0=1,lp1=0,rp1=1,lp2=0,rp2=0,lp3=0,rp3=0,v=0","support","1","yes","SYCL" -"SYCL0","PAD","type=f32,ne_a=[11,22,33,44],lp0=1,rp0=2,lp1=3,rp1=4,lp2=5,rp2=6,lp3=7,rp3=8,v=0","support","1","yes","SYCL" -"SYCL0","PAD","type=f32,ne_a=[512,512,1,1],lp0=0,rp0=1,lp1=0,rp1=1,lp2=0,rp2=0,lp3=0,rp3=0,v=1","support","0","no","SYCL" -"SYCL0","PAD","type=f32,ne_a=[11,22,33,44],lp0=1,rp0=2,lp1=3,rp1=4,lp2=5,rp2=6,lp3=7,rp3=8,v=1","support","0","no","SYCL" +"SYCL0","SOLVE_TRI","type=f32,ne_lhs=[128,128,4,4],ne_rhs=[32,128,4,4]","support","0","no","SYCL" +"SYCL0","SOLVE_TRI","type=f32,ne_lhs=[128,128,3,4],ne_rhs=[32,128,3,4]","support","0","no","SYCL" +"SYCL0","SOLVE_TRI","type=f32,ne_lhs=[128,128,4,1],ne_rhs=[32,128,4,1]","support","0","no","SYCL" +"SYCL0","SOLVE_TRI","type=f32,ne_lhs=[64,64,4,4],ne_rhs=[200,64,4,4]","support","0","no","SYCL" +"SYCL0","SOLVE_TRI","type=f32,ne_lhs=[64,64,4,4],ne_rhs=[384,64,4,4]","support","0","no","SYCL" +"SYCL0","PAD","type=f32,ne_a=[512,512,1,1],lp0=0,rp0=1,lp1=0,rp1=1,lp2=0,rp2=0,lp3=0,rp3=0,v=0,circular=0","support","1","yes","SYCL" +"SYCL0","PAD","type=f32,ne_a=[11,22,33,44],lp0=1,rp0=2,lp1=3,rp1=4,lp2=5,rp2=6,lp3=7,rp3=8,v=0,circular=0","support","1","yes","SYCL" +"SYCL0","PAD","type=f32,ne_a=[512,512,1,1],lp0=0,rp0=1,lp1=0,rp1=1,lp2=0,rp2=0,lp3=0,rp3=0,v=0,circular=1","support","0","no","SYCL" +"SYCL0","PAD","type=f32,ne_a=[11,22,33,44],lp0=1,rp0=2,lp1=3,rp1=4,lp2=5,rp2=6,lp3=7,rp3=8,v=0,circular=1","support","0","no","SYCL" +"SYCL0","PAD","type=f32,ne_a=[512,512,1,1],lp0=0,rp0=1,lp1=0,rp1=1,lp2=0,rp2=0,lp3=0,rp3=0,v=1,circular=0","support","0","no","SYCL" +"SYCL0","PAD","type=f32,ne_a=[11,22,33,44],lp0=1,rp0=2,lp1=3,rp1=4,lp2=5,rp2=6,lp3=7,rp3=8,v=1,circular=0","support","0","no","SYCL" +"SYCL0","PAD","type=f32,ne_a=[512,512,1,1],lp0=0,rp0=1,lp1=0,rp1=1,lp2=0,rp2=0,lp3=0,rp3=0,v=1,circular=1","support","0","no","SYCL" +"SYCL0","PAD","type=f32,ne_a=[11,22,33,44],lp0=1,rp0=2,lp1=3,rp1=4,lp2=5,rp2=6,lp3=7,rp3=8,v=1,circular=1","support","0","no","SYCL" "SYCL0","FLASH_ATTN_EXT","hsk=40,hsv=40,nh=4,nr23=[1,1],kv=113,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" "SYCL0","FLASH_ATTN_EXT","hsk=40,hsv=40,nh=4,nr23=[1,1],kv=113,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" "SYCL0","FLASH_ATTN_EXT","hsk=40,hsv=40,nh=4,nr23=[1,1],kv=113,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" @@ -14097,86 +14255,86 @@ "SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[4,1],kv=512,nb=35,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,2,1,3]","support","0","no","SYCL" "SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[4,1],kv=512,nb=35,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" "SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[4,1],kv=512,nb=35,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q8_0,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q4_0,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f32,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=bf16,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q8_0,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q4_0,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f32,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=bf16,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q8_0,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q4_0,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f32,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=bf16,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q8_0,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q4_0,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f32,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=bf16,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q8_0,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q4_0,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f32,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=bf16,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q8_0,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q4_0,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f32,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=bf16,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q8_0,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q4_0,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f32,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=bf16,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q8_0,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q4_0,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f32,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=bf16,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,2,1,3]","support","0","no","SYCL" "SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[1,1],kv=113,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" "SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[1,1],kv=113,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" "SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[1,1],kv=113,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" @@ -14337,46 +14495,46 @@ "SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[4,1],kv=512,nb=35,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" "SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[4,1],kv=512,nb=35,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" "SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[4,1],kv=512,nb=35,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" "SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[1,1],kv=113,nb=1,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" "SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[1,1],kv=113,nb=1,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" "SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[1,1],kv=113,nb=1,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" @@ -14537,46 +14695,46 @@ "SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[4,1],kv=512,nb=35,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" "SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[4,1],kv=512,nb=35,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" "SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[4,1],kv=512,nb=35,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=1,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" "SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[1,1],kv=113,nb=1,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" "SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[1,1],kv=113,nb=1,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" "SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[1,1],kv=113,nb=1,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" @@ -14737,46 +14895,46 @@ "SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[4,1],kv=512,nb=35,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" "SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[4,1],kv=512,nb=35,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" "SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[4,1],kv=512,nb=35,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=1,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" "SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[1,1],kv=113,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" "SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[1,1],kv=113,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" "SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[1,1],kv=113,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" @@ -15017,86 +15175,86 @@ "SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[4,1],kv=512,nb=35,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,2,1,3]","support","0","no","SYCL" "SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[4,1],kv=512,nb=35,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" "SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[4,1],kv=512,nb=35,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q8_0,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q4_0,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f32,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=bf16,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q8_0,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q4_0,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f32,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=bf16,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q8_0,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q4_0,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f32,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=bf16,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q8_0,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q4_0,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f32,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=bf16,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,2,1,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q8_0,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q4_0,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f32,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=bf16,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q8_0,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q4_0,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f32,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=bf16,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q8_0,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q4_0,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f32,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=bf16,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q8_0,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q4_0,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f32,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f16,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=bf16,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,2,1,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,2,1,3]","support","0","no","SYCL" "SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[1,1],kv=113,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" "SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[1,1],kv=113,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" "SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[1,1],kv=113,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" @@ -15257,46 +15415,46 @@ "SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[4,1],kv=512,nb=35,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" "SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[4,1],kv=512,nb=35,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" "SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[4,1],kv=512,nb=35,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" "SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[1,1],kv=113,nb=1,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" "SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[1,1],kv=113,nb=1,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" "SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[1,1],kv=113,nb=1,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" @@ -15457,46 +15615,46 @@ "SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[4,1],kv=512,nb=35,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" "SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[4,1],kv=512,nb=35,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" "SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[4,1],kv=512,nb=35,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=0,max_bias=8.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" "SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[1,1],kv=113,nb=1,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" "SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[1,1],kv=113,nb=1,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" "SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[1,1],kv=113,nb=1,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" @@ -15657,46 +15815,46 @@ "SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[4,1],kv=512,nb=35,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" "SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[4,1],kv=512,nb=35,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" "SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[4,1],kv=512,nb=35,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=1,sinks=0,max_bias=8.000000,logit_softcap=10.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" "SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[1,1],kv=113,nb=1,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" "SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[1,1],kv=113,nb=1,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" "SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[1,1],kv=113,nb=1,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" @@ -15857,46 +16015,46 @@ "SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[4,1],kv=512,nb=35,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" "SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[4,1],kv=512,nb=35,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" "SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[4,1],kv=512,nb=35,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" "SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[1,1],kv=113,nb=1,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" "SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[1,1],kv=113,nb=1,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" "SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[1,1],kv=113,nb=1,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" @@ -16057,46 +16215,46 @@ "SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[4,1],kv=512,nb=35,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" "SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[4,1],kv=512,nb=35,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" "SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[4,1],kv=512,nb=35,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=0,sinks=1,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" "SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[1,1],kv=113,nb=1,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" "SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[1,1],kv=113,nb=1,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" "SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[1,1],kv=113,nb=1,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" @@ -16257,46 +16415,46 @@ "SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[4,1],kv=512,nb=35,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" "SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[4,1],kv=512,nb=35,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" "SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[4,1],kv=512,nb=35,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=0,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" "SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[1,1],kv=113,nb=1,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" "SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[1,1],kv=113,nb=1,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" "SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[1,1],kv=113,nb=1,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" @@ -16457,46 +16615,46 @@ "SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[4,1],kv=512,nb=35,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" "SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[4,1],kv=512,nb=35,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" "SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[4,1],kv=512,nb=35,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=1,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=3,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=32,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" -"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[16,1],kv=512,nb=35,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=1,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=3,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=32,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=f32,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q8_0,permute=[0,1,2,3]","support","0","no","SYCL" +"SYCL0","FLASH_ATTN_EXT","hsk=128,hsv=128,nh=4,nr23=[12,1],kv=512,nb=35,mask=0,sinks=0,max_bias=0.000000,logit_softcap=10.000000,prec=def,type_KV=q4_0,permute=[0,1,2,3]","support","0","no","SYCL" "SYCL0","FLASH_ATTN_EXT","hsk=192,hsv=128,nh=4,nr23=[1,1],kv=113,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","SYCL" "SYCL0","FLASH_ATTN_EXT","hsk=192,hsv=128,nh=4,nr23=[1,1],kv=113,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","SYCL" "SYCL0","FLASH_ATTN_EXT","hsk=192,hsv=128,nh=4,nr23=[1,1],kv=113,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","SYCL" diff --git a/docs/speculative.md b/docs/speculative.md new file mode 100644 index 00000000000..03afab5b41e --- /dev/null +++ b/docs/speculative.md @@ -0,0 +1,184 @@ +# Speculative Decoding + +llama.cpp supports speculative decoding, a technique that can significantly accelerate token generation by predicting multiple tokens ahead of the main model. + +[Speculative decoding](https://en.wikipedia.org/wiki/Transformer_(deep_learning)#Speculative_decoding) leverages the fact that computing n tokens in a batch (as in prompt processing) is more efficient than computing n sequentially (as in response generation). By generating draft tokens quickly and then verifying them with the target model in a single batch, this approach can achieve substantial speedups when the draft predictions are frequently correct. + +## Implementations + +The `llama-server` application supports several implementations of speculative decoding. An implementation with draft model can be mixed with an implementation without draft model. + +### Draft Model (`draft`) + +A much smaller model (called the _draft model_) generates drafts. +A draft model is the most used approach in speculative decoding. + +### n-gram Cache (`ngram-cache`) + +An n-gram is a sequence of n tokens. The n-gram cache implementation maintains statistics about short n-gram sequences. +A draft is computed using probabilities derived from these statistics. External statistics can also be loaded from files for improved accuracy. + +See: + +- #5479, #6828, #6848 + +### n-gram Map (`ngram-simple`, `ngram-map-*`) + +These implementations search the token history for patterns and use matching sequences as draft candidates. +They require no additional model but rely on patterns that have already appeared in the generated text. +An example to use this approach can be the rewriting of source code by a LLM. + +#### n-gram Map (`ngram-simple`) + +This implementation looks for the last n-gram in history that matches the current n-gram and creates a draft using the m tokens following the matched n-gram. It is the simplest self-speculative approach with minimal overhead. + +``` +llama-server [...] --spec-type ngram-simple --draft-max 64 +``` + +#### n-gram Map Key (`ngram-map-k`) + +This implementation looks for the current n-gram of size n (called the _key_) in the token history. If the key n-gram is followed by the same m tokens (called the _mgram_) multiple times, it creates a draft using these m tokens. This approach requires a minimum number of occurrences (argument `--spec-ngram-min-hits`, default is 1) before generating drafts. + +The number of accepted tokens is stored for each used n-gram. + +**Example:** +``` +llama-server [...] --spec-type ngram-map-k --draft-max 64 +``` + +#### n-gram Map Key-4-Values (`ngram-map-k4v`) + +This experimental implementation looks for the current n-gram of size n (called the _key_) in the token history. For each key, up to four _values_ (n-grams of size m, called _mgrams_) are tracked. An internal statistic counts the occurrences of each mgram after the key n-gram. If one mgram is significantly more frequent than the others, it is used as the draft. + +The number of accepted tokens is stored for each used n-gram. + +**Example:** Server options to be used if there are a lot of longer repetitions. +``` +llama-server [...] --spec-type ngram-map-k4v --spec-ngram-size-n 8 --spec-ngram-size-m 8 --spec-ngram-min-hits 2 --draft-max 64 +``` + +### n-gram Mod (`ngram-mod`) + +Add basic ngram hasher for speculative decoding: + +- For each ngram, compute a hash using LCG +- For each computed hash, store the next token +- During speculation, iteratively compute the rolling hash of the last n tokens and pick the next token from the storage + +Some characteristics: + +- Lightweight (~16 MB) +- Constant memory and complexity +- Can generate variable draft lengths (i.e. m is not fixed) + +Currently, a single hash pool is shared across all server slots, so different requests can benefit from each other. + +**Sample usage:** + +``` +# notes: +# - small `n` are not recommended +# - MoEs require long drafts +# - dense models: can reduce `--draft-min` and `--draft-max` + +llama-server ... --spec-type ngram-mod --spec-ngram-size-n 24 --draft-min 48 --draft-max 64 +``` + +Applications: + +- Iterating over a block of text/code (e.g. in llama.vim) +- Reasoning models (when they have to repeat their thinking in the final answer) +- Summarization + +Example Video: + +- See #19164 + +### Differences between ngram-simple, ngram-map and ngram-mod + +- ngram-simple looks for a previous matching n-gram and inserts the following m-gram. +- ngram-map-k looks for a previous matching n-gram and inserts the following m-gram but uses an internal hash-map of n-grams in the current context window. +- ngram-mod uses a hash pool which is shared across all server slots. The hash pool is a map from n-gram hash to the next token (not the next m-gram as in ngram-map). + +## Command-Line Options + +If a draft model is combined with a draftless decoding the draftless decoding has higher precedence. + +``` +--draft, --draft-n, --draft-max N number of tokens to draft for speculative decoding (default: 16) + (env: LLAMA_ARG_DRAFT_MAX) +--draft-min, --draft-n-min N minimum number of draft tokens to use for speculative decoding + (default: 0) + (env: LLAMA_ARG_DRAFT_MIN) +[...] +--spec-type [none|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod] + type of speculative decoding to use when no draft model is provided + (default: none) +--spec-ngram-size-n N ngram size N for ngram-simple/ngram-map speculative decoding, length + of lookup n-gram (default: 12) +--spec-ngram-size-m N ngram size M for ngram-simple/ngram-map speculative decoding, length + of draft m-gram (default: 48) +--spec-ngram-check-rate N ngram check rate for ngram-simple/ngram-map speculative decoding + (default: 1) +--spec-ngram-min-hits N minimum hits for ngram-map speculative decoding (default: 1) +``` + +### `--spec-type TYPE` + +Specifies a type of speculative decoding without draft model. + +| Type | Description | +|------|-------------| +| `none` | No speculative decoding (default) | +| `ngram-cache` | Use n-gram cache lookup | +| `ngram-simple` | Use simple n-gram pattern matching | +| `ngram-map-k` | Use n-gram pattern matching with n-gram-keys | +| `ngram-map-k4v` | Use n-gram pattern matching with n-gram-keys and up to four m-gram values (experimental) | +| `ngram-mod` | Use basic ngram hasher for speculative decoding with shared pool | + +**Example:** Server-instance used to refactor source code. +```bash +./llama-server [...] --spec-type ngram-simple +``` + +### `--spec-ngram-size-n N` + +Sets the size N of the lookup n-gram for n-gram map based speculative decoding. +The n-gram size N determines how many tokens in a row to look back when searching for matching patterns. + +### `--spec-ngram-size-m M` + +Sets the size M of the draft m-gram for n-gram map based speculative decoding. +The m-gram size determines how many tokens to draft when a match is found. +Larger values can provide more speedup but may reduce acceptance rate. + +### `--spec-ngram-check-rate R` + +This option aims at performance if the n-gram lookup in history is to costly. A lookup will be executed at every R tokens (default is 1, every token). + +### `--spec-ngram-min-hits H` + +This option defines how often a key has to appear in the token history to be used as a draft (default is 1). + +## Statistics +Each speculative decoding implementation prints statistics. + +``` +draft acceptance rate = 0.57576 ( 171 accepted / 297 generated) +statistics ngram_simple: #calls = 15, #gen drafts = 5, #acc drafts = 5, #gen tokens = 187, #acc tokens = 73 +statistics draft: #calls = 10, #gen drafts = 10, #acc drafts = 10, #gen tokens = 110, #acc tokens = 98 +``` + +``` +draft acceptance rate = 0.70312 ( 90 accepted / 128 generated) +statistics ngram_mod: #calls = 810, #gen drafts = 15, #acc drafts = 15, #gen tokens = 960, #acc tokens = 730, dur(b,g,a) = 0.149, 0.347, 0.005 ms +``` + +- `#calls`: number of calls of this implementations +- `#gen drafts`: number of drafts generated by this implementation +- `#acc drafts`: number of drafts accepted (partially) by the main model +- `#gen tokens`: number of tokens generated by this implementation (including rejected tokens) +- `#acc tokens`: number of tokens accepted by the main model +- `dur(b,g,a): durations of begin (new prompt), generation and accumulation (process acceptance). + diff --git a/examples/deprecation-warning/README.md b/examples/deprecation-warning/README.md index 59918ec2bbf..9a1b263e8e5 100644 --- a/examples/deprecation-warning/README.md +++ b/examples/deprecation-warning/README.md @@ -1,7 +1,7 @@ # Migration notice for binary filenames > [!IMPORTANT] -[2024 Jun 12] Binaries have been renamed w/ a `llama-` prefix. `main` is now `llama-cli`, `server` is `llama-server`, etc (https://github.com/ggerganov/llama.cpp/pull/7809) +[2024 Jun 12] Binaries have been renamed w/ a `llama-` prefix. `main` is now `llama-cli`, `server` is `llama-server`, etc (https://github.com/ggml-org/llama.cpp/pull/7809) This migration was important, but it is a breaking change that may not always be immediately obvious to users. diff --git a/examples/deprecation-warning/deprecation-warning.cpp b/examples/deprecation-warning/deprecation-warning.cpp index c2958ea12d9..11f5147328a 100644 --- a/examples/deprecation-warning/deprecation-warning.cpp +++ b/examples/deprecation-warning/deprecation-warning.cpp @@ -28,7 +28,7 @@ int main(int argc, char** argv) { fprintf(stdout, "\n"); fprintf(stdout, "WARNING: The binary '%s' is deprecated.\n", filename.c_str()); fprintf(stdout, " Please use '%s' instead.\n", replacement_filename.c_str()); - fprintf(stdout, " See https://github.com/ggerganov/llama.cpp/tree/master/examples/deprecation-warning/README.md for more information.\n"); + fprintf(stdout, " See https://github.com/ggml-org/llama.cpp/tree/master/examples/deprecation-warning/README.md for more information.\n"); fprintf(stdout, "\n"); return EXIT_FAILURE; diff --git a/examples/json_schema_to_grammar.py b/examples/json_schema_to_grammar.py index 886dd3d81ec..9fc90a3c987 100755 --- a/examples/json_schema_to_grammar.py +++ b/examples/json_schema_to_grammar.py @@ -402,7 +402,7 @@ def _visit_pattern(self, pattern, name): Transforms a regular expression pattern into a GBNF rule. Input: https://json-schema.org/understanding-json-schema/reference/regular_expressions - Output: https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md + Output: https://github.com/ggml-org/llama.cpp/blob/master/grammars/README.md Unsupported features: negative/positive lookaheads, greedy/non-greedy modifiers. diff --git a/examples/lookahead/lookahead.cpp b/examples/lookahead/lookahead.cpp index f54cfdd77f2..aa6efa62b3b 100644 --- a/examples/lookahead/lookahead.cpp +++ b/examples/lookahead/lookahead.cpp @@ -50,6 +50,12 @@ int main(int argc, char ** argv) { const int N = 5; // n-gram size const int G = 15; // max verification n-grams + // lookahead requires W + G + 1 sequences for parallel Jacobi decoding + params.n_parallel = W + G + 1; + + // unified KV cache is required for coupled sequences in batch splitting + params.kv_unified = true; + // init llama.cpp llama_backend_init(); llama_numa_init(params.numa); @@ -115,7 +121,7 @@ int main(int argc, char ** argv) { // seq_id == 0 : the current input token // seq_id [1, W] : tokens from the past N - 1 Jacobi iterations // seq_id [W + 1, W + G] : verification n-grams - llama_batch batch = llama_batch_init(params.n_ctx, 0, W + G + 1); + llama_batch batch = llama_batch_init(llama_n_ctx(ctx), 0, W + G + 1); // target model sampling context struct common_sampler * smpl = common_sampler_init(model, params.sampling); diff --git a/examples/lookup/lookup-create.cpp b/examples/lookup/lookup-create.cpp index bb94a8fe06d..f7b6ea1b190 100644 --- a/examples/lookup/lookup-create.cpp +++ b/examples/lookup/lookup-create.cpp @@ -32,9 +32,9 @@ int main(int argc, char ** argv){ common_ngram_cache ngram_cache; common_ngram_cache_update(ngram_cache, LLAMA_NGRAM_STATIC, LLAMA_NGRAM_STATIC, inp, inp.size(), true); - fprintf(stderr, "%s: hashing done, writing file to %s\n", __func__, params.lookup_cache_static.c_str()); + fprintf(stderr, "%s: hashing done, writing file to %s\n", __func__, params.speculative.lookup_cache_static.c_str()); - common_ngram_cache_save(ngram_cache, params.lookup_cache_static); + common_ngram_cache_save(ngram_cache, params.speculative.lookup_cache_static); return 0; } diff --git a/examples/lookup/lookup-stats.cpp b/examples/lookup/lookup-stats.cpp index 135f6fcab95..ae28b2e6e86 100644 --- a/examples/lookup/lookup-stats.cpp +++ b/examples/lookup/lookup-stats.cpp @@ -46,18 +46,18 @@ int main(int argc, char ** argv){ { const int64_t t_start_draft_us = ggml_time_us(); - if (!params.lookup_cache_static.empty()) { + if (!params.speculative.lookup_cache_static.empty()) { try { - ngram_cache_static = common_ngram_cache_load(params.lookup_cache_static); + ngram_cache_static = common_ngram_cache_load(params.speculative.lookup_cache_static); } catch (std::ifstream::failure const &) { - LOG_ERR("failed to open static lookup cache: %s", params.lookup_cache_static.c_str()); + LOG_ERR("failed to open static lookup cache: %s", params.speculative.lookup_cache_static.c_str()); exit(1); } } - if (!params.lookup_cache_dynamic.empty()) { + if (!params.speculative.lookup_cache_dynamic.empty()) { try { - ngram_cache_dynamic = common_ngram_cache_load(params.lookup_cache_dynamic); + ngram_cache_dynamic = common_ngram_cache_load(params.speculative.lookup_cache_dynamic); } catch (std::ifstream::failure const &) {} // if the file does not exist it will simply be created at the end of the program } diff --git a/examples/lookup/lookup.cpp b/examples/lookup/lookup.cpp index 27f159940a4..c7552ddde14 100644 --- a/examples/lookup/lookup.cpp +++ b/examples/lookup/lookup.cpp @@ -51,18 +51,18 @@ int main(int argc, char ** argv){ const int64_t t_start_draft_us = ggml_time_us(); common_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, inp, inp.size(), false); - if (!params.lookup_cache_static.empty()) { + if (!params.speculative.lookup_cache_static.empty()) { try { - ngram_cache_static = common_ngram_cache_load(params.lookup_cache_static); + ngram_cache_static = common_ngram_cache_load(params.speculative.lookup_cache_static); } catch (std::ifstream::failure const &) { - LOG_ERR("failed to open static lookup cache: %s", params.lookup_cache_static.c_str()); + LOG_ERR("failed to open static lookup cache: %s", params.speculative.lookup_cache_static.c_str()); exit(1); } } - if (!params.lookup_cache_dynamic.empty()) { + if (!params.speculative.lookup_cache_dynamic.empty()) { try { - ngram_cache_dynamic = common_ngram_cache_load(params.lookup_cache_dynamic); + ngram_cache_dynamic = common_ngram_cache_load(params.speculative.lookup_cache_dynamic); } catch (std::ifstream::failure const &) {} // if the file does not exist it will simply be created at the end of the program } @@ -106,7 +106,7 @@ int main(int argc, char ** argv){ std::vector draft; - llama_batch batch_tgt = llama_batch_init(params.n_ctx, 0, 1); + llama_batch batch_tgt = llama_batch_init(llama_n_ctx(ctx), 0, 1); const auto t_dec_start = ggml_time_us(); @@ -210,7 +210,7 @@ int main(int argc, char ** argv){ // Update dynamic ngram cache with context ngram cache and save it to disk: common_ngram_cache_merge(ngram_cache_dynamic, ngram_cache_context); - common_ngram_cache_save(ngram_cache_dynamic, params.lookup_cache_dynamic); + common_ngram_cache_save(ngram_cache_dynamic, params.speculative.lookup_cache_dynamic); LOG("\n\n"); diff --git a/examples/model-conversion/Makefile b/examples/model-conversion/Makefile index 3b0505911d3..342de63bd00 100644 --- a/examples/model-conversion/Makefile +++ b/examples/model-conversion/Makefile @@ -33,11 +33,14 @@ DEVICE ?= auto causal-convert-model-bf16: OUTTYPE=bf16 causal-convert-model-bf16: causal-convert-model +causal-convert-model-debug: DEBUG=--debug +causal-convert-model-debug: causal-convert-model + causal-convert-model: $(call validate_model_path,causal-convert-model) @MODEL_NAME="$(MODEL_NAME)" OUTTYPE="$(OUTTYPE)" MODEL_PATH="$(MODEL_PATH)" \ METADATA_OVERRIDE="$(METADATA_OVERRIDE)" \ - ./scripts/causal/convert-model.sh + ./scripts/causal/convert-model.sh $(DEBUG) causal-convert-mm-model-bf16: OUTTYPE=bf16 causal-convert-mm-model-bf16: MM_OUTTYPE=f16 diff --git a/examples/model-conversion/scripts/causal/convert-model.sh b/examples/model-conversion/scripts/causal/convert-model.sh index 32ffe132e78..a5865f6acd3 100755 --- a/examples/model-conversion/scripts/causal/convert-model.sh +++ b/examples/model-conversion/scripts/causal/convert-model.sh @@ -4,12 +4,17 @@ set -e # Parse command line arguments MMPROJ="" +DEBUG="" while [[ $# -gt 0 ]]; do case $1 in --mmproj) MMPROJ="--mmproj" shift ;; + --debug) + DEBUG="1" + shift + ;; *) shift ;; @@ -28,7 +33,12 @@ echo "Data type: ${TYPE}" echo "Converted model path:: ${CONVERTED_MODEL}" echo "Metadata override: ${METADATA_OVERRIDE}" -CMD_ARGS=("python" "../../convert_hf_to_gguf.py" "--verbose") +if [[ -n "$DEBUG" ]]; then + CMD_ARGS=("python" "-m" "pdb") +else + CMD_ARGS=("python") +fi +CMD_ARGS+=("../../convert_hf_to_gguf.py" "--verbose") CMD_ARGS+=("${MODEL_PATH}") CMD_ARGS+=("--outfile" "${CONVERTED_MODEL}") CMD_ARGS+=("--outtype" "${TYPE}") diff --git a/examples/model-conversion/scripts/utils/perplexity-gen.sh b/examples/model-conversion/scripts/utils/perplexity-gen.sh index 4885acbae24..ef4b650fdab 100755 --- a/examples/model-conversion/scripts/utils/perplexity-gen.sh +++ b/examples/model-conversion/scripts/utils/perplexity-gen.sh @@ -3,6 +3,7 @@ set -e CONVERTED_MODEL="${1:-"$CONVERTED_MODEL"}" +BUILD_DIR="${2:-"$BUILD_DIR"}" # Final check if we have a model path if [ -z "$CONVERTED_MODEL" ]; then @@ -25,9 +26,13 @@ mkdir -p ppl OUTPUTFILE="ppl/$(basename $CONVERTED_MODEL).kld" echo "Model: $CONVERTED_MODEL" -cmake --build ../../build --target llama-perplexity -j8 +if [ -z "$BUILD_DIR" ]; then + BUILD_DIR="../../build" +fi + +cmake --build $BUILD_DIR --target llama-perplexity -j8 -../.././build/bin/llama-perplexity -m $CONVERTED_MODEL \ +${BUILD_DIR}/bin/llama-perplexity -m $CONVERTED_MODEL \ -f ppl/wikitext-2-raw/wiki.test.raw \ --kl-divergence-base $OUTPUTFILE diff --git a/examples/model-conversion/scripts/utils/perplexity-run-simple.sh b/examples/model-conversion/scripts/utils/perplexity-run-simple.sh index a2545436a5c..20ee9653a9e 100755 --- a/examples/model-conversion/scripts/utils/perplexity-run-simple.sh +++ b/examples/model-conversion/scripts/utils/perplexity-run-simple.sh @@ -3,6 +3,7 @@ set -e QUANTIZED_MODEL="${1:-"$QUANTIZED_MODEL"}" +BUILD_DIR="${2:-"$BUILD_DIR"}" if [ -z "$QUANTIZED_MODEL" ]; then echo "Error: Model path must be provided either as:" >&2 @@ -20,8 +21,12 @@ if [ ! -d "ppl/wikitext-2-raw" ]; then popd fi -cmake --build ../../build --target llama-perplexity -j8 +if [ -z "$BUILD_DIR" ]; then + BUILD_DIR="../../build" +fi + +cmake --build $BUILD_DIR --target llama-perplexity -j8 -../.././build/bin/llama-perplexity -m $QUANTIZED_MODEL -f ppl/wikitext-2-raw/wiki.test.raw +${BUILD_DIR}/bin/llama-perplexity -m $QUANTIZED_MODEL -f ppl/wikitext-2-raw/wiki.test.raw diff --git a/examples/model-conversion/scripts/utils/perplexity-run.sh b/examples/model-conversion/scripts/utils/perplexity-run.sh index 68b38e66285..c11f32c65f9 100755 --- a/examples/model-conversion/scripts/utils/perplexity-run.sh +++ b/examples/model-conversion/scripts/utils/perplexity-run.sh @@ -3,7 +3,8 @@ set -e QUANTIZED_MODEL="${1:-"$QUANTIZED_MODEL"}" -LOGITS_FILE="${1:-"$LOGITS_FILE"}" +LOGITS_FILE="${2:-"$LOGITS_FILE"}" +BUILD_DIR="${3:-"$BUILD_DIR"}" if [ -z "$QUANTIZED_MODEL" ]; then echo "Error: Model path must be provided either as:" >&2 @@ -18,11 +19,15 @@ if [ ! -f ${LOGITS_FILE} ]; then exit 1 fi +if [ -z "$BUILD_DIR" ]; then + BUILD_DIR="../../build" +fi + echo "Model: $QUANTIZED_MODEL" echo "Data file: $LOGITS_FILE" -cmake --build ../../build --target llama-perplexity -j8 +cmake --build $BUILD_DIR --target llama-perplexity -j8 -../.././build/bin/llama-perplexity -m $QUANTIZED_MODEL \ +${BUILD_DIR}/bin/llama-perplexity -m $QUANTIZED_MODEL \ --kl-divergence-base $LOGITS_FILE \ --kl-divergence diff --git a/examples/model-conversion/scripts/utils/quantize.sh b/examples/model-conversion/scripts/utils/quantize.sh index c25c5c21f3c..4c21a1345a6 100755 --- a/examples/model-conversion/scripts/utils/quantize.sh +++ b/examples/model-conversion/scripts/utils/quantize.sh @@ -6,6 +6,7 @@ CONVERTED_MODEL="${1:-"$CONVERTED_MODEL"}" QUANTIZED_TYPE="${2:-"$QUANTIZED_TYPE"}" TOKEN_EMBD_TYPE="${3:-"${TOKEN_EMBD_TYPE}"}" OUTPUT_TYPE="${4:-"${OUTPUT_TYPE}"}" +BUILD_DIR="${5:-"$BUILD_DIR"}" QUANTIZED_MODEL=$CONVERTED_MODEL # Final check if we have a model path @@ -33,12 +34,16 @@ else exit 1 fi -cmake --build ../../build --target llama-quantize -j8 +if [ -z "$BUILD_DIR" ]; then + BUILD_DIR="../../build" +fi + +cmake --build $BUILD_DIR --target llama-quantize -j8 echo $TOKEN_EMBD_TYPE echo $OUTPUT_TYPE -CMD_ARGS=("../../build/bin/llama-quantize") +CMD_ARGS=("${BUILD_DIR}/bin/llama-quantize") [[ -n "$TOKEN_EMBD_TYPE" ]] && CMD_ARGS+=("--token-embedding-type" "$TOKEN_EMBD_TYPE") [[ -n "$OUTPUT_TYPE" ]] && CMD_ARGS+=("--output-tensor-type" "$OUTPUT_TYPE") CMD_ARGS+=("$CONVERTED_MODEL" "$QUANTIZED_MODEL" "$QUANTIZED_TYPE") diff --git a/examples/model-conversion/scripts/utils/run-embedding-server.sh b/examples/model-conversion/scripts/utils/run-embedding-server.sh index d30b765964b..9f5fc2cf70f 100755 --- a/examples/model-conversion/scripts/utils/run-embedding-server.sh +++ b/examples/model-conversion/scripts/utils/run-embedding-server.sh @@ -4,6 +4,7 @@ set -e # # First try command line argument, then environment variable, then file CONVERTED_MODEL="${1:-"$CONVERTED_MODEL"}" +BUILD_DIR="${2:-"$BUILD_DIR"}" # Final check if we have a model path if [ -z "$CONVERTED_MODEL" ]; then @@ -13,10 +14,14 @@ if [ -z "$CONVERTED_MODEL" ]; then exit 1 fi +if [ -z "$BUILD_DIR" ]; then + BUILD_DIR="../../build" +fi + echo $CONVERTED_MODEL -cmake --build ../../build --target llama-server +cmake --build $BUILD_DIR --target llama-server -../../build/bin/llama-server -m $CONVERTED_MODEL \ +${BUILD_DIR}/bin/llama-server -m $CONVERTED_MODEL \ --embedding \ --pooling none diff --git a/examples/model-conversion/scripts/utils/tensor-info.py b/examples/model-conversion/scripts/utils/tensor-info.py new file mode 100755 index 00000000000..12a3430b495 --- /dev/null +++ b/examples/model-conversion/scripts/utils/tensor-info.py @@ -0,0 +1,159 @@ +#!/usr/bin/env python3 + +import argparse +import json +import os +import re +import sys +from pathlib import Path +from typing import Optional +from safetensors import safe_open + + +MODEL_SAFETENSORS_FILE = "model.safetensors" +MODEL_SAFETENSORS_INDEX = "model.safetensors.index.json" + + +def get_weight_map(model_path: Path) -> Optional[dict[str, str]]: + index_file = model_path / MODEL_SAFETENSORS_INDEX + + if index_file.exists(): + with open(index_file, 'r') as f: + index = json.load(f) + return index.get("weight_map", {}) + + return None + + +def get_all_tensor_names(model_path: Path) -> list[str]: + weight_map = get_weight_map(model_path) + + if weight_map is not None: + return list(weight_map.keys()) + + single_file = model_path / MODEL_SAFETENSORS_FILE + if single_file.exists(): + try: + with safe_open(single_file, framework="pt", device="cpu") as f: + return list(f.keys()) + except Exception as e: + print(f"Error reading {single_file}: {e}") + sys.exit(1) + + print(f"Error: No safetensors files found in {model_path}") + sys.exit(1) + + +def find_tensor_file(model_path: Path, tensor_name: str) -> Optional[str]: + weight_map = get_weight_map(model_path) + + if weight_map is not None: + return weight_map.get(tensor_name) + + single_file = model_path / MODEL_SAFETENSORS_FILE + if single_file.exists(): + return single_file.name + + return None + + +def normalize_tensor_name(tensor_name: str) -> str: + normalized = re.sub(r'\.\d+\.', '.#.', tensor_name) + normalized = re.sub(r'\.\d+$', '.#', normalized) + return normalized + + +def list_all_tensors(model_path: Path, unique: bool = False): + tensor_names = get_all_tensor_names(model_path) + + if unique: + seen = set() + for tensor_name in sorted(tensor_names): + normalized = normalize_tensor_name(tensor_name) + if normalized not in seen: + seen.add(normalized) + print(normalized) + else: + for tensor_name in sorted(tensor_names): + print(tensor_name) + + +def print_tensor_info(model_path: Path, tensor_name: str): + tensor_file = find_tensor_file(model_path, tensor_name) + + if tensor_file is None: + print(f"Error: Could not find tensor '{tensor_name}' in model index") + print(f"Model path: {model_path}") + sys.exit(1) + + file_path = model_path / tensor_file + + try: + with safe_open(file_path, framework="pt", device="cpu") as f: + if tensor_name in f.keys(): + tensor_slice = f.get_slice(tensor_name) + shape = tensor_slice.get_shape() + print(f"Tensor: {tensor_name}") + print(f"File: {tensor_file}") + print(f"Shape: {shape}") + else: + print(f"Error: Tensor '{tensor_name}' not found in {tensor_file}") + sys.exit(1) + + except FileNotFoundError: + print(f"Error: The file '{file_path}' was not found.") + sys.exit(1) + except Exception as e: + print(f"An error occurred: {e}") + sys.exit(1) + + +def main(): + parser = argparse.ArgumentParser( + description="Print tensor information from a safetensors model" + ) + parser.add_argument( + "tensor_name", + nargs="?", # optional (if --list is used for example) + help="Name of the tensor to inspect" + ) + parser.add_argument( + "-m", "--model-path", + type=Path, + help="Path to the model directory (default: MODEL_PATH environment variable)" + ) + parser.add_argument( + "-l", "--list", + action="store_true", + help="List unique tensor patterns in the model (layer numbers replaced with #)" + ) + + args = parser.parse_args() + + model_path = args.model_path + if model_path is None: + model_path_str = os.environ.get("MODEL_PATH") + if model_path_str is None: + print("Error: --model-path not provided and MODEL_PATH environment variable not set") + sys.exit(1) + model_path = Path(model_path_str) + + if not model_path.exists(): + print(f"Error: Model path does not exist: {model_path}") + sys.exit(1) + + if not model_path.is_dir(): + print(f"Error: Model path is not a directory: {model_path}") + sys.exit(1) + + if args.list: + list_all_tensors(model_path, unique=True) + else: + if args.tensor_name is None: + print("Error: tensor_name is required when not using --list") + sys.exit(1) + print_tensor_info(model_path, args.tensor_name) + + +if __name__ == "__main__": + main() diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index 8141052a227..d8b1f5a480c 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -24,7 +24,7 @@ int main(int argc, char ** argv) { common_init(); - if (params.speculative.model.path.empty()) { + if (params.speculative.mparams_dft.path.empty()) { LOG_ERR("%s: --model-draft is required\n", __func__); return 1; } @@ -34,10 +34,8 @@ int main(int argc, char ** argv) { llama_numa_init(params.numa); llama_model * model_tgt = NULL; - //llama_model * model_dft = NULL; llama_context * ctx_tgt = NULL; - llama_context * ctx_dft = NULL; // load the target model auto llama_init_tgt = common_init_from_params(params); @@ -48,26 +46,38 @@ int main(int argc, char ** argv) { const llama_vocab * vocab = llama_model_get_vocab(model_tgt); // load the draft model - params.devices = params.speculative.devices; - params.model = params.speculative.model; - params.n_ctx = params.speculative.n_ctx; - params.n_batch = params.speculative.n_ctx > 0 ? params.speculative.n_ctx : params.n_batch; - params.n_gpu_layers = params.speculative.n_gpu_layers; - - if (params.speculative.cpuparams.n_threads > 0) { - params.cpuparams.n_threads = params.speculative.cpuparams.n_threads; - } + llama_model_ptr model_dft; + + // TODO: simplify this logic + { + const auto & params_spec = params.speculative; - params.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads; - params.tensor_buft_overrides = params.speculative.tensor_buft_overrides; + auto params_dft = params; - auto llama_init_dft = common_init_from_params(params); + params_dft.n_parallel = 1; + params_dft.n_ctx = params_spec.n_ctx; + params_dft.n_batch = llama_n_ctx_seq(ctx_tgt); + params_dft.devices = params_spec.devices; + params_dft.model = params_spec.mparams_dft; + params_dft.n_gpu_layers = params_spec.n_gpu_layers; + + if (params_spec.cpuparams.n_threads > 0) { + params_dft.cpuparams.n_threads = params.speculative.cpuparams.n_threads; + params_dft.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads; + } - //model_dft = llama_init_dft->model(); - ctx_dft = llama_init_dft->context(); + params_dft.tensor_buft_overrides = params.speculative.tensor_buft_overrides; + + auto mparams_dft = common_model_params_to_llama(params_dft); + + model_dft.reset(llama_model_load_from_file(params_dft.model.path.c_str(), mparams_dft)); + if (model_dft == nullptr) { + LOG_ERR("failed to load draft model, '%s'\n", params_dft.model.path.c_str()); + return 1; + } - if (!common_speculative_are_compatible(ctx_tgt, ctx_dft)) { - LOG_INF("the draft model '%s' is not compatible with the target model '%s'. tokens will be translated between the draft and target models.\n", params.speculative.model.path.c_str(), params.model.path.c_str()); + params.speculative.model_dft = model_dft.get(); + params.speculative.cparams_dft = common_context_params_to_llama(params_dft); } // Tokenize the prompt @@ -92,12 +102,6 @@ int main(int argc, char ** argv) { LOG("%s", common_token_to_piece(ctx_tgt, id).c_str()); } - // how many tokens to draft each time - int n_draft = params.speculative.n_max; - int n_draft_min = params.speculative.n_min; - - float p_min = params.speculative.p_min; - int n_predict = 0; int n_drafted = 0; int n_accept = 0; @@ -127,15 +131,11 @@ int main(int argc, char ** argv) { int n_past = inp.size() - 1; // init the speculator - struct common_speculative_params params_spec; - params_spec.n_draft = n_draft; - params_spec.n_reuse = llama_n_ctx(ctx_dft) - n_draft; - params_spec.p_min = p_min; - - struct common_speculative * spec = common_speculative_init(ctx_tgt, ctx_dft); - for (auto &pair : params.speculative.replacements) { - common_speculative_add_replacement_tgt_dft(spec, pair.first.c_str(), pair.second.c_str()); - } + const auto & params_spec = params.speculative; + + struct common_speculative * spec = common_speculative_init(params.speculative, ctx_tgt); + + common_speculative_begin(spec, prompt_tgt); llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1); @@ -151,7 +151,7 @@ int main(int argc, char ** argv) { // offloaded to a remote device. it doesn't even have to be based on an LLM. instead, it can provide tokens // from a cache or lookup tables. // - llama_tokens draft = common_speculative_gen_draft(spec, params_spec, prompt_tgt, id_last); + llama_tokens draft = common_speculative_draft(spec, params_spec, prompt_tgt, id_last); //LOG_DBG("draft: %s\n", string_from(ctx_dft, draft).c_str()); @@ -162,7 +162,7 @@ int main(int argc, char ** argv) { // evaluate the target model on [id_last, draft0, draft1, ..., draftN-1] { // do not waste time on small drafts - if (draft.size() < (size_t) n_draft_min) { + if (draft.size() < (size_t) params_spec.n_min) { draft.clear(); } @@ -240,7 +240,7 @@ int main(int argc, char ** argv) { LOG_INF("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f)); LOG_INF("\n"); - LOG_INF("n_draft = %d\n", n_draft); + LOG_INF("n_draft = %d\n", params_spec.n_max); LOG_INF("n_predict = %d\n", n_predict); LOG_INF("n_drafted = %d\n", n_drafted); LOG_INF("n_accept = %d\n", n_accept); @@ -249,8 +249,6 @@ int main(int argc, char ** argv) { LOG_INF("\n"); LOG_INF("draft:\n\n"); - llama_perf_context_print(ctx_dft); - LOG_INF("\n"); LOG_INF("target:\n\n"); common_perf_print(ctx_tgt, smpl); diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 89d3249431e..3e5cf5f46b5 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -46,7 +46,7 @@ int main(int argc, char ** argv) { common_init(); - if (params.speculative.model.path.empty()) { + if (params.speculative.mparams_dft.path.empty()) { LOG_ERR("%s: --model-draft is required\n", __func__); return 1; } @@ -78,7 +78,7 @@ int main(int argc, char ** argv) { // load the draft model params.devices = params.speculative.devices; - params.model = params.speculative.model; + params.model = params.speculative.mparams_dft; params.n_gpu_layers = params.speculative.n_gpu_layers; if (params.speculative.cpuparams.n_threads > 0) { params.cpuparams.n_threads = params.speculative.cpuparams.n_threads; diff --git a/examples/sycl/run-llama2.sh b/examples/sycl/run-llama2.sh index cf23619ee04..d33f82f339b 100755 --- a/examples/sycl/run-llama2.sh +++ b/examples/sycl/run-llama2.sh @@ -18,13 +18,14 @@ CONTEXT=4096 #support malloc device memory more than 4GB. export UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1 +LOAD_MODE='--mmap' if [ $# -gt 0 ]; then GGML_SYCL_DEVICE=$1 echo "use $GGML_SYCL_DEVICE as main GPU" #use signle GPU only - ZES_ENABLE_SYSMAN=1 ./build/bin/llama-completion -m ${MODEL_FILE} -no-cnv -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONTEXT} -mg $GGML_SYCL_DEVICE -sm none + ZES_ENABLE_SYSMAN=1 ./build/bin/llama-completion -m ${MODEL_FILE} -no-cnv -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONTEXT} -mg $GGML_SYCL_DEVICE -sm none ${LOAD_MODE} else #use multiple GPUs with same max compute units - ZES_ENABLE_SYSMAN=1 ./build/bin/llama-completion -m ${MODEL_FILE} -no-cnv -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONTEXT} + ZES_ENABLE_SYSMAN=1 ./build/bin/llama-completion -m ${MODEL_FILE} -no-cnv -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONTEXT} ${LOAD_MODE} fi diff --git a/examples/sycl/run-llama3.sh b/examples/sycl/run-llama3.sh deleted file mode 100755 index feee5165e92..00000000000 --- a/examples/sycl/run-llama3.sh +++ /dev/null @@ -1,31 +0,0 @@ -#!/usr/bin/env bash - -# MIT license -# Copyright (C) 2025 Intel Corporation -# SPDX-License-Identifier: MIT - -# If you want more control, DPC++ Allows selecting a specific device through the -# following environment variable -export ONEAPI_DEVICE_SELECTOR="level_zero:0" -source /opt/intel/oneapi/setvars.sh - -#export GGML_SYCL_DEBUG=1 - -#ZES_ENABLE_SYSMAN=1, Support to get free memory of GPU by sycl::aspect::ext_intel_free_memory. Recommended to use when --split-mode = layer. - -INPUT_PROMPT="Building a website can be done in 10 simple steps:\nStep 1:" -MODEL_FILE=models/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf -NGL=99 # Layers offloaded to the GPU. If the device runs out of memory, reduce this value according to the model you are using. -CONTEXT=4096 - -#support malloc device memory more than 4GB. -export UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1 - -if [ $# -gt 0 ]; then - GGML_SYCL_DEVICE=$1 - echo "Using $GGML_SYCL_DEVICE as the main GPU" - ZES_ENABLE_SYSMAN=1 ./build/bin/llama-completion -m ${MODEL_FILE} -no-cnv -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONTEXT} -mg $GGML_SYCL_DEVICE -sm none -else - #use multiple GPUs with same max compute units - ZES_ENABLE_SYSMAN=1 ./build/bin/llama-completion -m ${MODEL_FILE} -no-cnv -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONTEXT} -fi diff --git a/examples/sycl/test.sh b/examples/sycl/test.sh new file mode 100755 index 00000000000..140c191466e --- /dev/null +++ b/examples/sycl/test.sh @@ -0,0 +1,130 @@ +#!/bin/bash + +# MIT license +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: MIT + +Help() { + cat << EOF +Usage: $(basename "$0") [OPTIONS] + +This script processes files with specified options. + +Options: + -h, --help Display this help message and exit. + -c, --context Set context length. Bigger need more memory. + -p, --promote Prompt to start generation with. + -m, --model Full model file path. + -mg,--main-gpu Set main GPU ID (0 - n) for single GPU mode. + -sm,--split-mode How to split the model across multiple GPUs, one of: + - none: use one GPU only + - layer (default): split layers and KV across GPUs + - row: split rows across GPUs + -ngl,--n-gpu-layers Max. number of layers to store in VRAM (default: -1) + -lv,--log-verbosity Set the verbosity threshold. Messages with a higher verbosity will be + ignored. Values: + - 0: generic output + - 1: error + - 2: warning + - 3: info + - 4: debug + + +EOF +} + +BIN_FILE=./build/bin/llama-completion +SEED=0 +GPUS_SETTING="" + +INPUT_PROMPT="Building a website can be done in 10 simple steps:\nStep 1:" +MODEL_FILE=models/llama-2-7b.Q4_0.gguf +NGL=99 +CONTEXT=4096 +GGML_SYCL_DEVICE=-1 +SPLIT_MODE=layer +LOG_VERBOSE=3 +while [[ $# -gt 0 ]]; do + case "$1" in + -c|--context) + CONTEXT=$2 + # Shift twice to consume both the option flag and its value + shift + shift + ;; + -p|--promote) + # Option that is a simple flag (boolean) + INPUT_PROMPT="$2" + # Shift once to consume the option flag + shift + shift + ;; + -m|--model) + MODEL_FILE="$2" + # Shift twice to consume both the option flag and its value + shift + shift + ;; + -mg|--main-gpu) + GGML_SYCL_DEVICE=$2 + SPLIT_MODE=none + # Shift twice to consume both the option flag and its value + shift + shift + ;; + -sm|--split-mode) + SPLIT_MODE=$2 + # Shift twice to consume both the option flag and its value + shift + shift + ;; + -ngl|--n-gpu-layers) + NGL=$2 + # Shift twice to consume both the option flag and its value + shift + shift + ;; + -lv|--log-verbosity) + LOG_VERBOSE=$2 + # Shift twice to consume both the option flag and its value + shift + shift + ;; + -h|--help) + Help + exit 0 + ;; + *) + # Handle unknown options or stop processing options + echo "Invalid option: $1" + # Optional: exit script or shift to treat remaining as positional args + exit 1 + ;; + esac +done + + + +source /opt/intel/oneapi/setvars.sh + +#export GGML_SYCL_DEBUG=1 + +#ZES_ENABLE_SYSMAN=1, Support to get free memory of GPU by sycl::aspect::ext_intel_free_memory. Recommended to use when --split-mode = layer. + +#support malloc device memory more than 4GB. +export UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1 +echo "UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=${UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS}" + +if [ $GGML_SYCL_DEVICE -ne -1 ]; then + echo "Use $GGML_SYCL_DEVICE as main GPU" + #use signle GPU only + GPUS_SETTING="-mg $GGML_SYCL_DEVICE -sm ${SPLIT_MODE}" + export ONEAPI_DEVICE_SELECTOR="level_zero:${$GGML_SYCL_DEVICE}" + echo "ONEAPI_DEVICE_SELECTOR=${ONEAPI_DEVICE_SELECTOR}" +else + echo "Use all Intel GPUs, including iGPU & dGPU" + fi + +echo "run cmd: ZES_ENABLE_SYSMAN=1 ${BIN_FILE} -m ${MODEL_FILE} -no-cnv -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s ${SEED} -c ${CONTEXT} ${GPUS_SETTING} -lv ${LOG_VERBOSE} --mmap " +ZES_ENABLE_SYSMAN=1 ${BIN_FILE} -m ${MODEL_FILE} -no-cnv -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s ${SEED} -c ${CONTEXT} ${GPUS_SETTING} -lv ${LOG_VERBOSE} --mmap + diff --git a/examples/sycl/win-run-llama2.bat b/examples/sycl/win-run-llama2.bat index 32ff673ae26..1f2dab8d0a8 100644 --- a/examples/sycl/win-run-llama2.bat +++ b/examples/sycl/win-run-llama2.bat @@ -7,5 +7,5 @@ set INPUT2="Building a website can be done in 10 simple steps:\nStep 1:" :: support malloc device memory more than 4GB. set UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1 - -.\build\bin\llama-completion.exe -m models\llama-2-7b.Q4_0.gguf -no-cnv -p %INPUT2% -n 400 -e -ngl 99 -s 0 +set LOAD_MODE="--mmap" +.\build\bin\llama-completion.exe -m models\llama-2-7b.Q4_0.gguf -no-cnv -p %INPUT2% -n 400 -e -ngl 99 -s 0 %LOAD_MODE% diff --git a/examples/sycl/win-run-llama3.bat b/examples/sycl/win-test.bat similarity index 69% rename from examples/sycl/win-run-llama3.bat rename to examples/sycl/win-test.bat index ea4ae69d6c7..1f2dab8d0a8 100644 --- a/examples/sycl/win-run-llama3.bat +++ b/examples/sycl/win-test.bat @@ -7,5 +7,5 @@ set INPUT2="Building a website can be done in 10 simple steps:\nStep 1:" :: support malloc device memory more than 4GB. set UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1 - -.\build\bin\llama-completion.exe -m models\Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf -no-cnv -p %INPUT2% -n 400 -s 0 -e -ngl 99 +set LOAD_MODE="--mmap" +.\build\bin\llama-completion.exe -m models\llama-2-7b.Q4_0.gguf -no-cnv -p %INPUT2% -n 400 -e -ngl 99 -s 0 %LOAD_MODE% diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 0176ca1ce93..71d1a7f0e34 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.14) # for add_link_options and implicit target directories. +cmake_minimum_required(VERSION 3.14...3.28) # for add_link_options and implicit target directories. project("ggml" C CXX ASM) ### GGML Version @@ -228,6 +228,8 @@ option(GGML_WEBGPU_CPU_PROFILE "ggml: enable WebGPU profiling (CPU) option(GGML_WEBGPU_GPU_PROFILE "ggml: enable WebGPU profiling (GPU)" OFF) option(GGML_WEBGPU_JSPI "ggml: use JSPI for WebGPU" ON) option(GGML_ZDNN "ggml: use zDNN" OFF) +option(GGML_VIRTGPU "ggml: use the VirtGPU/Virglrenderer API Remoting frontend" OFF) +option(GGML_VIRTGPU_BACKEND "ggml: build the VirtGPU/Virglrenderer API Remoting backend" OFF) option(GGML_METAL "ggml: use Metal" ${GGML_METAL_DEFAULT}) option(GGML_METAL_NDEBUG "ggml: disable Metal debugging" OFF) option(GGML_METAL_SHADER_DEBUG "ggml: compile Metal with -fno-fast-math" OFF) @@ -320,6 +322,7 @@ set(GGML_PUBLIC_HEADERS include/ggml-opt.h include/ggml-metal.h include/ggml-rpc.h + include/ggml-virtgpu.h include/ggml-sycl.h include/ggml-vulkan.h include/ggml-webgpu.h diff --git a/ggml/include/ggml-cann.h b/ggml/include/ggml-cann.h index b469e228d06..74af465337a 100644 --- a/ggml/include/ggml-cann.h +++ b/ggml/include/ggml-cann.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023-2024 The ggml authors + * Copyright (c) 2023-2026 The ggml authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to diff --git a/ggml/include/ggml-cpu.h b/ggml/include/ggml-cpu.h index 4f3b99c8d07..e3e067c916f 100644 --- a/ggml/include/ggml-cpu.h +++ b/ggml/include/ggml-cpu.h @@ -19,6 +19,9 @@ extern "C" { // abort ggml_graph_compute when true ggml_abort_callback abort_callback; void * abort_callback_data; + + // use only reference implementations + bool use_ref; }; // numa strategies @@ -132,6 +135,8 @@ extern "C" { GGML_BACKEND_API void ggml_backend_cpu_set_threadpool (ggml_backend_t backend_cpu, ggml_threadpool_t threadpool); GGML_BACKEND_API void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data); + GGML_BACKEND_API void ggml_backend_cpu_set_use_ref(ggml_backend_t backend_cpu, bool use_ref); + GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cpu_reg(void); GGML_BACKEND_API void ggml_cpu_fp32_to_fp32(const float *, float *, int64_t); diff --git a/ggml/include/ggml-virtgpu.h b/ggml/include/ggml-virtgpu.h new file mode 100644 index 00000000000..faaba8f246d --- /dev/null +++ b/ggml/include/ggml-virtgpu.h @@ -0,0 +1,14 @@ +#pragma once + +#include "ggml.h" +#include "ggml-backend.h" + +#ifdef __cplusplus +extern "C" { +#endif + +GGML_BACKEND_API ggml_backend_reg_t ggml_backend_virtgpu_reg(); + +#ifdef __cplusplus +} +#endif diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 1988d16dc42..f759e2d5883 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -6,7 +6,7 @@ // This documentation is still a work in progress. // If you wish some specific topics to be covered, feel free to drop a comment: // -// https://github.com/ggerganov/whisper.cpp/issues/40 +// https://github.com/ggml-org/whisper.cpp/issues/40 // // ## Overview // diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index 6192a870466..265023733e7 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -222,6 +222,7 @@ if (GGML_SCHED_NO_REALLOC) endif() add_library(ggml + ggml-backend-dl.cpp ggml-backend-reg.cpp) add_library(ggml::ggml ALIAS ggml) @@ -451,6 +452,7 @@ ggml_add_backend(HIP) ggml_add_backend(METAL) ggml_add_backend(MUSA) ggml_add_backend(RPC) +ggml_add_backend(VirtGPU) ggml_add_backend(SYCL) ggml_add_backend(Vulkan) ggml_add_backend(WebGPU) diff --git a/ggml/src/ggml-backend-dl.cpp b/ggml/src/ggml-backend-dl.cpp new file mode 100644 index 00000000000..a65cf009055 --- /dev/null +++ b/ggml/src/ggml-backend-dl.cpp @@ -0,0 +1,48 @@ +#include "ggml-backend-dl.h" + +#ifdef _WIN32 + +dl_handle * dl_load_library(const fs::path & path) { + // suppress error dialogs for missing DLLs + DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS); + SetErrorMode(old_mode | SEM_FAILCRITICALERRORS); + + HMODULE handle = LoadLibraryW(path.wstring().c_str()); + + SetErrorMode(old_mode); + + return handle; +} + +void * dl_get_sym(dl_handle * handle, const char * name) { + DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS); + SetErrorMode(old_mode | SEM_FAILCRITICALERRORS); + + void * p = (void *) GetProcAddress(handle, name); + + SetErrorMode(old_mode); + + return p; +} + +const char * dl_error() { + return ""; +} + +#else + +dl_handle * dl_load_library(const fs::path & path) { + dl_handle * handle = dlopen(path.string().c_str(), RTLD_NOW | RTLD_LOCAL); + return handle; +} + +void * dl_get_sym(dl_handle * handle, const char * name) { + return dlsym(handle, name); +} + +const char * dl_error() { + const char *rslt = dlerror(); + return rslt != nullptr ? rslt : ""; +} + +#endif diff --git a/ggml/src/ggml-backend-dl.h b/ggml/src/ggml-backend-dl.h new file mode 100644 index 00000000000..f74b7c94894 --- /dev/null +++ b/ggml/src/ggml-backend-dl.h @@ -0,0 +1,45 @@ +#pragma once + +#ifdef _WIN32 +# define WIN32_LEAN_AND_MEAN +# ifndef NOMINMAX +# define NOMINMAX +# endif +# include +# include +#else +# include +# include +#endif +#include + +namespace fs = std::filesystem; + +#ifdef _WIN32 + +using dl_handle = std::remove_pointer_t; + +struct dl_handle_deleter { + void operator()(HMODULE handle) { + FreeLibrary(handle); + } +}; + +#else + +using dl_handle = void; + +struct dl_handle_deleter { + void operator()(void * handle) { + dlclose(handle); + } +}; + +#endif + +using dl_handle_ptr = std::unique_ptr; + +dl_handle * dl_load_library(const fs::path & path); +void * dl_get_sym(dl_handle * handle, const char * name); +const char * dl_error(); + diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp index 6bee1bc4b49..8a693f84af5 100644 --- a/ggml/src/ggml-backend-reg.cpp +++ b/ggml/src/ggml-backend-reg.cpp @@ -1,5 +1,6 @@ #include "ggml-backend-impl.h" #include "ggml-backend.h" +#include "ggml-backend-dl.h" #include "ggml-impl.h" #include #include @@ -69,6 +70,10 @@ #include "ggml-rpc.h" #endif +#ifdef GGML_USE_VIRTGPU_FRONTEND +#include "ggml-virtgpu.h" +#endif + #ifdef GGML_USE_CANN #include "ggml-cann.h" #endif @@ -94,72 +99,6 @@ static std::string path_str(const fs::path & path) { } } -#ifdef _WIN32 - -using dl_handle = std::remove_pointer_t; - -struct dl_handle_deleter { - void operator()(HMODULE handle) { - FreeLibrary(handle); - } -}; - -static dl_handle * dl_load_library(const fs::path & path) { - // suppress error dialogs for missing DLLs - DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS); - SetErrorMode(old_mode | SEM_FAILCRITICALERRORS); - - HMODULE handle = LoadLibraryW(path.wstring().c_str()); - - SetErrorMode(old_mode); - - return handle; -} - -static void * dl_get_sym(dl_handle * handle, const char * name) { - DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS); - SetErrorMode(old_mode | SEM_FAILCRITICALERRORS); - - void * p = (void *) GetProcAddress(handle, name); - - SetErrorMode(old_mode); - - return p; -} - -static const char * dl_error() { - return ""; -} - -#else - -using dl_handle = void; - -struct dl_handle_deleter { - void operator()(void * handle) { - dlclose(handle); - } -}; - -static void * dl_load_library(const fs::path & path) { - dl_handle * handle = dlopen(path.string().c_str(), RTLD_NOW | RTLD_LOCAL); - - return handle; -} - -static void * dl_get_sym(dl_handle * handle, const char * name) { - return dlsym(handle, name); -} - -static const char * dl_error() { - const char *rslt = dlerror(); - return rslt != nullptr ? rslt : ""; -} - -#endif - -using dl_handle_ptr = std::unique_ptr; - struct ggml_backend_reg_entry { ggml_backend_reg_t reg; dl_handle_ptr handle; @@ -180,7 +119,12 @@ struct ggml_backend_registry { register_backend(ggml_backend_sycl_reg()); #endif #ifdef GGML_USE_VULKAN + // Add runtime disable check + if (getenv("GGML_DISABLE_VULKAN") == nullptr) { register_backend(ggml_backend_vk_reg()); + } else { + GGML_LOG_DEBUG("Vulkan backend disabled by GGML_DISABLE_VULKAN environment variable\n"); + } #endif #ifdef GGML_USE_WEBGPU register_backend(ggml_backend_webgpu_reg()); @@ -188,6 +132,10 @@ struct ggml_backend_registry { #ifdef GGML_USE_ZDNN register_backend(ggml_backend_zdnn_reg()); #endif +#ifdef GGML_USE_VIRTGPU_FRONTEND + register_backend(ggml_backend_virtgpu_reg()); +#endif + #ifdef GGML_USE_OPENCL register_backend(ggml_backend_opencl_reg()); #endif @@ -604,6 +552,7 @@ void ggml_backend_load_all_from_path(const char * dir_path) { ggml_backend_load_best("rpc", silent, dir_path); ggml_backend_load_best("sycl", silent, dir_path); ggml_backend_load_best("vulkan", silent, dir_path); + ggml_backend_load_best("virtgpu", silent, dir_path); ggml_backend_load_best("opencl", silent, dir_path); ggml_backend_load_best("hexagon", silent, dir_path); ggml_backend_load_best("musa", silent, dir_path); diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index 354876574a0..22c656996cc 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -258,6 +258,7 @@ void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor * GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds"); if (backend->iface.set_tensor_async == NULL) { + ggml_backend_synchronize(backend); ggml_backend_tensor_set(tensor, data, offset, size); } else { backend->iface.set_tensor_async(backend, tensor, data, offset, size); @@ -271,6 +272,7 @@ void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_ten GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds"); if (backend->iface.get_tensor_async == NULL) { + ggml_backend_synchronize(backend); ggml_backend_tensor_get(tensor, data, offset, size); } else { backend->iface.get_tensor_async(backend, tensor, data, offset, size); diff --git a/ggml/src/ggml-cann/acl_tensor.cpp b/ggml/src/ggml-cann/acl_tensor.cpp index 7b7042a1f54..e95d3c4d88d 100644 --- a/ggml/src/ggml-cann/acl_tensor.cpp +++ b/ggml/src/ggml-cann/acl_tensor.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023-2024 The ggml authors + * Copyright (c) 2023-2026 The ggml authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to diff --git a/ggml/src/ggml-cann/acl_tensor.h b/ggml/src/ggml-cann/acl_tensor.h index 7deac383420..4737773a4d4 100644 --- a/ggml/src/ggml-cann/acl_tensor.h +++ b/ggml/src/ggml-cann/acl_tensor.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023-2024 The ggml authors + * Copyright (c) 2023-2026 The ggml authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index 02867e4fdb5..87ac05748e8 100644 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023-2024 The ggml authors + * Copyright (c) 2023-2026 The ggml authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to diff --git a/ggml/src/ggml-cann/aclnn_ops.h b/ggml/src/ggml-cann/aclnn_ops.h index b76e4707ac7..3effa1c289c 100644 --- a/ggml/src/ggml-cann/aclnn_ops.h +++ b/ggml/src/ggml-cann/aclnn_ops.h @@ -1,5 +1,5 @@ /** - * Copyright (c) 2023-2024 The ggml authors + * Copyright (c) 2023-2026 The ggml authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to diff --git a/ggml/src/ggml-cann/common.h b/ggml/src/ggml-cann/common.h index fb3e7572e2c..0120f0dfd1e 100644 --- a/ggml/src/ggml-cann/common.h +++ b/ggml/src/ggml-cann/common.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023-2024 The ggml authors + * Copyright (c) 2023-2026 The ggml authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index 42c6c67a40b..6b2dbdd3591 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023-2024 The ggml authors + * Copyright (c) 2023-2026 The ggml authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h index 3f8946ac701..427c1146e46 100644 --- a/ggml/src/ggml-cpu/arch-fallback.h +++ b/ggml/src/ggml-cpu/arch-fallback.h @@ -1,3 +1,4 @@ + #pragma once // Rename `_generic` functions if no native implementation is available. @@ -38,9 +39,11 @@ #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 +#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K -#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K +#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K +#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 @@ -48,9 +51,11 @@ #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 +#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K -#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K +#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K +# define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 @@ -70,12 +75,16 @@ #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K +#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K +#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K +#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K +#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 @@ -94,9 +103,11 @@ #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 +#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K -#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K +#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K +#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 @@ -104,9 +115,11 @@ #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 +#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K -#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K +#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K +#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 @@ -126,9 +139,11 @@ #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 +#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K -#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K +#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K +#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 @@ -136,9 +151,11 @@ #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 +#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K -#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K +#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K +#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 @@ -165,18 +182,22 @@ #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 +#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K -#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K +#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K +#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 +#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K -#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K +#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K +#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 @@ -202,9 +223,11 @@ #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 +#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K -#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K +#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K +#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 @@ -212,9 +235,11 @@ #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 +#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K -#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K +#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K +#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 @@ -242,9 +267,11 @@ #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 +#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K -#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K +#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K +#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 @@ -252,9 +279,11 @@ #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 +#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K -#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K +#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K +#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 diff --git a/ggml/src/ggml-cpu/arch/arm/repack.cpp b/ggml/src/ggml-cpu/arch/arm/repack.cpp index b61220a189a..99bb70274c5 100644 --- a/ggml/src/ggml-cpu/arch/arm/repack.cpp +++ b/ggml/src/ggml-cpu/arch/arm/repack.cpp @@ -25,9 +25,8 @@ #define UNUSED GGML_UNUSED #if defined(__aarch64__) && defined(__ARM_NEON) && (defined(__ARM_FEATURE_MATMUL_INT8) || defined(__ARM_FEATURE_DOTPROD)) -static inline void decode_q4_Kx8_scales_mins(const uint8_t * scales_in, - int16x8_t * out_mins, - int8_t * out_scales) { +// Helper for decoding scales and mins of Q4_K and Q5_K block formats +static inline void decode_q_Kx8_6bit_scales(const uint8_t * scales_in, int16x8_t * out_mins, int8_t * out_scales) { constexpr uint32_t kmask1 = 0x3f3f3f3f; constexpr uint32_t kmask2 = 0x0f0f0f0f; constexpr uint32_t kmask3 = 0x03030303; @@ -561,7 +560,7 @@ void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo for (int i = 0; i < 2; i++) { int8_t aux_q4sb[8]; const int offset = sb * 24 + i * 12; - decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb); + decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb); q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb)); } @@ -701,7 +700,7 @@ void ggml_gemv_q4_K_8x8_q8_K(int n, for (int i = 0; i < 2; i++) { int8_t aux_q4sb[8]; const int offset = sb * 24 + i * 12; - decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb); + decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb); q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb)); } @@ -786,6 +785,495 @@ void ggml_gemv_q4_K_8x8_q8_K(int n, ggml_gemv_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc); } +void ggml_gemv_q5_K_8x8_q8_K(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + constexpr int qk = QK_K; + const int nb = n / qk; + + constexpr int ncols_interleaved = 8; + constexpr int blocklen = 8; + + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + constexpr int col_pairs = ncols_interleaved / 2; + const uint8x16_t m4b = vdupq_n_u8(0x0f); + const uint8x16_t mone = vdupq_n_u8(1); + const uint8x16_t mtwo = vdupq_n_u8(2); + + // 1x8 tile = 2 x 4 + float32x4_t acc_f32[ncols_interleaved / 4]; + + const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy; + + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q5_Kx8 * GGML_RESTRICT q5_ptr = (const block_q5_Kx8 *) vx + (x * nb); + + for (int i = 0; i < ncols_interleaved / 4; i++) { + acc_f32[i] = vdupq_n_f32(0); + } + + for (int b = 0; b < nb; b++) { + float32x4_t q5_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d)); // d0 d1 d2 d3 + float32x4_t q5_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d + 4)); // d4 d5 d6 d7 + float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d); + float32x4_t sb_scale_0 = vmulq_f32(q5_d_0, q8_d); + float32x4_t sb_scale_1 = vmulq_f32(q5_d_1, q8_d); + float32x4_t q5_dmin_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin)); // dmin 0..3 + float32x4_t q5_dmin_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin + 4)); // dmin 4..7 + float32x4_t sb_min_0 = vmulq_f32(q5_dmin_0, q8_d); + float32x4_t sb_min_1 = vmulq_f32(q5_dmin_1, q8_d); + + // 2 sb each iteration + int32x4_t acc_lo[col_pairs]; + int32x4_t acc_hi[col_pairs]; + + // Each bsum is 16 elements, pairwise add leaves us with the 8 bsums of the entire block + const int16x8_t bsums = vpaddq_s16(vld1q_s16(q8_ptr[b].bsums), vld1q_s16(q8_ptr[b].bsums + 8)); + int16_t bsums_arr[8]; + vst1q_s16(bsums_arr, bsums); + + // Load qh once per block and shift after each subblock + const uint8_t * qh_base = q5_ptr[b].qh; + uint8x16_t qh[col_pairs][4]; + for (int cp = 0; cp < col_pairs; cp++) { + qh[cp][0] = vld1q_u8(qh_base + 16 * cp); + qh[cp][1] = vld1q_u8(qh_base + 16 * cp + 64); + qh[cp][2] = vld1q_u8(qh_base + 16 * cp + 128); + qh[cp][3] = vld1q_u8(qh_base + 16 * cp + 192); + } + + for (int sb = 0; sb < QK_K / 64; sb++) { + for (int i = 0; i < col_pairs; i++) { + acc_lo[i] = vdupq_n_s32(0); + acc_hi[i] = vdupq_n_s32(0); + } + // Need scales for the low and high nibbles + // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total + int16x8_t q5sb_mins[2]; // int16 as its needed for bias_acc later + int16x8_t q5sb_scales[2]; + for (int i = 0; i < 2; i++) { + int8_t aux_q5sb[8]; + const int offset = sb * 24 + i * 12; + decode_q_Kx8_6bit_scales(&q5_ptr[b].scales[offset], &q5sb_mins[i], aux_q5sb); + q5sb_scales[i] = vmovl_s8(vld1_s8(aux_q5sb)); + } + + const uint8_t * qs_base = q5_ptr[b].qs + sb * QK_K; + + // Load the 64 quants from q8K duplicated to use vecdots with the interleaved columns + const int8_t * q8_base = q8_ptr[b].qs + sb * 64; + int8x16_t q8_qs[8]; + for (int i = 0; i < 8; i++) { + q8_qs[i] = (int8x16_t) vld1q_dup_s64((const int64_t *) (q8_base + i * 8)); + } + + // Q5s column pair loop unrolled + { + // Cols 01 + uint8x16_t qs_0 = vld1q_u8(qs_base); + uint8x16_t qs_1 = vld1q_u8(qs_base + 64); + uint8x16_t qs_2 = vld1q_u8(qs_base + 128); + uint8x16_t qs_3 = vld1q_u8(qs_base + 192); + + uint8x16_t hbit_lo_0 = vandq_u8(qh[0][0], mone); + uint8x16_t hbit_lo_1 = vandq_u8(qh[0][1], mone); + uint8x16_t hbit_lo_2 = vandq_u8(qh[0][2], mone); + uint8x16_t hbit_lo_3 = vandq_u8(qh[0][3], mone); + uint8x16_t hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[0][0], mtwo), 3); + uint8x16_t hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[0][1], mtwo), 3); + uint8x16_t hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[0][2], mtwo), 3); + uint8x16_t hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[0][3], mtwo), 3); + + qh[0][0] = vshrq_n_u8(qh[0][0], 2); + qh[0][1] = vshrq_n_u8(qh[0][1], 2); + qh[0][2] = vshrq_n_u8(qh[0][2], 2); + qh[0][3] = vshrq_n_u8(qh[0][3], 2); + + acc_lo[0] = ggml_vdotq_s32( + acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]); + acc_lo[0] = ggml_vdotq_s32( + acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]); + acc_lo[0] = ggml_vdotq_s32( + acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]); + acc_lo[0] = ggml_vdotq_s32( + acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]); + acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)), + q8_qs[4]); + acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)), + q8_qs[5]); + acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)), + q8_qs[6]); + acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)), + q8_qs[7]); + + // Cols 23 + qs_0 = vld1q_u8(qs_base + 16); + qs_1 = vld1q_u8(qs_base + 80); + qs_2 = vld1q_u8(qs_base + 144); + qs_3 = vld1q_u8(qs_base + 208); + + hbit_lo_0 = vandq_u8(qh[1][0], mone); + hbit_lo_1 = vandq_u8(qh[1][1], mone); + hbit_lo_2 = vandq_u8(qh[1][2], mone); + hbit_lo_3 = vandq_u8(qh[1][3], mone); + hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[1][0], mtwo), 3); + hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[1][1], mtwo), 3); + hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[1][2], mtwo), 3); + hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[1][3], mtwo), 3); + + qh[1][0] = vshrq_n_u8(qh[1][0], 2); + qh[1][1] = vshrq_n_u8(qh[1][1], 2); + qh[1][2] = vshrq_n_u8(qh[1][2], 2); + qh[1][3] = vshrq_n_u8(qh[1][3], 2); + + acc_lo[1] = ggml_vdotq_s32( + acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]); + acc_lo[1] = ggml_vdotq_s32( + acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]); + acc_lo[1] = ggml_vdotq_s32( + acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]); + acc_lo[1] = ggml_vdotq_s32( + acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]); + acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)), + q8_qs[4]); + acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)), + q8_qs[5]); + acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)), + q8_qs[6]); + acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)), + q8_qs[7]); + + // Cols 45 + qs_0 = vld1q_u8(qs_base + 32); + qs_1 = vld1q_u8(qs_base + 96); + qs_2 = vld1q_u8(qs_base + 160); + qs_3 = vld1q_u8(qs_base + 224); + + hbit_lo_0 = vandq_u8(qh[2][0], mone); + hbit_lo_1 = vandq_u8(qh[2][1], mone); + hbit_lo_2 = vandq_u8(qh[2][2], mone); + hbit_lo_3 = vandq_u8(qh[2][3], mone); + hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[2][0], mtwo), 3); + hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[2][1], mtwo), 3); + hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[2][2], mtwo), 3); + hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[2][3], mtwo), 3); + + qh[2][0] = vshrq_n_u8(qh[2][0], 2); + qh[2][1] = vshrq_n_u8(qh[2][1], 2); + qh[2][2] = vshrq_n_u8(qh[2][2], 2); + qh[2][3] = vshrq_n_u8(qh[2][3], 2); + + acc_lo[2] = ggml_vdotq_s32( + acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]); + acc_lo[2] = ggml_vdotq_s32( + acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]); + acc_lo[2] = ggml_vdotq_s32( + acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]); + acc_lo[2] = ggml_vdotq_s32( + acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]); + acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)), + q8_qs[4]); + acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)), + q8_qs[5]); + acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)), + q8_qs[6]); + acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)), + q8_qs[7]); + + // Cols 45 + qs_0 = vld1q_u8(qs_base + 48); + qs_1 = vld1q_u8(qs_base + 112); + qs_2 = vld1q_u8(qs_base + 176); + qs_3 = vld1q_u8(qs_base + 240); + + hbit_lo_0 = vandq_u8(qh[3][0], mone); + hbit_lo_1 = vandq_u8(qh[3][1], mone); + hbit_lo_2 = vandq_u8(qh[3][2], mone); + hbit_lo_3 = vandq_u8(qh[3][3], mone); + hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[3][0], mtwo), 3); + hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[3][1], mtwo), 3); + hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[3][2], mtwo), 3); + hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[3][3], mtwo), 3); + + qh[3][0] = vshrq_n_u8(qh[3][0], 2); + qh[3][1] = vshrq_n_u8(qh[3][1], 2); + qh[3][2] = vshrq_n_u8(qh[3][2], 2); + qh[3][3] = vshrq_n_u8(qh[3][3], 2); + + acc_lo[3] = ggml_vdotq_s32( + acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]); + acc_lo[3] = ggml_vdotq_s32( + acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]); + acc_lo[3] = ggml_vdotq_s32( + acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]); + acc_lo[3] = ggml_vdotq_s32( + acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]); + acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)), + q8_qs[4]); + acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)), + q8_qs[5]); + acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)), + q8_qs[6]); + acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)), + q8_qs[7]); + } + + // Prepare bsum vectors for bias computation + // Each pair of subblocks share the same bsums + int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[2 * sb + 0]); + int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[2 * sb + 1]); + + // Iterates over a pair of column pairs (4 columns) to use a single 128 register + // p = 0 -> 0123 p2 -> 4567 + for (int i = 0, p = 0; p < col_pairs; i++, p += 2) { + int16x4_t group_scales_lo = p == 0 ? vget_low_s16(q5sb_scales[0]) : vget_high_s16(q5sb_scales[0]); + int16x4_t group_scales_hi = p == 0 ? vget_low_s16(q5sb_scales[1]) : vget_high_s16(q5sb_scales[1]); + int16x4_t group_mins_lo = p == 0 ? vget_low_s16(q5sb_mins[0]) : vget_high_s16(q5sb_mins[0]); + int16x4_t group_mins_hi = p == 0 ? vget_low_s16(q5sb_mins[1]) : vget_high_s16(q5sb_mins[1]); + float32x4_t sb_scale = p == 0 ? sb_scale_0 : sb_scale_1; + float32x4_t sb_min = p == 0 ? sb_min_0 : sb_min_1; + + // 0123 or 4567 + float32x4_t sumf_0 = + vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_lo), vpaddq_s32(acc_lo[p], acc_lo[p + 1]))); + acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_0); + + float32x4_t sumf_1 = + vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_hi), vpaddq_s32(acc_hi[p], acc_hi[p + 1]))); + acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_1); + + // FUSED BIAS: Compute and subtract bias immediately + // bias = (bsums_lo * mins_lo + bsums_hi * mins_hi) * sb_min + int32x4_t bias = vmull_s16(bsums_vec_lo, group_mins_lo); + bias = vmlal_s16(bias, bsums_vec_hi, group_mins_hi); + float32x4_t bias_f32 = vcvtq_f32_s32(bias); + acc_f32[i] = vmlsq_f32(acc_f32[i], sb_min, bias_f32); + } + } // for sb + } // for b + + int base = x * ncols_interleaved; + vst1q_f32(s + base, acc_f32[0]); + vst1q_f32(s + base + 4, acc_f32[1]); + } // for x + return; +#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + ggml_gemv_q5_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemv_q6_K_8x8_q8_K(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + constexpr int qk = QK_K; + const int nb = n / qk; + + constexpr int ncols_interleaved = 8; + constexpr int blocklen = 8; + + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + constexpr int col_pairs = ncols_interleaved / 2; + const uint8x16_t m4b = vdupq_n_u8(0x0f); + const uint8x16_t mask_lo = vdupq_n_u8(0x03); + const uint8x16_t mask_hi = vdupq_n_u8(0x30); + + // 1x8 tile = 2 x 4 + float32x4_t acc_f32[2]; + + const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy; + + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q6_Kx8 * GGML_RESTRICT q6_ptr = (const block_q6_Kx8 *) vx + (x * nb); + + acc_f32[0] = vdupq_n_f32(0); + acc_f32[1] = vdupq_n_f32(0); + + for (int b = 0; b < nb; b++) { + float32x4_t q6_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d)); // d0 d1 d2 d3 + float32x4_t q6_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d + 4)); // d4 d5 d6 d7 + float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d); + float32x4_t sb_scale_0 = vmulq_f32(q6_d_0, q8_d); + float32x4_t sb_scale_1 = vmulq_f32(q6_d_1, q8_d); + + int32x2_t acc[col_pairs]; + for (int i = 0; i < col_pairs; i++) { + acc[i] = vdup_n_s32(0); + } + + // Load all 16 scales once and widen to int16 (Q6_K has 16 scales per block) + // Reused for bias and dequantization later + int16_t q6_scales[16 * 8]; + for (int i = 0; i < 16; i++) { + int16x8_t scales = vmovl_s8(vld1_s8(q6_ptr[b].scales + i * 8)); + vst1q_s16(q6_scales + i * 8, scales); + } + + // Compute bias per column using q8 bsums and preloaded scales to skip the -32 shift + int32x4_t bias_lo = vdupq_n_s32(0); + int32x4_t bias_hi = vdupq_n_s32(0); + + // Load bsums in chunks of 4 to process with vectorized operations + for (int i = 0; i < 16; i += 4) { + int16x4_t bsums_vec = vld1_s16(q8_ptr[b].bsums + i); + int16x4_t scales_lo_0 = vld1_s16(q6_scales + (i + 0) * 8); + int16x4_t scales_hi_0 = vld1_s16(q6_scales + (i + 0) * 8 + 4); + int16x4_t scales_lo_1 = vld1_s16(q6_scales + (i + 1) * 8); + int16x4_t scales_hi_1 = vld1_s16(q6_scales + (i + 1) * 8 + 4); + int16x4_t scales_lo_2 = vld1_s16(q6_scales + (i + 2) * 8); + int16x4_t scales_hi_2 = vld1_s16(q6_scales + (i + 2) * 8 + 4); + int16x4_t scales_lo_3 = vld1_s16(q6_scales + (i + 3) * 8); + int16x4_t scales_hi_3 = vld1_s16(q6_scales + (i + 3) * 8 + 4); + + bias_lo = vmlal_lane_s16(bias_lo, scales_lo_0, bsums_vec, 0); + bias_hi = vmlal_lane_s16(bias_hi, scales_hi_0, bsums_vec, 0); + bias_lo = vmlal_lane_s16(bias_lo, scales_lo_1, bsums_vec, 1); + bias_hi = vmlal_lane_s16(bias_hi, scales_hi_1, bsums_vec, 1); + bias_lo = vmlal_lane_s16(bias_lo, scales_lo_2, bsums_vec, 2); + bias_hi = vmlal_lane_s16(bias_hi, scales_hi_2, bsums_vec, 2); + bias_lo = vmlal_lane_s16(bias_lo, scales_lo_3, bsums_vec, 3); + bias_hi = vmlal_lane_s16(bias_hi, scales_hi_3, bsums_vec, 3); + } + bias_lo = vshlq_n_s32(bias_lo, 5); + bias_hi = vshlq_n_s32(bias_hi, 5); + + // Process two 128-value halves per superblock + for (int half = 0; half < 2; half++) { + const uint8_t * ql_base = q6_ptr[b].ql + half * 512; + const uint8_t * qh_base = q6_ptr[b].qh + half * 256; + + // A subblock (sb) is a set of weights that share the scale + // Since q6_K scales are per 16 elements + // num sbs -> 256 elements / (16 elements/scale * 2 elements/byte * 2 halves) + for (int sb = 0; sb < QK_K / 64; sb++) { + const int8_t * q8_base_l = q8_ptr[b].qs + half * 128 + sb * 16; + const int8_t * q8_base_h = q8_base_l + 64; + + // Load and duplicate q8 values (each register covers two interleaved columns of q6) + int8x16_t q8_l[2]; + int8x16_t q8_h[2]; + for (int i = 0; i < 2; i++) { + q8_l[i] = (int8x16_t) vld1q_dup_s64((const int64_t *) (q8_base_l + i * 8)); + q8_h[i] = (int8x16_t) vld1q_dup_s64((const int64_t *) (q8_base_h + i * 8)); + } + + // TODO: Test other qh repack patterns to reduce loads + const int ql_off_base = sb * QK_K / 2; + const int qh_off_base = ql_off_base & 255; // wraps after 256 bytes + + // Load 4 vectors at once (64 bytes each for ql_0, ql_1, qh_0, qh_1) + ggml_uint8x16x4_t q6_ql_0 = ggml_vld1q_u8_x4(ql_base + ql_off_base); + ggml_uint8x16x4_t q6_ql_1 = ggml_vld1q_u8_x4(ql_base + ql_off_base + 64); + ggml_uint8x16x4_t q6_qh_0 = ggml_vld1q_u8_x4(qh_base + qh_off_base); + ggml_uint8x16x4_t q6_qh_1 = ggml_vld1q_u8_x4(qh_base + qh_off_base + 64); + + // Adjust qh for subblocks 2 and 3 (shift right by 2) + if (sb > 1) { + q6_qh_0.val[0] = vshrq_n_u8(q6_qh_0.val[0], 2); + q6_qh_0.val[1] = vshrq_n_u8(q6_qh_0.val[1], 2); + q6_qh_0.val[2] = vshrq_n_u8(q6_qh_0.val[2], 2); + q6_qh_0.val[3] = vshrq_n_u8(q6_qh_0.val[3], 2); + q6_qh_1.val[0] = vshrq_n_u8(q6_qh_1.val[0], 2); + q6_qh_1.val[1] = vshrq_n_u8(q6_qh_1.val[1], 2); + q6_qh_1.val[2] = vshrq_n_u8(q6_qh_1.val[2], 2); + q6_qh_1.val[3] = vshrq_n_u8(q6_qh_1.val[3], 2); + } + + // Process column pairs (0-1, 2-3, 4-5, 6-7) + for (int cp = 0; cp < col_pairs; cp++) { + const uint8x16_t q6_qs_cp_0_l = q6_ql_0.val[cp]; + const uint8x16_t q6_qs_cp_1_l = q6_ql_1.val[cp]; + const uint8x16_t q6_qs_cp_0_h = q6_qh_0.val[cp]; + const uint8x16_t q6_qs_cp_1_h = q6_qh_1.val[cp]; + + // Extract high 2 bits for upper nibble reconstruction + const uint8x16_t q6_qs_cp_0_hh = vandq_u8(q6_qs_cp_0_h, mask_hi); + const uint8x16_t q6_qs_cp_1_hh = vandq_u8(q6_qs_cp_1_h, mask_hi); + + // q6 = (low4 | high2<<4), without -32 bias (handled via bsums) + const int8x16_t q6_l0 = vreinterpretq_s8_u8( + vsliq_n_u8(vandq_u8(q6_qs_cp_0_l, m4b), vandq_u8(q6_qs_cp_0_h, mask_lo), 4)); + const int8x16_t q6_l1 = vreinterpretq_s8_u8( + vsliq_n_u8(vandq_u8(q6_qs_cp_1_l, m4b), vandq_u8(q6_qs_cp_1_h, mask_lo), 4)); + const int8x16_t q6_h0 = + vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_qs_cp_0_l, 4), q6_qs_cp_0_hh)); + const int8x16_t q6_h1 = + vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_qs_cp_1_l, 4), q6_qs_cp_1_hh)); + + int32x4_t sb_acc_l = vdupq_n_s32(0); + sb_acc_l = vdotq_s32(sb_acc_l, q6_l0, q8_l[0]); + sb_acc_l = vdotq_s32(sb_acc_l, q6_l1, q8_l[1]); + + int32x4_t sb_acc_h = vdupq_n_s32(0); + sb_acc_h = vdotq_s32(sb_acc_h, q6_h0, q8_h[0]); + sb_acc_h = vdotq_s32(sb_acc_h, q6_h1, q8_h[1]); + + // Pairwise add to get per-column sums: [col0, col1] + int32x2_t sum_l = vpadd_s32(vget_low_s32(sb_acc_l), vget_high_s32(sb_acc_l)); + int32x2_t sum_h = vpadd_s32(vget_low_s32(sb_acc_h), vget_high_s32(sb_acc_h)); + + const int scale_idx_l = half * 8 + sb; + const int scale_idx_h = half * 8 + sb + 4; + + // Access scales using array indexing (scales are interleaved by column) + const int32x2_t scale_vec_l = { (int32_t) q6_scales[scale_idx_l * 8 + cp * 2], + (int32_t) q6_scales[scale_idx_l * 8 + cp * 2 + 1] }; + const int32x2_t scale_vec_h = { (int32_t) q6_scales[scale_idx_h * 8 + cp * 2], + (int32_t) q6_scales[scale_idx_h * 8 + cp * 2 + 1] }; + + // Accumulate scaled results + acc[cp] = vmla_s32(acc[cp], sum_l, scale_vec_l); + acc[cp] = vmla_s32(acc[cp], sum_h, scale_vec_h); + } + } + } // for half + + // Bias correction + acc[0] = vsub_s32(acc[0], vget_low_s32(bias_lo)); + acc[1] = vsub_s32(acc[1], vget_high_s32(bias_lo)); + acc[2] = vsub_s32(acc[2], vget_low_s32(bias_hi)); + acc[3] = vsub_s32(acc[3], vget_high_s32(bias_hi)); + + // Apply superblock scale (no mins for q6_K) + // acc[cp] has [c0, c1] + float32x2_t w_01 = vmul_f32(vcvt_f32_s32(acc[0]), vget_low_f32(sb_scale_0)); + float32x2_t w_23 = vmul_f32(vcvt_f32_s32(acc[1]), vget_high_f32(sb_scale_0)); + float32x2_t w_45 = vmul_f32(vcvt_f32_s32(acc[2]), vget_low_f32(sb_scale_1)); + float32x2_t w_67 = vmul_f32(vcvt_f32_s32(acc[3]), vget_high_f32(sb_scale_1)); + + acc_f32[0] = vaddq_f32(acc_f32[0], vcombine_f32(w_01, w_23)); + acc_f32[1] = vaddq_f32(acc_f32[1], vcombine_f32(w_45, w_67)); + } // for b + + int base = x * ncols_interleaved; + vst1q_f32(s + base, acc_f32[0]); + vst1q_f32(s + base + 4, acc_f32[1]); + } // for x + return; +#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + ggml_gemv_q6_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc); +} + void ggml_gemv_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, @@ -2431,7 +2919,7 @@ void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo for (int i = 0; i < 2; i++) { int8_t aux_q4sb[8]; const int offset = sb * 24 + i * 12; - decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb); + decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb); q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb)); } @@ -2595,7 +3083,7 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, int16x8_t q4sb_mins[2]; // int16 as its needed for bias_acc later for (int i = 0; i < 2; i++) { const int offset = sb * 24 + i * 12; - decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], q4sb_scales[i]); + decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], q4sb_scales[i]); } // q8_ptr[b].qs has interleaved Q8 rows (01, 23) @@ -2660,16 +3148,17 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, // Scales[i] corresponds to column i const int scale_offset = cp * 2; - for (int blk = 0; blk < 2; blk++) { - const int32x4_t block_scale = { - (int32_t) q4sb_scales[blk][scale_offset], - (int32_t) q4sb_scales[blk][scale_offset], - (int32_t) q4sb_scales[blk][scale_offset + 1], - (int32_t) q4sb_scales[blk][scale_offset + 1], - }; - acc[cp] = vmlaq_s32(acc[cp], sb_acc[blk], block_scale); - acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[blk + 2], block_scale); - } + const int32_t scale_00 = q4sb_scales[0][scale_offset]; + const int32_t scale_01 = q4sb_scales[0][scale_offset + 1]; + const int32_t scale_10 = q4sb_scales[1][scale_offset]; + const int32_t scale_11 = q4sb_scales[1][scale_offset + 1]; + const int32x4_t block_scale_0 = vcombine_s32(vdup_n_s32(scale_00), vdup_n_s32(scale_01)); + const int32x4_t block_scale_1 = vcombine_s32(vdup_n_s32(scale_10), vdup_n_s32(scale_11)); + + acc[cp] = vmlaq_s32(acc[cp], sb_acc[0], block_scale_0); + acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[2], block_scale_0); + acc[cp] = vmlaq_s32(acc[cp], sb_acc[1], block_scale_1); + acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[3], block_scale_1); } // Multiply Acc bsum + mins @@ -2738,6 +3227,469 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, ggml_gemm_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc); } +void ggml_gemm_q5_K_8x8_q8_K(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + constexpr int qk = QK_K; + const int nb = n / qk; + + constexpr int ncols_interleaved = 8; + constexpr int blocklen = 8; + + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) + constexpr int q8_k_blocklen = 4; + constexpr int col_pairs = ncols_interleaved / 2; + const uint8x16_t m4b = vdupq_n_u8(0x0f); + const uint8x16_t mone = vdupq_n_u8(1); + const uint8x16_t mtwo = vdupq_n_u8(2); + + // 8 accumulators: 2 row pairs × 4 col pairs + float32x4_t acc_f32[blocklen]; + + for (int y = 0; y < nr / q8_k_blocklen; y++) { + const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb); + + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q5_Kx8 * GGML_RESTRICT q5_ptr = (const block_q5_Kx8 *) vx + (x * nb); + + for (int i = 0; i < blocklen; i++) { + acc_f32[i] = vdupq_n_f32(0); + } + + for (int b = 0; b < nb; b++) { + // bsums pairs belongs to the same q8_k subblock + const int16x8_t bsums[4]{ + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)), + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)), + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)), + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)), + }; + int16_t bsums_arr[4][8]; + for (int q8_row = 0; q8_row < 4; q8_row++) { + vst1q_s16(bsums_arr[q8_row], bsums[q8_row]); + } + + int32x4_t sb_acc[4]; // Aux accumulators to store subblock (partial) results + int32x4_t acc[8]; // rows 01 stored in [0][1][2][3] rows 23 stored in [4][5][6][7] + int32x4_t bias_acc[8]; // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567, [2]->r1 0123 ... + for (int i = 0; i < 8; i++) { + acc[i] = vdupq_n_s32(0); + bias_acc[i] = vdupq_n_s32(0); + } + + // Load qh once per block and shift after each subblock + const uint8_t * qh_base = q5_ptr[b].qh; + uint8x16_t qh[col_pairs][4]; + for (int cp = 0; cp < col_pairs; cp++) { + qh[cp][0] = vld1q_u8(qh_base + 16 * cp); + qh[cp][1] = vld1q_u8(qh_base + 16 * cp + 64); + qh[cp][2] = vld1q_u8(qh_base + 16 * cp + 128); + qh[cp][3] = vld1q_u8(qh_base + 16 * cp + 192); + } + + for (int sb = 0; sb < QK_K / 64; sb++) { + // Need scales for the low and high nibbles + // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total + int8_t q5sb_scales[2][8]; + int16x8_t q5sb_mins[2]; // int16 as its needed for bias_acc later + for (int i = 0; i < 2; i++) { + const int offset = sb * 24 + i * 12; + decode_q_Kx8_6bit_scales(&q5_ptr[b].scales[offset], &q5sb_mins[i], q5sb_scales[i]); + } + + // q8_ptr[b].qs has interleaved Q8 rows (01, 23) + const int8_t * q8_base = q8_ptr[b].qs + sb * 256; + + int8x16_t q8_qs_01[8]; + int8x16_t q8_qs_23[8]; + + // Load 32-byte per row pair, 1 subblock each time + for (int i = 0; i < 8; i++) { + const int offset = i * 32; // 16 for row 01, 16 for row 23 + q8_qs_01[i] = vld1q_s8(q8_base + offset); + q8_qs_23[i] = vld1q_s8(q8_base + offset + 16); + } + + const int8x16_t q8s[2][8] = { + { q8_qs_01[0], q8_qs_01[1], q8_qs_01[2], q8_qs_01[3], q8_qs_01[4], q8_qs_01[5], q8_qs_01[6], + q8_qs_01[7] }, + { q8_qs_23[0], q8_qs_23[1], q8_qs_23[2], q8_qs_23[3], q8_qs_23[4], q8_qs_23[5], q8_qs_23[6], + q8_qs_23[7] }, + }; + + // Q5s columns iterated in pairs (01, 23, 45, 67) + for (int cp = 0; cp < col_pairs; cp++) { + for (int i = 0; i < 4; i++) { + sb_acc[i] = vdupq_n_s32(0); + } + + uint8x16_t qs_cp_0 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 0); // 0 .. 7 & 32..39 + uint8x16_t qs_cp_1 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 64); // 8 ..15 & 40..47 + uint8x16_t qs_cp_2 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 128); // 16..23 & 48..55 + uint8x16_t qs_cp_3 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 192); // 24..31 & 56..63 + + // This is the only part of the algorithm that differs with Q4_K + // Extract High bits and pack into 5 bit weights + uint8x16_t hbit_lo_0 = vandq_u8(qh[cp][0], mone); + uint8x16_t hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[cp][0], mtwo), 3); + qh[cp][0] = vshrq_n_u8(qh[cp][0], 2); + // Same as Q4_K, i8mm to dequantize the weights. + const int8x16_t qs_lo_0 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_0, m4b), hbit_lo_0, 4)); + int32x4_t acc_0 = sb_acc[0]; + acc_0 = vmmlaq_s32(acc_0, qs_lo_0, q8s[0][0]); + int32x4_t acc_2 = sb_acc[2]; + acc_2 = vmmlaq_s32(acc_2, qs_lo_0, q8s[1][0]); + const int8x16_t qs_hi_0 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_0, 4), hbit_hi_0)); + int32x4_t acc_1 = sb_acc[1]; + acc_1 = vmmlaq_s32(acc_1, qs_hi_0, q8s[0][4]); + int32x4_t acc_3 = sb_acc[3]; + acc_3 = vmmlaq_s32(acc_3, qs_hi_0, q8s[1][4]); + + // Repeat for the other 3 columns (8..15, 16..23, 24..31) + uint8x16_t hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[cp][1], mtwo), 3); + uint8x16_t hbit_lo_1 = vandq_u8(qh[cp][1], mone); + qh[cp][1] = vshrq_n_u8(qh[cp][1], 2); + const int8x16_t qs_lo_1 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_1, m4b), hbit_lo_1, 4)); + acc_0 = vmmlaq_s32(acc_0, qs_lo_1, q8s[0][1]); + acc_2 = vmmlaq_s32(acc_2, qs_lo_1, q8s[1][1]); + const int8x16_t qs_hi_1 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_1, 4), hbit_hi_1)); + acc_1 = vmmlaq_s32(acc_1, qs_hi_1, q8s[0][5]); + acc_3 = vmmlaq_s32(acc_3, qs_hi_1, q8s[1][5]); + + uint8x16_t hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[cp][2], mtwo), 3); + uint8x16_t hbit_lo_2 = vandq_u8(qh[cp][2], mone); + qh[cp][2] = vshrq_n_u8(qh[cp][2], 2); + const int8x16_t qs_lo_2 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_2, m4b), hbit_lo_2, 4)); + acc_0 = vmmlaq_s32(acc_0, qs_lo_2, q8s[0][2]); + acc_2 = vmmlaq_s32(acc_2, qs_lo_2, q8s[1][2]); + const int8x16_t qs_hi_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_2, 4), hbit_hi_2)); + acc_1 = vmmlaq_s32(acc_1, qs_hi_2, q8s[0][6]); + acc_3 = vmmlaq_s32(acc_3, qs_hi_2, q8s[1][6]); + + uint8x16_t hbit_lo_3 = vandq_u8(qh[cp][3], mone); + uint8x16_t hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[cp][3], mtwo), 3); + qh[cp][3] = vshrq_n_u8(qh[cp][3], 2); + const int8x16_t qs_lo_3 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_3, m4b), hbit_lo_3, 4)); + acc_0 = vmmlaq_s32(acc_0, qs_lo_3, q8s[0][3]); + sb_acc[0] = acc_0; + acc_2 = vmmlaq_s32(acc_2, qs_lo_3, q8s[1][3]); + sb_acc[2] = acc_2; + + // Scales[i] corresponds to column i + const int scale_offset = cp * 2; + const int32_t s0 = q5sb_scales[0][scale_offset]; + const int32_t s1 = q5sb_scales[0][scale_offset + 1]; + const int32x4_t block_scale = vcombine_s32(vdup_n_s32(s0), vdup_n_s32(s1)); + acc[cp] = vmlaq_s32(acc[cp], sb_acc[0], block_scale); + acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[2], block_scale); + + const int8x16_t qs_hi_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_3, 4), hbit_hi_3)); + acc_1 = vmmlaq_s32(acc_1, qs_hi_3, q8s[0][7]); + sb_acc[1] = acc_1; + acc_3 = vmmlaq_s32(acc_3, qs_hi_3, q8s[1][7]); + sb_acc[3] = acc_3; + + const int32_t s2 = q5sb_scales[1][scale_offset]; + const int32_t s3 = q5sb_scales[1][scale_offset + 1]; + const int32x4_t block_scale2 = vcombine_s32(vdup_n_s32(s2), vdup_n_s32(s3)); + acc[cp] = vmlaq_s32(acc[cp], sb_acc[1], block_scale2); + acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[3], block_scale2); + } + + // Multiply Acc bsum + mins + for (int q8_row = 0; q8_row < 4; q8_row++) { + // Each pair of subblocks share the same bsums + // Load scalar bsum → broadcast to a vector (vdupq_n_s16(s)). + int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][q8_row * 2]); + int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][q8_row * 2 + 1]); + + bias_acc[2 * q8_row] = + vmlal_s16(bias_acc[2 * q8_row], bsums_vec_lo, vget_low_s16(q5sb_mins[0])); + bias_acc[2 * q8_row] = + vmlal_s16(bias_acc[2 * q8_row], bsums_vec_hi, vget_low_s16(q5sb_mins[1])); + bias_acc[2 * q8_row + 1] = + vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_lo, vget_high_s16(q5sb_mins[0])); + bias_acc[2 * q8_row + 1] = + vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_hi, vget_high_s16(q5sb_mins[1])); + } + } // for sb + + // Reorder of i8mm output with bias and output layout + for (int i = 0; i < 8; i++) { + int32x2x2_t aux = vzip_s32(vget_low_s32(acc[i]), vget_high_s32(acc[i])); + acc[i] = vcombine_s32(aux.val[0], aux.val[1]); + } + int32x4_t reorder_acc[8] = { + vcombine_s32(vget_low_s32(acc[0]), vget_low_s32(acc[1])), + vcombine_s32(vget_low_s32(acc[2]), vget_low_s32(acc[3])), + vcombine_s32(vget_high_s32(acc[0]), vget_high_s32(acc[1])), + vcombine_s32(vget_high_s32(acc[2]), vget_high_s32(acc[3])), + vcombine_s32(vget_low_s32(acc[4]), vget_low_s32(acc[5])), + vcombine_s32(vget_low_s32(acc[6]), vget_low_s32(acc[7])), + vcombine_s32(vget_high_s32(acc[4]), vget_high_s32(acc[5])), + vcombine_s32(vget_high_s32(acc[6]), vget_high_s32(acc[7])), + }; + + for (int i = 0; i < q8_k_blocklen; i++) { + for (int j = 0; j < 2; j++) { + float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d[i]); + float32x4_t q5_dmin = vcvt_f32_f16(vld1_f16((const __fp16 *) (q5_ptr[b].dmin + j * 4))); + const float32x4_t dmins = vmulq_f32(q5_dmin, q8_d); + + float32x4_t q5_d = vcvt_f32_f16(vld1_f16((const __fp16 *) (q5_ptr[b].d + j * 4))); + const float32x4_t scale = vmulq_f32(q5_d, q8_d); + + acc_f32[2 * i + j] = vmlsq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(bias_acc[2 * i + j]), dmins); + acc_f32[2 * i + j] = + vmlaq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(reorder_acc[2 * i + j]), scale); + } + } + } // for b + + // With the previous reorder, the tile is already in the correct memory layout. + for (int i = 0; i < q8_k_blocklen; i++) { + int row = y * q8_k_blocklen + i; + for (int j = 0; j < 2; j++) { + int col = x * ncols_interleaved + j * 4; + int offset = row * bs + col; + vst1q_f32(s + offset, acc_f32[2 * i + j]); + } + } + } // for x + } // for y + return; +#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) + ggml_gemm_q5_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemm_q6_K_8x8_q8_K(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + constexpr int qk = QK_K; + const int nb = n / qk; + + constexpr int ncols_interleaved = 8; + constexpr int blocklen = 8; + + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) + constexpr int q8_k_blocklen = 4; + const uint8x16_t m4b = vdupq_n_u8(0x0f); + const uint8x16_t mask_lo = vdupq_n_u8(0x03); + const uint8x16_t mask_hi = vdupq_n_u8(0x30); + const int8x16_t m32s = vdupq_n_s8(32); + + // 8 accumulators: 4 q8 rows × 2 col groups (0-3, 4-7) + float32x4_t acc_f32[blocklen]; + + for (int y = 0; y < nr / q8_k_blocklen; y++) { + const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb); + + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q6_Kx8 * GGML_RESTRICT q6_ptr = (const block_q6_Kx8 *) vx + (x * nb); + + for (int i = 0; i < blocklen; i++) { + acc_f32[i] = vdupq_n_f32(0); + } + + for (int b = 0; b < nb; b++) { + int32x4_t acc[8]; // rows 01 stored in [0][1][2][3], rows 23 stored in [4][5][6][7] + for (int i = 0; i < 8; i++) { + acc[i] = vdupq_n_s32(0); + } + + // Q6_K has simple 8-bit scales, 16 per block (one per 16 values) + // Reused for bias and dequantization later + int16_t q6_scales[16 * 8]; + for (int i = 0; i < 16; ++i) { + int16x8_t s16 = vmovl_s8(vld1_s8(q6_ptr[b].scales + i * 8)); + vst1q_s16(q6_scales + i * 8, s16); + } + + // Process two 128-value halves per superblock + for (int half = 0; half < 2; half++) { + + const uint8_t * ql_base = q6_ptr[b].ql + half * 512; + const uint8_t * qh_base = q6_ptr[b].qh + half * 256; + + // A subblock (sb) is a set of weights that share the scale + // Since q6_K scales are per 16 elements + // num sbs -> 256 elements / (16 elements/scale * 2 elements/byte * 2 halves) + for (int sb = 0; sb < QK_K / 64; sb++) { + // Q6_K weight index increasing by 64 instead of 32 requires + // loading various q8 memory regions + const int8_t * q8_base_l = q8_ptr[b].qs + half * 512 + sb * 64; + const int8_t * q8_base_h = q8_ptr[b].qs + half * 512 + 256 + sb * 64; + + int8x16_t q8_l_01[2]; + int8x16_t q8_l_23[2]; + for (int i = 0; i < 2; i++) { + const int offset = i * 32; + q8_l_01[i] = vld1q_s8(q8_base_l + offset); // 0..7 & 8..15 (r01) + q8_l_23[i] = vld1q_s8(q8_base_l + offset + 16); // 0..7 & 8..15 (r23) + } + + int8x16_t q8_h_01[2]; + int8x16_t q8_h_23[2]; + for (int i = 0; i < 2; i++) { + const int offset = i * 32; + q8_h_01[i] = vld1q_s8(q8_base_h + offset); + q8_h_23[i] = vld1q_s8(q8_base_h + offset + 16); + } + + const int ql_off_base = sb * QK_K / 2; + + uint8x16_t q6_ql_0[4]; + uint8x16_t q6_ql_1[4]; + for (int k = 0; k < 4; k++) { + q6_ql_0[k] = vld1q_u8(ql_base + ql_off_base + 16 * k); + q6_ql_1[k] = vld1q_u8(ql_base + ql_off_base + 64 + 16 * k); + } + + const int qh_off_base = (sb * QK_K / 2) & 255; // wrap after 256 bytes + uint8x16_t q6_qh_0[4]; + uint8x16_t q6_qh_1[4]; + for (int k = 0; k < 4; k++) { + q6_qh_0[k] = vld1q_u8(qh_base + qh_off_base + 16 * k); + q6_qh_1[k] = vld1q_u8(qh_base + qh_off_base + 64 + 16 * k); + } + + // Adjust for the proper high bits (Sb 2 and 3) + if (sb > 1) { + for (int k = 0; k < 4; k++) { + q6_qh_0[k] = vshrq_n_u8(q6_qh_0[k], 2); + q6_qh_1[k] = vshrq_n_u8(q6_qh_1[k], 2); + } + } + + // Process column pairs (0-1, 2-3, 4-5, 6-7) + for (int cp = 0; cp < ncols_interleaved / 2; cp++) { + const uint8x16_t q6_qs_cp_0_l = q6_ql_0[cp]; + const uint8x16_t q6_qs_cp_1_l = q6_ql_1[cp]; + const uint8x16_t q6_qs_cp_0_h = q6_qh_0[cp]; + const uint8x16_t q6_qs_cp_1_h = q6_qh_1[cp]; + + // Extract high 2 bits for upper nibble reconstruction + const uint8x16_t q6_qs_cp_0_hh = vandq_u8(q6_qs_cp_0_h, mask_hi); + const uint8x16_t q6_qs_cp_1_hh = vandq_u8(q6_qs_cp_1_h, mask_hi); + + // q6 = (low4 | high2<<4) - 32 + // Use vsliq_n_u8 to combine shift-left-insert in one instruction (like Q5_K) + const int8x16_t q6_l0 = vsubq_s8( + vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_qs_cp_0_l, m4b), vandq_u8(q6_qs_cp_0_h, mask_lo), 4)), + m32s); + const int8x16_t q6_l1 = vsubq_s8( + vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_qs_cp_1_l, m4b), vandq_u8(q6_qs_cp_1_h, mask_lo), 4)), + m32s); + const int8x16_t q6_h0 = vsubq_s8( + vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_qs_cp_0_l, 4), q6_qs_cp_0_hh)), m32s); + const int8x16_t q6_h1 = vsubq_s8( + vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_qs_cp_1_l, 4), q6_qs_cp_1_hh)), m32s); + + // row pair 0, base_l + int32x4_t sb_acc_0l = vmmlaq_s32(vdupq_n_s32(0), q6_l0, q8_l_01[0]); + sb_acc_0l = vmmlaq_s32(sb_acc_0l, q6_l1, q8_l_01[1]); + // row pair 0, base_h + int32x4_t sb_acc_0h = vmmlaq_s32(vdupq_n_s32(0), q6_h0, q8_h_01[0]); + sb_acc_0h = vmmlaq_s32(sb_acc_0h, q6_h1, q8_h_01[1]); + // row pair 1, base_l + int32x4_t sb_acc_1l = vmmlaq_s32(vdupq_n_s32(0), q6_l0, q8_l_23[0]); + sb_acc_1l = vmmlaq_s32(sb_acc_1l, q6_l1, q8_l_23[1]); + // row pair 1, base_h + int32x4_t sb_acc_1h = vmmlaq_s32(vdupq_n_s32(0), q6_h0, q8_h_23[0]); + sb_acc_1h = vmmlaq_s32(sb_acc_1h, q6_h1, q8_h_23[1]); + + const int scale_idx_l = half * 8 + sb; + const int scale_idx_h = half * 8 + sb + 4; + + const int32x4_t scale_vec_l = { + q6_scales[scale_idx_l * 8 + cp * 2 + 0], + q6_scales[scale_idx_l * 8 + cp * 2 + 0], + q6_scales[scale_idx_l * 8 + cp * 2 + 1], + q6_scales[scale_idx_l * 8 + cp * 2 + 1], + }; + const int32x4_t scale_vec_h = { + q6_scales[scale_idx_h * 8 + cp * 2 + 0], + q6_scales[scale_idx_h * 8 + cp * 2 + 0], + q6_scales[scale_idx_h * 8 + cp * 2 + 1], + q6_scales[scale_idx_h * 8 + cp * 2 + 1], + }; + + acc[cp] = vmlaq_s32(acc[cp], sb_acc_0l, scale_vec_l); + acc[cp] = vmlaq_s32(acc[cp], sb_acc_0h, scale_vec_h); + acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc_1l, scale_vec_l); + acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc_1h, scale_vec_h); + } + } + } // for half + + // Reorder i8mm output to match memory layout + for (int i = 0; i < 8; i++) { + int32x2x2_t aux = vzip_s32(vget_low_s32(acc[i]), vget_high_s32(acc[i])); + acc[i] = vcombine_s32(aux.val[0], aux.val[1]); + } + int32x4_t reorder_acc[8] = { + vcombine_s32(vget_low_s32(acc[0]), vget_low_s32(acc[1])), + vcombine_s32(vget_low_s32(acc[2]), vget_low_s32(acc[3])), + vcombine_s32(vget_high_s32(acc[0]), vget_high_s32(acc[1])), + vcombine_s32(vget_high_s32(acc[2]), vget_high_s32(acc[3])), + vcombine_s32(vget_low_s32(acc[4]), vget_low_s32(acc[5])), + vcombine_s32(vget_low_s32(acc[6]), vget_low_s32(acc[7])), + vcombine_s32(vget_high_s32(acc[4]), vget_high_s32(acc[5])), + vcombine_s32(vget_high_s32(acc[6]), vget_high_s32(acc[7])), + }; + + // Apply superblock scale (no mins for q6_K) + for (int i = 0; i < q8_k_blocklen; i++) { + for (int j = 0; j < 2; j++) { + float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d[i]); + float32x4_t q6_d = vcvt_f32_f16(vld1_f16((const __fp16 *) (q6_ptr[b].d + j * 4))); + const float32x4_t scale = vmulq_f32(q6_d, q8_d); + + acc_f32[2 * i + j] = + vmlaq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(reorder_acc[2 * i + j]), scale); + } + } + } // for b + + // Store results + for (int i = 0; i < q8_k_blocklen; i++) { + int row = y * q8_k_blocklen + i; + for (int j = 0; j < 2; j++) { + int col = x * ncols_interleaved + j * 4; + int offset = row * bs + col; + vst1q_f32(s + offset, acc_f32[2 * i + j]); + } + } + } // for x + } // for y + return; +#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) + ggml_gemm_q6_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc); +} void ggml_gemm_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, diff --git a/ggml/src/ggml-cpu/arch/x86/quants.c b/ggml/src/ggml-cpu/arch/x86/quants.c index cb49320a67f..74d699f633d 100644 --- a/ggml/src/ggml-cpu/arch/x86/quants.c +++ b/ggml/src/ggml-cpu/arch/x86/quants.c @@ -268,9 +268,9 @@ static inline __m256 quad_fp16_delta_float(const float x0, const float y0, const _mm_set1_ps(GGML_CPU_FP16_TO_FP32(x0) * GGML_CPU_FP16_TO_FP32(y0))); } -static inline __m256 quad_mx_delta_float(const int8_t x0, const float y0, const int8_t x1, const float y1) { - return _mm256_set_m128(_mm_set1_ps(GGML_E8M0_TO_FP32_HALF(x1) * GGML_CPU_FP16_TO_FP32(y1)), - _mm_set1_ps(GGML_E8M0_TO_FP32_HALF(x0) * GGML_CPU_FP16_TO_FP32(y0))); +static inline __m256 quad_mx_delta_float(const uint8_t x0, const float y0, const uint8_t x1, const float y1) { + return _mm256_set_m128(_mm_set1_ps(GGML_CPU_E8M0_TO_FP32_HALF(x1) * GGML_CPU_FP16_TO_FP32(y1)), + _mm_set1_ps(GGML_CPU_E8M0_TO_FP32_HALF(x0) * GGML_CPU_FP16_TO_FP32(y0))); } #endif #elif defined(__SSSE3__) @@ -782,6 +782,7 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo __m256 accum1 = _mm256_setzero_ps(); __m256 accum2 = _mm256_setzero_ps(); + for (; ib + 1 < nb; ib += 2) { const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)x[ib + 0].qs); const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)x[ib + 1].qs); @@ -795,10 +796,10 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2); const __m256i p_1 = _mm256_madd_epi16(p16_1, mone); const __m256i p_2 = _mm256_madd_epi16(p16_2, mone); - accum1 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 0].d)*GGML_E8M0_TO_FP32_HALF(x[ib + 0].e)), - _mm256_cvtepi32_ps(p_1), accum1); - accum2 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 1].d)*GGML_E8M0_TO_FP32_HALF(x[ib + 1].e)), - _mm256_cvtepi32_ps(p_2), accum2); + const __m256 scale0 = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 0].d)*GGML_CPU_E8M0_TO_FP32_HALF(x[ib + 0].e)); + const __m256 scale1 = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 1].d)*GGML_CPU_E8M0_TO_FP32_HALF(x[ib + 1].e)); + accum1 = _mm256_fmadd_ps(scale0, _mm256_cvtepi32_ps(p_1), accum1); + accum2 = _mm256_fmadd_ps(scale1, _mm256_cvtepi32_ps(p_2), accum2); } sumf = hsum_float_8(_mm256_add_ps(accum1, accum2)); @@ -830,7 +831,7 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo #endif for (; ib < nb; ++ib) { - const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_E8M0_TO_FP32_HALF(x[ib].e); + const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_CPU_E8M0_TO_FP32_HALF(x[ib].e); int sumi1 = 0; int sumi2 = 0; for (int j = 0; j < QK_MXFP4/2; ++j) { @@ -3817,4 +3818,3 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); #endif } - diff --git a/ggml/src/ggml-cpu/common.h b/ggml/src/ggml-cpu/common.h index 6adca5437f8..1057b5bb152 100644 --- a/ggml/src/ggml-cpu/common.h +++ b/ggml/src/ggml-cpu/common.h @@ -6,6 +6,9 @@ #include "ggml-impl.h" #include "simd-mappings.h" +#define GGML_FA_TILE_Q 32 +#define GGML_FA_TILE_KV 16 + #ifdef __cplusplus #include @@ -84,4 +87,9 @@ static std::pair get_thread_range(const struct ggml_compute_pa return {ir0, ir1}; } +struct ggml_fa_tile_config { + static constexpr size_t Q = GGML_FA_TILE_Q; + static constexpr size_t KV = GGML_FA_TILE_KV; +}; + #endif diff --git a/ggml/src/ggml-cpu/ggml-cpu-impl.h b/ggml/src/ggml-cpu/ggml-cpu-impl.h index 0e8dd0ae053..88a9c9ec057 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-impl.h +++ b/ggml/src/ggml-cpu/ggml-cpu-impl.h @@ -24,6 +24,9 @@ struct ggml_compute_params { void * wdata; struct ggml_threadpool * threadpool; + + // use reference implementation + bool use_ref; }; diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 4c7a75e768a..b003fe13fd9 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -5,7 +5,6 @@ #include "ggml-backend.h" #include "traits.h" #include "ggml-cpu-impl.h" -#include "ggml-cpu.h" #include "ggml-impl.h" #include "quants.h" #include "ggml-threading.h" @@ -14,6 +13,7 @@ #include "vec.h" #include "ops.h" #include "ggml.h" +#include "common.h" #if defined(_MSC_VER) || defined(__MINGW32__) #include // using malloc.h with MSC/MINGW @@ -75,6 +75,9 @@ // precomputed f32 table for f16 (256 KB) (simd-mappings.h) float ggml_table_f32_f16[1 << 16]; +// precomputed f32 table for e8m0 half (1 KB) (simd-mappings.h) +float ggml_table_f32_e8m0_half[1 << 8]; + #if defined(__ARM_ARCH) struct ggml_arm_arch_features_type { int sve_cnt; @@ -2866,10 +2869,20 @@ struct ggml_cplan ggml_graph_plan( } break; case GGML_OP_FLASH_ATTN_EXT: { - const int64_t ne10 = node->src[1]->ne[0]; // DK - const int64_t ne20 = node->src[2]->ne[0]; // DV + const int64_t neq2 = node->src[0]->ne[2]; // number of query heads + const int64_t DK = node->src[1]->ne[0]; + const int64_t DV = node->src[2]->ne[0]; + + // Tiled flash attention scratch (tile sizes defined in common.h) + // Per-thread: Q_q + KQ + mask + VKQ32 + V32 + padding + size_t prefill = sizeof(float)*(GGML_FA_TILE_Q*DK + 2*GGML_FA_TILE_Q*GGML_FA_TILE_KV + GGML_FA_TILE_Q*DV + GGML_FA_TILE_KV*DV)*n_tasks; + + // Decode path: n_kv_chunks = n_tasks (one chunk per thread) + // Per-thread: VKQ accmulator (DV), partial M, partial S + intra-thread scratch for V, Q and VKQ + size_t n_chunks = n_tasks; + size_t decode = sizeof(float)*(neq2*n_chunks*(2+DV) + n_tasks*(DK + 2*DV)); - cur = sizeof(float)*(1*ne10 + 2*ne20)*n_tasks; // 1x head size K + 2x head size V (per thread) + cur += MAX(prefill, decode); } break; case GGML_OP_FLASH_ATTN_BACK: { @@ -2926,11 +2939,12 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { set_numa_thread_affinity(state->ith); struct ggml_compute_params params = { - /*.ith =*/ state->ith, - /*.nth =*/ atomic_load_explicit(&tp->n_graph, memory_order_relaxed) & GGML_THREADPOOL_N_THREADS_MASK, - /*.wsize =*/ cplan->work_size, - /*.wdata =*/ cplan->work_data, - /*.threadpool=*/ tp, + /*.ith =*/ state->ith, + /*.nth =*/ atomic_load_explicit(&tp->n_graph, memory_order_relaxed) & GGML_THREADPOOL_N_THREADS_MASK, + /*.wsize =*/ cplan->work_size, + /*.wdata =*/ cplan->work_data, + /*.threadpool =*/ tp, + /*.use_ref =*/ cplan->use_ref, }; GGML_PRINT_DEBUG("thread #%d compute-start cplan %p last-graph %d \n", state->ith, cplan, state->last_graph); @@ -3670,6 +3684,11 @@ void ggml_cpu_init(void) { ggml_table_gelu_quick_f16[i] = GGML_CPU_FP32_TO_FP16(ggml_gelu_quick_f32(f)); } + // initialize E8M0 half table (256 entries) + for (int i = 0; i < (1 << 8); ++i) { + ggml_table_f32_e8m0_half[i] = GGML_E8M0_TO_FP32_HALF(i); + } + const uint64_t t_end = ggml_time_us(); UNUSED(t_end); GGML_PRINT_DEBUG("%s: GELU, Quick GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0); diff --git a/ggml/src/ggml-cpu/ggml-cpu.cpp b/ggml/src/ggml-cpu/ggml-cpu.cpp index f4713a42185..ddf1737a317 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.cpp +++ b/ggml/src/ggml-cpu/ggml-cpu.cpp @@ -105,6 +105,8 @@ struct ggml_backend_cpu_context { ggml_abort_callback abort_callback; void * abort_callback_data; + + bool use_ref; // use reference implementation }; static const char * ggml_backend_cpu_get_name(ggml_backend_t backend) { @@ -143,6 +145,7 @@ static ggml_backend_graph_plan_t ggml_backend_cpu_graph_plan_create(ggml_backend cpu_plan->cplan.abort_callback = cpu_ctx->abort_callback; cpu_plan->cplan.abort_callback_data = cpu_ctx->abort_callback_data; + cpu_plan->cplan.use_ref = cpu_ctx->use_ref; return cpu_plan; } @@ -182,6 +185,7 @@ static enum ggml_status ggml_backend_cpu_graph_compute(ggml_backend_t backend, s cplan.abort_callback = cpu_ctx->abort_callback; cplan.abort_callback_data = cpu_ctx->abort_callback_data; + cplan.use_ref = cpu_ctx->use_ref; return ggml_graph_compute(cgraph, &cplan); } @@ -223,6 +227,7 @@ ggml_backend_t ggml_backend_cpu_init(void) { ctx->work_size = 0; ctx->abort_callback = NULL; ctx->abort_callback_data = NULL; + ctx->use_ref = false; ggml_backend_t cpu_backend = new ggml_backend { /* .guid = */ ggml_backend_cpu_guid(), @@ -270,6 +275,13 @@ void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_ ctx->abort_callback_data = abort_callback_data; } +void ggml_backend_cpu_set_use_ref(ggml_backend_t backend_cpu, bool use_ref) { + GGML_ASSERT(ggml_backend_is_cpu(backend_cpu)); + + struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context; + ctx->use_ref = use_ref; +} + // CPU backend - device struct ggml_backend_cpu_device_context { @@ -646,6 +658,9 @@ static void * ggml_backend_cpu_get_proc_address(ggml_backend_reg_t reg, const ch if (strcmp(name, "ggml_backend_cpu_is_numa") == 0) { return (void *)ggml_is_numa; } + if (strcmp(name, "ggml_backend_cpu_set_use_ref") == 0) { + return (void *)ggml_backend_cpu_set_use_ref; + } // threadpool - TODO: move to ggml-base if (strcmp(name, "ggml_threadpool_new") == 0) { diff --git a/ggml/src/ggml-cpu/llamafile/sgemm.cpp b/ggml/src/ggml-cpu/llamafile/sgemm.cpp index 7dc36d4f8ad..8f980c16b96 100644 --- a/ggml/src/ggml-cpu/llamafile/sgemm.cpp +++ b/ggml/src/ggml-cpu/llamafile/sgemm.cpp @@ -1797,10 +1797,27 @@ class tinyBLAS_Q0_AVX { } \ } \ +template +struct mma_instr; + +template<> +struct mma_instr { + static inline void outer_product(acc_t *acc, vec_t a, vec_t b) { + __builtin_mma_xvbf16ger2pp(acc, a, b); + } +}; + +template<> +struct mma_instr { + static inline void outer_product(acc_t *acc, vec_t a, vec_t b) { + __builtin_mma_xvf16ger2pp(acc, a, b); + } +}; + template -class tinyBLAS_BF16_PPC { +class tinyBLAS_HP16_PPC { public: - tinyBLAS_BF16_PPC(int64_t k, + tinyBLAS_HP16_PPC(int64_t k, const TA *A, int64_t lda, const TB *B, int64_t ldb, TC *C, int64_t ldc, @@ -2118,8 +2135,8 @@ class tinyBLAS_BF16_PPC { packNormal((A+(ii*lda)+l), lda, 4, 8, (uint8_t*)vec_A); packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B); for (int x = 0; x < 4; x++) { - __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]); - __builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x], vec_B[x+4]); + mma_instr::outer_product(&acc_0, vec_A[x], vec_B[x]); + mma_instr::outer_product(&acc_1, vec_A[x], vec_B[x+4]); } } SAVE_ACC(&acc_0, ii, jj); @@ -2135,8 +2152,8 @@ class tinyBLAS_BF16_PPC { packNormal((A+(ii*lda)+l), lda, 8, 8, (uint8_t*)vec_A); packNormal((B+(jj*ldb)+l), ldb, 8, 4, (uint8_t*)vec_B); for (int x = 0; x < 4; x++) { - __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]); - __builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x+4], vec_B[x]); + mma_instr::outer_product(&acc_0, vec_A[x], vec_B[x]); + mma_instr::outer_product(&acc_1, vec_A[x], vec_B[x+4]); } } SAVE_ACC(&acc_0, ii, jj); @@ -2155,10 +2172,10 @@ class tinyBLAS_BF16_PPC { packNormal(A+(ii*lda)+l, lda, 8, 8, (uint8_t*)vec_A); packNormal(B+(jj*ldb)+l, ldb, 8, 8, (uint8_t*)vec_B); for (int x = 0; x < 4; x++) { - __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]); - __builtin_mma_xvbf16ger2pp(&acc_1, (vec_t)vec_A[x], (vec_t)vec_B[x+4]); - __builtin_mma_xvbf16ger2pp(&acc_2, (vec_t)vec_A[x+4], (vec_t)vec_B[x]); - __builtin_mma_xvbf16ger2pp(&acc_3, (vec_t)vec_A[x+4], (vec_t)vec_B[x+4]); + mma_instr::outer_product(&acc_0, vec_A[x], vec_B[x]); + mma_instr::outer_product(&acc_1, vec_A[x], vec_B[x+4]); + mma_instr::outer_product(&acc_2, vec_A[x+4], vec_B[x]); + mma_instr::outer_product(&acc_3, vec_A[x+4], vec_B[x+4]); } } @@ -2189,7 +2206,7 @@ class tinyBLAS_BF16_PPC { packNormal(A+(ii*lda)+l, lda, RM, 4, (uint8_t*)vec_A); packNormal(B+(jj*ldb)+l, ldb, RN, 4, (uint8_t*)vec_B); for (int x = 0; x<2; x++) { - __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]); + mma_instr::outer_product(&acc_0, vec_A[x], vec_B[x]); } } __builtin_mma_disassemble_acc(vec_C, &acc_0); @@ -2224,8 +2241,8 @@ class tinyBLAS_BF16_PPC { packNormal(A+(ii*lda)+l, lda, RM, 8, (uint8_t*)vec_A); packNormal(B+(jj*ldb)+l, ldb, RN, 8, (uint8_t*)vec_B); for (int x = 0; x<4; x++) { - __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]); - __builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x], vec_B[x+4]); + mma_instr::outer_product(&acc_0, vec_A[x], vec_B[x]); + mma_instr::outer_product(&acc_1, vec_A[x], vec_B[x+4]); } } __builtin_mma_disassemble_acc(vec_C, &acc_0); @@ -3418,16 +3435,19 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64 return tb.matmul(m, n); } #elif defined(__MMA__) - if ((k % 8)) - return false; - if(Btype == GGML_TYPE_BF16) { - tinyBLAS_BF16_PPC tb{ k, - (const ggml_bf16_t *)A, lda, - (const ggml_bf16_t *)B, ldb, - (float *)C, ldc, - params->ith, params->nth}; - tb.matmul(m, n); - return true; + if (k % 8) { + return false; + } + + if (Btype == GGML_TYPE_BF16) { + tinyBLAS_HP16_PPC tb{ k, + (const ggml_bf16_t *)A, lda, + (const ggml_bf16_t *)B, ldb, + (float *)C, ldc, + params->ith, params->nth }; + + tb.matmul(m, n); + return true; } #elif defined(__riscv_zvfbfwma) #if LMUL == 1 @@ -3516,6 +3536,21 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64 #endif return tb.matmul(m, n); } +#elif defined(__MMA__) + if (k % 8) { + return false; + } + + if (Btype == GGML_TYPE_F16) { + tinyBLAS_HP16_PPC tb{ k, + (const ggml_fp16_t *)A, lda, + (const ggml_fp16_t *)B, ldb, + (float *)C, ldc, + params->ith, params->nth }; + + tb.matmul(m, n); + return true; + } #endif return false; } diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 387e2fe42c3..ce15b18ce0e 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -8042,12 +8042,14 @@ void ggml_compute_forward_top_k( } } -// ggml_compute_forward_flash_attn_ext - static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( const ggml_compute_params * params, ggml_tensor * dst, - int ir0, int ir1) { + int ir0, int ir1, + int64_t ic_start, int64_t ic_end, + float * partials, int64_t partial_stride) { + + const bool write_partials = (partials != nullptr); const ggml_tensor * q = dst->src[0]; const ggml_tensor * k = dst->src[1]; const ggml_tensor * v = dst->src[2]; @@ -8124,7 +8126,6 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( int ith = params->ith; - // loop over n_batch and n_head for (int ir = ir0; ir < ir1; ++ir) { // q indices const int iq3 = ir/(neq2*neq1); @@ -8164,7 +8165,8 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( // online softmax / attention // loop over n_kv and n_head_kv // ref: https://arxiv.org/pdf/2112.05682.pdf - for (int64_t ic = 0; ic < nek1; ++ic) { + + for (int64_t ic = ic_start; ic < ic_end; ++ic) { const float mv = mp ? slope*GGML_CPU_FP16_TO_FP32(mp[ic]) : 0.0f; if (mv == -INFINITY) { continue; @@ -8237,8 +8239,8 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( } } - // sinks - if (sinks) { + // sinks - apply only on the first kv-chunk + if (sinks && ic_start == 0) { const float s = ((float *)((char *) sinks->data))[h]; float ms = 1.0f; @@ -8246,6 +8248,7 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( if (s > M) { ms = expf(M - s); + M = s; ggml_vec_scale_f32(DV, VKQ32, ms); } else { vs = expf(s - M); @@ -8254,20 +8257,372 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( S = S*ms + vs; } - // V /= S - const float S_inv = S == 0.0f ? 0.0f : 1.0f/S; - ggml_vec_scale_f32(DV, VKQ32, S_inv); + if (write_partials) { + // Write M, S, VKQ to partials for later reduction + // partials layout: [M, S, VKQ[DV]] per query head + float * partial = partials + ir * partial_stride; + partial[0] = M; + partial[1] = S; + memcpy(partial + 2, VKQ32, DV * sizeof(float)); + } else { + // V /= S + const float S_inv = S == 0.0f ? 0.0f : 1.0f/S; + ggml_vec_scale_f32(DV, VKQ32, S_inv); - // dst indices - const int i1 = iq1; - const int i2 = iq2; - const int i3 = iq3; + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; + + // permute(0, 2, 1, 3) + memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1); + } + } +} + +static void ggml_compute_forward_flash_attn_ext_tiled( + const ggml_compute_params * params, + ggml_tensor * dst, + int ir0, int ir1) { + const ggml_tensor * q = dst->src[0]; + const ggml_tensor * k = dst->src[1]; + const ggml_tensor * v = dst->src[2]; + const ggml_tensor * mask = dst->src[3]; + const ggml_tensor * sinks = dst->src[4]; + + GGML_TENSOR_LOCALS(int64_t, neq, q, ne) + GGML_TENSOR_LOCALS(size_t, nbq, q, nb) + GGML_TENSOR_LOCALS(int64_t, nek, k, ne) + GGML_TENSOR_LOCALS(size_t, nbk, k, nb) + GGML_TENSOR_LOCALS(int64_t, nev, v, ne) + GGML_TENSOR_LOCALS(size_t, nbv, v, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + + const int64_t DK = nek0; + const int64_t DV = nev0; + const int64_t N = neq1; + + GGML_ASSERT(ne0 == DV); + GGML_ASSERT(ne2 == N); + + // input tensor rows must be contiguous + GGML_ASSERT(nbq0 == ggml_type_size(q->type)); + GGML_ASSERT(nbk0 == ggml_type_size(k->type)); + GGML_ASSERT(nbv0 == ggml_type_size(v->type)); + + GGML_ASSERT(neq0 == DK); + GGML_ASSERT(nek0 == DK); + GGML_ASSERT(nev0 == DV); + + GGML_ASSERT(neq1 == N); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + GGML_ASSERT(k->type == v->type); + const ggml_type kv_type = k->type; + + const auto * kv_type_traits_cpu = ggml_get_type_traits_cpu(kv_type); + const ggml_from_float_t kv_from_float = kv_type_traits_cpu->from_float; + const ggml_vec_dot_t kv_vec_dot = kv_type_traits_cpu->vec_dot; + const size_t kv_type_size = ggml_type_size(kv_type); - // original - //memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float)); + // broadcast factors + const int64_t rk2 = neq2/nek2; + const int64_t rk3 = neq3/nek3; + + const int64_t rv2 = neq2/nev2; + const int64_t rv3 = neq3/nev3; + + float scale = 1.0f; + float max_bias = 0.0f; + float logit_softcap = 0.0f; + + memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); + memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); + memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float)); - // permute(0, 2, 1, 3) - memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1); + if (logit_softcap != 0) { + scale /= logit_softcap; + } + + const uint32_t n_head = neq2; + const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + int ith = params->ith; + + static constexpr int Q_TILE_SZ = ggml_fa_tile_config::Q; + static constexpr int KV_TILE_SZ = ggml_fa_tile_config::KV; + + GGML_ASSERT(nek1 % KV_TILE_SZ == 0 && "KV sequence length must be divisible by KV_TILE_SZ"); + + int ir = ir0; + while (ir < ir1) { + // q indices for the start of this tile + const int iq3 = ir/(neq2*neq1); + const int iq2 = (ir - iq3*neq2*neq1)/neq1; + const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1); + + // Number of valid rows in this tile: + // - limited by tile size (Q_TILE_SZ) + // - limited by chunk boundary (ir1 - ir) + // - limited by head boundary (neq1 - iq1) to avoid crossing into next head + const int tile_rows = MIN(Q_TILE_SZ, MIN((int)(ir1 - ir), (int)(neq1 - iq1))); + GGML_ASSERT(tile_rows > 0); + + const uint32_t h = iq2; // head index + const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f; + + float S[Q_TILE_SZ]; + float M[Q_TILE_SZ]; + + for (int i = 0 ; i < Q_TILE_SZ; ++i) { + S[i] = 0.; + M[i] = -INFINITY; + } + + // Per-thread scratch layout: + // Q_q: Q_TILE_SZ * DK (converted Q tile in KV type) + // KQ: Q_TILE_SZ * KV_TILE_SZ (attention scores in float) + // mask: Q_TILE_SZ * KV_TILE_SZ (mask in float) + // VKQ32: Q_TILE_SZ * DV (FP32 output accumulator) + // V32: KV_TILE_SZ * DV (F32 buffer for V tile - used for f166 conversion) + float * base = (float *) params->wdata + ith*(Q_TILE_SZ*DK + 2*Q_TILE_SZ*KV_TILE_SZ + Q_TILE_SZ*DV + KV_TILE_SZ*DV + CACHE_LINE_SIZE_F32); + + void * Q_q = base; + float * KQ = (float *)((char *)base + Q_TILE_SZ * DK * sizeof(float)); + float * mask32 = KQ + Q_TILE_SZ * KV_TILE_SZ; + float * VKQ32 = mask32 + Q_TILE_SZ * KV_TILE_SZ; + float * V32 = VKQ32 + Q_TILE_SZ * DV; // F32 buffer for V tile + + memset(VKQ32, 0, Q_TILE_SZ * DV * sizeof(float)); + memset(mask32, 0, Q_TILE_SZ * KV_TILE_SZ * sizeof(float)); + + // k indices + const int ik3 = iq3 / rk3; + const int ik2 = iq2 / rk2; + + // v indices + const int iv3 = iq3 / rv3; + const int iv2 = iq2 / rv2; + + for (int tq = 0; tq < tile_rows; tq++) { + const float * pq = (const float *) ((char *) q->data + ((iq1 + tq)*nbq1 + iq2*nbq2 + iq3*nbq3)); + kv_from_float(pq, (char *)Q_q + tq * DK * kv_type_size, DK); + } + // Zero-pad remaining rows + for (int tq = tile_rows; tq < Q_TILE_SZ; tq++) { + memset((char *)Q_q + tq * DK * kv_type_size, 0, DK * kv_type_size); + } + + for (int64_t ic = 0; ic < nek1; ic += KV_TILE_SZ) { + + // skip the tile entirely if all the masks are -inf + if (mask) { + bool can_skip = true; + for (int tq = 0; tq < tile_rows; tq++) { + const ggml_fp16_t * mp_row = (const ggml_fp16_t *)((const char *) mask->data + (iq1 + tq)*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]); + for (int tk = 0; tk < KV_TILE_SZ; tk++) { + mask32[tq * KV_TILE_SZ + tk] = slope * GGML_CPU_FP16_TO_FP32(mp_row[ic + tk]); + if (mask32[tq * KV_TILE_SZ + tk] != -INFINITY) { + can_skip = false; + } + } + } + + if (can_skip) { + continue; + } + } + + for (int tq = 0; tq < Q_TILE_SZ; tq++) { + const void * q_row = (const char *)Q_q + tq * DK * kv_type_size; + for (int tk = 0; tk < KV_TILE_SZ; tk++) { + const void * k_row = (const char *) k->data + ((ic + tk)*nbk1 + ik2*nbk2 + ik3*nbk3); + float s; + kv_vec_dot(DK, &s, 0, k_row, 0, q_row, 0, 1); + KQ[tq * KV_TILE_SZ + tk] = s * scale; + } + } + + if (logit_softcap != 0.0f) { + ggml_vec_tanh_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, KQ); + ggml_vec_scale_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, logit_softcap); + } + + if (mask) { + ggml_vec_add_f32(tile_rows * KV_TILE_SZ, KQ, KQ, mask32); + } + + bool skip[Q_TILE_SZ] = {}; + + for (int tq = 0; tq < Q_TILE_SZ; tq++) { + float * kq_row = KQ + tq * KV_TILE_SZ; + + float tile_max; + ggml_vec_max_f32(KV_TILE_SZ, &tile_max, kq_row); + + if (tile_max == -INFINITY) { + skip[tq] = true; + continue; + } + + const float Mold = M[tq]; + const float Mnew = fmaxf(Mold, tile_max); + + if (Mnew > Mold) { + const float ms = expf(Mold - Mnew); + ggml_vec_scale_f32(DV, VKQ32 + tq * DV, ms); + S[tq] *= ms; + } + M[tq] = Mnew; + + + S[tq] += ggml_vec_soft_max_f32(KV_TILE_SZ, kq_row, kq_row, Mnew); + } + + // Convert V tile to F32 first (if F16), then do MAD + // On x86, ggml_vec_mad_f16 internall converts F16<->F32 on every load/store, so pre-converting is faster. + // TODO: on ARM, native f16 should be faster + if (kv_type == GGML_TYPE_F16) { + for (int tk = 0; tk < KV_TILE_SZ; tk++) { + const ggml_fp16_t * v_row = (const ggml_fp16_t *)((const char *) v->data + ((ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3)); + ggml_fp16_to_fp32_row(v_row, V32 + tk * DV, DV); + } + for (int tq = 0; tq < Q_TILE_SZ; tq++) { + if (skip[tq]) continue; + float * vkq_row = VKQ32 + tq * DV; + for (int tk = 0; tk < KV_TILE_SZ; tk++) { + const float p = KQ[tq * KV_TILE_SZ + tk]; + ggml_vec_mad_f32(DV, vkq_row, V32 + tk * DV, p); + } + } + } else { + for (int tq = 0; tq < Q_TILE_SZ; tq++) { + if (skip[tq]) continue; + float * vkq_row = VKQ32 + tq * DV; + for (int tk = 0; tk < KV_TILE_SZ; tk++) { + const float p = KQ[tq * KV_TILE_SZ + tk]; + const float * v_row = (const float *)((const char *) v->data + ((ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3)); + ggml_vec_mad_f32(DV, vkq_row, v_row, p); + } + } + } + } + + // sinks (apply only to valid rows in the tile) + if (sinks) { + const float s = ((float *)((char *) sinks->data))[h]; + + for (int tq = 0; tq < tile_rows; tq++) { + float ms = 1.0f; + float vs = 1.0f; + + if (s > M[tq]) { + ms = expf(M[tq] - s); + ggml_vec_scale_f32(DV, VKQ32 + tq * DV, ms); + } else { + vs = expf(s - M[tq]); + } + + S[tq] = S[tq] * ms + vs; + } + } + + for (int tq = 0; tq < tile_rows; tq++) { + // V /= S + const float S_inv = S[tq] == 0.0f ? 0.0f : 1.0f / S[tq]; + ggml_vec_scale_f32(DV, VKQ32 + tq * DV, S_inv); + + // dst indices + const int i1 = iq1 + tq; + const int i2 = iq2; + const int i3 = iq3; + + // permute(0, 2, 1, 3) + memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32 + tq * DV, nb1); + } + + ir += tile_rows; + } +} + +// Reduction function: combines partial results across KV chunks +// Partials layout in wdata: [n_q_heads][n_chunks][2 + DV] +static void ggml_flash_attn_ext_reduce_partials( + const ggml_compute_params * params, + ggml_tensor * dst, + const int64_t n_chunks, + const int64_t chunk_size) { + + const ggml_tensor * q = dst->src[0]; + const ggml_tensor * k = dst->src[1]; + const ggml_tensor * v = dst->src[2]; + + const int64_t DK = k->ne[0]; + const int64_t DV = v->ne[0]; + const int64_t nek1 = k->ne[1]; + const int64_t n_q_heads = q->ne[2]; + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t wdata_per_thread = DK + 2*DV + CACHE_LINE_SIZE_F32; + float * thread_wdata = (float *) params->wdata + ith * wdata_per_thread; + + const int64_t partials_offset = nth * (DK + 2*DV + CACHE_LINE_SIZE_F32); + const int64_t partial_size = 2 + DV; + const float * partials_base = (const float *) params->wdata + partials_offset; + + // Output layout + const int64_t ne1 = dst->ne[1]; + const int64_t ne2 = dst->ne[2]; + const size_t nb1 = dst->nb[1]; + + // Each thread reduces a subset of query heads + for (int64_t q_head = ith; q_head < n_q_heads; q_head += nth) { + float M_final = -INFINITY; + float S_final = 0.0f; + float * VKQ_final = thread_wdata; + memset(VKQ_final, 0, DV * sizeof(float)); + + // Combine partials from all chunks + for (int64_t chunk_idx = 0; chunk_idx < n_chunks; ++chunk_idx) { + const int64_t ic_start = chunk_idx * chunk_size; + if (ic_start >= nek1) continue; + + const float * partial = partials_base + (q_head * n_chunks + chunk_idx) * partial_size; + const float M_chunk = partial[0]; + const float S_chunk = partial[1]; + const float * VKQ_chunk = partial + 2; + + if (S_chunk == 0.0f) continue; + + const float M_new = fmaxf(M_final, M_chunk); + const float scale_old = expf(M_final - M_new); + const float scale_new = expf(M_chunk - M_new); + + for (int64_t d = 0; d < DV; ++d) { + VKQ_final[d] = VKQ_final[d] * scale_old + VKQ_chunk[d] * scale_new; + } + S_final = S_final * scale_old + S_chunk * scale_new; + M_final = M_new; + } + + // Normalize and write to output + if (S_final != 0.0f) { + const float S_inv = 1.0f / S_final; + ggml_vec_scale_f32(DV, VKQ_final, S_inv); + } + // iq1=0, iq3=0 for decode + memcpy((char *) dst->data + (0*ne2*ne1 + q_head + 0*ne1)*nb1, VKQ_final, nb1); } } @@ -8292,6 +8647,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int64_t DV = nev0; const int64_t N = neq1; + GGML_ASSERT(ne0 == DV); GGML_ASSERT(ne2 == N); @@ -8312,47 +8668,92 @@ static void ggml_compute_forward_flash_attn_ext_f16( GGML_ASSERT(nb1 <= nb2); GGML_ASSERT(nb2 <= nb3); - // parallelize by q rows using ggml_vec_dot_f32 - - // total rows in q - const int64_t nr = neq1*neq2*neq3; - - // rows per thread const int ith = params->ith; const int nth = params->nth; - // disable for NUMA - const bool disable_chunking = ggml_is_numa(); + // When use_ref is set, force the vec-only reference implementation (no tiling, no KV-chunking) + const bool use_ref = params->use_ref; - // 4x chunks per thread - int nth_scaled = nth * 4; - int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled; - int64_t nchunk = (nr + chunk_size - 1) / chunk_size; + const bool kv_is_f32_or_f16 = (k->type == GGML_TYPE_F32 || k->type == GGML_TYPE_F16); + const bool use_split_kv_path = !use_ref && (neq1 == 1 && neq3 == 1) && kv_is_f32_or_f16 && (k->type == v->type) && q->type == GGML_TYPE_F32 && nek1 >= 512; - if (nth == 1 || nchunk < nth || disable_chunking) { - nchunk = nth; - } + if (use_split_kv_path) { + const int64_t chunk_size = (nek1 + nth - 1) / nth; - if (ith == 0) { - // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start. - ggml_threadpool_chunk_set(params->threadpool, nth); - } + // Partials buffer layout: [q_head][kv_chunk][M, S, VKQ] + const int64_t partial_size = 2 + DV; + float * partials_base = (float *) params->wdata + nth * (DK + 2*DV + CACHE_LINE_SIZE_F32); - ggml_barrier(params->threadpool); + const int64_t ic_start = ith * chunk_size; + const int64_t ic_end = std::min(ic_start + chunk_size, nek1); + + const int64_t partial_stride = nth * partial_size; + float * chunk_partials = partials_base + ith * partial_size; + + if (ic_start < nek1) { + for (int64_t q_head = 0; q_head < neq2; q_head++) { + ggml_compute_forward_flash_attn_ext_f16_one_chunk( + params, dst, q_head, q_head + 1, ic_start, ic_end, + chunk_partials, partial_stride); + } + } else { + for (int64_t q_head = 0; q_head < neq2; q_head++) { + float * q_partials = chunk_partials + q_head * partial_stride; + q_partials[0] = -INFINITY; // M + q_partials[1] = 0.0f; // S + } + } - // The number of elements in each chunk - const int64_t dr = (nr + nchunk - 1) / nchunk; + ggml_barrier(params->threadpool); + ggml_flash_attn_ext_reduce_partials(params, dst, nth, chunk_size); + } else { + + // total rows in q + const int64_t nr = neq1*neq2*neq3; - // The first chunk comes from our thread_id, the rest will get auto-assigned. - int current_chunk = ith; + // disable for NUMA + const bool disable_chunking = ggml_is_numa(); - while (current_chunk < nchunk) { - const int64_t ir0 = dr * current_chunk; - const int64_t ir1 = MIN(ir0 + dr, nr); + // 4x chunks per thread + int nth_scaled = nth * 4; + int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled; + int64_t nchunk = (nr + chunk_size - 1) / chunk_size; - ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1); + if (nth == 1 || nchunk < nth || disable_chunking) { + nchunk = nth; + } - current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1); + if (ith == 0) { + ggml_threadpool_chunk_set(params->threadpool, nth); + } + + ggml_barrier(params->threadpool); + + const int64_t dr = (nr + nchunk - 1) / nchunk; + + static constexpr int64_t KV_TILE_SZ = ggml_fa_tile_config::KV; + static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q; + const bool use_tiled = !use_ref && + (q->type == GGML_TYPE_F32 && + kv_is_f32_or_f16 && + k->type == v->type && + nek1 % KV_TILE_SZ == 0 && + neq1 >= Q_TILE_SZ); + + int current_chunk = ith; + + while (current_chunk < nchunk) { + const int64_t ir0 = dr * current_chunk; + const int64_t ir1 = MIN(ir0 + dr, nr); + + if (use_tiled) { + ggml_compute_forward_flash_attn_ext_tiled(params, dst, ir0, ir1); + } else { + ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1, 0, nek1, nullptr, 0); + } + + current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1); + } } } diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index fbf7ed9432a..24e8ab46182 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -474,15 +474,8 @@ void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, assert (n % qk == 0); assert (nc % ncols_interleaved == 0); - UNUSED(s); UNUSED(bs); - UNUSED(vx); - UNUSED(vy); UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); float sumf[8]; float sum_minf[8]; @@ -616,6 +609,191 @@ void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, } } +void ggml_gemv_q5_K_8x8_q8_K_generic(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + const int qk = QK_K; + const int nb = n / qk; + const int ncols_interleaved = 8; + const int blocklen = 8; + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; + + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(bs); + UNUSED(nr); + + float sumf[8]; + float sum_minf[8]; + uint32_t utmp[32]; + int sumi1; + int sumi2; + int sumi; + + const block_q8_K * a_ptr = (const block_q8_K *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q5_Kx8 * b_ptr = (const block_q5_Kx8 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) { + sumf[j] = 0.0; + sum_minf[j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int sb = 0; sb < 8; sb++) { + memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12); + utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4); + const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1; + utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4); + utmp[sb * 4 + 2] = uaux_0; + utmp[sb * 4 + 0] &= kmask1; + } + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + uint8_t * scales_0 = (uint8_t *) utmp + (k / 4) * 32; + uint8_t * scales_1 = (uint8_t *) utmp + (k / 4) * 32 + 16; + + const int qh_shift = (k / 4) * 2; + for (int j = 0; j < ncols_interleaved; j++) { + sumi1 = 0; + sumi2 = 0; + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int b_qs_offset = k * ncols_interleaved * blocklen + j * blocklen + i; + + const int qh_idx = (k * 8 + i) % 32; + const int qh_chunk = qh_idx / 8; + const int qh_pos = qh_idx % 8; + const int b_qh_offset = qh_chunk * 64 + j * 8 + qh_pos; + + const uint8_t qh_val = b_ptr[l].qh[b_qh_offset]; + const uint8_t h0 = (qh_val >> qh_shift) & 1; + const uint8_t h1 = (qh_val >> (qh_shift + 1)) & 1; + + const int v0 = (int8_t) ((b_ptr[l].qs[b_qs_offset] & 0xF) | (h0 << 4)); + const int v1 = (int8_t) ((b_ptr[l].qs[b_qs_offset] >> 4) | (h1 << 4)); + + const int q8_offset = (k >> 2) * 64 + (k % 4) * blocklen + i; + + sumi1 = (v0 * a_ptr[l].qs[q8_offset]); + sumi2 = (v1 * a_ptr[l].qs[q8_offset + 32]); + sumi1 = sumi1 * scales_0[j]; + sumi2 = sumi2 * scales_1[j]; + sumi += sumi1 + sumi2; + } + sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d; + } + } + for (int sb = 0; sb < 8; sb++) { + uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16; + for (int j = 0; j < ncols_interleaved; j++) { + sum_minf[j] += mins[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) * + GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d; + } + } + } + for (int j = 0; j < ncols_interleaved; j++) { + s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j]; + } + } +} + + +void ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + constexpr int qk = QK_K; + const int nb = n / qk; + const int ncols_interleaved = 8; + const int blocklen = 8; + + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(bs); + UNUSED(nr); + + float sumf[8]; + + const block_q8_K * a_ptr = (const block_q8_K *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q6_Kx8 * b_ptr = (const block_q6_Kx8 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) { + sumf[j] = 0.0f; + } + + for (int l = 0; l < nb; l++) { + + + for (int k = 0; k < 16; k++) { + // k = 0.. 7 weights 0-63 low, 64-127 high + // k = 8..15 weights 128-191 low, 192-255 high + const int base_l = (k / 8) * 128 + (k % 8) * 8; + const int base_h = base_l + 64; + + const int scale_idx_l = base_l / 16; + const int scale_idx_h = base_h / 16; + + // Bit shift cycles 0,2,4,6 for each 32-value group within a 128-value half + const int qh_shift_l = ((base_l % 128) / 32) * 2; + const int qh_shift_h = ((base_h % 128) / 32) * 2; + + // qh_half: offset to the correct 32-byte half (0 or 32) + const int qh_half_l = (base_l / 128) * 32; + const int qh_half_h = (base_h / 128) * 32; + + for (int j = 0; j < ncols_interleaved; j++) { + // Interleaved scales + const int8_t scale_l = b_ptr[l].scales[scale_idx_l * 8 + j]; + const int8_t scale_h = b_ptr[l].scales[scale_idx_h * 8 + j]; + + int sumi_l = 0; + int sumi_h = 0; + + for (int i = 0; i < blocklen; i++) { + const int ql_pos = k * 64 + j * 8 + i; + const int l_4 = b_ptr[l].ql[ql_pos] & 0xF; + const int hi_4 = (b_ptr[l].ql[ql_pos] >> 4) & 0xF; + + // qh indexing with 8-byte interleaving (like q5_K) + const int qh_byte_l = qh_half_l + ((base_l + i) % 32); + const int qh_chunk_l = qh_byte_l / 8; + const int qh_pos_l = qh_byte_l % 8; + const int qh_offset_l = qh_chunk_l * 64 + j * 8 + qh_pos_l; + const int hi_2_l = (b_ptr[l].qh[qh_offset_l] >> qh_shift_l) & 0x3; + + const int qh_byte_h = qh_half_h + ((base_h + i) % 32); + const int qh_chunk_h = qh_byte_h / 8; + const int qh_pos_h = qh_byte_h % 8; + const int qh_offset_h = qh_chunk_h * 64 + j * 8 + qh_pos_h; + const int hi_2_h = (b_ptr[l].qh[qh_offset_h] >> qh_shift_h) & 0x3; + + const int q_l = ((hi_2_l << 4) | l_4) - 32; + const int q_h = ((hi_2_h << 4) | hi_4) - 32; + + const int8_t a_l = a_ptr[l].qs[base_l + i]; + const int8_t a_h = a_ptr[l].qs[base_h + i]; + + sumi_l += q_l * a_l; + sumi_h += q_h * a_h; + } + + sumf[j] += + (sumi_l * scale_l + sumi_h * scale_h) * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d; + } + } + } + + for (int j = 0; j < ncols_interleaved; j++) { + s[x * ncols_interleaved + j] = sumf[j]; + } + } +} + void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; @@ -1046,15 +1224,7 @@ void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, assert (nr % 4 == 0); assert (nc % ncols_interleaved == 0); - UNUSED(s); UNUSED(bs); - UNUSED(vx); - UNUSED(vy); - UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); float sumf[4][8]; float sum_minf[4][8]; @@ -1212,6 +1382,213 @@ void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, } } +void ggml_gemm_q5_K_8x8_q8_K_generic(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + const int qk = QK_K; + const int nb = n / qk; + const int ncols_interleaved = 8; + const int blocklen = 8; + + constexpr uint32_t kmask1 = 0x3f3f3f3f; + constexpr uint32_t kmask2 = 0x0f0f0f0f; + constexpr uint32_t kmask3 = 0x03030303; + + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); + + float sumf[4][8]; + float sum_minf[4][8]; + uint32_t utmp[32]; + int sumi1; + int sumi2; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q5_Kx8 * b_ptr = (const block_q5_Kx8 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumf[m][j] = 0.0; + sum_minf[m][j] = 0.0; + } + } + for (int l = 0; l < nb; l++) { + for (int sb = 0; sb < 8; sb++) { + memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12); + utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4); + const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1; + utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4); + utmp[sb * 4 + 2] = uaux_0; + utmp[sb * 4 + 0] &= kmask1; + } + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + uint8_t * scales_0 = (uint8_t *) utmp + (k / 4) * 32; + uint8_t * scales_1 = (uint8_t *) utmp + (k / 4) * 32 + 16; + + const int qh_shift = (k / 4) * 2; + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi1 = 0; + sumi2 = 0; + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int b_qs_offset = k * ncols_interleaved * blocklen + j * blocklen + i; + + const int qh_idx = (k * 8 + i) % 32; + const int qh_chunk = qh_idx / 8; + const int qh_pos = qh_idx % 8; + const int b_qh_offset = qh_chunk * 64 + j * 8 + qh_pos; + + const uint8_t qh_val = b_ptr[l].qh[b_qh_offset]; + const uint8_t h0 = (qh_val >> qh_shift) & 1; + const uint8_t h1 = (qh_val >> (qh_shift + 1)) & 1; + + const int v0 = (int8_t) ((b_ptr[l].qs[b_qs_offset] & 0xF) | (h0 << 4)); + const int v1 = (int8_t) ((b_ptr[l].qs[b_qs_offset] >> 4) | (h1 << 4)); + + const int q8_offset = (k >> 2) * 256 + (k % 4) * 4 * blocklen + m * blocklen + i; + + sumi1 = (v0 * a_ptr[l].qs[q8_offset]); + sumi2 = (v1 * a_ptr[l].qs[q8_offset + 128]); + sumi1 = sumi1 * scales_0[j]; + sumi2 = sumi2 * scales_1[j]; + sumi += sumi1 + sumi2; + } + sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m]; + } + } + } + for (int sb = 0; sb < 8; sb++) { + uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16; + for (int m = 0; m < 4; m++) { + const int16_t * bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6); + for (int j = 0; j < ncols_interleaved; j++) { + sum_minf[m][j] += mins[j] * (bsums[0] + bsums[1]) * + GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m]; + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j]; + } + } + } + } +} + +void ggml_gemm_q6_K_8x8_q8_K_generic(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + const int qk = QK_K; + const int nb = n / qk; + const int ncols_interleaved = 8; + const int blocklen = 8; + + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(bs); + + float sumf[4][8]; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q6_Kx8 * b_ptr = (const block_q6_Kx8 *) vx + (x * nb); + + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumf[m][j] = 0.0f; + } + } + + for (int l = 0; l < nb; l++) { + for (int k = 0; k < 16; k++) { + // k = 0.. 7 weights 0-63 low, 64-127 high + // k = 8..15 weights 128-191 low, 192-255 high + const int base_l = (k / 8) * 128 + (k % 8) * 8; + const int base_h = base_l + 64; + + const int scale_idx_l = base_l / 16; + const int scale_idx_h = base_h / 16; + + // Bit shift cycles 0,2,4,6 for each 32-value group within a 128-value half + const int qh_shift_l = ((base_l % 128) / 32) * 2; + const int qh_shift_h = ((base_h % 128) / 32) * 2; + + // qh_half: offset to the correct 32-byte half (0 or 32) + const int qh_half_l = (base_l / 128) * 32; + const int qh_half_h = (base_h / 128) * 32; + + // Activation base indices for q8_Kx4 interleaved format + // Layout: 128-value halves (k/8), then 8-value sub-blocks (k%8) with stride 32 + const int q8_base = (k / 8) * 512 + (k % 8) * 32; + + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + // Interleaved scales + const int8_t scale_l = b_ptr[l].scales[scale_idx_l * 8 + j]; + const int8_t scale_h = b_ptr[l].scales[scale_idx_h * 8 + j]; + + int sumi_l = 0; + int sumi_h = 0; + + for (int i = 0; i < blocklen; i++) { + const int ql_pos = k * 64 + j * 8 + i; + const int l_4 = b_ptr[l].ql[ql_pos] & 0xF; + const int hi_4 = (b_ptr[l].ql[ql_pos] >> 4) & 0xF; + + const int qh_idx_l = qh_half_l + ((base_l + i) % 32); + const int qh_chunk_l = qh_idx_l / 8; + const int qh_pos_l = qh_idx_l % 8; + const int qh_offset_l = qh_chunk_l * 64 + j * 8 + qh_pos_l; + const int hi_2_l = (b_ptr[l].qh[qh_offset_l] >> qh_shift_l) & 0x3; + + const int qh_idx_h = qh_half_h + ((base_h + i) % 32); + const int qh_chunk_h = qh_idx_h / 8; + const int qh_pos_h = qh_idx_h % 8; + const int qh_offset_h = qh_chunk_h * 64 + j * 8 + qh_pos_h; + const int hi_2_h = (b_ptr[l].qh[qh_offset_h] >> qh_shift_h) & 0x3; + + const int q_l = ((hi_2_l << 4) | l_4) - 32; + const int q_h = ((hi_2_h << 4) | hi_4) - 32; + + const int8_t q8_l = a_ptr[l].qs[q8_base + m * 8 + i]; + const int8_t q8_h = a_ptr[l].qs[q8_base + m * 8 + i + 256]; + + sumi_l += q_l * q8_l; + sumi_h += q_h * q8_h; + } + + sumf[m][j] += (sumi_l * scale_l + sumi_h * scale_h) * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * + a_ptr[l].d[m]; + } + } + } + } + + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } + } +} void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; @@ -1612,8 +1989,7 @@ static block_q2_Kx8 make_block_q2_Kx8(block_q2_K * in, unsigned int blck_size_in // Every 16 byte is packed such that it contains scales and mins for corresponding sub blocks from Q2_K structure // For eg - First 16 bytes contains 16 scales and 16 mins - each of first and second sub blocks from different Q2_K structures - for(int i = 0; i < 128; i++){ - + for (int i = 0; i < 128; i++) { // Index for selecting which q2k super block int src1 = (i % 16) / 2; // Index for selecting scale @@ -1622,7 +1998,141 @@ static block_q2_Kx8 make_block_q2_Kx8(block_q2_K * in, unsigned int blck_size_in out.scales[i] = in[src1].scales[src2]; } return out; +} + +static block_q5_Kx8 make_block_q5_Kx8(block_q5_K * in, unsigned int blck_size_interleave) { + block_q5_Kx8 out; + //Delta(scale) and dmin values of the eight Q5_K structures are copied onto the output interleaved structure + for (int i = 0; i < 8; i++) { + out.d[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d; + } + + for (int i = 0; i < 8; i++) { + out.dmin[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin; + } + + const int end = QK_K * 4 / blck_size_interleave; + + // Interleave Q5_K quants by taking 8 bytes at a time + for (int i = 0; i < end; ++i) { + int src_id = i % 8; + int src_offset = (i / 8) * blck_size_interleave; + int dst_offset = i * blck_size_interleave; + + uint64_t elems; + memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t)); + memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t)); + } + + // Repeat for low bits 8 bytes at a time as well, since + // the high bits are interleaved in Q5_K and the index is + // qh_idx = (qs_idx % 32); + // qh_val = qh[qh_idx] >> (qs_idx / 32); + for (int i = 0; i < end / 4; ++i) { + int src_id = i % 8; + int src_offset = (i / 8) * blck_size_interleave; + int dst_offset = i * blck_size_interleave; + + uint64_t elems; + memcpy(&elems, &in[src_id].qh[src_offset], sizeof(uint64_t)); + memcpy(&out.qh[dst_offset], &elems, sizeof(uint64_t)); + } + + // The below logic is copied over from Q4_K + // The point is to unpack all the scales and mins for each sub block every time we load 12 bytes. + // Currently the Q5_K structure has 8 scales and 8 mins packed in 12 bytes ( 6 bits for each value) + // The output Q5_Kx8 structure has 96 bytes + // Every 12 byte is packed such that it contains scales and mins for corresponding sub blocks from Q5_K structure + // For eg - First 12 bytes contains 8 scales and 8 mins - each of first sub block from different Q5_K structures + uint8_t s[8], m[8]; + + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 8; j++) { + s[j] = in[j].scales[i] & 63; + m[j] = in[j].scales[i + 4] & 63; + } + + out.scales[i * 12] = (s[0] & 63) + ((s[4] & 48) << 2); + out.scales[i * 12 + 1] = (s[1] & 63) + ((s[5] & 48) << 2); + out.scales[i * 12 + 2] = (s[2] & 63) + ((s[6] & 48) << 2); + out.scales[i * 12 + 3] = (s[3] & 63) + ((s[7] & 48) << 2); + out.scales[i * 12 + 4] = (m[0] & 63) + ((m[4] & 48) << 2); + out.scales[i * 12 + 5] = (m[1] & 63) + ((m[5] & 48) << 2); + out.scales[i * 12 + 6] = (m[2] & 63) + ((m[6] & 48) << 2); + out.scales[i * 12 + 7] = (m[3] & 63) + ((m[7] & 48) << 2); + out.scales[i * 12 + 8] = (s[4] & 15) + ((m[4] & 15) << 4); + out.scales[i * 12 + 9] = (s[5] & 15) + ((m[5] & 15) << 4); + out.scales[i * 12 + 10] = (s[6] & 15) + ((m[6] & 15) << 4); + out.scales[i * 12 + 11] = (s[7] & 15) + ((m[7] & 15) << 4); + } + + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 8; j++) { + s[j] = ((in[j].scales[i] & 192) >> 2) | (in[j].scales[i + 8] & 15); + m[j] = ((in[j].scales[i + 4] & 192) >> 2) | ((in[j].scales[i + 8] & 240) >> 4); + } + + out.scales[i * 12 + 48] = (s[0] & 63) + ((s[4] & 48) << 2); + out.scales[i * 12 + 49] = (s[1] & 63) + ((s[5] & 48) << 2); + out.scales[i * 12 + 50] = (s[2] & 63) + ((s[6] & 48) << 2); + out.scales[i * 12 + 51] = (s[3] & 63) + ((s[7] & 48) << 2); + out.scales[i * 12 + 52] = (m[0] & 63) + ((m[4] & 48) << 2); + out.scales[i * 12 + 53] = (m[1] & 63) + ((m[5] & 48) << 2); + out.scales[i * 12 + 54] = (m[2] & 63) + ((m[6] & 48) << 2); + out.scales[i * 12 + 55] = (m[3] & 63) + ((m[7] & 48) << 2); + out.scales[i * 12 + 56] = (s[4] & 15) + ((m[4] & 15) << 4); + out.scales[i * 12 + 57] = (s[5] & 15) + ((m[5] & 15) << 4); + out.scales[i * 12 + 58] = (s[6] & 15) + ((m[6] & 15) << 4); + out.scales[i * 12 + 59] = (s[7] & 15) + ((m[7] & 15) << 4); + } + + return out; +} + +static block_q6_Kx8 make_block_q6_Kx8(block_q6_K * in, unsigned int blck_size_interleave) { + block_q6_Kx8 out; + constexpr int n_blocks = 8; // Kx8 + for (int i = 0; i < n_blocks; i++) { + out.d[i] = in[i].d; + } + + const int end_ls = QK_K * 4 / blck_size_interleave; + // Interleave Q6_K quants by taking 8 bytes at a time + for (int i = 0; i < end_ls; ++i) { + int src_id = i % n_blocks; + int src_offset = (i / n_blocks) * blck_size_interleave; + int dst_offset = i * blck_size_interleave; + + uint64_t elem_ls; + memcpy(&elem_ls, &in[src_id].ql[src_offset], sizeof(uint64_t)); + memcpy(&out.ql[dst_offset], &elem_ls, sizeof(uint64_t)); + } + + // Interleave high bits using same 8-byte pattern as low bits + const int end_hs = end_ls / 2; + for (int i = 0; i < end_hs; ++i) { + int src_id = i % n_blocks; + int src_offset = (i / n_blocks) * blck_size_interleave; + int dst_offset = i * blck_size_interleave; + + uint64_t elem_hs; + memcpy(&elem_hs, &in[src_id].qh[src_offset], sizeof(uint64_t)); + memcpy(&out.qh[dst_offset], &elem_hs, sizeof(uint64_t)); + } + + // The below logic is designed so as to unpack and rearrange scales in Q6_K + // The output Q6_Kx8 structure interleaves the 8 bit scales in the same fashion as the quants + // Q6_K structure has an 8-bit scale per 16 elements -> 16 scales + // scales: [0 bl0 0 bl1 ... 0 bl7][1 bl0 ... 1 bl7] ... [15 bl0 ... 15 bl7] (bl = block) + constexpr int n_scales = QK_K / 16; + for (int i = 0; i < n_blocks; i++) { + for (int j = 0; j < n_scales; j++) { + out.scales[j * n_blocks + i] = in[i].scales[j]; + } + } + + return out; } static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { @@ -1706,7 +2216,7 @@ static int repack_q2_K_to_q2_K_8_bl(struct ggml_tensor * t, int interleave_block for (int b = 0; b < nrow; b += nrows_interleaved) { for (int64_t x = 0; x < nblocks; x++) { - for (int i = 0; i < nrows_interleaved; i++ ) { + for (int i = 0; i < nrows_interleaved; i++) { dst_tmp[i] = src[x + i * nblocks]; } *dst++ = make_block_q2_Kx8(dst_tmp, interleave_block); @@ -1718,6 +2228,67 @@ static int repack_q2_K_to_q2_K_8_bl(struct ggml_tensor * t, int interleave_block GGML_UNUSED(data_size); } +static int repack_q5_K_to_q5_K_8_bl(struct ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q5_K); + GGML_ASSERT(interleave_block == 8); + constexpr int nrows_interleaved = 8; + + block_q5_Kx8 * dst = (block_q5_Kx8 *) t->data; + const block_q5_K * src = (const block_q5_K *) data; + block_q5_K dst_tmp[8]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q5_K)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q5_Kx8(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; +} + +static int repack_q6_K_to_q6_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q6_K); + GGML_ASSERT(interleave_block == 8); + constexpr int nrows_interleaved = 8; + + block_q6_Kx8 * dst = (block_q6_Kx8 *)t->data; + const block_q6_K * src = (const block_q6_K *) data; + block_q6_K dst_tmp[8]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q6_K)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q6_Kx8(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; +} + static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { GGML_ASSERT(t->type == GGML_TYPE_Q4_0); GGML_ASSERT(interleave_block == 8); @@ -1936,6 +2507,14 @@ template <> int repack(struct ggml_tensor * t, const void * da return repack_q2_K_to_q2_K_8_bl(t, 8, data, data_size); } +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q5_K_to_q5_K_8_bl(t, 8, data, data_size); +} + +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q6_K_to_q6_K_8_bl(t, 8, data, data_size); +} + template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { return repack_iq4_nl_to_iq4_nl_4_bl(t, 4, data, data_size); } @@ -1973,6 +2552,17 @@ template <> void gemv(int n, float * s, size_t ggml_gemv_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc); } +template <> +void gemv(int n, + float * s, + size_t bs, + const void * vx, + const void * vy, + int nr, + int nc) { + ggml_gemv_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); +} + template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemv_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc); } @@ -1981,8 +2571,12 @@ template <> void gemv(int n, float * s, size_t ggml_gemv_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); } -template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { - ggml_gemv_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q5_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q6_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); } template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { @@ -2013,20 +2607,35 @@ template <> void gemm(int n, float * s, size_t ggml_gemm_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc); } -template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { - ggml_gemm_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc); +template <> +void gemm(int n, + float * s, + size_t bs, + const void * vx, + const void * vy, + int nr, + int nc) { + ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc); } -template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { - ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc); +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc); } template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); } -template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { - ggml_gemm_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q5_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q6_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); } template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { @@ -2393,20 +3002,19 @@ template (ne00, - (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, ne01, - src0_cur + src0_cur_start * nb01, - src1_col, 1, src0_cur_end - src0_cur_start); + gemv( + ne00, (float *) ((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, ne01, + src0_cur + src0_cur_start * nb01, src1_col, 1, src0_cur_end - src0_cur_start); } } #undef MMID_MATRIX_ROW @@ -2422,7 +3030,6 @@ template q4_0_4x4_q8_0; static const ggml::cpu::repack::tensor_traits q4_0_4x8_q8_0; @@ -2432,6 +3039,12 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons static const ggml::cpu::repack::tensor_traits q4_K_8x4_q8_K; static const ggml::cpu::repack::tensor_traits q4_K_8x8_q8_K; + // instance for Q5_K + static const ggml::cpu::repack::tensor_traits q5_K_8x8_q8_K; + + // instance for Q6_K + static const ggml::cpu::repack::tensor_traits q6_K_8x8_q8_K; + // instance for Q2 static const ggml::cpu::repack::tensor_traits q2_K_8x8_q8_K; @@ -2482,6 +3095,18 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons return &q2_K_8x8_q8_K; } } + } else if (cur->type == GGML_TYPE_Q5_K) { + if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) { + if (cur->ne[1] % 8 == 0) { + return &q5_K_8x8_q8_K; + } + } + } else if (cur->type == GGML_TYPE_Q6_K) { + if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) { + if (cur->ne[1] % 8 == 0) { + return &q6_K_8x8_q8_K; + } + } } else if (cur->type == GGML_TYPE_IQ4_NL) { if (ggml_cpu_has_avx2()) { if (cur->ne[1] % 8 == 0) { diff --git a/ggml/src/ggml-cpu/repack.h b/ggml/src/ggml-cpu/repack.h index af98e703442..855320eeeb6 100644 --- a/ggml/src/ggml-cpu/repack.h +++ b/ggml/src/ggml-cpu/repack.h @@ -44,6 +44,7 @@ struct block_q4_Kx8 { }; static_assert(sizeof(block_q4_Kx8) == sizeof(ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 4, "wrong q4_K block size/padding"); + struct block_q2_Kx8 { ggml_half d[8]; // super-block scale for quantized scales ggml_half dmin[8]; // super-block scale for quantized mins @@ -52,6 +53,28 @@ struct block_q2_Kx8 { }; static_assert(sizeof(block_q2_Kx8) == sizeof(ggml_half) * 16 + QK_K/2 + QK_K * 2, "wrong q2_K block size/padding"); + +struct block_q5_Kx8 { + ggml_half d[8]; // super-block scale for quantized scales + ggml_half dmin[8]; // super-block scale for quantized mins + uint8_t scales[96]; // scales and mins, quantized with 6 bits + uint8_t qh[QK_K * 8 / 8]; // high bits of 5-bit quants + uint8_t qs[QK_K * 8 / 2]; // low bits of 5-bit quants (in groups of 4) +}; + +static_assert(sizeof(block_q5_Kx8) == sizeof(ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 5, + "wrong q5_K block size/padding"); + +struct block_q6_Kx8 { + ggml_half d[8]; + int8_t scales[QK_K / 16 * 8]; + uint8_t ql[QK_K / 2 * 8]; // low bits of 6-bit quants (groups of 2) + uint8_t qh[QK_K / 4 * 8]; // high bits of 6-bit quants (groups of 4) +}; + +static_assert(sizeof(block_q6_Kx8) == sizeof(ggml_half) * 8 + QK_K / 16 * 8 + 3 * QK_K / 4 * 8, + "wrong q6_K block size/padding"); + struct block_q8_Kx4 { float d[4]; // delta int8_t qs[QK_K * 4]; // quants @@ -85,17 +108,21 @@ void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTR void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -111,17 +138,21 @@ void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GG void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); diff --git a/ggml/src/ggml-cpu/simd-mappings.h b/ggml/src/ggml-cpu/simd-mappings.h index e367f110b46..630e506542b 100644 --- a/ggml/src/ggml-cpu/simd-mappings.h +++ b/ggml/src/ggml-cpu/simd-mappings.h @@ -116,6 +116,17 @@ extern "C" { // defined in ggml-cpu.c, initialized in ggml_cpu_init() extern float ggml_table_f32_f16[1 << 16]; +// precomputed f32 table for e8m0 half (1 KB) +// defined in ggml-cpu.c, initialized in ggml_cpu_init() +extern float ggml_table_f32_e8m0_half[1 << 8]; + +// Use lookup table for E8M0 on x86 (faster than bit manipulation) +#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) +#define GGML_CPU_E8M0_TO_FP32_HALF(x) ggml_table_f32_e8m0_half[(uint8_t)(x)] +#else +#define GGML_CPU_E8M0_TO_FP32_HALF(x) GGML_E8M0_TO_FP32_HALF(x) +#endif + // On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32, // so we define GGML_CPU_FP16_TO_FP32 and GGML_CPU_FP32_TO_FP16 elsewhere for NEON. // This is also true for POWER9. diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 179522d8355..a3256d59dd0 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -53,6 +53,7 @@ // While BW spans CC 1000, 1100 & 1200, we are integrating Tensor Core instructions available to 1200 family, see // https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_functionality.html#blackwell-sm120-gemms #define GGML_CUDA_CC_BLACKWELL 1200 +#define GGML_CUDA_CC_DGX_SPARK 1210 #define GGML_CUDA_CC_RUBIN 1300 #define GGML_CUDA_CC_OFFSET_AMD 0x1000000 #define GGML_CUDA_CC_OFFSET_MTHREADS 0x0100000 @@ -1121,15 +1122,18 @@ struct ggml_tensor_extra_gpu { #endif struct ggml_cuda_graph_node_properties { - void * node_address; + void * node_data; ggml_op node_op; + enum ggml_type node_type; int32_t flags; int64_t ne[GGML_MAX_DIMS]; size_t nb[GGML_MAX_DIMS]; - void * src_address[GGML_MAX_SRC]; + void * src_data[GGML_MAX_SRC]; int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)]; }; +static_assert(std::is_trivial::value, "ggml_cuda_graph_node_properties must be trivial"); + struct ggml_cuda_graph { #ifdef USE_CUDA_GRAPH ~ggml_cuda_graph() { @@ -1149,6 +1153,12 @@ struct ggml_cuda_graph { int number_consecutive_updates = 0; std::vector props; + // these are extra tensors (inputs) that participate in the ggml graph but are not nodes + // they properties also have to match in order to be able to safely reuse a CUDA graph + // ref: https://github.com/ggml-org/llama.cpp/pull/18583 + // ref: https://github.com/ggml-org/llama.cpp/pull/19165 + std::vector extra; + void record_update(bool use_graph, bool update_required) { if (use_graph && update_required) { number_consecutive_updates++; @@ -1327,10 +1337,44 @@ struct ggml_backend_cuda_context { cudaStream_t streams[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { { nullptr } }; cublasHandle_t cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr}; - std::unique_ptr cuda_graph; - int curr_stream_no = 0; +#ifdef USE_CUDA_GRAPH + // Map from first_node_ptr to cuda_graph - allows multiple graphs per context + // when the computation is split across CPU/GPU (e.g., with --n-cpu-moe) + std::unordered_map> cuda_graphs; + + ggml_cuda_graph * cuda_graph(const void * first_node_ptr) { + auto it = cuda_graphs.find(first_node_ptr); + if (it == cuda_graphs.end()) { + cuda_graphs[first_node_ptr] = std::make_unique(); + return cuda_graphs[first_node_ptr].get(); + } + return it->second.get(); + } + + // Check if any CUDA graph is enabled for this context (used by kernels that need to know + // if graphs are in use without having access to the specific graph key) + bool any_cuda_graph_enabled() const { + for (const auto & [key, graph] : cuda_graphs) { + if (graph && graph->is_enabled()) { + return true; + } + } + return false; + } + + // Check if any CUDA graph has an instance for this context + bool any_cuda_graph_has_instance() const { + for (const auto & [key, graph] : cuda_graphs) { + if (graph && graph->instance != nullptr) { + return true; + } + } + return false; + } +#endif // USE_CUDA_GRAPH + explicit ggml_backend_cuda_context(int device) : device(device), name(GGML_CUDA_NAME + std::to_string(device)) { diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 8468ba8488d..b6a7460da83 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -629,8 +629,8 @@ static __global__ void flash_attn_mask_to_KV_max( template // D == head size __launch_bounds__(D, 1) static __global__ void flash_attn_stream_k_fixup( - float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11, - const int nbatch_fa) { + float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, + const int ne11, const int ne12, const int nbatch_fa) { constexpr int ncols = ncols1*ncols2; const int bidx0 = blockIdx.x; @@ -641,11 +641,14 @@ static __global__ void flash_attn_stream_k_fixup( const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols); - const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa; - const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; + const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const int kbc0 = int64_t(bidx0 + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x; - const int kbc0_stop = int64_t(bidx0 + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x; + const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa; + const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; + const int iter_z_gqa = (gqa_ratio + (ncols2 - 1)) / ncols2; + + const int kbc0 = int64_t(bidx0 + 0)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x; + const int kbc0_stop = int64_t(bidx0 + 1)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x; const bool did_not_have_any_data = kbc0 == kbc0_stop; const bool wrote_beginning_of_tile = kbc0 % iter_k == 0; @@ -654,15 +657,19 @@ static __global__ void flash_attn_stream_k_fixup( return; } - const int sequence = kbc0 / (iter_k*iter_j*(ne02/ncols2)); - const int head = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); - const int jt = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile. + // z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index + const int sequence = kbc0 /(iter_k*iter_j*iter_z_gqa*ne12); + const int z_KV = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa); + const int zt_gqa = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j); + const int jt = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k; + + const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index. - if (jt*ncols1 + j >= ne01) { + if (jt*ncols1 + j >= ne01 || zt_gqa*ncols2 + c >= gqa_ratio) { return; } - dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + head*(ncols2*D) + (j*ne02 + c)*D + tid; + dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + zt_Q*D + (j*ne02 + c)*D + tid; // Load the partial result that needs a fixup: float dst_val = 0.0f; @@ -681,7 +688,7 @@ static __global__ void flash_attn_stream_k_fixup( int bidx = bidx0 - 1; int kbc_stop = kbc0; while(true) { - const int kbc = int64_t(bidx)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x; + const int kbc = int64_t(bidx)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x; if (kbc == kbc_stop) { // Did not have any data. bidx--; kbc_stop = kbc; @@ -778,13 +785,11 @@ void launch_fattn( ) { constexpr int ncols = ncols1 * ncols2; - const bool is_mla = DV == 512; // TODO better parameterization - const ggml_tensor * Q = dst->src[0]; const ggml_tensor * K = dst->src[1]; const ggml_tensor * V = dst->src[2]; - GGML_ASSERT(V || is_mla); + const bool V_is_K_view = V->view_src && (V->view_src == K || (V->view_src == K->view_src && V->view_offs == K->view_offs)); const ggml_tensor * mask = dst->src[3]; const ggml_tensor * sinks = dst->src[4]; @@ -794,9 +799,9 @@ void launch_fattn( GGML_ASSERT(Q->type == GGML_TYPE_F32); GGML_ASSERT(KQV->type == GGML_TYPE_F32); - GGML_ASSERT( Q->nb[0] == ggml_element_size(Q)); - GGML_ASSERT( K->nb[0] == ggml_element_size(K)); - GGML_ASSERT(!V || V->nb[0] == ggml_element_size(V)); + GGML_ASSERT(Q->nb[0] == ggml_element_size(Q)); + GGML_ASSERT(K->nb[0] == ggml_element_size(K)); + GGML_ASSERT(V->nb[0] == ggml_element_size(V)); GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16); @@ -817,10 +822,10 @@ void launch_fattn( size_t nb12 = K->nb[2]; size_t nb13 = K->nb[3]; - const char * V_data = V ? (const char *) V->data : nullptr; - size_t nb21 = V ? V->nb[1] : nb11; - size_t nb22 = V ? V->nb[2] : nb12; - size_t nb23 = V ? V->nb[3] : nb13; + const char * V_data = (const char *) V->data; + size_t nb21 = V->nb[1]; + size_t nb22 = V->nb[2]; + size_t nb23 = V->nb[3]; if (need_f16_K && K->type != GGML_TYPE_F16) { const size_t bs = ggml_blck_size(K->type); @@ -849,36 +854,45 @@ void launch_fattn( K_data = (char *) K_f16.ptr; } - if (V && need_f16_V && V->type != GGML_TYPE_F16) { - const size_t bs = ggml_blck_size(V->type); - const size_t ts = ggml_type_size(V->type); - - V_f16.alloc(ggml_nelements(V)); - if (ggml_is_contiguously_allocated(V)) { - to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type); - to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream); - V_data = (char *) V_f16.ptr; - - nb21 = nb21*bs*sizeof(half)/ts; - nb22 = nb22*bs*sizeof(half)/ts; - nb23 = nb23*bs*sizeof(half)/ts; + if (need_f16_V && V->type != GGML_TYPE_F16) { + if (V_is_K_view) { + V_data = K_data; + nb21 = nb11; + nb22 = nb12; + nb23 = nb13; } else { - GGML_ASSERT(V->nb[0] == ts); - to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(V->type); - const int64_t s01 = nb21 / ts; - const int64_t s02 = nb22 / ts; - const int64_t s03 = nb23 / ts; - to_fp16(V_data, V_f16.ptr, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream); - - nb21 = V->ne[0] * sizeof(half); - nb22 = V->ne[1] * nb21; - nb23 = V->ne[2] * nb22; + const size_t bs = ggml_blck_size(V->type); + const size_t ts = ggml_type_size(V->type); + + V_f16.alloc(ggml_nelements(V)); + if (ggml_is_contiguously_allocated(V)) { + to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type); + to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream); + V_data = (char *) V_f16.ptr; + + nb21 = nb21*bs*sizeof(half)/ts; + nb22 = nb22*bs*sizeof(half)/ts; + nb23 = nb23*bs*sizeof(half)/ts; + } else { + GGML_ASSERT(V->nb[0] == ts); + to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(V->type); + const int64_t s01 = nb21 / ts; + const int64_t s02 = nb22 / ts; + const int64_t s03 = nb23 / ts; + to_fp16(V_data, V_f16.ptr, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream); + + nb21 = V->ne[0] * sizeof(half); + nb22 = V->ne[1] * nb21; + nb23 = V->ne[2] * nb22; + } + V_data = (char *) V_f16.ptr; } - V_data = (char *) V_f16.ptr; } - const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1); - const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3]; + const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1); + const int gqa_ratio = Q->ne[2] / K->ne[2]; + const int ntiles_z_gqa = ((gqa_ratio + ncols2 - 1) / ncols2); + const int ntiles_total = ntiles_x * ntiles_z_gqa * K->ne[2] * Q->ne[3]; // Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped. // Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or @@ -953,7 +967,7 @@ void launch_fattn( blocks_num.x = ntiles_x; blocks_num.y = parallel_blocks; - blocks_num.z = (Q->ne[2]/ncols2)*Q->ne[3]; + blocks_num.z = ntiles_z_gqa*K->ne[2]*Q->ne[3]; if (parallel_blocks > 1) { dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); @@ -1007,7 +1021,7 @@ void launch_fattn( flash_attn_stream_k_fixup <<>> - ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1], nbatch_fa); + ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1], K->ne[2], nbatch_fa); } } else if (parallel_blocks > 1) { const dim3 block_dim_combine(DV, 1, 1); diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index e53bbc0502c..0b8ef90794c 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -400,7 +400,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask( } template static __device__ __forceinline__ void flash_attn_ext_f16_iter( const float2 * const __restrict__ Q_f2, @@ -432,7 +432,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( constexpr int ncols = ncols1 * ncols2; constexpr int cols_per_warp = T_B_KQ::I; constexpr int cols_per_thread = get_cols_per_thread(); - constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. + constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column. constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols); constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2(DKQ, DV, ncols); constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2(DKQ, DV, ncols); @@ -442,8 +442,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( constexpr int stride_tile_Q = DKQ/2 + 4; constexpr int stride_tile_K = nbatch_K2 + 4; - static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA"); - constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4; + constexpr int stride_tile_V = V_is_K_view ? stride_tile_K : nbatch_V2 + 4; const int k_VKQ_0 = kb0 * nbatch_fa; #if defined(TURING_MMA_AVAILABLE) @@ -456,7 +455,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( if constexpr (nstages > 1) { static_assert(!oob_check, "OOB check incompatible with multi-stage pipeline"); - static_assert(!mla, "multi-stage loading not implemented for MLA"); + static_assert(!V_is_K_view, "K data reuse not implemented multi-stage loading"); static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading"); constexpr bool use_cp_async = true; cp_async_wait_all(); @@ -471,8 +470,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } } + // For MLA K and V have the same data. + // Therefore, iterate over K in reverse and later re-use the data if possible. #pragma unroll - for (int k0_start = 0; k0_start < DKQ/2; k0_start += nbatch_K2) { + for (int k0_start = (DKQ/2-1) - (DKQ/2-1) % nbatch_K2; k0_start >= 0; k0_start -= nbatch_K2) { const int k0_stop = k0_start + nbatch_K2 < DKQ/2 ? k0_start + nbatch_K2 : DKQ/2; const int k0_diff = k0_stop - k0_start; @@ -510,7 +511,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } } } else { - static_assert(cols_per_warp != 8, "cols_per_warp == 8 not implemented"); #pragma unroll for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) { load_ldmatrix(Q_B[0], tile_Q + (threadIdx.y / np)*(T_B_KQ::I*stride_tile_Q) + k_KQ_0, stride_tile_Q); @@ -522,14 +522,18 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( T_A_KQ K_A; load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K); - // Wide version of KQ_C is column-major + if constexpr (cols_per_warp == 8) { + mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]); + } else { + // Wide version of KQ_C is column-major #if defined(AMD_WMMA_AVAILABLE) - // RDNA matrix C is column-major. - mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]); + // RDNA matrix C is column-major. + mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]); #else - // swap A and B for CUDA. - mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A); + // swap A and B for CUDA. + mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A); #endif // defined(AMD_WMMA_AVAILABLE) + } } } } @@ -773,6 +777,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } if constexpr (nstages > 1) { + static_assert(!V_is_K_view, "K data reuse not implemented multi-stage loading"); // Preload K tile for next iteration: constexpr bool use_cp_async = true; cp_async_wait_all(); @@ -788,10 +793,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } - // For MLA K and V have the same data. - // Therefore, iterate over V in reverse and re-use the data if possible. - static_assert(!mla || nstages <= 1, "combination of MLA and multi-stage loading not implemented"); - constexpr int reusable_cutoff = mla ? (DKQ - 1) - (DKQ - 1) % (2*nbatch_K2) - (DKQ - DV) : DV; #if defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE) T_A_VKQ A_identity; make_identity_mat(A_identity); @@ -799,12 +800,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( // Calculate VKQ tile, need to use logical rather than physical elements for i0 due to transposition of V: #pragma unroll - for (int i0_stop = DV; i0_stop > 0; i0_stop -= 2*nbatch_V2) { - const int i0_start = i0_stop - 2*nbatch_V2 > 0 ? i0_stop - 2*nbatch_V2 : 0; - const int i0_diff = i0_stop - i0_start; + for (int i0_start = 0; i0_start < DV; i0_start += 2*nbatch_V2) { + static_assert(DV % (2*nbatch_V2) == 0, "bad loop size"); + const int i0_stop = i0_start + 2*nbatch_V2; + const int i0_diff = i0_stop - i0_start; if constexpr (nstages <= 1) { - if (i0_start < reusable_cutoff) { + if (!V_is_K_view || i0_stop > 2*nbatch_K2) { constexpr bool use_cp_async = nstages == 1; flash_attn_ext_f16_load_tile (V_h2 + int64_t(k_VKQ_0)*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V, k_VKQ_sup); @@ -814,7 +816,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( __syncthreads(); } } - const half2 * tile_V_i = i0_start < reusable_cutoff ? tile_V : tile_V + (i0_start - reusable_cutoff)/2; + const half2 * tile_V_i = !V_is_K_view || i0_stop > 2*nbatch_K2 ? tile_V : tile_V + i0_start/2; #if defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int i0_stride = cols_per_warp == 8 ? T_C_VKQ::I : 2*T_C_VKQ::J; @@ -917,7 +919,7 @@ template struct mma_tile_sizes { }; #endif // defined(TURING_MMA_AVAILABLE) -template +template static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const float2 * const __restrict__ Q_f2, const half2 * const __restrict__ K_h2, @@ -931,6 +933,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const float logit_softcap, const uint3 ne01, const int ne02, + const int gqa_ratio, const int ne11, const int stride_Q1, const int stride_Q2, @@ -938,6 +941,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const int stride_V, const int stride_mask, const int jt, + const int zt_gqa, const int kb0_start, const int kb0_stop) { #if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) @@ -953,7 +957,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( constexpr int cols_per_warp = T_B_KQ::I; constexpr int cols_per_thread = get_cols_per_thread(); - constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. + constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column. constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa (DKQ, DV, ncols); constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2 (DKQ, DV, ncols); constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2 (DKQ, DV, ncols); @@ -971,8 +975,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( constexpr int stride_tile_Q = DKQ/2 + 4; constexpr int stride_tile_K = nbatch_K2 + 4; - static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA"); - constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4; + constexpr int stride_tile_V = V_is_K_view ? stride_tile_K : nbatch_V2 + 4; constexpr int stride_tile_KV_max = stride_tile_K > stride_tile_V ? stride_tile_K : stride_tile_V; extern __shared__ half2 tile_Q[]; @@ -1021,7 +1024,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const int j = jc / ncols2; const int c = jc % ncols2; - if (jt*ncols1 + j < int(ne01.z)) { + if ((ncols1 == 1 || jt*ncols1 + j < int(ne01.z)) && (ncols2 == 1 || zt_gqa*ncols2 + c < gqa_ratio)) { #pragma unroll for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); @@ -1076,7 +1079,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( constexpr bool last_iter = false; constexpr int k_VKQ_sup = nbatch_fa; flash_attn_ext_f16_iter - (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap, ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, @@ -1085,7 +1088,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( constexpr bool last_iter = true; const int k_VKQ_sup = ne11 - kb0*nbatch_fa; flash_attn_ext_f16_iter - (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap, ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, @@ -1096,7 +1099,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( constexpr bool last_iter = false; constexpr int k_VKQ_sup = nbatch_fa; flash_attn_ext_f16_iter - (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap, ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, @@ -1105,7 +1108,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( constexpr bool last_iter = true; constexpr int k_VKQ_sup = nbatch_fa; flash_attn_ext_f16_iter - (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap, ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, @@ -1407,7 +1410,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const int j_dst = jc_dst / ncols2; const int c_dst = jc_dst % ncols2; - if (!is_fixup && jt*ncols1 + j_dst >= int(ne01.z)) { + if (!is_fixup && ((ncols1 > 1 && jt*ncols1 + j_dst >= int(ne01.z)) || (ncols2 > 1 && zt_gqa*ncols2 + c_dst >= gqa_ratio))) { continue; } @@ -1446,14 +1449,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } #else GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dstk_fixup, - scale, slope, logit_softcap, ne01, ne02, + scale, slope, logit_softcap, ne01, ne02, gqa_ratio, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop); NO_DEVICE_CODE; #endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) } -template +template __launch_bounds__(ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols1*ncols2), ggml_cuda_fattn_mma_get_occupancy(DKQ, DV, ncols1*ncols2)) static __global__ void flash_attn_ext_f16( const char * __restrict__ Q, @@ -1484,6 +1487,13 @@ static __global__ void flash_attn_ext_f16( NO_DEVICE_CODE; return; } +#ifdef VOLTA_MMA_AVAILABLE + if (ncols1*ncols2 < 32) { + NO_DEVICE_CODE; + return; + } +#endif // VOLTA_MMA_AVAILABLE + #if __CUDA_ARCH__ == GGML_CUDA_CC_TURING if (ncols1*ncols2 > 32) { NO_DEVICE_CODE; @@ -1498,8 +1508,6 @@ static __global__ void flash_attn_ext_f16( } #endif // defined(AMD_WMMA_AVAILABLE) - static_assert(!mla || DKQ >= DV, "MLA needs DKQ >= DV"); - constexpr int ncols = ncols1 * ncols2; constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols); constexpr int nthreads = ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols); @@ -1512,14 +1520,15 @@ static __global__ void flash_attn_ext_f16( const int stride_K = nb11 / sizeof(half2); const int stride_mask = nb31 / sizeof(half); - const int stride_V = mla ? stride_K : nb21 / sizeof(half2); + const int stride_V = V_is_K_view ? stride_K : nb21 / sizeof(half2); - const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa; - const int iter_j = (ne01.z + (ncols1 - 1)) / ncols1; + const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa; + const int iter_j = (ne01.z + (ncols1 - 1)) / ncols1; + const int iter_z_gqa = (gqa_ratio + (ncols2 - 1)) / ncols2; // kbc == k block continuous, current index in continuous ijk space. - int kbc = int64_t(blockIdx.x + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x; - const int kbc_stop = int64_t(blockIdx.x + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x; + int kbc = int64_t(blockIdx.x + 0)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x; + const int kbc_stop = int64_t(blockIdx.x + 1)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x; // If the seams of 2 CUDA blocks fall within an output tile their results need to be combined. // For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup). @@ -1530,22 +1539,24 @@ static __global__ void flash_attn_ext_f16( int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc); while (kbc < kbc_stop && kb0_stop == iter_k) { - const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2)); - const int zt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); // head in units of ncols2 - const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile. + // z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index + const int sequence = kbc /(iter_k*iter_j*iter_z_gqa*ne12); + const int z_KV = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa); + const int zt_gqa = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j); + const int jt = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k; - const int head0 = zt * ncols2; + const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index. - const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0); - const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio)); + const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*zt_Q); + const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*z_KV); const half * mask_h = ncols2 == 1 && !mask ? nullptr : (const half *) (mask + nb33*(sequence % ne33)); - float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2); + float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + zt_Q) * (DV/2); - const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); - const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr; + const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*z_KV); + const float * sinks_f = sinks ? (const float *) sinks + zt_Q : nullptr; - const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f; + const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, zt_Q, n_head_log2, m0, m1) : 1.0f; if (KV_max) { kb0_stop = min(kb0_stop, KV_max[sequence*iter_j + jt] / nbatch_fa); @@ -1553,14 +1564,14 @@ static __global__ void flash_attn_ext_f16( constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer. if (kb0_start == 0) { constexpr bool needs_fixup = false; // CUDA block is working on an entire tile. - flash_attn_ext_f16_process_tile + flash_attn_ext_f16_process_tile (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, - ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop); + ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop); } else { constexpr bool needs_fixup = true; // CUDA block is missing the beginning of a tile. - flash_attn_ext_f16_process_tile + flash_attn_ext_f16_process_tile (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, - ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop); + ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop); } kbc += iter_k; @@ -1574,22 +1585,24 @@ static __global__ void flash_attn_ext_f16( return; } - const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2)); - const int zt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); // head in units of ncols2 - const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile. + // z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index. + const int sequence = kbc /(iter_k*iter_j*iter_z_gqa*ne12); + const int z_KV = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa); + const int zt_gqa = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j); + const int jt = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k; - const int head0 = zt * ncols2; + const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index. - const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0); - const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio)); + const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*zt_Q); + const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*z_KV); const half * mask_h = ncols2 == 1 && !mask ? nullptr : (const half *) (mask + nb33*(sequence % ne33)); - float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2); + float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + zt_Q) * (DV/2); - const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); - const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr; + const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*z_KV); + const float * sinks_f = sinks ? (const float *) sinks + zt_Q : nullptr; - const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f; + const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, zt_Q, n_head_log2, m0, m1) : 1.0f; if (KV_max) { kb0_stop = min(kb0_stop, KV_max[sequence*iter_j + jt] / nbatch_fa); @@ -1597,9 +1610,9 @@ static __global__ void flash_attn_ext_f16( constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. constexpr bool needs_fixup = false; - flash_attn_ext_f16_process_tile + flash_attn_ext_f16_process_tile (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, - ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop); + ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop); #else GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, max_bias, m0, m1, n_head_log2, logit_softcap, @@ -1633,7 +1646,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml const int cols_per_warp = std::min(ncols, get_cols_per_warp(cc)); const int nwarps = nthreads / WARP_SIZE; - constexpr bool mla = DKQ == 576; + constexpr bool V_is_K_view = DKQ == 576; // Guaranteed by the kernel selection logic in fattn.cu const size_t nbytes_shared_KV_1stage = nbatch_fa * std::max(nbatch_K2 + 4, nbatch_V2 + 4) * sizeof(half2); const size_t nbytes_shared_KV_2stage = nbatch_fa * (nbatch_K2 + 4 + nbatch_V2 + 4) * sizeof(half2); @@ -1658,7 +1671,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml fattn_kernel_t fattn_kernel; if (logit_softcap == 0.0f) { constexpr bool use_logit_softcap = false; - fattn_kernel = flash_attn_ext_f16; + fattn_kernel = flash_attn_ext_f16; #if !defined(GGML_USE_MUSA) static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; @@ -1669,7 +1682,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml #endif // !defined(GGML_USE_MUSA) } else { constexpr bool use_logit_softcap = true; - fattn_kernel = flash_attn_ext_f16; + fattn_kernel = flash_attn_ext_f16; #if !defined(GGML_USE_MUSA) static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; @@ -1728,3 +1741,10 @@ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 64) extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16); extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16); extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16); + +// For GLM 4.7 Flash +extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4); +extern DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4); +extern DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4); +extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 32); +extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 32); diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh index f055da8e2be..b6db5822818 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cuh +++ b/ggml/src/ggml-cuda/fattn-tile.cuh @@ -68,6 +68,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64) return 0; @@ -122,6 +124,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 32, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64) return 0; @@ -183,6 +187,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 512, 1, 128, 64) @@ -245,6 +251,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 128, 64) @@ -1187,6 +1195,10 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm launch_fattn_tile_switch_ncols1(ctx, dst); return; } + if (use_gqa_opt && gqa_ratio % 4 == 0) { + launch_fattn_tile_switch_ncols1(ctx, dst); + return; + } } if constexpr (DV <= 256) { diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 598cda7daa0..721edd99944 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -18,9 +18,11 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_con } } - if ((turing_mma_available(cc) || amd_wmma_available(cc)) && Q->ne[1] <= 16/ncols2) { - ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); - return; + if constexpr (ncols2 <= 16) { + if ((turing_mma_available(cc) || amd_wmma_available(cc)) && Q->ne[1] <= 16/ncols2) { + ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); + return; + } } if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING || amd_wmma_available(cc) || Q->ne[1] <= 32/ncols2) { @@ -33,6 +35,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_con template static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; const ggml_tensor * KQV = dst; const ggml_tensor * Q = dst->src[0]; const ggml_tensor * K = dst->src[1]; @@ -46,7 +49,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con // are put into the template specialization without GQA optimizations. bool use_gqa_opt = mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0; for (const ggml_tensor * t : {Q, K, V, mask}) { - if (t == nullptr) { + if (t == nullptr || ggml_is_quantized(t->type)) { continue; } for (size_t i = 1; i < GGML_MAX_DIMS; ++i) { @@ -60,17 +63,38 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); const int gqa_ratio = Q->ne[2] / K->ne[2]; - if (use_gqa_opt && gqa_ratio % 8 == 0) { + // On Volta the GQA optimizations aren't as impactful vs. minimizing wasted compute: + if (cc == GGML_CUDA_CC_VOLTA) { + if (use_gqa_opt && gqa_ratio % 8 == 0) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); + return; + } + + if (use_gqa_opt && gqa_ratio % 4 == 0) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); + return; + } + + if (use_gqa_opt && gqa_ratio % 2 == 0) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); + return; + } + + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); + return; + } + + if (use_gqa_opt && gqa_ratio > 4) { ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); return; } - if (use_gqa_opt && gqa_ratio % 4 == 0) { + if (use_gqa_opt && gqa_ratio > 2) { ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); return; } - if (use_gqa_opt && gqa_ratio % 2 == 0) { + if (use_gqa_opt && gqa_ratio > 1) { ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); return; } @@ -79,6 +103,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con } static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; const ggml_tensor * KQV = dst; const ggml_tensor * Q = dst->src[0]; const ggml_tensor * K = dst->src[1]; @@ -121,8 +146,50 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); const int gqa_ratio = Q->ne[2] / K->ne[2]; - GGML_ASSERT(gqa_ratio % 16 == 0); - ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst); + if (gqa_ratio == 20) { // GLM 4.7 Flash + if (cc >= GGML_CUDA_CC_DGX_SPARK) { + if (Q->ne[1] <= 8) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst); + break; + } + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst); + break; + } + if (cc >= GGML_CUDA_CC_BLACKWELL) { + if (Q->ne[1] <= 4 && K->ne[1] >= 65536) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst); + break; + } + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst); + break; + } + if (cc >= GGML_CUDA_CC_ADA_LOVELACE) { + if (Q->ne[1] <= 4) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst); + break; + } + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst); + break; + } + if (cc >= GGML_CUDA_CC_TURING) { + if (Q->ne[1] <= 4) { + if (K->ne[1] <= 16384) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst); + break; + } + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 32>(ctx, dst); + break; + } + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst); + break; + } + // Volta: + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst); + } else if (gqa_ratio % 16 == 0) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst); + } else { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst); + } } break; default: GGML_ABORT("fatal error"); @@ -230,9 +297,9 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const // The effective batch size for the kernel can be increased by gqa_ratio. // The kernel versions without this optimization are also used for ALiBi, if there is no mask, or if the KV cache is not padded, - bool gqa_opt_applies = gqa_ratio % 2 == 0 && mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0; + bool gqa_opt_applies = gqa_ratio >= 2 && mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0; for (const ggml_tensor * t : {Q, K, V, mask}) { - if (t == nullptr) { + if (t == nullptr || ggml_is_quantized(t->type)) { continue; } for (size_t i = 1; i < GGML_MAX_DIMS; ++i) { @@ -262,7 +329,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const if (V->ne[0] != 512) { return BEST_FATTN_KERNEL_NONE; } - if (!gqa_opt_applies || gqa_ratio % 16 != 0) { + if (!gqa_opt_applies) { return BEST_FATTN_KERNEL_NONE; } break; diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index cda422defbe..9e77c231c85 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -70,17 +70,18 @@ #include #include #include -#include +#include #include #include #include #include #include -#include -#include -#include +#include +#include +#include #include #include +#include static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); @@ -2278,13 +2279,19 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - if (ne2 == 1) { + static_assert(MMVQ_MAX_BATCH_SIZE == MMVF_MAX_BATCH_SIZE); + if (ne2 <= MMVQ_MAX_BATCH_SIZE) { if (ggml_is_quantized(src0->type)) { - ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst); + if (ne2 <= 4) { + ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst); + return; + } } else { - ggml_cuda_mul_mat_vec_f(ctx, src0, src1, ids, dst); + if (GGML_CUDA_CC_IS_AMD(cc)) { + ggml_cuda_mul_mat_vec_f(ctx, src0, src1, ids, dst); + return; + } } - return; } if (ggml_cuda_should_use_mmq(src0->type, cc, ne12, /*n_experts=*/ne02)) { @@ -2916,22 +2923,27 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) { } static void ggml_cuda_graph_node_set_properties(ggml_cuda_graph_node_properties * props, ggml_tensor * node) { - props->node_address = node->data; + memset(props, 0, sizeof(ggml_cuda_graph_node_properties)); + props->node_data = node->data; props->node_op = node->op; + props->node_type = node->type; props->flags = node->flags; for (int i = 0; i < GGML_MAX_DIMS; i++) { props->ne[i] = node->ne[i]; props->nb[i] = node->nb[i]; } for (int i = 0; i < GGML_MAX_SRC; i++) { - props->src_address[i] = node->src[i] ? node->src[i]->data : nullptr; + if (!node->src[i]) { + continue; + } + + props->src_data[i] = node->src[i]->data; } memcpy(props->op_params, node->op_params, GGML_MAX_OP_PARAMS); } static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_graph_node_properties * props) { - if (node->data != props->node_address && - node->op != GGML_OP_VIEW) { + if (node->data != props->node_data && node->op != GGML_OP_VIEW) { return false; } @@ -2939,6 +2951,10 @@ static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_ return false; } + if (node->type != props->node_type) { + return false; + } + for (int i = 0; i < GGML_MAX_DIMS; i++) { if (node->ne[i] != props->ne[i]) { return false; @@ -2948,17 +2964,22 @@ static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_ } } - for (int i = 0; i < GGML_MAX_SRC; i++) { - if (node->src[i] && - node->src[i]->data != props->src_address[i] && - node->op != GGML_OP_VIEW - ) { - return false; + if (node->op != GGML_OP_VIEW) { + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (!node->src[i]) { + if (props->src_data[i] != nullptr) { + return false; + } + continue; + } + + if (node->src[i]->data != props->src_data[i]) { + return false; + } } } - if ((node->op == GGML_OP_SCALE || node->op == GGML_OP_GLU) && - memcmp(props->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) { + if (memcmp(props->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) { return false; } @@ -2969,56 +2990,82 @@ static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_ return true; } -static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) { +static const void * ggml_cuda_graph_get_key(ggml_cgraph * cgraph) { + return cgraph->nodes[0]; +} +static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) { bool res = false; - if (cuda_ctx->cuda_graph->instance == nullptr) { + const void * graph_key = ggml_cuda_graph_get_key(cgraph); + ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key); + + if (graph->instance == nullptr) { res = true; } // Check if the graph size has changed - if (cuda_ctx->cuda_graph->props.size() != (size_t)cgraph->n_nodes + cgraph->n_leafs) { + if (graph->props.size() != (size_t)cgraph->n_nodes) { res = true; - cuda_ctx->cuda_graph->props.resize(cgraph->n_nodes + cgraph->n_leafs); + graph->props.resize(cgraph->n_nodes); } // Loop over nodes in GGML graph to determine if CUDA graph update is required // and store properties to allow this comparison for the next token + std::unordered_set seen_node; + std::vector srcs_extra; for (int i = 0; i < cgraph->n_nodes; i++) { bool props_match = true; + + seen_node.insert(cgraph->nodes[i]); + if (!res) { - props_match = ggml_cuda_graph_node_properties_match(cgraph->nodes[i], &cuda_ctx->cuda_graph->props[i]); + props_match = ggml_cuda_graph_node_properties_match(cgraph->nodes[i], &graph->props[i]); } if (!props_match) { res = true; } - ggml_cuda_graph_node_set_properties(&cuda_ctx->cuda_graph->props[i], cgraph->nodes[i]); + ggml_cuda_graph_node_set_properties(&graph->props[i], cgraph->nodes[i]); + + for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) { + ggml_tensor * src = cgraph->nodes[i]->src[src_idx]; + if (src && seen_node.find(src) == seen_node.end()) { + srcs_extra.push_back(src); + } + } + } + + if (graph->extra.size() != (size_t) srcs_extra.size()) { + res = true; + graph->extra.resize(srcs_extra.size()); } - for (int i = 0; i < cgraph->n_leafs; i++) { - bool props_match= true; + for (size_t i = 0; i < srcs_extra.size(); ++i) { + bool props_match = true; + if (!res) { - props_match = ggml_cuda_graph_node_properties_match(cgraph->leafs[i], &cuda_ctx->cuda_graph->props[cgraph->n_nodes + i]); + props_match = ggml_cuda_graph_node_properties_match(srcs_extra[i], &graph->extra[i]); } + if (!props_match) { res = true; } - ggml_cuda_graph_node_set_properties(&cuda_ctx->cuda_graph->props[cgraph->n_nodes + i], cgraph->leafs[i]); + ggml_cuda_graph_node_set_properties(&graph->extra[i], srcs_extra[i]); } return res; } -static void ggml_cuda_graph_update_executable(ggml_backend_cuda_context * cuda_ctx) { +static void ggml_cuda_graph_update_executable(ggml_backend_cuda_context * cuda_ctx, const void * graph_key) { + ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key); #if CUDART_VERSION >= 12000 cudaGraphExecUpdateResultInfo result_info; - cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info); + cudaError_t stat = cudaGraphExecUpdate(graph->instance, graph->graph, &result_info); #else cudaGraphNode_t errorNode; cudaGraphExecUpdateResult result_info; - cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info); + cudaError_t stat = cudaGraphExecUpdate(graph->instance, graph->graph, &errorNode, &result_info); #endif // CUDART_VERSION >= 12000 if (stat == cudaErrorGraphExecUpdateFailure) { @@ -3029,14 +3076,14 @@ static void ggml_cuda_graph_update_executable(ggml_backend_cuda_context * cuda_c // The pre-existing graph exec cannot be updated due to violated constraints // so instead clear error and re-instantiate (void)cudaGetLastError(); - CUDA_CHECK(cudaGraphExecDestroy(cuda_ctx->cuda_graph->instance)); - cuda_ctx->cuda_graph->instance = nullptr; - CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0)); + CUDA_CHECK(cudaGraphExecDestroy(graph->instance)); + graph->instance = nullptr; + CUDA_CHECK(cudaGraphInstantiate(&graph->instance, graph->graph, NULL, NULL, 0)); } else { GGML_ASSERT(stat == cudaSuccess); } } -#endif +#endif // USE_CUDA_GRAPH static bool ggml_cuda_should_fuse_rope_set_rows(const ggml_tensor * rope, const ggml_tensor * view, @@ -3072,63 +3119,166 @@ static bool ggml_cuda_should_fuse_rope_set_rows(const ggml_tensor * rope, return true; } -static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list ops, std::initializer_list unary_ops) { -#ifndef NDEBUG - const size_t num_unary = std::count(ops.begin(), ops.end(), GGML_OP_UNARY); - GGML_ASSERT(unary_ops.size() == num_unary); -#endif - - //TODO: remove special case once ggml_can_fuse can handle empty nodes - std::initializer_list topk_moe_ops = - ggml_cuda_topk_moe_ops(/*with_norm*/ false, /*delayed_softmax=*/false); - std::initializer_list topk_moe_ops_with_norm = - ggml_cuda_topk_moe_ops(/*with_norm=*/true, /*delayed_softmax=*/false); - std::initializer_list topk_moe_ops_delayed_softmax = - ggml_cuda_topk_moe_ops(/*with_norm=*/false, /*delayed_softmax=*/true); +static bool ggml_cuda_topk_moe_fusion(const struct ggml_cgraph * cgraph, int node_idx, ggml_cuda_topk_moe_args & args) { + args.sigmoid = false; + args.softmax = false; + args.delayed_softmax = false; + args.prob_bias = false; + args.norm = false; - const auto is_equal = [](const std::initializer_list & list1, - const std::initializer_list & list2) { - return std::equal(list1.begin(), list1.end(), list2.begin(), list2.end()); - }; + const int n_nodes = cgraph->n_nodes; + ggml_tensor ** nodes = cgraph->nodes; - if (is_equal(topk_moe_ops_with_norm, ops) && - ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 9 })) { - ggml_tensor * softmax = cgraph->nodes[node_idx]; - ggml_tensor * weights = cgraph->nodes[node_idx + 9]; - ggml_tensor * get_rows = cgraph->nodes[node_idx + 4]; - ggml_tensor * argsort = cgraph->nodes[node_idx + 2]; - int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0]; + if (nodes[node_idx]->op == GGML_OP_SOFT_MAX) { + args.softmax = true; + } - if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) { - return true; + if (nodes[node_idx]->op == GGML_OP_UNARY) { + if (ggml_get_unary_op(nodes[node_idx]) != GGML_UNARY_OP_SIGMOID) { + return false; } + args.sigmoid = true; } - if (is_equal(topk_moe_ops, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 4 })) { - ggml_tensor * softmax = cgraph->nodes[node_idx]; - ggml_tensor * weights = cgraph->nodes[node_idx + 4]; - ggml_tensor * get_rows = cgraph->nodes[node_idx + 4]; - ggml_tensor * argsort = cgraph->nodes[node_idx + 2]; - int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0]; + if (nodes[node_idx]->op == GGML_OP_ARGSORT) { + args.delayed_softmax = true; + } - if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) { - return true; + node_idx++; + + if (args.sigmoid || args.softmax) { + // SOFTMAX -> RESHAPE + if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_RESHAPE || + nodes[node_idx]->src[0] != nodes[node_idx - 1]) { + return false; + } + ggml_tensor * probs_reshaped = nodes[node_idx]; + node_idx++; + + if (node_idx >= n_nodes) { + return false; + } + + // src of bias add is the unreshaped probs (-2 instead of -1) + if (nodes[node_idx]->op == GGML_OP_ADD && nodes[node_idx]->src[0] == nodes[node_idx - 2]) { + args.prob_bias = true; + node_idx++; + } + // RESHAPE/ADD -> ARGSORT + if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_ARGSORT) { + return false; + } + + if (args.prob_bias && nodes[node_idx]->src[0] != nodes[node_idx - 1]) { + return false; + } else if (!args.prob_bias && nodes[node_idx]->src[0] != nodes[node_idx - 2]) { + return false; + } + + node_idx++; + + // ARGSORT-> VIEW + if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_VIEW || + nodes[node_idx]->src[0] != nodes[node_idx - 1]) { + return false; + } + node_idx++; + + if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_GET_ROWS) { + return false; + } + + // GET_ROWS + if (nodes[node_idx]->src[0] != probs_reshaped || nodes[node_idx]->src[1] != nodes[node_idx - 1]) { + return false; + } + node_idx++; + } else if (args.delayed_softmax) { + if (node_idx - 2 < 0) { + return false; + } + ggml_tensor * probs_reshaped = nodes[node_idx - 2]; + + // VIEW->ARGSORT + if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_VIEW || + nodes[node_idx]->src[0] != nodes[node_idx - 1]) { + return false; + } + node_idx++; + + // GET_ROWS + if (node_idx >= n_nodes || nodes[node_idx]->src[1] != nodes[node_idx - 1] || + nodes[node_idx]->src[0] != probs_reshaped) { + return false; + } + node_idx++; + + static const std::vector remaining_ops = { GGML_OP_RESHAPE, GGML_OP_SOFT_MAX, GGML_OP_RESHAPE }; + + for (const ggml_op op : remaining_ops) { + if (node_idx >= n_nodes || nodes[node_idx]->op != op || nodes[node_idx]->src[0] != nodes[node_idx - 1]) { + return false; + } + node_idx++; } } - if (is_equal(topk_moe_ops_delayed_softmax, ops) && - ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 1, node_idx + 5 })) { - ggml_tensor * softmax = cgraph->nodes[node_idx + 4]; - ggml_tensor * weights = cgraph->nodes[node_idx + 5]; - ggml_tensor * get_rows = cgraph->nodes[node_idx + 2]; - ggml_tensor * argsort = cgraph->nodes[node_idx + 0]; - int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0]; + // At this point we can check for norm + scale. Everything is now at least valid till the norm + if (node_idx >= n_nodes) { + return true; + } + + if (nodes[node_idx]->op == GGML_OP_RESHAPE) { + //check RESHAPE->SUM_ROWS->CLAMP->DIV->RESHAPE + static const std::vector norm_ops = { GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP }; - if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) { + args.norm = true; + for (const ggml_op op : norm_ops) { + if (nodes[node_idx]->op == op && nodes[node_idx]->src[0] == nodes[node_idx - 1]) { + node_idx++; + } else { + args.norm = false; + return true; + } + } + + // DIV <- CLAMP, RESHAPE + if (nodes[node_idx]->op != GGML_OP_DIV || nodes[node_idx]->src[1] != nodes[node_idx - 1] || + nodes[node_idx]->src[0] != nodes[node_idx - 3]) { + args.norm = false; return true; } + node_idx++; + + if (nodes[node_idx]->op != GGML_OP_RESHAPE || nodes[node_idx]->src[0] != nodes[node_idx - 1]) { + args.norm = false; + return true; + } + + node_idx++; + } + + if (nodes[node_idx]->op == GGML_OP_SCALE && nodes[node_idx]->src[0] == nodes[node_idx - 1]) { + args.scale = true; } + return true; +} + +static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, + int node_idx, + std::initializer_list ops, + std::initializer_list unary_ops) { +#ifndef NDEBUG + const size_t num_unary = std::count(ops.begin(), ops.end(), GGML_OP_UNARY); + GGML_ASSERT(unary_ops.size() == num_unary); +#endif + + const auto is_equal = [](const std::initializer_list & list1, + const std::initializer_list & list2) { + return std::equal(list1.begin(), list1.end(), list2.begin(), list2.end()); + }; + std::initializer_list mul_mat_bias_glu_ops = { GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_GLU }; std::initializer_list mul_mat_id_bias_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_GLU }; @@ -3241,7 +3391,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, return false; } -static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, const bool use_cuda_graph, const bool cuda_graph_update_required) { +static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, const bool use_cuda_graph, const bool cuda_graph_update_required, const void * graph_key) { bool graph_evaluated_or_captured = false; // flag used to determine whether it is an integrated_gpu @@ -3390,35 +3540,75 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud // start of fusion operations static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr); if (!disable_fusion) { + ggml_cuda_topk_moe_args args; + + if (cgraph->nodes[i]->op == GGML_OP_UNARY || cgraph->nodes[i]->op == GGML_OP_SOFT_MAX || + cgraph->nodes[i]->op == GGML_OP_ARGSORT) { + const bool can_fuse = ggml_cuda_topk_moe_fusion(cgraph, i, args); + + std::vector ops; + + if (can_fuse) { + const ggml_tensor * logits = node->src[0]; + ggml_tensor * weights = nullptr; + ggml_tensor * ids = nullptr; + const ggml_tensor * bias = nullptr; + const ggml_tensor * clamp = nullptr; + const ggml_tensor * scale = nullptr; + + if (!args.delayed_softmax) { + ggml_op gating_op = args.sigmoid ? GGML_OP_UNARY : GGML_OP_SOFT_MAX; + int out_nodes[2]; // nodes which can't be elided + + if (args.prob_bias) { + bias = cgraph->nodes[i + 2]->src[1]; + ops.insert(ops.end(), { gating_op, GGML_OP_RESHAPE, GGML_OP_ADD, GGML_OP_ARGSORT, + GGML_OP_VIEW, GGML_OP_GET_ROWS }); + out_nodes[0] = i + 4; + ids = cgraph->nodes[i + 4]; + } else { + ops.insert(ops.end(), { gating_op, GGML_OP_RESHAPE, GGML_OP_ARGSORT, GGML_OP_VIEW, + GGML_OP_GET_ROWS }); + out_nodes[0] = i + 3; + ids = cgraph->nodes[i + 3]; + } - if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ true), {})) { - ggml_tensor * weights = cgraph->nodes[i + 9]; - ggml_tensor * selected_experts = cgraph->nodes[i + 3]; - ggml_tensor * clamp = cgraph->nodes[i + 7]; - ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ true, - /*delayed softmax*/ false, clamp); - i += 9; - continue; - } - - if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ false), {})) { - ggml_tensor * weights = cgraph->nodes[i + 4]; - ggml_tensor * selected_experts = cgraph->nodes[i + 3]; - ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ false, - /*delayed softmax*/ false); - i += 4; - continue; - } + if (args.norm) { + ops.insert(ops.end(), { GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP, + GGML_OP_DIV, GGML_OP_RESHAPE }); + clamp = cgraph->nodes[i + ops.size() - 3]; + } + if (args.scale) { + ops.insert(ops.end(), { GGML_OP_SCALE }); + scale = cgraph->nodes[i + ops.size() - 1]; + } - if (ggml_cuda_can_fuse(cgraph, i, - ggml_cuda_topk_moe_ops(/*with norm*/ false, /*delayed softmax*/ true), {})) { - ggml_tensor * weights = cgraph->nodes[i + 5]; - ggml_tensor * ids = cgraph->nodes[i + 1]; + weights = cgraph->nodes[i + ops.size() - 1]; + out_nodes[1] = i + ops.size() - 1; - ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, ids, /*with norm*/ false, - /*delayed_softmax*/ true); - i += 5; - continue; + if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) && + ggml_cuda_should_use_topk_moe(node, logits, weights, ids)) { + ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args); + i += ops.size() - 1; + continue; + } + } else if (!args.norm && !args.prob_bias) { + //special case gpt-oss, no norm, no bias. + ops.insert(ops.end(), { GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS, + GGML_OP_RESHAPE, GGML_OP_SOFT_MAX, GGML_OP_RESHAPE }); + weights = cgraph->nodes[i + 5]; + ids = cgraph->nodes[i + 1]; + const ggml_tensor * softmax = cgraph->nodes[i + 4]; + + int out_nodes[2] = { i + 1, i + 5 }; + if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) && + ggml_cuda_should_use_topk_moe(softmax, logits, weights, ids)) { + ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args); + i += ops.size() - 1; + continue; + } + } + } } if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, {})) { @@ -3695,13 +3885,14 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud } #ifdef USE_CUDA_GRAPH + ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key); if (use_cuda_graph && cuda_graph_update_required) { // End CUDA graph capture - if (cuda_ctx->cuda_graph->graph != nullptr) { - CUDA_CHECK(cudaGraphDestroy(cuda_ctx->cuda_graph->graph)); - cuda_ctx->cuda_graph->graph = nullptr; + if (graph->graph != nullptr) { + CUDA_CHECK(cudaGraphDestroy(graph->graph)); + graph->graph = nullptr; } - CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph)); + CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &graph->graph)); graph_evaluated_or_captured = true; // CUDA graph has been captured std::lock_guard lock(ggml_cuda_lock); @@ -3714,43 +3905,38 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud } if (use_cuda_graph) { - if (cuda_ctx->cuda_graph->instance == nullptr) { // Create executable graph from captured graph. - CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0)); + ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key); + if (graph->instance == nullptr) { // Create executable graph from captured graph. + CUDA_CHECK(cudaGraphInstantiate(&graph->instance, graph->graph, NULL, NULL, 0)); } if (cuda_graph_update_required) { // Update graph executable - ggml_cuda_graph_update_executable(cuda_ctx); + ggml_cuda_graph_update_executable(cuda_ctx, graph_key); } // Launch graph - CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream())); + CUDA_CHECK(cudaGraphLaunch(graph->instance, cuda_ctx->stream())); #else + GGML_UNUSED(graph_key); graph_evaluated_or_captured = true; #endif // USE_CUDA_GRAPH } } -static bool ggml_cuda_graph_set_enabled(ggml_backend_cuda_context * cuda_ctx) { - #ifdef USE_CUDA_GRAPH +static bool ggml_cuda_graph_set_enabled(ggml_backend_cuda_context * cuda_ctx, const void * graph_key) { + ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key); - if (cuda_ctx->cuda_graph == nullptr) { - cuda_ctx->cuda_graph.reset(new ggml_cuda_graph()); - } - - if (cuda_ctx->cuda_graph->graph == nullptr) { + if (graph->graph == nullptr) { if (ggml_cuda_info().devices[cuda_ctx->device].cc < GGML_CUDA_CC_AMPERE) { - if (!cuda_ctx->cuda_graph->disable_due_to_gpu_arch) { + if (!graph->disable_due_to_gpu_arch) { GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__); } - cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true; + graph->disable_due_to_gpu_arch = true; } } - return cuda_ctx->cuda_graph->is_enabled(); -#else - GGML_UNUSED(cuda_ctx); - return false; -#endif // USE_CUDA_GRAPH + return graph->is_enabled(); } +#endif // USE_CUDA_GRAPH static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context; @@ -3759,15 +3945,19 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, bool use_cuda_graph = false; bool cuda_graph_update_required = false; + const void * graph_key = nullptr; #ifdef USE_CUDA_GRAPH - use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx); + graph_key = ggml_cuda_graph_get_key(cgraph); + + use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx, graph_key); - if (cuda_ctx->cuda_graph->is_enabled()) { + ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key); + if (graph->is_enabled()) { cuda_graph_update_required = ggml_cuda_graph_update_required(cuda_ctx, cgraph); use_cuda_graph = ggml_cuda_graph_check_compability(cgraph); - cuda_ctx->cuda_graph->record_update(use_cuda_graph, cuda_graph_update_required); + graph->record_update(use_cuda_graph, cuda_graph_update_required); } #endif // USE_CUDA_GRAPH @@ -3781,7 +3971,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed)); } - ggml_cuda_graph_evaluate_and_capture(cuda_ctx, cgraph, use_cuda_graph, cuda_graph_update_required); + ggml_cuda_graph_evaluate_and_capture(cuda_ctx, cgraph, use_cuda_graph, cuda_graph_update_required, graph_key); return GGML_STATUS_SUCCESS; } @@ -3814,7 +4004,14 @@ static void ggml_backend_cuda_event_wait(ggml_backend_t backend, ggml_backend_ev static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) { ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context; - const bool use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx); +#ifdef USE_CUDA_GRAPH + const void * graph_key = ggml_cuda_graph_get_key(cgraph); + const bool use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx, graph_key); +#else + const bool use_cuda_graph = false; + GGML_UNUSED(cuda_ctx); + GGML_UNUSED(cgraph); +#endif static bool enable_graph_optimization = [] { const char * env = getenv("GGML_CUDA_GRAPH_OPT"); diff --git a/ggml/src/ggml-cuda/mean.cu b/ggml/src/ggml-cuda/mean.cu index 60542fc19dd..49af5389957 100644 --- a/ggml/src/ggml-cuda/mean.cu +++ b/ggml/src/ggml-cuda/mean.cu @@ -31,14 +31,15 @@ void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { #endif // USE_CUDA_GRAPH if ((nrows == 1) && #ifdef USE_CUDA_GRAPH - // CUDA_GRAPHS_DISABLED - ((ncols > 65536) && - ((ctx.cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) || - ctx.cuda_graph->is_enabled())) || - // CUDA_GRAPHS ENABLED - ((ncols > 32768) && - !((ctx.cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) || - ctx.cuda_graph->is_enabled()))) { + // Determine if CUDA graphs are effectively disabled for this context + // (no graph instance exists and we're not capturing, OR graphs are explicitly enabled) + (((ncols > 65536) && + (((!ctx.any_cuda_graph_has_instance()) && (iscapturing == cudaStreamCaptureStatusNone)) || + ctx.any_cuda_graph_enabled())) || + // CUDA graphs are enabled - use lower threshold + ((ncols > 32768) && + !(((!ctx.any_cuda_graph_has_instance()) && (iscapturing == cudaStreamCaptureStatusNone)) || + ctx.any_cuda_graph_enabled())))) { #else (ncols > 65536)) { #endif // USE_CUDA_GRAPH diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index 42085d10027..dd45d6c78fd 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -333,7 +333,33 @@ namespace ggml_cuda_mma { static __device__ __forceinline__ int get_j(const int l) { if constexpr (I == 16 && J == 8) { - return 4 * (threadIdx.x / 16) + l; + return ne * (threadIdx.x / 16) + l; + } else { + NO_DEVICE_CODE; + return -1; + } + } +#elif defined(AMD_MFMA_AVAILABLE) + static constexpr int ne = I * J / 64; + half2 x[ne] = {{0.0f, 0.0f}}; + + static constexpr __device__ bool supported() { + if (I == 16 && J == 8) return true; + return false; + } + + static __device__ __forceinline__ int get_i(const int l) { + if constexpr (I == 16 && J == 8) { + return threadIdx.x % 16; + } else { + NO_DEVICE_CODE; + return -1; + } + } + + static __device__ __forceinline__ int get_j(const int l) { + if constexpr (I == 16 && J == 8) { + return ne * (threadIdx.x / 16) + l; } else { NO_DEVICE_CODE; return -1; @@ -391,7 +417,22 @@ namespace ggml_cuda_mma { static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR; #if defined(AMD_WMMA_AVAILABLE) - static constexpr int ne = I * J / 32; + static constexpr int ne = tile::ne; + nv_bfloat162 x[ne] = {{0.0f, 0.0f}}; + + static constexpr __device__ bool supported() { + return tile::supported(); + } + + static __device__ __forceinline__ int get_i(const int l) { + return tile::get_i(l); + } + + static __device__ __forceinline__ int get_j(const int l) { + return tile::get_j(l); + } +#elif defined(AMD_MFMA_AVAILABLE) + static constexpr int ne = tile::ne; nv_bfloat162 x[ne] = {{0.0f, 0.0f}}; static constexpr __device__ bool supported() { @@ -945,6 +986,32 @@ namespace ggml_cuda_mma { #endif // AMPERE_MMA_AVAILABLE } + template + static __device__ __forceinline__ void mma( + tile<16, 16, float, dl_d> & D, const tile<16, 8, float, dl_ab> & A, const tile<16, 8, float, dl_ab> & B) { +#ifdef AMD_MFMA_AVAILABLE + using floatx4_t = __attribute__((ext_vector_type(4))) float; + floatx4_t& acc_frag = reinterpret_cast(D.x[0]); +#if defined(CDNA3) + using floatx2_t = __attribute__((ext_vector_type(2))) float; + const floatx2_t& a_frag = reinterpret_cast(A.x[0]); + const floatx2_t& b_frag = reinterpret_cast(B.x[0]); + acc_frag = __builtin_amdgcn_mfma_f32_16x16x8_xf32(a_frag, b_frag, acc_frag, 0, 0, 0); +#elif defined(CDNA2) || defined(CDNA1) +#pragma unroll + for (int i = 0; i < 2; ++i) { + acc_frag = __builtin_amdgcn_mfma_f32_16x16x4f32(A.x[i], B.x[i], acc_frag, 0, 0, 0); + } +#else + GGML_UNUSED_VARS(D, A, B); + NO_DEVICE_CODE; +#endif // defined(CDNA3) +#else + GGML_UNUSED_VARS(D, A, B); + NO_DEVICE_CODE; +#endif // AMD_MFMA_AVAILABLE + } + static __device__ __forceinline__ void mma_block_scaled(tile<16, 8, float> & D, const tile<16, 8, int> & A, const tile<8, 8, int> & B, @@ -1054,6 +1121,13 @@ namespace ggml_cuda_mma { GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; #endif // RDNA4 +#elif defined(AMD_MFMA_AVAILABLE) + using halfx4_t = __attribute__((ext_vector_type(4))) _Float16; + using floatx4_t = __attribute__((ext_vector_type(4))) float; + floatx4_t& acc_frag = reinterpret_cast(D.x[0]); + const halfx4_t& a_frag = reinterpret_cast(A.x[0]); + const halfx4_t& b_frag = reinterpret_cast(B.x[0]); + acc_frag = __builtin_amdgcn_mfma_f32_16x16x16f16(a_frag, b_frag, acc_frag, 0, 0, 0); #else GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; @@ -1081,11 +1155,31 @@ namespace ggml_cuda_mma { #else GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; -#endif // RDNA4 +#endif // defined(RDNA4) +#elif defined(AMD_MFMA_AVAILABLE) + using floatx4_t = __attribute__((ext_vector_type(4))) float; + floatx4_t& acc_frag = reinterpret_cast(D.x[0]); +#if defined(CDNA3) || defined(CDNA2) + using bf16x4_t = __attribute__((ext_vector_type(4))) __bf16; + const bf16x4_t& a_frag = reinterpret_cast(A.x[0]); + const bf16x4_t& b_frag = reinterpret_cast(B.x[0]); + acc_frag = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_frag, b_frag, acc_frag, 0, 0, 0); +#elif defined(CDNA1) +#pragma unroll + for (int i = 0; i < 2; ++i) { + using bf16x2_t = __attribute__((ext_vector_type(2))) __bf16; + const bf16x2_t& a_frag = reinterpret_cast(A.x[i]); + const bf16x2_t& b_frag = reinterpret_cast(B.x[i]); + acc_frag = __builtin_amdgcn_mfma_f32_16x16x8bf16(a_frag, b_frag, acc_frag, 0, 0, 0); + } #else GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; -#endif // AMPERE_MMA_AVAILABLE +#endif // defined(CDNA3) || defined(CDNA2) +#else + GGML_UNUSED_VARS(D, A, B); + NO_DEVICE_CODE; +#endif // defined(AMD_WMMA_AVAILABLE) } template diff --git a/ggml/src/ggml-cuda/mmf.cu b/ggml/src/ggml-cuda/mmf.cu index 6643f243b12..aad4c34aa66 100644 --- a/ggml/src/ggml-cuda/mmf.cu +++ b/ggml/src/ggml-cuda/mmf.cu @@ -2,6 +2,13 @@ #include "mmf.cuh" #include "mmid.cuh" +static __forceinline__ int mmf_get_rows_per_block(const int cc) { + if (GGML_CUDA_CC_IS_CDNA(cc)) { + return MMF_ROWS_PER_BLOCK_CDNA; + } else { + return MMF_ROWS_PER_BLOCK; + } +} void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) { GGML_ASSERT( src1->type == GGML_TYPE_F32); @@ -89,28 +96,32 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr ids_info_ptr = &ids_info; } + const int device = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[device].cc; + const int rows_per_block = mmf_get_rows_per_block(cc); + switch (src0->type) { case GGML_TYPE_F32: { const float * src0_d = (const float *) src0->data; constexpr int vals_per_T = 1; - mul_mat_f_switch_cols_per_block( - src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, + mul_mat_f_switch_rows_per_block( + rows_per_block, src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr); } break; case GGML_TYPE_F16: { const half2 * src0_d = (const half2 *) src0->data; constexpr int vals_per_T = 2; - mul_mat_f_switch_cols_per_block( - src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, + mul_mat_f_switch_rows_per_block( + rows_per_block, src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr); } break; case GGML_TYPE_BF16: { const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data; constexpr int vals_per_T = 2; - mul_mat_f_switch_cols_per_block( - src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, + mul_mat_f_switch_rows_per_block( + rows_per_block, src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr); } break; @@ -140,7 +151,11 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const return false; } } - if (src0_ne[1] % MMF_ROWS_PER_BLOCK != 0) { + if (src0_ne[1] % mmf_get_rows_per_block(cc) != 0) { + return false; + } + + if (GGML_CUDA_CC_IS_CDNA3(cc) && type == GGML_TYPE_BF16) { return false; } @@ -153,6 +168,11 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const } else { if (GGML_CUDA_CC_IS_RDNA3_0(cc) && src1_ncols > 8) { return false; + } else if (GGML_CUDA_CC_IS_CDNA2(cc) && (type == GGML_TYPE_F16 || type == GGML_TYPE_BF16)) { + //TODO: truse CDNA2 as CDNA1, tune the perf when CDNA2 is available. + return false; + } else if (GGML_CUDA_CC_IS_CDNA1(cc) && (type == GGML_TYPE_F16 || type == GGML_TYPE_BF16)) { + return false; } else if (src1_ncols > 16) { return false; } @@ -160,11 +180,11 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const switch (type) { case GGML_TYPE_F32: - return ampere_mma_available(cc); + return ampere_mma_available(cc) || amd_mfma_available(cc); case GGML_TYPE_F16: - return volta_mma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc); + return volta_mma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc) || amd_mfma_available(cc); case GGML_TYPE_BF16: - return ampere_mma_available(cc) || amd_wmma_available(cc); + return ampere_mma_available(cc) || amd_wmma_available(cc) || amd_mfma_available(cc); default: return false; } diff --git a/ggml/src/ggml-cuda/mmf.cuh b/ggml/src/ggml-cuda/mmf.cuh index e36730948ff..c2a8d54c95a 100644 --- a/ggml/src/ggml-cuda/mmf.cuh +++ b/ggml/src/ggml-cuda/mmf.cuh @@ -7,6 +7,31 @@ using namespace ggml_cuda_mma; #define MMF_ROWS_PER_BLOCK 32 +#define MMF_ROWS_PER_BLOCK_CDNA 64 + +static __forceinline__ int64_t mmf_get_max_block_size(int cc) { + if (GGML_CUDA_CC_IS_CDNA(cc)) { + return 512; + } else { + return 256; + } +} + +static __forceinline__ int mmf_get_padding(int cc) { + if (GGML_CUDA_CC_IS_CDNA(cc)) { + return 2; + } else { + return 4; + } +} + +static constexpr __device__ int mmf_get_padding() { +#if defined(AMD_MFMA_AVAILABLE) + return 2; +#else + return 4; +#endif // defined(AMD_MFMA_AVAILABLE) +} struct mmf_ids_data { const int32_t * ids_src_compact = nullptr; @@ -29,23 +54,25 @@ static __global__ void mul_mat_f( const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { // TODO: handle this in a consistent and simpler way after AMD MFMA support has been added -#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE) +#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) #if defined(AMD_WMMA_AVAILABLE) - // Special case for tf32, just dummy mma layout as wmma doesn't support it. - constexpr bool is_tf32 = std::is_same_v; - constexpr int tile_B_I = is_tf32 ? 8 : 16; - constexpr int tile_C_J = is_tf32 ? 8 : 16; - constexpr data_layout ab_layout = is_tf32 ? DATA_LAYOUT_I_MAJOR : get_input_data_layout(); - typedef tile<16, 8, T, ab_layout> tile_A; - typedef tile tile_B; - typedef tile<16, tile_C_J, float, DATA_LAYOUT_J_MAJOR> tile_C; + if constexpr (!(std::is_same_v || std::is_same_v) || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else { + typedef tile<16, 8, T, get_input_data_layout()> tile_A; + typedef tile<16, 8, T, get_input_data_layout()> tile_B; + typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C; +#elif defined(AMD_MFMA_AVAILABLE) + if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK_CDNA) {NO_DEVICE_CODE;} else { + typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_A; + typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_B; + typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C; #else #ifdef VOLTA_MMA_AVAILABLE - if constexpr (!std::is_same_v) {NO_DEVICE_CODE;} else { + if constexpr (!std::is_same_v || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else { typedef tile<32, 4, T, DATA_LAYOUT_I_MAJOR> tile_A; typedef tile< 8, 4, T, DATA_LAYOUT_I_MAJOR_MIRRORED> tile_B; typedef tile<32, 8, float, DATA_LAYOUT_I_MAJOR> tile_C; #else + if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else { typedef tile<16, 8, T> tile_A; typedef tile<8, 8, T> tile_B; typedef tile<16, 8, float> tile_C; @@ -57,7 +84,7 @@ static __global__ void mul_mat_f( } constexpr int warp_size = ggml_cuda_get_physical_warp_size(); - constexpr int tile_k_padded = warp_size + 4; + constexpr int tile_k_padded = warp_size + mmf_get_padding(); constexpr int ntA = rows_per_block / tile_A::I; constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I; @@ -198,7 +225,7 @@ static __global__ void mul_mat_f( } float * buf_iw = (float *) compute_base; - constexpr int kiw = nwarps*rows_per_block + 4; + constexpr int kiw = nwarps*rows_per_block + mmf_get_padding(); if (nwarps > 1) { __syncthreads(); @@ -228,27 +255,34 @@ static __global__ void mul_mat_f( return; } - float sum = 0.0f; - static_assert(rows_per_block == warp_size, "need loop/check"); + float sum[rows_per_block/warp_size] = {0.0f}; + static_assert((rows_per_block % warp_size) == 0, "rows_per_block must be a multiple of warp_size."); #pragma unroll for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) { - const int i = i0 + threadIdx.x; +#pragma unroll + for (int i1 = 0; i1 < sizeof(sum)/sizeof(sum[0]); ++i1) { + const int i = i0 + i1*warp_size + threadIdx.x; - sum += buf_iw[j*kiw + i]; + sum[i1] += buf_iw[j*kiw + i]; + } } if constexpr (!has_ids) { - dst[j*stride_col_dst + row0 + threadIdx.x] = sum; +#pragma unroll + for (int i0 = 0; i0 < sizeof(sum)/sizeof(sum[0]); ++i0) { + dst[j*stride_col_dst + row0 + i0*warp_size + threadIdx.x] = sum[i0]; + } } else { const int slot = (j < cols_per_block) ? slot_map[j] : -1; if (slot >= 0 && (col_base + j) < ncols_dst_total) { - dst[slot*stride_channel_dst + j*stride_col_dst + row0 + threadIdx.x] = sum; +#pragma unroll + for (int i0 = 0; i0 < sizeof(sum)/sizeof(sum[0]); ++i0) { + dst[slot*stride_channel_dst + j*stride_col_dst + row0 + i0*warp_size + threadIdx.x] = sum[i0]; + } } } } -#ifdef VOLTA_MMA_AVAILABLE } -#endif //VOLTA_MMA_AVAILABLE #else GGML_UNUSED_VARS(x, y, ids, dst, ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst, @@ -256,7 +290,7 @@ static __global__ void mul_mat_f( channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); NO_DEVICE_CODE; -#endif // (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE) +#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) } //This kernel is for larger batch sizes of mul_mat_id @@ -271,23 +305,25 @@ static __global__ void mul_mat_f_ids( const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, const uint3 sis1_fd, const uint3 nch_fd) { // TODO: handle this in a consistent and simpler way after AMD MFMA support has been added -#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE) +#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) #if defined(AMD_WMMA_AVAILABLE) - // Special case for tf32, just dummy mma layout as wmma doesn't support it. - constexpr bool is_tf32 = std::is_same_v; - constexpr int tile_B_I = is_tf32 ? 8 : 16; - constexpr int tile_C_J = is_tf32 ? 8 : 16; - constexpr data_layout ab_layout = is_tf32 ? DATA_LAYOUT_I_MAJOR : get_input_data_layout(); - typedef tile<16, 8, T, ab_layout> tile_A; - typedef tile tile_B; - typedef tile<16, tile_C_J, float, DATA_LAYOUT_J_MAJOR> tile_C; + if constexpr (!(std::is_same_v || std::is_same_v) || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else { + typedef tile<16, 8, T, get_input_data_layout()> tile_A; + typedef tile<16, 8, T, get_input_data_layout()> tile_B; + typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C; +#elif defined(AMD_MFMA_AVAILABLE) + if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK_CDNA) {NO_DEVICE_CODE;} else { + typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_A; + typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_B; + typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C; #else #ifdef VOLTA_MMA_AVAILABLE - if constexpr (!std::is_same_v) {NO_DEVICE_CODE;} else { + if constexpr (!std::is_same_v || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else { typedef tile<32, 4, T, DATA_LAYOUT_I_MAJOR> tile_A; typedef tile< 8, 4, T, DATA_LAYOUT_I_MAJOR_MIRRORED> tile_B; typedef tile<32, 8, float, DATA_LAYOUT_I_MAJOR> tile_C; #else + if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else { typedef tile<16, 8, T> tile_A; typedef tile<8, 8, T> tile_B; typedef tile<16, 8, float> tile_C; @@ -300,7 +336,7 @@ static __global__ void mul_mat_f_ids( constexpr int warp_size = ggml_cuda_get_physical_warp_size(); - constexpr int tile_k_padded = warp_size + 4; + constexpr int tile_k_padded = warp_size + mmf_get_padding(); constexpr int ntA = rows_per_block / tile_A::I; constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I; @@ -467,7 +503,7 @@ static __global__ void mul_mat_f_ids( } float * buf_iw = (float *) compute_base; - constexpr int kiw = nwarps*rows_per_block + 4; + constexpr int kiw = nwarps*rows_per_block + mmf_get_padding(); if (nwarps > 1) { __syncthreads(); @@ -497,13 +533,16 @@ static __global__ void mul_mat_f_ids( return; } - float sum = 0.0f; - static_assert(rows_per_block == warp_size, "need loop/check"); + float sum[rows_per_block/warp_size] = {0.0f}; + static_assert((rows_per_block % warp_size) == 0, "rows_per_block must be a multiple of warp_size."); #pragma unroll for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) { - const int i = i0 + threadIdx.x; +#pragma unroll + for (int i1 = 0; i1 < sizeof(sum)/sizeof(sum[0]); ++i1) { + const int i = i0 + i1*warp_size + threadIdx.x; - sum += buf_iw[j*kiw + i]; + sum[i1] += buf_iw[j * kiw + i]; + } } const int global_j = col_base + j; @@ -513,23 +552,24 @@ static __global__ void mul_mat_f_ids( const int token = (int) qrm.x; if (token < ncols_dst_total) { const int slot = (int) qrm.y; - dst[slot*stride_channel_dst + token*stride_col_dst + row0 + threadIdx.x] = sum; +#pragma unroll + for (int i0 = 0; i0 < sizeof(sum)/sizeof(sum[0]); ++i0) { + dst[slot * stride_channel_dst + token * stride_col_dst + row0 + i0*warp_size + threadIdx.x] = sum[i0]; + } } } } -#ifdef VOLTA_MMA_AVAILABLE } -#endif // VOLTA_MMA_AVAILABLE #else GGML_UNUSED_VARS(x, y, ids_src_compact, ids_dst_compact, expert_bounds, dst, ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd); NO_DEVICE_CODE; -#endif // (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE) +#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) } -template +template static inline void mul_mat_f_switch_ids( const T * x, const float * y, const int32_t * ids, float * dst, const int64_t ncols_x, const int64_t ncols_dst, const int64_t nchannels_dst, @@ -553,7 +593,7 @@ static inline void mul_mat_f_switch_ids( const uint3 sis1_fd = ids_data->sis1 > 0 ? init_fastdiv_values((uint32_t) ids_data->sis1) : make_uint3(0, 0, 1); const uint3 nch_fd = init_fastdiv_values((uint32_t) nchannels_dst); - mul_mat_f_ids<<>> + mul_mat_f_ids<<>> (x, y, ids_data->ids_src_compact, ids_data->ids_dst_compact, ids_data->expert_bounds_dev, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, @@ -564,19 +604,19 @@ static inline void mul_mat_f_switch_ids( dim3 block_nums_ids = block_nums; block_nums_ids.y *= col_tiles; - mul_mat_f<<>> + mul_mat_f<<>> (x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } else { - mul_mat_f<<>> + mul_mat_f<<>> (x, y, ids, dst, ncols_x, cols_per_block, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } } -template +template void mul_mat_f_cuda( const T * x, const float * y, const int32_t * ids, float * dst, const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst, @@ -605,7 +645,7 @@ void mul_mat_f_cuda( int64_t nwarps_best = 1; int64_t niter_best = (ncols_x + warp_size*2 - 1) / (warp_size*2); - int64_t max_block_size = 256; + int64_t max_block_size = mmf_get_max_block_size(cc); for (int64_t nwarps = 2; nwarps <= max_block_size/warp_size; nwarps++) { const int64_t niter = (ncols_x + nwarps*warp_size*2 - 1) / (nwarps*warp_size*2); if (niter < niter_best) { @@ -614,10 +654,9 @@ void mul_mat_f_cuda( } } - constexpr int rows_per_block = MMF_ROWS_PER_BLOCK; - const int nbytes_shared_iter = nwarps_best * (volta_mma_available(cc) ? tile_A_32::I : tile_A_16::I) * (warp_size + 4) * 4; - const int nbytes_cols_per_block_pad = amd_wmma_available(cc) ? tile_B_16::I : tile_B_8::I; - const int nbytes_shared_combine = GGML_PAD(cols_per_block, nbytes_cols_per_block_pad) * (nwarps_best*rows_per_block + 4) * 4; + const int nbytes_shared_iter = nwarps_best * (volta_mma_available(cc) ? tile_A_32::I : tile_A_16::I) * (warp_size + mmf_get_padding(cc)) * 4; + const int nbytes_cols_per_block_pad = (amd_wmma_available(cc) || amd_mfma_available(cc)) ? tile_B_16::I : tile_B_8::I; + const int nbytes_shared_combine = GGML_PAD(cols_per_block, nbytes_cols_per_block_pad) * (nwarps_best*rows_per_block + mmf_get_padding(cc)) * 4; const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine); const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0; const int nbytes_shared_total = nbytes_shared + nbytes_slotmap; @@ -628,56 +667,56 @@ void mul_mat_f_cuda( switch (nwarps_best) { case 1: { - mul_mat_f_switch_ids( + mul_mat_f_switch_ids( x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, ids_data); } break; case 2: { - mul_mat_f_switch_ids( + mul_mat_f_switch_ids( x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, ids_data); } break; case 3: { - mul_mat_f_switch_ids( + mul_mat_f_switch_ids( x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, ids_data); } break; case 4: { - mul_mat_f_switch_ids( + mul_mat_f_switch_ids( x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, ids_data); } break; case 5: { - mul_mat_f_switch_ids( + mul_mat_f_switch_ids( x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, ids_data); } break; case 6: { - mul_mat_f_switch_ids( + mul_mat_f_switch_ids( x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, ids_data); } break; case 7: { - mul_mat_f_switch_ids( + mul_mat_f_switch_ids( x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, ids_data); } break; case 8: { - mul_mat_f_switch_ids( + mul_mat_f_switch_ids( x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, @@ -691,7 +730,7 @@ void mul_mat_f_cuda( GGML_UNUSED_VARS(nchannels_y); } -template +template static void mul_mat_f_switch_cols_per_block( const T * x, const float * y, const int32_t * ids, float * dst, const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst, @@ -708,82 +747,82 @@ static void mul_mat_f_switch_cols_per_block( switch (ncols_case) { case 1: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 2: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 3: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 4: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 5: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 6: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 7: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 8: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 9: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 10: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 11: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 12: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 13: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 14: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 15: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 16: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; @@ -793,8 +832,36 @@ static void mul_mat_f_switch_cols_per_block( } } -#define DECL_MMF_CASE_HELPER(T, ncols_dst) \ - template void mul_mat_f_cuda( \ +template +static void mul_mat_f_switch_rows_per_block( + const int rows_per_block, const T * x, const float * y, const int32_t * ids, float * dst, + const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst, + const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, + const int64_t stride_col_id, const int stride_row_id, + const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, + const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, + const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, + cudaStream_t stream, const mmf_ids_data * ids_data) { + switch (rows_per_block) { + case MMF_ROWS_PER_BLOCK: { + mul_mat_f_switch_cols_per_block( + x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); + } break; + case MMF_ROWS_PER_BLOCK_CDNA: { + mul_mat_f_switch_cols_per_block( + x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); + } break; + default: + GGML_ABORT("unsupported rows_per_block: %i", rows_per_block); + } +} + +#define DECL_MMF_CASE_HELPER(T, nrows_dst, ncols_dst) \ + template void mul_mat_f_cuda( \ const T * x, const float * y, const int32_t * ids, float * dst, \ const int64_t ncols_x, const int64_t nrows_x, int64_t ncols_dst_total, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, \ const int64_t stride_col_id, const int64_t stride_row_id, \ @@ -803,16 +870,22 @@ static void mul_mat_f_switch_cols_per_block( const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, \ cudaStream_t stream, const mmf_ids_data * ids_data); -#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) +#if !defined(GGML_USE_MUSA) #define DECL_MMF_CASE_EXTERN(ncols_dst) \ - extern DECL_MMF_CASE_HELPER(float, ncols_dst) \ - extern DECL_MMF_CASE_HELPER(half2, ncols_dst) \ - extern DECL_MMF_CASE_HELPER(nv_bfloat162, ncols_dst) + extern DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK, ncols_dst) \ + extern DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK, ncols_dst) \ + extern DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK, ncols_dst) \ + extern DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \ + extern DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \ + extern DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) #define DECL_MMF_CASE(ncols_dst) \ - DECL_MMF_CASE_HELPER(float, ncols_dst) \ - DECL_MMF_CASE_HELPER(half2, ncols_dst) \ - DECL_MMF_CASE_HELPER(nv_bfloat162, ncols_dst) + DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK, ncols_dst) \ + DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK, ncols_dst) \ + DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK, ncols_dst) \ + DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \ + DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \ + DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) DECL_MMF_CASE_EXTERN(1); DECL_MMF_CASE_EXTERN(2); diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index a382e6a6979..f80f98cda2c 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -3697,13 +3697,20 @@ static __global__ void mul_mat_q( tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop); } - template -static __global__ void mul_mat_q_stream_k_fixup( - const int32_t * ids_dst, const int32_t * expert_bounds, float * __restrict__ dst, const float * __restrict__ tmp_last_tile, - const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_col_dst, - const int nchannels_y, const int stride_channel_dst, const int nsamples_y, const int stride_sample_dst, - const int ncols_max) { +static __global__ void mul_mat_q_stream_k_fixup(const int32_t * ids_dst, + const int32_t * expert_bounds, + float * __restrict__ dst, + const float * __restrict__ tmp_last_tile, + const int ncols_x, + const int nrows_x, + const int ncols_dst, + const size_t stride_col_dst, + const int nchannels_y, + const size_t stride_channel_dst, + const int nsamples_y, + const size_t stride_sample_dst, + const int ncols_max) { constexpr int mmq_y = get_mmq_y_device(); constexpr int qk = ggml_cuda_type_traits::qk; constexpr int ITER_K = get_iter_k(type); diff --git a/ggml/src/ggml-cuda/mmvf.cu b/ggml/src/ggml-cuda/mmvf.cu index 32948e4d7a1..d9147202429 100644 --- a/ggml/src/ggml-cuda/mmvf.cu +++ b/ggml/src/ggml-cuda/mmvf.cu @@ -4,26 +4,48 @@ #include "mmvf.cuh" #include "convert.cuh" -template +template static __global__ void mul_mat_vec_f( const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst, - const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst, + const int ncols2, const uint3 nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst, const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, - const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { + const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, + const int ids_stride) { const int row = blockIdx.x; + // for MUL_MAT_ID - blockIdx.y = n_expert_used, blockIdx.z = ncols_dst (tokens) const int channel_dst = blockIdx.y; - const int channel_x = ids ? ids[channel_dst] : fastdiv((uint32_t) channel_dst, channel_ratio); - const int channel_y = ids ? channel_dst % nchannels_y : channel_dst; - const int sample_dst = blockIdx.z; + const int tid = threadIdx.x; + + int token_idx; + int channel_x; + int channel_y; + int sample_dst; + + if constexpr (is_multi_token_id) { + // Multi-token MUL_MAT_ID path, adding these in the normal path causes a perf regression for n_tokens=1 case + token_idx = blockIdx.z; + channel_x = ids[channel_dst + token_idx * ids_stride]; + channel_y = fastmodulo(channel_dst, nchannels_y); + sample_dst = 0; + } else { + token_idx = ids ? blockIdx.z : 0; + channel_x = ids ? ids[blockIdx.y + token_idx * ids_stride] : fastdiv((uint32_t) channel_dst, channel_ratio); + channel_y = ids ? fastmodulo(blockIdx.y, nchannels_y) : channel_dst; + sample_dst = ids ? 0 : blockIdx.z; + } + const int sample_x = fastdiv((uint32_t) sample_dst, sample_ratio); const int sample_y = sample_dst; - const int tid = threadIdx.x; constexpr int warp_size = ggml_cuda_get_physical_warp_size(); x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row*stride_row; y += int64_t(sample_y) *stride_sample_y + channel_y *stride_channel_y; dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst; + if constexpr (is_multi_token_id) { + y += token_idx*stride_col_y2*2; + dst += token_idx*stride_col_dst; + } bool use_gate = false; bool use_bias = false; @@ -56,8 +78,10 @@ static __global__ void mul_mat_vec_f( if (use_gate) { gate_x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row*stride_row; } + + const int channel_bias = ids ? channel_x : channel_dst; + if constexpr (has_fusion) { - const int channel_bias = ids ? channel_x : channel_dst; if (use_bias) { x_bias += int64_t(sample_dst)*stride_sample_dst + channel_bias*stride_channel_dst; } @@ -349,36 +373,36 @@ static __global__ void mul_mat_vec_f( } } -template +template static void mul_mat_vec_f_switch_fusion( const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, - const int64_t ncols, const int64_t nrows, + const int64_t ncols, const uint3 nchannels_y, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, - const dim3 & block_dims, const dim3 & block_nums, const int nbytes_shared, const cudaStream_t stream) { + const dim3 & block_dims, const dim3 & block_nums, const int nbytes_shared, const int ids_stride, const cudaStream_t stream) { const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; if constexpr (ncols_dst == 1) { if (has_fusion) { - mul_mat_vec_f<<>> - (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + mul_mat_vec_f<<>> + (x, y, ids, fusion, dst, ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride); return; } } GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1"); - mul_mat_vec_f<<>> - (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + mul_mat_vec_f<<>> + (x, y, ids, fusion, dst, ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride); } -template +template void launch_mul_mat_vec_f_cuda( const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, const int64_t ncols, const int64_t nrows, @@ -386,12 +410,13 @@ void launch_mul_mat_vec_f_cuda( const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, - cudaStream_t stream) { + const int64_t nsamples_or_ntokens, const int64_t ids_stride, cudaStream_t stream) { GGML_ASSERT(ncols % 2 == 0); GGML_ASSERT(stride_row % 2 == 0); GGML_ASSERT(stride_col_y % 2 == 0); GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0); GGML_ASSERT( nsamples_dst % nsamples_x == 0); + const uint3 nchannels_y_fd = ids ? init_fastdiv_values(nchannels_y) : make_uint3(0, 0, 0); const uint3 channel_ratio_fd = ids ? make_uint3(0, 0, 0) : init_fastdiv_values(nchannels_dst / nchannels_x); const uint3 sample_ratio_fd = init_fastdiv_values(nsamples_dst / nsamples_x); @@ -415,56 +440,56 @@ void launch_mul_mat_vec_f_cuda( const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; const int nbytes_shared = warp_size*sizeof(float) + (has_fusion ? warp_size*sizeof(float) : 0); - const dim3 block_nums(nrows, nchannels_dst, nsamples_dst); + const dim3 block_nums(nrows, nchannels_dst, nsamples_or_ntokens); const dim3 block_dims(block_size_best, 1, 1); switch (block_size_best) { case 32: { - mul_mat_vec_f_switch_fusion - (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + mul_mat_vec_f_switch_fusion + (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream); } break; case 64: { - mul_mat_vec_f_switch_fusion - (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + mul_mat_vec_f_switch_fusion + (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream); } break; case 96: { - mul_mat_vec_f_switch_fusion - (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + mul_mat_vec_f_switch_fusion + (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream); } break; case 128: { - mul_mat_vec_f_switch_fusion - (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + mul_mat_vec_f_switch_fusion + (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream); } break; case 160: { - mul_mat_vec_f_switch_fusion - (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + mul_mat_vec_f_switch_fusion + (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream); } break; case 192: { - mul_mat_vec_f_switch_fusion - (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + mul_mat_vec_f_switch_fusion + (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream); } break; case 224: { - mul_mat_vec_f_switch_fusion - (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + mul_mat_vec_f_switch_fusion + (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream); } break; case 256: { - mul_mat_vec_f_switch_fusion - (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + mul_mat_vec_f_switch_fusion + (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream); } break; default: { GGML_ABORT("fatal error"); @@ -480,55 +505,88 @@ static void mul_mat_vec_f_cuda_switch_ncols_dst( const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, - cudaStream_t stream) { + const int64_t ids_stride, cudaStream_t stream) { + + const bool has_ids = ids != nullptr; + + if (has_ids && ncols_dst > 1) { + // Multi-token MUL_MAT_ID path only - single-token goes through regular path below + constexpr int c_ncols_dst = 1; + launch_mul_mat_vec_f_cuda + (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + ncols_dst, ids_stride, stream); + return; + } + + if (has_ids) { + // Single-token MUL_MAT_ID path + constexpr int c_ncols_dst = 1; + launch_mul_mat_vec_f_cuda + (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + ncols_dst, ids_stride, stream); + return; + } + switch (ncols_dst) { case 1: launch_mul_mat_vec_f_cuda (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + nsamples_dst, ids_stride, stream); break; case 2: launch_mul_mat_vec_f_cuda (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + nsamples_dst, ids_stride, stream); break; case 3: launch_mul_mat_vec_f_cuda (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + nsamples_dst, ids_stride, stream); break; case 4: launch_mul_mat_vec_f_cuda (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + nsamples_dst, ids_stride, stream); break; case 5: launch_mul_mat_vec_f_cuda (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + nsamples_dst, ids_stride, stream); break; case 6: launch_mul_mat_vec_f_cuda (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + nsamples_dst, ids_stride, stream); break; case 7: launch_mul_mat_vec_f_cuda (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + nsamples_dst, ids_stride, stream); break; case 8: launch_mul_mat_vec_f_cuda (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + nsamples_dst, ids_stride, stream); break; default: GGML_ABORT("fatal error"); @@ -544,21 +602,21 @@ static void mul_mat_vec_f_cuda( const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, - enum ggml_prec prec, cudaStream_t stream) { + const int64_t ids_stride, enum ggml_prec prec, cudaStream_t stream) { if constexpr(std::is_same_v) { if (prec == GGML_PREC_DEFAULT) { mul_mat_vec_f_cuda_switch_ncols_dst (x, y, ids, fusion, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); return; } } mul_mat_vec_f_cuda_switch_ncols_dst (x, y, ids, fusion, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); } void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, @@ -573,7 +631,7 @@ void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor const size_t ts_src1 = ggml_type_size(src1->type); const size_t ts_dst = ggml_type_size(dst->type); - GGML_ASSERT(!ids || ne12 == 1); // Implementation is only correct for batch size 1. + GGML_ASSERT(!ids || ne12 <= MMVF_MAX_BATCH_SIZE); GGML_ASSERT(ne13 == ne3); GGML_ASSERT( nb00 == ts_src0); @@ -626,29 +684,31 @@ void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor const int64_t ncols_dst = ids ? ne2 : ne1; const int64_t nchannels_y = ids ? ne11 : ne12; const int64_t nchannels_dst = ids ? ne1 : ne2; + const int64_t stride_col_dst = ids ? s2 : s1; + const int64_t stride_col_y = ids ? s12 : s11; const int64_t stride_channel_dst = ids ? s1 : s2; const int64_t stride_channel_y = ids ? s11 : s12; - GGML_ASSERT(!ids || ncols_dst == 1); + const int64_t ids_stride = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0; switch (src0->type) { case GGML_TYPE_F32: { const float * src0_d = (const float *) src0->data; - mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1, + mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, stride_col_y, stride_col_dst, ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, - ne03, ne3, s03, s13, s3, prec, ctx.stream()); + ne03, ne3, s03, s13, s3, ids_stride, prec, ctx.stream()); } break; case GGML_TYPE_F16: { const half * src0_d = (const half *) src0->data; - mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1, + mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, stride_col_y, stride_col_dst, ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, - ne03, ne3, s03, s13, s3, prec, ctx.stream()); + ne03, ne3, s03, s13, s3, ids_stride, prec, ctx.stream()); } break; case GGML_TYPE_BF16: { const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data; - mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1, + mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, stride_col_y, stride_col_dst, ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, - ne03, ne3, s03, s13, s3, prec, ctx.stream()); + ne03, ne3, s03, s13, s3, ids_stride, prec, ctx.stream()); } break; default: GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type)); @@ -695,19 +755,19 @@ void ggml_cuda_op_mul_mat_vec_f( const float * src0_d = (const float *) src0_dd_i; mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, 0, prec, stream); } break; case GGML_TYPE_F16: { const half * src0_d = (const half *) src0_dd_i; mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, 0, prec, stream); } break; case GGML_TYPE_BF16: { const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i; mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, 0, prec, stream); } break; default: GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type)); diff --git a/ggml/src/ggml-cuda/mmvf.cuh b/ggml/src/ggml-cuda/mmvf.cuh index a09fbdc7202..a50f7c02180 100644 --- a/ggml/src/ggml-cuda/mmvf.cuh +++ b/ggml/src/ggml-cuda/mmvf.cuh @@ -1,5 +1,7 @@ #include "common.cuh" +#define MMVF_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVF kernels. + void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, const ggml_cuda_mm_fusion_args_host * fusion = nullptr); diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index d671551c171..ce25ccf427c 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -137,15 +137,15 @@ static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int return 1; } -// tell the compiler to use as many registers as it wants, see nwarps definition below -template +template __launch_bounds__(calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1) static __global__ void mul_mat_vec_q( const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst, const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y, const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x, const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio, - const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst) { + const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst, + const uint32_t ids_stride) { constexpr int qk = ggml_cuda_type_traits::qk; constexpr int qi = ggml_cuda_type_traits::qi; @@ -162,11 +162,25 @@ static __global__ void mul_mat_vec_q( const int blocks_per_row_x = ncols_x / qk; constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi; - // The MUL_MAT_ID code path with ids != nullptr is only implemented for ncols_dst == 1. const uint32_t channel_dst = blockIdx.y; - const uint32_t channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv(channel_dst, channel_ratio); - const uint32_t channel_y = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst; - const uint32_t sample_dst = blockIdx.z; + + uint32_t token_idx = 0; + uint32_t channel_x; + uint32_t channel_y; + uint32_t sample_dst; + + if constexpr (is_multi_token_id) { + // Multi-token MUL_MAT_ID path, adding these in the normal path causes a perf regression for n_tokens=1 case + token_idx = blockIdx.z; + channel_x = ids[channel_dst + token_idx * ids_stride]; + channel_y = fastmodulo(channel_dst, nchannels_y); + sample_dst = 0; + } else { + channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv(channel_dst, channel_ratio); + channel_y = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst; + sample_dst = blockIdx.z; + } + const uint32_t sample_x = fastdiv(sample_dst, sample_ratio); const uint32_t sample_y = sample_dst; @@ -188,11 +202,11 @@ static __global__ void mul_mat_vec_q( active_glu = fusion.glu_op; } - const uint32_t channel_bias = ids ? channel_x : channel_dst; float x_biases[ncols_dst] = { 0.0f }; float gate_biases[ncols_dst] = { 0.0f }; if constexpr (has_fusion) { + const uint32_t channel_bias = ids ? channel_x : channel_dst; if (use_bias) { x_bias = x_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0; // 1. Hide latency by prefetching bias and gate here @@ -222,6 +236,9 @@ static __global__ void mul_mat_vec_q( float tmp_gate[ncols_dst][rows_per_cuda_block] = {{0.0f}}; const block_q8_1 * y = ((const block_q8_1 *) vy) + sample_y*stride_sample_y + channel_y*stride_channel_y; + if constexpr (is_multi_token_id) { + y += token_idx*stride_col_y; + } const int kbx_offset = sample_x*stride_sample_x + channel_x*stride_channel_x + row0*stride_row_x; for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) { @@ -275,6 +292,10 @@ static __global__ void mul_mat_vec_q( dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst + row0; + if constexpr (is_multi_token_id) { + dst += token_idx*stride_col_dst; + } + // sum up partial sums and write back result #pragma unroll for (int j = 0; j < ncols_dst; ++j) { @@ -335,40 +356,41 @@ static __global__ void mul_mat_vec_q( } static std::pair calc_launch_params( - const int ncols_dst, const int nrows_x, const int nchannels_y, const int nsamples_y, + const int ncols_dst, const int nrows_x, const int nchannels_dst, const int nsamples_or_ntokens, const int warp_size, const mmvq_parameter_table_id table_id) { const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_dst, table_id) - 1) / calc_rows_per_block(ncols_dst, table_id); - const dim3 block_nums(nblocks, nchannels_y, nsamples_y); + const dim3 block_nums(nblocks, nchannels_dst, nsamples_or_ntokens); const dim3 block_dims(warp_size, calc_nwarps(ncols_dst, table_id), 1); return {block_nums, block_dims}; } -template +template static void mul_mat_vec_q_switch_fusion( const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y, const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x, const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio, const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst, - const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared, cudaStream_t stream) { + const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared, + const uint32_t ids_stride, cudaStream_t stream) { const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; if constexpr (c_ncols_dst == 1) { if (has_fusion) { - mul_mat_vec_q<<>> + mul_mat_vec_q<<>> (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride); return; } } GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1"); - mul_mat_vec_q<<>> + mul_mat_vec_q<<>> (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride); } template @@ -379,7 +401,7 @@ static void mul_mat_vec_q_switch_ncols_dst( const int nchannels_x, const int nchannels_y, const int nchannels_dst, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, - cudaStream_t stream) { + const int ids_stride, cudaStream_t stream) { GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0); GGML_ASSERT(ncols_dst <= MMVQ_MAX_BATCH_SIZE); @@ -393,8 +415,19 @@ static void mul_mat_vec_q_switch_ncols_dst( const mmvq_parameter_table_id table_id = get_device_table_id(ggml_cuda_info().devices[device].cc); const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; + const bool has_ids = ids != nullptr; + + if (has_ids && ncols_dst > 1) { + // Multi-token MUL_MAT_ID path only - single-token goes through regular path below + constexpr int c_ncols_dst = 1; + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, ncols_dst, warp_size, table_id); + mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, ids_stride, stream); + return; + } - GGML_ASSERT(!ids || ncols_dst == 1); switch (ncols_dst) { case 1: { constexpr int c_ncols_dst = 1; @@ -402,7 +435,7 @@ static void mul_mat_vec_q_switch_ncols_dst( mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, stream); + dims.first, dims.second, 0, ids_stride, stream); } break; case 2: { constexpr int c_ncols_dst = 2; @@ -410,7 +443,7 @@ static void mul_mat_vec_q_switch_ncols_dst( mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, stream); + dims.first, dims.second, 0, ids_stride, stream); } break; case 3: { constexpr int c_ncols_dst = 3; @@ -418,7 +451,7 @@ static void mul_mat_vec_q_switch_ncols_dst( mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, stream); + dims.first, dims.second, 0, ids_stride, stream); } break; case 4: { constexpr int c_ncols_dst = 4; @@ -426,7 +459,7 @@ static void mul_mat_vec_q_switch_ncols_dst( mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, stream); + dims.first, dims.second, 0, ids_stride, stream); } break; case 5: { constexpr int c_ncols_dst = 5; @@ -434,7 +467,7 @@ static void mul_mat_vec_q_switch_ncols_dst( mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, stream); + dims.first, dims.second, 0, ids_stride, stream); } break; case 6: { constexpr int c_ncols_dst = 6; @@ -442,7 +475,7 @@ static void mul_mat_vec_q_switch_ncols_dst( mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, stream); + dims.first, dims.second, 0, ids_stride, stream); } break; case 7: { constexpr int c_ncols_dst = 7; @@ -450,7 +483,7 @@ static void mul_mat_vec_q_switch_ncols_dst( mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, stream); + dims.first, dims.second, 0, ids_stride, stream); } break; case 8: { constexpr int c_ncols_dst = 8; @@ -458,7 +491,7 @@ static void mul_mat_vec_q_switch_ncols_dst( mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, stream); + dims.first, dims.second, 0, ids_stride, stream); } break; default: GGML_ABORT("fatal error"); @@ -474,127 +507,127 @@ static void mul_mat_vec_q_switch_type( const int nchannels_x, const int nchannels_y, const int nchannels_dst, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, - cudaStream_t stream) { + const int ids_stride, cudaStream_t stream) { switch (type_x) { case GGML_TYPE_Q4_0: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_Q4_1: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_Q5_0: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_Q5_1: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_Q8_0: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_MXFP4: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_Q2_K: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_Q3_K: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_Q4_K: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_Q5_K: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_Q6_K: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_IQ2_XXS: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_IQ2_XS: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_IQ2_S: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_IQ3_XXS: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_IQ1_S: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_IQ1_M: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_IQ4_NL: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_IQ4_XS: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_IQ3_S: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; default: GGML_ABORT("fatal error"); @@ -622,7 +655,7 @@ void ggml_cuda_mul_mat_vec_q( GGML_ASSERT( nb0 == ts_dst); GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type)); - GGML_ASSERT(!ids || ne12 == 1); // Implementation is only correct for batch size 1. + GGML_ASSERT(!ids || ne12 <= MMVQ_MAX_BATCH_SIZE); const float * src1_d = (const float *) src1->data; const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr; @@ -693,11 +726,13 @@ void ggml_cuda_mul_mat_vec_q( const int64_t stride_channel_dst = ids ? s1 : s2; const int64_t stride_channel_y = ids ? s11 : s12; + const int64_t ids_stride = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0; + mul_mat_vec_q_switch_type( src0->data, src0->type, src1_q8_1.get(), ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, stride_col_y, stride_col_dst, ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, - ne03, ne3, s03, s13, s3, stream); + ne03, ne3, s03, s13, s3, ids_stride, stream); } void ggml_cuda_op_mul_mat_vec_q( @@ -726,7 +761,7 @@ void ggml_cuda_op_mul_mat_vec_q( ggml_cuda_mm_fusion_args_device fusion_local{}; mul_mat_vec_q_switch_type( src0_dd_i, src0->type, src1_ddq_i, nullptr, fusion_local, dst_dd_i, ne00, row_diff, src1_ncols, stride_row_x, stride_col_y, nrows_dst, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, stream); + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, stream); GGML_UNUSED_VARS(src1, dst, src1_ddf_i, src1_ncols, src1_padded_row_size); } diff --git a/ggml/src/ggml-cuda/rope.cu b/ggml/src/ggml-cuda/rope.cu index 88ed79111a1..45a49a5dc2a 100644 --- a/ggml/src/ggml-cuda/rope.cu +++ b/ggml/src/ggml-cuda/rope.cu @@ -43,10 +43,15 @@ static __device__ void rope_yarn( template static __global__ void rope_norm(const T * x, D * dst, - const int ne0, - const int ne1, + const int ne00, + const int ne01, + const int ne02, + const int s01, + const int s02, + const int s03, const int s1, const int s2, + const int s3, const int n_dims, const int32_t * pos, const float freq_scale, @@ -59,23 +64,23 @@ static __global__ void rope_norm(const T * x, const int set_rows_stride) { const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); - if (i0 >= ne0) { + if (i0 >= ne00) { return; } const int row_dst = blockDim.x*blockIdx.x + threadIdx.x; - const int row_x = row_dst % ne1; - const int channel_x = row_dst / ne1; - - int idst = row_dst * ne0 + i0; - const int ix = channel_x*s2 + row_x*s1 + i0; + const uint32_t i3 = row_dst / (ne01 * ne02); + const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01; + const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01; + int idst = i0 + i1 * s1 + i2 * s2 + i3 * s3; + const int ix = i0 + i1 * s01 + i2 * s02 + i3 * s03; // Fusion optimization: ROPE + VIEW + SET_ROWS. // The rope output is viewed as a 1D tensor and offset based on a row index in row_indices. if (set_rows_stride != 0) { - idst = row_x * ne0 + i0; - idst += row_indices[channel_x] * set_rows_stride; + idst = i1 * s1 + i0; + idst += row_indices[i2] * set_rows_stride; } const auto & store_coaelsced = [&](float x0, float x1) { @@ -92,7 +97,7 @@ static __global__ void rope_norm(const T * x, return; } - const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f); + const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f); const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; @@ -110,10 +115,15 @@ static __global__ void rope_norm(const T * x, template static __global__ void rope_neox(const T * x, D * dst, - const int ne0, - const int ne1, + const int ne00, + const int ne01, + const int ne02, + const int s01, + const int s02, + const int s03, const int s1, const int s2, + const int s3, const int n_dims, const int32_t * pos, const float freq_scale, @@ -126,23 +136,24 @@ static __global__ void rope_neox(const T * x, const int set_rows_stride) { const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); - if (i0 >= ne0) { + if (i0 >= ne00) { return; } const int row_dst = blockDim.x*blockIdx.x + threadIdx.x; - const int row_x = row_dst % ne1; - const int channel_x = row_dst / ne1; + const uint32_t i3 = row_dst / (ne01 * ne02); + const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01; + const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01; - int idst = row_dst * ne0 + i0 / 2; - const int ix = channel_x*s2 + row_x*s1 + i0/2; + int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3; + const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03; // Fusion optimization: ROPE + VIEW + SET_ROWS. // The rope output is viewed as a 1D tensor and offset based on a row index in row_indices. if (set_rows_stride != 0) { - idst = row_x * ne0 + i0 / 2; - idst += row_indices[channel_x] * set_rows_stride; + idst = i1 * s1 + i0 / 2; + idst += row_indices[i2] * set_rows_stride; } if (i0 >= n_dims) { @@ -152,7 +163,7 @@ static __global__ void rope_neox(const T * x, return; } - const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f); + const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f); const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; @@ -168,24 +179,42 @@ static __global__ void rope_neox(const T * x, dst[idst + n_dims / 2] = ggml_cuda_cast(x0 * sin_theta + x1 * cos_theta); } -template -static __global__ void rope_multi( - const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, - const int n_dims, const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, - const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, const mrope_sections sections, const bool is_imrope) { - const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); - - if (i0 >= ne0) { +template +static __global__ void rope_multi(const T * x, + T * dst, + const int ne00, + const int ne01, + const int ne02, + const int s01, + const int s02, + const int s03, + const int s1, + const int s2, + const int s3, + const int n_dims, + const int32_t * pos, + const float freq_scale, + const float ext_factor, + const float attn_factor, + const rope_corr_dims corr_dims, + const float theta_scale, + const float * freq_factors, + const mrope_sections sections, + const bool is_imrope) { + const int i0 = 2 * (blockDim.y * blockIdx.y + threadIdx.y); + + if (i0 >= ne00) { return; } const int row_dst = blockDim.x*blockIdx.x + threadIdx.x; - const int row_x = row_dst % ne1; - const int channel_x = row_dst / ne1; + const uint32_t i3 = row_dst / (ne01 * ne02); + const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01; + const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01; - const int idst = row_dst*ne0 + i0/2; - const int ix = channel_x*s2 + row_x*s1 + i0/2; + int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3; + const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03; if (i0 >= n_dims) { dst[idst + i0/2 + 0] = x[ix + i0/2 + 0]; @@ -200,27 +229,24 @@ static __global__ void rope_multi( float theta_base = 0.0; if (is_imrope) { - if (sector % 3 == 1 && sector < 3 * sections.v[1]) { // h - theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f); - } else if (sector % 3 == 2 && sector < 3 * sections.v[2]) { // w - theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f); - } else if (sector % 3 == 0 && sector < 3 * sections.v[0]) { // t - theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f); + if (sector % 3 == 1 && sector < 3 * sections.v[1]) { // h + theta_base = pos[i2 + ne02 * 1] * powf(theta_scale, i0 / 2.0f); + } else if (sector % 3 == 2 && sector < 3 * sections.v[2]) { // w + theta_base = pos[i2 + ne02 * 2] * powf(theta_scale, i0 / 2.0f); + } else if (sector % 3 == 0 && sector < 3 * sections.v[0]) { // t + theta_base = pos[i2] * powf(theta_scale, i0 / 2.0f); } else { - theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f); + theta_base = pos[i2 + ne02 * 3] * powf(theta_scale, i0 / 2.0f); } } else { if (sector < sections.v[0]) { - theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f); - } - else if (sector >= sections.v[0] && sector < sec_w) { - theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f); - } - else if (sector >= sec_w && sector < sec_w + sections.v[2]) { - theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f); - } - else if (sector >= sec_w + sections.v[2]) { - theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f); + theta_base = pos[i2] * powf(theta_scale, i0 / 2.0f); + } else if (sector >= sections.v[0] && sector < sec_w) { + theta_base = pos[i2 + ne02 * 1] * powf(theta_scale, i0 / 2.0f); + } else if (sector >= sec_w && sector < sec_w + sections.v[2]) { + theta_base = pos[i2 + ne02 * 2] * powf(theta_scale, i0 / 2.0f); + } else if (sector >= sec_w + sections.v[2]) { + theta_base = pos[i2 + ne02 * 3] * powf(theta_scale, i0 / 2.0f); } } @@ -238,37 +264,53 @@ static __global__ void rope_multi( dst[idst + n_dims/2] = x0*sin_theta + x1*cos_theta; } -template -static __global__ void rope_vision( - const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, - const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims, - const float theta_scale, const float * freq_factors, const mrope_sections sections) { +template +static __global__ void rope_vision(const T * x, + T * dst, + const int ne00, + const int ne01, + const int ne02, + const int s01, + const int s02, + const int s03, + const int s1, + const int s2, + const int s3, + const int n_dims, + const int32_t * pos, + const float freq_scale, + const float ext_factor, + const float attn_factor, + const rope_corr_dims corr_dims, + const float theta_scale, + const float * freq_factors, + const mrope_sections sections) { const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); - if (i0 >= ne0) { + if (i0 >= ne00) { return; } const int row_dst = blockDim.x*blockIdx.x + threadIdx.x; - const int row_x = row_dst % ne1; - const int channel_x = row_dst / ne1; + const uint32_t i3 = row_dst / (ne01 * ne02); + const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01; + const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01; - const int idst = row_dst*ne0 + i0/2; - const int ix = channel_x*s2 + row_x*s1 + i0/2; + int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3; + const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03; const int sect_dims = sections.v[0] + sections.v[1]; - const int sec_w = sections.v[1] + sections.v[0]; - const int sector = (i0 / 2) % sect_dims; + const int sec_w = sections.v[1] + sections.v[0]; + const int sector = (i0 / 2) % sect_dims; float theta_base = 0.0; if (sector < sections.v[0]) { const int p = sector; - theta_base = pos[channel_x]*powf(theta_scale, p); - } - else if (sector >= sections.v[0] && sector < sec_w) { + theta_base = pos[i2] * powf(theta_scale, p); + } else if (sector >= sections.v[0] && sector < sec_w) { const int p = sector - sections.v[0]; - theta_base = pos[channel_x + ne2]*powf(theta_scale, p); + theta_base = pos[i2 + ne02] * powf(theta_scale, p); } const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; @@ -288,10 +330,15 @@ static __global__ void rope_vision( template static void rope_norm_cuda(const T * x, D * dst, - const int ne0, - const int ne1, + const int ne00, + const int ne01, + const int ne02, + const int s01, + const int s02, + const int s03, const int s1, const int s2, + const int s3, const int n_dims, const int nr, const int32_t * pos, @@ -304,31 +351,36 @@ static void rope_norm_cuda(const T * x, const int64_t * row_indices, const int set_rows_stride, cudaStream_t stream) { - GGML_ASSERT(ne0 % 2 == 0); + GGML_ASSERT(ne00 % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); - const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); + const int n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE); const dim3 block_nums(nr, n_blocks_x, 1); - const float theta_scale = powf(freq_base, -2.0f/n_dims); + const float theta_scale = powf(freq_base, -2.0f / n_dims); if (freq_factors == nullptr) { rope_norm<<>>( - x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, - freq_factors, row_indices, set_rows_stride); + x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor, + attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride); } else { rope_norm<<>>( - x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, - freq_factors, row_indices, set_rows_stride); + x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor, + attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride); } } template static void rope_neox_cuda(const T * x, D * dst, - const int ne0, - const int ne1, + const int ne00, + const int ne01, + const int ne02, + const int s01, + const int s02, + const int s03, const int s1, const int s2, + const int s3, const int n_dims, const int nr, const int32_t * pos, @@ -341,55 +393,92 @@ static void rope_neox_cuda(const T * x, const int64_t * row_indices, const int set_rows_stride, cudaStream_t stream) { - GGML_ASSERT(ne0 % 2 == 0); + GGML_ASSERT(ne00 % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); - const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); + const int n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE); const dim3 block_nums(nr, n_blocks_x, 1); - const float theta_scale = powf(freq_base, -2.0f/n_dims); + const float theta_scale = powf(freq_base, -2.0f / n_dims); if (freq_factors == nullptr) { rope_neox<<>>( - x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, - freq_factors, row_indices, set_rows_stride); + x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor, + attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride); } else { rope_neox<<>>( - x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, - freq_factors, row_indices, set_rows_stride); + x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor, + attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride); } } -template -static void rope_multi_cuda( - const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr, - const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor, - const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, const bool is_imrope, cudaStream_t stream) { - GGML_ASSERT(ne0 % 2 == 0); +template +static void rope_multi_cuda(const T * x, + T * dst, + const int ne00, + const int ne01, + const int ne02, + const int s01, + const int s02, + const int s03, + const int s1, + const int s2, + const int s3, + const int n_dims, + const int nr, + const int32_t * pos, + const float freq_scale, + const float freq_base, + const float ext_factor, + const float attn_factor, + const rope_corr_dims corr_dims, + const float * freq_factors, + const mrope_sections sections, + const bool is_imrope, + cudaStream_t stream) { + GGML_ASSERT(ne00 % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); - const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); + const int n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE); const dim3 block_nums(nr, n_blocks_x, 1); - const float theta_scale = powf(freq_base, -2.0f/n_dims); + const float theta_scale = powf(freq_base, -2.0f / n_dims); if (freq_factors == nullptr) { rope_multi<<>>( - x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, + x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope); } else { rope_multi<<>>( - x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, + x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope); } } -template -static void rope_vision_cuda( - const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr, - const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor, - const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, cudaStream_t stream) { - GGML_ASSERT(ne0 % 2 == 0); +template +static void rope_vision_cuda(const T * x, + T * dst, + const int ne00, + const int ne01, + const int ne02, + const int s01, + const int s02, + const int s03, + const int s1, + const int s2, + const int s3, + const int n_dims, + const int nr, + const int32_t * pos, + const float freq_scale, + const float freq_base, + const float ext_factor, + const float attn_factor, + const rope_corr_dims corr_dims, + const float * freq_factors, + const mrope_sections sections, + cudaStream_t stream) { + GGML_ASSERT(ne00 % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); - const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); + const int n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE); const dim3 block_nums(nr, n_blocks_x, 1); // break down (head_dim, heads, seq) into (CUDA_ROPE_BLOCK_SIZE, x, heads * seq) // where x ~= ceil(head_dim / CUDA_ROPE_BLOCK_SIZE); @@ -398,11 +487,11 @@ static void rope_vision_cuda( if (freq_factors == nullptr) { rope_vision<<>>( - x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, + x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, freq_factors, sections); } else { rope_vision<<>>( - x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, + x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, freq_factors, sections); } } @@ -445,6 +534,11 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, const size_t s01 = src0->nb[1] / ggml_type_size(src0->type); const size_t s02 = src0->nb[2] / ggml_type_size(src0->type); + const size_t s03 = src0->nb[3] / ggml_type_size(src0->type); + + const size_t s1 = dst->nb[1] / ggml_type_size(dst->type); + const size_t s2 = dst->nb[2] / ggml_type_size(dst->type); + const size_t s3 = dst->nb[3] / ggml_type_size(dst->type); //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; @@ -495,57 +589,63 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, // compute if (is_neox) { if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) { - rope_neox_cuda((const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims, - nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, - freq_factors, row_indices, set_rows_stride, stream); + rope_neox_cuda((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, + s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base, + ext_factor, attn_factor, corr_dims, freq_factors, row_indices, + set_rows_stride, stream); } else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) { - rope_neox_cuda((const float *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, - nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, - freq_factors, row_indices, set_rows_stride, stream); + rope_neox_cuda((const float *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, + s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base, + ext_factor, attn_factor, corr_dims, freq_factors, row_indices, + set_rows_stride, stream); } else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) { - rope_neox_cuda((const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr, - pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, - freq_factors, row_indices, set_rows_stride, stream); + rope_neox_cuda((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, + s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base, + ext_factor, attn_factor, corr_dims, freq_factors, row_indices, + set_rows_stride, stream); } else { GGML_ABORT("fatal error"); } } else if (is_mrope && !is_vision) { if (src0->type == GGML_TYPE_F32) { - rope_multi_cuda( - (const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale, - freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, is_imrope, stream); + rope_multi_cuda((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, s03, s1, + s2, s3, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, + corr_dims, freq_factors, sections, is_imrope, stream); } else if (src0->type == GGML_TYPE_F16) { - rope_multi_cuda( - (const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale, - freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, is_imrope, stream); + rope_multi_cuda((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, s03, s1, + s2, s3, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, + corr_dims, freq_factors, sections, is_imrope, stream); } else { GGML_ABORT("fatal error"); } } else if (is_vision) { if (src0->type == GGML_TYPE_F32) { - rope_vision_cuda( - (const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale, - freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream); + rope_vision_cuda((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, s03, s1, + s2, s3, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, + corr_dims, freq_factors, sections, stream); } else if (src0->type == GGML_TYPE_F16) { - rope_vision_cuda( - (const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale, - freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream); + rope_vision_cuda((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, s03, s1, + s2, s3, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, + corr_dims, freq_factors, sections, stream); } else { GGML_ABORT("fatal error"); } } else { if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) { - rope_norm_cuda((const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims, - nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, - freq_factors, row_indices, set_rows_stride, stream); + rope_norm_cuda((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, + s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base, + ext_factor, attn_factor, corr_dims, freq_factors, row_indices, + set_rows_stride, stream); } else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) { - rope_norm_cuda((const float *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, - nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, - freq_factors, row_indices, set_rows_stride, stream); + rope_norm_cuda((const float *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, + s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base, + ext_factor, attn_factor, corr_dims, freq_factors, row_indices, + set_rows_stride, stream); } else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) { - rope_norm_cuda((const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr, - pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, - freq_factors, row_indices, set_rows_stride, stream); + rope_norm_cuda((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, + s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base, + ext_factor, attn_factor, corr_dims, freq_factors, row_indices, + set_rows_stride, stream); } else { GGML_ABORT("fatal error"); } diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu new file mode 100644 index 00000000000..1f554d81e5e --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(576, 512, 1, 32); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu index 2074e954a32..517993cb068 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu @@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 16, 4); DECL_FATTN_MMA_F16_CASE(112, 112, 16, 4); DECL_FATTN_MMA_F16_CASE(128, 128, 16, 4); DECL_FATTN_MMA_F16_CASE(256, 256, 16, 4); +DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu new file mode 100644 index 00000000000..264751d65ec --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(576, 512, 2, 32); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu index 24c64cf000f..97b19c67ade 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu @@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 2, 4); DECL_FATTN_MMA_F16_CASE(112, 112, 2, 4); DECL_FATTN_MMA_F16_CASE(128, 128, 2, 4); DECL_FATTN_MMA_F16_CASE(256, 256, 2, 4); +DECL_FATTN_MMA_F16_CASE(576, 512, 2, 4); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu index 1ada657f194..989626dfa5e 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu @@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 4, 4); DECL_FATTN_MMA_F16_CASE(112, 112, 4, 4); DECL_FATTN_MMA_F16_CASE(128, 128, 4, 4); DECL_FATTN_MMA_F16_CASE(256, 256, 4, 4); +DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu index 86d4ffae27c..173de7aac7d 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu @@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 8, 4); DECL_FATTN_MMA_F16_CASE(112, 112, 8, 4); DECL_FATTN_MMA_F16_CASE(128, 128, 8, 4); DECL_FATTN_MMA_F16_CASE(256, 256, 8, 4); +DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4); diff --git a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py index a5602da02bb..e382df1ae20 100755 --- a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +++ b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py @@ -71,7 +71,7 @@ def get_short_name(long_quant_name): f.write(SOURCE_FATTN_VEC.format(type_k=type_k, type_v=type_v)) for ncols in [8, 16, 32, 64]: - for ncols2 in [1, 2, 4, 8, 16]: + for ncols2 in [1, 2, 4, 8, 16, 32]: if ncols2 > ncols: continue ncols1 = ncols // ncols2 @@ -83,9 +83,9 @@ def get_short_name(long_quant_name): continue if head_size_kq == 72: continue - if head_size_kq != 576 and ncols2 == 16: + if head_size_kq != 576 and ncols2 in (16, 32): continue - if head_size_kq == 576 and ncols2 != 16: + if head_size_kq == 576 and ncols2 not in (4, 16, 32): continue head_size_v = head_size_kq if head_size_kq != 576 else 512 f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size_kq=head_size_kq, head_size_v=head_size_v)) diff --git a/ggml/src/ggml-cuda/topk-moe.cu b/ggml/src/ggml-cuda/topk-moe.cu index 48e569efa0d..08a88990dde 100644 --- a/ggml/src/ggml-cuda/topk-moe.cu +++ b/ggml/src/ggml-cuda/topk-moe.cu @@ -5,6 +5,13 @@ #include #include +// Kernel config struct - passed by value to CUDA kernel +struct topk_moe_config { + bool use_sigmoid; + bool with_norm; + bool delayed_softmax; +}; + // Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path. template __device__ void softmax_warp_inplace(float (&vals)[experts_per_thread], const int limit, const int lane) { @@ -50,6 +57,16 @@ __device__ void softmax_warp_inplace(float (&vals)[experts_per_thread], const in } } +template +__device__ void sigmoid_warp_inplace(float (&vals)[experts_per_thread], const int limit, const int lane) { +#pragma unroll + for (int i = 0; i < experts_per_thread; i++) { + const int idx = lane + i * WARP_SIZE; + const bool active = !use_limit || (idx < limit); + vals[i] = active ? 1.f / (1.f + expf(-vals[i])) : -INFINITY; + } +} + /* This kernel does the following: 1. optionally softmax over the logits per token [n_experts, n_tokens] @@ -59,13 +76,16 @@ __device__ void softmax_warp_inplace(float (&vals)[experts_per_thread], const in It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models */ -template -__launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits, - float * weights, - int32_t * ids, - const int n_rows, - const int n_expert_used, - const float clamp_val) { +template +__launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits, + float * weights, + int32_t * ids, + float * bias, + const int n_rows, + const int n_expert_used, + const float clamp_val, + const float scale_val, + const topk_moe_config config) { const int row = blockIdx.x * blockDim.y + threadIdx.y; if (row >= n_rows) { return; @@ -79,14 +99,41 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * float wt[experts_per_thread]; + // Initialize all slots to -INFINITY +#pragma unroll + for (int i = 0; i < experts_per_thread; i++) { + wt[i] = -INFINITY; + } + #pragma unroll for (int i = 0; i < n_experts; i += WARP_SIZE) { const int expert = i + threadIdx.x; wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[expert] : -INFINITY; } - if constexpr (!delayed_softmax) { - softmax_warp_inplace(wt, n_experts, threadIdx.x); + if (!config.delayed_softmax) { + if (config.use_sigmoid) { + sigmoid_warp_inplace(wt, n_experts, threadIdx.x); + } else { + softmax_warp_inplace(wt, n_experts, threadIdx.x); + } + } + + // selection_wt is only needed when bias is present (selection uses wt + bias) + // when no bias, we use wt directly for both selection and weight values + float selection_wt[has_bias ? experts_per_thread : 1]; + + if constexpr (has_bias) { +#pragma unroll + for (int i = 0; i < experts_per_thread; i++) { + selection_wt[i] = -INFINITY; + } +#pragma unroll + for (int i = 0; i < n_experts; i += WARP_SIZE) { + const int expert = i + threadIdx.x; + selection_wt[i / WARP_SIZE] = + (n_experts % WARP_SIZE == 0 || expert < n_experts) ? wt[i / WARP_SIZE] + bias[expert] : -INFINITY; + } } //at this point, each thread holds either a portion of the softmax distribution @@ -106,22 +153,56 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * float max_val = wt[0]; int max_expert = threadIdx.x; + if constexpr (has_bias) { + float max_val_s = selection_wt[0]; + #pragma unroll - for (int i = 1; i < experts_per_thread; i++) { - const int expert = threadIdx.x + i * WARP_SIZE; - if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) { - max_val = wt[i]; - max_expert = expert; + for (int i = 1; i < experts_per_thread; i++) { + const int expert = threadIdx.x + i * WARP_SIZE; + if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && selection_wt[i] > max_val_s) { + max_val = wt[i]; + max_val_s = selection_wt[i]; + max_expert = expert; + } + } + +#pragma unroll + for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) { + const float val = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE); + const float val_s = __shfl_xor_sync(0xFFFFFFFF, max_val_s, mask, WARP_SIZE); + const int expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE); + if (val_s > max_val_s || (val_s == max_val_s && expert < max_expert)) { + max_val = val; + max_val_s = val_s; + max_expert = expert; + } + } + + if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) { + selection_wt[max_expert / WARP_SIZE] = -INFINITY; + } + } else { +#pragma unroll + for (int i = 1; i < experts_per_thread; i++) { + const int expert = threadIdx.x + i * WARP_SIZE; + if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) { + max_val = wt[i]; + max_expert = expert; + } } - } #pragma unroll - for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) { - const float val = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE); - const int expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE); - if (val > max_val || (val == max_val && expert < max_expert)) { - max_val = val; - max_expert = expert; + for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) { + const float val = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE); + const int expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE); + if (val > max_val || (val == max_val && expert < max_expert)) { + max_val = val; + max_expert = expert; + } + } + + if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) { + wt[max_expert / WARP_SIZE] = -INFINITY; } } @@ -130,16 +211,14 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * } if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) { - wt[max_expert / WARP_SIZE] = -INFINITY; - ids[k] = max_expert; - if constexpr (with_norm) { + if (config.with_norm) { wt_sum += max_val; } } } - if constexpr (with_norm) { + if (config.with_norm) { wt_sum = warp_reduce_sum(wt_sum); wt_sum = max(wt_sum, clamp_val); const float inv_sum = 1.0f / wt_sum; @@ -149,7 +228,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * } } - if constexpr (delayed_softmax) { + if (config.delayed_softmax) { softmax_warp_inplace(output_weights, n_expert_used, threadIdx.x); } @@ -157,25 +236,25 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * for (int i = 0; i < experts_per_thread; i++) { const int idx = i * WARP_SIZE + threadIdx.x; if (idx < n_expert_used) { - weights[idx] = output_weights[i]; + weights[idx] = output_weights[i] * scale_val; } } - - if (!with_norm) { - GGML_UNUSED(clamp_val); - } } -template +template static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx, const float * logits, float * weights, int32_t * ids, + float * bias, const int n_rows, const int n_expert, const int n_expert_used, - const float clamp_val) { - static_assert(!(with_norm && delayed_softmax), "delayed softmax is not supported with weight normalization"); + const float clamp_val, + const float scale_val, + const topk_moe_config config) { + GGML_ASSERT(!(config.with_norm && config.delayed_softmax) && + "delayed softmax is not supported with weight normalization"); const int rows_per_block = 4; dim3 grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1); dim3 block_dims(WARP_SIZE, rows_per_block, 1); @@ -183,44 +262,48 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx, switch (n_expert) { case 1: - topk_moe_cuda<1, with_norm, delayed_softmax> - <<>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); + topk_moe_cuda<1, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, + clamp_val, scale_val, config); break; case 2: - topk_moe_cuda<2, with_norm, delayed_softmax> - <<>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); + topk_moe_cuda<2, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, + clamp_val, scale_val, config); break; case 4: - topk_moe_cuda<4, with_norm, delayed_softmax> - <<>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); + topk_moe_cuda<4, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, + clamp_val, scale_val, config); break; case 8: - topk_moe_cuda<8, with_norm, delayed_softmax> - <<>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); + topk_moe_cuda<8, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, + clamp_val, scale_val, config); break; case 16: - topk_moe_cuda<16, with_norm, delayed_softmax> - <<>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); + topk_moe_cuda<16, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, + clamp_val, scale_val, config); break; case 32: - topk_moe_cuda<32, with_norm, delayed_softmax> - <<>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); + topk_moe_cuda<32, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, + clamp_val, scale_val, config); break; case 64: - topk_moe_cuda<64, with_norm, delayed_softmax> - <<>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); + topk_moe_cuda<64, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, + clamp_val, scale_val, config); break; case 128: - topk_moe_cuda<128, with_norm, delayed_softmax> - <<>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); + topk_moe_cuda<128, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, + clamp_val, scale_val, config); break; case 256: - topk_moe_cuda<256, with_norm, delayed_softmax> - <<>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); + topk_moe_cuda<256, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, + clamp_val, scale_val, config); break; case 512: - topk_moe_cuda<512, with_norm, delayed_softmax> - <<>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); + topk_moe_cuda<512, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, + clamp_val, scale_val, config); + break; + case 576: + topk_moe_cuda<576, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, + clamp_val, scale_val, config); break; default: GGML_ASSERT(false && "fatal error"); @@ -228,13 +311,14 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx, } } -void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, - const ggml_tensor * logits, - ggml_tensor * weights, - ggml_tensor * ids, - const bool with_norm, - const bool delayed_softmax, - ggml_tensor * clamp) { +void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, + const ggml_tensor * logits, + ggml_tensor * weights, + ggml_tensor * ids, + const ggml_tensor * clamp, + const ggml_tensor * scale, + const ggml_tensor * bias, + const ggml_cuda_topk_moe_args & args) { GGML_ASSERT(logits->type == GGML_TYPE_F32); GGML_ASSERT(weights->type == GGML_TYPE_F32); GGML_ASSERT(ids->type == GGML_TYPE_I32); @@ -245,107 +329,75 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, const float * logits_d = (const float *) logits->data; float * weights_d = (float *) weights->data; int32_t * ids_d = (int32_t *) ids->data; + float * bias_d = bias ? (float *) bias->data : nullptr; + + float scale_val = scale ? ggml_get_op_params_f32(scale, 0) : 1.0f; GGML_ASSERT(ids->nb[1] / ggml_type_size(ids->type) == (size_t) n_experts); const int n_expert_used = weights->ne[1]; + const bool with_norm = clamp != nullptr; + float clamp_val = -INFINITY; - if (with_norm) { - if (clamp) { - clamp_val = ggml_get_op_params_f32(clamp, 0); - } - launch_topk_moe_cuda(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used, clamp_val); + if (clamp) { + clamp_val = ggml_get_op_params_f32(clamp, 0); + } + + topk_moe_config config; + config.use_sigmoid = args.sigmoid; + config.with_norm = with_norm; + config.delayed_softmax = args.delayed_softmax; + + if (bias) { + launch_topk_moe_cuda(ctx, logits_d, weights_d, ids_d, bias_d, n_rows, n_experts, n_expert_used, clamp_val, + scale_val, config); } else { - GGML_ASSERT(clamp == nullptr); - if (delayed_softmax) { - launch_topk_moe_cuda(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used, - clamp_val); - } else { - launch_topk_moe_cuda(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used, - clamp_val); - } + launch_topk_moe_cuda(ctx, logits_d, weights_d, ids_d, bias_d, n_rows, n_experts, n_expert_used, clamp_val, + scale_val, config); } } -bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, +bool ggml_cuda_should_use_topk_moe(const ggml_tensor * gating_op, const ggml_tensor * weights, - const ggml_tensor * get_rows, - const ggml_tensor * argsort, - const ggml_tensor * clamp, - int n_expert) { - ggml_tensor * probs = get_rows->src[0]; - if (probs->op != GGML_OP_RESHAPE) { + const ggml_tensor * logits, + const ggml_tensor * ids) { + const int n_expert = ids->nb[1] / ids->nb[0]; + if (((n_expert & (n_expert - 1)) != 0 || n_expert > 512) && n_expert != 576) { return false; } - probs = probs->src[0]; - ggml_tensor * selection_probs = argsort->src[0]; - if (probs != selection_probs) { + if (!ggml_is_contiguous(weights) || !ggml_is_contiguous(logits)) { return false; } - float scale = 1.0f; - float max_bias = 0.0f; + if (gating_op->op == GGML_OP_SOFT_MAX) { + const ggml_tensor * softmax = gating_op; + float scale = 1.0f; + float max_bias = 0.0f; - memcpy(&scale, (const float *) softmax->op_params + 0, sizeof(float)); - memcpy(&max_bias, (const float *) softmax->op_params + 1, sizeof(float)); + memcpy(&scale, (const float *) softmax->op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float *) softmax->op_params + 1, sizeof(float)); - if (!ggml_is_contiguous(softmax->src[0]) || !ggml_is_contiguous(weights)) { - return false; - } - - if (scale != 1.0f || max_bias != 0.0f) { - return false; - } - - // don't fuse when masks or sinks are present - if (softmax->src[1] || softmax->src[2]) { - return false; - } + if (!ggml_is_contiguous(softmax->src[0])) { + return false; + } - // n_expert must be a power of 2 - if ((n_expert & (n_expert - 1)) != 0 || n_expert > 512) { - return false; - } + if (scale != 1.0f || max_bias != 0.0f) { + return false; + } - if (clamp) { - if (clamp->op != GGML_OP_CLAMP) { + // don't fuse when masks or sinks are present + if (softmax->src[1] || softmax->src[2]) { return false; } - float max_val = ggml_get_op_params_f32(clamp, 1); + } else if (gating_op->op == GGML_OP_UNARY) { + ggml_unary_op op = ggml_get_unary_op(gating_op); - if (max_val != INFINITY) { + if (op != GGML_UNARY_OP_SIGMOID) { return false; } } - return true; } - -std::initializer_list ggml_cuda_topk_moe_ops(bool norm, bool delayed_softmax) { - static std::initializer_list norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, - GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE, - GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV, - GGML_OP_RESHAPE }; - - static std::initializer_list no_norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, - GGML_OP_VIEW, GGML_OP_GET_ROWS }; - - static std::initializer_list delayed_softmax_ops = { GGML_OP_ARGSORT, GGML_OP_VIEW, - GGML_OP_GET_ROWS, GGML_OP_RESHAPE, - GGML_OP_SOFT_MAX, GGML_OP_RESHAPE }; - - GGML_ASSERT(!norm || !delayed_softmax); - - if (delayed_softmax) { - return delayed_softmax_ops; - } - - if (norm) { - return norm_ops; - } - - return no_norm_ops; -} diff --git a/ggml/src/ggml-cuda/topk-moe.cuh b/ggml/src/ggml-cuda/topk-moe.cuh index 6b6c13c5870..243dc2f1c41 100644 --- a/ggml/src/ggml-cuda/topk-moe.cuh +++ b/ggml/src/ggml-cuda/topk-moe.cuh @@ -3,19 +3,25 @@ #include -void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, - const ggml_tensor * logits, - ggml_tensor * weights, - ggml_tensor * ids, - const bool with_norm, - const bool delayed_softmax = false, - ggml_tensor * weight_clamp = nullptr); +struct ggml_cuda_topk_moe_args { + bool sigmoid{}; + bool softmax{}; + bool delayed_softmax{}; + bool prob_bias{}; + bool norm{}; + bool scale{}; +}; -bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, - const ggml_tensor * weights, - const ggml_tensor * get_rows, - const ggml_tensor * argsort, - const ggml_tensor * clamp, - int n_expert); +void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, + const ggml_tensor * logits, + ggml_tensor * weights, + ggml_tensor * ids, + const ggml_tensor * clamp, + const ggml_tensor * scale, + const ggml_tensor * bias, + const ggml_cuda_topk_moe_args & args); -std::initializer_list ggml_cuda_topk_moe_ops(bool with_norm, bool delayed_softmax = false); +bool ggml_cuda_should_use_topk_moe(const ggml_tensor * gating_op, + const ggml_tensor * weights, + const ggml_tensor * logits, + const ggml_tensor * ids); diff --git a/ggml/src/ggml-hexagon/CMakeLists.txt b/ggml/src/ggml-hexagon/CMakeLists.txt index d58e2878237..f3a583543c6 100644 --- a/ggml/src/ggml-hexagon/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/CMakeLists.txt @@ -1,7 +1,29 @@ +file(TO_CMAKE_PATH "${HEXAGON_SDK_ROOT}" HEXAGON_SDK_ROOT) +file(TO_CMAKE_PATH "${HEXAGON_TOOLS_ROOT}" HEXAGON_TOOLS_ROOT) + +if (NOT IS_DIRECTORY "${HEXAGON_SDK_ROOT}") + message(FATAL_ERROR "Make sure HEXAGON_SDK_ROOT point to the correct Hexagon SDK installation.") +endif() + +if (NOT IS_DIRECTORY "${HEXAGON_TOOLS_ROOT}") + message("Try to read HEXAGON_TOOLS_ROOT from hexagon_sdk.json") + file(READ "${HEXAGON_SDK_ROOT}/hexagon_sdk.json" HEXAGON_SDK_CONFIG_PATH) + string(JSON HEXAGON_TOOLS_PATH GET ${HEXAGON_SDK_CONFIG_PATH} "root" "tools" "info" 0 "path") + message("Found HEXAGON_TOOLS_PATH: ${HEXAGON_TOOLS_PATH}") + set(HEXAGON_TOOLS_ROOT "${HEXAGON_SDK_ROOT}/${HEXAGON_TOOLS_PATH}") + file(TO_CMAKE_PATH "${HEXAGON_TOOLS_ROOT}" HEXAGON_TOOLS_ROOT) + if (NOT IS_DIRECTORY "${HEXAGON_TOOLS_ROOT}") + message(FATAL_ERROR "Make sure HEXAGON_TOOLS_ROOT point to the correct Hexagon SDK installation.") + endif() +endif() + +message(STATUS "hexagon: using ${HEXAGON_SDK_ROOT} and ${HEXAGON_TOOLS_ROOT} for building libggml-htp skels") + include(${HEXAGON_SDK_ROOT}/build/cmake/hexagon_fun.cmake) include(ExternalProject) option(GGML_HEXAGON_HTP_DEBUG "ggml-hexagon: enable HTP debug output" OFF) +set(GGML_HEXAGON_HTP_CERT "$ENV{HEXAGON_HTP_CERT}" CACHE PATH "ggml-hexagon: enable HTP library signing using certificate") set(GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE 128 CACHE STRING "ggml-hexagon: quantize group size (32, 64, or 128)") add_library(htp_iface OBJECT @@ -25,56 +47,71 @@ else() target_link_options(htp_iface PUBLIC -ldl) endif() -link_custom_library(htp_iface cdsprpc) -link_custom_library(htp_iface rpcmem) - set(TARGET_NAME ggml-hexagon) ggml_add_backend_library(${TARGET_NAME} - ggml-hexagon.cpp htp-utils.c htp-utils.h ../../include/ggml-hexagon.h) + ggml-hexagon.cpp + htp-drv.cpp + htp-drv.h + libdl.h + ../../include/ggml-hexagon.h) target_link_libraries(${TARGET_NAME} PRIVATE htp_iface) target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/htp ${CMAKE_CURRENT_BINARY_DIR}) -# Build HTP bits -set(HTP_CMAKE_ARGS - -DCMAKE_TOOLCHAIN_FILE=${CMAKE_CURRENT_SOURCE_DIR}/htp/cmake-toolchain.cmake - -DCMAKE_BUILD_TYPE=Release - -DCMAKE_INSTALL_LIBDIR=${CMAKE_CURRENT_BINARY_DIR} - -DHEXAGON_SDK_ROOT=$ENV{HEXAGON_SDK_ROOT} - -DHEXAGON_TOOLS_ROOT=$ENV{HEXAGON_TOOLS_ROOT} - -DHEXAGON_HTP_DEBUG=${GGML_HEXAGON_HTP_DEBUG} - -DGGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE}) - -ExternalProject_Add(htp-v68 - SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON - CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v68 -DPREBUILT_LIB_DIR="toolv19_v68") - -ExternalProject_Add(htp-v69 - SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON - CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v69 -DPREBUILT_LIB_DIR="toolv19_v69") - -ExternalProject_Add(htp-v73 - SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON - CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v73 -DPREBUILT_LIB_DIR="toolv19_v73") - -ExternalProject_Add(htp-v75 - SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON - CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v75 -DPREBUILT_LIB_DIR="toolv19_v75") - -ExternalProject_Add(htp-v79 - SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON - CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v79 -DPREBUILT_LIB_DIR="toolv19_v79") - -ExternalProject_Add(htp-v81 - SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON - CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v81 -DPREBUILT_LIB_DIR="toolv19_v81") +# Build HTP skels +set(HTP_SKELS) +function(build_htp_skel V) + ExternalProject_Add(htp-${V} + SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON + BUILD_BYPRODUCTS ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-${V}.so + CMAKE_ARGS + -DCMAKE_BUILD_TYPE=Release + -DCMAKE_TOOLCHAIN_FILE=${CMAKE_CURRENT_SOURCE_DIR}/htp/cmake-toolchain.cmake + -DCMAKE_INSTALL_LIBDIR=${CMAKE_CURRENT_BINARY_DIR} + -DHEXAGON_SDK_ROOT=${HEXAGON_SDK_ROOT} + -DHEXAGON_TOOLS_ROOT=${HEXAGON_TOOLS_ROOT} + -DHEXAGON_HTP_DEBUG=${GGML_HEXAGON_HTP_DEBUG} + -DGGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE} + -DDSP_VERSION=${V} + -DPREBUILT_LIB_DIR="toolv19_${V}") + list(APPEND HTP_SKELS ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-${V}.so) + set(HTP_SKELS ${HTP_SKELS} PARENT_SCOPE) +endfunction() + +build_htp_skel(v68) +build_htp_skel(v69) +build_htp_skel(v73) +build_htp_skel(v75) +build_htp_skel(v79) +build_htp_skel(v81) # Install Hexagon skels required at runtime -install(FILES - ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v68.so - ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v69.so - ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v73.so - ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v75.so - ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v79.so - ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v81.so - TYPE LIB) +install(FILES ${HTP_SKELS} TYPE LIB) + +if (CMAKE_SYSTEM_NAME MATCHES Windows AND GGML_HEXAGON_HTP_CERT) + file(TO_CMAKE_PATH "$ENV{WINDOWS_SDK_BIN}/arm64" WINSDK_BIN0_ARM64) + file(TO_CMAKE_PATH "$ENV{WINDOWS_SDK_BIN}/x86" WINSDK_BIN0_X86) + file(TO_CMAKE_PATH "$ENV{WindowsSdkVerBinPath}/arm64" WINSDK_BIN1_ARM64) + file(TO_CMAKE_PATH "$ENV{WindowsSdkVerBinPath}/x86" WINSDK_BIN1_X86) + + set(WINSDK_PATHS ${WINSDK_BIN0_ARM64} ${WINSDK_BIN0_X86} ${WINSDK_BIN1_ARM64} ${WINSDK_BIN1_X86}) + + find_program(INF2CAT NAMES inf2cat.exe PATHS ${WINSDK_PATHS} REQUIRED) + find_program(SIGNTOOL NAMES signtool.exe PATHS ${WINSDK_PATHS} REQUIRED) + + message(STATUS "hexagon: using ${GGML_HEXAGON_HTP_CERT} to sign libggml-htp skels") + + set(LIBGGML_HTP_CAT ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp.cat) + add_custom_target(libggml-htp-cat + BYPRODUCTS ${LIBGGML_HTP_CAT} + DEPENDS libggml-htp.inf ${HTP_SKELS} + COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_SOURCE_DIR}/libggml-htp.inf ${CMAKE_CURRENT_BINARY_DIR} + COMMAND ${INF2CAT} /driver:${CMAKE_CURRENT_BINARY_DIR} /os:10_25H2_ARM64 + COMMAND ${SIGNTOOL} sign /fd sha256 /f ${GGML_HEXAGON_HTP_CERT} ${LIBGGML_HTP_CAT} + COMMENT "generating and signing libggml-htp.cat file" + VERBATIM + ) + + add_dependencies(${TARGET_NAME} libggml-htp-cat) + install(FILES ${LIBGGML_HTP_CAT} TYPE LIB) +endif() diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 5b835c11c72..4f0a1620fbf 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -14,9 +14,6 @@ #ifdef _WIN32 # include -# ifndef _WINDOWS -# define _WINDOWS -# endif #else # include # include @@ -25,8 +22,6 @@ #pragma clang diagnostic ignored "-Wnested-anon-types" #pragma clang diagnostic ignored "-Wgnu-anonymous-struct" -#include "htp-utils.h" - #include #include #include @@ -40,6 +35,7 @@ #include "op-desc.h" #include "htp-msg.h" #include "htp_iface.h" +#include "htp-drv.h" static size_t opt_ndev = 1; static size_t opt_nhvx = 0; // use all @@ -150,9 +146,9 @@ void ggml_hexagon_session::enqueue(struct htp_general_req &req, struct dspqueue_ 0, // flags - the framework will autoset this n_bufs, // number of buffers bufs, // buffer references - sizeof(req), + sizeof(req), // Message length (const uint8_t *) &req, // Message - 1000000 // Timeout + DSPQUEUE_TIMEOUT // Timeout ); if (err != 0) { @@ -182,13 +178,13 @@ void ggml_hexagon_session::flush() { // Read response packet from queue int err = dspqueue_read(q, &flags, - HTP_MAX_PACKET_BUFFERS, // Maximum number of buffer references - &n_bufs, // Number of buffer references - bufs, // Buffer references - sizeof(rsp), // Max message length - &rsp_size, // Message length - (uint8_t *) &rsp, - 1000000); // Timeout + HTP_MAX_PACKET_BUFFERS, // Maximum number of buffer references + &n_bufs, // Number of buffer references + bufs, // Buffer references + sizeof(rsp), // Max message length + &rsp_size, // Message length + (uint8_t *) &rsp, // Message + DSPQUEUE_TIMEOUT); // Timeout if (err == AEE_EEXPIRED) { // TODO: might need to bail out if the HTP is stuck on something @@ -269,13 +265,7 @@ struct ggml_backend_hexagon_buffer_context { ggml_backend_hexagon_buffer_context(ggml_hexagon_session * sess, size_t size, bool repack) { size += 4 * 1024; // extra page for padding - if (rpcmem_alloc2) { - this->base = (uint8_t *) rpcmem_alloc2(RPCMEM_HEAP_ID_SYSTEM, RPCMEM_DEFAULT_FLAGS | RPCMEM_HEAP_NOREG, size); - } else { - GGML_LOG_INFO("ggml-hex: %s rpcmem_alloc2 not found, falling back to rpcmem_alloc\n", sess->name.c_str()); - this->base = (uint8_t *) rpcmem_alloc(RPCMEM_HEAP_ID_SYSTEM, RPCMEM_DEFAULT_FLAGS | RPCMEM_HEAP_NOREG, size); - } - + this->base = (uint8_t *) rpcmem_alloc2(RPCMEM_HEAP_ID_SYSTEM, RPCMEM_DEFAULT_FLAGS | RPCMEM_HEAP_NOREG, size); if (!this->base) { GGML_LOG_ERROR("ggml-hex: %s failed to allocate buffer : size %zu\n", sess->name.c_str(), size); throw std::runtime_error("ggml-hex: rpcmem_alloc failed (see log for details)"); @@ -2461,12 +2451,12 @@ static void ggml_backend_hexagon_free(ggml_backend_t backend) { } static inline bool op_reuse_src1(const ggml_tensor * op1, const ggml_tensor * op0) { - return (op0 && op0->src[1] == op1->src[1] && ggml_is_quantized(op0->src[0]->type) && ggml_is_quantized(op1->src[1]->type)); + return (op0 && op0->src[1] == op1->src[1] && ggml_is_quantized(op0->src[0]->type)); } static inline bool is_compute_op(ggml_tensor *node) { - return !(ggml_op_is_empty(node->op) || ggml_is_empty(node)); + return !ggml_op_is_empty(node->op) && !ggml_is_empty(node) && (node->flags & GGML_TENSOR_FLAG_COMPUTE); } // scan the graph and figure out last compute op index @@ -2488,7 +2478,7 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg const int last = last_compute_op(graph); - const struct ggml_tensor * prev_quant_op = nullptr; // prev executed op with quantizer + const struct ggml_tensor * prev_op = nullptr; // prev executed op for (int i = 0; i < graph->n_nodes; ++i) { ggml_tensor * node = graph->nodes[i]; @@ -2497,17 +2487,15 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg continue; } - if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { - continue; - } - uint32_t flags = 0; // skip quantizer if src1 is reused - if (op_reuse_src1(node, prev_quant_op)) { + if (op_reuse_src1(node, prev_op)) { flags |= HTP_OPFLAGS_SKIP_QUANTIZE; } + prev_op = node; + // ask for early notification for the last Op if (i == last) { flags |= HTP_OPFLAGS_EARLY_WAKEUP; @@ -2520,7 +2508,6 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg } else { ggml_hexagon_dispatch_op>(sess, node, flags); } - prev_quant_op = node; break; case GGML_OP_MUL_MAT_ID: if (ggml_is_quantized(node->src[0]->type)) { @@ -2528,7 +2515,6 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg } else { ggml_hexagon_dispatch_op>(sess, node, flags); } - prev_quant_op = node; break; case GGML_OP_MUL: case GGML_OP_ADD: @@ -2670,7 +2656,7 @@ static std::vector ggml_hexagon_graph_optimize_reorder(const std::vectorcontext = new ggml_hexagon_registry(reg); HEX_VERBOSE("ggml-hex: size-of-general-req %zu size-of-general-rsp %zu\n", sizeof(struct htp_general_req), @@ -3180,6 +3170,11 @@ ggml_backend_reg_t ggml_backend_hexagon_reg(void) { static std::mutex mutex; std::lock_guard lock(mutex); if (!initialized) { + auto nErr = htpdrv_init(); + if (nErr != AEE_SUCCESS) { + return NULL; + } + ggml_hexagon_init(®); } diff --git a/ggml/src/ggml-hexagon/htp-drv.cpp b/ggml/src/ggml-hexagon/htp-drv.cpp new file mode 100644 index 00000000000..2530bb06d6c --- /dev/null +++ b/ggml/src/ggml-hexagon/htp-drv.cpp @@ -0,0 +1,418 @@ +// sample drv interface + +#pragma clang diagnostic ignored "-Wgnu-anonymous-struct" +#pragma clang diagnostic ignored "-Wmissing-prototypes" +#pragma clang diagnostic ignored "-Wsign-compare" + +#include +#include +#include +#include +#ifdef _WIN32 +# define WIN32_LEAN_AND_MEAN +# ifndef NOMINMAX +# define NOMINMAX +# endif +# include +# include +#else +# include +# include +#endif +#include "ggml-impl.h" +#include "htp-drv.h" +#include "libdl.h" + +#include + +// +// Driver API types +// + +typedef void * (*rpcmem_alloc_pfn_t)(int heapid, uint32_t flags, int size); +typedef void * (*rpcmem_alloc2_pfn_t)(int heapid, uint32_t flags, size_t size); +typedef void (*rpcmem_free_pfn_t)(void * po); +typedef int (*rpcmem_to_fd_pfn_t)(void * po); + +typedef AEEResult (*dspqueue_create_pfn_t)(int domain, + uint32_t flags, + uint32_t req_queue_size, + uint32_t resp_queue_size, + dspqueue_callback_t packet_callback, + dspqueue_callback_t error_callback, + void * callback_context, + dspqueue_t * queue); +typedef AEEResult (*dspqueue_close_pfn_t)(dspqueue_t queue); +typedef AEEResult (*dspqueue_export_pfn_t)(dspqueue_t queue, uint64_t *queue_id); +typedef AEEResult (*dspqueue_write_pfn_t)(dspqueue_t queue, uint32_t flags, + uint32_t num_buffers, + struct dspqueue_buffer *buffers, + uint32_t message_length, + const uint8_t *message, + uint32_t timeout_us); +typedef AEEResult (*dspqueue_read_pfn_t)(dspqueue_t queue, uint32_t *flags, + uint32_t max_buffers, uint32_t *num_buffers, + struct dspqueue_buffer *buffers, + uint32_t max_message_length, + uint32_t *message_length, uint8_t *message, + uint32_t timeout_us); + +typedef int (*fastrpc_mmap_pfn_t)(int domain, int fd, void *addr, int offset, size_t length, enum fastrpc_map_flags flags); +typedef int (*fastrpc_munmap_pfn_t)(int domain, int fd, void *addr, size_t length); + +typedef int (*remote_handle64_open_pfn_t)(const char* name, remote_handle64 *ph); +typedef int (*remote_handle64_invoke_pfn_t)(remote_handle64 h, uint32_t dwScalars, remote_arg *pra); +typedef int (*remote_handle64_close_pfn_t)(remote_handle h); +typedef int (*remote_handle_control_pfn_t)(uint32_t req, void* data, uint32_t datalen); +typedef int (*remote_handle64_control_pfn_t)(remote_handle64 h, uint32_t req, void* data, uint32_t datalen); +typedef int (*remote_session_control_pfn_t)(uint32_t req, void *data, uint32_t datalen); + +// +// Driver API pfns +// + +rpcmem_alloc_pfn_t rpcmem_alloc_pfn = nullptr; +rpcmem_alloc2_pfn_t rpcmem_alloc2_pfn = nullptr; +rpcmem_free_pfn_t rpcmem_free_pfn = nullptr; +rpcmem_to_fd_pfn_t rpcmem_to_fd_pfn = nullptr; + +fastrpc_mmap_pfn_t fastrpc_mmap_pfn = nullptr; +fastrpc_munmap_pfn_t fastrpc_munmap_pfn = nullptr; + +dspqueue_create_pfn_t dspqueue_create_pfn = nullptr; +dspqueue_close_pfn_t dspqueue_close_pfn = nullptr; +dspqueue_export_pfn_t dspqueue_export_pfn = nullptr; +dspqueue_write_pfn_t dspqueue_write_pfn = nullptr; +dspqueue_read_pfn_t dspqueue_read_pfn = nullptr; + +remote_handle64_open_pfn_t remote_handle64_open_pfn = nullptr; +remote_handle64_invoke_pfn_t remote_handle64_invoke_pfn = nullptr; +remote_handle64_close_pfn_t remote_handle64_close_pfn = nullptr; +remote_handle_control_pfn_t remote_handle_control_pfn = nullptr; +remote_handle64_control_pfn_t remote_handle64_control_pfn = nullptr; +remote_session_control_pfn_t remote_session_control_pfn = nullptr; + +// +// Driver API +// + +void * rpcmem_alloc(int heapid, uint32_t flags, int size) { + return rpcmem_alloc_pfn(heapid, flags, size); +} + +void * rpcmem_alloc2(int heapid, uint32_t flags, size_t size) { + if (rpcmem_alloc2_pfn) { + return rpcmem_alloc2_pfn(heapid, flags, size); + } else { + GGML_LOG_INFO("ggml-hex: rpcmem_alloc2 not found, falling back to rpcmem_alloc\n"); + return rpcmem_alloc_pfn(heapid, flags, size); + } +} + +void rpcmem_free(void * po) { + return rpcmem_free_pfn(po); +} + +int rpcmem_to_fd(void * po) { + return rpcmem_to_fd_pfn(po); +} + +HTPDRV_API int fastrpc_mmap(int domain, int fd, void * addr, int offset, size_t length, enum fastrpc_map_flags flags) { + return fastrpc_mmap_pfn(domain, fd, addr, offset, length, flags); +} + +HTPDRV_API int fastrpc_munmap(int domain, int fd, void * addr, size_t length) { + return fastrpc_munmap_pfn(domain, fd, addr, length); +} + +AEEResult dspqueue_create(int domain, + uint32_t flags, + uint32_t req_queue_size, + uint32_t resp_queue_size, + dspqueue_callback_t packet_callback, + dspqueue_callback_t error_callback, + void * callback_context, + dspqueue_t * queue) { + return dspqueue_create_pfn(domain, flags, req_queue_size, resp_queue_size, packet_callback, error_callback, + callback_context, queue); +} + +AEEResult dspqueue_close(dspqueue_t queue) { + return dspqueue_close_pfn(queue); +} + +AEEResult dspqueue_export(dspqueue_t queue, uint64_t * queue_id) { + return dspqueue_export_pfn(queue, queue_id); +} + +AEEResult dspqueue_write(dspqueue_t queue, + uint32_t flags, + uint32_t num_buffers, + struct dspqueue_buffer * buffers, + uint32_t message_length, + const uint8_t * message, + uint32_t timeout_us) { + return dspqueue_write_pfn(queue, flags, num_buffers, buffers, message_length, message, timeout_us); +} + +AEEResult dspqueue_read(dspqueue_t queue, + uint32_t * flags, + uint32_t max_buffers, + uint32_t * num_buffers, + struct dspqueue_buffer * buffers, + uint32_t max_message_length, + uint32_t * message_length, + uint8_t * message, + uint32_t timeout_us) { + return dspqueue_read_pfn(queue, flags, max_buffers, num_buffers, buffers, max_message_length, message_length, + message, timeout_us); +} + +HTPDRV_API int remote_handle64_open(const char * name, remote_handle64 * ph) { + return remote_handle64_open_pfn(name, ph); +} + +HTPDRV_API int remote_handle64_invoke(remote_handle64 h, uint32_t dwScalars, remote_arg * pra) { + return remote_handle64_invoke_pfn(h, dwScalars, pra); +} + +HTPDRV_API int remote_handle64_close(remote_handle64 h) { + return remote_handle64_close_pfn(h); +} + +HTPDRV_API int remote_handle_control(uint32_t req, void * data, uint32_t datalen) { + return remote_handle_control_pfn(req, data, datalen); +} + +HTPDRV_API int remote_handle64_control(remote_handle64 h, uint32_t req, void * data, uint32_t datalen) { + return remote_handle64_control_pfn(h, req, data, datalen); +} + +HTPDRV_API int remote_session_control(uint32_t req, void * data, uint32_t datalen) { + return remote_session_control_pfn(req, data, datalen); +} + +#ifdef _WIN32 + +static std::string wstr_to_str(std::wstring_view wstr) { + std::string result; + if (wstr.empty()) { + return result; + } + auto bytes_needed = WideCharToMultiByte(CP_UTF8, WC_ERR_INVALID_CHARS, + wstr.data(), (int) wstr.size(), + nullptr, 0, nullptr, nullptr); + if (bytes_needed == 0) { + GGML_LOG_ERROR("ggml-hex: WideCharToMultiByte failed. Error %lu\n", GetLastError()); + throw std::runtime_error("Invalid wstring input"); + } + + result.resize(bytes_needed, '\0'); + int bytes_written = WideCharToMultiByte(CP_UTF8, WC_ERR_INVALID_CHARS, + wstr.data(), (int) wstr.size(), + result.data(), bytes_needed, + nullptr, nullptr); + if (bytes_written == 0) { + GGML_LOG_ERROR("ggml-hex: WideCharToMultiByte failed. Error %lu\n", GetLastError()); + throw std::runtime_error("Wstring conversion failed"); + } + return result; +} + +static std::string get_driver_path() { + std::wstring serviceName = L"qcnspmcdm"; + std::string result; + + // Get a handle to the SCM database. + SC_HANDLE schSCManager = OpenSCManagerW(NULL, NULL, STANDARD_RIGHTS_READ); + if (nullptr == schSCManager) { + GGML_LOG_ERROR("ggml-hex: Failed to open SCManager. Error: %lu\n", GetLastError()); + return result; + } + + // Get a handle to the service. + SC_HANDLE schService = OpenServiceW(schSCManager, // SCM database + serviceName.c_str(), // name of service + SERVICE_QUERY_CONFIG); // need query config access + + if (nullptr == schService) { + GGML_LOG_ERROR("ggml-hex: Failed to open qcnspmcdm service. Error: %lu\n", GetLastError()); + CloseServiceHandle(schSCManager); + return result; + } + + // Store the size of buffer used as an output. + DWORD bufferSize; + if (!QueryServiceConfigW(schService, NULL, 0, &bufferSize) && + (GetLastError() != ERROR_INSUFFICIENT_BUFFER)) { + GGML_LOG_ERROR("ggml-hex: Failed to query service config. Error: %lu\n", GetLastError()); + CloseServiceHandle(schService); + CloseServiceHandle(schSCManager); + return result; + } + // Get the configuration of the service. + LPQUERY_SERVICE_CONFIGW serviceConfig = + static_cast(LocalAlloc(LMEM_FIXED, bufferSize)); + if (!QueryServiceConfigW(schService, serviceConfig, bufferSize, &bufferSize)) { + fprintf(stderr, "ggml-hex: Failed to query service config. Error: %lu\n", GetLastError()); + LocalFree(serviceConfig); + CloseServiceHandle(schService); + CloseServiceHandle(schSCManager); + return result; + } + + // Read the driver file path get its parent directory + std::wstring driverPath = std::wstring(serviceConfig->lpBinaryPathName); + driverPath = driverPath.substr(0, driverPath.find_last_of(L"\\")); + + // Clean up resources + LocalFree(serviceConfig); + CloseServiceHandle(schService); + CloseServiceHandle(schSCManager); + + // Driver path would contain invalid path string, like: + // \SystemRoot\System32\DriverStore\FileRepository\qcadsprpc8280.inf_arm64_c2b9460c9a072f37 + // "\SystemRoot" should be replace with a correct one (e.g. C:\Windows) + const std::wstring systemRootPlaceholder = L"\\SystemRoot"; + if (0 != driverPath.compare(0, systemRootPlaceholder.length(), systemRootPlaceholder)) { + GGML_LOG_ERROR("ggml-hex: String pattern not found in driver path.\n"); + return result; + } + + // Replace \SystemRoot with an absolute path from system ENV windir + const std::wstring systemRootEnv = L"windir"; + + // Query the number of wide charactors this variable requires + DWORD numWords = GetEnvironmentVariableW(systemRootEnv.c_str(), NULL, 0); + if (numWords == 0) { + GGML_LOG_ERROR("ggml-hex: Failed get systemRoot environment variable\n"); + return result; + } + + // Query the actual system root name from environment variable + std::vector systemRoot(numWords + 1); + numWords = GetEnvironmentVariableW(systemRootEnv.c_str(), systemRoot.data(), numWords + 1); + if (numWords == 0) { + GGML_LOG_ERROR("ggml-hex: Failed to read windir environment variable\n"); + return result; + } + driverPath.replace(0, systemRootPlaceholder.length(), std::wstring(systemRoot.data())); + + return wstr_to_str(driverPath); +} + +#endif + +using dl_handle_ptr = std::unique_ptr; + +int htpdrv_init() { + static dl_handle_ptr lib_cdsp_rpc_handle = nullptr; + static bool initialized = false; +#ifdef _WIN32 + std::string drv_path = get_driver_path() + "\\" + "libcdsprpc.dll"; +#else + std::string drv_path = "libcdsprpc.so"; +#endif + if (initialized) { + GGML_LOG_INFO("ggml-hex: Driver already loaded\n"); + return AEE_SUCCESS; + } + GGML_LOG_INFO("ggml-hex: Loading driver %s\n", drv_path.c_str()); + + fs::path path{ drv_path.c_str() }; + dl_handle_ptr handle { dl_load_library(path) }; + if (!handle) { + GGML_LOG_ERROR("ggml-hex: failed to load %s: %s\n", path.u8string().c_str(), dl_error()); + return AEE_EUNABLETOLOAD; + } + +#define dlsym(drv, type, pfn, symbol, ignore) \ + do { \ + pfn = (type) dl_get_sym(drv, #symbol); \ + if (!ignore && nullptr == pfn) { \ + GGML_LOG_ERROR("ggml-hex: failed to dlsym %s\n", #symbol); \ + return AEE_EUNABLETOLOAD; \ + } \ + } while (0) + + dlsym(handle.get(), rpcmem_alloc_pfn_t, rpcmem_alloc_pfn, rpcmem_alloc, false); + dlsym(handle.get(), rpcmem_alloc2_pfn_t, rpcmem_alloc2_pfn, rpcmem_alloc2, true); + dlsym(handle.get(), rpcmem_free_pfn_t, rpcmem_free_pfn, rpcmem_free, false); + dlsym(handle.get(), rpcmem_to_fd_pfn_t, rpcmem_to_fd_pfn, rpcmem_to_fd, false); + dlsym(handle.get(), fastrpc_mmap_pfn_t, fastrpc_mmap_pfn, fastrpc_mmap, false); + dlsym(handle.get(), fastrpc_munmap_pfn_t, fastrpc_munmap_pfn, fastrpc_munmap, false); + dlsym(handle.get(), dspqueue_create_pfn_t, dspqueue_create_pfn, dspqueue_create, false); + dlsym(handle.get(), dspqueue_close_pfn_t, dspqueue_close_pfn, dspqueue_close, false); + dlsym(handle.get(), dspqueue_export_pfn_t, dspqueue_export_pfn, dspqueue_export, false); + dlsym(handle.get(), dspqueue_write_pfn_t, dspqueue_write_pfn, dspqueue_write, false); + dlsym(handle.get(), dspqueue_read_pfn_t, dspqueue_read_pfn, dspqueue_read, false); + dlsym(handle.get(), remote_handle64_open_pfn_t, remote_handle64_open_pfn, remote_handle64_open, false); + dlsym(handle.get(), remote_handle64_invoke_pfn_t, remote_handle64_invoke_pfn, remote_handle64_invoke, false); + dlsym(handle.get(), remote_handle_control_pfn_t, remote_handle_control_pfn, remote_handle_control, false); + dlsym(handle.get(), remote_handle64_control_pfn_t, remote_handle64_control_pfn, remote_handle64_control, false); + dlsym(handle.get(), remote_session_control_pfn_t, remote_session_control_pfn, remote_session_control, false); + dlsym(handle.get(), remote_handle64_close_pfn_t, remote_handle64_close_pfn, remote_handle64_close, false); + + lib_cdsp_rpc_handle = std::move(handle); + initialized = true; + + return AEE_SUCCESS; +} + +domain * get_domain(int domain_id) { + int i = 0; + int size = sizeof(supported_domains) / sizeof(domain); + + for (i = 0; i < size; i++) { + if (supported_domains[i].id == domain_id) { + return &supported_domains[i]; + } + } + + return NULL; +} + +int get_hex_arch_ver(int domain, int * arch) { + if (!remote_handle_control_pfn) { + GGML_LOG_ERROR("ggml-hex: remote_handle_control is not supported on this device\n"); + return AEE_EUNSUPPORTEDAPI; + } + + struct remote_dsp_capability arch_ver; + arch_ver.domain = (uint32_t) domain; + arch_ver.attribute_ID = ARCH_VER; + arch_ver.capability = (uint32_t) 0; + + int err = remote_handle_control(DSPRPC_GET_DSP_INFO, &arch_ver, sizeof(arch_ver)); + if ((err & 0xff) == (AEE_EUNSUPPORTEDAPI & 0xff)) { + GGML_LOG_ERROR("ggml-hex: FastRPC capability API is not supported on this device\n"); + return AEE_EUNSUPPORTEDAPI; + } + + if (err != AEE_SUCCESS) { + GGML_LOG_ERROR("ggml-hex: FastRPC capability query failed (err %d)\n", err); + return err; + } + + switch (arch_ver.capability & 0xff) { + case 0x68: + *arch = 68; + return 0; + case 0x69: + *arch = 69; + return 0; + case 0x73: + *arch = 73; + return 0; + case 0x75: + *arch = 75; + return 0; + case 0x79: + *arch = 79; + return 0; + case 0x81: + *arch = 81; + return 0; + } + return -1; +} diff --git a/ggml/src/ggml-hexagon/htp-drv.h b/ggml/src/ggml-hexagon/htp-drv.h new file mode 100644 index 00000000000..6eba7ba17d8 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp-drv.h @@ -0,0 +1,121 @@ +#pragma once + +#ifdef __cplusplus +extern "C" { +#endif + +#ifdef _WIN32 +# pragma clang diagnostic ignored "-Wignored-attributes" +#endif + +#include +#include +#include +#include + +#if defined(_WIN32) && !defined(__MINGW32__) +# ifdef GGML_BACKEND_BUILD +# define HTPDRV_API __declspec(dllexport) extern +# else +# define HTPDRV_API __declspec(dllimport) extern +# endif +#else +# define HTPDRV_API __attribute__ ((visibility ("default"))) extern +#endif + +/* Offset to differentiate HLOS and Hexagon error codes. + Stores the value of AEE_EOFFSET for Hexagon. */ +#ifndef DSP_OFFSET +# define DSP_OFFSET 0x80000400 +#endif + +/* Errno for connection reset by peer. */ +#ifndef ECONNRESET +# ifdef __hexagon__ +# define ECONNRESET 104 +# endif +#endif + +/* Abstraction of different OS specific sleep APIs. + SLEEP accepts input in seconds. */ +#ifndef SLEEP +# ifdef __hexagon__ +# define SLEEP(x) \ + { /* Do nothing for simulator. */ \ + } +# else +# ifdef _WIN32 +# define SLEEP(x) Sleep(1000 * x) /* Sleep accepts input in milliseconds. */ +# else +# define SLEEP(x) sleep(x) /* sleep accepts input in seconds. */ +# endif +# endif +#endif + +/* Include windows specific header files. */ +#ifdef _WIN32 +# include +# include +# define _CRT_SECURE_NO_WARNINGS 1 +# define _WINSOCK_DEPRECATED_NO_WARNINGS 1 +#endif + +/* Includes and defines for all HLOS except windows */ +#if !defined(__hexagon__) && !defined(_WIN32) +# include "unistd.h" + +# include +#endif + +/* Includes and defines for Hexagon and all HLOS except Windows. */ +#if !defined(_WIN32) +/* Weak reference to remote symbol for compilation. */ +# pragma weak remote_session_control +# pragma weak remote_handle_control +# pragma weak remote_handle64_control +# pragma weak fastrpc_mmap +# pragma weak fastrpc_munmap +# pragma weak rpcmem_alloc2 +#endif + +#if !defined(_WIN32) +# pragma weak remote_system_request +#endif + +#ifdef _WIN32 +# define DSPQUEUE_TIMEOUT DSPQUEUE_TIMEOUT_NONE +#else +# define DSPQUEUE_TIMEOUT 1000000 +#endif + +/** + * htpdrv_init API: driver interface entry point + * + * @return Return AEE error codes as defined in Hexagon SDK. + */ +HTPDRV_API int htpdrv_init(void); + +/** + * get_domain API: get domain struct from domain value. + * + * @param[in] domain value of a domain + * @return Returns domain struct of the domain if it is supported or else + * returns NULL. + * + */ +HTPDRV_API domain * get_domain(int domain_id); + +/** + * get_hex_arch_ver API: query the Hexagon processor architecture version information + * + * @param[in] domain_id value of a domain + * @param[out] Arch version (73, 75, ...) + * @return 0 if query is successful. + * non-zero if error, return value points to the error. + * + */ +HTPDRV_API int get_hex_arch_ver(int domain, int * arch); + +#ifdef __cplusplus +} +#endif diff --git a/ggml/src/ggml-hexagon/htp-utils.c b/ggml/src/ggml-hexagon/htp-utils.c deleted file mode 100644 index 3f335bf71c0..00000000000 --- a/ggml/src/ggml-hexagon/htp-utils.c +++ /dev/null @@ -1,454 +0,0 @@ - -#pragma clang diagnostic ignored "-Wgnu-anonymous-struct" -#pragma clang diagnostic ignored "-Wmissing-prototypes" -#pragma clang diagnostic ignored "-Wsign-compare" - -#define GGML_COMMON_IMPL_C -#include "ggml-backend-impl.h" -#include "ggml-common.h" -#include "ggml-hexagon.h" -#include "ggml-impl.h" - -#include "htp-utils.h" - -#include -#include -#include -#include -#include -#include -#include - -domain * get_domain(int domain_id) { - int i = 0; - int size = sizeof(supported_domains) / sizeof(domain); - - for (i = 0; i < size; i++) { - if (supported_domains[i].id == domain_id) { - return &supported_domains[i]; - } - } - - return NULL; -} - -bool is_valid_domain_id(int domain_id, int compute_only) { - int i = 0; - int size = sizeof(supported_domains) / sizeof(domain); - - if (compute_only) { - return is_CDSP(domain_id); - } - - for (i = 0; i < size; i++) { - if (supported_domains[i].id == domain_id) { - return true; - } - } - - return false; -} - -int get_domains_info(char * domain_type, int * num_domains, fastrpc_domain ** domains_info) { - int nErr = AEE_SUCCESS; - int ss_info = 0; - if (domain_type != NULL) { - if (strcmp(domain_type, "LPASS") == 0) { - ss_info = FASTRPC_LPASS; - } else if (strcmp(domain_type, "HPASS") == 0) { - ss_info = FASTRPC_HPASS; - } else { - ss_info = FASTRPC_NSP; - } - } - system_req_payload req = { 0 }; - req.id = FASTRPC_GET_DOMAINS; - req.sys.domains = NULL; - fastrpc_domain * domain = NULL; - if (ss_info != 0) { - req.sys.flags = DOMAINS_LIST_FLAGS_SET_TYPE(req.sys.flags, ss_info); - } else { - req.sys.flags = 0; - } -#ifdef _WIN32 - nErr = AEE_EUNSUPPORTED; - goto bail; -#endif - if (remote_system_request) { - nErr = remote_system_request(&req); - if (nErr != AEE_SUCCESS) { - GGML_LOG_ERROR("Failure in remote_system_request call: %d.\n", nErr); - goto bail; - } - // Allocate memory for domain-info array - req.sys.max_domains = req.sys.num_domains; - if ((req.sys.domains = calloc(req.sys.num_domains, sizeof(fastrpc_domain))) == NULL) { - nErr = AEE_ENOMEMORY; - GGML_LOG_ERROR("Unable to allocate memory for req.sys.domains"); - goto bail; - } - - nErr = remote_system_request(&req); - if (nErr != AEE_SUCCESS) { - GGML_LOG_ERROR("Failure in remote_system_request call: %d.\n", nErr); - goto bail; - } - - for (int i = 0; i < req.sys.num_domains; i++) { - // Verify that only requested type domains were returned - domain = &req.sys.domains[i]; - if (domain->type != ss_info && domain_type != NULL) { - nErr = -1; - GGML_LOG_ERROR("Incorrect data received from remote_system_request.\n"); - goto bail; - } - } - *domains_info = req.sys.domains; - *num_domains = req.sys.num_domains; - } else { - nErr = AEE_EUNSUPPORTED; - goto bail; - } -bail: - if (nErr && !req.sys.domains) { - free(req.sys.domains); - } - return nErr; -} - -int get_effective_domain_id(char * domain_name, int session_id, int * effec_domain_id) { - int err = 0; - remote_rpc_effective_domain_id_t sess = { 0 }; - - sess.domain_name = domain_name; - sess.domain_name_len = strlen(domain_name); - sess.session_id = session_id; - - err = remote_session_control(FASTRPC_GET_EFFECTIVE_DOMAIN_ID, &sess, sizeof(sess)); - if (err) { - GGML_LOG_ERROR("Error 0x%x: failed to get effective domain id for %s, session id %d\n", err, sess.domain_name, - session_id); - return err; - } - - *effec_domain_id = sess.effective_domain_id; - return err; -} - -int get_dsp_support(int * domain) { - int nErr = AEE_SUCCESS; - *domain = CDSP_DOMAIN_ID; // DSP domain default value is CDSP_DOMAIN_ID - - if (remote_handle_control) { - struct remote_dsp_capability dsp_capability_domain = { CDSP_DOMAIN_ID, DOMAIN_SUPPORT, 0 }; - nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_domain, sizeof(struct remote_dsp_capability)); - if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) { - GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n"); - goto bail; - } - - if (dsp_capability_domain.capability == 0) { - dsp_capability_domain.domain = ADSP_DOMAIN_ID; // Check for ADSP support. - dsp_capability_domain.attribute_ID = DOMAIN_SUPPORT; - dsp_capability_domain.capability = 0; - nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_domain, - sizeof(struct remote_dsp_capability)); - if (dsp_capability_domain.capability) { - *domain = ADSP_DOMAIN_ID; // For targets like Agatti (not having cDSP), domain is ADSP_DOMAIN_ID - } - } - - if (nErr != AEE_SUCCESS) { - GGML_LOG_ERROR("\nget_dsp_support failed with Error 0x%x\n", nErr); - goto bail; - } - } else { - nErr = AEE_EUNSUPPORTEDAPI; - GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n"); - } - -bail: - return nErr; -} - -int get_vtcm_info(int domain, uint32_t * capability, uint32_t attr) { - int nErr = AEE_SUCCESS; - *capability = 0; - - if (attr == VTCM_PAGE || attr == VTCM_COUNT) { - } else { - nErr = AEE_EBADPARM; - GGML_LOG_ERROR("Unsupported attr. Only VTCM_PAGE and VTCM_COUNT supported\n"); - goto bail; - } - if (remote_handle_control) { - if (domain == ADSP_DOMAIN_ID || domain == CDSP_DOMAIN_ID) { - /* - * Query the DSP for VTCM information - * Since the ADSP does not have a dedicated VTCM, we expect the output to be 0 - */ - struct remote_dsp_capability dsp_capability_vtcm_dsp; - dsp_capability_vtcm_dsp.domain = (uint32_t) domain; - dsp_capability_vtcm_dsp.attribute_ID = attr; - dsp_capability_vtcm_dsp.capability = (uint32_t) 0; - nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_vtcm_dsp, - sizeof(struct remote_dsp_capability)); - if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) { - GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n"); - GGML_LOG_ERROR("Running the usecase without checking the capability\n"); - nErr = AEE_SUCCESS; - goto bail; - } else if (nErr == AEE_SUCCESS) { - *capability = dsp_capability_vtcm_dsp.capability; - } else { - GGML_LOG_ERROR("\nget_vtcm_info failed with Error 0x%x\n", nErr); - goto bail; - } - } else { - nErr = AEE_EUNSUPPORTED; - GGML_LOG_ERROR("Unsupported domain %d\n", domain); - goto bail; - } - } else { - nErr = AEE_EUNSUPPORTEDAPI; - GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n"); - } - -bail: - return nErr; -} - -bool is_unsignedpd_supported(int domain_id) { - int nErr = AEE_SUCCESS; - if (remote_handle_control) { - struct remote_dsp_capability dsp_capability_domain = { domain_id, UNSIGNED_PD_SUPPORT, 0 }; - nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_domain, sizeof(struct remote_dsp_capability)); - if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) { - GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device. Falling back to signed pd.\n"); - return false; - } - if (nErr) { - GGML_LOG_ERROR("\nERROR 0x%x: FastRPC Capability API failed. Falling back to signed pd.", nErr); - return false; - } - if (dsp_capability_domain.capability == 1) { - return true; - } - } else { - nErr = AEE_EUNSUPPORTEDAPI; - GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device. Falling back to signed pd.\n"); - return false; - } - return false; -} - -bool get_unsignedpd_support(void) { - return is_unsignedpd_supported(CDSP_DOMAIN_ID); -} - -bool is_async_fastrpc_supported(int domain) { - int nErr = AEE_SUCCESS; - if (remote_handle_control) { - if (domain == CDSP_DOMAIN_ID) { - /* - * Query the DSP for ASYNC_FASTRPC_SUPPORT information - * Async fastrpc is supported only on CDSP - */ - struct remote_dsp_capability dsp_capability_async_support; - dsp_capability_async_support.domain = (uint32_t) domain; - dsp_capability_async_support.attribute_ID = ASYNC_FASTRPC_SUPPORT; - dsp_capability_async_support.capability = (uint32_t) 0; - nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_async_support, - sizeof(struct remote_dsp_capability)); - if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) { - GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n"); - GGML_LOG_ERROR("Running the usecase without checking the capability\n"); - nErr = AEE_SUCCESS; - goto bail; - } else if (dsp_capability_async_support.capability == 1) { - return true; - } - if (nErr != AEE_SUCCESS) { - GGML_LOG_ERROR("\nis_async_fastrpc_supported failed with Error 0x%x\n", nErr); - goto bail; - } - } else { - nErr = AEE_EUNSUPPORTED; - GGML_LOG_ERROR("Async fastrpc is not supported on domain %d\n", domain); - goto bail; - } - } else { - nErr = AEE_EUNSUPPORTEDAPI; - GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n"); - } - -bail: - return false; -} - -bool is_status_notification_supported(int domain) { - int nErr = AEE_SUCCESS; - - if (remote_handle_control) { - /* - * Query the DSP for STATUS_NOTIFICATION_SUPPORT information - * DSP User PD status notification Support - */ - struct remote_dsp_capability dsp_capability_status_notification_support; - dsp_capability_status_notification_support.domain = (uint32_t) domain; - dsp_capability_status_notification_support.attribute_ID = STATUS_NOTIFICATION_SUPPORT; - dsp_capability_status_notification_support.capability = (uint32_t) 0; - nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_status_notification_support, - sizeof(struct remote_dsp_capability)); - if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) { - GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n"); - GGML_LOG_ERROR("Running the usecase without checking the capability\n"); - nErr = AEE_SUCCESS; - goto bail; - } else if (dsp_capability_status_notification_support.capability == 1) { - return true; - } - if (nErr != AEE_SUCCESS) { - GGML_LOG_ERROR("\nis_status_notification_supported failed with Error 0x%x\n", nErr); - goto bail; - } - } else { - nErr = AEE_EUNSUPPORTEDAPI; - GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n"); - } - -bail: - return false; -} - -int get_hmx_support_info(int domain, uint32_t * capability, uint32_t attr) { - int nErr = AEE_SUCCESS; - *capability = 0; - - if (attr != HMX_SUPPORT_SPATIAL && attr != HMX_SUPPORT_DEPTH) { - nErr = AEE_EBADPARM; - GGML_LOG_ERROR("Unsupported attr. Only HMX_SUPPORT_SPATIAL and HMX_SUPPORT_DEPTH supported\n"); - goto bail; - } - if (remote_handle_control) { - if (domain == CDSP_DOMAIN_ID) { - /* - * Query the DSP for HMX SUPPORT information - * HMX is supported on CDSP only - */ - struct remote_dsp_capability dsp_capability_hmx_dsp; - dsp_capability_hmx_dsp.domain = (uint32_t) domain; - dsp_capability_hmx_dsp.attribute_ID = attr; - dsp_capability_hmx_dsp.capability = (uint32_t) 0; - nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_hmx_dsp, - sizeof(struct remote_dsp_capability)); - if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) { - GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n"); - GGML_LOG_ERROR("Running the usecase without checking the capability\n"); - nErr = AEE_SUCCESS; - goto bail; - } else if (nErr == AEE_SUCCESS) { - *capability = dsp_capability_hmx_dsp.capability; - } else { - GGML_LOG_ERROR("\nget_hmx_support_info failed with Error 0x%x\n", nErr); - goto bail; - } - } else { - nErr = AEE_EUNSUPPORTED; - GGML_LOG_ERROR("HMX support is not there for domain %d\n", domain); - goto bail; - } - } else { - nErr = AEE_EUNSUPPORTEDAPI; - GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n"); - } - -bail: - return nErr; -} - -int get_hex_arch_ver(int domain, int * arch) { - if (!remote_handle_control) { - GGML_LOG_ERROR("ggml-hex: remote_handle_control is not supported on this device\n"); - return AEE_EUNSUPPORTEDAPI; - } - - struct remote_dsp_capability arch_ver; - arch_ver.domain = (uint32_t) domain; - arch_ver.attribute_ID = ARCH_VER; - arch_ver.capability = (uint32_t) 0; - - int err = remote_handle_control(DSPRPC_GET_DSP_INFO, &arch_ver, sizeof(arch_ver)); - if ((err & 0xff) == (AEE_EUNSUPPORTEDAPI & 0xff)) { - GGML_LOG_ERROR("ggml-hex: FastRPC capability API is not supported on this device\n"); - return AEE_EUNSUPPORTEDAPI; - } - - if (err != AEE_SUCCESS) { - GGML_LOG_ERROR("ggml-hex: FastRPC capability query failed (err %d)\n", err); - return err; - } - - switch (arch_ver.capability & 0xff) { - case 0x68: - *arch = 68; - return 0; - case 0x69: - *arch = 69; - return 0; - case 0x73: - *arch = 73; - return 0; - case 0x75: - *arch = 75; - return 0; - case 0x79: - *arch = 79; - return 0; - case 0x81: - *arch = 81; - return 0; - } - return -1; -} - -int get_hvx_support_info(int domain, uint32_t * capability, uint32_t attr) { - int nErr = AEE_SUCCESS; - *capability = 0; - - if (remote_handle_control) { - if (domain == CDSP_DOMAIN_ID) { - /* - * Query the DSP for HVX SUPPORT information - * HVX is supported on CDSP only - */ - struct remote_dsp_capability dsp_capability_hvx_dsp; - dsp_capability_hvx_dsp.domain = (uint32_t) domain; - dsp_capability_hvx_dsp.attribute_ID = attr; - dsp_capability_hvx_dsp.capability = (uint32_t) 0; - nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_hvx_dsp, - sizeof(struct remote_dsp_capability)); - if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) { - GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n"); - GGML_LOG_ERROR("Running the usecase without checking the capability\n"); - nErr = AEE_SUCCESS; - goto bail; - } else if (nErr == AEE_SUCCESS) { - *capability = dsp_capability_hvx_dsp.capability; - } else { - GGML_LOG_ERROR("\nget_hvx_support_info failed with Error 0x%x\n", nErr); - goto bail; - } - } else { - nErr = AEE_EUNSUPPORTED; - GGML_LOG_ERROR("HVX support is not available on domain %d\n", domain); - goto bail; - } - } else { - nErr = AEE_EUNSUPPORTEDAPI; - GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n"); - } - -bail: - return nErr; -} diff --git a/ggml/src/ggml-hexagon/htp-utils.h b/ggml/src/ggml-hexagon/htp-utils.h deleted file mode 100644 index 7bbae3a0b73..00000000000 --- a/ggml/src/ggml-hexagon/htp-utils.h +++ /dev/null @@ -1,221 +0,0 @@ -#ifndef HTP_UTILS_H -#define HTP_UTILS_H - -#ifdef __cplusplus -extern "C" { -#endif - -#include -#include -#include -#include -#include - -/* Offset to differentiate HLOS and Hexagon error codes. - Stores the value of AEE_EOFFSET for Hexagon. */ -#ifndef DSP_OFFSET -# define DSP_OFFSET 0x80000400 -#endif - -/* Errno for connection reset by peer. */ -#ifndef ECONNRESET -# ifdef __hexagon__ -# define ECONNRESET 104 -# endif -#endif - -/* Abstraction of different OS specific sleep APIs. - SLEEP accepts input in seconds. */ -#ifndef SLEEP -# ifdef __hexagon__ -# define SLEEP(x) \ - { /* Do nothing for simulator. */ \ - } -# else -# ifdef _WINDOWS -# define SLEEP(x) Sleep(1000 * x) /* Sleep accepts input in milliseconds. */ -# else -# define SLEEP(x) sleep(x) /* sleep accepts input in seconds. */ -# endif -# endif -#endif - -/* Include windows specific header files. */ -#ifdef _WINDOWS -# include -# include -# define _CRT_SECURE_NO_WARNINGS 1 -# define _WINSOCK_DEPRECATED_NO_WARNINGS 1 -/* Including this file for custom implementation of getopt function. */ -# include "getopt_custom.h" -#endif - -/* Includes and defines for all HLOS except windows */ -#if !defined(__hexagon__) && !defined(_WINDOWS) -# include "unistd.h" - -# include -#endif - -/* Includes and defines for Hexagon and all HLOS except Windows. */ -#if !defined(_WINDOWS) -/* Weak reference to remote symbol for compilation. */ -# pragma weak remote_session_control -# pragma weak remote_handle_control -# pragma weak remote_handle64_control -# pragma weak fastrpc_mmap -# pragma weak fastrpc_munmap -# pragma weak rpcmem_alloc2 -#endif - -#if !defined(_WINDOWS) -# pragma weak remote_system_request -#endif -/** - * Wrapper for FastRPC Capability API: query DSP support. - * - * @param[out] domain pointer to supported domain. - * @return 0 if query is successful. - * non-zero if error, return value points to the error. - */ -int get_dsp_support(int * domain); - -/** - * Wrapper for FastRPC Capability API: query VTCM information. - * - * @param[in] domain value of domain in the queried. - * @param[out] capability capability value of the attribute queried. - * @param[in] attr value of the attribute to the queried. - * @return 0 if query is successful. - * non-zero if error, return value points to the error. - */ -int get_vtcm_info(int domain, uint32_t * capability, uint32_t attr); - -/** - * Wrapper for FastRPC Capability API: query unsigned pd support on CDSP domain. - * - * @return true if unsigned pd is supported. - * false if unsigned pd is not supported, capability query failed. - */ - -bool get_unsignedpd_support(void); - -/** - * Wrapper for FastRPC Capability API: query unsigned pd support. - * - * @param[in] domain value of domain in the queried. - * @return true if unsigned pd is supported. - * false if unsigned pd is not supported, capability query failed. - */ - -bool is_unsignedpd_supported(int domain_id); - -/** - * is_valid_domain_id API: query a domain id is valid. - * - * @param[in] domain value of domain in the queried. - * @param[in] compute_only value of domain is only compared with CDSP domains supported by the target when enabled. - * @return true if value of domain is valid. - * false if value of domain is not valid. - */ - -bool is_valid_domain_id(int domain_id, int compute_only); - -/** - * get_domain API: get domain struct from domain value. - * - * @param[in] domain value of a domain - * @return Returns domain struct of the domain if it is supported or else - * returns NULL. - * - */ - -domain * get_domain(int domain_id); - -/** - * get_domains_info API: get information for all the domains available on the device - * - * @param[in] domain_type pointer to domain type - * @param[in] num_domains pointer to number of domains - * @param[in] domains_info pointer to save discovered domains information. - * @return 0 if query is successful. - * non-zero if error, return value points to the error. - * - * It is user's responsibility to free the memory used to store the domains info whose address is present in domains_info before closing the application. - * - */ - -int get_domains_info(char * domain_type, int * num_domains, fastrpc_domain ** domains_info); - -/** - * get_effective_domain_id API: get effective domain id for given session id - * - * @param[in] domain_name pointer to domain name - * @param[in] session_id - * @param[in] effec_domain_id pointer to save obtained effective domain id. - * @return 0 if query is successful. - * non-zero if error, return value points to the error. - * - */ - -int get_effective_domain_id(char * domain_name, int session_id, int * effec_domain_id); - -/** - * is_async_fastrpc_supported API: query a domain id has async fastrpc supported or not - * - * @param[in] domain_id value of a domain - * @return Returns true or false stating support of Async FastRPC - * - */ - -bool is_async_fastrpc_supported(int domain_id); - -/** - * is_status_notification_supported API: query the DSP for STATUS_NOTIFICATION_SUPPORT information - * - * @param[in] domain_id value of a domain - * @return Returns true or false stating status notification support information - * - */ -bool is_status_notification_supported(int domain_id); - -/** - * get_hmx_support_info API: query the DSP for HMX SUPPORT information - * - * @param[in] domain_id value of a domain - * @param[out] capability capability value of the attribute queried. - * @param[in] attr value of the attribute to the queried. - * @return 0 if query is successful. - * non-zero if error, return value points to the error. - * - */ -int get_hmx_support_info(int domain, uint32_t * capability, uint32_t attr); - -/** - * get_hex_arch_ver API: query the Hexagon processor architecture version information - * - * @param[in] domain_id value of a domain - * @param[out] Arch version (73, 75, ...) - * @return 0 if query is successful. - * non-zero if error, return value points to the error. - * - */ -int get_hex_arch_ver(int domain, int * arch); - -/** - * get_hvx_support_info API: query the DSP for HVX SUPPORT information - * - * @param[in] domain_id value of a domain - * @param[out] capability capability value of the attribute queried. - * @param[in] attr value of the attribute to the queried. - * @return 0 if query is successful. - * non-zero if error, return value points to the error. - * - */ -int get_hvx_support_info(int domain, uint32_t * capability, uint32_t attr); - -#ifdef __cplusplus -} -#endif - -#endif //DSP_CAPABILITIES_UTILS_H diff --git a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c index 1de47d0f3d4..c1846374437 100644 --- a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +++ b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c @@ -2,9 +2,9 @@ #pragma clang diagnostic ignored "-Wunused-function" #pragma clang diagnostic ignored "-Wunused-but-set-variable" +#include #include #include - #include #include @@ -17,6 +17,12 @@ #include "htp-msg.h" #include "htp-ops.h" +static inline HVX_Vector hvx_load_f32_to_f16(const HVX_Vector * restrict src, const HVX_Vector zero) { + HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(src[0], zero); // 32 elements + HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(src[1], zero); // 32 elements + return Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf))); +} + // Dot product of FP32 and FP16 vectors, accumulating to float static inline void hvx_dot_f32_f16_aa(float * restrict r, const void * restrict y, const void * restrict x, unsigned int n, float s) { const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp32 @@ -33,23 +39,19 @@ static inline void hvx_dot_f32_f16_aa(float * restrict r, const void * restrict #pragma unroll(4) for (i = 0; i < nvec; i++) { // Load y (fp32) and convert into fp16 - HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements - HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements - HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf))); + HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero); // Load x (fp16) HVX_Vector x_hf = vx[i]; HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); - rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); + rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum)); } if (nloe) { // Load y (fp32) and convert into fp16 - HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements - HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements - HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf))); + HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero); // Load x (fp16) HVX_Vector x_hf = vx[i]; @@ -62,13 +64,72 @@ static inline void hvx_dot_f32_f16_aa(float * restrict r, const void * restrict HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); - rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); + rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum)); + } + + rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum)); + hvx_vec_store_u(r, 4, Q6_Vsf_equals_Vqf32(rsum)); +} + +// Dot product of FP32 and FP16 vectors, accumulating to float +static inline void hvx_dot_f32_f16_aa_rx2(float * restrict r, + const void * restrict y, + const void * restrict x0, + const void * restrict x1, + unsigned int n, + float s) { + const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp32 + const HVX_Vector * restrict vx0 = (const HVX_Vector * restrict) x0; // fp16 + const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) x1; // fp16 + + uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors + uint32_t nloe = n % VLEN_FP16; // leftover elements + + const HVX_Vector zero = Q6_V_vsplat_R(0); + HVX_Vector rsum0 = Q6_V_vsplat_R(0); + HVX_Vector rsum1 = Q6_V_vsplat_R(0); + + uint32_t i = 0; + + #pragma unroll(2) + for (i = 0; i < nvec; i++) { + // Load y (fp32) and convert into fp16 + HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero); + // Load x (fp16) + HVX_Vector x0_hf = vx0[i]; + HVX_Vector x1_hf = vx1[i]; + + HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf); + HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf); + + rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0)); + rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1)); } - rsum = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(rsum), hvx_vec_splat_f32(s)); - rsum = Q6_Vsf_equals_Vqf32(hvx_vec_reduce_sum_qf32(rsum)); + if (nloe) { + // Load y (fp32) and convert into fp16 + HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero); + + // Load x (fp16) + HVX_Vector x0_hf = vx0[i]; + HVX_Vector x1_hf = vx1[i]; - hvx_vec_store_u(r, 4, rsum); + // Zero-out unused elements + // Note that we need to clear both x and y because they may contain NANs + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); + x0_hf = Q6_V_vand_QV(bmask, x0_hf); + x1_hf = Q6_V_vand_QV(bmask, x1_hf); + y_hf = Q6_V_vand_QV(bmask, y_hf); + + HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf); + HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf); + + rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0)); + rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1)); + } + + HVX_Vector rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32x2(rsum0, rsum1)); + hvx_vec_store_u(r, 8, Q6_Vsf_equals_Vqf32(rsum)); } // Dot product of two F16 vectors, accumulating to float @@ -91,7 +152,7 @@ static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); - rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); + rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum)); } if (nloe) { @@ -103,15 +164,65 @@ static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); - rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); + rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum)); + } + + rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum)); + hvx_vec_store_u(r, 4, Q6_Vsf_equals_Vqf32(rsum)); +} + +static inline void hvx_dot_f16_f16_aa_rx2(float * restrict r, + const void * restrict y, + const void * restrict x0, + const void * restrict x1, + unsigned int n, + float s) { + const HVX_Vector * restrict vx0 = (const HVX_Vector * restrict) x0; // fp16 + const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) x1; // fp16 + const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16 + + uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors + uint32_t nloe = n % VLEN_FP16; // leftover elements + + const HVX_Vector zero = Q6_V_vsplat_R(0); + HVX_Vector rsum0 = Q6_V_vsplat_R(0); + HVX_Vector rsum1 = Q6_V_vsplat_R(0); + + uint32_t i = 0; + + #pragma unroll(4) + for (i = 0; i < nvec; i++) { + HVX_Vector y_hf = vy[i]; + HVX_Vector x0_hf = vx0[i]; + HVX_Vector x1_hf = vx1[i]; + + HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf); + HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf); + + rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0)); + rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1)); + } + + if (nloe) { + HVX_Vector y_hf = vy[i]; + + // Load x (fp16) and zero-out unused elements + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); + HVX_Vector x0_hf = Q6_V_vand_QV(bmask, vx0[i]); + HVX_Vector x1_hf = Q6_V_vand_QV(bmask, vx1[i]); + + HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf); + HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf); + + rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0)); + rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1)); } - rsum = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(rsum), hvx_vec_splat_f32(s)); - rsum = Q6_Vsf_equals_Vqf32(hvx_vec_reduce_sum_qf32(rsum)); - hvx_vec_store_u(r, 4, rsum); + HVX_Vector rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32x2(rsum0, rsum1)); + hvx_vec_store_u(r, 8, Q6_Vsf_equals_Vqf32(rsum)); } -// MAD: y (F32) += x (F16) * v (float) +// MAD: y (F32) += x (F16) * s (float) static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict x, int n, float s) { const HVX_Vector * restrict ptr_x = (const HVX_Vector *) x; HVX_Vector * restrict ptr_y = (HVX_Vector *) y; @@ -317,17 +428,22 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in // Inner loop processing the block from VTCM uint32_t ic = 0; + const bool is_q_fp32 = (q->type == HTP_TYPE_F32); + // Process in blocks of 32 (VLEN_FP32) - for (; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32) { + static_assert(FLASH_ATTN_BLOCK_SIZE / VLEN_FP32 <= 4, "FLASH_ATTN_BLOCK_SIZE changed, fix HVX_Vector_x4 usage"); + HVX_Vector_x4 scores_x4; + HVX_Vector v_max = hvx_vec_splat_f32(-INFINITY); + for (uint32_t iv = 0; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32, ++iv) { // 1. Compute scores float __attribute__((aligned(VLEN))) scores_arr[VLEN_FP32]; - for (int j = 0; j < VLEN_FP32; ++j) { + for (int j = 0; j < VLEN_FP32; j += 2) { const uint32_t cur_ic = ic + j; const uint8_t * k_ptr = k_base + cur_ic * size_k_row_padded; - if (q->type == HTP_TYPE_F32) { - hvx_dot_f32_f16_aa(&scores_arr[j], q_ptr_vtcm, k_ptr, DK, scale); + if (is_q_fp32) { + hvx_dot_f32_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + size_k_row_padded, DK, scale); } else { - hvx_dot_f16_f16_aa(&scores_arr[j], q_ptr_vtcm, k_ptr, DK, scale); + hvx_dot_f16_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + size_k_row_padded, DK, scale); } } @@ -356,36 +472,43 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in scores = Q6_Vsf_equals_Vqf32(scores); } + scores_x4.v[iv] = scores; + v_max = Q6_Vsf_vmax_VsfVsf(scores, v_max); + } + + { // 4. Online Softmax Update - HVX_Vector v_max = hvx_vec_reduce_max_f32(scores); + v_max = hvx_vec_reduce_max_f32(v_max); float m_block = hvx_vec_get_f32(v_max); - float M_old = M; float M_new = (m_block > M) ? m_block : M; M = M_new; - float ms = expf(M_old - M_new); - + const float ms = expf(M_old - M_new); hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms); - S = S * ms; HVX_Vector M_new_vec = hvx_vec_splat_f32(M_new); - HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_new_vec); - HVX_Vector P = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(scores_shifted)); - - HVX_Vector p_sum_vec = hvx_vec_reduce_sum_f32(P); - float p_sum = hvx_vec_get_f32(p_sum_vec); - S += p_sum; - - // 5. Accumulate V - float __attribute__((aligned(VLEN))) p_arr[VLEN_FP32]; - *(HVX_Vector*)p_arr = P; - - for (int j = 0; j < VLEN_FP32; ++j) { - const uint32_t cur_ic = ic + j; - const uint8_t * v_ptr = v_base + cur_ic * size_v_row_padded; - hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, p_arr[j]); + HVX_Vector p_sum_vec = hvx_vec_splat_f32(0.0f); + for (uint32_t ic2 = 0, iv = 0; ic2 + VLEN_FP32 <= current_block_size; ic2 += VLEN_FP32, ++iv) { + HVX_Vector scores = scores_x4.v[iv]; + HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_new_vec); + HVX_Vector P = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(scores_shifted)); + + p_sum_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(p_sum_vec, P)); + + // 5. Accumulate V + float __attribute__((aligned(VLEN))) p_arr[VLEN_FP32]; + *(HVX_Vector*)p_arr = P; + + for (int j = 0; j < VLEN_FP32; ++j) { + const uint32_t cur_ic = ic2 + j; + const uint8_t * v_ptr = v_base + cur_ic * size_v_row_padded; + hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, p_arr[j]); + } } + + p_sum_vec = hvx_vec_reduce_sum_f32(p_sum_vec); + S = S * ms + hvx_vec_get_f32(p_sum_vec); } // Leftover @@ -393,7 +516,7 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in float s_val; const uint8_t * k_ptr = k_base + ic * size_k_row_padded; - if (q->type == HTP_TYPE_F32) { + if (is_q_fp32) { hvx_dot_f32_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale); } else { hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale); diff --git a/ggml/src/ggml-hexagon/htp/hvx-dump.h b/ggml/src/ggml-hexagon/htp/hvx-dump.h index e882227893e..85201fc3453 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-dump.h +++ b/ggml/src/ggml-hexagon/htp/hvx-dump.h @@ -28,19 +28,16 @@ static void hvx_vec_dump_f16(char * pref, HVX_Vector v) { } static void hvx_vec_dump_f32_n(char * pref, HVX_Vector v, uint32_t n) { - union { - HVX_Vector v; - float d[32]; - } u = { .v = v }; + HVX_VectorAlias u = { .v = v }; const uint32_t n0 = n / 16; const uint32_t n1 = n % 16; int i = 0; for (; i < n0; i++) { - hex_dump_f32_line(pref, u.d + (16 * i), 16); + hex_dump_f32_line(pref, u.fp32 + (16 * i), 16); } if (n1) { - hex_dump_f32_line(pref, u.d + (16 * i), n1); + hex_dump_f32_line(pref, u.fp32 + (16 * i), n1); } } diff --git a/ggml/src/ggml-hexagon/htp/hvx-reduce.h b/ggml/src/ggml-hexagon/htp/hvx-reduce.h index 8845fe73ea1..1ca7c05d983 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-reduce.h +++ b/ggml/src/ggml-hexagon/htp/hvx-reduce.h @@ -44,6 +44,45 @@ static inline HVX_Vector hvx_vec_reduce_sum_qf32(HVX_Vector in) { return hvx_vec_reduce_sum_n_qf32(in, 32); } +#if __HVX_ARCH__ > 75 + +static inline HVX_Vector hvx_vec_reduce_sum_f32x2(HVX_Vector in0, HVX_Vector in1) { + HVX_VectorPair sump = Q6_W_vshuff_VVR(in1, in0, 4); + HVX_Vector sum_sf = Q6_Vsf_vadd_VsfVsf(Q6_V_lo_W(sump), Q6_V_hi_W(sump)); + + sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 2)); + sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 4)); + sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 8)); + sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 16)); + return sum_sf; +} + +static inline HVX_Vector hvx_vec_reduce_sum_n_f32(HVX_Vector in, unsigned int n) { + unsigned int total = n * 4; // total vec nbytes + unsigned int width = 4; // fp32 nbytes + + HVX_Vector sum = in, sum_t; + while (width < total) { + sum_t = Q6_V_vror_VR(sum, width); // rotate right + sum = Q6_Vsf_vadd_VsfVsf(sum, sum_t); // elementwise sum + width = width << 1; + } + return sum; +} + +#else + +static inline HVX_Vector hvx_vec_reduce_sum_f32x2(HVX_Vector in0, HVX_Vector in1) { + HVX_VectorPair sump = Q6_W_vshuff_VVR(in1, in0, 4); + HVX_Vector sum_qf = Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(sump), Q6_V_hi_W(sump)); + + sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 2)); + sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 4)); + sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 8)); + sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 16)); + return Q6_Vsf_equals_Vqf32(sum_qf); +} + static inline HVX_Vector hvx_vec_reduce_sum_n_f32(HVX_Vector in, unsigned int n) { unsigned int total = n * 4; // total vec nbytes unsigned int width = 4; // fp32 nbytes @@ -57,6 +96,8 @@ static inline HVX_Vector hvx_vec_reduce_sum_n_f32(HVX_Vector in, unsigned int n) return sum; } +#endif + static inline HVX_Vector hvx_vec_reduce_sum_f32(HVX_Vector in) { return hvx_vec_reduce_sum_n_f32(in, 32); } diff --git a/ggml/src/ggml-hexagon/htp/matmul-ops.c b/ggml/src/ggml-hexagon/htp/matmul-ops.c index 1603ff2b3b6..d251eeed33a 100644 --- a/ggml/src/ggml-hexagon/htp/matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/matmul-ops.c @@ -11,6 +11,7 @@ #include "hex-dma.h" #include "hvx-utils.h" +#include "hvx-dump.h" #define GGML_COMMON_DECL_C #include "ggml-common.h" @@ -320,7 +321,7 @@ static void vec_dot_q4x4x2_q8x4x2(const int n, float * restrict s, const void * const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales - // Row sum (qf32) + // Row sum (sf) HVX_Vector r0_sum = Q6_V_vsplat_R(0); // Multiply and accumulate into int32. @@ -344,7 +345,7 @@ static void vec_dot_q4x4x2_q8x4x2(const int n, float * restrict s, const void * HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); } // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks @@ -362,14 +363,14 @@ static void vec_dot_q4x4x2_q8x4x2(const int n, float * restrict s, const void * // Zero out unused scales HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); } - // Reduce and convert into fp32 - r0_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r0_sum)); + r0_sum = hvx_vec_reduce_sum_f32(r0_sum); hvx_vec_store_u(&s[0], 4, r0_sum); } @@ -402,7 +403,7 @@ static void vec_dot_q4x4x2_q8x4x2_rx2(const int n, const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales - // Row sum (qf32) + // Row sum (sf) HVX_Vector r0_sum = Q6_V_vsplat_R(0); HVX_Vector r1_sum = Q6_V_vsplat_R(0); @@ -432,8 +433,8 @@ static void vec_dot_q4x4x2_q8x4x2_rx2(const int n, HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); - r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa); + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); } // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks @@ -456,20 +457,18 @@ static void vec_dot_q4x4x2_q8x4x2_rx2(const int n, HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); r0_dd = Q6_V_vand_QV(bmask, r0_dd); r1_dd = Q6_V_vand_QV(bmask, r1_dd); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); + r1_ia = Q6_V_vand_QV(bmask, r1_ia); HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); - r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa); + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); } - // Convert into fp32 and reduce - r0_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r0_sum)); - r1_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r1_sum)); - HVX_VectorPair p0 = Q6_W_vshuff_VVR(r1_sum, r0_sum, 4); - - hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0)); + HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum); + hvx_vec_store_u(&s[0], 8, rsum); } static void vec_dot_q8x4x2_q8x4x2(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { @@ -493,7 +492,7 @@ static void vec_dot_q8x4x2_q8x4x2(const int n, float * restrict s, const void * const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales - // Row sum (qf32) + // Row sum (sf) HVX_Vector r0_sum = Q6_V_vsplat_R(0); // Multiply and accumulate into int32. @@ -517,7 +516,7 @@ static void vec_dot_q8x4x2_q8x4x2(const int n, float * restrict s, const void * HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); } // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks @@ -535,14 +534,14 @@ static void vec_dot_q8x4x2_q8x4x2(const int n, float * restrict s, const void * // Zero out unused scales HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); } - // Reduce and convert into fp32 - r0_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r0_sum)); + r0_sum = hvx_vec_reduce_sum_f32(r0_sum); hvx_vec_store_u(&s[0], 4, r0_sum); } @@ -605,8 +604,8 @@ static void vec_dot_q8x4x2_q8x4x2_rx2(const int n, HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); - r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa); + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); } // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks @@ -629,20 +628,18 @@ static void vec_dot_q8x4x2_q8x4x2_rx2(const int n, HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); r0_dd = Q6_V_vand_QV(bmask, r0_dd); r1_dd = Q6_V_vand_QV(bmask, r1_dd); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); + r1_ia = Q6_V_vand_QV(bmask, r1_ia); HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); - r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa); + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); } - // Convert into fp32 and reduce - r0_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r0_sum)); - r1_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r1_sum)); - HVX_VectorPair p0 = Q6_W_vshuff_VVR(r1_sum, r0_sum, 4); - - hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0)); + HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum); + hvx_vec_store_u(&s[0], 8, rsum); } static void vec_dot_mxfp4x4x2_q8x4x2(const int n, @@ -669,7 +666,7 @@ static void vec_dot_mxfp4x4x2_q8x4x2(const int n, const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales - // Row sum (qf32) + // Row sum (sf) HVX_Vector r0_sum = Q6_V_vsplat_R(0); // Multiply and accumulate into int32. @@ -708,7 +705,7 @@ static void vec_dot_mxfp4x4x2_q8x4x2(const int n, HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); } // Process leftovers @@ -741,14 +738,14 @@ static void vec_dot_mxfp4x4x2_q8x4x2(const int n, // Zero-out unused scales HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); } - // Reduce and convert into fp32 - r0_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r0_sum)); + r0_sum = hvx_vec_reduce_sum_f32(r0_sum); hvx_vec_store_u(&s[0], 4, r0_sum); } @@ -781,13 +778,13 @@ static void vec_dot_mxfp4x4x2_q8x4x2_rx2(const int n, const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales - // Row sum (qf32) + // Row sum (sf) HVX_Vector r0_sum = Q6_V_vsplat_R(0); HVX_Vector r1_sum = Q6_V_vsplat_R(0); // Multiply and accumulate into int32. // Compute combined scale (fp32). - // Apply scale to acc and accumulate into the row sum (qf32). + // Apply scale to acc and accumulate into the row sum (f32). const uint32_t nb = n / qk; // num full blocks int32_t nloe = n % qk; // num leftover elemements (must be signed) @@ -829,8 +826,8 @@ static void vec_dot_mxfp4x4x2_q8x4x2_rx2(const int n, HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); - r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa); + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); } // Process leftovers @@ -867,24 +864,22 @@ static void vec_dot_mxfp4x4x2_q8x4x2_rx2(const int n, HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d)); HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy_d)); - // Zero-out unused scales + // Zero-out unused values HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); r0_dd = Q6_V_vand_QV(bmask, r0_dd); r1_dd = Q6_V_vand_QV(bmask, r1_dd); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); + r1_ia = Q6_V_vand_QV(bmask, r1_ia); HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); - r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa); + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); } - // Convert into fp32 and reduce - r0_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r0_sum)); - r1_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r1_sum)); - HVX_VectorPair p0 = Q6_W_vshuff_VVR(r1_sum, r0_sum, 4); - - hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0)); + HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum); + hvx_vec_store_u(&s[0], 8, rsum); } static void vec_dot_f16_f16_aa(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { @@ -913,7 +908,7 @@ static void vec_dot_f16_f16_aa(const int n, float * restrict s, const void * res rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); } - rsum = Q6_Vsf_equals_Vqf32(hvx_vec_reduce_sum_qf32(rsum)); + rsum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(rsum)); hvx_vec_store_u(&s[0], 4, rsum); } @@ -957,11 +952,8 @@ static void vec_dot_f16_f16_aa_rx2(const int n, rsum1 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum1, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf))); } - rsum0 = Q6_Vsf_equals_Vqf32(hvx_vec_reduce_sum_qf32(rsum0)); - rsum1 = Q6_Vsf_equals_Vqf32(hvx_vec_reduce_sum_qf32(rsum1)); - HVX_VectorPair p0 = Q6_W_vshuff_VVR(rsum1, rsum0, 4); - - hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0)); + HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(Q6_Vsf_equals_Vqf32(rsum0), Q6_Vsf_equals_Vqf32(rsum1)); + hvx_vec_store_u(&s[0], 8, rsum); } static void vec_dot_f16_f16_uu(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { @@ -990,7 +982,7 @@ static void vec_dot_f16_f16_uu(const int n, float * restrict s, const void * res rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); } - rsum = Q6_Vsf_equals_Vqf32(hvx_vec_reduce_sum_qf32(rsum)); + rsum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(rsum)); hvx_vec_store_u(&s[0], 4, rsum); } @@ -1042,7 +1034,8 @@ static void vec_dot_f16_f32_uu(const int n, float * restrict s, const void * res rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); } - rsum = Q6_Vsf_equals_Vqf32(hvx_vec_reduce_sum_qf32(rsum)); + // Convert into fp32 and reduce + rsum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(rsum)); hvx_vec_store_u(&s[0], 4, rsum); } diff --git a/ggml/src/ggml-hexagon/htp/softmax-ops.c b/ggml/src/ggml-hexagon/htp/softmax-ops.c index 1b6b2eba4ae..e91a16d947f 100644 --- a/ggml/src/ggml-hexagon/htp/softmax-ops.c +++ b/ggml/src/ggml-hexagon/htp/softmax-ops.c @@ -154,8 +154,8 @@ static void hvx_fast_softmax_f32(const uint8_t * restrict src, v_pad[i] = v3; } - v = hvx_vec_reduce_sum_qf32(sum_vec); - sum_vec = hvx_vec_repl4(Q6_Vsf_equals_Vqf32(v)); + v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_vec)); + sum_vec = hvx_vec_repl4(v); HVX_VectorPred pos_sum = Q6_Q_vcmp_gt_VwVw(sum_vec, zero_v); HVX_Vector v4 = hvx_vec_inverse_f32(sum_vec); diff --git a/ggml/src/ggml-hexagon/htp/unary-ops.c b/ggml/src/ggml-hexagon/htp/unary-ops.c index be8be8c4e64..1a27cb6e63e 100644 --- a/ggml/src/ggml-hexagon/htp/unary-ops.c +++ b/ggml/src/ggml-hexagon/htp/unary-ops.c @@ -57,8 +57,8 @@ static void hvx_fast_rms_norm_f32(const uint8_t * restrict src, sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2); } - HVX_Vector reduced_sum = hvx_vec_reduce_sum_qf32(sum_v); - sum_v = hvx_vec_repl4(Q6_Vsf_equals_Vqf32(reduced_sum)); + HVX_Vector reduced_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v)); + sum_v = hvx_vec_repl4(reduced_sum); HVX_Vector t_v = hvx_vec_splat_f32((float) num_elems); HVX_Vector denom_v = hvx_vec_inverse_f32(t_v); diff --git a/ggml/src/ggml-hexagon/libdl.h b/ggml/src/ggml-hexagon/libdl.h new file mode 100644 index 00000000000..8ca5016f039 --- /dev/null +++ b/ggml/src/ggml-hexagon/libdl.h @@ -0,0 +1,79 @@ +#pragma once + +#ifdef _WIN32 +# define WIN32_LEAN_AND_MEAN +# ifndef NOMINMAX +# define NOMINMAX +# endif +# include +# include +#else +# include +# include +#endif +#include + +namespace fs = std::filesystem; + +#ifdef _WIN32 + +using dl_handle = std::remove_pointer_t; + +struct dl_handle_deleter { + void operator()(HMODULE handle) { + FreeLibrary(handle); + } +}; + +static inline dl_handle * dl_load_library(const fs::path & path) { + // suppress error dialogs for missing DLLs + DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS); + SetErrorMode(old_mode | SEM_FAILCRITICALERRORS); + + HMODULE handle = LoadLibraryW(path.wstring().c_str()); + + SetErrorMode(old_mode); + + return handle; +} + +static inline void * dl_get_sym(dl_handle * handle, const char * name) { + DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS); + SetErrorMode(old_mode | SEM_FAILCRITICALERRORS); + + void * p = (void *) GetProcAddress(handle, name); + + SetErrorMode(old_mode); + + return p; +} + +static inline const char * dl_error() { + return ""; +} + +#else + +using dl_handle = void; + +struct dl_handle_deleter { + void operator()(void * handle) { + dlclose(handle); + } +}; + +static inline dl_handle * dl_load_library(const fs::path & path) { + dl_handle * handle = dlopen(path.string().c_str(), RTLD_NOW | RTLD_LOCAL); + return handle; +} + +static inline void * dl_get_sym(dl_handle * handle, const char * name) { + return dlsym(handle, name); +} + +static inline const char * dl_error() { + const char *rslt = dlerror(); + return rslt != nullptr ? rslt : ""; +} + +#endif diff --git a/ggml/src/ggml-hexagon/libggml-htp.inf b/ggml/src/ggml-hexagon/libggml-htp.inf new file mode 100644 index 00000000000..656d2d9ab26 --- /dev/null +++ b/ggml/src/ggml-hexagon/libggml-htp.inf @@ -0,0 +1,38 @@ +[Version] +Signature = "$WINDOWS NT$" +Class = ComputeAccelerator +ClassGuid = {F01A9D53-3FF6-48D2-9F97-C8A7004BE10C} +Provider = %GGML% +DriverVer = 01/01/2026,1.0.0.0 +CatalogFile = libggml-htp.cat +PnpLockDown = 1 + +[DestinationDirs] +Drivers_Dir = 6 + +[SourceDisksNames] +1 = %DiskId% + +[SourceDisksFiles] +libggml-htp-v68.so = 1 +libggml-htp-v69.so = 1 +libggml-htp-v73.so = 1 +libggml-htp-v75.so = 1 +libggml-htp-v81.so = 1 + +[ControlFlags] +ExcludeFromSelect = * + +[DefaultInstall.NTarm64] +CopyFiles=Drivers_Dir + +[Drivers_Dir] +libggml-htp-v68.so,,,0x10 ;COPYFLG_NO_OVERWRITE +libggml-htp-v69.so,,,0x10 ;COPYFLG_NO_OVERWRITE +libggml-htp-v73.so,,,0x10 ;COPYFLG_NO_OVERWRITE +libggml-htp-v75.so,,,0x10 ;COPYFLG_NO_OVERWRITE +libggml-htp-v81.so,,,0x10 ;COPYFLG_NO_OVERWRITE + +[Strings] +GGML = 'GGML' +DiskId = 'GGML HTP library' diff --git a/ggml/src/ggml-hip/CMakeLists.txt b/ggml/src/ggml-hip/CMakeLists.txt index 23b6889919f..80037d24361 100644 --- a/ggml/src/ggml-hip/CMakeLists.txt +++ b/ggml/src/ggml-hip/CMakeLists.txt @@ -62,6 +62,8 @@ file(GLOB SRCS "../ggml-cuda/template-instances/fattn-mma*.cu") list(APPEND GGML_SOURCES_ROCM ${SRCS}) file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu") list(APPEND GGML_SOURCES_ROCM ${SRCS}) +file(GLOB SRCS "../ggml-cuda/template-instances/mmf*.cu") +list(APPEND GGML_SOURCES_ROCM ${SRCS}) if (GGML_CUDA_FA_ALL_QUANTS) file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*.cu") diff --git a/ggml/src/ggml-metal/CMakeLists.txt b/ggml/src/ggml-metal/CMakeLists.txt index 9c0b3db8599..42054d841aa 100644 --- a/ggml/src/ggml-metal/CMakeLists.txt +++ b/ggml/src/ggml-metal/CMakeLists.txt @@ -71,7 +71,7 @@ else() # disabling fast math is needed in order to pass tests/test-backend-ops # note: adding -fno-inline fixes the tests when using MTL_SHADER_VALIDATION=1 # note: unfortunately, we have to call it default.metallib instead of ggml.metallib - # ref: https://github.com/ggerganov/whisper.cpp/issues/1720 + # ref: https://github.com/ggml-org/whisper.cpp/issues/1720 # note: adding -g causes segmentation fault during compile #set(XC_FLAGS -fno-fast-math -fno-inline -g) set(XC_FLAGS -fno-fast-math -fno-inline) diff --git a/ggml/src/ggml-metal/ggml-metal-context.h b/ggml/src/ggml-metal/ggml-metal-context.h index ec2b686b733..abf4b06ed2a 100644 --- a/ggml/src/ggml-metal/ggml-metal-context.h +++ b/ggml/src/ggml-metal/ggml-metal-context.h @@ -15,14 +15,22 @@ typedef struct ggml_metal * ggml_metal_t; ggml_metal_t ggml_metal_init(ggml_metal_device_t dev); void ggml_metal_free(ggml_metal_t ctx); +const char * ggml_metal_get_name(ggml_metal_t ctx); + void ggml_metal_synchronize(ggml_metal_t ctx); void ggml_metal_set_tensor_async(ggml_metal_t ctx, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); void ggml_metal_get_tensor_async(ggml_metal_t ctx, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); +bool ggml_metal_cpy_tensor_async(ggml_metal_t ctx_src, ggml_metal_t ctx_dst, const struct ggml_tensor * src, struct ggml_tensor * dst); enum ggml_status ggml_metal_graph_compute (ggml_metal_t ctx, struct ggml_cgraph * gf); void ggml_metal_graph_optimize(ggml_metal_t ctx, struct ggml_cgraph * gf); +void ggml_metal_event_record(ggml_metal_t ctx, ggml_metal_event_t ev); +void ggml_metal_event_wait (ggml_metal_t ctx, ggml_metal_event_t ev); + +ggml_metal_event_t ggml_metal_get_ev_cpy(ggml_metal_t ctx); + void ggml_metal_set_n_cb (ggml_metal_t ctx, int n_cb); void ggml_metal_set_abort_callback (ggml_metal_t ctx, ggml_abort_callback abort_callback, void * user_data); bool ggml_metal_supports_family (ggml_metal_t ctx, int family); diff --git a/ggml/src/ggml-metal/ggml-metal-context.m b/ggml/src/ggml-metal/ggml-metal-context.m index 42a35736eea..5d3a8ce412a 100644 --- a/ggml/src/ggml-metal/ggml-metal-context.m +++ b/ggml/src/ggml-metal/ggml-metal-context.m @@ -24,9 +24,13 @@ }; struct ggml_metal { + char name[128]; + ggml_metal_device_t dev; ggml_metal_library_t lib; + ggml_metal_event_t ev_cpy; // for async copies + dispatch_queue_t d_queue; // additional, inference-time compiled pipelines @@ -117,7 +121,11 @@ ggml_metal_t ggml_metal_init(ggml_metal_device_t dev) { } } - //const struct ggml_metal_device_props * props_dev = ggml_metal_device_get_props(dev); + res->ev_cpy = ggml_metal_device_event_init(dev); + + const struct ggml_metal_device_props * props_dev = ggml_metal_device_get_props(dev); + + snprintf(res->name, sizeof(res->name), "%s", props_dev->name); res->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT); @@ -206,9 +214,15 @@ void ggml_metal_free(ggml_metal_t ctx) { dispatch_release(ctx->d_queue); + ggml_metal_device_event_free(ctx->dev, ctx->ev_cpy); + free(ctx); } +const char * ggml_metal_get_name(ggml_metal_t ctx) { + return ctx->name; +} + void ggml_metal_synchronize(ggml_metal_t ctx) { // wait for any backend operations to finish if (ctx->cmd_buf_last) { @@ -273,8 +287,8 @@ void ggml_metal_set_tensor_async(ggml_metal_t ctx, struct ggml_tensor * tensor, // wrap the source data into a Metal buffer id device = ggml_metal_device_get_obj(ctx->dev); id buf_src = [device newBufferWithBytes:data - length:size - options:MTLResourceStorageModeShared]; + length:size + options:MTLResourceStorageModeShared]; GGML_ASSERT(buf_src); @@ -316,9 +330,9 @@ void ggml_metal_get_tensor_async(ggml_metal_t ctx, const struct ggml_tensor * te @autoreleasepool { id device = ggml_metal_device_get_obj(ctx->dev); id buf_dst = [device newBufferWithBytesNoCopy:data - length:size - options:MTLResourceStorageModeShared - deallocator:nil]; + length:size + options:MTLResourceStorageModeShared + deallocator:nil]; GGML_ASSERT(buf_dst); @@ -356,9 +370,52 @@ void ggml_metal_get_tensor_async(ggml_metal_t ctx, const struct ggml_tensor * te } } +bool ggml_metal_cpy_tensor_async(ggml_metal_t ctx_src, ggml_metal_t ctx_dst, const struct ggml_tensor * src, struct ggml_tensor * dst) { + @autoreleasepool { + struct ggml_metal_buffer_id bid_src = ggml_metal_get_buffer_id(src); + struct ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(dst); + + if (bid_src.metal == nil || bid_dst.metal == nil) { + return false; + } + + // queue the copy operation into the Metal context + // this will be queued at the end, after any currently ongoing GPU operations + id queue = ggml_metal_device_get_queue(ctx_src->dev); + id cmd_buf = [queue commandBuffer]; + id encoder = [cmd_buf blitCommandEncoder]; + + [encoder copyFromBuffer:bid_src.metal + sourceOffset:bid_src.offs + toBuffer:bid_dst.metal + destinationOffset:bid_dst.offs + size:ggml_nbytes(src)]; + + [encoder endEncoding]; + + ggml_metal_event_t ev_cpy = ggml_metal_get_ev_cpy(ctx_src); + ggml_metal_event_encode_signal(ev_cpy, cmd_buf); + + [cmd_buf commit]; + + // do not wait here for completion + //[cmd_buf waitUntilCompleted]; + + // instead, remember a reference to the command buffer and wait for it later if needed + [ctx_src->cmd_bufs_ext addObject:cmd_buf]; + ctx_src->cmd_buf_last = cmd_buf; + + [cmd_buf retain]; + + ggml_metal_event_wait(ctx_dst, ev_cpy); + + return true; + } +} + enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph * gf) { // number of nodes encoded by the main thread (empirically determined) - const int n_main = 64; + const int n_main = MAX(64, 0.1*gf->n_nodes); // number of threads in addition to the main thread const int n_cb = ctx->n_cb; @@ -530,6 +587,42 @@ void ggml_metal_graph_optimize(ggml_metal_t ctx, struct ggml_cgraph * gf) { //printf("%s: graph optimize took %.3f ms\n", __func__, (ggml_time_us() - t_start) / 1000.0); } +void ggml_metal_event_record(ggml_metal_t ctx, ggml_metal_event_t ev) { + @autoreleasepool { + id queue = ggml_metal_device_get_queue(ctx->dev); + id cmd_buf = [queue commandBuffer]; + + ggml_metal_event_encode_signal(ev, cmd_buf); + + [cmd_buf commit]; + + [ctx->cmd_bufs_ext addObject:cmd_buf]; + ctx->cmd_buf_last = cmd_buf; + + [cmd_buf retain]; + } +} + +void ggml_metal_event_wait(ggml_metal_t ctx, ggml_metal_event_t ev) { + @autoreleasepool { + id queue = ggml_metal_device_get_queue(ctx->dev); + id cmd_buf = [queue commandBuffer]; + + ggml_metal_event_encode_wait(ev, cmd_buf); + + [cmd_buf commit]; + + [ctx->cmd_bufs_ext addObject:cmd_buf]; + ctx->cmd_buf_last = cmd_buf; + + [cmd_buf retain]; + } +} + +ggml_metal_event_t ggml_metal_get_ev_cpy(ggml_metal_t ctx) { + return ctx->ev_cpy; +} + void ggml_metal_set_n_cb(ggml_metal_t ctx, int n_cb) { if (ctx->n_cb != n_cb) { ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_COMMAND_BUFFERS); diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 04c6137c5a7..4c4c3ce36c4 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -17,10 +17,12 @@ struct ggml_metal_device_deleter { typedef std::unique_ptr ggml_metal_device_ptr; -ggml_metal_device_t ggml_metal_device_get(void) { - static ggml_metal_device_ptr ctx { ggml_metal_device_init() }; +ggml_metal_device_t ggml_metal_device_get(int device) { + static std::vector devs; - return ctx.get(); + devs.emplace_back(ggml_metal_device_init(device)); + + return devs.back().get(); } struct ggml_metal_pipelines { @@ -174,6 +176,26 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_set_rows(ggml_me return res; } +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_diag(ggml_metal_library_t lib, const ggml_tensor * op) { + char base[256]; + char name[256]; + + const int n = op->src[0]->ne[0]; + + snprintf(base, 256, "kernel_diag_%s", ggml_type_name(op->src[0]->type)); + snprintf(name, 256, "%s_n=%d", base, n); + + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + } + + res.nsg = 1; + res.smem = 0; + + return res; +} + ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_repeat(ggml_metal_library_t lib, ggml_type tsrc) { char base[256]; char name[256]; @@ -532,6 +554,36 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv(ggml_metal_ return res; } +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri(ggml_metal_library_t lib, const ggml_tensor * op) { + char base[256]; + char name[256]; + + const int nsg = 8; + const int n = op->src[1]->ne[1]; + const int k = op->src[1]->ne[0]; + + snprintf(base, 256, "kernel_solve_tri_%s", ggml_type_name(op->src[0]->type)); + snprintf(name, 256, "%s_nsg=%d_n=%d_k=%d", base, nsg, n, k); + + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_int16(cv, nsg, FC_SOLVE_TRI + 0); + ggml_metal_cv_set_int16(cv, n, FC_SOLVE_TRI + 1); + ggml_metal_cv_set_int16(cv, k, FC_SOLVE_TRI + 2); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); + } + + res.nsg = nsg; + res.smem = GGML_PAD(GGML_PAD(n, 32)*nsg*sizeof(float), 16); + + return res; +} + ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1, int nsg, int nxpsg, int r1ptg) { char base[256]; char name[256]; @@ -1340,34 +1392,78 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_v GGML_UNUSED(op); } -ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin( - ggml_metal_library_t lib, - ggml_op op, - int32_t n_fuse, - bool row) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin(ggml_metal_library_t lib, const ggml_tensor * op, int32_t n_fuse) { char base[256]; char name[256]; - const char * op_str = "undefined"; - switch (op) { - case GGML_OP_ADD: op_str = "add"; break; - case GGML_OP_SUB: op_str = "sub"; break; - case GGML_OP_MUL: op_str = "mul"; break; - case GGML_OP_DIV: op_str = "div"; break; + int op_num = -1; + + switch (op->op) { + case GGML_OP_ADD: op_num = 0; break; + case GGML_OP_SUB: op_num = 1; break; + case GGML_OP_MUL: op_num = 2; break; + case GGML_OP_DIV: op_num = 3; break; default: GGML_ABORT("fatal error"); }; - if (row) { - snprintf(base, 256, "kernel_%s_row_c4_fuse_%d", op_str, n_fuse); - } else { - snprintf(base, 256, "kernel_%s_fuse_%d", op_str, n_fuse); + const char * t0_str = ggml_type_name(op->src[0]->type); + const char * t1_str = ggml_type_name(op->src[1]->type); + const char * t_str = ggml_type_name(op->type); + + const bool is_c4 = (op->src[0]->ne[0] % 4 == 0) && (op->src[1]->ne[0] % 4 == 0); + + const bool is_rb = ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]) && (ggml_nrows(op->src[1]) == 1) && ggml_nelements(op) < 65536; + + snprintf(base, 256, "kernel_bin_fuse_%s_%s_%s%s", t0_str, t1_str, t_str, is_c4 ? "_4" : ""); + snprintf(name, 256, "%s_op=%d_nf=%d_rb=%d", base, op_num, n_fuse, is_rb); + + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_int16(cv, op_num, FC_BIN + 0); + ggml_metal_cv_set_int16(cv, n_fuse, FC_BIN + 1); + ggml_metal_cv_set_bool (cv, is_rb, FC_BIN + 2); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); } - snprintf(name, 256, "%s", base); + res.c4 = is_c4; + res.cnt = is_rb; + + return res; +} + +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin_one(ggml_metal_library_t lib, ggml_op op) { + char base[256]; + char name[256]; + + int op_num = -1; + + switch (op) { + case GGML_OP_ADD: op_num = 0; break; + case GGML_OP_SUB: op_num = 1; break; + case GGML_OP_MUL: op_num = 2; break; + case GGML_OP_DIV: op_num = 3; break; + default: GGML_ABORT("fatal error"); + }; + + snprintf(base, 256, "kernel_bin_fuse_%s_%s_%s", "f32", "f32", "f32"); + snprintf(name, 256, "%s_op=%d_nf=%d", base, op_num, 1); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); if (!res.pipeline) { - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_int16(cv, op_num, FC_BIN + 0); + ggml_metal_cv_set_int16(cv, 1, FC_BIN + 1); + ggml_metal_cv_set_bool (cv, false, FC_BIN + 2); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); } return res; diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index 3d01c56fb81..93d7f6a216f 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -53,6 +53,9 @@ struct ggml_metal_pipeline_with_params { int nr1; size_t smem; + + bool c4; + bool cnt; }; int ggml_metal_pipeline_max_theads_per_threadgroup(struct ggml_metal_pipeline_with_params pipeline); @@ -108,6 +111,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_1d struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_2d (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_get_rows (ggml_metal_library_t lib, enum ggml_type tsrc); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_set_rows (ggml_metal_library_t lib, enum ggml_type tidx, enum ggml_type tdst); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_diag (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_repeat (ggml_metal_library_t lib, enum ggml_type tsrc); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_glu (ggml_metal_library_t lib, const struct ggml_tensor * op); @@ -121,6 +125,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched (ggml_metal_library_t lib, const struct ggml_tensor * op, int ssm_conv_bs); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv (ggml_metal_library_t lib, const struct ggml_tensor * op); @@ -132,7 +137,8 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort_merge (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k_merge (ggml_metal_library_t lib, const struct ggml_tensor * op); -struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse ); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin_one (ggml_metal_library_t lib, enum ggml_op op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse); @@ -205,7 +211,9 @@ void ggml_metal_rsets_free(ggml_metal_rsets_t rsets); // struct ggml_metal_device_props { + int device; char name[128]; + char desc[128]; size_t max_buffer_size; size_t max_working_set_size; @@ -224,11 +232,15 @@ struct ggml_metal_device_props { int op_offload_min_batch_size; }; -ggml_metal_device_t ggml_metal_device_init(void); +typedef struct ggml_metal_event * ggml_metal_event_t; + +void ggml_metal_event_encode_signal(ggml_metal_event_t ev, ggml_metal_cmd_buf_t cmd_buf); +void ggml_metal_event_encode_wait (ggml_metal_event_t ev, ggml_metal_cmd_buf_t cmd_buf); + +ggml_metal_device_t ggml_metal_device_init(int device); void ggml_metal_device_free(ggml_metal_device_t dev); -// return a singleton that is automatically destroyed when the program exits -ggml_metal_device_t ggml_metal_device_get(void); +ggml_metal_device_t ggml_metal_device_get(int device); void * ggml_metal_device_get_obj (ggml_metal_device_t dev); // id void * ggml_metal_device_get_queue(ggml_metal_device_t dev); // id @@ -240,6 +252,10 @@ void ggml_metal_device_rsets_rm (ggml_metal_device_t dev, ggml_metal_rset_t rset void ggml_metal_device_rsets_keep_alive(ggml_metal_device_t dev); +ggml_metal_event_t ggml_metal_device_event_init(ggml_metal_device_t dev); +void ggml_metal_device_event_free(ggml_metal_device_t dev, ggml_metal_event_t ev); +void ggml_metal_device_event_synchronize(ggml_metal_device_t dev, ggml_metal_event_t ev); + void ggml_metal_device_get_memory(ggml_metal_device_t dev, size_t * free, size_t * total); bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_tensor * op); diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index eb4e2c209ce..891d70c85a4 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -24,9 +24,6 @@ static const NSInteger MTLGPUFamilyMetal3_GGML = 5001; static const NSInteger MTLGPUFamilyMetal4_GGML = 5002; -// virtual address for GPU memory allocations -static atomic_uintptr_t g_addr_device = 0x000000400ULL; - #if !GGML_METAL_EMBED_LIBRARY // Here to assist with NSBundle Path Hack @interface GGMLMetalClass : NSObject @@ -349,10 +346,12 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline(ggml_meta struct ggml_metal_pipeline_with_params res = { /*.pipeline =*/ nil, + /*.nsg =*/ 0, /*.nr0 =*/ 0, /*.nr1 =*/ 0, - /*.nsg =*/ 0, /*.smem =*/ 0, + /*.c4 =*/ false, + /*.cnt =*/ false, }; res.pipeline = ggml_metal_pipelines_get(lib->pipelines, name); @@ -365,10 +364,12 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline(ggml_meta struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv) { struct ggml_metal_pipeline_with_params res = { /*.pipeline =*/ nil, + /*.nsg =*/ 0, /*.nr0 =*/ 0, /*.nr1 =*/ 0, - /*.nsg =*/ 0, /*.smem =*/ 0, + /*.c4 =*/ false, + /*.cnt =*/ false, }; [lib->lock lock]; @@ -523,6 +524,9 @@ void ggml_metal_encoder_end_encoding(ggml_metal_encoder_t encoder) { ggml_metal_library_t library; struct ggml_metal_device_props props; + + // virtual address for GPU memory allocations + atomic_uintptr_t addr_virt; }; // @@ -618,7 +622,7 @@ void ggml_metal_rsets_free(ggml_metal_rsets_t rsets) { free(rsets); } -ggml_metal_device_t ggml_metal_device_init(void) { +ggml_metal_device_t ggml_metal_device_init(int device) { ggml_metal_device_t dev = calloc(1, sizeof(struct ggml_metal_device)); assert(dev != NULL); @@ -632,6 +636,9 @@ ggml_metal_device_t ggml_metal_device_init(void) { GGML_LOG_ERROR("%s: error: failed to create command queue\n", __func__); } + dev->addr_virt = 0x000000400ULL; + + dev->props.device = device; dev->props.has_simdgroup_reduction = [dev->mtl_device supportsFamily:MTLGPUFamilyApple7]; dev->props.has_simdgroup_reduction |= [dev->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML]; @@ -785,10 +792,15 @@ ggml_metal_device_t ggml_metal_device_init(void) { dev->props.op_offload_min_batch_size = getenv("GGML_OP_OFFLOAD_MIN_BATCH") ? atoi(getenv("GGML_OP_OFFLOAD_MIN_BATCH")) : 32; dev->props.max_buffer_size = dev->mtl_device.maxBufferLength; - dev->props.max_working_set_size = dev->mtl_device.recommendedMaxWorkingSetSize; dev->props.max_theadgroup_memory_size = dev->mtl_device.maxThreadgroupMemoryLength; + if (@available(macOS 10.12, iOS 16.0, *)) { + dev->props.max_working_set_size = dev->mtl_device.recommendedMaxWorkingSetSize; + } else { + dev->props.max_working_set_size = dev->mtl_device.maxBufferLength; + } - strncpy(dev->props.name, [[dev->mtl_device name] UTF8String], sizeof(dev->props.name) - 1); + snprintf(dev->props.name, sizeof(dev->props.name), "%s%d", "MTL", device); + snprintf(dev->props.desc, sizeof(dev->props.desc), "%s", [[dev->mtl_device name] UTF8String]); dev->library = ggml_metal_library_init(dev); if (!dev->library) { @@ -918,6 +930,59 @@ void ggml_metal_device_rsets_keep_alive(ggml_metal_device_t dev) { atomic_store_explicit(&dev->rsets->d_loop, 2*dev->rsets->keep_alive_s, memory_order_relaxed); } +struct ggml_metal_event { + void * obj; // id + + atomic_int value; +}; + +void ggml_metal_event_encode_signal(ggml_metal_event_t ev, ggml_metal_cmd_buf_t cmd_buf_raw) { + id event = (id)ev->obj; + + id cmd_buf = (id) cmd_buf_raw; + + [cmd_buf encodeSignalEvent:event value:atomic_fetch_add_explicit(&ev->value, 1, memory_order_relaxed) + 1]; +} + +void ggml_metal_event_encode_wait(ggml_metal_event_t ev, ggml_metal_cmd_buf_t cmd_buf_raw) { + id event = (id)ev->obj; + + id cmd_buf = (id) cmd_buf_raw; + + [cmd_buf encodeWaitForEvent:event value:atomic_load_explicit(&ev->value, memory_order_relaxed)]; +} + +ggml_metal_event_t ggml_metal_device_event_init(ggml_metal_device_t dev) { + id event = [dev->mtl_device newEvent]; + + ggml_metal_event_t ev = calloc(1, sizeof(struct ggml_metal_event)); + + ev->obj = (__bridge void *)event; + ev->value = 0; + + return ev; +} + +void ggml_metal_device_event_free(ggml_metal_device_t dev, ggml_metal_event_t ev) { + id event = ev->obj; + [event release]; + + free(ev); + + GGML_UNUSED(dev); +} + +void ggml_metal_device_event_synchronize(ggml_metal_device_t dev, ggml_metal_event_t ev) { + @autoreleasepool { + id event = ev->obj; + + id cmd_buf = [dev->mtl_queue commandBuffer]; + [cmd_buf encodeWaitForEvent:event value:atomic_load_explicit(&ev->value, memory_order_relaxed)]; + [cmd_buf commit]; + [cmd_buf waitUntilCompleted]; + } +} + void ggml_metal_device_get_memory(ggml_metal_device_t dev, size_t * free, size_t * total) { if (@available(macOS 10.12, iOS 16.0, *)) { *total = dev->mtl_device.recommendedMaxWorkingSetSize; @@ -993,7 +1058,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_MUL: case GGML_OP_DIV: case GGML_OP_ADD_ID: - return op->src[0]->type == GGML_TYPE_F32; + return ggml_is_contiguous_rows(op->src[0]) && ggml_is_contiguous_rows(op->src[1]) && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_ACC: case GGML_OP_REPEAT: case GGML_OP_SCALE: @@ -1092,6 +1157,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_RWKV_WKV6: case GGML_OP_RWKV_WKV7: return true; + case GGML_OP_SOLVE_TRI: case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: return has_simdgroup_reduction; @@ -1173,6 +1239,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te return false; }; } + case GGML_OP_DIAG: + return true; case GGML_OP_OPT_STEP_ADAMW: case GGML_OP_OPT_STEP_SGD: return has_simdgroup_reduction; @@ -1340,8 +1408,8 @@ ggml_metal_buffer_t ggml_metal_buffer_init(ggml_metal_device_t dev, size_t size, res->all_data = ggml_metal_host_malloc(size_aligned); res->is_shared = true; } else { - // use virtual address from g_addr_device counter - res->all_data = (void *) atomic_fetch_add_explicit(&g_addr_device, size_aligned, memory_order_relaxed); + // use virtual address + res->all_data = (void *) atomic_fetch_add_explicit(&dev->addr_virt, size_aligned, memory_order_relaxed); res->is_shared = false; } res->all_size = size_aligned; diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 59d88b01a55..77bb403c15d 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -78,13 +78,15 @@ #define FC_MUL_MM 700 #define FC_ROPE 800 #define FC_SSM_CONV 900 -#define FC_COUNT_EQUAL 1000 +#define FC_SOLVE_TRI 1000 +#define FC_COUNT_EQUAL 1100 +#define FC_BIN 1200 // op-specific constants -#define OP_FLASH_ATTN_EXT_NQPTG 8 +#define OP_FLASH_ATTN_EXT_NQPSG 8 #define OP_FLASH_ATTN_EXT_NCPSG 64 -#define OP_FLASH_ATTN_EXT_VEC_NQPTG 1 +#define OP_FLASH_ATTN_EXT_VEC_NQPSG 1 #define OP_FLASH_ATTN_EXT_VEC_NCPSG 32 // kernel argument structs @@ -733,6 +735,33 @@ typedef struct { uint64_t nb0; } ggml_metal_kargs_ssm_scan; +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne10; + int32_t ne11; + int32_t ne12; + int32_t ne13; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; +} ggml_metal_kargs_solve_tri; + typedef struct { int32_t ne00t; int32_t ne00; @@ -764,6 +793,25 @@ typedef struct { uint64_t nb3; } ggml_metal_kargs_set_rows; +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; +} ggml_metal_kargs_diag; + typedef struct { int64_t ne00; int64_t ne01; diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 7f4cfbba226..dbf25433c25 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -341,6 +341,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { { n_fuse = ggml_metal_op_rwkv(ctx, idx); } break; + case GGML_OP_SOLVE_TRI: + { + n_fuse = ggml_metal_op_solve_tri(ctx, idx); + } break; case GGML_OP_MUL_MAT: { n_fuse = ggml_metal_op_mul_mat(ctx, idx); @@ -357,6 +361,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { { n_fuse = ggml_metal_op_set_rows(ctx, idx); } break; + case GGML_OP_DIAG: + { + n_fuse = ggml_metal_op_diag(ctx, idx); + } break; case GGML_OP_L2_NORM: { n_fuse = ggml_metal_op_l2_norm(ctx, idx); @@ -699,7 +707,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) { /*.o1 =*/ { 0 }, }; - auto pipeline = ggml_metal_library_get_pipeline_bin(lib, GGML_OP_ADD, 1, false); + auto pipeline = ggml_metal_library_get_pipeline_bin_one(lib, GGML_OP_ADD); ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); @@ -1255,6 +1263,48 @@ int ggml_metal_op_set_rows(ggml_metal_op_t ctx, int idx) { return 1; } +int ggml_metal_op_diag(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS(int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS(int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + ggml_metal_kargs_diag args = { + /*.ne00 =*/ne00, + /*.ne01 =*/ne01, + /*.ne02 =*/ne02, + /*.ne03 =*/ne03, + /*.nb00 =*/nb00, + /*.nb01 =*/nb01, + /*.nb02 =*/nb02, + /*.nb03 =*/nb03, + /*.ne0 =*/ne0, + /*.ne1 =*/ne1, + /*.ne2 =*/ne2, + /*.ne3 =*/ne3, + /*.nb0 =*/nb0, + /*.nb1 =*/nb1, + /*.nb2 =*/nb2, + /*.nb3 =*/nb3, + }; + + auto pipeline = ggml_metal_library_get_pipeline_diag(lib, op); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, 32, 1, 1); + + return 1; +} + int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) { ggml_tensor * op = ctx->node(idx); @@ -1557,6 +1607,63 @@ int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) { return 1; } +int ggml_metal_op_solve_tri(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + ggml_metal_kargs_solve_tri args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.ne13 =*/ ne13, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + }; + + auto pipeline = ggml_metal_library_get_pipeline_solve_tri(lib, op); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); + + const int nsg = pipeline.nsg; + + ggml_metal_encoder_set_threadgroup_memory_size(enc, pipeline.smem, 0); + + ggml_metal_encoder_dispatch_threadgroups(enc, (ne10 + nsg - 1)/nsg, ne02, ne03, 32, nsg, 1); + + return 1; +} + int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) { ggml_tensor * op = ctx->node(idx); @@ -2295,7 +2402,7 @@ size_t ggml_metal_op_flash_attn_ext_extra_blk(const ggml_tensor * op) { // return res; //} - const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPTG : OP_FLASH_ATTN_EXT_NQPTG; + const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPSG : OP_FLASH_ATTN_EXT_NQPSG; const int ncpsg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NCPSG : OP_FLASH_ATTN_EXT_NCPSG; const int64_t ne1 = (ne01 + nqptg - 1)/nqptg; @@ -2411,7 +2518,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { if (!ggml_metal_op_flash_attn_ext_use_vec(op)) { // half8x8 kernel - const int nqptg = OP_FLASH_ATTN_EXT_NQPTG; // queries per threadgroup + const int nqptg = OP_FLASH_ATTN_EXT_NQPSG; // queries per threadgroup const int ncpsg = OP_FLASH_ATTN_EXT_NCPSG; // cache values per simdgroup GGML_ASSERT(nqptg <= 32); @@ -2578,9 +2685,9 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { #undef FATTN_SMEM } else { // half4x4 kernel - const int nqptg = OP_FLASH_ATTN_EXT_VEC_NQPTG; // queries per threadgroup + const int nqptg = OP_FLASH_ATTN_EXT_VEC_NQPSG; // queries per threadgroup const int ncpsg = OP_FLASH_ATTN_EXT_VEC_NCPSG; // cache values per simdgroup !! sync with kernel template arguments !! - const int nkpsg = 1*ncpsg; + const int nhptg = 1; // heads per threadgroup GGML_ASSERT(nqptg <= 32); GGML_ASSERT(nqptg % 1 == 0); @@ -2632,6 +2739,9 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { ggml_metal_op_concurrency_reset(ctx); } + // note: for simplicity assume the K is larger or equal than V + GGML_ASSERT(ne10 >= ne20); + // ne00 + 2*ncpsg*(nsg) // for each query, we load it as f16 in shared memory (ne00) // and store the soft_max values and the mask @@ -2639,28 +2749,9 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { // ne20*(nsg) // each simdgroup has a full f32 head vector in shared mem to accumulate results // -#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + 2*GGML_PAD(ne20, 128)*(nsg))*(sizeof(float)/2), 16)) - - int64_t nsgmax = 2; - while (true) { - const size_t smem = FATTN_SMEM(nsgmax); - // avoid using more than half of the threadgroup memory - can cause slow downs especially for large head sizes - if (smem > props_dev->max_theadgroup_memory_size/2) { - break; - } - nsgmax *= 2; - } - nsgmax /= 2; - - // simdgroups per threadgroup (a.k.a. warps) - //const int64_t nsgt = MAX(2, MIN(nsgmax, MIN((ne11 + nkpsg - 1)/(nkpsg), (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))); - const int64_t nsgt = MAX(2, MIN(nsgmax, MIN((ne11 + nkpsg - 1)/(nkpsg), (int64_t) 1024/32))); +#define FATTN_SMEM(nsg) (GGML_PAD(((GGML_PAD(ne00, 128) + 4*ncpsg + 2*GGML_PAD(ne20, 128))*(nsg))*(sizeof(float)/2), 16)) int64_t nsg = 1; - while (nsg <= nsgt) { - nsg *= 2; - } - nsg /= 2; // workgroups // each workgroup handles nsg*nkpsg cache values @@ -2673,7 +2764,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { } else { nwg = 32; nsg = 1; - while (2*nwg*nsg*nkpsg < ne11 && nsg < 4) { + while (2*nwg*nsg*ncpsg < ne11 && nsg < 4) { nsg *= 2; } } @@ -2739,7 +2830,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); - ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1); + ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, (ne02 + nhptg - 1)/nhptg, ne03*nwg, 32, nsg, 1); } else { // sanity checks assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) != 0); @@ -2752,7 +2843,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_set_buffer(enc, bid_tmp, 7); ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); - ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1); + ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, (ne02 + nhptg - 1)/nhptg, ne03*nwg, 32, nsg, 1); // sync the 2 kernels ggml_metal_op_concurrency_reset(ctx); @@ -2804,8 +2895,6 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) { GGML_ASSERT(ggml_is_contiguous_rows(op->src[0])); GGML_ASSERT(ggml_is_contiguous_rows(op->src[1])); - bool bcast_row = false; - ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]); ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]); ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); @@ -2899,18 +2988,7 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) { struct ggml_metal_pipeline_with_params pipeline; - if (ggml_nelements(op->src[1]) == ne10 && ggml_is_contiguous(op->src[1]) && ne00 % 4 == 0 && ne10 % 4 == 0) { - GGML_ASSERT(ggml_is_contiguous(op->src[0])); - - // src1 is a row - GGML_ASSERT(ne11 == 1); - - pipeline = ggml_metal_library_get_pipeline_bin(lib, op->op, n_fuse, true); - - bcast_row = true; - } else { - pipeline = ggml_metal_library_get_pipeline_bin(lib, op->op, n_fuse, false); - } + pipeline = ggml_metal_library_get_pipeline_bin(lib, op, n_fuse); if (n_fuse > 1) { bid_dst = ggml_metal_get_buffer_id(ctx->node(idx + n_fuse - 1)); @@ -2924,20 +3002,28 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) { } } + if (pipeline.c4) { + args.ne00 = ne00/4; + args.ne10 = ne10/4; + args.ne0 = ne0/4; + } + ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); ggml_metal_encoder_set_buffer (enc, bid_src0, 1); ggml_metal_encoder_set_buffer (enc, bid_src1, 2); ggml_metal_encoder_set_buffer (enc, bid_dst, 3); - if (bcast_row) { - const int64_t n = ggml_nelements(op)/4; + if (pipeline.cnt) { + const int n = pipeline.c4 ? ggml_nelements(op)/4 : ggml_nelements(op); ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1); } else { - int nth = 32; + const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + + int nth = 1; - while (16*nth < ne0 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + while (2*nth < args.ne0 && nth < nth_max) { nth *= 2; } diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h index 10686a334e0..3c64e4f6007 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.h +++ b/ggml/src/ggml-metal/ggml-metal-ops.h @@ -56,10 +56,12 @@ int ggml_metal_op_sum_rows (ggml_metal_op_t ctx, int idx); int ggml_metal_op_cumsum (ggml_metal_op_t ctx, int idx); int ggml_metal_op_get_rows (ggml_metal_op_t ctx, int idx); int ggml_metal_op_set_rows (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_diag (ggml_metal_op_t ctx, int idx); int ggml_metal_op_soft_max (ggml_metal_op_t ctx, int idx); int ggml_metal_op_ssm_conv (ggml_metal_op_t ctx, int idx); int ggml_metal_op_ssm_scan (ggml_metal_op_t ctx, int idx); int ggml_metal_op_rwkv (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_solve_tri (ggml_metal_op_t ctx, int idx); int ggml_metal_op_cpy (ggml_metal_op_t ctx, int idx); int ggml_metal_op_pool_1d (ggml_metal_op_t ctx, int idx); int ggml_metal_op_pool_2d (ggml_metal_op_t ctx, int idx); diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp index 56b59f0afdf..1c705362fb7 100644 --- a/ggml/src/ggml-metal/ggml-metal.cpp +++ b/ggml/src/ggml-metal/ggml-metal.cpp @@ -7,11 +7,15 @@ #include "ggml-metal-context.h" #include "ggml-metal-ops.h" -// globals +#include +#include -// initialized in ggml_backend_metal_reg -static ggml_backend_reg g_ggml_metal_reg; -static ggml_backend_device g_ggml_metal_device; +#define GGML_METAL_NAME "MTL" +#define GGML_METAL_MAX_DEVICES 16 + +// number of Metal devices +// note: can be overriden with GGML_METAL_DEVICES env to simulate virtual devices +static int g_devices = 1; //////////////////////////////////////////////////////////////////////////////// // backend interface @@ -165,10 +169,28 @@ static ggml_backend_buffer_i ggml_backend_metal_buffer_private_i = { /* .reset = */ NULL, }; +static bool ggml_backend_buffer_is_metal(ggml_backend_buffer_t buffer) { + return buffer->iface.free_buffer == ggml_backend_metal_buffer_shared_free_buffer || + buffer->iface.free_buffer == ggml_backend_metal_buffer_private_free_buffer; +} + // // buffer types // +struct ggml_backend_metal_buffer_type { + int device; + std::string name; +}; + +struct ggml_backend_metal_buffer_type_deleter { + void operator()(ggml_backend_metal_buffer_type * ctx) const { + delete ctx; + } +}; + +typedef std::unique_ptr ggml_backend_metal_buffer_type_ptr; + // common method for allocating shread or private Metal buffers static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size, bool shared) { ggml_metal_device_t ctx_dev = (ggml_metal_device_t)buft->device->context; @@ -218,9 +240,9 @@ static size_t ggml_backend_metal_buffer_type_get_alloc_size(ggml_backend_buffer_ // default (shared) buffer type static const char * ggml_backend_metal_buffer_type_shared_get_name(ggml_backend_buffer_type_t buft) { - return "Metal"; + ggml_backend_metal_buffer_type * ctx = (ggml_backend_metal_buffer_type *)buft->context; - GGML_UNUSED(buft); + return ctx->name.c_str(); } static ggml_backend_buffer_t ggml_backend_metal_buffer_type_shared_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { @@ -249,29 +271,54 @@ static bool ggml_backend_metal_buffer_type_shared_is_host(ggml_backend_buffer_ty GGML_UNUSED(buft); } -static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_shared(void) { - static ggml_backend_buffer_type ggml_backend_buffer_type_metal = { - /* .iface = */ { - /* .get_name = */ ggml_backend_metal_buffer_type_shared_get_name, - /* .alloc_buffer = */ ggml_backend_metal_buffer_type_shared_alloc_buffer, - /* .get_alignment = */ ggml_backend_metal_buffer_type_shared_get_alignment, - /* .get_max_size = */ ggml_backend_metal_buffer_type_shared_get_max_size, - /* .get_alloc_size = */ ggml_backend_metal_buffer_type_shared_get_alloc_size, - /* .is_host = */ ggml_backend_metal_buffer_type_shared_is_host, - }, - /* .device = */ &g_ggml_metal_device, - /* .context = */ NULL, - }; +static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_shared(int device) { + static std::mutex mutex; + std::lock_guard lock(mutex); + + static std::vector bufts; + static std::vector ctxs; + + static bool initialized = false; + if (!initialized) { + bufts.reserve(g_devices); + ctxs.reserve(g_devices); + + for (int i = 0; i < g_devices; ++i) { + ggml_backend_metal_buffer_type * raw_ctx = + new ggml_backend_metal_buffer_type { + /* .device = */ i, + /* .name = */ GGML_METAL_NAME + std::to_string(i), + }; + ctxs.emplace_back(raw_ctx); + + ggml_backend_buffer_type buft = { + /* .iface = */ { + /* .get_name = */ ggml_backend_metal_buffer_type_shared_get_name, + /* .alloc_buffer = */ ggml_backend_metal_buffer_type_shared_alloc_buffer, + /* .get_alignment = */ ggml_backend_metal_buffer_type_shared_get_alignment, + /* .get_max_size = */ ggml_backend_metal_buffer_type_shared_get_max_size, + /* .get_alloc_size = */ ggml_backend_metal_buffer_type_shared_get_alloc_size, + /* .is_host = */ ggml_backend_metal_buffer_type_shared_is_host, + }, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_metal_reg(), i), + /* .context = */ raw_ctx, + }; + + bufts.emplace_back(buft); + } + + initialized = true; + } - return &ggml_backend_buffer_type_metal; + return &bufts[device]; } // default (private) buffer type static const char * ggml_backend_metal_buffer_type_private_get_name(ggml_backend_buffer_type_t buft) { - return "Metal_Private"; + ggml_backend_metal_buffer_type * ctx = (ggml_backend_metal_buffer_type *)buft->context; - GGML_UNUSED(buft); + return ctx->name.c_str(); } static ggml_backend_buffer_t ggml_backend_metal_buffer_type_private_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { @@ -300,29 +347,53 @@ static bool ggml_backend_metal_buffer_type_private_is_host(ggml_backend_buffer_t GGML_UNUSED(buft); } -static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_private(void) { - static ggml_backend_buffer_type ggml_backend_buffer_type_metal = { - /* .iface = */ { - /* .get_name = */ ggml_backend_metal_buffer_type_private_get_name, - /* .alloc_buffer = */ ggml_backend_metal_buffer_type_private_alloc_buffer, - /* .get_alignment = */ ggml_backend_metal_buffer_type_private_get_alignment, - /* .get_max_size = */ ggml_backend_metal_buffer_type_private_get_max_size, - /* .get_alloc_size = */ ggml_backend_metal_buffer_type_private_get_alloc_size, - /* .is_host = */ ggml_backend_metal_buffer_type_private_is_host, - }, - /* .device = */ &g_ggml_metal_device, - /* .context = */ NULL, - }; +static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_private(int device) { + static std::mutex mutex; + std::lock_guard lock(mutex); + + static std::vector bufts; + static std::vector ctxs; + + static bool initialized = false; + if (!initialized) { + bufts.reserve(g_devices); + ctxs.reserve(g_devices); + + for (int i = 0; i < g_devices; ++i) { + ggml_backend_metal_buffer_type * raw_ctx = new ggml_backend_metal_buffer_type{ + /* .device = */ i, + /* .name = */ GGML_METAL_NAME + std::to_string(i) + "_Private" + }; + ctxs.emplace_back(raw_ctx); + + ggml_backend_buffer_type buft = { + /* .iface = */ { + /* .get_name = */ ggml_backend_metal_buffer_type_private_get_name, + /* .alloc_buffer = */ ggml_backend_metal_buffer_type_private_alloc_buffer, + /* .get_alignment = */ ggml_backend_metal_buffer_type_private_get_alignment, + /* .get_max_size = */ ggml_backend_metal_buffer_type_private_get_max_size, + /* .get_alloc_size = */ ggml_backend_metal_buffer_type_private_get_alloc_size, + /* .is_host = */ ggml_backend_metal_buffer_type_private_is_host, + }, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_metal_reg(), i), + /* .context = */ raw_ctx, + }; + + bufts.emplace_back(buft); + } + + initialized = true; + } - return &ggml_backend_buffer_type_metal; + return &bufts[device]; } // mapped buffer type static const char * ggml_backend_metal_buffer_type_mapped_get_name(ggml_backend_buffer_type_t buft) { - return "Metal_Mapped"; + ggml_backend_metal_buffer_type * ctx = (ggml_backend_metal_buffer_type *)buft->context; - GGML_UNUSED(buft); + return ctx->name.c_str(); } static ggml_backend_buffer_t ggml_backend_metal_buffer_type_mapped_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { @@ -352,31 +423,55 @@ static bool ggml_backend_metal_buffer_type_mapped_is_host(ggml_backend_buffer_ty GGML_UNUSED(buft); } -static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_mapped(void) { - // note: not obvious, but this buffer type still needs to implement .alloc_buffer: - // https://github.com/ggml-org/llama.cpp/pull/15832#discussion_r2333177099 - static ggml_backend_buffer_type ggml_backend_buffer_type_mapped_metal = { - /* .iface = */ { - /* .get_name = */ ggml_backend_metal_buffer_type_mapped_get_name, - /* .alloc_buffer = */ ggml_backend_metal_buffer_type_mapped_alloc_buffer, - /* .get_alignment = */ ggml_backend_metal_buffer_type_mapped_get_alignment, - /* .get_max_size = */ ggml_backend_metal_buffer_type_mapped_get_max_size, - /* .get_alloc_size = */ ggml_backend_metal_buffer_type_mapped_get_alloc_size, - /* .is_host = */ ggml_backend_metal_buffer_type_mapped_is_host, - }, - /* .device = */ &g_ggml_metal_device, - /* .context = */ NULL, - }; +static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_mapped(int device) { + static std::mutex mutex; + std::lock_guard lock(mutex); + + static std::vector bufts; + static std::vector ctxs; + + static bool initialized = false; + if (!initialized) { + bufts.reserve(g_devices); + ctxs.reserve(g_devices); + + for (int i = 0; i < g_devices; ++i) { + ggml_backend_metal_buffer_type * raw_ctx = new ggml_backend_metal_buffer_type{ + /* .device = */ i, + /* .name = */ GGML_METAL_NAME + std::to_string(i) + "_Mapped" + }; + ctxs.emplace_back(raw_ctx); + + // note: not obvious, but this buffer type still needs to implement .alloc_buffer: + // https://github.com/ggml-org/llama.cpp/pull/15832#discussion_r2333177099 + ggml_backend_buffer_type buft = { + /* .iface = */ { + /* .get_name = */ ggml_backend_metal_buffer_type_mapped_get_name, + /* .alloc_buffer = */ ggml_backend_metal_buffer_type_mapped_alloc_buffer, + /* .get_alignment = */ ggml_backend_metal_buffer_type_mapped_get_alignment, + /* .get_max_size = */ ggml_backend_metal_buffer_type_mapped_get_max_size, + /* .get_alloc_size = */ ggml_backend_metal_buffer_type_mapped_get_alloc_size, + /* .is_host = */ ggml_backend_metal_buffer_type_mapped_is_host, + }, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_metal_reg(), i), + /* .context = */ raw_ctx, + }; + + bufts.emplace_back(buft); + } + + initialized = true; + } - return &ggml_backend_buffer_type_mapped_metal; + return &bufts[device]; } // backend static const char * ggml_backend_metal_name(ggml_backend_t backend) { - return "Metal"; + ggml_metal_t ctx = (ggml_metal_t)backend->context; - GGML_UNUSED(backend); + return ggml_metal_get_name(ctx); } static void ggml_backend_metal_free(ggml_backend_t backend) { @@ -409,12 +504,24 @@ static void ggml_backend_metal_get_tensor_async(ggml_backend_t backend, const gg } static bool ggml_backend_metal_cpy_tensor_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const ggml_tensor * src, ggml_tensor * dst) { - return false; + if (!ggml_backend_is_metal(backend_src) || !ggml_backend_is_metal(backend_dst)) { + return false; + } - GGML_UNUSED(backend_src); - GGML_UNUSED(backend_dst); - GGML_UNUSED(src); - GGML_UNUSED(dst); + if (!ggml_backend_buffer_is_metal(src->buffer) || !ggml_backend_buffer_is_metal(dst->buffer)) { + return false; + } + + ggml_metal_t ctx_src = (ggml_metal_t)backend_src->context; + ggml_metal_t ctx_dst = (ggml_metal_t)backend_dst->context; + + //ggml_backend_buffer_t buf_src = src->view_src ? src->view_src->buffer : src->buffer; + //ggml_backend_buffer_t buf_dst = dst->view_src ? dst->view_src->buffer : dst->buffer; + + //ggml_metal_buffer_t buf_ctx_src = (ggml_metal_buffer_t)buf_src->context; + //ggml_metal_buffer_t buf_ctx_dst = (ggml_metal_buffer_t)buf_dst->context; + + return ggml_metal_cpy_tensor_async(ctx_src, ctx_dst, src, dst); } static enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { @@ -423,6 +530,20 @@ static enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend, return ggml_metal_graph_compute(ctx, cgraph); } +static void ggml_backend_metal_event_record(ggml_backend_t backend, ggml_backend_event_t event) { + ggml_metal_t ctx = (ggml_metal_t)backend->context; + ggml_metal_event_t ev = (ggml_metal_event_t)event->context; + + ggml_metal_event_record(ctx, ev); +} + +static void ggml_backend_metal_event_wait(ggml_backend_t backend, ggml_backend_event_t event) { + ggml_metal_t ctx = (ggml_metal_t)backend->context; + ggml_metal_event_t ev = (ggml_metal_event_t)event->context; + + ggml_metal_event_wait(ctx, ev); +} + static void ggml_backend_metal_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) { ggml_metal_t ctx = (ggml_metal_t)backend->context; @@ -435,7 +556,6 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) { ggml_metal_t ctx = (ggml_metal_t)backend->context; ggml_metal_set_n_cb(ctx, n_cb); - } static ggml_backend_i ggml_backend_metal_i = { @@ -450,12 +570,8 @@ static ggml_backend_i ggml_backend_metal_i = { /* .graph_plan_update = */ NULL, /* .graph_plan_compute = */ NULL, /* .graph_compute = */ ggml_backend_metal_graph_compute, - - // the events API is needed only for multi-GPU setups, so likely no need to implement it for Metal - // in any case, these docs seem relevant if we ever decide to implement it: - // https://developer.apple.com/documentation/metal/mtlcommandbuffer#Synchronizing-Passes-with-Events - /* .event_record = */ NULL, - /* .event_wait = */ NULL, + /* .event_record = */ ggml_backend_metal_event_record, + /* .event_wait = */ ggml_backend_metal_event_wait, /* .graph_optimize = */ ggml_backend_metal_graph_optimize, }; @@ -519,15 +635,17 @@ void ggml_backend_metal_capture_next_compute(ggml_backend_t backend) { // backend device static const char * ggml_backend_metal_device_get_name(ggml_backend_dev_t dev) { - return "Metal"; + ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context; - GGML_UNUSED(dev); + const ggml_metal_device_props * props_dev = ggml_metal_device_get_props(ctx_dev); + + return props_dev->name; } static const char * ggml_backend_metal_device_get_description(ggml_backend_dev_t dev) { ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context; - return ggml_metal_device_get_props(ctx_dev)->name; + return ggml_metal_device_get_props(ctx_dev)->desc; } static void ggml_backend_metal_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { @@ -550,14 +668,14 @@ static void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, ggml_bac ggml_backend_metal_device_get_memory(dev, &props->memory_free, &props->memory_total); props->caps = { - /* .async = */ true, - /* .host_buffer = */ false, - /* .buffer_from_host_ptr = */ true, - /* .events = */ false, + /* .async = */ true, + /* .host_buffer = */ false, + /* .buffer_from_host_ptr = */ true, + /* .events = */ true, }; } -static ggml_backend_t ggml_backend_metal_device_init(ggml_backend_dev_t dev, const char * params) { +static ggml_backend_t ggml_backend_metal_device_init_backend(ggml_backend_dev_t dev, const char * params) { ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context; ggml_metal_t ctx = ggml_metal_init(ctx_dev); @@ -587,7 +705,7 @@ static ggml_backend_buffer_type_t ggml_backend_metal_device_get_buffer_type(ggml const ggml_metal_device_props * props_dev = ggml_metal_device_get_props(ctx_dev); - return props_dev->use_shared_buffers ? ggml_backend_metal_buffer_type_shared() : ggml_backend_metal_buffer_type_private(); + return props_dev->use_shared_buffers ? ggml_backend_metal_buffer_type_shared(props_dev->device) : ggml_backend_metal_buffer_type_private(props_dev->device); } static ggml_backend_buffer_t ggml_backend_metal_device_buffer_mapped(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) { @@ -595,7 +713,9 @@ static ggml_backend_buffer_t ggml_backend_metal_device_buffer_mapped(ggml_backen ggml_metal_buffer_t res = ggml_metal_buffer_map(ctx_dev, ptr, size, max_tensor_size); - return ggml_backend_buffer_init(ggml_backend_metal_buffer_type_mapped(), ggml_backend_metal_buffer_shared_i, res, size); + const ggml_metal_device_props * props_dev = ggml_metal_device_get_props(ctx_dev); + + return ggml_backend_buffer_init(ggml_backend_metal_buffer_type_mapped(props_dev->device), ggml_backend_metal_buffer_shared_i, res, size); } static bool ggml_backend_metal_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) { @@ -606,9 +726,10 @@ static bool ggml_backend_metal_device_supports_op(ggml_backend_dev_t dev, const static bool ggml_backend_metal_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { return + buft->device == dev && ( buft->iface.get_name == ggml_backend_metal_buffer_type_shared_get_name || buft->iface.get_name == ggml_backend_metal_buffer_type_private_get_name || - buft->iface.get_name == ggml_backend_metal_buffer_type_mapped_get_name; + buft->iface.get_name == ggml_backend_metal_buffer_type_mapped_get_name); GGML_UNUSED(dev); } @@ -632,45 +753,97 @@ static bool ggml_backend_metal_device_offload_op(ggml_backend_dev_t dev, const g get_op_batch_size(op) >= ggml_metal_device_get_props(ctx_dev)->op_offload_min_batch_size; } +static ggml_backend_event_t ggml_backend_metal_device_event_new(ggml_backend_dev_t dev) { + ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context; + + ggml_metal_event_t event = ggml_metal_device_event_init(ctx_dev); + GGML_ASSERT(event); + + ggml_backend_event_t ev = new ggml_backend_event { + /* .device = */ dev, + /* .context = */ event, + }; + + return ev; +} + +static void ggml_backend_metal_device_event_free(ggml_backend_dev_t dev, ggml_backend_event_t event) { + ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context; + + ggml_metal_event_t ev = (ggml_metal_event_t)event->context; + + ggml_metal_device_event_free(ctx_dev, ev); + + delete event; +} + +static void ggml_backend_metal_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) { + ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context; + + ggml_metal_event_t evt = (ggml_metal_event_t)event->context; + + ggml_metal_device_event_synchronize(ctx_dev, evt); +} + static ggml_backend_device_i ggml_backend_metal_device_i = { /* .get_name = */ ggml_backend_metal_device_get_name, /* .get_description = */ ggml_backend_metal_device_get_description, /* .get_memory = */ ggml_backend_metal_device_get_memory, /* .get_type = */ ggml_backend_metal_device_get_type, /* .get_props = */ ggml_backend_metal_device_get_props, - /* .init_backend = */ ggml_backend_metal_device_init, + /* .init_backend = */ ggml_backend_metal_device_init_backend, /* .get_buffer_type = */ ggml_backend_metal_device_get_buffer_type, /* .get_host_buffer_type = */ NULL, /* .buffer_from_host_ptr = */ ggml_backend_metal_device_buffer_mapped, /* .supports_op = */ ggml_backend_metal_device_supports_op, /* .supports_buft = */ ggml_backend_metal_device_supports_buft, /* .offload_op = */ ggml_backend_metal_device_offload_op, - /* .event_new = */ NULL, - /* .event_free = */ NULL, - /* .event_synchronize = */ NULL, + /* .event_new = */ ggml_backend_metal_device_event_new, + /* .event_free = */ ggml_backend_metal_device_event_free, + /* .event_synchronize = */ ggml_backend_metal_device_event_synchronize, }; // backend registry +struct ggml_backend_metal_reg { + std::vector devices; +}; + +typedef struct ggml_backend_metal_reg * ggml_backend_metal_reg_t; + +static ggml_backend_metal_reg_t ggml_backend_metal_reg_init(void) { + ggml_backend_metal_reg_t ctx = new struct ggml_backend_metal_reg; + + return ctx; +} + +static void ggml_backend_metal_reg_free(ggml_backend_metal_reg_t ctx) { + delete ctx; +} + +struct ggml_backend_metal_reg_deleter { + void operator()(ggml_backend_metal_reg_t ctx) { + ggml_backend_metal_reg_free(ctx); + } +}; + +typedef std::unique_ptr ggml_backend_metal_reg_ptr; + static const char * ggml_backend_metal_reg_get_name(ggml_backend_reg_t reg) { - return "Metal"; + return GGML_METAL_NAME; GGML_UNUSED(reg); } static size_t ggml_backend_metal_reg_device_count(ggml_backend_reg_t reg) { - return 1; - - GGML_UNUSED(reg); + ggml_backend_metal_reg_t ctx = (ggml_backend_metal_reg_t)reg->context; + return ctx->devices.size(); } static ggml_backend_dev_t ggml_backend_metal_reg_device_get(ggml_backend_reg_t reg, size_t index) { - GGML_ASSERT(index == 0); - - return &g_ggml_metal_device; - - GGML_UNUSED(reg); - GGML_UNUSED(index); + ggml_backend_metal_reg_t ctx = (ggml_backend_metal_reg_t)reg->context; + GGML_ASSERT(index < ctx->devices.size()); + return ctx->devices[index]; } static ggml_backend_feature g_ggml_backend_metal_features[] = { @@ -698,27 +871,67 @@ static void * ggml_backend_metal_get_proc_address(ggml_backend_reg_t reg, const static ggml_backend_reg_i ggml_backend_metal_reg_i = { /* .get_name = */ ggml_backend_metal_reg_get_name, - /* .device_count = */ ggml_backend_metal_reg_device_count, - /* .device_get = */ ggml_backend_metal_reg_device_get, + /* .get_device_count = */ ggml_backend_metal_reg_device_count, + /* .get_device = */ ggml_backend_metal_reg_device_get, /* .get_proc_address = */ ggml_backend_metal_get_proc_address, }; +static ggml_backend_dev_t ggml_backend_metal_device_init(ggml_backend_reg_t reg, int device) { + return new ggml_backend_device { + /* .iface = */ ggml_backend_metal_device_i, + /* .reg = */ reg, + /* .context = */ ggml_metal_device_get(device), + }; +} + +static void ggml_backend_metal_device_free(ggml_backend_dev_t dev) { + delete dev; +} + +struct ggml_backend_device_deleter { + void operator()(ggml_backend_dev_t ctx) { + ggml_backend_metal_device_free(ctx); + } +}; + +typedef std::unique_ptr ggml_backend_device_ptr; + ggml_backend_reg_t ggml_backend_metal_reg(void) { + static ggml_backend_reg reg; + static bool initialized = false; + { - g_ggml_metal_reg = { - /* .api_version = */ GGML_BACKEND_API_VERSION, - /* .iface = */ ggml_backend_metal_reg_i, - /* .context = */ NULL, - }; - - g_ggml_metal_device = { - /* .iface = */ ggml_backend_metal_device_i, - /* .reg = */ &g_ggml_metal_reg, - /* .context = */ ggml_metal_device_get(), - }; + static std::mutex mutex; + std::lock_guard lock(mutex); + + const char * env = getenv("GGML_METAL_DEVICES"); + if (env) { + g_devices = atoi(env); + } + + static std::vector devs; + + if (!initialized) { + static ggml_backend_metal_reg_ptr reg_ctx(ggml_backend_metal_reg_init()); + + for (int i = 0; i < g_devices; ++i) { + auto * dev = ggml_backend_metal_device_init(®, i); + devs.emplace_back(dev); + + reg_ctx->devices.push_back(dev); + } + + reg = { + /* .api_version = */ GGML_BACKEND_API_VERSION, + /* .iface = */ ggml_backend_metal_reg_i, + /* .context = */ reg_ctx.get(), + }; + } + + initialized = true; } - return &g_ggml_metal_reg; + return ® } GGML_BACKEND_DL_IMPL(ggml_backend_metal_reg) diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 17e358d1a8d..35cc3bbdfdf 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -895,11 +895,13 @@ enum ggml_sort_order { GGML_SORT_ORDER_DESC, }; -// general-purpose kernel for addition, subtraction, multiplication and division of two tensors -// pros: works for non-contiguous tensors, supports broadcast across all dims -// cons: not very efficient -template -kernel void kernel_add_fuse_impl( +// OP: 0 - add, 1 - sub, 2 - mul, 3 - div +constant short FC_bin_op [[function_constant(FC_BIN + 0)]]; +constant short FC_bin_f [[function_constant(FC_BIN + 1)]]; +constant bool FC_bin_rb [[function_constant(FC_BIN + 2)]]; + +template +kernel void kernel_bin_fuse_impl( constant ggml_metal_kargs_bin & args, device const char * src0, device const char * src1, @@ -907,139 +909,153 @@ kernel void kernel_add_fuse_impl( uint3 tgpig[[threadgroup_position_in_grid]], ushort3 tpitg[[thread_position_in_threadgroup]], ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig.z; - const int i02 = tgpig.y; - const int i01 = tgpig.x; +#define FC_OP FC_bin_op +#define FC_F FC_bin_f +#define FC_RB FC_bin_rb - const int i13 = i03%args.ne13; - const int i12 = i02%args.ne12; - const int i11 = i01%args.ne11; + if (FC_RB) { + // row broadcast + const uint i0 = tgpig.x; + const uint i1 = i0%args.ne10; - device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs); - device float * dst_ptr = (device float *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs); + device const T0 * src0_row = (device const T0 *) (src0); + device T * dst_row = (device T *) (dst); - device const float * src1_ptr[F]; - for (short j = 0; j < F; ++j) { - src1_ptr[j] = (device const float *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11); - } + if (FC_F == 1) { + device const T1 * src1_row = (device const T1 *) (src1 + args.o1[0]); - for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { - const int i10 = i0%args.ne10; + if (FC_OP == 0) { + dst_row[i0] = src0_row[i0] + src1_row[i1]; + } + + if (FC_OP == 1) { + dst_row[i0] = src0_row[i0] - src1_row[i1]; + } - float res = src0_ptr[i0]; + if (FC_OP == 2) { + dst_row[i0] = src0_row[i0] * src1_row[i1]; + } -#pragma unroll - for (short j = 0; j < F; ++j) { - res += src1_ptr[j][i10]; + if (FC_OP == 3) { + dst_row[i0] = src0_row[i0] / src1_row[i1]; + } + } else { + T0 res = src0_row[i0]; + + if (FC_OP == 0) { + FOR_UNROLL (short j = 0; j < FC_F; ++j) { + res += ((device const T1 *) (src1 + args.o1[j]))[i1]; + } + } + + if (FC_OP == 1) { + FOR_UNROLL (short j = 0; j < FC_F; ++j) { + res -= ((device const T1 *) (src1 + args.o1[j]))[i1]; + } + } + + if (FC_OP == 2) { + FOR_UNROLL (short j = 0; j < FC_F; ++j) { + res *= ((device const T1 *) (src1 + args.o1[j]))[i1]; + } + } + + if (FC_OP == 3) { + FOR_UNROLL (short j = 0; j < FC_F; ++j) { + res /= ((device const T1 *) (src1 + args.o1[j]))[i1]; + } + } + + dst_row[i0] = res; } + } else { + const int i03 = tgpig.z; + const int i02 = tgpig.y; + const int i01 = tgpig.x; - dst_ptr[i0] = res; - } -} + if (i01 >= args.ne01) { + return; + } -typedef decltype(kernel_add_fuse_impl<2>) kernel_add_fuse_t; + const int i13 = i03%args.ne13; + const int i12 = i02%args.ne12; + const int i11 = i01%args.ne11; -template [[host_name("kernel_add_fuse_1")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<1>; -template [[host_name("kernel_add_fuse_2")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<2>; -template [[host_name("kernel_add_fuse_3")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<3>; -template [[host_name("kernel_add_fuse_4")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<4>; -template [[host_name("kernel_add_fuse_5")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<5>; -template [[host_name("kernel_add_fuse_6")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<6>; -template [[host_name("kernel_add_fuse_7")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<7>; -template [[host_name("kernel_add_fuse_8")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<8>; + device const T0 * src0_ptr = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs); + device T * dst_ptr = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs); -kernel void kernel_sub_fuse_1( - constant ggml_metal_kargs_bin & args, - device const char * src0, - device const char * src1, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig.z; - const int i02 = tgpig.y; - const int i01 = tgpig.x; + if (FC_F == 1) { + device const T1 * src1_ptr = (device const T1 *) (src1 + args.o1[0] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11); - const int i13 = i03%args.ne13; - const int i12 = i02%args.ne12; - const int i11 = i01%args.ne11; + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const int i10 = i0%args.ne10; - device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs; - device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0]; - device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs; + if (FC_OP == 0) { + dst_ptr[i0] = src0_ptr[i0] + src1_ptr[i10]; + } - for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { - const int i10 = i0%args.ne10; - *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) - *((device float *)(src1_ptr + i10*args.nb10)); - } -} + if (FC_OP == 1) { + dst_ptr[i0] = src0_ptr[i0] - src1_ptr[i10]; + } -kernel void kernel_mul_fuse_1( - constant ggml_metal_kargs_bin & args, - device const char * src0, - device const char * src1, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig.z; - const int i02 = tgpig.y; - const int i01 = tgpig.x; + if (FC_OP == 2) { + dst_ptr[i0] = src0_ptr[i0] * src1_ptr[i10]; + } - const int i13 = i03%args.ne13; - const int i12 = i02%args.ne12; - const int i11 = i01%args.ne11; + if (FC_OP == 3) { + dst_ptr[i0] = src0_ptr[i0] / src1_ptr[i10]; + } + } + } else { + device const T1 * src1_ptr[8]; + FOR_UNROLL (short j = 0; j < FC_F; ++j) { + src1_ptr[j] = (device const T1 *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11); + } - device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs; - device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0]; - device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs; + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const int i10 = i0%args.ne10; - if (args.ne10 == 1) { - const float x = *((device float *)(src1_ptr)); - for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { - *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x; - } - } else { - for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { - const int i10 = i0%args.ne10; - *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * *((device float *)(src1_ptr + i10*args.nb10)); - } - } -} + T res = src0_ptr[i0]; -kernel void kernel_div_fuse_1( - constant ggml_metal_kargs_bin & args, - device const char * src0, - device const char * src1, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig.z; - const int i02 = tgpig.y; - const int i01 = tgpig.x; + if (FC_OP == 0) { + FOR_UNROLL (short j = 0; j < FC_F; ++j) { + res += src1_ptr[j][i10]; + } + } - const int i13 = i03%args.ne13; - const int i12 = i02%args.ne12; - const int i11 = i01%args.ne11; + if (FC_OP == 1) { + FOR_UNROLL (short j = 0; j < FC_F; ++j) { + res -= src1_ptr[j][i10]; + } + } - device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs; - device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0]; - device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs; + if (FC_OP == 2) { + FOR_UNROLL (short j = 0; j < FC_F; ++j) { + res *= src1_ptr[j][i10]; + } + } - if (args.ne10 == 1) { - const float x = 1.0f / *((device float *)(src1_ptr)); - for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { - *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x; - } - } else { - for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { - const int i10 = i0%args.ne10; - *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) / *((device float *)(src1_ptr + i10*args.nb10)); + if (FC_OP == 3) { + FOR_UNROLL (short j = 0; j < FC_F; ++j) { + res /= src1_ptr[j][i10]; + } + } + + dst_ptr[i0] = res; + } } } + +#undef FC_OP +#undef FC_F +#undef FC_RB } +typedef decltype(kernel_bin_fuse_impl) kernel_bin_fuse_t; + +template [[host_name("kernel_bin_fuse_f32_f32_f32")]] kernel kernel_bin_fuse_t kernel_bin_fuse_impl; +template [[host_name("kernel_bin_fuse_f32_f32_f32_4")]] kernel kernel_bin_fuse_t kernel_bin_fuse_impl; + kernel void kernel_add_id( constant ggml_metal_kargs_add_id & args, device const char * src0, @@ -1057,7 +1073,7 @@ kernel void kernel_add_id( const size_t nb1 = args.ne0 * sizeof(float); const size_t nb2 = args.ne1 * nb1; - device float * dst_row = (device float *)((device char *)dst + i1*nb1 + i2*nb2); + device float * dst_row = (device float *)((device char *)dst + i1*nb1 + i2*nb2); device const float * src0_row = (device const float *)((device char *)src0 + i1*args.nb01 + i2*args.nb02); device const float * src1_row = (device const float *)((device char *)src1 + i11*args.nb11); @@ -1098,141 +1114,6 @@ template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat; template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat; -// assumption: src1 is a row -// broadcast src1 into src0 -template -kernel void kernel_add_row_c4_fuse_impl( - constant ggml_metal_kargs_bin & args, - device const char * src0, - device const char * src1, - device char * dst, - uint tpig[[thread_position_in_grid]]) { - const uint nb = args.ne00/4; - const uint i = tpig % nb; - - device const float4 * src0_row = (device const float4 *) (src0); - device float4 * dst_row = (device float4 *) (dst); - - float4 res = src0_row[tpig]; - -#pragma unroll(F) - for (short j = 0; j < F; ++j) { - res += ((device const float4 *) (src1 + args.o1[j]))[i]; - } - - dst_row[tpig] = res; -} - -typedef decltype(kernel_add_row_c4_fuse_impl<1>) kernel_add_row_c4_fuse_t; - -template [[host_name("kernel_add_row_c4_fuse_1")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<1>; -template [[host_name("kernel_add_row_c4_fuse_2")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<2>; -template [[host_name("kernel_add_row_c4_fuse_3")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<3>; -template [[host_name("kernel_add_row_c4_fuse_4")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<4>; -template [[host_name("kernel_add_row_c4_fuse_5")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<5>; -template [[host_name("kernel_add_row_c4_fuse_6")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<6>; -template [[host_name("kernel_add_row_c4_fuse_7")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<7>; -template [[host_name("kernel_add_row_c4_fuse_8")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<8>; - -template -kernel void kernel_sub_row_c4_fuse_impl( - constant ggml_metal_kargs_bin & args, - device const char * src0, - device const char * src1, - device char * dst, - uint tpig[[thread_position_in_grid]]) { - - const uint nb = args.ne00/4; - const uint i = tpig % nb; - - device const float4 * src0_row = (device const float4 *) (src0); - device float4 * dst_row = (device float4 *) (dst); - - device const float4 * src1_row[F]; - for (short j = 0; j < F; ++j) { - src1_row[j] = (device const float4 *) (src1 + args.o1[j]); - } - - float4 res = src0_row[tpig]; - -#pragma unroll(F) - for (short j = 0; j < F; ++j) { - res -= src1_row[j][i]; - } - - dst_row[tpig] = res; -} - -typedef decltype(kernel_sub_row_c4_fuse_impl<1>) kernel_sub_row_c4_fuse_t; - -template [[host_name("kernel_sub_row_c4_fuse_1")]] kernel kernel_sub_row_c4_fuse_t kernel_sub_row_c4_fuse_impl<1>; - -template -kernel void kernel_mul_row_c4_fuse_impl( - constant ggml_metal_kargs_bin & args, - device const char * src0, - device const char * src1, - device char * dst, - uint tpig[[thread_position_in_grid]]) { - - const uint nb = args.ne00/4; - const uint i = tpig % nb; - - device const float4 * src0_row = (device const float4 *) (src0); - device float4 * dst_row = (device float4 *) (dst); - - device const float4 * src1_row[F]; - for (short j = 0; j < F; ++j) { - src1_row[j] = (device const float4 *) (src1 + args.o1[j]); - } - - float4 res = src0_row[tpig]; - -#pragma unroll(F) - for (short j = 0; j < F; ++j) { - res *= src1_row[j][i]; - } - - dst_row[tpig] = res; -} - -typedef decltype(kernel_mul_row_c4_fuse_impl<1>) kernel_mul_row_c4_fuse_t; - -template [[host_name("kernel_mul_row_c4_fuse_1")]] kernel kernel_mul_row_c4_fuse_t kernel_mul_row_c4_fuse_impl<1>; - -template -kernel void kernel_div_row_c4_fuse_impl( - constant ggml_metal_kargs_bin & args, - device const char * src0, - device const char * src1, - device char * dst, - uint tpig[[thread_position_in_grid]]) { - - const uint nb = args.ne00/4; - const uint i = tpig % nb; - - device const float4 * src0_row = (device const float4 *) (src0); - device float4 * dst_row = (device float4 *) (dst); - - device const float4 * src1_row[F]; - for (short j = 0; j < F; ++j) { - src1_row[j] = (device const float4 *) (src1 + args.o1[j]); - } - - float4 res = src0_row[tpig]; - -#pragma unroll(F) - for (short j = 0; j < F; ++j) { - res /= src1_row[j][i]; - } - - dst_row[tpig] = res; -} - -typedef decltype(kernel_div_row_c4_fuse_impl<1>) kernel_div_row_c4_fuse_t; - -template [[host_name("kernel_div_row_c4_fuse_1")]] kernel kernel_div_row_c4_fuse_t kernel_div_row_c4_fuse_impl<1>; - kernel void kernel_scale_f32( constant ggml_metal_kargs_scale & args, device const float * src0, @@ -2737,6 +2618,83 @@ kernel void kernel_rwkv_wkv7_f32( } } +constant short FC_solve_tri_nsg [[function_constant(FC_SOLVE_TRI + 0)]]; +constant short FC_solve_tri_n [[function_constant(FC_SOLVE_TRI + 1)]]; +constant short FC_solve_tri_k [[function_constant(FC_SOLVE_TRI + 2)]]; + +kernel void kernel_solve_tri_f32( + constant ggml_metal_kargs_solve_tri & args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], + ushort3 tgpig[[threadgroup_position_in_grid]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + constexpr short NW = N_SIMDWIDTH; + + const short NSG = FC_solve_tri_nsg; + const short N = FC_solve_tri_n; + const short K = FC_solve_tri_k; + const short NP = PAD2(N, NW); + + const int32_t ne02 = args.ne02; + const int32_t ne03 = args.ne03; + + const int32_t i03 = tgpig.z; + const int32_t i02 = tgpig.y; + const int32_t i01 = tgpig.x*NSG + sgitg; + + threadgroup float * sh0 = (threadgroup float *) shmem; + + device const float * src0_ptr = (device const float *)(src0 + i02 * args.nb02 + i03 * args.nb03) + sgitg*N; + device const float * src1_ptr = (device const float *)(src1 + i02 * args.nb12 + i03 * args.nb13) + i01; + device float * dst_ptr = (device float *)(dst + i02 * args.nb2 + i03 * args.nb3) + i01; + + for (short rr = 0; rr < N; rr += NSG) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + { + threadgroup float * sh0_cur = sh0 + sgitg*NP; + + for (short t = 0; t*NW < N; ++t) { + const short idx = t*NW + tiisg; + sh0_cur[idx] = src0_ptr[idx]; + } + + src0_ptr += NSG*N; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (i01 >= args.ne10) { + continue; + } + + for (short ir = 0; ir < NSG && rr + ir < N; ++ir) { + const short r = rr + ir; + + threadgroup float * sh0_cur = sh0 + ir*NP; + + float sum = 0.0f; + + for (short t = 0; t*NW < r; ++t) { + const short idx = t*NW + tiisg; + sum += sh0_cur[idx] * dst_ptr[idx*K] * (idx < r); + } + + sum = simd_sum(sum); + + if (tiisg == 0) { + const float diag = sh0_cur[r]; + + dst_ptr[r*K] = (src1_ptr[r*K] - sum) / diag; + } + } + } +} + kernel void kernel_argmax_f32( constant ggml_metal_kargs_argmax & args, device const char * src0, @@ -5208,6 +5166,7 @@ constant int32_t FC_flash_attn_ext_blk_ncpsg [[function_constant(FC_FLASH_ATTN_E // scan the blocks of the mask that are not masked // 0 - masked (i.e. full of -INF, skip) // 1 - not masked (i.e. at least one element of the mask is not -INF) +// 2 - all zero kernel void kernel_flash_attn_ext_blk( constant ggml_metal_kargs_flash_attn_ext_blk & args, device const char * mask, @@ -5229,27 +5188,29 @@ kernel void kernel_flash_attn_ext_blk( device const half * mask_src = (device const half *) (mask + (i1*Q)*args.nb31 + i2*args.nb32 + i3*args.nb33) + i0*C + tiisg; - // fast route - if (res == 0) { - if (simd_max(*mask_src) > -MAXHALF/2) { - res = 1; - } - } - // detailed check of the elements of the block if ((C > NW || Q > 1) && res == 0) { - half m = -MAXHALF; + half mmin = MAXHALF; + half mmax = -MAXHALF; FOR_UNROLL (short j = 0; j < Q; ++j) { FOR_UNROLL (short ii = 0; ii < C/NW; ++ii) { - m = max(m, mask_src[ii*NW]); + mmin = min(mmin, mask_src[ii*NW]); + mmax = max(mmax, mask_src[ii*NW]); } mask_src += args.nb31/2; } - if (simd_max(m) > -MAXHALF/2) { - res = 1; + mmin = simd_min(mmin); + mmax = simd_max(mmax); + + if (mmax > -MAXHALF) { + if (mmin == 0.0 && mmax == 0.0) { + res = 2; + } else { + res = 1; + } } } @@ -5491,9 +5452,13 @@ void kernel_flash_attn_ext_impl( ic = 0; } + char blk_cur = 1; + // read the mask into shared mem if (FC_flash_attn_ext_has_mask) { - if (blk[ic0] == 0) { + blk_cur = blk[ic0]; + + if (blk_cur == 0) { FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { pm2[jj] += NW; } @@ -5501,16 +5466,22 @@ void kernel_flash_attn_ext_impl( continue; } - FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { - const short j = jj*NSG + sgitg; + if (blk_cur == 1) { + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { + const short j = jj*NSG + sgitg; - if (FC_flash_attn_ext_bc_mask) { - sm2[j*SH + tiisg] = (iq1 + j) < args.ne31 ? pm2[jj][tiisg] : half2(-MAXHALF, -MAXHALF); - } else { - sm2[j*SH + tiisg] = pm2[jj][tiisg]; - } + if (FC_flash_attn_ext_bc_mask) { + sm2[j*SH + tiisg] = (iq1 + j) < args.ne31 ? pm2[jj][tiisg] : half2(-MAXHALF, -MAXHALF); + } else { + sm2[j*SH + tiisg] = pm2[jj][tiisg]; + } - pm2[jj] += NW; + pm2[jj] += NW; + } + } else if (blk_cur == 2) { + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { + pm2[jj] += NW; + } } #if 0 @@ -5675,10 +5646,12 @@ void kernel_flash_attn_ext_impl( } // mqk = mqk + slope*mask - if (FC_flash_attn_ext_has_bias) { - s2 += s2_t(sm2[j*SH + tiisg])*slope; - } else { - s2 += s2_t(sm2[j*SH + tiisg]); + if (blk_cur != 2) { + if (FC_flash_attn_ext_has_bias) { + s2 += s2_t(sm2[j*SH + tiisg])*slope; + } else { + s2 += s2_t(sm2[j*SH + tiisg]); + } } M[jj] = simd_max(max(M[jj], max(s2[0], s2[1]))); @@ -5931,7 +5904,7 @@ template< void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &), short DK, // K head size short DV, // V head size - short Q = OP_FLASH_ATTN_EXT_NQPTG, // queries per threadgroup + short Q = OP_FLASH_ATTN_EXT_NQPSG, // queries per threadgroup short C = OP_FLASH_ATTN_EXT_NCPSG> // cache items per threadgroup kernel void kernel_flash_attn_ext( constant ggml_metal_kargs_flash_attn_ext & args, @@ -6141,11 +6114,10 @@ template< void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &), short DK, // K head size short DV, // V head size - short NE, // head elements per thread - short Q, // queries per threadgroup - short C, // cache items per threadgroup - short NSG> // number of simd groups -void kernel_flash_attn_ext_vec_impl( + short NE = 4, // head elements per thread + short Q = OP_FLASH_ATTN_EXT_VEC_NQPSG, // queries per threadgroup + short C = OP_FLASH_ATTN_EXT_VEC_NCPSG> // cache items per threadgroup +kernel void kernel_flash_attn_ext_vec( constant ggml_metal_kargs_flash_attn_ext_vec & args, device const char * q, device const char * k, @@ -6162,6 +6134,7 @@ void kernel_flash_attn_ext_vec_impl( static_assert(DV % 32 == 0, "DV must be divisible by 32"); #define NWG (FC_flash_attn_ext_vec_nwg) +#define NSG (FC_flash_attn_ext_vec_nsg) #define NS10 (FC_flash_attn_ext_vec_ns10) #define NS20 (FC_flash_attn_ext_vec_ns20) @@ -6190,12 +6163,12 @@ void kernel_flash_attn_ext_vec_impl( const short T = PK + NSG*SH; // shared memory size per query in (half) - //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*PK); // holds the query data - threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*PK); // same as above but in q4_t - threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*PK); // scratch buffer for attention - threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*PK); // same as above but in s4_t - threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + 2*C + Q*PK); // scratch buffer for mask - threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 2*sgitg*PV + Q*T); // scratch buffer for the results + //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*PK); // holds the query data + threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*PK); // same as above but in q4_t + threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + NSG*PK); // scratch buffer for attention + threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + NSG*PK); // same as above but in s4_t + threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + 2*C + NSG*PK); // scratch buffer for mask + threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 2*sgitg*PV + NSG*PK + NSG*SH); // scratch buffer for the results // store the result for all queries in shared memory (the O matrix from the paper) so4 += tiisg; @@ -6213,11 +6186,13 @@ void kernel_flash_attn_ext_vec_impl( // load heads from Q to shared memory device const float4 * q4 = (device const float4 *) ((device const char *) q); - for (short i = tiisg; i < PK4; i += NW) { - if (iq1 < args.ne01 && i < DK4) { - sq4[i] = (q4_t) q4[i]; - } else { - sq4[i] = (q4_t) 0.0f; + if (iq1 < args.ne01) { + for (short i = tiisg; i < PK4; i += NW) { + if (i < DK4) { + sq4[i] = (q4_t) q4[i]; + } else { + sq4[i] = (q4_t) 0.0f; + } } } @@ -6295,7 +6270,7 @@ void kernel_flash_attn_ext_vec_impl( } // skip -INF blocks - if (simd_max(sm[tiisg]) == -INFINITY) { + if (simd_max(sm[tiisg]) <= -MAXHALF) { continue; } @@ -6569,57 +6544,11 @@ void kernel_flash_attn_ext_vec_impl( } #undef NWG +#undef NSG #undef NS10 #undef NS20 } -template< - typename q4_t, // query types in shared memory - typename k4_t, // key types in shared memory - typename v4_t, // value types in shared memory - typename qk_t, // Q*K types - typename s_t, // soft-max types - typename s4_t, - typename o4_t, // attention accumulation types - typename kd4_t, // key type in device memory - short nl_k, - void (*deq_k_t4)(device const kd4_t *, short, thread k4_t &), - typename vd4_t, // value type in device memory - short nl_v, - void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &), - short DK, // K head size - short DV, // V head size - short NE = 4, // head elements per thread - short Q = OP_FLASH_ATTN_EXT_VEC_NQPTG, // queries per threadgroup - short C = OP_FLASH_ATTN_EXT_VEC_NCPSG> // cache items per threadgroup -kernel void kernel_flash_attn_ext_vec( - constant ggml_metal_kargs_flash_attn_ext_vec & args, - device const char * q, - device const char * k, - device const char * v, - device const char * mask, - device const char * sinks, - device const char * pad, - device char * dst, - threadgroup half * shmem_f16 [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]]) { -#define FWD_TMPL q4_t, k4_t, v4_t, qk_t, s_t, s4_t, o4_t, kd4_t, nl_k, deq_k_t4, vd4_t, nl_v, deq_v_t4, DK, DV, NE, Q, C -#define FWD_ARGS args, q, k, v, mask, sinks, pad, dst, shmem_f16, tgpig, tiisg, sgitg - switch (FC_flash_attn_ext_vec_nsg) { - // note: disabled cases to reduce library load time - case 1: kernel_flash_attn_ext_vec_impl(FWD_ARGS); break; - case 2: kernel_flash_attn_ext_vec_impl(FWD_ARGS); break; - case 4: kernel_flash_attn_ext_vec_impl(FWD_ARGS); break; - //case 8: kernel_flash_attn_ext_vec_impl(FWD_ARGS); break; - //case 16: kernel_flash_attn_ext_vec_impl(FWD_ARGS); break; - //case 32: kernel_flash_attn_ext_vec_impl(FWD_ARGS); break; - } -#undef FWD_TMPL -#undef FWD_ARGS -} - // note: I think the s_t can be half instead of float, because the Q*K scaling is done before storing to shared mem // in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max // @@ -8782,6 +8711,26 @@ kernel void kernel_set_rows_f( } } +kernel void kernel_diag_f32( + constant ggml_metal_kargs_diag & args, + device const char * src0, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiitg[[thread_index_in_threadgroup]]) { + constexpr short NW = N_SIMDWIDTH; + + const int32_t i3 = tgpig.z; + const int32_t i2 = tgpig.y; + const int32_t i1 = tgpig.x; + + device const float * src0_ptr = (device const float *)(src0 + i2*args.nb02 + i3*args.nb03); + device float * dst_ptr = (device float *)(dst + i1*args.nb01 + i2*args.nb2 + i3*args.nb3); + + for (int i0 = tiitg; i0 < args.ne0; i0 += NW) { + dst_ptr[i0] = i0 == i1 ? src0_ptr[i0] : 0.0f; + } +} + constant bool FC_mul_mm_bc_inp [[function_constant(FC_MUL_MM + 0)]]; constant bool FC_mul_mm_bc_out [[function_constant(FC_MUL_MM + 1)]]; diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index 307ec08242a..fa5fadd112b 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -57,6 +57,7 @@ set(GGML_OPENCL_KERNELS add add_id argsort + tri fill clamp cpy @@ -84,7 +85,8 @@ set(GGML_OPENCL_KERNELS mul_mv_q4_0_f32_8x_flat mul_mv_q4_0_f32_1d_8x_flat mul_mv_q4_0_f32_1d_16x_flat - mul_mv_q6_k + mul_mv_q6_k_f32 + mul_mv_q6_k_f32_flat mul_mv_q8_0_f32 mul_mv_q8_0_f32_flat mul_mv_mxfp4_f32 @@ -99,6 +101,8 @@ set(GGML_OPENCL_KERNELS mul_mm_f32_f32_l4_lm mul_mm_f16_f32_l4_lm mul_mm_q8_0_f32_l4_lm + mul_mm_q8_0_f32_8x4 + gemv_noshuffle_general_q8_0_f32 mul norm relu diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 8059240b1c4..508b2b8f037 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -226,7 +226,8 @@ static ADRENO_GPU_GEN get_adreno_gpu_gen(const char *device_name) { return ADRENO_GPU_GEN::A7X; } - if (strstr(device_name, "830")) { + if (strstr(device_name, "830") || + strstr(device_name, "840")) { return ADRENO_GPU_GEN::A8X; } @@ -398,6 +399,7 @@ struct ggml_backend_opencl_context { int adreno_wave_size; cl_bool non_uniform_workgroups; + size_t image_max_buffer_size; cl_context context; cl_command_queue queue; @@ -407,6 +409,10 @@ struct ggml_backend_opencl_context { ggml_cl_buffer prealloc_scales_trans; ggml_cl_buffer prealloc_act_trans; + // prealloc buffers for src0 and src1 + ggml_cl_buffer prealloc_src0; + ggml_cl_buffer prealloc_src1; + cl_program program_add; cl_program program_add_id; cl_program program_clamp; @@ -447,7 +453,6 @@ struct ggml_backend_opencl_context { cl_program program_rms_norm; cl_program program_group_norm; cl_program program_rope; - cl_program program_scale; cl_program program_silu; cl_program program_sigmoid; cl_program program_softmax_f32; @@ -456,11 +461,8 @@ struct ggml_backend_opencl_context { cl_program program_softmax_4_f16; cl_program program_argsort_f32_i32; cl_program program_sum_rows_f32; - cl_program program_repeat; cl_program program_pad; - cl_program program_tanh; cl_program program_upscale; - cl_program program_concat; cl_program program_conv_2d_f16; cl_program program_conv_2d_f32; cl_program program_conv_2d_f16_f32; @@ -479,7 +481,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_div, kernel_div_row, kernel_div_f16, kernel_div_row_f16; cl_kernel kernel_sub, kernel_sub_row, kernel_sub_f16, kernel_sub_row_f16; cl_kernel kernel_add_id; - cl_kernel kernel_scale; + cl_kernel kernel_scale_f32, kernel_scale_f32_4; cl_kernel kernel_sqr_cont_f32, kernel_sqr_cont_f32_4, kernel_sqr_cont_f16, kernel_sqr_cont_f16_4; cl_kernel kernel_sqrt_cont_f32, kernel_sqrt_cont_f32_4, kernel_sqrt_cont_f16, kernel_sqrt_cont_f16_4; cl_kernel kernel_mean_f32; @@ -489,6 +491,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_gelu_quick, kernel_gelu_quick_4; cl_kernel kernel_relu; cl_kernel kernel_sigmoid_f32, kernel_sigmoid_f16; + cl_kernel kernel_tri; cl_kernel kernel_fill; cl_kernel kernel_clamp; cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu, kernel_swiglu_oai, kernel_geglu_erf, kernel_geglu_quick, @@ -523,30 +526,31 @@ struct ggml_backend_opencl_context { cl_kernel kernel_mul_mat_q4_0_f32, kernel_mul_mat_q4_0_f32_v; cl_kernel kernel_convert_block_q4_0, kernel_restore_block_q4_0; cl_kernel kernel_convert_block_mxfp4, kernel_convert_block_mxfp4_trans, kernel_restore_block_mxfp4, kernel_restore_block_mxfp4_trans; - cl_kernel kernel_convert_block_q8_0, kernel_restore_block_q8_0; + cl_kernel kernel_convert_block_q8_0, kernel_restore_block_q8_0, kernel_restore_block_q8_0_trans; cl_kernel kernel_mul_mat_q4_0_f32_8x_flat; cl_kernel kernel_convert_block_q4_0_noshuffle; cl_kernel kernel_restore_block_q4_0_noshuffle; + cl_kernel kernel_convert_block_q6_K, kernel_restore_block_q6_K; cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat; cl_kernel kernel_mul_mv_q6_K_f32; + cl_kernel kernel_mul_mv_q6_K_f32_flat; cl_kernel kernel_mul_mv_mxfp4_f32, kernel_mul_mv_mxfp4_f32_flat; cl_kernel kernel_mul_mv_q8_0_f32, kernel_mul_mv_q8_0_f32_flat; cl_kernel kernel_solve_tri_f32; cl_kernel kernel_im2col_f32, kernel_im2col_f16; cl_kernel kernel_argsort_f32_i32; cl_kernel kernel_sum_rows_f32; - cl_kernel kernel_repeat; + cl_kernel kernel_repeat_f32; cl_kernel kernel_pad; - cl_kernel kernel_tanh_f32_nd; - cl_kernel kernel_tanh_f16_nd; + cl_kernel kernel_tanh_f32, kernel_tanh_f32_4, kernel_tanh_f32_nc; + cl_kernel kernel_tanh_f16, kernel_tanh_f16_4, kernel_tanh_f16_nc; cl_kernel kernel_expm1_f32_nd; cl_kernel kernel_expm1_f16_nd; cl_kernel kernel_softplus_f32_nd; cl_kernel kernel_softplus_f16_nd; cl_kernel kernel_upscale; cl_kernel kernel_upscale_bilinear; - cl_kernel kernel_concat_f32_contiguous; - cl_kernel kernel_concat_f32_non_contiguous; + cl_kernel kernel_concat_f32; cl_kernel kernel_conv_2d_f16; cl_kernel kernel_conv_2d_f32; cl_kernel kernel_conv_2d_f16_f32; @@ -688,6 +692,8 @@ struct ggml_backend_opencl_context { cl_kernel CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_4096; cl_kernel CL_mul_mat_vec_q4_0_f32_1d_4x_flat_11008_1_4096; cl_kernel CL_mul_mat_vec_q4_0_f32_1d_4x_flat_32000_1_4096; + cl_kernel kernel_mul_mm_q8_0_f32_8x4; + cl_kernel CL_mul_mat_vec_q8_0_f32; #endif // GGML_OPENCL_USE_ADRENO_KERNELS void free() { @@ -793,6 +799,24 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // tri + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "tri.cl.h" + }; +#else + const std::string kernel_src = read_file("tri.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_tri = clCreateKernel(prog, "kernel_tri_f32", &err), err)); + GGML_LOG_CONT("."); + + CL_CHECK(clReleaseProgram(prog)); + } + // fill { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -868,6 +892,9 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve CL_CHECK((backend_ctx->kernel_restore_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_mxfp4", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q8_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q8_0", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q8_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q8_0", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q8_0_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q8_0_trans", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q6_K = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q6_K", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q6_K = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q6_K", &err), err)); GGML_LOG_CONT("."); } @@ -1090,14 +1117,14 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } - // mul_mv_q6_k + // mul_mv_q6_k_f32 { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { - #include "mul_mv_q6_k.cl.h" + #include "mul_mv_q6_k_f32.cl.h" }; #else - const std::string kernel_src = read_file("mul_mv_q6_k.cl"); + const std::string kernel_src = read_file("mul_mv_q6_k_f32.cl"); #endif backend_ctx->program_mul_mv_q6_K = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); @@ -1106,6 +1133,23 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // mul_mv_q6_k_f32_flat + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_q6_k_f32_flat.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_q6_k_f32_flat.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mv_q6_K_f32_flat = clCreateKernel(prog, "kernel_mul_mv_q6_K_f32_flat", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + // mul_mv_q8_0_f32 { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -1434,10 +1478,12 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve #else const std::string kernel_src = read_file("scale.cl"); #endif - backend_ctx->program_scale = + cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_scale = clCreateKernel(backend_ctx->program_scale, "kernel_scale", &err), err)); + CL_CHECK((backend_ctx->kernel_scale_f32 = clCreateKernel(prog, "kernel_scale_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_scale_f32_4 = clCreateKernel(prog, "kernel_scale_f32_4", &err), err)); + CL_CHECK(clReleaseProgram(prog)); GGML_LOG_CONT("."); } @@ -1765,16 +1811,11 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve #else const std::string kernel_src = read_file("repeat.cl"); #endif - if (!kernel_src.empty()) { - backend_ctx->program_repeat = - build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_repeat = clCreateKernel(backend_ctx->program_repeat, "kernel_repeat", &err), err)); - GGML_LOG_CONT("."); - } else { - GGML_LOG_WARN("ggml_opencl: repeat kernel source not found or empty. Repeat operations will not be available.\n"); - backend_ctx->program_repeat = nullptr; - backend_ctx->kernel_repeat = nullptr; - } + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_repeat_f32 = clCreateKernel(prog, "kernel_repeat_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); } // pad @@ -1807,18 +1848,16 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve #else const std::string kernel_src = read_file("tanh.cl"); #endif - if (!kernel_src.empty()) { - backend_ctx->program_tanh = - build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_tanh_f32_nd = clCreateKernel(backend_ctx->program_tanh, "kernel_tanh_f32_nd", &err), err)); - CL_CHECK((backend_ctx->kernel_tanh_f16_nd = clCreateKernel(backend_ctx->program_tanh, "kernel_tanh_f16_nd", &err), err)); - GGML_LOG_CONT("."); - } else { - GGML_LOG_WARN("ggml_opencl: tanh kernel source not found or empty. Tanh operation will not be available.\n"); - backend_ctx->program_tanh = nullptr; - backend_ctx->kernel_tanh_f32_nd = nullptr; - backend_ctx->kernel_tanh_f16_nd = nullptr; - } + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_tanh_f32 = clCreateKernel(prog, "kernel_tanh_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_tanh_f32_4 = clCreateKernel(prog, "kernel_tanh_f32_4", &err), err)); + CL_CHECK((backend_ctx->kernel_tanh_f32_nc = clCreateKernel(prog, "kernel_tanh_f32_nc", &err), err)); + CL_CHECK((backend_ctx->kernel_tanh_f16 = clCreateKernel(prog, "kernel_tanh_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_tanh_f16_4 = clCreateKernel(prog, "kernel_tanh_f16_4", &err), err)); + CL_CHECK((backend_ctx->kernel_tanh_f16_nc = clCreateKernel(prog, "kernel_tanh_f16_nc", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); } // expm1 @@ -1910,22 +1949,13 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve #include "concat.cl.h" }; #else - const std::string kernel_src = read_file("concat.cl"); #endif - if (!kernel_src.empty()) { - backend_ctx->program_concat = - build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - - CL_CHECK((backend_ctx->kernel_concat_f32_contiguous = clCreateKernel(backend_ctx->program_concat, "kernel_concat_f32_contiguous", &err), err)); - CL_CHECK((backend_ctx->kernel_concat_f32_non_contiguous = clCreateKernel(backend_ctx->program_concat, "kernel_concat_f32_non_contiguous", &err), err)); - GGML_LOG_CONT("."); - } else { - GGML_LOG_WARN("ggml_opencl: concat kernel source not found or empty. Concat operations will not be available.\n"); - backend_ctx->program_concat = nullptr; - backend_ctx->kernel_concat_f32_contiguous = nullptr; - backend_ctx->kernel_concat_f32_non_contiguous = nullptr; - } + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_concat_f32 = clCreateKernel(prog, "kernel_concat_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); } // timestep_embedding @@ -2245,6 +2275,46 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // mul_mm_q8_0_f32_8x4 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src_q8_8x4_gemm { + #include "mul_mm_q8_0_f32_8x4.cl.h" + }; +#else + const std::string kernel_src_q8_8x4_gemm = read_file("mul_mm_q8_0_f32_8x4.cl"); +#endif + backend_ctx->program_CL_gemm = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_q8_8x4_gemm.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_mul_mm_q8_0_f32_8x4 = clCreateKernel(backend_ctx->program_CL_gemm, "kernel_mul_mm_q8_0_f32_8x4", &err), err)); + GGML_LOG_CONT("."); + } + + // gemv_noshuffle_general_q8_0_f32 + { + std::string CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable " + " -DSIMDGROUP_WIDTH=" + + std::to_string(backend_ctx->adreno_wave_size); + if (backend_ctx->has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; + } + +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src_CL_gemv_general { + #include "gemv_noshuffle_general_q8_0_f32.cl.h" + }; +#else + const std::string kernel_src_CL_gemv_general = read_file("gemv_noshuffle_general_q8_0_f32.cl"); +#endif + + cl_program prog = build_program_from_source( + backend_ctx->context, backend_ctx->device, kernel_src_CL_gemv_general.c_str(), CL_gemv_compile_opts); + + CL_CHECK((backend_ctx->CL_mul_mat_vec_q8_0_f32 = clCreateKernel(prog, "kernel_gemv_noshuffle", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + std::string CL_moe_compile_opts = std::string("-cl-std=") + opencl_c_std + " -cl-mad-enable " " -cl-fast-relaxed-math"; @@ -2639,6 +2709,9 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { clGetDeviceInfo(device, CL_DEVICE_MAX_MEM_ALLOC_SIZE, sizeof(size_t), &backend_ctx->max_alloc_size, NULL); GGML_LOG_INFO("ggml_opencl: max mem alloc size: %zu MB\n", backend_ctx->max_alloc_size/1024/1024); + clGetDeviceInfo(device, CL_DEVICE_IMAGE_MAX_BUFFER_SIZE, sizeof(size_t), &backend_ctx->image_max_buffer_size, NULL); + GGML_LOG_INFO("ggml_opencl: device max image buffer size (pixels): %lu\n", backend_ctx->image_max_buffer_size); + clGetDeviceInfo(device, CL_DEVICE_MAX_WORK_GROUP_SIZE, sizeof(size_t), &backend_ctx->max_workgroup_size, NULL); GGML_LOG_INFO("ggml_opencl: device max workgroup size: %lu\n", backend_ctx->max_workgroup_size); @@ -2892,6 +2965,50 @@ struct ggml_tensor_extra_cl_q8_0 { } }; +struct ggml_tensor_extra_cl_q6_K { + // Lower 4 bits of quantized weights. + cl_mem ql = nullptr; + // Upper 2 bits of quantized weights. + cl_mem qh = nullptr; + // Scales for each block. + cl_mem s = nullptr; + // Scales for each super block. + cl_mem d = nullptr; + + size_t size_ql = 0; + size_t size_qh = 0; + size_t size_s = 0; + size_t size_d = 0; + + ~ggml_tensor_extra_cl_q6_K() { + reset(); + } + + void reset() { + if (ql != nullptr) { + CL_CHECK(clReleaseMemObject(ql)); + ql = nullptr; + } + if (qh != nullptr) { + CL_CHECK(clReleaseMemObject(qh)); + qh = nullptr; + } + if (s != nullptr) { + CL_CHECK(clReleaseMemObject(s)); + s = nullptr; + } + if (d != nullptr) { + CL_CHECK(clReleaseMemObject(d)); + d = nullptr; + } + + size_ql = 0; + size_qh = 0; + size_s = 0; + size_d = 0; + } +}; + //------------------------------------------------------------------------------ // Backend API //------------------------------------------------------------------------------ @@ -3182,8 +3299,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te case GGML_UNARY_OP_SIGMOID: return ggml_is_contiguous(op->src[0]); case GGML_UNARY_OP_TANH: - return (op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) || - (op->src[0]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16); + return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16; case GGML_UNARY_OP_EXPM1: return (op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) || (op->src[0]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16); @@ -3205,6 +3321,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te default: return false; } + case GGML_OP_TRI: + return op->type == GGML_TYPE_F32 && ggml_is_contiguous(op); case GGML_OP_FILL: return op->type == GGML_TYPE_F32 && ggml_is_contiguous(op); case GGML_OP_CLAMP: @@ -3436,6 +3554,12 @@ struct ggml_backend_opencl_buffer_context { for (ggml_tensor_extra_cl_q8_0 * e : temp_tensor_extras_q8_0_in_use) { delete e; } + for (ggml_tensor_extra_cl_q6_K * e : temp_tensor_extras_q6_K) { + delete e; + } + for (ggml_tensor_extra_cl_q6_K * e : temp_tensor_extras_q6_K_in_use) { + delete e; + } } ggml_tensor_extra_cl * ggml_opencl_alloc_temp_tensor_extra() { @@ -3498,6 +3622,21 @@ struct ggml_backend_opencl_buffer_context { return extra; } + ggml_tensor_extra_cl_q6_K * ggml_opencl_alloc_temp_tensor_extra_q6_K() { + ggml_tensor_extra_cl_q6_K * extra; + if (temp_tensor_extras_q6_K.empty()) { + extra = new ggml_tensor_extra_cl_q6_K(); + } else { + extra = temp_tensor_extras_q6_K.back(); + temp_tensor_extras_q6_K.pop_back(); + } + + temp_tensor_extras_q6_K_in_use.push_back(extra); + + extra->reset(); + return extra; + } + void reset() { for (ggml_tensor_extra_cl * e : temp_tensor_extras_in_use) { temp_tensor_extras.push_back(e); @@ -3518,6 +3657,11 @@ struct ggml_backend_opencl_buffer_context { temp_tensor_extras_q8_0.push_back(e); } temp_tensor_extras_q8_0_in_use.clear(); + + for (ggml_tensor_extra_cl_q6_K * e : temp_tensor_extras_q6_K_in_use) { + temp_tensor_extras_q6_K.push_back(e); + } + temp_tensor_extras_q6_K_in_use.clear(); } // Pools for extras. Available extras are in `temp_tensor_extras`. Extras @@ -3533,6 +3677,8 @@ struct ggml_backend_opencl_buffer_context { std::vector temp_tensor_extras_mxfp4_in_use; std::vector temp_tensor_extras_q8_0; std::vector temp_tensor_extras_q8_0_in_use; + std::vector temp_tensor_extras_q6_K; + std::vector temp_tensor_extras_q6_K_in_use; // The buffer_context is initially created by ggml_backend_buft_alloc_buffer // before any tensor is initialized (at the beginning of alloc_tensor_range). @@ -3574,7 +3720,7 @@ static enum ggml_status ggml_backend_opencl_buffer_init_tensor(ggml_backend_buff // Reuse extra of the parent tensor. The offset of this view tensor // becomes `extra->offset + view_offs` and needs to be calculated when // it is used. This changes is needed because of the change to - // ggml_alloc.c in https://github.com/ggerganov/llama.cpp/pull/7640. + // ggml_alloc.c in https://github.com/ggml-org/llama.cpp/pull/7640. // `buffer` passed in here will always be `tensor->buffer`. It is OK // to allocate extras from the same buffer context for ordinary // intermediate tensors. But for views into kv cache tensors, doing so @@ -3623,6 +3769,15 @@ inline bool use_adreno_moe_kernels(const ggml_backend_opencl_context *backend_ct return ((strstr(tensor->name, "ffn") != NULL) || (strstr(tensor->name, "as") != NULL)) && (ne01 % 64 == 0); } +inline bool enable_adreno_trans_weight(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) { + + bool adreno_kernel = use_adreno_kernels(backend_ctx, tensor); + + size_t elem_num = tensor->ne[0] * tensor->ne[1] * tensor->ne[2] * tensor->ne[3]; + + return ((elem_num < 128 * 1024 * 1024) && adreno_kernel); // max element num: 2**27 +} + static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(buffer->buft->device); @@ -4037,6 +4192,216 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, tensor->extra = extra; + // Transpose the weights and scales +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (enable_adreno_trans_weight(backend_ctx, tensor)) { + + int M = tensor->ne[1]; // ne01 + int K = tensor->ne[0]; // ne00 + + GGML_ASSERT(K % 32 == 0); + GGML_ASSERT(M % 4 == 0); + GGML_ASSERT(tensor->ne[2] == 1); + GGML_ASSERT(tensor->ne[3] == 1); + + // Transpose weights + size_t q_size_bytes = K * M / 4 * sizeof(float); + cl_buffer_region region; + region.origin = 0; + region.size = q_size_bytes; + cl_mem qT_d = clCreateSubBuffer( + backend_ctx->prealloc_quant_trans.buffer, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &err); + CL_CHECK(err); + + cl_mem q_d_image1D; + cl_mem qT_d_image1D; + + cl_image_format img_fmt_1d; + cl_image_desc img_desc_1d; + + img_fmt_1d = { CL_RGBA, CL_FLOAT }; + memset(&img_desc_1d, 0, sizeof(img_desc_1d)); + img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc_1d.image_width = M * K / 4 / 4; + img_desc_1d.buffer = extra->q; + q_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err); + CL_CHECK(err); + + img_fmt_1d = { CL_RGBA, CL_FLOAT }; + memset(&img_desc_1d, 0, sizeof(img_desc_1d)); + img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc_1d.image_width = M * K / 4 / 4; + img_desc_1d.buffer = qT_d; + qT_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err); + CL_CHECK(err); + + int height_q = M / 4; + int width_q = K / 4 / 4; + kernel = backend_ctx->kernel_transpose_32; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &q_d_image1D)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &qT_d_image1D)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_q)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_q)); + + size_t local_size_q[3] = {4, 16, 1}; + size_t global_size_q[3] = {static_cast(width_q), static_cast(height_q), 1}; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_size_q, local_size_q, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + + // Transpose scales + size_t d_size_bytes = M * (K / 32) * 2; + region.origin = 0; + region.size = d_size_bytes; + cl_mem dT_d = clCreateSubBuffer( + backend_ctx->prealloc_scales_trans.buffer, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &err); + CL_CHECK(err); + + cl_mem d_d_image1D; + cl_mem dT_d_image1D; + + memset(&img_desc_1d, 0, sizeof(img_desc_1d)); + img_fmt_1d = { CL_R, CL_HALF_FLOAT }; + img_desc_1d.image_width = M * K / 32; + img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc_1d.buffer = extra->d; + d_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err); + CL_CHECK(err); + + img_fmt_1d = { CL_RGBA, CL_HALF_FLOAT }; + memset(&img_desc_1d, 0, sizeof(img_desc_1d)); + img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc_1d.image_width = M * K / 32 / 4; + img_desc_1d.buffer = dT_d; + dT_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err); + CL_CHECK(err); + + int height_s = M / 4; + int width_s = K / 32; + + kernel = backend_ctx->kernel_transpose_16_4x1; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &d_d_image1D)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &dT_d_image1D)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_s)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_s)); + + size_t local_size_s[3] = {4, 16, 1}; + size_t global_size_s[3] = {static_cast(width_s), static_cast(height_s), 1}; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_size_s, local_size_s, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + + // copy transposed buffer contents to original buffers + CL_CHECK(clEnqueueCopyBuffer(queue, qT_d, extra->q, 0, 0, q_size_bytes, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + + CL_CHECK(clEnqueueCopyBuffer(queue, dT_d, extra->d, 0, 0, d_size_bytes, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + + CL_CHECK(clReleaseMemObject(qT_d)); + CL_CHECK(clReleaseMemObject(dT_d)); + + CL_CHECK(clReleaseMemObject(q_d_image1D)); + CL_CHECK(clReleaseMemObject(d_d_image1D)); + CL_CHECK(clReleaseMemObject(qT_d_image1D)); + CL_CHECK(clReleaseMemObject(dT_d_image1D)); + } // end transpose +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + + return; + } + if (tensor->type == GGML_TYPE_Q6_K) { + ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra; + GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized"); + + // Allocate the new extra and create aliases from the original. + ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; + ggml_tensor_extra_cl_q6_K * extra = ctx->ggml_opencl_alloc_temp_tensor_extra_q6_K(); + + size_t size_ql = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/2; + size_t size_qh = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/4; + size_t size_s = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/16; + size_t size_d = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t); + GGML_ASSERT(size_ql + size_qh + size_s + size_d == ggml_nbytes(tensor) && + "Incorrect tensor size"); + + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + CL_CHECK(clEnqueueWriteBuffer( + queue, data_device, CL_TRUE, 0, + ggml_nbytes(tensor), data, 0, NULL, NULL)); + + cl_buffer_region region; + + // Subbuffer for ql + region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment); + region.size = size_ql; + extra->ql = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + auto previous_origin = region.origin; + + // Subbuffer for qh + region.origin = align_to(previous_origin + size_ql, backend_ctx->alignment); + region.size = size_qh; + extra->qh = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + previous_origin = region.origin; + + // Subbuffer for scales + region.origin = align_to(previous_origin + size_qh, backend_ctx->alignment); + region.size = size_s; + extra->s = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + previous_origin = region.origin; + + // Create subbuffer for d. + region.origin = align_to(previous_origin + size_s, backend_ctx->alignment); + region.size = size_d; + extra->d = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + previous_origin = region.origin; + + // Flatten the weights + cl_kernel kernel = backend_ctx->kernel_convert_block_q6_K; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->ql)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->s)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->d)); + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + extra->size_ql = size_ql; + extra->size_qh = size_qh; + extra->size_s = size_s; + extra->size_d = size_d; + + tensor->extra = extra; return; } #endif // GGML_OPENCL_SOA_Q @@ -4240,27 +4605,85 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, ggml_nbytes(tensor), NULL, &err); CL_CHECK(err); - cl_kernel kernel = backend_ctx->kernel_restore_block_q8_0; - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->d)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &data_device)); - - size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; - size_t local_work_size[] = {1, 1, 1}; +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (enable_adreno_trans_weight(backend_ctx, tensor)) { + cl_kernel kernel = backend_ctx->kernel_restore_block_q8_0_trans; - cl_event evt; - CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, - global_work_size, local_work_size, 0, NULL, &evt)); - CL_CHECK(clWaitForEvents(1, &evt)); - CL_CHECK(clEnqueueReadBuffer( - queue, data_device, CL_TRUE, offset, - size, data, 0, NULL, NULL)); - CL_CHECK(clReleaseMemObject(data_device)); - return; - } -#endif // GGML_OPENCL_SOA_Q + int ne00 = tensor->ne[0]; + int ne01 = tensor->ne[1]; + GGML_ASSERT(tensor->ne[2] == 1); // ??? + GGML_ASSERT(tensor->ne[3] == 1); // ??? - ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) tensor->extra; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_int), &ne01)); + + size_t global_work_size[3] = {static_cast(((ne01 + 63) / 64) * 64), 1, 1}; + size_t local_work_size[3] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; + } +#endif + cl_kernel kernel = backend_ctx->kernel_restore_block_q8_0; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &data_device)); + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {1, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; + } + if (tensor->type == GGML_TYPE_Q6_K) { + ggml_tensor_extra_cl_q6_K * extra = (ggml_tensor_extra_cl_q6_K *)tensor->extra; + + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + + cl_kernel kernel = backend_ctx->kernel_restore_block_q6_K; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->ql)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->s)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &data_device)); + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {1, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; + } +#endif // GGML_OPENCL_SOA_Q + + ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) tensor->extra; CL_CHECK(clEnqueueReadBuffer( queue, extra->data_device, CL_TRUE, extra->offset + tensor->view_offs + offset, @@ -4690,6 +5113,81 @@ static bool ggml_cl_can_mul_mat(const struct ggml_tensor * src0, const struct gg (ne0 >= 32 && ne1 >= 32 && ne10 >= 32); } +// Copy a noncontiguous tensor to contiguous tensor. ne[] remains the same but +// nb[] is recalculated such that tensor is contiguous. +static void ggml_cl_copy_to_contiguous(ggml_backend_t backend, const ggml_tensor * src, cl_mem dst, + cl_ulong &nb0, cl_ulong &nb1, cl_ulong &nb2, cl_ulong &nb3) { + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + const int tensor_type_size = ggml_type_size(src->type); + + const int ne00 = src->ne[0]; + const int ne01 = src->ne[1]; + const int ne02 = src->ne[2]; + const int ne03 = src->ne[3]; + + const cl_ulong nb00 = src->nb[0]; + const cl_ulong nb01 = src->nb[1]; + const cl_ulong nb02 = src->nb[2]; + const cl_ulong nb03 = src->nb[3]; + + const int ne0 = src->ne[0]; + const int ne1 = src->ne[1]; + const int ne2 = src->ne[2]; + const int ne3 = src->ne[3]; + + nb0 = tensor_type_size; + nb1 = tensor_type_size*ne00; + nb2 = tensor_type_size*ne00*ne01; + nb3 = tensor_type_size*ne00*ne01*ne02; + + ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *)src->extra; + + cl_ulong offset0 = extra->offset + src->view_offs; + cl_ulong offsetd = 0; + + cl_kernel kernel; + + switch (src->type) { + case GGML_TYPE_F32: + kernel = backend_ctx->kernel_cpy_f32_f32; + break; + case GGML_TYPE_F16: + kernel = backend_ctx->kernel_cpy_f16_f16; + break; + default: + GGML_ASSERT(false && "not implemented"); + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &dst)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne2)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne3)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb0)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb3)); + + const int nth = MIN(64, ne00); + + size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; + size_t local_work_size[] = {(size_t)nth, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, src); +} + static void ggml_cl_nop(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { UNUSED(backend); UNUSED(src0); @@ -5965,6 +6463,44 @@ static void ggml_cl_sigmoid(ggml_backend_t backend, const ggml_tensor * src0, co backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst); } +static void ggml_cl_tri(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + UNUSED(src1); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + const int tri_type = ggml_get_op_params_i32(dst, 0); + const int64_t n = ggml_nelements(dst); + const int ne0 = dst->ne[0]; + const int ne1 = dst->ne[1]; + + cl_kernel kernel = backend_ctx->kernel_tri; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &n)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &tri_type)); + + size_t local_work_size[1] = { 256 }; + size_t global_work_size[1] = { ((size_t)n + local_work_size[0] - 1) / local_work_size[0] * local_work_size[0] }; + + backend_ctx->enqueue_ndrange_kernel(kernel, 1, global_work_size, local_work_size, dst); +} + static void ggml_cl_fill(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(dst); GGML_ASSERT(dst->extra); @@ -6473,79 +7009,87 @@ static void ggml_cl_tanh(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; - cl_ulong offset0_abs = extra0->offset + src0->view_offs; - cl_ulong offsetd_abs = extrad->offset + dst->view_offs; + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; - cl_kernel kernel; - if (dst->type == GGML_TYPE_F32) { - kernel = backend_ctx->kernel_tanh_f32_nd; - } else if (dst->type == GGML_TYPE_F16) { - kernel = backend_ctx->kernel_tanh_f16_nd; - } else { - GGML_ASSERT(false && "Unsupported type for ggml_cl_tanh"); - } - GGML_ASSERT(kernel != nullptr); + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; - const int ne00 = src0->ne[0]; const int ne01 = src0->ne[1]; const int ne02 = src0->ne[2]; const int ne03 = src0->ne[3]; - const cl_ulong nb00 = src0->nb[0]; const cl_ulong nb01 = src0->nb[1]; const cl_ulong nb02 = src0->nb[2]; const cl_ulong nb03 = src0->nb[3]; + const cl_ulong nb00 = src0->nb[0]; + const cl_ulong nb01 = src0->nb[1]; + const cl_ulong nb02 = src0->nb[2]; + const cl_ulong nb03 = src0->nb[3]; - const int ne10 = dst->ne[0]; const int ne11 = dst->ne[1]; const int ne12 = dst->ne[2]; const int ne13 = dst->ne[3]; - const cl_ulong nb10 = dst->nb[0]; const cl_ulong nb11 = dst->nb[1]; const cl_ulong nb12 = dst->nb[2]; const cl_ulong nb13 = dst->nb[3]; + const cl_ulong nb0 = dst->nb[0]; + const cl_ulong nb1 = dst->nb[1]; + const cl_ulong nb2 = dst->nb[2]; + const cl_ulong nb3 = dst->nb[3]; - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0_abs)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd_abs)); + cl_kernel kernel; - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01)); - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02)); - CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03)); - CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb00)); - CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01)); - CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong),&nb02)); - CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong),&nb03)); + if (ggml_is_contiguous(src0)) { + // Handle contiguous input + int n = ggml_nelements(dst); + if (n % 4 == 0) { + if (src0->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_tanh_f32_4; + } else { + kernel = backend_ctx->kernel_tanh_f16_4; + } + n /= 4; + } else { + if (src0->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_tanh_f32; + } else { + kernel = backend_ctx->kernel_tanh_f16; + } + } - CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10)); - CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne11)); - CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne12)); - CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne13)); - CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong),&nb10)); - CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong),&nb11)); - CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong),&nb12)); - CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong),&nb13)); + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); - size_t global_work_size[3]; - if (ne10 == 0 || ne11 == 0 || ne12 == 0 || ne13 == 0) { // Handle case of 0 elements - return; - } - global_work_size[0] = (size_t)ne10; - global_work_size[1] = (size_t)ne11; - global_work_size[2] = (size_t)ne12; + size_t global_work_size[] = {(size_t)n, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; - size_t lws0 = 16, lws1 = 4, lws2 = 1; - if (ne10 < 16) lws0 = ne10; - if (ne11 < 4) lws1 = ne11; - if (ne12 < 1) lws2 = ne12 > 0 ? ne12 : 1; + size_t * local_work_size_ptr = local_work_size; + if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) { + local_work_size_ptr = nullptr; + } - while (lws0 * lws1 * lws2 > 256 && lws0 > 1) lws0 /= 2; - while (lws0 * lws1 * lws2 > 256 && lws1 > 1) lws1 /= 2; - while (lws0 * lws1 * lws2 > 256 && lws2 > 1) lws2 /= 2; + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst); + } else { + // Handle non-contiguous input + if (src0->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_tanh_f32_nc; + } else { + kernel = backend_ctx->kernel_tanh_f16_nc; + } + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb0)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb3)); + + int nth = 64; - size_t local_work_size[] = {lws0, lws1, lws2}; + size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; + size_t local_work_size[] = {(size_t)nth, 1, 1}; - size_t* local_work_size_ptr = local_work_size; - if (!backend_ctx->non_uniform_workgroups) { - if (global_work_size[0] % local_work_size[0] != 0 || - global_work_size[1] % local_work_size[1] != 0 || - global_work_size[2] % local_work_size[2] != 0) { - local_work_size_ptr = NULL; - } + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); } - if (global_work_size[0] == 0 || global_work_size[1] == 0 || global_work_size[2] == 0) return; - - backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst); } static void ggml_cl_expm1(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { @@ -6763,53 +7307,58 @@ static void ggml_cl_repeat(ggml_backend_t backend, const ggml_tensor * src0, con ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; - if (backend_ctx->kernel_repeat == nullptr) { - GGML_LOG_WARN("%s: repeat kernel not available, skipping OpenCL execution.\n", __func__); - return; - } + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; - ggml_tensor_extra_cl * extra_src0 = (ggml_tensor_extra_cl *)src0->extra; - ggml_tensor_extra_cl * extra_dst = (ggml_tensor_extra_cl *)dst->extra; + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; - cl_ulong off_src0 = extra_src0->offset + src0->view_offs; - cl_ulong off_dst = extra_dst->offset + dst->view_offs; + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; - const int src0_ne0 = src0->ne[0]; const int src0_ne1 = src0->ne[1]; const int src0_ne2 = src0->ne[2]; const int src0_ne3 = src0->ne[3]; - const cl_ulong src0_nb0 = src0->nb[0]; const cl_ulong src0_nb1 = src0->nb[1]; const cl_ulong src0_nb2 = src0->nb[2]; const cl_ulong src0_nb3 = src0->nb[3]; + const cl_ulong nb00 = src0->nb[0]; + const cl_ulong nb01 = src0->nb[1]; + const cl_ulong nb02 = src0->nb[2]; + const cl_ulong nb03 = src0->nb[3]; - const int dst_ne0 = dst->ne[0]; const int dst_ne1 = dst->ne[1]; const int dst_ne2 = dst->ne[2]; const int dst_ne3 = dst->ne[3]; - const cl_ulong dst_nb0 = dst->nb[0]; const cl_ulong dst_nb1 = dst->nb[1]; const cl_ulong dst_nb2 = dst->nb[2]; const cl_ulong dst_nb3 = dst->nb[3]; + const int ne0 = dst->ne[0]; + const int ne1 = dst->ne[1]; + const int ne2 = dst->ne[2]; + const int ne3 = dst->ne[3]; - cl_kernel kernel = backend_ctx->kernel_repeat; + const cl_ulong nb0 = dst->nb[0]; + const cl_ulong nb1 = dst->nb[1]; + const cl_ulong nb2 = dst->nb[2]; + const cl_ulong nb3 = dst->nb[3]; - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_src0->data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra_dst->data_device)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_ulong), &off_src0)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &off_dst)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &src0_ne0)); - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &src0_ne1)); - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &src0_ne2)); - CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &src0_ne3)); - CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &src0_nb0)); - CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &src0_nb1)); - CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &src0_nb2)); - CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &src0_nb3)); - CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &dst_ne0)); - CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &dst_ne1)); - CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &dst_ne2)); - CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &dst_ne3)); - CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &dst_nb0)); - CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &dst_nb1)); - CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &dst_nb2)); - CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &dst_nb3)); - - size_t gws0 = dst_ne1 > 0 ? (size_t)dst_ne1 : 1; - size_t gws1 = dst_ne2 > 0 ? (size_t)dst_ne2 : 1; - size_t gws2 = dst_ne3 > 0 ? (size_t)dst_ne3 : 1; - - size_t global_work_size[] = { gws0, gws1, gws2 }; + cl_kernel kernel = backend_ctx->kernel_repeat_f32; - backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, NULL, dst); + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb0)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb3)); + + int nth = 64; + + size_t global_work_size[] = {(size_t)ne1*nth, (size_t)ne2, (size_t)ne3}; + size_t local_work_size[] = {(size_t)nth, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); } static void ggml_cl_pad(ggml_backend_t backend, const ggml_tensor * src0, ggml_tensor * dst) { @@ -7033,121 +7582,76 @@ static void ggml_cl_concat(ggml_backend_t backend, const ggml_tensor * src0, con GGML_ASSERT(dst->type == GGML_TYPE_F32); ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; - cl_command_queue queue = backend_ctx->queue; - if (backend_ctx->kernel_concat_f32_contiguous == nullptr || backend_ctx->kernel_concat_f32_non_contiguous == nullptr) { - GGML_LOG_WARN("%s: concat kernels not available, skipping OpenCL execution.\n", __func__); - return; - } + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; + + const cl_ulong nb00 = src0->nb[0]; + const cl_ulong nb01 = src0->nb[1]; + const cl_ulong nb02 = src0->nb[2]; + const cl_ulong nb03 = src0->nb[3]; + + const cl_ulong nb10 = src1->nb[0]; + const cl_ulong nb11 = src1->nb[1]; + const cl_ulong nb12 = src1->nb[2]; + const cl_ulong nb13 = src1->nb[3]; - ggml_tensor_extra_cl * extra0_cl = (ggml_tensor_extra_cl *)src0->extra; - ggml_tensor_extra_cl * extra1_cl = (ggml_tensor_extra_cl *)src1->extra; - ggml_tensor_extra_cl * extrad_cl = (ggml_tensor_extra_cl *)dst->extra; + const int ne0 = dst->ne[0]; + const int ne1 = dst->ne[1]; + const int ne2 = dst->ne[2]; + const int ne3 = dst->ne[3]; - cl_ulong off_src0 = extra0_cl->offset + src0->view_offs; - cl_ulong off_src1 = extra1_cl->offset + src1->view_offs; - cl_ulong off_dst = extrad_cl->offset + dst->view_offs; + const cl_ulong nb0 = dst->nb[0]; + const cl_ulong nb1 = dst->nb[1]; + const cl_ulong nb2 = dst->nb[2]; + const cl_ulong nb3 = dst->nb[3]; - const int32_t dim = ((const int32_t *) dst->op_params)[0]; + const cl_int dim = ((const int32_t *) dst->op_params)[0]; GGML_ASSERT(dim >= 0 && dim <= 3); - if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) { - if (dim == 3) { + int nth = MIN(64, ne0); - size_t nbytes_src0 = ggml_nbytes(src0); - size_t nbytes_src1 = ggml_nbytes(src1); + cl_kernel kernel = backend_ctx->kernel_concat_f32; - CL_CHECK(clEnqueueCopyBuffer(queue, extra0_cl->data_device, extrad_cl->data_device, - off_src0, off_dst, nbytes_src0, 0, NULL, NULL)); - CL_CHECK(clEnqueueCopyBuffer(queue, extra1_cl->data_device, extrad_cl->data_device, - off_src1, off_dst + nbytes_src0, nbytes_src1, 0, NULL, NULL)); - } else { + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne03)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb10)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb13)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb0)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &nb3)); + CL_CHECK(clSetKernelArg(kernel, 23, sizeof(cl_int), &dim)); + + size_t global_work_size[] = {(size_t)ne1*nth, (size_t)ne2, (size_t)ne3}; + size_t local_work_size[] = {(size_t)nth, 1, 1}; - cl_kernel kernel = backend_ctx->kernel_concat_f32_contiguous; - size_t global_work_size[3]; - - for (int i3 = 0; i3 < dst->ne[3]; ++i3) { - cl_ulong current_off_src0 = off_src0 + (i3 * src0->nb[3]); - cl_ulong current_off_src1 = off_src1 + (i3 * src1->nb[3]); - cl_ulong current_off_dst = off_dst + (i3 * dst->nb[3]); - - int d_ne00 = src0->ne[0]; int d_ne01 = src0->ne[1]; int d_ne02 = src0->ne[2]; - int d_ne10 = src1->ne[0]; int d_ne11 = src1->ne[1]; int d_ne12 = src1->ne[2]; - int d_ne0 = dst->ne[0]; int d_ne1 = dst->ne[1]; int d_ne2 = dst->ne[2]; - - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_cl->data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), ¤t_off_src0)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1_cl->data_device)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), ¤t_off_src1)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad_cl->data_device)); - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), ¤t_off_dst)); - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &d_ne00)); - CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &d_ne01)); - CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &d_ne02)); - CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &d_ne10)); - CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &d_ne11)); - CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &d_ne12)); - CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &d_ne0)); - CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &d_ne1)); - CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &d_ne2)); - CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &dim)); - - global_work_size[0] = d_ne0; - global_work_size[1] = d_ne1; - global_work_size[2] = d_ne2; - - backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, NULL, dst); - } - } - } else { - cl_kernel kernel = backend_ctx->kernel_concat_f32_non_contiguous; - - cl_long ne00 = src0->ne[0], ne01 = src0->ne[1], ne02 = src0->ne[2], ne03 = src0->ne[3]; - cl_ulong nb00 = src0->nb[0], nb01 = src0->nb[1], nb02 = src0->nb[2], nb03 = src0->nb[3]; - - cl_ulong nb10 = src1->nb[0], nb11 = src1->nb[1], nb12 = src1->nb[2], nb13 = src1->nb[3]; - - cl_long d_ne0 = dst->ne[0], d_ne1 = dst->ne[1], d_ne2 = dst->ne[2], d_ne3 = dst->ne[3]; - cl_ulong d_nb0 = dst->nb[0], d_nb1 = dst->nb[1], d_nb2 = dst->nb[2], d_nb3 = dst->nb[3]; - - - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_cl->data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &off_src0)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1_cl->data_device)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &off_src1)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad_cl->data_device)); - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &off_dst)); - - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_long), &ne00)); - CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_long), &ne01)); - CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_long), &ne02)); - CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_long), &ne03)); - CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb00)); - CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb01)); - CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02)); - CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb03)); - - CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb10)); - CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb11)); - CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb12)); - CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb13)); - - CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_long), &d_ne0)); - CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_long), &d_ne1)); - CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_long), &d_ne2)); - CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_long), &d_ne3)); - CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &d_nb0)); - CL_CHECK(clSetKernelArg(kernel, 23, sizeof(cl_ulong), &d_nb1)); - CL_CHECK(clSetKernelArg(kernel, 24, sizeof(cl_ulong), &d_nb2)); - CL_CHECK(clSetKernelArg(kernel, 25, sizeof(cl_ulong), &d_nb3)); - CL_CHECK(clSetKernelArg(kernel, 26, sizeof(int), &dim)); - - size_t global_work_size_nc[] = { d_ne1 > 0 ? (size_t)d_ne1 : 1, - d_ne2 > 0 ? (size_t)d_ne2 : 1, - d_ne3 > 0 ? (size_t)d_ne3 : 1 }; - - backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size_nc, NULL, dst); - } + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); } static void ggml_cl_timestep_embedding(ggml_backend_t backend, const ggml_tensor * src0, ggml_tensor * dst) { @@ -7598,6 +8102,253 @@ static void ggml_cl_mul_mat_kq_kqv_adreno(ggml_backend_t backend, const ggml_ten CL_CHECK(clReleaseMemObject(D_sub_buffer)); } +static void ggml_cl_mul_mat_q8_0_f32_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + const enum ggml_type src0t = src0->type; + const enum ggml_type src1t = src1->type; + + GGML_ASSERT(src0t == GGML_TYPE_Q8_0); + GGML_ASSERT(src1t == GGML_TYPE_F32); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + ggml_tensor_extra_cl_q8_0 * extra0_q8_0 = (ggml_tensor_extra_cl_q8_0 *)src0->extra; + + GGML_ASSERT(src1->view_offs == 0); + GGML_ASSERT(dst->view_offs == 0); + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + + const int ne10 = src1->ne[0]; + const int ne12 = src1->ne[2]; + + const int ne0 = dst->ne[0]; + const int ne1 = dst->ne[1]; + + GGML_ASSERT(ne00 == ne10); + GGML_ASSERT((ne00 % 32) == 0); + GGML_ASSERT(ne0 == ne01); + + cl_context context = backend_ctx->context; + cl_kernel kernel; + + // init CL objects + cl_int status; + cl_image_format img_fmt_1d; + cl_image_desc img_desc_1d; + cl_buffer_region region; + cl_mem A_image1d; + cl_mem B_image1d; + cl_mem B_sub_buffer; + cl_mem S_image1d; + + cl_mem D_image1d; + cl_mem D_sub_buffer; + + int M = ne01; + int N = ne1; + int K = ne00; + + // create an image for A + img_fmt_1d = { CL_R, CL_FLOAT}; + memset(&img_desc_1d, 0, sizeof(img_desc_1d)); + img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc_1d.image_width = M * K / 4; // Divide by 4 for char -> float + img_desc_1d.buffer = extra0_q8_0->q; + A_image1d = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt_1d, &img_desc_1d, NULL, &status); + CL_CHECK(status); + + // create an image for Scale + img_fmt_1d = { CL_R, CL_HALF_FLOAT}; + memset(&img_desc_1d, 0, sizeof(img_desc_1d)); + img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc_1d.image_width = M * K / 32; // Block size is 32 + img_desc_1d.buffer = extra0_q8_0->d; + S_image1d = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt_1d, &img_desc_1d, NULL, &status); + CL_CHECK(status); + + // create a sub_buffer for B + region.origin = (extra1->offset); // + src1->view_offs); + region.size = K * N * sizeof(float); + B_sub_buffer = clCreateSubBuffer((extra1->data_device), 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // create an image for B from sub_buffer: RGBA (OCL) + img_fmt_1d = {CL_RGBA, CL_FLOAT}; + memset(&img_desc_1d, 0, sizeof(img_desc_1d)); + img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc_1d.image_width = K * N / 4; + img_desc_1d.buffer = B_sub_buffer; + B_image1d = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt_1d, &img_desc_1d, NULL, &status); + CL_CHECK(status); + + // Create subbuffer and image1d_buffer for dst + region.origin = (extrad->offset); // + dst->view_offs; + region.size = M * N * sizeof(float); + D_sub_buffer = clCreateSubBuffer((extrad->data_device), 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + img_fmt_1d = {CL_R, CL_FLOAT}; + memset(&img_desc_1d, 0, sizeof(img_desc_1d)); + img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc_1d.image_width = M * N; + img_desc_1d.buffer = D_sub_buffer; + D_image1d = clCreateImage(context, CL_MEM_WRITE_ONLY, &img_fmt_1d, &img_desc_1d, NULL, &status); + CL_CHECK(status); + + size_t local_work_size[3] = {1, 1, 1}; + size_t global_work_size[3] = {1, 1, 1}; + + if (N == 1) { + kernel = backend_ctx->CL_mul_mat_vec_q8_0_f32; + + int r2 = 1; + int r3 = 1; + cl_uint k_arg = 0; + + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &A_image1d)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &extra0_q8_0->d)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &B_image1d)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_ulong), &extra1->offset)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_ulong), &extrad->offset)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &r3)); + + size_t wavesize = backend_ctx->adreno_wave_size; + local_work_size[0] = wavesize; + local_work_size[1] = 4; // reduce factor + local_work_size[2] = 1; + + global_work_size[0] = ((M + wavesize - 1) / wavesize) * wavesize; + global_work_size[1] = 4; // reduce factor + global_work_size[2] = 1; + } else { + cl_ulong offsetd = extrad->offset + dst->view_offs; + cl_mem B_image1d_trans = nullptr; + // for B transpose + cl_mem B_d = nullptr; + int padding; + + //how many extra elements beyond multiple of 8 + int extra_elements = N % 8; + + //how much padding to add + padding = 0; + if (extra_elements > 0){ + padding = 8 - extra_elements; + } + + // Specify the starting offset (in bytes) + region.origin = 0; + // Specify the size of the sub-buffer (divide by 2 for FP16) + region.size = K * (N + padding) * sizeof(float)/2; + backend_ctx->prealloc_act_trans.allocate(context, region.size); + B_d = clCreateSubBuffer( + backend_ctx->prealloc_act_trans.buffer, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + + cl_image_format image_format_B_d_output = { CL_RGBA, CL_HALF_FLOAT }; //(CL_HALF_FLOAT for FP16) + cl_image_desc image_desc_B_d_output = { + CL_MEM_OBJECT_IMAGE1D_BUFFER, + static_cast(K * (N + padding)/4), + 0, 0, 0, 0, 0, 0, 0, { B_d } + }; + B_image1d_trans = clCreateImage( + context, + 0, + &image_format_B_d_output, + &image_desc_B_d_output, + NULL, + &status); + CL_CHECK(status); + + int height_B = N/4; + if (height_B == 0) { + height_B = 1; + } + int width_B = K/4; + int padded_height_B = (N + padding)/4; + + kernel = backend_ctx->kernel_transpose_32_16; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &B_image1d)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &B_image1d_trans)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_B)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_B)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &padded_height_B)); + + size_t local_size_t[2] = { 1, 16 }; + size_t global_size_t[2] = { + static_cast(width_B), + static_cast(padded_height_B) + }; + + backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_size_t, local_size_t, dst); + + kernel = backend_ctx->kernel_mul_mm_q8_0_f32_8x4; + + int N_with_padding = N + padding; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q8_0->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q8_0->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &B_image1d_trans)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &K)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &M)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &N_with_padding)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &N)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &offsetd)); + + global_work_size[0] = (size_t)(N + 7) / 8; + global_work_size[1] = (size_t)(M + 3) / 4; + global_work_size[2] = 1; + + local_work_size[0] = 2; + local_work_size[1] = 128; + local_work_size[2] = 1; + } + + // enqueue kernel with profiling + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + + // deallocate sub buffers and images + CL_CHECK(clReleaseMemObject(A_image1d)); + CL_CHECK(clReleaseMemObject(B_sub_buffer)); + CL_CHECK(clReleaseMemObject(B_image1d)); + CL_CHECK(clReleaseMemObject(S_image1d)); + CL_CHECK(clReleaseMemObject(D_sub_buffer)); + CL_CHECK(clReleaseMemObject(D_image1d)); +#else + GGML_UNUSED(backend); + GGML_UNUSED(src0); + GGML_UNUSED(src1); + GGML_UNUSED(dst); +#endif +} + static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0); GGML_ASSERT(src0->extra); @@ -7623,6 +8374,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co ggml_tensor_extra_cl_q4_0 * extra0_q4_0 = (ggml_tensor_extra_cl_q4_0 *)src0->extra; ggml_tensor_extra_cl_mxfp4 * extra0_mxfp4 = (ggml_tensor_extra_cl_mxfp4 *)src0->extra; ggml_tensor_extra_cl_q8_0 * extra0_q8_0 = (ggml_tensor_extra_cl_q8_0 *)src0->extra; + ggml_tensor_extra_cl_q6_K * extra0_q6_K = (ggml_tensor_extra_cl_q6_K *)src0->extra; #endif const int ne00 = src0 ? src0->ne[0] : 0; @@ -7665,9 +8417,12 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co cl_context context = backend_ctx->context; if(src0t == GGML_TYPE_F16 && src1t == GGML_TYPE_F32){ - if (ne01 >= 64 && ne1 >= 32 && ne00 >= 16 && (ne12 % ne02) == 0) { + if (ne01 >= 64 && ne1 >= 32 && ne00 >= 16 && (ne12 % ne02) == 0 && + // dst is wrapped with image1d_buffer, the size limit applies, also src0 + (ne0 * ne1 * dst->ne[2] * dst->nb[0] / 4 <= backend_ctx->image_max_buffer_size)) { // For KQ if (ggml_is_permuted(src0) && ggml_is_permuted(src1) && + ((nb01 * ne01 / 4)/4 <= backend_ctx->image_max_buffer_size) && nb00 <= nb02 && nb02 <= nb01 && nb01 <= nb03 && @@ -7678,7 +8433,8 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co return; } // For KQV - if (!ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) { + if (!ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && + ((nb02 * ne02 / 4)/4 <= backend_ctx->image_max_buffer_size)) { ggml_cl_mul_mat_kq_kqv_adreno(backend, src0, src1, dst); return; } @@ -7710,6 +8466,13 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co int padding; // <--------------------------------------------> // + // q8_0 x fp32 + if (src0t == GGML_TYPE_Q8_0 && src1t == GGML_TYPE_F32 && + enable_adreno_trans_weight(backend_ctx, src0)) { + ggml_cl_mul_mat_q8_0_f32_adreno(backend, src0, src1, dst); + return; + } + // q4_0 x fp32 if(src0t == GGML_TYPE_Q4_0 && src1t == GGML_TYPE_F32) { // TODO: remove duplicate definitions of image description + format -- move to top @@ -7984,9 +8747,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co // GEMM using local memory // Current BK = 16, so ne00 % 16 == 0 - if (ggml_is_contiguous(src0) && - ggml_is_contiguous(src1) && - src1t == GGML_TYPE_F32 && + if (src1t == GGML_TYPE_F32 && ne00 % 16 == 0 && ne11 > 1) { switch(src0t) { @@ -7998,10 +8759,42 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co int batch_stride_b = ne10*ne11; int batch_stride_d = ne0*ne1; - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + cl_mem mem_src0 = extra0->data_device; + cl_mem mem_src1 = extra1->data_device; + + cl_ulong nb00_cont = nb00; + cl_ulong nb01_cont = nb01; + cl_ulong nb02_cont = nb02; + cl_ulong nb03_cont = nb03; + + cl_ulong nb10_cont = nb10; + cl_ulong nb11_cont = nb11; + cl_ulong nb12_cont = nb12; + cl_ulong nb13_cont = nb13; + + cl_ulong offset0_cont = offset0; + cl_ulong offset1_cont = offset1; + + if (!ggml_is_contiguous(src0)) { + backend_ctx->prealloc_src0.allocate(backend_ctx->context, ggml_nbytes(src0)); + ggml_cl_copy_to_contiguous(backend, src0, backend_ctx->prealloc_src0.buffer, + nb00_cont, nb01_cont, nb02_cont, nb03_cont); + mem_src0 = backend_ctx->prealloc_src0.buffer; + offset0_cont = 0; + } + + if (!ggml_is_contiguous(src1)) { + backend_ctx->prealloc_src1.allocate(backend_ctx->context, ggml_nbytes(src1)); + ggml_cl_copy_to_contiguous(backend, src1, backend_ctx->prealloc_src1.buffer, + nb10_cont, nb11_cont, nb12_cont, nb13_cont); + mem_src1 = backend_ctx->prealloc_src1.buffer; + offset1_cont = 0; + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &mem_src0)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0_cont)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &mem_src1)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1_cont)); CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); @@ -8033,10 +8826,42 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co int batch_stride_b = ne10*ne11; int batch_stride_d = ne0*ne1; - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + cl_mem mem_src0 = extra0->data_device; + cl_mem mem_src1 = extra1->data_device; + + cl_ulong nb00_cont = nb00; + cl_ulong nb01_cont = nb01; + cl_ulong nb02_cont = nb02; + cl_ulong nb03_cont = nb03; + + cl_ulong nb10_cont = nb10; + cl_ulong nb11_cont = nb11; + cl_ulong nb12_cont = nb12; + cl_ulong nb13_cont = nb13; + + cl_ulong offset0_cont = offset0; + cl_ulong offset1_cont = offset1; + + if (!ggml_is_contiguous(src0)) { + backend_ctx->prealloc_src0.allocate(backend_ctx->context, ggml_nbytes(src0)); + ggml_cl_copy_to_contiguous(backend, src0, backend_ctx->prealloc_src0.buffer, + nb00_cont, nb01_cont, nb02_cont, nb03_cont); + mem_src0 = backend_ctx->prealloc_src0.buffer; + offset0_cont = 0; + } + + if (!ggml_is_contiguous(src1)) { + backend_ctx->prealloc_src1.allocate(backend_ctx->context, ggml_nbytes(src1)); + ggml_cl_copy_to_contiguous(backend, src1, backend_ctx->prealloc_src1.buffer, + nb10_cont, nb11_cont, nb12_cont, nb13_cont); + mem_src1 = backend_ctx->prealloc_src1.buffer; + offset1_cont = 0; + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &mem_src0)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0_cont)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &mem_src1)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1_cont)); CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); @@ -8064,6 +8889,10 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co if (ne11 < 32) { break; } + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) { + break; + } + kernel = backend_ctx->kernel_mul_mm_q8_0_f32_l4_lm; nth0 = 128; // calculated as (BM*BN)/(TM*TN) @@ -8436,14 +9265,49 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: +#ifdef GGML_OPENCL_SOA_Q + kernel = backend_ctx->kernel_mul_mv_q6_K_f32_flat; + + if (backend_ctx->gpu_family == INTEL) { + nth0 = 16; + nth1 = 2; + ndst = 4; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 2; + ndst = 4; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q6_K->ql)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q6_K->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q6_K->s)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q6_K->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &r3)); +#else kernel = backend_ctx->kernel_mul_mv_q6_K_f32; if (backend_ctx->gpu_family == INTEL) { - nth0 = 2; - nth1 = 16; + nth0 = 16; + nth1 = 2; + ndst = 1; } else if (backend_ctx->gpu_family == ADRENO) { - nth0 = 2; - nth1 = 64; + nth0 = 64; + nth1 = 2; + ndst = 1; } else { GGML_ASSERT(false && "TODO: Unknown GPU"); } @@ -8463,6 +9327,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne1)); CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &r2)); CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3)); +#endif // GGML_OPENCL_SOA_Q break; case GGML_TYPE_MXFP4: { #ifdef GGML_OPENCL_SOA_Q @@ -8565,7 +9430,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co } else if (src0t == GGML_TYPE_Q5_K) { GGML_ASSERT(false && "not implemented"); } else if (src0t == GGML_TYPE_Q6_K) { - size_t global_work_size[] = {(size_t)(ne01+1)/2*nth0, (size_t)ne11*nth1, (size_t)ne12*ne13}; + size_t global_work_size[] = {(size_t)(ne01+ndst*nth1-1)/(ndst*nth1)*nth0, (size_t)ne11*nth1, (size_t)ne12*ne13}; size_t local_work_size[] = {(size_t)nth0, (size_t)nth1, 1}; backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); @@ -8997,7 +9862,16 @@ static void ggml_cl_scale(ggml_backend_t backend, const ggml_tensor * src0, cons cl_ulong offset0 = extra0->offset + src0->view_offs; cl_ulong offsetd = extrad->offset + dst->view_offs; - cl_kernel kernel = backend_ctx->kernel_scale; + cl_kernel kernel; + + int n = ggml_nelements(dst); + + if (n % 4 == 0) { + kernel = backend_ctx->kernel_scale_f32_4; + n /= 4; + } else { + kernel = backend_ctx->kernel_scale_f32; + } CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); @@ -9006,8 +9880,6 @@ static void ggml_cl_scale(ggml_backend_t backend, const ggml_tensor * src0, cons CL_CHECK(clSetKernelArg(kernel, 4, sizeof(float), &scale)); CL_CHECK(clSetKernelArg(kernel, 5, sizeof(float), &bias)); - int n = ggml_nelements(dst)/4; - size_t global_work_size[] = {(size_t)n, 1, 1}; size_t local_work_size[] = {64, 1, 1}; @@ -10012,6 +10884,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor } func = ggml_cl_glu; break; + case GGML_OP_TRI: + if (!any_on_device) { + return false; + } + func = ggml_cl_tri; + break; case GGML_OP_FILL: if (!any_on_device) { return false; diff --git a/ggml/src/ggml-opencl/kernels/concat.cl b/ggml/src/ggml-opencl/kernels/concat.cl index 132758469c6..0c1b3d785ca 100644 --- a/ggml/src/ggml-opencl/kernels/concat.cl +++ b/ggml/src/ggml-opencl/kernels/concat.cl @@ -1,109 +1,51 @@ -kernel void kernel_concat_f32_contiguous( - global const char * p_src0, ulong off_src0, - global const char * p_src1, ulong off_src1, - global char * p_dst, ulong off_dst, - int d_ne00, int d_ne01, int d_ne02, // src0->ne[0..2] for the slice - int d_ne10, int d_ne11, int d_ne12, // src1->ne[0..2] for the slice (d_ne1X must match d_ne0X on non-concat axes) - int d_ne0, int d_ne1, int d_ne2, // dst->ne[0..2] for the slice - int dim +kernel void kernel_concat_f32( + global const char * src0, + ulong offset0, + global const char * src1, + ulong offset1, + global char * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + ulong nb10, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3, + int dim ) { - global const float * src0 = (global const float*)((global char*)p_src0 + off_src0); - global const float * src1 = (global const float*)((global char*)p_src1 + off_src1); - global float * dst = (global float*)((global char*)p_dst + off_dst); + src0 = src0 + offset0; + src1 = src1 + offset1; + dst = dst + offsetd; - int i0 = get_global_id(0); // Index along dst's 0th dimension - int i1 = get_global_id(1); // Index along dst's 1st dimension - int i2 = get_global_id(2); // Index along dst's 2nd dimension + const int i3 = get_group_id(2); + const int i2 = get_group_id(1); + const int i1 = get_group_id(0); - if (i0 >= d_ne0 || i1 >= d_ne1 || i2 >= d_ne2) { - return; - } - - ulong dst_idx = (ulong)i2 * d_ne0 * d_ne1 + (ulong)i1 * d_ne0 + i0; - ulong src_idx; - - if (dim == 0) { - if (i0 < d_ne00) { // Data from src0 - src_idx = (ulong)i2 * d_ne00 * d_ne01 + (ulong)i1 * d_ne00 + i0; - dst[dst_idx] = src0[src_idx]; - } else { // Data from src1 - src_idx = (ulong)i2 * d_ne10 * d_ne11 + (ulong)i1 * d_ne10 + (i0 - d_ne00); - dst[dst_idx] = src1[src_idx]; - } - } else if (dim == 1) { - if (i1 < d_ne01) { // Data from src0 - src_idx = (ulong)i2 * d_ne00 * d_ne01 + (ulong)i1 * d_ne00 + i0; - dst[dst_idx] = src0[src_idx]; - } else { // Data from src1 - src_idx = (ulong)i2 * d_ne10 * d_ne11 + (ulong)(i1 - d_ne01) * d_ne10 + i0; - dst[dst_idx] = src1[src_idx]; - } - } else if (dim == 2) { - if (i2 < d_ne02) { // Data from src0 - src_idx = (ulong)i2 * d_ne00 * d_ne01 + (ulong)i1 * d_ne00 + i0; - dst[dst_idx] = src0[src_idx]; - } else { // Data from src1 - - src_idx = (ulong)(i2 - d_ne02) * d_ne10 * d_ne11 + (ulong)i1 * d_ne10 + i0; - dst[dst_idx] = src1[src_idx]; - } - } -} - -kernel void kernel_concat_f32_non_contiguous( - global const char * p_src0, ulong off_src0, - global const char * p_src1, ulong off_src1, - global char * p_dst, ulong off_dst, - - long ne00, long ne01, long ne02, long ne03, - ulong nb00, ulong nb01, ulong nb02, ulong nb03, + int o[4] = {0, 0, 0, 0}; + o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03)); - ulong nb10, ulong nb11, ulong nb12, ulong nb13, // Strides for src1 + global const float * x; - long d_ne0, long d_ne1, long d_ne2, long d_ne3, - ulong d_nb0, ulong d_nb1, ulong d_nb2, ulong d_nb3, - int dim -) { - global const char * src0_base = p_src0 + off_src0; - global const char * src1_base = p_src1 + off_src1; - global char * dst_base = p_dst + off_dst; - - long current_i1 = get_global_id(0); // Index for dst_dim_1 - long current_i2 = get_global_id(1); // Index for dst_dim_2 - long current_i3 = get_global_id(2); // Index for dst_dim_3 - - if (current_i1 >= d_ne1 || current_i2 >= d_ne2 || current_i3 >= d_ne3) { - return; - } - - global const float * x_val_ptr; - global float * y_val_ptr; - - for (long current_i0 = 0; current_i0 < d_ne0; ++current_i0) { - bool use_src0; - long s_i0 = current_i0, s_i1 = current_i1, s_i2 = current_i2, s_i3 = current_i3; - - if (dim == 0) { - use_src0 = (current_i0 < ne00); - if (!use_src0) { s_i0 = current_i0 - ne00; } - } else if (dim == 1) { - use_src0 = (current_i1 < ne01); - if (!use_src0) { s_i1 = current_i1 - ne01; } - } else if (dim == 2) { - use_src0 = (current_i2 < ne02); - if (!use_src0) { s_i2 = current_i2 - ne02; } - } else { // dim == 3 - use_src0 = (current_i3 < ne03); - if (!use_src0) { s_i3 = current_i3 - ne03; } - } - - if (use_src0) { - x_val_ptr = (global const float *)(src0_base + (ulong)s_i3*nb03 + (ulong)s_i2*nb02 + (ulong)s_i1*nb01 + (ulong)s_i0*nb00); + for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) { + if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { + x = (global const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00); } else { - x_val_ptr = (global const float *)(src1_base + (ulong)s_i3*nb13 + (ulong)s_i2*nb12 + (ulong)s_i1*nb11 + (ulong)s_i0*nb10); + x = (global const float *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10); } - y_val_ptr = (global float *)(dst_base + (ulong)current_i3*d_nb3 + (ulong)current_i2*d_nb2 + (ulong)current_i1*d_nb1 + (ulong)current_i0*d_nb0); - *y_val_ptr = *x_val_ptr; + global float * y = (global float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + *y = *x; } } diff --git a/ggml/src/ggml-opencl/kernels/cvt.cl b/ggml/src/ggml-opencl/kernels/cvt.cl index 513a4d3e28f..9fb434713df 100644 --- a/ggml/src/ggml-opencl/kernels/cvt.cl +++ b/ggml/src/ggml-opencl/kernels/cvt.cl @@ -46,6 +46,16 @@ struct block_q4_0 uint8_t qs[QK4_0 / 2]; }; +//------------------------------------------------------------------------------ +// block_q6_K +//------------------------------------------------------------------------------ +struct block_q6_K { + uint8_t ql[QK_K/2]; // quants, lower 4 bits + uint8_t qh[QK_K/4]; // quants, upper 2 bits + int8_t scales[QK_K/16]; // scales, quantized with 8 bits + half d; // super-block scale +}; + //------------------------------------------------------------------------------ // kernel_convert_block_q4_0 // Convert the block_q4_0 format to 2 separate arrays (AOS -> SOA). @@ -263,3 +273,94 @@ kernel void kernel_restore_block_q8_0( b->qs[i] = q[i]; } } + +kernel void kernel_restore_block_q8_0_trans( + global uchar * src_q, + global half * src_d, + global block_q8_0 * dst, + uint ne00, + uint ne01 +){ + uint num_blk_per_row = ne00 / QK8_0; + + global block_q8_0 * b = (global block_q8_0 *) dst + get_global_id(0) * num_blk_per_row; + global uchar * q = (global uchar *) src_q + get_global_id(0) * 4; // 4 8-bit packed + global half * d = (global half *) src_d + get_global_id(0); + + for (uint blk = 0; blk < num_blk_per_row; blk++) { + b->d = *d; + + for (uint i = 0; i < QK8_0; i+=4) { + b->qs[i] = q[0]; + b->qs[i+1] = q[1]; + b->qs[i+2] = q[2]; + b->qs[i+3] = q[3]; + + q += 4 * ne01; // M stride + } + + d += ne01; + + b++; + } +} + +//------------------------------------------------------------------------------ +// kernel_convert_block_q6_K +// Convert the block_q6_K format to 3 separate arrays (AOS -> SOA). +// This kernel does not deshuffle the bits. +// Each thread processes a super block. +//------------------------------------------------------------------------------ +kernel void kernel_convert_block_q6_K( + global struct block_q6_K * src0, + global uchar * dst_ql, + global uchar * dst_qh, + global char * dst_s, + global half * dst_d +) { + global struct block_q6_K * b = (global struct block_q6_K *) src0 + get_global_id(0); + global uchar * ql = (global uchar *) dst_ql + QK_K/2*get_global_id(0); + global uchar * qh = (global uchar *) dst_qh + QK_K/4*get_global_id(0); + global char * s = (global char *) dst_s + QK_K/16*get_global_id(0); + global half * d = (global half *) dst_d + get_global_id(0); + + *d = b->d; + + for (int i = 0; i < QK_K/2; ++i) { + ql[i] = b->ql[i]; + } + for (int i = 0; i < QK_K/4; ++i) { + qh[i] = b->qh[i]; + } + for (int i = 0; i < QK_K/16; ++i) { + s[i] = b->scales[i]; + } +} + +// Restore block_q6_K from flattened arrays. +// Each thread processes a super block. +kernel void kernel_restore_block_q6_K( + global uchar * dst_ql, + global uchar * dst_qh, + global char * dst_s, + global half * dst_d, + global struct block_q6_K * dst +) { + global struct block_q6_K * b = (global struct block_q6_K *) dst + get_global_id(0); + global uchar * ql = (global uchar *) dst_ql + QK_K/2*get_global_id(0); + global uchar * qh = (global uchar *) dst_qh + QK_K/4*get_global_id(0); + global char * s = (global char *) dst_s + QK_K/16*get_global_id(0); + global half * d = (global half *) dst_d + get_global_id(0); + + b->d = *d; + + for (int i = 0; i < QK_K/2; ++i) { + b->ql[i] = ql[i]; + } + for (int i = 0; i < QK_K/4; ++i) { + b->qh[i] = qh[i]; + } + for (int i = 0; i < QK_K/16; ++i) { + b->scales[i] = s[i]; + } +} diff --git a/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl new file mode 100644 index 00000000000..f944ef3a992 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl @@ -0,0 +1,195 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable + +#ifdef cl_qcom_reqd_sub_group_size +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#endif + +#define QK8_0 32 +#define N_SIMDGROUP 4 + +#define dequantizeBlockAccum_ns_sgbroadcast_1(total_sums, bits8, scale, y) \ + float shared_y; \ + char elem; \ + \ + shared_y = sub_group_broadcast(y.s0, 0); \ + elem = (char)(bits8.s0 & 0x000000FF); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 0); \ + elem = (char)((bits8.s0 & 0x0000FF00) >> 8); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 0); \ + elem = (char)((bits8.s0 & 0x00FF0000) >> 16); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 0); \ + elem = (char)((bits8.s0 & 0xFF000000) >> 24); \ + total_sums += convert_int(elem) * scale * shared_y; \ + \ + shared_y = sub_group_broadcast(y.s4, 0); \ + elem = (char)(bits8.s1 & 0x000000FF); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 0); \ + elem = (char)((bits8.s1 & 0x0000FF00) >> 8); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 0); \ + elem = (char)((bits8.s1 & 0x00FF0000) >> 16); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 0); \ + elem = (char)((bits8.s1 & 0xFF000000) >> 24); \ + total_sums += convert_int(elem) * scale * shared_y; \ + \ + shared_y = sub_group_broadcast(y.s0, 1); \ + elem = (char)(bits8.s2 & 0x000000FF); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 1); \ + elem = (char)((bits8.s2 & 0x0000FF00) >> 8); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 1); \ + elem = (char)((bits8.s2 & 0x00FF0000) >> 16); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 1); \ + elem = (char)((bits8.s2 & 0xFF000000) >> 24); \ + total_sums += convert_int(elem) * scale * shared_y; \ + \ + shared_y = sub_group_broadcast(y.s4, 1); \ + elem = (char)(bits8.s3 & 0x000000FF); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 1); \ + elem = (char)((bits8.s3 & 0x0000FF00) >> 8); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 1); \ + elem = (char)((bits8.s3 & 0x00FF0000) >> 16); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 1); \ + elem = (char)((bits8.s3 & 0xFF000000) >> 24); \ + total_sums += convert_int(elem) * scale * shared_y; \ + \ + shared_y = sub_group_broadcast(y.s0, 2); \ + elem = (char)(bits8.s4 & 0x000000FF); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 2); \ + elem = (char)((bits8.s4 & 0x0000FF00) >> 8); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 2); \ + elem = (char)((bits8.s4 & 0x00FF0000) >> 16); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 2); \ + elem = (char)((bits8.s4 & 0xFF000000) >> 24); \ + total_sums += convert_int(elem) * scale * shared_y; \ + \ + shared_y = sub_group_broadcast(y.s4, 2); \ + elem = (char)(bits8.s5 & 0x000000FF); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 2); \ + elem = (char)((bits8.s5 & 0x0000FF00) >> 8); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 2); \ + elem = (char)((bits8.s5 & 0x00FF0000) >> 16); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 2); \ + elem = (char)((bits8.s5 & 0xFF000000) >> 24); \ + total_sums += convert_int(elem) * scale * shared_y; \ + \ + shared_y = sub_group_broadcast(y.s0, 3); \ + elem = (char)(bits8.s6 & 0x000000FF); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 3); \ + elem = (char)((bits8.s6 & 0x0000FF00) >> 8); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 3); \ + elem = (char)((bits8.s6 & 0x00FF0000) >> 16); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 3); \ + elem = (char)((bits8.s6 & 0xFF000000) >> 24); \ + total_sums += convert_int(elem) * scale * shared_y; \ + \ + shared_y = sub_group_broadcast(y.s4, 3); \ + elem = (char)(bits8.s7 & 0x000000FF); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 3); \ + elem = (char)((bits8.s7 & 0x0000FF00) >> 8); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 3); \ + elem = (char)((bits8.s7 & 0x00FF0000) >> 16); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 3); \ + elem = (char)((bits8.s7 & 0xFF000000) >> 24); \ + total_sums += convert_int(elem) * scale * shared_y; \ + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +__kernel void kernel_gemv_noshuffle( + __read_only image1d_buffer_t src0_q, // quantized A + global half * src0_d, // A scales + __read_only image1d_buffer_t src1, // B + ulong offset1, // offset to B (0) + global float * dst, // C + ulong offsetd, // offset to C + int ne00, // K + int ne01, // M + int ne02, // 1 + int ne10, // K + int ne12, // 1 + int ne0, // M + int ne1, // N + int r2, // 1 + int r3) +{ + uint groupId = get_local_id(1); + uint gid = get_global_id(0); + ushort slid = get_sub_group_local_id(); + + uint K = ne00; + uint M = ne01; + + uint LINE_STRIDE_A = M; + uint BLOCK_STRIDE_A = 8 * M; // 32 / 4 = 8 + + __private uint8 regA; + __private half regS; + __private float8 regB; + + __private float totalSum = (float)(0.0f); + + // loop along K in block granularity, skip 4 blocks every iter + #pragma unroll 1 /* tell compiler not to unroll */ + for (uint k = groupId; k < (K / QK8_0); k += N_SIMDGROUP) { + regS = src0_d[gid + k * LINE_STRIDE_A]; // each fiber loads scale of one rows + // first 4 fibers in each wave load 8 B values to its private scope + if (slid < 4) { + regB.s0123 = read_imagef(src1, (slid * 2 + k * 8)); + regB.s4567 = read_imagef(src1, (1 + slid * 2 + k * 8)); + } + + // load weights for one block in consecutive rows + regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 0)).x; + regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 1)).x; + regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 2)).x; + regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 3)).x; + regA.s4 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 4)).x; + regA.s5 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 5)).x; + regA.s6 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x; + regA.s7 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x; + + dequantizeBlockAccum_ns_sgbroadcast_1(totalSum, regA, regS, regB); + } + + // reduction in local memory, assumes #wave=4 + __local float reduceLM[SIMDGROUP_WIDTH * 3]; + if (groupId == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = totalSum; + if (groupId == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = totalSum; + if (groupId == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = totalSum; + barrier(CLK_LOCAL_MEM_FENCE); + if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 0 + slid]; + if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 1 + slid]; + if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 2 + slid]; + + // 1 outputs per fiber in wave 0 + if (groupId == 0) { + dst = (global float*)((global char*)dst + offsetd); + dst[gid] = totalSum; + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl b/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl new file mode 100644 index 00000000000..51ce2121ce2 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl @@ -0,0 +1,129 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable + +#ifdef cl_qcom_reqd_sub_group_size +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_128 +#endif + +kernel void kernel_mul_mm_q8_0_f32_8x4( + global const uint * src0_q, + global const half * src0_d, + __read_only image1d_buffer_t src1, + global float * dst, + int k, + int m, + int n, + int n_no_padding, + ulong offsetd +) { + + int m_4 = m >> 2; + int n_4 = n >> 2; + + int gy = get_global_id(0); + int gx = get_global_id(1); + int gx_2 = gx << 2; + dst = (global float *)((global char*)dst + offsetd); + + + half8 c0 = 0, c1 = 0, c2 = 0, c3 = 0; + half8 B; + half4 deq; + + __global const uint* wptr = src0_q + gx_2; + __global const half* sptr = src0_d + gx_2; + + for (int i = 0; i < k; i += 4) { + uint4 pack4 = vload4(0, wptr + (i / 4) * m); + half4 scale = vload4(0, sptr + (i / 32) * m); + + char4 p0 = as_char4(pack4.s0); + char4 p1 = as_char4(pack4.s1); + char4 p2 = as_char4(pack4.s2); + char4 p3 = as_char4(pack4.s3); + + // ------------------- j = 0 (k = i+0) ------------------- + B.s0123 = read_imageh(src1, gy * 2 + (i + 0) * n_4); + B.s4567 = read_imageh(src1, gy * 2 + (i + 0) * n_4 + 1); + + half4 wj0 = convert_half4((char4)(p0.s0, p1.s0, p2.s0, p3.s0)) * scale; + + c0 += B * wj0.s0; + c1 += B * wj0.s1; + c2 += B * wj0.s2; + c3 += B * wj0.s3; + + // ------------------- j = 1 (k = i+1) ------------------- + B.s0123 = read_imageh(src1, gy * 2 + (i + 1) * n_4); + B.s4567 = read_imageh(src1, gy * 2 + (i + 1) * n_4 + 1); + + half4 wj1 = convert_half4((char4)(p0.s1, p1.s1, p2.s1, p3.s1)) * scale; + + c0 += B * wj1.s0; + c1 += B * wj1.s1; + c2 += B * wj1.s2; + c3 += B * wj1.s3; + + // ------------------- j = 2 (k = i+2) ------------------- + B.s0123 = read_imageh(src1, gy * 2 + (i + 2) * n_4); + B.s4567 = read_imageh(src1, gy * 2 + (i + 2) * n_4 + 1); + + half4 wj2 = convert_half4((char4)(p0.s2, p1.s2, p2.s2, p3.s2)) * scale; + + c0 += B * wj2.s0; + c1 += B * wj2.s1; + c2 += B * wj2.s2; + c3 += B * wj2.s3; + + // ------------------- j = 3 (k = i+3) ------------------- + B.s0123 = read_imageh(src1, gy * 2 + (i + 3) * n_4); + B.s4567 = read_imageh(src1, gy * 2 + (i + 3) * n_4 + 1); + + half4 wj3 = convert_half4((char4)(p0.s3, p1.s3, p2.s3, p3.s3)) * scale; + + c0 += B * wj3.s0; + c1 += B * wj3.s1; + c2 += B * wj3.s2; + c3 += B * wj3.s3; + } + + int idx = (gy << 3) * m + (gx << 2); + + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s0, c1.s0, c2.s0, c3.s0), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s1, c1.s1, c2.s1, c3.s1), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s2, c1.s2, c2.s2, c3.s2), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s3, c1.s3, c2.s3, c3.s3), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s4, c1.s4, c2.s4, c3.s4), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s5, c1.s5, c2.s5, c3.s5), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s6, c1.s6, c2.s6, c3.s6), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s7, c1.s7, c2.s7, c3.s7), 0, dst + idx); + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q6_k.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32.cl similarity index 100% rename from ggml/src/ggml-opencl/kernels/mul_mv_q6_k.cl rename to ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32.cl diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl new file mode 100644 index 00000000000..86fe09c6dd6 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl @@ -0,0 +1,194 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +//------------------------------------------------------------------------------ +// kernel_mul_mv_q6_K_f32_flat +//------------------------------------------------------------------------------ +#define Q6_K_MASK1 0x03 +#define Q6_K_MASK2 0x0C +#define Q6_K_MASK3 0x30 +#define Q6_K_MASK4 0xC0 + +#define QK_K 256 + +inline float block_q_6_K_dot_y_flat( + global uchar * blk_ql, + global uchar * blk_qh, + global char * blk_scales, + global half * blk_d, + global float * yy, + int ib, + int ip, + int is, + int l0 +) { + int y_offset = 128*ip + l0; + int q_offset_l = 64*ip + l0; + int q_offset_h = 32*ip + l0; + + global uchar * q1 = blk_ql + ib*128 + q_offset_l; + global uchar * q2 = q1 + QK_K/8; + global uchar * qh = blk_qh + ib*64 + q_offset_h; + global char * sc = blk_scales + ib*16 + is; + + global float * y = yy + ib * QK_K + y_offset; + + float dall = blk_d[ib]; + + float sumf = 0; + float4 sums = {0.f, 0.f, 0.f, 0.f}; + + sums.s0 += y[0+ 0] * ((float)((q1[0] & 0xF) | ((qh[0] & Q6_K_MASK1) << 4)) - 32.f); + sums.s1 += y[0+32] * ((float)((q2[0] & 0xF) | ((qh[0] & Q6_K_MASK2) << 2)) - 32.f); + sums.s2 += y[0+64] * ((float)((q1[0] >> 4) | ((qh[0] & Q6_K_MASK3) << 0)) - 32.f); + sums.s3 += y[0+96] * ((float)((q2[0] >> 4) | ((qh[0] & Q6_K_MASK4) >> 2)) - 32.f); + + sums.s0 += y[1+ 0] * ((float)((q1[1] & 0xF) | ((qh[1] & Q6_K_MASK1) << 4)) - 32.f); + sums.s1 += y[1+32] * ((float)((q2[1] & 0xF) | ((qh[1] & Q6_K_MASK2) << 2)) - 32.f); + sums.s2 += y[1+64] * ((float)((q1[1] >> 4) | ((qh[1] & Q6_K_MASK3) << 0)) - 32.f); + sums.s3 += y[1+96] * ((float)((q2[1] >> 4) | ((qh[1] & Q6_K_MASK4) >> 2)) - 32.f); + + sums.s0 += y[2+ 0] * ((float)((q1[2] & 0xF) | ((qh[2] & Q6_K_MASK1) << 4)) - 32.f); + sums.s1 += y[2+32] * ((float)((q2[2] & 0xF) | ((qh[2] & Q6_K_MASK2) << 2)) - 32.f); + sums.s2 += y[2+64] * ((float)((q1[2] >> 4) | ((qh[2] & Q6_K_MASK3) << 0)) - 32.f); + sums.s3 += y[2+96] * ((float)((q2[2] >> 4) | ((qh[2] & Q6_K_MASK4) >> 2)) - 32.f); + + sums.s0 += y[3+ 0] * ((float)((q1[3] & 0xF) | ((qh[3] & Q6_K_MASK1) << 4)) - 32.f); + sums.s1 += y[3+32] * ((float)((q2[3] & 0xF) | ((qh[3] & Q6_K_MASK2) << 2)) - 32.f); + sums.s2 += y[3+64] * ((float)((q1[3] >> 4) | ((qh[3] & Q6_K_MASK3) << 0)) - 32.f); + sums.s3 += y[3+96] * ((float)((q2[3] >> 4) | ((qh[3] & Q6_K_MASK4) >> 2)) - 32.f); + + sumf += dall * (sums.s0 * sc[0] + sums.s1 * sc[2] + sums.s2 * sc[4] + sums.s3 * sc[6]); + + return sumf; +} + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 4 +#define N_SIMDGROUP 2 +#define N_SIMDWIDTH 16 +#elif defined (ADRENO_GPU) +#define N_DST 4 +#define N_SIMDGROUP 2 +#define N_SIMDWIDTH 64 +#endif + +#define BLOCK_STRIDE (N_SIMDWIDTH/16) // number of blocks each subgroup processes + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mv_q6_K_f32_flat( + global uchar * src0_ql, + global uchar * src0_qh, + global char * src0_s, + global half * src0_d, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + int nb = ne00/QK_K; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + int i12 = im%ne12; + int i13 = im/ne12; + + int first_row = (N_SIMDGROUP * r0 + get_sub_group_id()) * N_DST; + + ulong offset_src0 = first_row*nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + ulong offset_src0_ql = offset_src0 * 128; + ulong offset_src0_qh = offset_src0 * 64; + ulong offset_src0_s = offset_src0 * 16; + ulong offset_src0_d = offset_src0; + + global uchar * blk_ql = (global uchar *) src0_ql + offset_src0_ql; + global uchar * blk_qh = (global uchar *) src0_qh + offset_src0_qh; + global char * blk_scales = (global char *) src0_s + offset_src0_s; + global half * blk_d = (global half *) src0_d + offset_src0_d; + global float * yy = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + int tid = get_sub_group_local_id()/BLOCK_STRIDE; // first block_stride groups have tid=0 + int ix = get_sub_group_local_id()%BLOCK_STRIDE; // first block is 0..block_stride-1 + int ip = tid/8; // first or second half of (super) block (0 or 1) + int il = tid%8; // each half has 8 parts, one per scale + int n = 4; // 4 scales at a time (and 4 sums) + int l0 = n*il; // offset into half-block, 0..28 + int is = 8*ip + l0/16; // 0, 1, 8, 9 + + float4 sumf = 0; + + for (int ib = ix; ib < nb; ib += BLOCK_STRIDE) { + if (first_row + 0 < ne01) { + sumf.s0 += block_q_6_K_dot_y_flat(blk_ql + 0*nb*128, blk_qh + 0*nb*64, blk_scales + 0*nb*16, blk_d + 0*nb, yy, ib, ip, is, l0); + } + if (first_row + 1 < ne01) { + sumf.s1 += block_q_6_K_dot_y_flat(blk_ql + 1*nb*128, blk_qh + 1*nb*64, blk_scales + 1*nb*16, blk_d + 1*nb, yy, ib, ip, is, l0); + } + if (first_row + 2 < ne01) { + sumf.s2 += block_q_6_K_dot_y_flat(blk_ql + 2*nb*128, blk_qh + 2*nb*64, blk_scales + 2*nb*16, blk_d + 2*nb, yy, ib, ip, is, l0); + } + if (first_row + 3 < ne01) { + sumf.s3 += block_q_6_K_dot_y_flat(blk_ql + 3*nb*128, blk_qh + 3*nb*64, blk_scales + 3*nb*16, blk_d + 3*nb, yy, ib, ip, is, l0); + } + } + + float4 tot = (float4)( + sub_group_reduce_add(sumf.s0), + sub_group_reduce_add(sumf.s1), + sub_group_reduce_add(sumf.s2), + sub_group_reduce_add(sumf.s3) + ); + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + } + if (first_row + 1 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + } + if (first_row + 2 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + } + if (first_row + 3 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/repeat.cl b/ggml/src/ggml-opencl/kernels/repeat.cl index 079498f5ab9..53951a55434 100644 --- a/ggml/src/ggml-opencl/kernels/repeat.cl +++ b/ggml/src/ggml-opencl/kernels/repeat.cl @@ -1,39 +1,38 @@ -kernel void kernel_repeat( - global const char * src0_data_in, - global char * dst_data_in, - ulong src0_offset, - ulong dst_offset, - int src0_ne0, int src0_ne1, int src0_ne2, int src0_ne3, - ulong src0_nb0, ulong src0_nb1, ulong src0_nb2, ulong src0_nb3, - int dst_ne0, int dst_ne1, int dst_ne2, int dst_ne3, - ulong dst_nb0, ulong dst_nb1, ulong dst_nb2, ulong dst_nb3 +kernel void kernel_repeat_f32( + global const char * src0, + ulong offset0, + global char * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 ) { - global const char * src0_data = src0_data_in + src0_offset; - global char * dst_data = dst_data_in + dst_offset; + src0 = src0 + offset0; + dst = dst + offsetd; - const int d3 = get_global_id(2); - const int d2 = get_global_id(1); - const int d1 = get_global_id(0); + const int i3 = get_group_id(2); + const int i2 = get_group_id(1); + const int i1 = get_group_id(0); - if (d3 >= dst_ne3 || d2 >= dst_ne2 || d1 >= dst_ne1) { - return; - } - - const int s3 = d3 % src0_ne3; - const int s2 = d2 % src0_ne2; - const int s1 = d1 % src0_ne1; - - const global char * p_src0_slice = src0_data + (ulong)s3*src0_nb3 + (ulong)s2*src0_nb2 + (ulong)s1*src0_nb1; - global char * p_dst_slice = dst_data + (ulong)d3*dst_nb3 + (ulong)d2*dst_nb2 + (ulong)d1*dst_nb1; + const int i03 = i3%ne03; + const int i02 = i2%ne02; + const int i01 = i1%ne01; - for (int d0 = 0; d0 < dst_ne0; ++d0) { - // Determine source index for dimension 0 based on tiling/broadcasting. - const int s0 = d0 % src0_ne0; + global const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; + global char * dst_ptr = dst + i3*nb3 + i2*nb2 + i1*nb1; - const global char * restrict current_src_el_ptr = p_src0_slice + (ulong)s0*src0_nb0; - global char * restrict current_dst_el_ptr = p_dst_slice + (ulong)d0*dst_nb0; - for (int k = 0; k < src0_nb0; ++k) { - current_dst_el_ptr[k] = current_src_el_ptr[k]; - } + for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) { + const int i00 = i0%ne00; + *((global float *)(dst_ptr + i0*nb0)) = *((global float *)(src0_ptr + i00*nb00)); } } diff --git a/ggml/src/ggml-opencl/kernels/scale.cl b/ggml/src/ggml-opencl/kernels/scale.cl index aeca8a456e4..17ed97f0d66 100644 --- a/ggml/src/ggml-opencl/kernels/scale.cl +++ b/ggml/src/ggml-opencl/kernels/scale.cl @@ -1,9 +1,19 @@ #pragma OPENCL EXTENSION cl_khr_fp16 : enable -//------------------------------------------------------------------------------ -// scale -//------------------------------------------------------------------------------ -kernel void kernel_scale( +kernel void kernel_scale_f32( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd, + float scale, + float bias +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + dst[get_global_id(0)] = src0[get_global_id(0)] * scale + bias; +} + +kernel void kernel_scale_f32_4( global float4 * src0, ulong offset0, global float4 * dst, diff --git a/ggml/src/ggml-opencl/kernels/tanh.cl b/ggml/src/ggml-opencl/kernels/tanh.cl index d9da86b1489..2c4887ad3e0 100644 --- a/ggml/src/ggml-opencl/kernels/tanh.cl +++ b/ggml/src/ggml-opencl/kernels/tanh.cl @@ -1,63 +1,109 @@ #pragma OPENCL EXTENSION cl_khr_fp16 : enable -#ifdef cl_intel_required_subgroup_size -#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable -#define INTEL_GPU 1 -#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) -#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) -#elif defined(cl_qcom_reqd_sub_group_size) -#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable -#define ADRENO_GPU 1 -#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) -#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) -#endif - -kernel void kernel_tanh_f32_nd( - global void * p_src0_base, ulong off_src0_abs, - global void * p_dst_base, ulong off_dst_abs, - int ne00, int ne01, int ne02, int ne03, - ulong nb00, ulong nb01, ulong nb02, ulong nb03, - int ne10, int ne11, int ne12, int ne13, - ulong nb10, ulong nb11, ulong nb12, ulong nb13 +kernel void kernel_tanh_f32( + global const float * src0, + ulong offset0, + global float * dst, + ulong offsetd ) { - int i0 = get_global_id(0); - int i1 = get_global_id(1); - int i2 = get_global_id(2); + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); - if (i0 < ne10 && i1 < ne11 && i2 < ne12) { - for (int i3 = 0; i3 < ne13; ++i3) { - ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03; - global const float *src_val_ptr = (global const float *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor); + dst[get_global_id(0)] = tanh(src0[get_global_id(0)]); +} + +kernel void kernel_tanh_f32_4( + global const float4 * src0, + ulong offset0, + global float4 * dst, + ulong offsetd +) { + src0 = (global float4*)((global char*)src0 + offset0); + dst = (global float4*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = tanh(src0[get_global_id(0)]); +} + +kernel void kernel_tanh_f16( + global const half * src0, + ulong offset0, + global half * dst, + ulong offsetd +) { + src0 = (global half*)((global char*)src0 + offset0); + dst = (global half*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = tanh(src0[get_global_id(0)]); +} + +kernel void kernel_tanh_f16_4( + global const half4 * src0, + ulong offset0, + global half4 * dst, + ulong offsetd +) { + src0 = (global half4*)((global char*)src0 + offset0); + dst = (global half4*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = tanh(src0[get_global_id(0)]); +} + +kernel void kernel_tanh_f32_nc( + global const char * src0, + ulong offset0, + global char * dst, + ulong offsetd, + int ne00, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = src0 + offset0; + dst = dst + offsetd; + + const int i3 = get_group_id(2); + const int i2 = get_group_id(1); + const int i1 = get_group_id(0); - ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13; - global float *dst_val_ptr = (global float *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor); + for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) { + global const float * x = (global const float *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global float * y = (global float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - *dst_val_ptr = tanh(*src_val_ptr); - } + *y = tanh(*x); } } -kernel void kernel_tanh_f16_nd( - global void * p_src0_base, ulong off_src0_abs, - global void * p_dst_base, ulong off_dst_abs, - int ne00, int ne01, int ne02, int ne03, - ulong nb00, ulong nb01, ulong nb02, ulong nb03, - int ne10, int ne11, int ne12, int ne13, - ulong nb10, ulong nb11, ulong nb12, ulong nb13 +kernel void kernel_tanh_f16_nc( + global const char * src0, + ulong offset0, + global char * dst, + ulong offsetd, + int ne00, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 ) { - int i0 = get_global_id(0); - int i1 = get_global_id(1); - int i2 = get_global_id(2); + src0 = src0 + offset0; + dst = dst + offsetd; - if (i0 < ne10 && i1 < ne11 && i2 < ne12) { - for (int i3 = 0; i3 < ne13; ++i3) { - ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03; - global const half *src_val_ptr = (global const half *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor); + const int i3 = get_group_id(2); + const int i2 = get_group_id(1); + const int i1 = get_group_id(0); - ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13; - global half *dst_val_ptr = (global half *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor); + for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) { + global const half * x = (global const half *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global half * y = (global half *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - *dst_val_ptr = tanh(*src_val_ptr); - } + *y = tanh(*x); } } diff --git a/ggml/src/ggml-opencl/kernels/tri.cl b/ggml/src/ggml-opencl/kernels/tri.cl new file mode 100644 index 00000000000..35cdd543bc5 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/tri.cl @@ -0,0 +1,32 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +//------------------------------------------------------------------------------ +// tri +//------------------------------------------------------------------------------ +__kernel void kernel_tri_f32( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd, + int n, + int ne0, + int ne1, + int tri_type +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + int idx = get_global_id(0); + if (idx >= n) return; + + int i0 = idx % ne0; + int i1 = (idx / ne0) % ne1; + + int keep = 0; + if (tri_type == 0) keep = (i0 >= i1); + else if (tri_type == 1) keep = (i0 > i1); + else if (tri_type == 2) keep = (i0 <= i1); + else keep = (i0 < i1); + + dst[idx] = keep ? src0[idx] : 0.0f; +} diff --git a/ggml/src/ggml-sycl/CMakeLists.txt b/ggml/src/ggml-sycl/CMakeLists.txt index 5a89d8dd688..eefdd9725ca 100644 --- a/ggml/src/ggml-sycl/CMakeLists.txt +++ b/ggml/src/ggml-sycl/CMakeLists.txt @@ -1,7 +1,7 @@ message(STATUS "GGML_SYCL_TARGET=${GGML_SYCL_TARGET}") -if (NOT GGML_SYCL_TARGET MATCHES "^(INTEL|NVIDIA|AMD)$") - message(FATAL_ERROR "Invalid backend chosen, supported options are INTEL, NVIDIA, or AMD") +if (NOT GGML_SYCL_TARGET MATCHES "^(INTEL)$") + message(FATAL_ERROR "GGML_SYCL_TARGET: Invalid target, the supported options are [INTEL]") endif() check_cxx_compiler_flag("-fsycl" SUPPORTS_SYCL) @@ -125,106 +125,27 @@ endif() target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_DNNL=${GGML_SYCL_DNNL}) if (GGML_SYCL_F16) - if (GGML_SYCL_TARGET STREQUAL "AMD") - message(WARNING "AMD target does not entirely support FP16 in the SYCL backend.") - endif() add_compile_definitions(GGML_SYCL_F16) endif() if (GGML_SYCL_TARGET STREQUAL "INTEL") add_compile_definitions(GGML_SYCL_WARP_SIZE=16) target_link_options(ggml-sycl PRIVATE -Xs -ze-intel-greater-than-4GB-buffer-required) -elseif (GGML_SYCL_TARGET STREQUAL "NVIDIA") - add_compile_definitions(GGML_SYCL_WARP_SIZE=32) -elseif (GGML_SYCL_TARGET STREQUAL "AMD") - # INFO: Allowed Sub_group_sizes are not consistent through all - # hip targets. For example, 64 is used for certain models, but the backend - # does not support it. - # Target archs tested working: gfx1030, gfx1031, (Only tested sub_group_size = 32) - add_compile_definitions(GGML_SYCL_WARP_SIZE=32) -else() - # default for other target - add_compile_definitions(GGML_SYCL_WARP_SIZE=32) -endif() - -if (GGML_SYCL_GRAPH) - target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_GRAPH) -endif() -# Link against Intel oneMKL or oneMath -if (GGML_SYCL_TARGET STREQUAL "INTEL") - # Intel devices use Intel oneMKL directly instead of oneMath to avoid the limitation of linking Intel oneMKL statically - # See https://github.com/uxlfoundation/oneMath/issues/654 + # Link against Intel oneMKL if (CMAKE_CXX_COMPILER_ID STREQUAL "Clang") set(SYCL_COMPILER ON) endif() find_package(MKL REQUIRED) target_link_libraries(ggml-sycl PRIVATE MKL::MKL_SYCL::BLAS) - target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_USE_INTEL_ONEMKL) else() - find_package(oneMath QUIET) - if (NOT oneMath_FOUND) - message(STATUS "oneMath not found: oneMath will be automatically downloaded") - # Use FetchContent to automatically pull and build oneMath - include(FetchContent) - set(BUILD_FUNCTIONAL_TESTS False) - set(BUILD_EXAMPLES False) - set(TARGET_DOMAINS blas) - if (GGML_SYCL_TARGET STREQUAL "NVIDIA") - set(ENABLE_MKLCPU_BACKEND False) - set(ENABLE_MKLGPU_BACKEND False) - set(ENABLE_CUBLAS_BACKEND True) - elseif (GGML_SYCL_TARGET STREQUAL "AMD") - set(ENABLE_MKLCPU_BACKEND False) - set(ENABLE_MKLGPU_BACKEND False) - set(ENABLE_ROCBLAS_BACKEND True) - # Ensure setting a string variable here is not overriden by oneMath CACHE variables - cmake_policy(SET CMP0126 NEW) - # Setting the device architecture is only needed and useful for AMD devices in oneMath - set(HIP_TARGETS ${GGML_SYCL_DEVICE_ARCH} CACHE STRING "oneMath HIP target" FORCE) - endif() - FetchContent_Declare( - ONEMATH - GIT_REPOSITORY https://github.com/uxlfoundation/oneMath.git - GIT_TAG 8efe85f5aaebb37f1d8c503b7af66315feabf142 - ) - FetchContent_MakeAvailable(ONEMATH) - # Create alias to match with find_package targets name - function(onemath_alias target) - if (TARGET ${target}_obj) - # Silence verbose warnings from external libraries - target_compile_options(${target}_obj PRIVATE -w) - endif() - if (TARGET ${target}) - add_library(ONEMATH::${target} ALIAS ${target}) - endif() - endfunction() - onemath_alias(onemath) - onemath_alias(onemath_blas_mklcpu) - onemath_alias(onemath_blas_mklgpu) - onemath_alias(onemath_blas_cublas) - onemath_alias(onemath_blas_rocblas) - endif() + # default for other target + message(FATAL_ERROR "GGML_SYCL_TARGET is not supported") + add_compile_definitions(GGML_SYCL_WARP_SIZE=32) +endif() - # Below oneMath compile-time dispatching is used for better performance - if (GGML_SYCL_TARGET STREQUAL "NVIDIA") - target_link_libraries(ggml-sycl PRIVATE ONEMATH::onemath_blas_cublas) - target_compile_options(ggml-sycl PRIVATE "-fsycl-targets=nvptx64-nvidia-cuda") - target_link_options(ggml-sycl PRIVATE "-fsycl-targets=nvptx64-nvidia-cuda") - target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_NVIDIA) - elseif (GGML_SYCL_TARGET STREQUAL "AMD") - if (NOT GGML_SYCL_DEVICE_ARCH) - message(FATAL_ERROR "Can't enable SYCL hip backend, GGML_SYCL_DEVICE_ARCH has not been set.") - endif() - target_link_libraries(ggml-sycl PRIVATE ONEMATH::onemath_blas_rocblas) - target_compile_options(ggml-sycl PRIVATE "-fsycl-targets=amdgcn-amd-amdhsa") - target_link_options(ggml-sycl PRIVATE "-fsycl-targets=amdgcn-amd-amdhsa") - target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_AMD) - else() - # Fallback to oneMath runtime dispatcher - target_link_libraries(ggml-sycl PRIVATE ONEMATH::onemath) - target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_GENERIC) - endif() +if (GGML_SYCL_GRAPH) + target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_GRAPH) endif() if (GGML_SYCL_DEVICE_ARCH) diff --git a/ggml/src/ggml-sycl/dpct/helper.hpp b/ggml/src/ggml-sycl/dpct/helper.hpp index 30ec1e8dafc..ece66a7ac1f 100644 --- a/ggml/src/ggml-sycl/dpct/helper.hpp +++ b/ggml/src/ggml-sycl/dpct/helper.hpp @@ -15,18 +15,9 @@ #include #include -#include -#include - -#ifdef GGML_SYCL_USE_INTEL_ONEMKL #include -// Allow to use the same namespace for Intel oneMKL and oneMath -namespace oneapi { - namespace math = mkl; -} -#else -#include -#endif + +#include #include "ggml.h" @@ -92,32 +83,13 @@ inline std::string get_device_backend_and_type(const sycl::device &device) { } template struct matrix_info_t { - oneapi::math::transpose transpose_info[2]; + oneapi::mkl::transpose transpose_info[2]; Ts value_info[2]; std::int64_t size_info[3]; std::int64_t ld_info[3]; std::int64_t groupsize_info; }; -inline auto get_onemath_backend(sycl::queue& queue) -#if defined(GGML_SYCL_GENERIC) || defined(GGML_SYCL_USE_INTEL_ONEMKL) - -> sycl::queue& -#endif -{ -// If the backend is known at compile-time, use oneMath backend_selector to use -// compile-time dispatching and avoid the need to dlopen libraries. Otherwise -// fallback to runtime dispatching. -#if defined(GGML_SYCL_NVIDIA) - return oneapi::math::backend_selector{ queue }; -#elif defined(GGML_SYCL_AMD) - return oneapi::math::backend_selector{ queue }; -#elif defined(GGML_SYCL_GENERIC) || defined(GGML_SYCL_USE_INTEL_ONEMKL) - return queue; -#else - static_assert(false, "Unsupported backend"); -#endif -} - namespace dpct { typedef sycl::queue *queue_ptr; @@ -1735,7 +1707,7 @@ namespace dpct namespace detail { template - inline void gemm_impl(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m, + inline void gemm_impl(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m, int n, int k, const void * alpha, const void * a, int lda, const void * b, int ldb, const void * beta, void * c, int ldc) { Ts alpha_value = dpct::get_value(reinterpret_cast(alpha), q); @@ -1743,7 +1715,7 @@ namespace dpct auto data_a = get_memory(a); auto data_b = get_memory(b); auto data_c = get_memory(c); - oneapi::math::blas::column_major::gemm(get_onemath_backend(q), a_trans, b_trans, m, n, k, alpha_value, data_a, + oneapi::mkl::blas::column_major::gemm(q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda, data_b, ldb, beta_value, data_c, ldc); } @@ -1775,7 +1747,7 @@ namespace dpct }; template - inline void gemm_batch_impl(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, + inline void gemm_batch_impl(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m, int n, int k, const void * alpha, const void ** a, int lda, const void ** b, int ldb, const void * beta, void ** c, int ldc, int batch_size, matrix_info_t * matrix_info) { @@ -1794,8 +1766,8 @@ namespace dpct matrix_info->ld_info[2] = ldc; matrix_info->groupsize_info = batch_size; - sycl::event e = oneapi::math::blas::column_major::gemm_batch( - get_onemath_backend(q), matrix_info->transpose_info, matrix_info->transpose_info + 1, + sycl::event e = oneapi::mkl::blas::column_major::gemm_batch( + q, matrix_info->transpose_info, matrix_info->transpose_info + 1, matrix_info->size_info, matrix_info->size_info + 1, matrix_info->size_info + 2, reinterpret_cast(matrix_info->value_info), reinterpret_cast(a), matrix_info->ld_info, reinterpret_cast(b), matrix_info->ld_info + 1, @@ -1804,7 +1776,7 @@ namespace dpct } template - inline void gemm_batch_impl(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, + inline void gemm_batch_impl(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m, int n, int k, const void * alpha, const void * a, int lda, long long int stride_a, const void * b, int ldb, long long int stride_b, const void * beta, void * c, int ldc, long long int stride_c, int batch_size) { @@ -1813,7 +1785,7 @@ namespace dpct auto data_a = get_memory(a); auto data_b = get_memory(b); auto data_c = get_memory(c); - oneapi::math::blas::column_major::gemm_batch(get_onemath_backend(q), a_trans, b_trans, m, n, k, alpha_value, + oneapi::mkl::blas::column_major::gemm_batch(q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda, stride_a, data_b, ldb, stride_b, beta_value, data_c, ldc, stride_c, batch_size); } @@ -2300,7 +2272,7 @@ namespace dpct sycl::range<3>(x, y, 1), direction); } - inline void gemm(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m, int n, + inline void gemm(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m, int n, int k, const void * alpha, const void * a, library_data_t a_type, int lda, const void * b, library_data_t b_type, int ldb, const void * beta, void * c, library_data_t c_type, int ldc, library_data_t scaling_type) { @@ -2367,7 +2339,7 @@ namespace dpct library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_float, library_data_t::real_float): { - detail::gemm_impl( + detail::gemm_impl( q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); break; } @@ -2406,7 +2378,7 @@ namespace dpct library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_float): { - detail::gemm_impl( + detail::gemm_impl( q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); break; } @@ -2448,7 +2420,7 @@ namespace dpct /// \param [in] ldc Leading dimension of C. /// \param [in] batch_size Specifies the number of matrix multiply operations to perform. /// \param [in] scaling_type Data type of the scaling factors. - inline void gemm_batch(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m, + inline void gemm_batch(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m, int n, int k, const void * alpha, const void * a[], library_data_t a_type, int lda, const void * b[], library_data_t b_type, int ldb, const void * beta, void * c[], library_data_t c_type, int ldc, int batch_size, library_data_t scaling_type, @@ -2486,7 +2458,7 @@ namespace dpct library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_float): { - detail::gemm_batch_impl( + detail::gemm_batch_impl( q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info); break; } @@ -2494,7 +2466,7 @@ namespace dpct library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_float, library_data_t::real_float): { - detail::gemm_batch_impl( + detail::gemm_batch_impl( q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info); break; } @@ -2570,7 +2542,7 @@ namespace dpct /// \param [in] stride_c Stride between the different C matrices. /// \param [in] batch_size Specifies the number of matrix multiply operations to perform. /// \param [in] scaling_type Data type of the scaling factors. - inline void gemm_batch(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m, + inline void gemm_batch(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m, int n, int k, const void * alpha, const void * a, library_data_t a_type, int lda, long long int stride_a, const void * b, library_data_t b_type, int ldb, long long int stride_b, const void * beta, void * c, library_data_t c_type, int ldc, @@ -2643,7 +2615,7 @@ namespace dpct library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_float): { - detail::gemm_batch_impl( + detail::gemm_batch_impl( q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size); break; @@ -2652,7 +2624,7 @@ namespace dpct library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_float, library_data_t::real_float): { - detail::gemm_batch_impl( + detail::gemm_batch_impl( q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size); break; diff --git a/ggml/src/ggml-sycl/element_wise.cpp b/ggml/src/ggml-sycl/element_wise.cpp index 8d83b2446bd..00d54b83f82 100644 --- a/ggml/src/ggml-sycl/element_wise.cpp +++ b/ggml/src/ggml-sycl/element_wise.cpp @@ -123,6 +123,15 @@ static __dpct_inline__ T op_log(T x) { return sycl::log(x); } +template +static __dpct_inline__ T op_softplus(T x) { + const float xf = (float) x; + const float ax = sycl::fabs(xf); + const float m = sycl::fmax(xf, 0.0f); + const float y = m + sycl::log1p(sycl::exp(-ax)); + return (T) y; +} + template static __dpct_inline__ T op_neg(T x) { return -x; @@ -695,6 +704,12 @@ static inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor }); } +static inline void ggml_sycl_op_softplus(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) { + return op_softplus(x); + }); +} + static inline void ggml_sycl_op_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) { return op_neg(x); @@ -821,16 +836,9 @@ static inline void ggml_sycl_op_floor(ggml_backend_sycl_context & ctx, ggml_tens } static inline void ggml_sycl_op_ceil(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, - [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { - const int num_blocks = ceil_div(k_elements, 256); - stream->parallel_for( - sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256), - sycl::range<1>(256)), - [=](sycl::nd_item<1> item_ct1) { - unary_op_ceil_kernel(src, dst_ptr, k_elements, item_ct1); - }); - }); + ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) { + return op_ceil(x); + }); } static inline void ggml_sycl_op_round(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { @@ -1101,6 +1109,11 @@ void ggml_sycl_log(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { ggml_sycl_op_log(ctx, dst); } +void ggml_sycl_softplus(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_softplus(ctx, dst); +} + void ggml_sycl_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_neg(ctx, dst); diff --git a/ggml/src/ggml-sycl/element_wise.hpp b/ggml/src/ggml-sycl/element_wise.hpp index 0913a2e529b..7c71974687a 100644 --- a/ggml/src/ggml-sycl/element_wise.hpp +++ b/ggml/src/ggml-sycl/element_wise.hpp @@ -61,6 +61,8 @@ void ggml_sycl_exp(ggml_backend_sycl_context & ctx, ggml_tensor * dst); void ggml_sycl_log(ggml_backend_sycl_context & ctx, ggml_tensor * dst); +void ggml_sycl_softplus(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + void ggml_sycl_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst); void ggml_sycl_step(ggml_backend_sycl_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index bb8acc922b9..0614d7e8f3a 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -1157,13 +1157,28 @@ static const char * ggml_backend_sycl_host_buffer_type_name(ggml_backend_buffer_ GGML_UNUSED(buft); } +inline void * aligned_malloc_host(size_t alignment, size_t size) { +#ifdef _WIN32 + return _aligned_malloc(size, alignment); +#else + return aligned_alloc(alignment, size); +#endif +} + +inline void free_aligned_mem_host(void * memblock) { +#ifdef _WIN32 + _aligned_free(memblock); +#else + free(memblock); +#endif +} + static void ggml_backend_sycl_host_buffer_free_buffer(ggml_backend_buffer_t buffer) { - ggml_sycl_host_free(buffer->context); + free_aligned_mem_host((void *)buffer->context); } static ggml_backend_buffer_t ggml_backend_sycl_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { - void * ptr = ggml_sycl_host_malloc(size); - + void * ptr = aligned_malloc_host(TENSOR_ALIGNMENT, size); if (ptr == nullptr) { // fallback to cpu buffer return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size); @@ -1825,6 +1840,110 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols, } } +static void top_k_f32_sycl( + const float * src, + int32_t * dst_indices, + const int64_t ncols, + const int64_t nrows, + const int k, + dpct::queue_ptr main_stream +) { + const int block_size = 128; + + const sycl::range<1> block_dims(block_size); + const sycl::range<1> grid_dims(nrows); + + main_stream->submit([&](sycl::handler &cgh) { + sycl::local_accessor shared_vals(sycl::range<1>(block_size * k), cgh); + sycl::local_accessor shared_idx(sycl::range<1>(block_size * k), cgh); + + cgh.parallel_for( + sycl::nd_range<1>(grid_dims * block_dims, block_dims), + [=](sycl::nd_item<1> item_ct1) { + const int row = item_ct1.get_group(0); + const int tid = item_ct1.get_local_id(0); + + if (row >= nrows) return; + + const float * src_row = src + row * ncols; + int32_t * dst_idx_row = dst_indices + row * k; + + float local_vals[32]; + int local_idx[32]; + + for (int i = 0; i < k; i++) { + local_vals[i] = -FLT_MAX; + local_idx[i] = -1; + } + + for (int col = tid; col < ncols; col += block_size) { + float val = src_row[col]; + + if (val > local_vals[k-1]) { + int pos = k - 1; + while (pos > 0 && val > local_vals[pos - 1]) { + pos--; + } + + for (int i = k - 1; i > pos; i--) { + local_vals[i] = local_vals[i - 1]; + local_idx[i] = local_idx[i - 1]; + } + local_vals[pos] = val; + local_idx[pos] = col; + } + } + + for (int i = 0; i < k; i++) { + shared_vals[tid * k + i] = local_vals[i]; + shared_idx[tid * k + i] = local_idx[i]; + } + item_ct1.barrier(sycl::access::fence_space::local_space); + + if (tid == 0) { + float final_vals[32]; + int final_idx[32]; + + for (int i = 0; i < k; i++) { + final_vals[i] = -FLT_MAX; + final_idx[i] = -1; + } + + for (int t = 0; t < block_size; t++) { + for (int i = 0; i < k; i++) { + float val = shared_vals[t * k + i]; + int idx = shared_idx[t * k + i]; + + if (val > final_vals[k-1]) { + int pos = k - 1; + while (pos > 0 && val > final_vals[pos - 1]) { + pos--; + } + + for (int j = k - 1; j > pos; j--) { + final_vals[j] = final_vals[j - 1]; + final_idx[j] = final_idx[j - 1]; + } + final_vals[pos] = val; + final_idx[pos] = idx; + } + } + } + + for (int i = 0; i < k; i++) { + dst_idx_row[i] = final_idx[i]; + } + + if (k > 1) { + int32_t temp = dst_idx_row[0]; + dst_idx_row[0] = dst_idx_row[1]; + dst_idx_row[1] = temp; + } + } + }); + }); +} + static void argmax_f32_i32_sycl(const float *x, int *dst, const int ncols, const int nrows, queue_ptr stream) { const sycl::range<3> block_dims(1, 1, SYCL_ARGMAX_BLOCK_SIZE); @@ -2048,8 +2167,8 @@ inline void ggml_sycl_op_mul_mat_sycl( const sycl::half alpha_f16 = 1.0f; const sycl::half beta_f16 = 0.0f; SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm( - *stream, oneapi::math::transpose::trans, - oneapi::math::transpose::nontrans, row_diff, src1_ncols, ne10, + *stream, oneapi::mkl::transpose::trans, + oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10, &alpha_f16, src0_ptr, dpct::library_data_t::real_half, ne00, src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16, dst_f16.get(), dpct::library_data_t::real_half, ldc, @@ -2092,8 +2211,8 @@ inline void ggml_sycl_op_mul_mat_sycl( { const float alpha = 1.0f; const float beta = 0.0f; - SYCL_CHECK(CHECK_TRY_ERROR(oneapi::math::blas::column_major::gemm( - get_onemath_backend(*stream), oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, row_diff, + SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm( + *stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10, dpct::get_value(&beta, *stream), dst_dd_i, ldc))); } @@ -2216,6 +2335,30 @@ inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * main_stream, ctx.device); } +static void ggml_sycl_op_top_k(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(src0); + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_I32); + GGML_ASSERT(ggml_is_contiguous(src0)); + + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + + const float * src0_dd = static_cast(src0->data); + int32_t * dst_dd = static_cast(dst->data); + + const int k = dst->ne[0]; + const int64_t ncols = src0->ne[0]; + const int64_t nrows = ggml_nrows(src0); + + GGML_ASSERT(k > 0 && k <= 32); + GGML_ASSERT(k <= ncols); + + top_k_f32_sycl(src0_dd, dst_dd, ncols, nrows, k, main_stream); +} + inline void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_I32); @@ -2248,6 +2391,65 @@ inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_ten diag_mask_inf_f32_sycl(src0_dd, dst_dd, ne00, nrows0, ne01, n_past, main_stream); } +static void tri_f32_sycl( + const float * src, + float * dst, + const int64_t ne0, + const int64_t ne1, + const int64_t ne2, + const int64_t ne3, + const ggml_tri_type ttype, + dpct::queue_ptr main_stream +) { + const size_t total = (size_t) ne0 * (size_t) ne1 * (size_t) ne2 * (size_t) ne3; + + main_stream->parallel_for(sycl::range<1>(total), [=](sycl::id<1> tid) { + const int64_t idx = (int64_t) tid[0]; + + const int64_t i0 = idx % ne0; + const int64_t t1 = idx / ne0; + const int64_t i1 = t1 % ne1; + + bool keep = false; + switch (ttype) { + case GGML_TRI_TYPE_LOWER: keep = (i0 < i1); break; + case GGML_TRI_TYPE_LOWER_DIAG: keep = (i0 <= i1); break; + case GGML_TRI_TYPE_UPPER: keep = (i0 > i1); break; + case GGML_TRI_TYPE_UPPER_DIAG: keep = (i0 >= i1); break; + default: keep = false; break; + } + + dst[idx] = keep ? src[idx] : 0.0f; + }); +} + +static void ggml_sycl_op_tri(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + GGML_ASSERT(src0); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + + const float * src0_dd = static_cast(src0->data); + float * dst_dd = static_cast(dst->data); + + const ggml_tri_type ttype = (ggml_tri_type) ggml_get_op_params_i32(dst, 0); + + const int64_t ne0 = src0->ne[0]; + const int64_t ne1 = src0->ne[1]; + const int64_t ne2 = src0->ne[2]; + const int64_t ne3 = src0->ne[3]; + + tri_f32_sycl(src0_dd, dst_dd, ne0, ne1, ne2, ne3, ttype, main_stream); +} + + inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -2963,8 +3165,8 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons const int64_t smb = ne12 == 1 ? s13 : s12; // there is no broadcast and src0, src1 are contiguous across dims 2, 3 - SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::math::transpose::trans, - oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha, + SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::mkl::transpose::trans, + oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha, src0_f16, dpct::library_data_t::real_half, nb01 / nb00, sma, src1_f16, dpct::library_data_t::real_half, s11, smb, beta, dst_ddf, mkl_data_type, ne0, ne1 * ne0, ne12 * ne13, mkl_compute_type))); @@ -2988,7 +3190,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons }); SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch( - *queue, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha, + *queue, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha, (const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00, (const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, s11, beta, (void **) (ptrs_dst.get() + 0 * ne23), mkl_data_type, ne0, ne23, mkl_compute_type, matrix_info.get()))); @@ -3316,18 +3518,17 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor // mmvq and mmq need the __dp4a instruction which is available for gen12+ - // Workaround in https://github.com/ggerganov/llama.cpp/commit/95f84d5ce8b449a9b16009434aca800df504a02e + // Workaround in https://github.com/ggml-org/llama.cpp/commit/95f84d5ce8b449a9b16009434aca800df504a02e use_mul_mat_q = use_mul_mat_q && (src0->type != GGML_TYPE_IQ2_XXS); #ifdef SYCL_USE_XMX use_mul_mat_q = use_mul_mat_q && (src1->ne[1] <= MMQ_MAX_BATCH_SIZE); #endif // SYCL_USE_XMX - // mmvq path is faster in the CUDA backend. - if (!g_ggml_sycl_prioritize_dmmv && (ctx.stream()->get_backend() == sycl::backend::ext_oneapi_cuda - // Dispatch becomes obscure with the reorder, MMVQ when the reorder optimization - // is enabled takes precedence over DMMV, the current if-else implementation - // requires disabling DMMV if both conditions are met - || (should_reorder_tensor(ctx, dst) && ggml_sycl_supports_reorder_mmvq(src0->type)))) { + // Dispatch becomes obscure with the reorder, MMVQ when the reorder optimization + // is enabled takes precedence over DMMV, the current if-else implementation + // requires disabling DMMV if both conditions are met + if (!g_ggml_sycl_prioritize_dmmv && ((should_reorder_tensor(ctx, dst) && + ggml_sycl_supports_reorder_mmvq(src0->type)))) { use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q; } @@ -3771,6 +3972,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg case GGML_UNARY_OP_EXP: ggml_sycl_exp(ctx, dst); break; + case GGML_UNARY_OP_SOFTPLUS: + ggml_sycl_softplus(ctx, dst); + break; case GGML_UNARY_OP_SGN: ggml_sycl_sgn(ctx, dst); break; @@ -3897,6 +4101,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg case GGML_OP_TRANSPOSE: GGML_SYCL_DEBUG("%s: Tensor NO-OP\n", __func__); break; + case GGML_OP_TRI: + ggml_sycl_op_tri(ctx, dst); + break; case GGML_OP_DIAG_MASK_INF: ggml_sycl_diag_mask_inf(ctx, dst); break; @@ -3927,6 +4134,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg case GGML_OP_ARGSORT: ggml_sycl_argsort(ctx, dst); break; + case GGML_OP_TOP_K: + ggml_sycl_op_top_k(ctx, dst); + break; case GGML_OP_TIMESTEP_EMBEDDING: ggml_sycl_op_timestep_embedding(ctx, dst); break; @@ -3978,16 +4188,6 @@ void ggml_backend_sycl_get_device_memory(int device, size_t *free, GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_get_device_memory\n"); ggml_sycl_set_device(device); - /* - DPCT1009:218: SYCL uses exceptions to report errors and does not use the - error codes. The original code was commented out and a warning string was - inserted. You need to rewrite this code. - */ - /* - DPCT1106:217: 'cudaMemGetInfo' was migrated with the Intel extensions for - device information which may not be supported by all compilers or runtimes. - You may need to adjust the code. - */ SYCL_CHECK(CHECK_TRY_ERROR( dpct::dev_mgr::instance().get_device(device).get_memory_info(*free, *total))); } @@ -4389,10 +4589,11 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_GELU_ERF: case GGML_UNARY_OP_EXP: + case GGML_UNARY_OP_SOFTPLUS: case GGML_UNARY_OP_ELU: + case GGML_UNARY_OP_CEIL: return true; case GGML_UNARY_OP_FLOOR: - case GGML_UNARY_OP_CEIL: case GGML_UNARY_OP_ROUND: case GGML_UNARY_OP_TRUNC: #if defined (GGML_SYCL_F16) @@ -4591,18 +4792,23 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g return (op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32) && (op->type == op->src[0]->type); #endif case GGML_OP_NORM: - return true; case GGML_OP_L2_NORM: case GGML_OP_GROUP_NORM: - return ggml_is_contiguous(op->src[0]); case GGML_OP_RMS_NORM: - return ((op->src[0]->ne[0] % WARP_SIZE) == 0); + return true; case GGML_OP_RMS_NORM_BACK: - return ((op->src[0]->ne[0] % WARP_SIZE) == 0); + return ggml_is_contiguous(op->src[0]); case GGML_OP_SCALE: return true; case GGML_OP_CONT: return op->src[0]->type != GGML_TYPE_BF16; + case GGML_OP_TRI: + { + const ggml_tensor * src0 = op->src[0]; + return src0 && + op->type == GGML_TYPE_F32 && + ggml_is_contiguous(src0); + } case GGML_OP_DIAG_MASK_INF: return true; case GGML_OP_SOFT_MAX: @@ -4624,6 +4830,15 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_ARGSORT: return op->src[0]->ne[0] * sizeof(int) <= ggml_sycl_info().devices[device].smpbo; + case GGML_OP_TOP_K: { + const ggml_tensor * src0 = op->src[0]; + const int k = op->ne[0]; + return src0 && + op->type == GGML_TYPE_I32 && + src0->type == GGML_TYPE_F32 && + ggml_is_contiguous(src0) && + k > 0 && k <= 32; + } case GGML_OP_POOL_2D: case GGML_OP_ACC: return true; diff --git a/ggml/src/ggml-sycl/norm.cpp b/ggml/src/ggml-sycl/norm.cpp index 823d3a4828c..00702b5d09c 100644 --- a/ggml/src/ggml-sycl/norm.cpp +++ b/ggml/src/ggml-sycl/norm.cpp @@ -251,7 +251,6 @@ static void norm_f32_sycl(const float * x, float * dst, const int ncols, const i const float eps, queue_ptr stream, int device) { const sycl::range<3> global_dims(nsamples, nchannels, nrows); - GGML_ASSERT(ncols % WARP_SIZE == 0); if (ncols < 1024) { const sycl::range<3> block_dims(1, 1, WARP_SIZE); stream->submit([&](sycl::handler& cgh) { @@ -334,7 +333,6 @@ static void group_norm_f32_sycl(const float* x, float* dst, static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const int nrows, const int nchannels, const int nsamples, const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, queue_ptr stream, int device) { - GGML_ASSERT(ncols % WARP_SIZE == 0); // printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE); const sycl::range<3> global_dims(nsamples, nchannels, nrows); @@ -374,7 +372,6 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols, const int nrows, const float eps, queue_ptr stream, int device) { - GGML_ASSERT(ncols % WARP_SIZE == 0); // printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE); if (ncols < 1024) { const sycl::range<3> block_dims(1, 1, WARP_SIZE); diff --git a/ggml/src/ggml-sycl/outprod.cpp b/ggml/src/ggml-sycl/outprod.cpp index 3a17f3a1b88..f52b11f0d6e 100644 --- a/ggml/src/ggml-sycl/outprod.cpp +++ b/ggml/src/ggml-sycl/outprod.cpp @@ -32,12 +32,12 @@ void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { // Handle transposition of src1 const bool src1_T = ggml_is_transposed(src1); - const oneapi::math::transpose src1_op = src1_T ? oneapi::math::transpose::nontrans : oneapi::math::transpose::trans; + const oneapi::mkl::transpose src1_op = src1_T ? oneapi::mkl::transpose::nontrans : oneapi::mkl::transpose::trans; const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float); try { - // Perform matrix multiplication using oneMath GEMM - oneapi::math::blas::column_major::gemm(get_onemath_backend(*stream), oneapi::math::transpose::nontrans, src1_op, + // Perform matrix multiplication using oneMKL GEMM + oneapi::mkl::blas::column_major::gemm(*stream, oneapi::mkl::transpose::nontrans, src1_op, ne0, ne1, ne01, alpha, src0_d, ne00, src1_d, ldb, beta, dst_d, ne0); } catch (sycl::exception const& exc) { diff --git a/ggml/src/ggml-sycl/rope.cpp b/ggml/src/ggml-sycl/rope.cpp index 69140b19a4c..aeaa58b95b3 100644 --- a/ggml/src/ggml-sycl/rope.cpp +++ b/ggml/src/ggml-sycl/rope.cpp @@ -207,7 +207,6 @@ static void rope_vision(const T * x, T * dst, const int ne0, const int ne1, cons const int p = sector; theta_base = pos[channel_x] * sycl::pow(theta_scale, (float) p); } else { - // Simplified from CUDA backend code: if (sector >= sections.v[0] && sector < sec_w) which is just sector >= sections.v[0] const int p = sector - sections.v[0]; theta_base = pos[channel_x + ne2] * sycl::pow(theta_scale, (float) p); } diff --git a/ggml/src/ggml-sycl/wkv.cpp b/ggml/src/ggml-sycl/wkv.cpp index c10e2f7645e..b56e0c2400f 100644 --- a/ggml/src/ggml-sycl/wkv.cpp +++ b/ggml/src/ggml-sycl/wkv.cpp @@ -1,7 +1,7 @@ #include #include "wkv.hpp" -constexpr int WKV_BLOCK_SIZE = 64; // Matching CUDA_WKV_BLOCK_SIZE +constexpr int WKV_BLOCK_SIZE = 64; // Helper function for the main kernel template diff --git a/ggml/src/ggml-virtgpu/CMakeLists.txt b/ggml/src/ggml-virtgpu/CMakeLists.txt new file mode 100644 index 00000000000..e6b020beb5b --- /dev/null +++ b/ggml/src/ggml-virtgpu/CMakeLists.txt @@ -0,0 +1,70 @@ +cmake_minimum_required(VERSION 3.19) +cmake_policy(SET CMP0114 NEW) + +include(ExternalProject) + +message(STATUS "Including the VirtGPU/Virglrenderer API Remoting") + +# Download venus_hw.h from virglrenderer repository +ExternalProject_Add( + venus_hw_header + URL https://gitlab.freedesktop.org/virgl/virglrenderer/-/raw/virglrenderer-1.2.0/src/venus_hw.h + DOWNLOAD_NO_EXTRACT YES + DOWNLOAD_DIR ${CMAKE_CURRENT_SOURCE_DIR}/include + DOWNLOAD_NAME venus_hw.h + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND "" + LOG_DOWNLOAD ON +) + +if (NOT GGML_VIRTGPU_BACKEND STREQUAL "ONLY") + message(STATUS "Enable the VirtGPU/Virglrenderer API Remoting frontend library") + + find_package(PkgConfig REQUIRED) + pkg_check_modules(DRM REQUIRED libdrm) + if (NOT GGML_BACKEND_DL) + # cannot simply use USE_VIRTGPU, as in the 'else()' case the + # frontend isn't compiled + target_compile_definitions(ggml PUBLIC "GGML_USE_VIRTGPU_FRONTEND") + endif() + + ggml_add_backend_library(ggml-virtgpu + ggml-backend-buffer.cpp + ggml-backend.cpp + ggml-backend-device.cpp + ggml-backend-reg.cpp + ggml-backend-buffer-type.cpp + virtgpu-apir.h + virtgpu-forward.gen.h + virtgpu.cpp + virtgpu-shm.cpp + virtgpu-utils.cpp + virtgpu-forward-device.cpp + virtgpu-forward-buffer-type.cpp + virtgpu-forward-buffer.cpp + virtgpu-forward-backend.cpp + virtgpu-forward-impl.h + apir_cs_ggml-rpc-front.cpp + ../../include/ggml-virtgpu.h) + + target_include_directories(ggml-virtgpu PUBLIC /usr/include/libdrm/) + + target_link_libraries(ggml-virtgpu PUBLIC ${DRM_LIBRARIES}) + target_include_directories(ggml-virtgpu PUBLIC ${DRM_INCLUDE_DIRS}) + target_compile_options(ggml-virtgpu PUBLIC ${DRM_CFLAGS_OTHER}) + + target_include_directories(ggml-virtgpu PUBLIC ./include) + target_include_directories(ggml-virtgpu PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) + + # Ensure venus_hw.h is downloaded before building ggml-virtgpu + add_dependencies(ggml-virtgpu venus_hw_header) + + target_compile_options(ggml-virtgpu PRIVATE -std=c++20) +else() + message(STATUS "Not building the VirtGPU/Virglrenderer API Remoting frontend library") +endif() + +if (NOT GGML_VIRTGPU_BACKEND STREQUAL "OFF") + add_subdirectory("backend") +endif() diff --git a/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp b/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp new file mode 100644 index 00000000000..d2e87330a63 --- /dev/null +++ b/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp @@ -0,0 +1,87 @@ +#include "backend/shared/apir_cs_rpc.h" +#include "ggml-backend-impl.h" +#include "ggml-impl.h" +#include "ggml-remoting.h" + +#include +#include +#include +#include + +apir_rpc_tensor apir_serialize_tensor(const ggml_tensor * tensor) { + apir_rpc_tensor result; + result.id = reinterpret_cast(tensor); + result.type = tensor->type; + if (tensor->buffer) { + ggml_backend_buffer_t buffer = tensor->buffer; + + result.buffer = BUFFER_TO_HOST_HANDLE(buffer); + } else { + result.buffer = 0; + } + for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) { + result.ne[i] = tensor->ne[i]; + result.nb[i] = tensor->nb[i]; + } + result.op = tensor->op; + for (uint32_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) { + result.op_params[i] = tensor->op_params[i]; + } + result.flags = tensor->flags; + for (uint32_t i = 0; i < GGML_MAX_SRC; i++) { + result.src[i] = reinterpret_cast(tensor->src[i]); + } + result.view_src = reinterpret_cast(tensor->view_src); + result.view_offs = tensor->view_offs; + result.data = reinterpret_cast(tensor->data); + if (tensor->data) { + if (!tensor->buffer) { + GGML_ABORT("%s: tensor has data but not buffer", __func__); + } + // tensor->data is serialized as an offset to the buffer base address + result.data -= reinterpret_cast(BUFFER_TO_GGML_CONTEXT(tensor->buffer)->base); + } + snprintf(result.name, GGML_MAX_NAME, "%s", tensor->name); + return result; +} + +void apir_add_tensor(ggml_tensor * tensor, + std::vector & tensors, + std::unordered_set & visited) { + if (tensor == nullptr) { + return; + } + if (visited.find(tensor) != visited.end()) { + return; + } + visited.insert(tensor); + for (int i = 0; i < GGML_MAX_SRC; i++) { + apir_add_tensor(tensor->src[i], tensors, visited); + } + apir_add_tensor(tensor->view_src, tensors, visited); + tensors.push_back(apir_serialize_tensor(tensor)); +} + +void apir_serialize_graph(const ggml_cgraph * cgraph, std::vector & output) { + uint32_t n_nodes = cgraph->n_nodes; + std::vector tensors; + std::unordered_set visited; + for (uint32_t i = 0; i < n_nodes; i++) { + apir_add_tensor(cgraph->nodes[i], tensors, visited); + } + // serialization format: + // | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(apir_rpc_tensor)) | + uint32_t n_tensors = tensors.size(); + int output_size = + sizeof(uint32_t) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t) + n_tensors * sizeof(apir_rpc_tensor); + output.resize(output_size, 0); + memcpy(output.data(), &n_nodes, sizeof(n_nodes)); + for (uint32_t i = 0; i < n_nodes; i++) { + memcpy(output.data() + sizeof(n_nodes) + i * sizeof(uint64_t), &cgraph->nodes[i], sizeof(uint64_t)); + } + uint32_t * out_ntensors = (uint32_t *) (output.data() + sizeof(n_nodes) + n_nodes * sizeof(uint64_t)); + *out_ntensors = n_tensors; + apir_rpc_tensor * out_tensors = + (apir_rpc_tensor *) (output.data() + sizeof(n_nodes) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t)); + memcpy(out_tensors, tensors.data(), n_tensors * sizeof(apir_rpc_tensor)); +} diff --git a/ggml/src/ggml-virtgpu/backend/CMakeLists.txt b/ggml/src/ggml-virtgpu/backend/CMakeLists.txt new file mode 100644 index 00000000000..0b49c403b9a --- /dev/null +++ b/ggml/src/ggml-virtgpu/backend/CMakeLists.txt @@ -0,0 +1,21 @@ +cmake_minimum_required(VERSION 3.19) +cmake_policy(SET CMP0114 NEW) + +message(STATUS "Enable the VirtGPU/Virglrenderer backend library") + +ggml_add_backend_library(ggml-virtgpu-backend + backend.cpp + backend-dispatched.cpp + backend-dispatched-backend.cpp + backend-dispatched-device.cpp + backend-dispatched-buffer.cpp + backend-dispatched-buffer-type.cpp + shared/api_remoting.h + shared/apir_backend.h + shared/apir_cs.h + apir_cs_ggml-rpc-back.cpp) + +target_compile_options(ggml-virtgpu-backend PRIVATE -std=c++20) + +# Add include directory for ggml-backend-impl.h and other core headers +target_include_directories(ggml-virtgpu-backend PRIVATE ../..) diff --git a/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp b/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp new file mode 100644 index 00000000000..60a8a93bfb8 --- /dev/null +++ b/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp @@ -0,0 +1,115 @@ +#include "ggml-backend-impl.h" +#include "ggml-impl.h" +#include "shared/apir_cs_rpc.h" + +#include +#include +#include +#include + +std::unordered_set backend_buffers; + +void apir_track_backend_buffer(ggml_backend_buffer_t buffer) { + backend_buffers.insert(buffer); +} + +bool apir_untrack_backend_buffer(ggml_backend_buffer_t buffer) { + auto it = backend_buffers.find(buffer); + if (it == backend_buffers.end()) { + return false; + } + + backend_buffers.erase(it); + return true; +} + +std::unordered_set apir_get_track_backend_buffers() { + return backend_buffers; +} + +ggml_tensor * apir_deserialize_tensor(ggml_context * ctx, const apir_rpc_tensor * tensor) { + ggml_tensor * result = + ggml_new_tensor_4d(ctx, (ggml_type) tensor->type, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); + for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) { + result->nb[i] = tensor->nb[i]; + } + result->buffer = reinterpret_cast(tensor->buffer); + if (result->buffer && backend_buffers.find(result->buffer) == backend_buffers.end()) { + printf("WARNING: HOST BUFFER NOT FOUND | %p\n", (void *) result->buffer); + result->buffer = nullptr; + } + + uint64_t tensor_data = tensor->data; + if (result->buffer) { + // require that the tensor data does not go beyond the buffer end + uint64_t tensor_size = (uint64_t) ggml_nbytes(result); + uint64_t buffer_start = (uint64_t) ggml_backend_buffer_get_base(result->buffer); + uint64_t buffer_size = (uint64_t) ggml_backend_buffer_get_size(result->buffer); + + // tensor->data is serialized as an offset to the buffer base address + tensor_data += buffer_start; + + GGML_ASSERT(tensor_data + tensor_size >= tensor_data); // check for overflow + GGML_ASSERT(tensor_data >= buffer_start && tensor_data + tensor_size <= buffer_start + buffer_size); + } + + result->op = (ggml_op) tensor->op; + for (uint32_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) { + result->op_params[i] = tensor->op_params[i]; + } + result->flags = tensor->flags; + result->data = reinterpret_cast(tensor_data); + ggml_set_name(result, tensor->name); + return result; +} + +ggml_tensor * apir_create_node(uint64_t id, + ggml_context * ctx, + const std::unordered_map & tensor_ptrs, + std::unordered_map & tensor_map) { + if (id == 0) { + return nullptr; + } + if (tensor_map.find(id) != tensor_map.end()) { + return tensor_map[id]; + } + const apir_rpc_tensor * tensor = tensor_ptrs.at(id); + ggml_tensor * result = apir_deserialize_tensor(ctx, tensor); + if (result == nullptr) { + return nullptr; + } + tensor_map[id] = result; + for (int i = 0; i < GGML_MAX_SRC; i++) { + result->src[i] = apir_create_node(tensor->src[i], ctx, tensor_ptrs, tensor_map); + } + result->view_src = apir_create_node(tensor->view_src, ctx, tensor_ptrs, tensor_map); + result->view_offs = tensor->view_offs; + return result; +} + +ggml_cgraph * apir_deserialize_graph(uint32_t n_nodes, + uint32_t n_tensors, + const apir_rpc_tensor * tensors, + const uint64_t * nodes) { + size_t buf_size = ggml_tensor_overhead() * (n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false); + ggml_init_params params = { + /*.mem_size =*/buf_size, + /*.mem_buffer =*/NULL, + /*.no_alloc =*/true, + }; + ggml_context * ctx = ggml_init(params); + ggml_cgraph * graph = ggml_new_graph_custom(ctx, n_nodes, false); + graph->n_nodes = n_nodes; + std::unordered_map tensor_ptrs; + for (uint32_t i = 0; i < n_tensors; i++) { + tensor_ptrs[tensors[i].id] = &tensors[i]; + } + std::unordered_map tensor_map; + for (uint32_t i = 0; i < n_nodes; i++) { + int64_t id; + memcpy(&id, &nodes[i], sizeof(id)); + graph->nodes[i] = apir_create_node(id, ctx, tensor_ptrs, tensor_map); + } + + return graph; +} diff --git a/ggml/src/ggml-virtgpu/backend/backend-convert.h b/ggml/src/ggml-virtgpu/backend/backend-convert.h new file mode 100644 index 00000000000..1978d21f7ef --- /dev/null +++ b/ggml/src/ggml-virtgpu/backend/backend-convert.h @@ -0,0 +1,13 @@ +#include "shared/apir_backend.h" + +#define BUFFER_TO_HOST_HANDLE(name) ggml_buffer_to_apir_handle(name) + +static inline apir_buffer_host_handle_t ggml_buffer_to_apir_handle(ggml_backend_buffer_t buffer) { + // in the backend, the buffer handle is the buffer pointer + return (apir_buffer_host_handle_t) buffer; +} + +static inline apir_buffer_type_host_handle_t ggml_buffer_type_to_apir_handle(ggml_backend_buffer_type_t buft) { + // in the backend, the buffer handle is the buffer pointer + return (apir_buffer_type_host_handle_t) buft; +} diff --git a/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp b/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp new file mode 100644 index 00000000000..cc879e51d04 --- /dev/null +++ b/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp @@ -0,0 +1,65 @@ +#include "backend-dispatched.h" +#include "backend-virgl-apir.h" +#include "ggml-backend-impl.h" +#include "ggml-backend.h" +#include "ggml-impl.h" +#include "shared/apir_backend.h" + +#include + +uint32_t backend_backend_graph_compute(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { + GGML_UNUSED(ctx); + GGML_UNUSED(enc); + + static bool async_backend_initialized = false; + static bool async_backend; + + if (!async_backend_initialized) { + ggml_backend_dev_props props; + + dev->iface.get_props(dev, &props); + async_backend = props.caps.async; + async_backend_initialized = true; + } + + uint32_t shmem_res_id; + apir_decode_virtgpu_shmem_res_id(dec, &shmem_res_id); + + const void * shmem_data = ctx->iface->get_shmem_ptr(ctx->ctx_id, shmem_res_id); + if (!shmem_data) { + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Couldn't get the shmem addr from virgl\n", __func__); + apir_decoder_set_fatal(dec); + return 1; + } + size_t cgraph_size; + apir_decode_size_t(dec, &cgraph_size); + + apir_decoder secondary_dec = apir_new_decoder((const char *) shmem_data, cgraph_size); + + ggml_cgraph * cgraph = apir_decode_ggml_cgraph(&secondary_dec, cgraph_size); + + ggml_status status; +#if APIR_BACKEND_CHECK_SUPPORTS_OP == 1 + for (int idx = 0; idx < cgraph->n_nodes; idx++) { + ggml_tensor * op = ggml_graph_node(cgraph, idx); + if (dev->iface.supports_op(dev, op)) { + continue; + } + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Graph node %d (%s) not supported by the backend\n", idx, ggml_op_desc(op)); + + status = GGML_STATUS_ABORTED; + apir_encode_ggml_status(enc, &status); + + return 0; + } +#endif + status = bck->iface.graph_compute(bck, cgraph); + + if (async_backend) { + bck->iface.synchronize(bck); + } + + apir_encode_ggml_status(enc, &status); + + return 0; +} diff --git a/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp b/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp new file mode 100644 index 00000000000..d55eec27610 --- /dev/null +++ b/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp @@ -0,0 +1,93 @@ +#include "backend-dispatched.h" +#include "backend-virgl-apir.h" +#include "ggml-backend-impl.h" +#include "ggml-backend.h" +#include "ggml-impl.h" + +#include + +uint32_t backend_buffer_type_get_name(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { + GGML_UNUSED(ctx); + ggml_backend_buffer_type_t buft; + buft = apir_decode_ggml_buffer_type(dec); + + const char * string = buft->iface.get_name(buft); + + const size_t string_size = strlen(string) + 1; + apir_encode_array_size(enc, string_size); + apir_encode_char_array(enc, string, string_size); + + return 0; +} + +uint32_t backend_buffer_type_get_alignment(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { + GGML_UNUSED(ctx); + ggml_backend_buffer_type_t buft; + buft = apir_decode_ggml_buffer_type(dec); + + size_t value = buft->iface.get_alignment(buft); + apir_encode_size_t(enc, &value); + + return 0; +} + +uint32_t backend_buffer_type_get_max_size(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { + GGML_UNUSED(ctx); + ggml_backend_buffer_type_t buft; + buft = apir_decode_ggml_buffer_type(dec); + + size_t value = SIZE_MAX; + if (buft->iface.get_max_size) { + value = buft->iface.get_max_size(buft); + } + + apir_encode_size_t(enc, &value); + + return 0; +} + +/* APIR_COMMAND_TYPE_BUFFER_TYPE_IS_HOST is deprecated. Keeping the handler for backward compatibility. */ +uint32_t backend_buffer_type_is_host(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { + GGML_UNUSED(ctx); + GGML_UNUSED(dec); + const bool is_host = false; + + apir_encode_bool_t(enc, &is_host); + + return 0; +} + +uint32_t backend_buffer_type_alloc_buffer(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { + GGML_UNUSED(ctx); + ggml_backend_buffer_type_t buft; + buft = apir_decode_ggml_buffer_type(dec); + + size_t size; + apir_decode_size_t(dec, &size); + + ggml_backend_buffer_t buffer; + + buffer = buft->iface.alloc_buffer(buft, size); + + apir_encode_ggml_buffer(enc, buffer); + + if (buffer) { + apir_track_backend_buffer(buffer); + } + + return 0; +} + +uint32_t backend_buffer_type_get_alloc_size(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { + GGML_UNUSED(ctx); + ggml_backend_buffer_type_t buft; + buft = apir_decode_ggml_buffer_type(dec); + + const ggml_tensor * op = apir_decode_ggml_tensor_inplace(dec); + + size_t value = buft->iface.get_alloc_size(buft, op); + + apir_encode_size_t(enc, &value); + + return 0; +} diff --git a/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp b/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp new file mode 100644 index 00000000000..8cc063ff0a6 --- /dev/null +++ b/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp @@ -0,0 +1,131 @@ +#include "backend-dispatched.h" +#include "backend-virgl-apir.h" +#include "ggml-backend-impl.h" +#include "ggml-backend.h" +#include "ggml-impl.h" + +#include + +uint32_t backend_buffer_get_base(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { + GGML_UNUSED(ctx); + ggml_backend_buffer_t buffer; + buffer = apir_decode_ggml_buffer(dec); + + uintptr_t base = (uintptr_t) buffer->iface.get_base(buffer); + apir_encode_uintptr_t(enc, &base); + + return 0; +} + +uint32_t backend_buffer_set_tensor(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { + GGML_UNUSED(ctx); + GGML_UNUSED(enc); + + ggml_backend_buffer_t buffer; + buffer = apir_decode_ggml_buffer(dec); + + ggml_tensor * tensor; + // safe to remove the const qualifier here + tensor = (ggml_tensor *) (uintptr_t) apir_decode_ggml_tensor(dec); + + uint32_t shmem_res_id; + apir_decode_virtgpu_shmem_res_id(dec, &shmem_res_id); + + size_t offset; + apir_decode_size_t(dec, &offset); + + size_t size; + apir_decode_size_t(dec, &size); + + void * shmem_data = ctx->iface->get_shmem_ptr(ctx->ctx_id, shmem_res_id); + + if (!shmem_data) { + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Couldn't get the shmem addr from virgl\n", __func__); + return 1; + } + + buffer->iface.set_tensor(buffer, tensor, shmem_data, offset, size); + + return 0; +} + +uint32_t backend_buffer_get_tensor(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { + GGML_UNUSED(ctx); + GGML_UNUSED(enc); + + ggml_backend_buffer_t buffer; + buffer = apir_decode_ggml_buffer(dec); + + const ggml_tensor * tensor; + // safe to remove the const qualifier here + tensor = apir_decode_ggml_tensor(dec); + + uint32_t shmem_res_id; + apir_decode_virtgpu_shmem_res_id(dec, &shmem_res_id); + + size_t offset; + apir_decode_size_t(dec, &offset); + + size_t size; + apir_decode_size_t(dec, &size); + + void * shmem_data = ctx->iface->get_shmem_ptr(ctx->ctx_id, shmem_res_id); + if (!shmem_data) { + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Couldn't get the shmem addr from virgl\n", __func__); + return 1; + } + + buffer->iface.get_tensor(buffer, tensor, shmem_data, offset, size); + + return 0; +} + +uint32_t backend_buffer_cpy_tensor(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { + GGML_UNUSED(ctx); + + ggml_backend_buffer_t buffer; + buffer = apir_decode_ggml_buffer(dec); + + const ggml_tensor * src; + // safe to remove the const qualifier here + src = apir_decode_ggml_tensor(dec); + ggml_tensor * dst = (ggml_tensor *) (uintptr_t) apir_decode_ggml_tensor(dec); + + bool ret = buffer->iface.cpy_tensor(buffer, src, (ggml_tensor *) dst); + + apir_encode_bool_t(enc, &ret); + + return 0; +} + +uint32_t backend_buffer_clear(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { + GGML_UNUSED(ctx); + GGML_UNUSED(enc); + + ggml_backend_buffer_t buffer; + buffer = apir_decode_ggml_buffer(dec); + + uint8_t value; + apir_decode_uint8_t(dec, &value); + + buffer->iface.clear(buffer, value); + + return 0; +} + +uint32_t backend_buffer_free_buffer(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { + GGML_UNUSED(ctx); + GGML_UNUSED(enc); + + ggml_backend_buffer_t buffer; + buffer = apir_decode_ggml_buffer(dec); + + if (!apir_untrack_backend_buffer(buffer)) { + GGML_LOG_WARN(GGML_VIRTGPU_BCK "%s: unknown buffer %p\n", __func__, (void *) buffer); + return 1; + } + + buffer->iface.free_buffer(buffer); + + return 0; +} diff --git a/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp b/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp new file mode 100644 index 00000000000..c7acb8b51ce --- /dev/null +++ b/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp @@ -0,0 +1,148 @@ +#include "backend-dispatched.h" +#include "backend-virgl-apir.h" +#include "ggml-backend-impl.h" +#include "ggml-backend.h" +#include "ggml-impl.h" + +#include + +uint32_t backend_device_get_device_count(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { + GGML_UNUSED(ctx); + GGML_UNUSED(ctx); + GGML_UNUSED(dec); + + int32_t dev_count = reg->iface.get_device_count(reg); + apir_encode_int32_t(enc, &dev_count); + + return 0; +} + +uint32_t backend_device_get_count(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { + GGML_UNUSED(ctx); + GGML_UNUSED(ctx); + GGML_UNUSED(dec); + + int32_t dev_count = reg->iface.get_device_count(reg); + apir_encode_int32_t(enc, &dev_count); + + return 0; +} + +uint32_t backend_device_get_name(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { + GGML_UNUSED(ctx); + GGML_UNUSED(dec); + + const char * string = dev->iface.get_name(dev); + + const size_t string_size = strlen(string) + 1; + apir_encode_array_size(enc, string_size); + apir_encode_char_array(enc, string, string_size); + + return 0; +} + +uint32_t backend_device_get_description(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { + GGML_UNUSED(ctx); + GGML_UNUSED(dec); + + const char * string = dev->iface.get_description(dev); + + const size_t string_size = strlen(string) + 1; + apir_encode_array_size(enc, string_size); + apir_encode_char_array(enc, string, string_size); + + return 0; +} + +uint32_t backend_device_get_type(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { + GGML_UNUSED(ctx); + GGML_UNUSED(dec); + + uint32_t type = dev->iface.get_type(dev); + apir_encode_uint32_t(enc, &type); + + return 0; +} + +uint32_t backend_device_get_memory(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { + GGML_UNUSED(ctx); + GGML_UNUSED(dec); + + size_t free, total; + dev->iface.get_memory(dev, &free, &total); + + apir_encode_size_t(enc, &free); + apir_encode_size_t(enc, &total); + + return 0; +} + +uint32_t backend_device_supports_op(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { + GGML_UNUSED(ctx); + + const ggml_tensor * op = apir_decode_ggml_tensor_inplace(dec); + + bool supports_op = dev->iface.supports_op(dev, op); + + apir_encode_bool_t(enc, &supports_op); + + return 0; +} + +uint32_t backend_device_get_buffer_type(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { + GGML_UNUSED(ctx); + GGML_UNUSED(dec); + + ggml_backend_buffer_type_t bufft = dev->iface.get_buffer_type(dev); + + apir_encode_ggml_buffer_type(enc, bufft); + + return 0; +} + +uint32_t backend_device_get_props(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { + GGML_UNUSED(ctx); + GGML_UNUSED(dec); + + ggml_backend_dev_props props; + dev->iface.get_props(dev, &props); + + apir_encode_bool_t(enc, &props.caps.async); + apir_encode_bool_t(enc, &props.caps.host_buffer); + apir_encode_bool_t(enc, &props.caps.buffer_from_host_ptr); + apir_encode_bool_t(enc, &props.caps.events); + + return 0; +} + +uint32_t backend_device_buffer_from_ptr(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { + GGML_UNUSED(ctx); + GGML_UNUSED(dec); + + uint32_t shmem_res_id; + apir_decode_virtgpu_shmem_res_id(dec, &shmem_res_id); + + void * shmem_ptr = ctx->iface->get_shmem_ptr(ctx->ctx_id, shmem_res_id); + if (!shmem_ptr) { + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Couldn't get the shmem addr from virgl\n", __func__); + apir_decoder_set_fatal(dec); + return 1; + } + + size_t size; + apir_decode_size_t(dec, &size); + size_t max_tensor_size; + apir_decode_size_t(dec, &max_tensor_size); + + ggml_backend_buffer_t buffer; + buffer = dev->iface.buffer_from_host_ptr(dev, shmem_ptr, size, max_tensor_size); + + apir_encode_ggml_buffer(enc, buffer); + apir_encode_ggml_buffer_type(enc, buffer->buft); + + if (buffer) { + apir_track_backend_buffer(buffer); + } + + return 0; +} diff --git a/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp b/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp new file mode 100644 index 00000000000..64152eef0d8 --- /dev/null +++ b/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp @@ -0,0 +1,46 @@ +#include "backend-dispatched.h" +#include "backend-virgl-apir.h" + +#include "ggml-backend-impl.h" +#include "ggml-backend.h" +#include "ggml-impl.h" + +#include + +ggml_backend_reg_t reg = NULL; +ggml_backend_dev_t dev = NULL; +ggml_backend_t bck = NULL; + +uint64_t timer_start = 0; +uint64_t timer_total = 0; +uint64_t timer_count = 0; + +uint32_t backend_dispatch_initialize(void * ggml_backend_reg_fct_p) { + if (reg != NULL) { + GGML_LOG_WARN(GGML_VIRTGPU_BCK "%s: already initialized\n", __func__); + return APIR_BACKEND_INITIALIZE_ALREADY_INITED; + } + ggml_backend_reg_t (*ggml_backend_reg_fct)(void) = (ggml_backend_reg_t (*)()) ggml_backend_reg_fct_p; + + reg = ggml_backend_reg_fct(); + if (reg == NULL) { + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: backend registration failed\n", __func__); + return APIR_BACKEND_INITIALIZE_BACKEND_REG_FAILED; + } + + if (!reg->iface.get_device_count(reg)) { + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: backend initialization failed: no device found\n", __func__); + return APIR_BACKEND_INITIALIZE_NO_DEVICE; + } + + dev = reg->iface.get_device(reg, 0); + + if (!dev) { + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: backend initialization failed: no device received\n", __func__); + return APIR_BACKEND_INITIALIZE_NO_DEVICE; + } + + bck = dev->iface.init_backend(dev, NULL); + + return APIR_BACKEND_INITIALIZE_SUCCESS; +} diff --git a/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h b/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h new file mode 100644 index 00000000000..481d7f3150d --- /dev/null +++ b/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h @@ -0,0 +1,131 @@ +#pragma once + +/* device */ +uint32_t backend_device_get_device_count(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); +uint32_t backend_device_get_count(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); +uint32_t backend_device_get_name(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); +uint32_t backend_device_get_description(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); +uint32_t backend_device_get_type(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); +uint32_t backend_device_get_memory(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); +uint32_t backend_device_supports_op(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); +uint32_t backend_device_get_buffer_type(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); +uint32_t backend_device_get_props(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); +uint32_t backend_device_buffer_from_ptr(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); + +/* buffer-type */ +uint32_t backend_buffer_type_get_name(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); +uint32_t backend_buffer_type_get_alignment(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); +uint32_t backend_buffer_type_get_max_size(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); +/* APIR_COMMAND_TYPE_BUFFER_TYPE_IS_HOST is deprecated. Keeping the handler for backward compatibility. */ +uint32_t backend_buffer_type_is_host(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); +uint32_t backend_buffer_type_alloc_buffer(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); +uint32_t backend_buffer_type_get_alloc_size(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); + +/* buffer */ +uint32_t backend_buffer_get_base(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); +uint32_t backend_buffer_set_tensor(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); +uint32_t backend_buffer_get_tensor(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); +uint32_t backend_buffer_cpy_tensor(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); +uint32_t backend_buffer_clear(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); +uint32_t backend_buffer_free_buffer(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); + +/* backend */ +uint32_t backend_backend_graph_compute(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); + +static inline const char * backend_dispatch_command_name(ApirBackendCommandType type) { + switch (type) { + /* device */ + case APIR_COMMAND_TYPE_DEVICE_GET_DEVICE_COUNT: + return "backend_device_get_device_count"; + case APIR_COMMAND_TYPE_DEVICE_GET_COUNT: + return "backend_device_get_count"; + case APIR_COMMAND_TYPE_DEVICE_GET_NAME: + return "backend_device_get_name"; + case APIR_COMMAND_TYPE_DEVICE_GET_DESCRIPTION: + return "backend_device_get_description"; + case APIR_COMMAND_TYPE_DEVICE_GET_TYPE: + return "backend_device_get_type"; + case APIR_COMMAND_TYPE_DEVICE_GET_MEMORY: + return "backend_device_get_memory"; + case APIR_COMMAND_TYPE_DEVICE_SUPPORTS_OP: + return "backend_device_supports_op"; + case APIR_COMMAND_TYPE_DEVICE_GET_BUFFER_TYPE: + return "backend_device_get_buffer_type"; + case APIR_COMMAND_TYPE_DEVICE_GET_PROPS: + return "backend_device_get_props"; + case APIR_COMMAND_TYPE_DEVICE_BUFFER_FROM_PTR: + return "backend_device_buffer_from_ptr"; + /* buffer-type */ + case APIR_COMMAND_TYPE_BUFFER_TYPE_GET_NAME: + return "backend_buffer_type_get_name"; + case APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALIGNMENT: + return "backend_buffer_type_get_alignment"; + case APIR_COMMAND_TYPE_BUFFER_TYPE_GET_MAX_SIZE: + return "backend_buffer_type_get_max_size"; + case APIR_COMMAND_TYPE_BUFFER_TYPE_IS_HOST: + return "backend_buffer_type_is_host (DEPRECATED)"; + case APIR_COMMAND_TYPE_BUFFER_TYPE_ALLOC_BUFFER: + return "backend_buffer_type_alloc_buffer"; + case APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALLOC_SIZE: + return "backend_buffer_type_get_alloc_size"; + /* buffer */ + case APIR_COMMAND_TYPE_BUFFER_GET_BASE: + return "backend_buffer_get_base"; + case APIR_COMMAND_TYPE_BUFFER_SET_TENSOR: + return "backend_buffer_set_tensor"; + case APIR_COMMAND_TYPE_BUFFER_GET_TENSOR: + return "backend_buffer_get_tensor"; + case APIR_COMMAND_TYPE_BUFFER_CPY_TENSOR: + return "backend_buffer_cpy_tensor"; + case APIR_COMMAND_TYPE_BUFFER_CLEAR: + return "backend_buffer_clear"; + case APIR_COMMAND_TYPE_BUFFER_FREE_BUFFER: + return "backend_buffer_free_buffer"; + /* backend */ + case APIR_COMMAND_TYPE_BACKEND_GRAPH_COMPUTE: + return "backend_backend_graph_compute"; + + default: + return "unknown"; + } +} + +extern "C" { +static const backend_dispatch_t apir_backend_dispatch_table[APIR_BACKEND_DISPATCH_TABLE_COUNT] = { + + /* device */ + + /* APIR_COMMAND_TYPE_DEVICE_GET_DEVICE_COUNT = */ backend_device_get_device_count, + /* APIR_COMMAND_TYPE_DEVICE_GET_COUNT = */ backend_device_get_count, + /* APIR_COMMAND_TYPE_DEVICE_GET_NAME = */ backend_device_get_name, + /* APIR_COMMAND_TYPE_DEVICE_GET_DESCRIPTION = */ backend_device_get_description, + /* APIR_COMMAND_TYPE_DEVICE_GET_TYPE = */ backend_device_get_type, + /* APIR_COMMAND_TYPE_DEVICE_GET_MEMORY = */ backend_device_get_memory, + /* APIR_COMMAND_TYPE_DEVICE_SUPPORTS_OP = */ backend_device_supports_op, + /* APIR_COMMAND_TYPE_DEVICE_GET_BUFFER_TYPE = */ backend_device_get_buffer_type, + /* APIR_COMMAND_TYPE_DEVICE_GET_PROPS = */ backend_device_get_props, + /* APIR_COMMAND_TYPE_DEVICE_BUFFER_FROM_PTR = */ backend_device_buffer_from_ptr, + + /* buffer-type */ + + /* APIR_COMMAND_TYPE_BUFFER_TYPE_GET_NAME = */ backend_buffer_type_get_name, + /* APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALIGNMENT = */ backend_buffer_type_get_alignment, + /* APIR_COMMAND_TYPE_BUFFER_TYPE_GET_MAX_SIZE = */ backend_buffer_type_get_max_size, + /* APIR_COMMAND_TYPE_BUFFER_TYPE_IS_HOST = */ backend_buffer_type_is_host /* DEPRECATED */, + /* APIR_COMMAND_TYPE_BUFFER_TYPE_ALLOC_BUFFER = */ backend_buffer_type_alloc_buffer, + /* APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALLOC_SIZE = */ backend_buffer_type_get_alloc_size, + + /* buffer */ + + /* APIR_COMMAND_TYPE_BUFFER_GET_BASE = */ backend_buffer_get_base, + /* APIR_COMMAND_TYPE_BUFFER_SET_TENSOR = */ backend_buffer_set_tensor, + /* APIR_COMMAND_TYPE_BUFFER_GET_TENSOR = */ backend_buffer_get_tensor, + /* APIR_COMMAND_TYPE_BUFFER_CPY_TENSOR = */ backend_buffer_cpy_tensor, + /* APIR_COMMAND_TYPE_BUFFER_CLEAR = */ backend_buffer_clear, + /* APIR_COMMAND_TYPE_BUFFER_FREE_BUFFER = */ backend_buffer_free_buffer, + + /* backend */ + + /* APIR_COMMAND_TYPE_BACKEND_GRAPH_COMPUTE = */ backend_backend_graph_compute, +}; +} diff --git a/ggml/src/ggml-virtgpu/backend/backend-dispatched.h b/ggml/src/ggml-virtgpu/backend/backend-dispatched.h new file mode 100644 index 00000000000..10311631d4f --- /dev/null +++ b/ggml/src/ggml-virtgpu/backend/backend-dispatched.h @@ -0,0 +1,25 @@ +#pragma once + +#include +#include + +#include + +#include "backend-convert.h" +#include "backend-virgl-apir.h" +#include "shared/apir_backend.h" +#include "shared/apir_cs.h" +#include "shared/apir_cs_ggml.h" + +#define GGML_VIRTGPU_BCK "ggml-virtgpu-backend: " + +struct virgl_apir_context { + uint32_t ctx_id; + virgl_apir_callbacks * iface; +}; + +typedef uint32_t (*backend_dispatch_t)(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); + +#include "backend-dispatched.gen.h" + +uint32_t backend_dispatch_initialize(void * ggml_backend_reg_fct_p); diff --git a/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h b/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h new file mode 100644 index 00000000000..44b347f853f --- /dev/null +++ b/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h @@ -0,0 +1,32 @@ +#pragma once + +#include "ggml-backend-impl.h" +#include "ggml-backend.h" +#include "ggml-impl.h" +#include "shared/api_remoting.h" + +#include +#include +#include + +extern ggml_backend_reg_t reg; +extern ggml_backend_dev_t dev; +extern ggml_backend_t bck; + +struct virgl_apir_callbacks { + const char * (*get_config)(uint32_t virgl_ctx_id, const char * key); + void * (*get_shmem_ptr)(uint32_t virgl_ctx_id, uint32_t res_id); +}; + +extern "C" { +ApirLoadLibraryReturnCode apir_backend_initialize(uint32_t virgl_ctx_id, struct virgl_apir_callbacks *virgl_cbs); +void apir_backend_deinit(uint32_t virgl_ctx_id); +uint32_t apir_backend_dispatcher(uint32_t virgl_ctx_id, + virgl_apir_callbacks * virgl_cbs, + uint32_t cmd_type, + char * dec_cur, + const char * dec_end, + char * enc_cur, + const char * enc_end, + char ** enc_cur_after); +} diff --git a/ggml/src/ggml-virtgpu/backend/backend.cpp b/ggml/src/ggml-virtgpu/backend/backend.cpp new file mode 100644 index 00000000000..d93414a078b --- /dev/null +++ b/ggml/src/ggml-virtgpu/backend/backend.cpp @@ -0,0 +1,152 @@ +#include "backend-dispatched.h" +#include "backend-virgl-apir.h" + +#include "shared/api_remoting.h" +#include "shared/apir_backend.h" +#include "shared/apir_cs.h" + +#include +#include + +#include + +#define APIR_LLAMA_CPP_GGML_LIBRARY_PATH_ENV "APIR_LLAMA_CPP_GGML_LIBRARY_PATH" +#define APIR_LLAMA_CPP_GGML_LIBRARY_REG_ENV "APIR_LLAMA_CPP_GGML_LIBRARY_REG" +#define APIR_LLAMA_CPP_LOG_TO_FILE_ENV "APIR_LLAMA_CPP_LOG_TO_FILE" + +#define GGML_DEFAULT_BACKEND_REG "ggml_backend_init" + +static void * backend_library_handle = NULL; +static FILE * apir_logfile = NULL; + +static void log_to_file_callback(enum ggml_log_level level, const char * text, void * user_data) { + FILE * logfile = (FILE *)user_data; + fprintf(logfile, "[%d] %s", level, text); + fflush(logfile); +} + +extern "C" { +void apir_backend_deinit(uint32_t virgl_ctx_id) { + GGML_UNUSED(virgl_ctx_id); + + auto buffers = apir_get_track_backend_buffers(); + for (const auto & buffer : buffers) { + apir_untrack_backend_buffer(buffer); + buffer->iface.free_buffer(buffer); + } + + if (backend_library_handle) { + GGML_LOG_INFO(GGML_VIRTGPU_BCK "The GGML backend library was loaded. Unloading it.\n"); + dlclose(backend_library_handle); + backend_library_handle = NULL; + } + + if (apir_logfile) { + fclose(apir_logfile); + apir_logfile = NULL; + } +} + +#define APIR_GGML_LIBRARY_PATH_KEY "ggml.library.path" +#define APIR_GGML_LIBRARY_REG_KEY "ggml.library.reg" + +ApirLoadLibraryReturnCode apir_backend_initialize(uint32_t virgl_ctx_id, struct virgl_apir_callbacks *virgl_cbs) { + const char * dlsym_error; + + const char * apir_log_to_file = getenv(APIR_LLAMA_CPP_LOG_TO_FILE_ENV); + if (apir_log_to_file) { + apir_logfile = fopen(apir_log_to_file, "w"); + if (apir_logfile) { + ggml_log_set(log_to_file_callback, apir_logfile); + } else { + GGML_LOG_INFO(GGML_VIRTGPU_BCK "Could not open the log file at '%s'\n", apir_log_to_file); + } + } + + const char * library_name = virgl_cbs->get_config(virgl_ctx_id, APIR_GGML_LIBRARY_PATH_KEY); + const char * virgl_library_reg = virgl_cbs->get_config(virgl_ctx_id, APIR_GGML_LIBRARY_REG_KEY); + const char * library_reg = virgl_library_reg ? virgl_library_reg : GGML_DEFAULT_BACKEND_REG; + + if (!library_name) { + GGML_LOG_ERROR(GGML_VIRTGPU_BCK + "%s: cannot open the GGML library: env var '%s' not defined\n", + __func__, APIR_LLAMA_CPP_GGML_LIBRARY_PATH_ENV); + + + return APIR_LOAD_LIBRARY_ENV_VAR_MISSING; + } + + backend_library_handle = dlopen(library_name, RTLD_LAZY); + + if (!backend_library_handle) { + GGML_LOG_ERROR(GGML_VIRTGPU_BCK + "%s: cannot open the GGML library: %s\n", __func__, dlerror()); + + return APIR_LOAD_LIBRARY_CANNOT_OPEN; + } + + if (!library_reg) { + GGML_LOG_ERROR(GGML_VIRTGPU_BCK + "%s: cannot register the GGML library: env var '%s' not defined\n", + __func__, APIR_LLAMA_CPP_GGML_LIBRARY_REG_ENV); + + return APIR_LOAD_LIBRARY_ENV_VAR_MISSING; + } + + void * ggml_backend_reg_fct = dlsym(backend_library_handle, library_reg); + dlsym_error = dlerror(); + if (dlsym_error) { + GGML_LOG_ERROR(GGML_VIRTGPU_BCK + "%s: cannot find the GGML backend registration symbol '%s' (from %s): %s\n", + __func__, library_reg, APIR_LLAMA_CPP_GGML_LIBRARY_REG_ENV, dlsym_error); + + + return APIR_LOAD_LIBRARY_SYMBOL_MISSING; + } + + uint32_t ret = backend_dispatch_initialize(ggml_backend_reg_fct); + + return (ApirLoadLibraryReturnCode) (APIR_LOAD_LIBRARY_INIT_BASE_INDEX + ret); +} + +uint32_t apir_backend_dispatcher(uint32_t virgl_ctx_id, + virgl_apir_callbacks * virgl_cbs, + uint32_t cmd_type, + char * dec_cur, + const char * dec_end, + char * enc_cur, + const char * enc_end, + char ** enc_cur_after) { + apir_encoder enc = { + .cur = enc_cur, + .start = enc_cur, + .end = enc_end, + .fatal = false, + }; + + apir_decoder dec = { + .cur = dec_cur, + .end = dec_end, + .fatal = false, + }; + + virgl_apir_context ctx = { + .ctx_id = virgl_ctx_id, + .iface = virgl_cbs, + }; + + if (cmd_type >= APIR_BACKEND_DISPATCH_TABLE_COUNT) { + GGML_LOG_ERROR(GGML_VIRTGPU_BCK + "%s: Received an invalid dispatch index (%d >= %d)\n", + __func__, cmd_type, APIR_BACKEND_DISPATCH_TABLE_COUNT); + return APIR_BACKEND_FORWARD_INDEX_INVALID; + } + + backend_dispatch_t forward_fct = apir_backend_dispatch_table[cmd_type]; + uint32_t ret = forward_fct(&enc, &dec, &ctx); + + *enc_cur_after = enc.cur; + + return ret; +} +} diff --git a/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h b/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h new file mode 100644 index 00000000000..f19a5d12d17 --- /dev/null +++ b/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h @@ -0,0 +1,90 @@ +#pragma once + +/* the rest of this file must match virglrenderer/src/apir-protocol.h */ + +#include + +#include + +#define APIR_PROTOCOL_MAJOR 0 +#define APIR_PROTOCOL_MINOR 1 + +#define APIR_HANDSHAKE_MAGIC 0xab1e + +enum ApirCommandType { + APIR_COMMAND_TYPE_HANDSHAKE = 0, + APIR_COMMAND_TYPE_LOADLIBRARY = 1, + APIR_COMMAND_TYPE_FORWARD = 2, + + APIR_COMMAND_TYPE_LENGTH = 3, +}; + +typedef uint64_t ApirCommandFlags; + +enum ApirLoadLibraryReturnCode { + APIR_LOAD_LIBRARY_SUCCESS = 0, + APIR_LOAD_LIBRARY_HYPERCALL_INITIALIZATION_ERROR = 1, + APIR_LOAD_LIBRARY_ALREADY_LOADED = 2, + APIR_LOAD_LIBRARY_ENV_VAR_MISSING = 3, + APIR_LOAD_LIBRARY_CANNOT_OPEN = 4, + APIR_LOAD_LIBRARY_SYMBOL_MISSING = 5, + APIR_LOAD_LIBRARY_INIT_BASE_INDEX = 6, // anything above this is a APIR backend library initialization return code +}; + +enum ApirForwardReturnCode { + APIR_FORWARD_SUCCESS = 0, + APIR_FORWARD_NO_DISPATCH_FCT = 1, + APIR_FORWARD_TIMEOUT = 2, + + APIR_FORWARD_BASE_INDEX = 3, // anything above this is a APIR backend library forward return code +} ; + +__attribute__((unused)) static inline const char * apir_command_name(ApirCommandType type) { + switch (type) { + case APIR_COMMAND_TYPE_HANDSHAKE: + return "HandShake"; + case APIR_COMMAND_TYPE_LOADLIBRARY: + return "LoadLibrary"; + case APIR_COMMAND_TYPE_FORWARD: + return "Forward"; + default: + return "unknown"; + } +} + +__attribute__((unused)) static const char * apir_load_library_error(ApirLoadLibraryReturnCode code) { +#define APIR_LOAD_LIBRARY_ERROR(code_name) \ + do { \ + if (code == code_name) \ + return #code_name; \ + } while (0) + + APIR_LOAD_LIBRARY_ERROR(APIR_LOAD_LIBRARY_SUCCESS); + APIR_LOAD_LIBRARY_ERROR(APIR_LOAD_LIBRARY_HYPERCALL_INITIALIZATION_ERROR); + APIR_LOAD_LIBRARY_ERROR(APIR_LOAD_LIBRARY_ALREADY_LOADED); + APIR_LOAD_LIBRARY_ERROR(APIR_LOAD_LIBRARY_ENV_VAR_MISSING); + APIR_LOAD_LIBRARY_ERROR(APIR_LOAD_LIBRARY_CANNOT_OPEN); + APIR_LOAD_LIBRARY_ERROR(APIR_LOAD_LIBRARY_SYMBOL_MISSING); + APIR_LOAD_LIBRARY_ERROR(APIR_LOAD_LIBRARY_INIT_BASE_INDEX); + + return "Unknown APIR_COMMAND_TYPE_LoadLibrary error"; + +#undef APIR_LOAD_LIBRARY_ERROR +} + +__attribute__((unused)) static const char * apir_forward_error(ApirForwardReturnCode code) { +#define APIR_FORWARD_ERROR(code_name) \ + do { \ + if (code == code_name) \ + return #code_name; \ + } while (0) + + APIR_FORWARD_ERROR(APIR_FORWARD_SUCCESS); + APIR_FORWARD_ERROR(APIR_FORWARD_NO_DISPATCH_FCT); + APIR_FORWARD_ERROR(APIR_FORWARD_TIMEOUT); + APIR_FORWARD_ERROR(APIR_FORWARD_BASE_INDEX); + + return "Unknown APIR_COMMAND_TYPE_FORWARD error"; + +#undef APIR_FORWARD_ERROR +} diff --git a/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h b/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h new file mode 100644 index 00000000000..d214b6f2a90 --- /dev/null +++ b/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h @@ -0,0 +1,36 @@ +typedef enum ApirBackendCommandType { + + /* device */ + APIR_COMMAND_TYPE_DEVICE_GET_DEVICE_COUNT = 0, + APIR_COMMAND_TYPE_DEVICE_GET_COUNT = 1, + APIR_COMMAND_TYPE_DEVICE_GET_NAME = 2, + APIR_COMMAND_TYPE_DEVICE_GET_DESCRIPTION = 3, + APIR_COMMAND_TYPE_DEVICE_GET_TYPE = 4, + APIR_COMMAND_TYPE_DEVICE_GET_MEMORY = 5, + APIR_COMMAND_TYPE_DEVICE_SUPPORTS_OP = 6, + APIR_COMMAND_TYPE_DEVICE_GET_BUFFER_TYPE = 7, + APIR_COMMAND_TYPE_DEVICE_GET_PROPS = 8, + APIR_COMMAND_TYPE_DEVICE_BUFFER_FROM_PTR = 9, + + /* buffer-type */ + APIR_COMMAND_TYPE_BUFFER_TYPE_GET_NAME = 10, + APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALIGNMENT = 11, + APIR_COMMAND_TYPE_BUFFER_TYPE_GET_MAX_SIZE = 12, + APIR_COMMAND_TYPE_BUFFER_TYPE_IS_HOST = 13, + APIR_COMMAND_TYPE_BUFFER_TYPE_ALLOC_BUFFER = 14, + APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALLOC_SIZE = 15, + + /* buffer */ + APIR_COMMAND_TYPE_BUFFER_GET_BASE = 16, + APIR_COMMAND_TYPE_BUFFER_SET_TENSOR = 17, + APIR_COMMAND_TYPE_BUFFER_GET_TENSOR = 18, + APIR_COMMAND_TYPE_BUFFER_CPY_TENSOR = 19, + APIR_COMMAND_TYPE_BUFFER_CLEAR = 20, + APIR_COMMAND_TYPE_BUFFER_FREE_BUFFER = 21, + + /* backend */ + APIR_COMMAND_TYPE_BACKEND_GRAPH_COMPUTE = 22, + + // last command_type index + 1 + APIR_BACKEND_DISPATCH_TABLE_COUNT = 23, +} ApirBackendCommandType; diff --git a/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h b/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h new file mode 100644 index 00000000000..f3efa52c721 --- /dev/null +++ b/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h @@ -0,0 +1,46 @@ +#pragma once + +#include "apir_backend.gen.h" + +#include // for uintptr_t +#include // for timespec, clock_gettime + +#define APIR_BACKEND_INITIALIZE_SUCCESS 0 +#define APIR_BACKEND_INITIALIZE_CANNOT_OPEN_BACKEND_LIBRARY 1 +#define APIR_BACKEND_INITIALIZE_CANNOT_OPEN_GGML_LIBRARY 2 +#define APIR_BACKEND_INITIALIZE_MISSING_BACKEND_SYMBOLS 3 +#define APIR_BACKEND_INITIALIZE_MISSING_GGML_SYMBOLS 4 +#define APIR_BACKEND_INITIALIZE_BACKEND_FAILED 5 +#define APIR_BACKEND_INITIALIZE_BACKEND_REG_FAILED 6 +#define APIR_BACKEND_INITIALIZE_ALREADY_INITED 7 +#define APIR_BACKEND_INITIALIZE_NO_DEVICE 8 + + +// new entries here need to be added to the apir_backend_initialize_error function below + +#define APIR_BACKEND_FORWARD_INDEX_INVALID 6 + +// 0 is fast, 1 avoids the backend to crash if an unsupported tensor is received +#define APIR_BACKEND_CHECK_SUPPORTS_OP 0 + +typedef uintptr_t apir_buffer_type_host_handle_t; +typedef uintptr_t apir_buffer_host_handle_t; + +static const char * apir_backend_initialize_error(int code) { +#define APIR_BACKEND_INITIALIZE_ERROR(code_name) \ + do { \ + if (code == code_name) \ + return #code_name; \ + } while (0) + + APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_SUCCESS); + APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_CANNOT_OPEN_BACKEND_LIBRARY); + APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_CANNOT_OPEN_GGML_LIBRARY); + APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_MISSING_BACKEND_SYMBOLS); + APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_MISSING_GGML_SYMBOLS); + APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_BACKEND_FAILED); + + return "Unknown APIR_BACKEND_INITIALIZE error:/"; + +#undef APIR_BACKEND_INITIALIZE_ERROR +} diff --git a/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h b/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h new file mode 100644 index 00000000000..1bc3a5f685b --- /dev/null +++ b/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h @@ -0,0 +1,384 @@ +#pragma once + +#include "ggml-impl.h" + +#include +#include + +#define likely(x) __builtin_expect(!!(x), 1) +#define unlikely(x) __builtin_expect(!!(x), 0) + +struct apir_encoder { + char * cur; + const char * start; + const char * end; + bool fatal; + +}; + +struct apir_decoder { + const char * cur; + const char * end; + bool fatal; +}; + +/* + * new encoder and decoder + */ + +static apir_decoder apir_new_decoder(const char * ptr, size_t size) { + apir_decoder dec = { + .cur = ptr, + .end = ptr + size, + .fatal = false, + }; + + return dec; +} + +static apir_encoder apir_new_encoder(char * ptr, size_t size) { + apir_encoder enc = { + .cur = ptr, + .start = ptr, + .end = ptr + size, + .fatal = false, + }; + + return enc; +} + +/* + * fatal flag handling + */ + +static inline void apir_encoder_reset_fatal(apir_encoder * enc) { + enc->fatal = false; +} + +static inline void apir_encoder_set_fatal(apir_encoder * enc) { + enc->fatal = true; +} + +static inline bool apir_encoder_get_fatal(const apir_encoder * enc) { + return enc->fatal; +} + +static inline void apir_decoder_reset_fatal(apir_decoder * dec) { + dec->fatal = false; +} + +static inline void apir_decoder_set_fatal(apir_decoder * dec) { + dec->fatal = true; +} + +static inline bool apir_decoder_get_fatal(const apir_decoder * dec) { + return dec->fatal; +} + +/* + * encode peek + */ + +static inline bool apir_decoder_peek_internal(apir_decoder * dec, + size_t size, + void * val, + size_t val_size) { + assert(val_size <= size); + + if (unlikely(size > (size_t) (dec->end - dec->cur))) { + GGML_LOG_ERROR("%s: reading too much from the decoder ...\n", __func__); + apir_decoder_set_fatal(dec); + memset(val, 0, val_size); + return false; + } + + /* we should not rely on the compiler to optimize away memcpy... */ + memcpy(val, dec->cur, val_size); + return true; +} + +static inline void apir_decoder_peek(apir_decoder * dec, size_t size, void * val, size_t val_size) { + apir_decoder_peek_internal(dec, size, val, val_size); +} + +static inline const void * apir_decoder_use_inplace(apir_decoder * dec, size_t size) { + if (unlikely(size > (size_t) (dec->end - dec->cur))) { + GGML_LOG_ERROR("%s: reading too much from the decoder ...\n", __func__); + apir_decoder_set_fatal(dec); + return NULL; + } + const void * addr = dec->cur; + dec->cur += size; + + return addr; +} + +/* + * read/write + */ + +static inline void apir_decoder_read(apir_decoder * dec, size_t size, void * val, size_t val_size) { + if (apir_decoder_peek_internal(dec, size, val, val_size)) { + dec->cur += size; + } +} + +static inline char * apir_encoder_write(apir_encoder * enc, size_t size, const void * val, size_t val_size) { + assert(val_size <= size); + assert(size <= ((size_t) (enc->end - enc->cur))); + + char * write_addr = enc->cur; + /* we should not rely on the compiler to optimize away memcpy... */ + memcpy(write_addr, val, val_size); + enc->cur += size; + + return write_addr; +} + +/* + * encode/decode + */ + +static inline void apir_decode(apir_decoder * dec, size_t size, void * data, size_t data_size) { + assert(size % 4 == 0); + apir_decoder_read(dec, size, data, data_size); +} + +static inline void apir_encode(apir_encoder * enc, size_t size, const void * data, size_t data_size) { + assert(size % 4 == 0); + apir_encoder_write(enc, size, data, data_size); +} + +/* + * typed encode/decode + */ + +/* uint8_t */ + +static inline void apir_encode_uint8_t(apir_encoder * enc, const uint8_t * val) { + apir_encode(enc, sizeof(int), val, sizeof(*val)); +} + +static inline void apir_decode_uint8_t(apir_decoder * dec, uint8_t * val) { + apir_decode(dec, sizeof(int), val, sizeof(*val)); +} + +/* uint64_t */ + +static inline void apir_encode_uint64_t(apir_encoder * enc, const uint64_t * val) { + apir_encode(enc, 8, val, sizeof(*val)); +} + +static inline void apir_decode_uint64_t(apir_decoder * dec, uint64_t * val) { + apir_decode(dec, 8, val, sizeof(*val)); +} + +static inline void apir_encode_uint64_t_array(apir_encoder * enc, const uint64_t * val, uint32_t count) { + const size_t size = sizeof(*val) * count; + assert(size >= count); + apir_encode(enc, size, val, size); +} + +static inline void apir_decode_uint64_t_array(apir_decoder * dec, uint64_t * val, uint32_t count) { + const size_t size = sizeof(*val) * count; + assert(size >= count); + apir_decode(dec, size, val, size); +} + +static inline const uint64_t * apir_decode_uint64_t_array_inplace(apir_decoder * dec, uint32_t count) { + return (uint64_t *) (uintptr_t) apir_decoder_use_inplace(dec, count * sizeof(uint64_t)); +} + +/* int32_t */ + +static inline void apir_encode_int32_t(apir_encoder * enc, const int32_t * val) { + apir_encode(enc, 4, val, sizeof(*val)); +} + +static inline void apir_decode_int32_t(apir_decoder * dec, int32_t * val) { + apir_decode(dec, 4, val, sizeof(*val)); +} + +static inline void apir_encode_int32_t_array(apir_encoder * enc, const int32_t * val, uint32_t count) { + const size_t size = sizeof(*val) * count; + assert(size >= count); + apir_encode(enc, size, val, size); +} + +static inline void apir_decode_int32_t_array(apir_decoder * dec, int32_t * val, uint32_t count) { + const size_t size = sizeof(*val) * count; + assert(size >= count); + apir_decode(dec, size, val, size); +} + +/* array size (uint64_t) */ + +static inline void apir_encode_array_size(apir_encoder * enc, uint64_t size) { + apir_encode_uint64_t(enc, &size); +} + +static inline uint64_t apir_decode_array_size(apir_decoder * dec, uint64_t expected_size) { + uint64_t size; + apir_decode_uint64_t(dec, &size); + if (size != expected_size) { + GGML_LOG_ERROR("%s: Couldn't decode array from the decoder\n", __func__); + apir_decoder_set_fatal(dec); + size = 0; + } + return size; +} + +static inline uint64_t apir_decode_array_size_unchecked(apir_decoder * dec) { + uint64_t size; + apir_decode_uint64_t(dec, &size); + return size; +} + +/* non-array pointer */ + +static inline bool apir_encode_simple_pointer(apir_encoder * enc, const void * val) { + apir_encode_array_size(enc, val ? 1 : 0); + return val; +} + +static inline bool apir_decode_simple_pointer(apir_decoder * dec) { + return apir_decode_array_size_unchecked(dec); +} + +/* uint32_t */ + +static inline void apir_encode_uint32_t(apir_encoder * enc, const uint32_t * val) { + apir_encode(enc, 4, val, sizeof(*val)); +} + +static inline void apir_decode_uint32_t(apir_decoder * dec, uint32_t * val) { + apir_decode(dec, 4, val, sizeof(*val)); +} + +static inline void apir_encode_uint32_t_array(apir_encoder * enc, const uint32_t * val, uint32_t count) { + const size_t size = sizeof(*val) * count; + assert(size >= count); + apir_encode(enc, size, val, size); +} + +static inline void apir_decode_uint32_t_array(apir_decoder * dec, uint32_t * val, uint32_t count) { + const size_t size = sizeof(*val) * count; + assert(size >= count); + apir_decode(dec, size, val, size); +} + +/* size_t */ + +static inline void apir_encode_size_t(apir_encoder * enc, const size_t * val) { + const uint64_t tmp = *val; + apir_encode_uint64_t(enc, &tmp); +} + +static inline void apir_decode_size_t(apir_decoder * dec, size_t * val) { + uint64_t tmp; + apir_decode_uint64_t(dec, &tmp); + *val = tmp; +} + +static inline void apir_encode_size_t_array(apir_encoder * enc, const size_t * val, uint32_t count) { + if (sizeof(size_t) == sizeof(uint64_t)) { + apir_encode_uint64_t_array(enc, (const uint64_t *) val, count); + } else { + for (uint32_t i = 0; i < count; i++) { + apir_encode_size_t(enc, &val[i]); + } + } +} + +static inline void apir_decode_size_t_array(apir_decoder * dec, size_t * val, uint32_t count) { + if (sizeof(size_t) == sizeof(uint64_t)) { + apir_decode_uint64_t_array(dec, (uint64_t *) val, count); + } else { + for (uint32_t i = 0; i < count; i++) { + apir_decode_size_t(dec, &val[i]); + } + } +} + +/* opaque blob */ + +static inline void apir_encode_blob_array(apir_encoder * enc, const void * val, size_t size) { + apir_encode(enc, (size + 3) & ~3, val, size); +} + +static inline void apir_decode_blob_array(apir_decoder * dec, void * val, size_t size) { + apir_decode(dec, (size + 3) & ~3, val, size); +} + +/* string */ + +static inline void apir_encode_char_array(apir_encoder * enc, const char * val, size_t size) { + assert(size && strlen(val) < size); + apir_encode_blob_array(enc, val, size); +} + +static inline void apir_decode_char_array(apir_decoder * dec, char * val, size_t size) { + apir_decode_blob_array(dec, val, size); + if (size) { + val[size - 1] = '\0'; + } else { + GGML_LOG_ERROR("%s: Couldn't decode the blog array\n", __func__); + apir_decoder_set_fatal(dec); + } +} + +/* (temp) buffer allocation */ + +static inline void * apir_decoder_alloc_array(size_t size, size_t count) { + size_t alloc_size; + if (unlikely(__builtin_mul_overflow(size, count, &alloc_size))) { + GGML_LOG_ERROR("%s: overflow in array allocation of %zu * %zu bytes\n", + __func__, size, count); + return NULL; + } + + return malloc(alloc_size); +} + +/* bool */ + +static inline void apir_encode_bool_t(apir_encoder * enc, const bool * val) { + apir_encode(enc, sizeof(int), val, sizeof(bool)); +} + +static inline void apir_decode_bool_t(apir_decoder * dec, bool * val) { + apir_decode(dec, sizeof(int), val, sizeof(bool)); +} + +/* apir_buffer_type_host_handle_t */ + +static inline void apir_encode_apir_buffer_type_host_handle_t(apir_encoder * enc, + const apir_buffer_type_host_handle_t * val) { + apir_encode(enc, sizeof(apir_buffer_type_host_handle_t), val, sizeof(apir_buffer_type_host_handle_t)); +} + +static inline void apir_decode_apir_buffer_type_host_handle_t(apir_decoder * dec, + apir_buffer_type_host_handle_t * val) { + apir_decode(dec, sizeof(apir_buffer_type_host_handle_t), val, sizeof(apir_buffer_type_host_handle_t)); +} + +/* apir_buffer_host_handle_t */ + +static inline void apir_encode_apir_buffer_host_handle_t(apir_encoder * enc, + const apir_buffer_host_handle_t * val) { + apir_encode(enc, sizeof(apir_buffer_host_handle_t), val, sizeof(apir_buffer_host_handle_t)); +} + +static inline void apir_decode_apir_buffer_host_handle_t(apir_decoder * dec, apir_buffer_host_handle_t * val) { + apir_decode(dec, sizeof(apir_buffer_host_handle_t), val, sizeof(apir_buffer_host_handle_t)); +} + +/* uintptr_t */ + +static inline void apir_encode_uintptr_t(apir_encoder * enc, const uintptr_t * val) { + apir_encode(enc, sizeof(*val), val, sizeof(*val)); +} + +static inline void apir_decode_uintptr_t(apir_decoder * dec, uintptr_t * val) { + apir_decode(dec, sizeof(*val), val, sizeof(*val)); +} diff --git a/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h b/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h new file mode 100644 index 00000000000..289f4b77d74 --- /dev/null +++ b/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h @@ -0,0 +1,221 @@ +#include "ggml-impl.h" +#include "apir_cs.h" +#include "apir_cs_rpc.h" + +// ggml_buffer_to_apir_host_handle(ggml_backend_buffer_t buffer); + +static inline void apir_encode_ggml_buffer_host_handle(apir_encoder * enc, + const apir_buffer_host_handle_t * handle); + +static inline ggml_backend_buffer_t apir_decode_ggml_buffer(apir_decoder * dec); + +/* apir_rpc_tensor */ + +static inline void apir_encode_rcp_tensor(apir_encoder * enc, const apir_rpc_tensor * apir_rpc_tensor) { + size_t apir_rpc_tensor_size = sizeof(*apir_rpc_tensor); + apir_encode(enc, apir_rpc_tensor_size, apir_rpc_tensor, apir_rpc_tensor_size); +} + +static inline apir_rpc_tensor * apir_decode_apir_rpc_tensor_inplace(apir_decoder * dec) { + size_t apir_rpc_tensor_size = sizeof(apir_rpc_tensor); + + return (apir_rpc_tensor *) (uintptr_t) apir_decoder_use_inplace(dec, apir_rpc_tensor_size); +} + +static inline apir_rpc_tensor * apir_decode_apir_rpc_tensor_array_inplace(apir_decoder * dec, + uint32_t n_tensors) { + size_t apir_rpc_tensor_size = sizeof(apir_rpc_tensor) * n_tensors; + + return (apir_rpc_tensor *) (uintptr_t) apir_decoder_use_inplace(dec, apir_rpc_tensor_size); +} + +/* ggml_tensor */ + +static inline void apir_encode_ggml_tensor(apir_encoder * enc, const ggml_tensor * tensor) { + apir_rpc_tensor serialized = apir_serialize_tensor(tensor); + + apir_encode_rcp_tensor(enc, &serialized); +} + +static inline const ggml_tensor * apir_decode_ggml_tensor(apir_decoder * dec) { + const apir_rpc_tensor * apir_rpc_tensor = apir_decode_apir_rpc_tensor_inplace(dec); + + if (!apir_rpc_tensor) { + return NULL; + } + + ggml_init_params params{ + /*.mem_size =*/ ggml_tensor_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + ggml_context * ctx = ggml_init(params); + + const ggml_tensor * tensor = apir_deserialize_tensor(ctx, apir_rpc_tensor); + + return tensor; +} + +/* *** ggml_backend_buffer_type_t *** */ + +// ggml_backend_buffer_type_t is a POINTER (to a struct). +// Only the host pointer is shared between the host and guest. +// The guest stores it in `buft->context`. +// The host simply writes the pointer address in the buffer variable. + +static inline void apir_encode_ggml_buffer_type(apir_encoder * enc, ggml_backend_buffer_type_t buft) { + apir_buffer_type_host_handle_t handle = ggml_buffer_type_to_apir_handle(buft); + apir_encoder_write(enc, sizeof(handle), &handle, sizeof(handle)); +} + +static inline ggml_backend_buffer_type_t apir_decode_ggml_buffer_type(apir_decoder * dec) { + apir_buffer_type_host_handle_t handle; + + apir_decoder_read(dec, sizeof(handle), &handle, sizeof(handle)); + + return (ggml_backend_buffer_type_t) handle; +} + +static inline void apir_encode_apir_buffer_type_host_handle(apir_encoder * enc, apir_buffer_type_host_handle_t handle) { + apir_encoder_write(enc, sizeof(handle), &handle, sizeof(handle)); +} + +static inline apir_buffer_type_host_handle_t apir_decode_apir_buffer_type_host_handle(apir_decoder * dec) { + apir_buffer_type_host_handle_t handle; + + apir_decoder_read(dec, sizeof(handle), &handle, sizeof(handle)); + + return handle; +} + +/* *** ggml_backend_type_t *** */ + +// ggml_backend_buffer_t is a POINTER. +// same logic as for ggml_backend_buffer_type_t + +static inline void apir_encode_ggml_buffer(apir_encoder * enc, const ggml_backend_buffer_t buffer) { + apir_buffer_host_handle_t handle = BUFFER_TO_HOST_HANDLE(buffer); + apir_encoder_write(enc, sizeof(handle), &handle, sizeof(handle)); +} + +static inline ggml_backend_buffer_t apir_decode_ggml_buffer(apir_decoder * dec) { + ggml_backend_buffer_t buffer; + size_t buffer_ptr_size = sizeof(buffer); + + apir_decoder_read(dec, buffer_ptr_size, &buffer, buffer_ptr_size); + + return buffer; +} + +/* enum ggml_status */ + +static inline void apir_encode_ggml_status(apir_encoder * enc, const ggml_status * status) { + apir_encoder_write(enc, sizeof(*status), status, sizeof(*status)); +} + +static inline void apir_decode_ggml_status(apir_decoder * dec, ggml_status * status) { + apir_decoder_read(dec, sizeof(*status), status, sizeof(*status)); +} + +/* virtgpu_shmem */ + +static inline void apir_encode_virtgpu_shmem_res_id(apir_encoder * enc, uint32_t shmem_res_id) { + apir_encode_uint32_t(enc, &shmem_res_id); +} + +static inline void apir_decode_virtgpu_shmem_res_id(apir_decoder * dec, uint32_t * shmem_res_id) { + apir_decode_uint32_t(dec, shmem_res_id); +} + +/* ggml_cgraph */ + +static inline size_t apir_serialize_ggml_cgraph(ggml_cgraph * cgraph, std::vector & cgraph_data) { + apir_serialize_graph(cgraph, cgraph_data); + + return cgraph_data.size(); +} + +static inline void apir_encode_cgraph_data(apir_encoder * enc, std::vector & cgraph_data) { + size_t cgraph_size = cgraph_data.size(); + + apir_encode(enc, cgraph_size, cgraph_data.data(), cgraph_size); +} + +static inline ggml_cgraph * apir_decode_ggml_cgraph(apir_decoder * dec, size_t cgraph_size) { + GGML_UNUSED(cgraph_size); + + uint32_t n_nodes; + apir_decode_uint32_t(dec, &n_nodes); + const uint64_t * nodes = apir_decode_uint64_t_array_inplace(dec, n_nodes); + + uint32_t n_tensors; + apir_decode_uint32_t(dec, &n_tensors); + const apir_rpc_tensor * tensors = apir_decode_apir_rpc_tensor_array_inplace(dec, n_tensors); + + return apir_deserialize_graph(n_nodes, n_tensors, tensors, nodes); +} + +static inline void apir_encode_ggml_buffer_handle(apir_encoder * enc, const apir_buffer_host_handle_t * handle) { + apir_encoder_write(enc, sizeof(*handle), &handle, sizeof(*handle)); +} + +static inline void apir_encode_ggml_tensor_inline(apir_encoder * enc, const ggml_tensor * tensor) { + size_t tensor_size = sizeof(*tensor); + + if (tensor->extra) { + GGML_ABORT("%s: Cannot pass tensors with extra", __func__); + } + + if (tensor->src[0] && tensor->buffer) { + static int first = 1; + if (first) { + GGML_LOG_WARN("%s: Cannot pass tensors with src and buffer\n", __func__); + first = 0; + } + } + + apir_encoder_write(enc, tensor_size, tensor, tensor_size); + + // tensor->data is a pointer inside the device buffer. No need to touch it + // tensor->buffer is a pointer to a buffer. Encoding the buffer handle in sequence. + // (could also make a copy of the tensor, and update locally.) + + if (tensor->buffer) { + apir_buffer_host_handle_t buffer_handle = ggml_buffer_to_apir_handle(tensor->buffer); + apir_encode_ggml_buffer_handle(enc, &buffer_handle); + } + + if (tensor->view_src) { + apir_encoder_write(enc, tensor_size, tensor->view_src, tensor_size); + } + + for (int i = 0; tensor->src[i]; i++) { + const ggml_tensor * tensor_src = tensor->src[i]; + apir_encoder_write(enc, tensor_size, tensor_src, tensor_size); + } +} + +static inline const ggml_tensor * apir_decode_ggml_tensor_inplace(apir_decoder * dec) { + // it safe to remove the `const` qualifier here, we *do* want to + // modify the shared memory data to fix the `src` pointers. + ggml_tensor * tensor = (ggml_tensor *) (uintptr_t) apir_decoder_use_inplace(dec, sizeof(ggml_tensor)); + + // tensor->data is a pointer inside the device buffer. No need to touch it + // tensor->buffer is a pointer to a buffer. Decode the buffer handle encoded in sequence. + if (tensor->buffer) { + tensor->buffer = apir_decode_ggml_buffer(dec); + } + + if (tensor->view_src) { + ggml_tensor * tensor_view_src = (ggml_tensor *) (uintptr_t) apir_decoder_use_inplace(dec, sizeof(ggml_tensor)); + tensor->view_src = tensor_view_src; + } + + for (int i = 0; tensor->src[i]; i++) { + ggml_tensor * tensor_src = (ggml_tensor *) (uintptr_t) apir_decoder_use_inplace(dec, sizeof(ggml_tensor)); + tensor->src[i] = tensor_src; // overwrite op->src[i] pointer with the actual location of the src tensor + } + + return tensor; +} diff --git a/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h b/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h new file mode 100644 index 00000000000..f6817989528 --- /dev/null +++ b/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h @@ -0,0 +1,54 @@ +#include "ggml.h" +#include "ggml-backend-impl.h" + +#include +#include +#include +#include + +// ggml_tensor is serialized into apir_rpc_tensor +struct apir_rpc_tensor { + uint64_t id; + uint32_t type; + uint64_t buffer; + uint32_t ne[GGML_MAX_DIMS]; + uint32_t nb[GGML_MAX_DIMS]; + uint32_t op; + int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)]; + int32_t flags; + uint64_t src[GGML_MAX_SRC]; + uint64_t view_src; + uint64_t view_offs; + uint64_t data; + char name[GGML_MAX_NAME]; + + char padding[4]; +}; + +/* frontend */ + +apir_rpc_tensor apir_serialize_tensor(const ggml_tensor * tensor); + +void apir_serialize_graph(const ggml_cgraph * cgraph, std::vector & output); + +/* backend */ + +void apir_track_backend_buffer(ggml_backend_buffer_t buffer); +bool apir_untrack_backend_buffer(ggml_backend_buffer_t buffer); +std::unordered_set apir_get_track_backend_buffers(); + +void apir_add_tensor(ggml_tensor * tensor, + std::vector & tensors, + std::unordered_set & visited); + +ggml_tensor * apir_deserialize_tensor(ggml_context * ctx, const apir_rpc_tensor * tensor); + +ggml_tensor * apir_create_node(uint64_t id, + ggml_context * ctx, + const std::unordered_map & tensor_ptrs, + std::unordered_map & tensor_map); + +ggml_cgraph * apir_deserialize_graph(uint32_t n_nodes, + uint32_t n_tensors, + const apir_rpc_tensor * tensors, + const uint64_t * nodes); diff --git a/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp b/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp new file mode 100644 index 00000000000..c493a8e2ae3 --- /dev/null +++ b/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp @@ -0,0 +1,81 @@ +#include "ggml-remoting.h" + +static ggml_backend_buffer_t ggml_backend_remoting_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, + size_t size) { + virtgpu * gpu = BUFT_TO_GPU(buft); + + ggml_backend_remoting_buffer_context * context = (ggml_backend_remoting_buffer_context *) malloc(sizeof(*context)); + if (!context) { + GGML_ABORT(GGML_VIRTGPU "%s: Couldn't allocate the buffer context ...", __func__); + } + + context->gpu = gpu; + + bool async__unused, host_buffer__unused, events__unused; + bool buffer_from_host_ptr; + apir_device_get_props(gpu, &async__unused, &host_buffer__unused, &buffer_from_host_ptr, &events__unused); + + if (buffer_from_host_ptr) { + context->apir_context = apir_device_buffer_from_ptr(gpu, size, size); + context->base = context->apir_context.shmem.mmap_ptr; + context->is_from_ptr = true; + } else { + context->apir_context = apir_buffer_type_alloc_buffer(gpu, gpu->cached_buffer_type.host_handle, size); + context->is_from_ptr = false; + context->base = NULL; + } + + ggml_backend_buffer_t buffer = + ggml_backend_buffer_init(buft, ggml_backend_remoting_buffer_interface, (void *) context, size); + + return buffer; +} + +static const char * ggml_backend_remoting_buffer_type_get_name(ggml_backend_buffer_type_t buft) { + virtgpu * gpu = BUFT_TO_GPU(buft); + + return gpu->cached_buffer_type.name; +} + +static size_t ggml_backend_remoting_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { + virtgpu * gpu = BUFT_TO_GPU(buft); + + return gpu->cached_buffer_type.alignment; +} + +static size_t ggml_backend_remoting_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) { + virtgpu * gpu = BUFT_TO_GPU(buft); + + return gpu->cached_buffer_type.max_size; +} + +static size_t ggml_backend_remoting_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, + const ggml_tensor * tensor) { + virtgpu * gpu = BUFT_TO_GPU(buft); + + if (tensor->buffer == NULL + || !tensor->buffer->context + || !buft->device->iface.supports_buft(buft->device, tensor->buffer->buft)) { + return ggml_nbytes(tensor); + } + + return apir_buffer_type_get_alloc_size(gpu, gpu->cached_buffer_type.host_handle, tensor); +} + +const ggml_backend_buffer_type_i ggml_backend_remoting_buffer_type_interface = { + /* .get_name = */ ggml_backend_remoting_buffer_type_get_name, + /* .alloc_buffer = */ ggml_backend_remoting_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_remoting_buffer_type_get_alignment, + /* .get_max_size = */ ggml_backend_remoting_buffer_type_get_max_size, + /* .get_alloc_size = */ ggml_backend_remoting_buffer_type_get_alloc_size, + /* .is_host = */ NULL, +}; + +const ggml_backend_buffer_type_i ggml_backend_remoting_buffer_from_ptr_type_interface = { + /* .get_name = */ ggml_backend_remoting_buffer_type_get_name, + /* .alloc_buffer = */ NULL, + /* .get_alignment = */ ggml_backend_remoting_buffer_type_get_alignment, + /* .get_max_size = */ ggml_backend_remoting_buffer_type_get_max_size, + /* .get_alloc_size = */ ggml_backend_remoting_buffer_type_get_alloc_size, + /* .is_host = */ NULL, +}; diff --git a/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp b/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp new file mode 100644 index 00000000000..6b95362dd80 --- /dev/null +++ b/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp @@ -0,0 +1,119 @@ +#include "ggml-remoting.h" + +#define BUFFER_TO_GPU(name) ((ggml_backend_remoting_buffer_context *) (name)->context)->gpu + +static void * ggml_backend_remoting_buffer_get_base(ggml_backend_buffer_t buffer) { + ggml_backend_remoting_buffer_context * context = (ggml_backend_remoting_buffer_context *) buffer->context; + if (context->base) { + return context->base; + } + + context->base = apir_buffer_get_base(BUFFER_TO_GPU(buffer), BUFFER_TO_APIR_CONTEXT(buffer)); + + return context->base; +} + +static void ggml_backend_remoting_buffer_set_tensor(ggml_backend_buffer_t buffer, + ggml_tensor * tensor, + const void * data, + size_t offset, + size_t size) { + virtgpu * gpu = BUFFER_TO_GPU(buffer); + + ggml_backend_remoting_buffer_context * context = BUFFER_TO_GGML_CONTEXT(buffer); + if (context->is_from_ptr) { + memcpy((char *) tensor->data + offset, data, size); + } else { + apir_buffer_set_tensor(gpu, BUFFER_TO_APIR_CONTEXT(buffer), tensor, data, offset, size); + } + + return; +} + +static void ggml_backend_remoting_buffer_get_tensor(ggml_backend_buffer_t buffer, + const ggml_tensor * tensor, + void * data, + size_t offset, + size_t size) { + virtgpu * gpu = BUFFER_TO_GPU(buffer); + ggml_backend_remoting_buffer_context * context = BUFFER_TO_GGML_CONTEXT(buffer); + if (context->is_from_ptr) { + memcpy(data, (const char *) tensor->data + offset, size); + } else { + apir_buffer_get_tensor(gpu, BUFFER_TO_APIR_CONTEXT(buffer), tensor, data, offset, size); + } +} + +static void ggml_backend_remoting_buffer_set_tensor_from_ptr(ggml_backend_buffer_t buffer, + ggml_tensor * tensor, + const void * data, + size_t offset, + size_t size) { + UNUSED(buffer); + + memcpy((char *) tensor->data + offset, data, size); + + return; +} + +static void ggml_backend_remoting_buffer_get_tensor_from_ptr(ggml_backend_buffer_t buffer, + const ggml_tensor * tensor, + void * data, + size_t offset, + size_t size) { + UNUSED(buffer); + + memcpy(data, (const char *) tensor->data + offset, size); +} + +static bool ggml_backend_remoting_buffer_cpy_tensor(ggml_backend_buffer_t buffer, + const ggml_tensor * src, + ggml_tensor * dst) { + virtgpu * gpu = BUFFER_TO_GPU(buffer); + + bool ret = apir_buffer_cpy_tensor(gpu, BUFFER_TO_APIR_CONTEXT(buffer), src, dst); + + return ret; +} + +static void ggml_backend_remoting_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { + virtgpu * gpu = BUFFER_TO_GPU(buffer); + + apir_buffer_clear(gpu, BUFFER_TO_APIR_CONTEXT(buffer), value); + + return; +} + +static void ggml_backend_remoting_buffer_free_buffer(ggml_backend_buffer_t buffer) { + virtgpu * gpu = BUFFER_TO_GPU(buffer); + + apir_buffer_free_buffer(gpu, BUFFER_TO_APIR_CONTEXT(buffer)); + + ggml_backend_remoting_buffer_context * context = BUFFER_TO_GGML_CONTEXT(buffer); + free(context); + buffer->context = NULL; +} + +const ggml_backend_buffer_i ggml_backend_remoting_buffer_interface = { + /* .free_buffer = */ ggml_backend_remoting_buffer_free_buffer, + /* .get_base = */ ggml_backend_remoting_buffer_get_base, + /* .init_tensor = */ NULL, + /* .memset_tensor = */ NULL, + /* .set_tensor = */ ggml_backend_remoting_buffer_set_tensor, + /* .get_tensor = */ ggml_backend_remoting_buffer_get_tensor, + /* .cpy_tensor = */ ggml_backend_remoting_buffer_cpy_tensor, + /* .clear = */ ggml_backend_remoting_buffer_clear, + /* .reset = */ NULL, +}; + +const ggml_backend_buffer_i ggml_backend_remoting_buffer_from_ptr_interface = { + /* .free_buffer = */ ggml_backend_remoting_buffer_free_buffer, + /* .get_base = */ ggml_backend_remoting_buffer_get_base, + /* .init_tensor = */ NULL, + /* .memset_tensor = */ NULL, + /* .set_tensor = */ ggml_backend_remoting_buffer_set_tensor_from_ptr, + /* .get_tensor = */ ggml_backend_remoting_buffer_get_tensor_from_ptr, + /* .cpy_tensor = */ ggml_backend_remoting_buffer_cpy_tensor, + /* .clear = */ ggml_backend_remoting_buffer_clear, + /* .reset = */ NULL, +}; diff --git a/ggml/src/ggml-virtgpu/ggml-backend-device.cpp b/ggml/src/ggml-virtgpu/ggml-backend-device.cpp new file mode 100644 index 00000000000..c7d2881058b --- /dev/null +++ b/ggml/src/ggml-virtgpu/ggml-backend-device.cpp @@ -0,0 +1,157 @@ +#include "ggml-remoting.h" + +static const char * ggml_backend_remoting_device_get_name(ggml_backend_dev_t dev) { + virtgpu * gpu = DEV_TO_GPU(dev); + + return gpu->cached_device_info.name; +} + +static const char * ggml_backend_remoting_device_get_description(ggml_backend_dev_t dev) { + virtgpu * gpu = DEV_TO_GPU(dev); + + // Return the pre-cached description from the virtgpu structure + return gpu->cached_device_info.description; +} + +static enum ggml_backend_dev_type ggml_backend_remoting_device_get_type(ggml_backend_dev_t dev) { + virtgpu * gpu = DEV_TO_GPU(dev); + + return (enum ggml_backend_dev_type) gpu->cached_device_info.type; +} + +static void ggml_backend_remoting_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { + virtgpu * gpu = DEV_TO_GPU(dev); + + *free = gpu->cached_device_info.memory_free; + *total = gpu->cached_device_info.memory_total; +} + +static bool ggml_backend_remoting_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) { +#if USE_ALWAYS_TRUE_SUPPORTS_OP == 1 + /* ggml-rpc cheats it like this */ + /* with the current implementation of serialize_tensor, the src/view aren't properly passed */ + UNUSED(dev); + UNUSED(op); + + return true; +#else + virtgpu * gpu = DEV_TO_GPU(dev); + + return apir_device_supports_op(gpu, op); +#endif +} + +static bool ggml_backend_remoting_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { + bool supported = buft->device == dev; + + return supported; +} + +static bool ggml_backend_remoting_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) { + UNUSED(dev); + UNUSED(op); + + return false; +} + +static void ggml_backend_remoting_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) { + props->name = ggml_backend_remoting_device_get_name(dev); + props->description = ggml_backend_remoting_device_get_description(dev); + props->type = ggml_backend_remoting_device_get_type(dev); + ggml_backend_remoting_device_get_memory(dev, &props->memory_free, &props->memory_total); + + virtgpu * gpu = DEV_TO_GPU(dev); + apir_device_get_props(gpu, &props->caps.async, &props->caps.host_buffer, &props->caps.buffer_from_host_ptr, + &props->caps.events); + + props->caps.buffer_from_host_ptr = false; + props->caps.async = false; + props->caps.events = false; +} + +ggml_backend_buffer_type_t ggml_backend_remoting_device_get_buffer_type(ggml_backend_dev_t dev) { + virtgpu * gpu = DEV_TO_GPU(dev); + + static std::atomic initialized = false; + static ggml_backend_buffer_type buft; + + if (!initialized) { + static std::mutex mutex; + std::lock_guard lock(mutex); + + if (!initialized) { + buft = { + /* .iface = */ ggml_backend_remoting_buffer_type_interface, + /* .device = */ dev, + /* .context = */ (void *) gpu->cached_buffer_type.host_handle, + }; + initialized = true; + } + } + + return &buft; +} + +static ggml_backend_buffer_type_t ggml_backend_remoting_device_get_buffer_from_ptr_type(ggml_backend_dev_t dev) { + virtgpu * gpu = DEV_TO_GPU(dev); + + static std::atomic initialized = false; + static ggml_backend_buffer_type buft; + + if (!initialized) { + static std::mutex mutex; + std::lock_guard lock(mutex); + + if (!initialized) { + buft = { + /* .iface = */ ggml_backend_remoting_buffer_from_ptr_type_interface, + /* .device = */ dev, + /* .context = */ (void *) gpu->cached_buffer_type.host_handle, + }; + initialized = true; + } + } + + return &buft; +} + +static ggml_backend_buffer_t ggml_backend_remoting_device_buffer_from_ptr(ggml_backend_dev_t dev, + void * ptr, + size_t size, + size_t max_tensor_size) { + virtgpu * gpu = DEV_TO_GPU(dev); + + ggml_backend_remoting_buffer_context * context = (ggml_backend_remoting_buffer_context *) malloc(sizeof(*context)); + if (!context) { + GGML_ABORT(GGML_VIRTGPU "%s: Couldn't allocate the buffer context ...", __func__); + } + + context->gpu = gpu; + context->apir_context = apir_device_buffer_from_ptr(gpu, size, max_tensor_size); + context->base = ptr; + context->is_from_ptr = true; + + ggml_backend_buffer_t buffer = + ggml_backend_buffer_init(ggml_backend_remoting_device_get_buffer_from_ptr_type(dev), + ggml_backend_remoting_buffer_from_ptr_interface, (void *) context, size); + + return buffer; +} + +const ggml_backend_device_i ggml_backend_remoting_device_interface = { + /* .get_name = */ ggml_backend_remoting_device_get_name, + /* .get_description = */ ggml_backend_remoting_device_get_description, + /* .get_memory = */ ggml_backend_remoting_device_get_memory, + /* .get_type = */ ggml_backend_remoting_device_get_type, + /* .get_props = */ ggml_backend_remoting_device_get_props, + /* .init_backend = */ ggml_backend_remoting_device_init, + /* .get_buffer_type = */ ggml_backend_remoting_device_get_buffer_type, + /* .get_host_buffer_type = */ NULL, + /* .buffer_from_host_ptr = */ ggml_backend_remoting_device_buffer_from_ptr, + /* .supports_op = */ ggml_backend_remoting_device_supports_op, + /* .supports_buft = */ ggml_backend_remoting_device_supports_buft, + /* .offload_op = */ ggml_backend_remoting_device_offload_op, + /* .event_new = */ NULL, + /* .event_free = */ NULL, + /* .event_synchronize = */ NULL, +}; diff --git a/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp b/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp new file mode 100644 index 00000000000..2d02cfec1d3 --- /dev/null +++ b/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp @@ -0,0 +1,189 @@ +#include "ggml-remoting.h" +#include "ggml-virtgpu.h" + +#include +#include + +void ggml_virtgpu_cleanup(virtgpu * gpu); + +static virtgpu * apir_initialize() { + static virtgpu * gpu = NULL; + static std::atomic initialized = false; + + if (initialized) { + // fast track + return gpu; + } + + { + static std::mutex mutex; + std::lock_guard lock(mutex); + + if (initialized) { + // thread safe + return gpu; + } + + gpu = create_virtgpu(); + if (!gpu) { + initialized = true; + return NULL; + } + + // Pre-fetch and cache all device information, it will not change + gpu->cached_device_info.description = apir_device_get_description(gpu); + if (!gpu->cached_device_info.description) { + GGML_ABORT(GGML_VIRTGPU "%s: failed to initialize the virtgpu device description", __func__); + } + gpu->cached_device_info.name = apir_device_get_name(gpu); + if (!gpu->cached_device_info.name) { + GGML_ABORT(GGML_VIRTGPU "%s: failed to initialize the virtgpu device name", __func__); + } + gpu->cached_device_info.device_count = apir_device_get_count(gpu); + gpu->cached_device_info.type = apir_device_get_type(gpu); + + apir_device_get_memory(gpu, + &gpu->cached_device_info.memory_free, + &gpu->cached_device_info.memory_total); + + apir_buffer_type_host_handle_t buft_host_handle = apir_device_get_buffer_type(gpu); + gpu->cached_buffer_type.host_handle = buft_host_handle; + gpu->cached_buffer_type.name = apir_buffer_type_get_name(gpu, buft_host_handle); + if (!gpu->cached_buffer_type.name) { + GGML_ABORT(GGML_VIRTGPU "%s: failed to initialize the virtgpu buffer type name", __func__); + } + gpu->cached_buffer_type.alignment = apir_buffer_type_get_alignment(gpu, buft_host_handle); + gpu->cached_buffer_type.max_size = apir_buffer_type_get_max_size(gpu, buft_host_handle); + + initialized = true; + } + + return gpu; +} + +static int ggml_backend_remoting_get_device_count() { + virtgpu * gpu = apir_initialize(); + if (!gpu) { + return 0; + } + + return gpu->cached_device_info.device_count; +} + +static size_t ggml_backend_remoting_reg_get_device_count(ggml_backend_reg_t reg) { + UNUSED(reg); + + return ggml_backend_remoting_get_device_count(); +} + +static std::vector devices; + +ggml_backend_dev_t ggml_backend_remoting_get_device(size_t device) { + GGML_ASSERT(device < devices.size()); + return devices[device]; +} + +static void ggml_backend_remoting_reg_init_devices(ggml_backend_reg_t reg) { + if (devices.size() > 0) { + GGML_LOG_INFO(GGML_VIRTGPU "%s: already initialized\n", __func__); + return; + } + + virtgpu * gpu = apir_initialize(); + if (!gpu) { + GGML_LOG_ERROR(GGML_VIRTGPU "%s: apir_initialize failed\n", __func__); + return; + } + + static std::atomic initialized = false; + + if (initialized) { + return; // fast track + } + + { + static std::mutex mutex; + std::lock_guard lock(mutex); + if (!initialized) { + for (int i = 0; i < ggml_backend_remoting_get_device_count(); i++) { + ggml_backend_remoting_device_context * ctx = new ggml_backend_remoting_device_context; + char desc[256] = "ggml-virtgpu API Remoting device"; + + ctx->device = i; + ctx->name = GGML_VIRTGPU_NAME + std::to_string(i); + ctx->description = desc; + ctx->gpu = gpu; + + ggml_backend_dev_t dev = new ggml_backend_device{ + /* .iface = */ ggml_backend_remoting_device_interface, + /* .reg = */ reg, + /* .context = */ ctx, + }; + devices.push_back(dev); + } + initialized = true; + } + } +} + +static ggml_backend_dev_t ggml_backend_remoting_reg_get_device(ggml_backend_reg_t reg, size_t device) { + UNUSED(reg); + + return ggml_backend_remoting_get_device(device); +} + +static const char * ggml_backend_remoting_reg_get_name(ggml_backend_reg_t reg) { + UNUSED(reg); + + return GGML_VIRTGPU_NAME; +} + +static const ggml_backend_reg_i ggml_backend_remoting_reg_i = { + /* .get_name = */ ggml_backend_remoting_reg_get_name, + /* .get_device_count = */ ggml_backend_remoting_reg_get_device_count, + /* .get_device = */ ggml_backend_remoting_reg_get_device, + /* .get_proc_address = */ NULL, +}; + +ggml_backend_reg_t ggml_backend_virtgpu_reg() { + virtgpu * gpu = apir_initialize(); + if (!gpu) { + GGML_LOG_ERROR(GGML_VIRTGPU "%s: virtgpu_apir_initialize failed\n", __func__); + } + + static ggml_backend_reg reg = { + /* .api_version = */ GGML_BACKEND_API_VERSION, + /* .iface = */ ggml_backend_remoting_reg_i, + /* .context = */ gpu, + }; + + static bool initialized = false; + if (initialized) { + return ® + } + initialized = true; + + ggml_backend_remoting_reg_init_devices(®); + + return ® +} + +// public function, not exposed in the GGML interface at the moment +void ggml_virtgpu_cleanup(virtgpu * gpu) { + if (gpu->cached_device_info.name) { + free(gpu->cached_device_info.name); + gpu->cached_device_info.name = NULL; + } + if (gpu->cached_device_info.description) { + free(gpu->cached_device_info.description); + gpu->cached_device_info.description = NULL; + } + if (gpu->cached_buffer_type.name) { + free(gpu->cached_buffer_type.name); + gpu->cached_buffer_type.name = NULL; + } + + mtx_destroy(&gpu->data_shmem_mutex); +} + +GGML_BACKEND_DL_IMPL(ggml_backend_virtgpu_reg) diff --git a/ggml/src/ggml-virtgpu/ggml-backend.cpp b/ggml/src/ggml-virtgpu/ggml-backend.cpp new file mode 100644 index 00000000000..5cd6c0c0608 --- /dev/null +++ b/ggml/src/ggml-virtgpu/ggml-backend.cpp @@ -0,0 +1,69 @@ +#include "ggml-remoting.h" +#include "../../include/ggml-virtgpu.h" + +static const char * ggml_backend_remoting_get_name(ggml_backend_t backend) { + UNUSED(backend); + + return "API Remoting backend"; +} + +static void ggml_backend_remoting_free(ggml_backend_t backend) { + delete backend; +} + +static ggml_status ggml_backend_remoting_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { + virtgpu * gpu = DEV_TO_GPU(backend->device); + + return apir_backend_graph_compute(gpu, cgraph); +} + +static void ggml_backend_remoting_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) { + virtgpu * gpu = DEV_TO_GPU(backend->device); +#if true + UNUSED(gpu); + UNUSED(cgraph); +#else + // not working yet + + apir_backend_graph_optimize(gpu, cgraph); +#endif +} + +static ggml_backend_i ggml_backend_remoting_interface = { + /* .get_name = */ ggml_backend_remoting_get_name, + /* .free = */ ggml_backend_remoting_free, + /* .set_tensor_async = */ NULL, // ggml_backend_remoting_set_tensor_async, + /* .get_tensor_async = */ NULL, // ggml_backend_remoting_get_tensor_async, + /* .cpy_tensor_async = */ NULL, // ggml_backend_remoting_cpy_tensor_async, + /* .synchronize = */ NULL, // ggml_backend_remoting_synchronize, + /* .graph_plan_create = */ NULL, + /* .graph_plan_free = */ NULL, + /* .graph_plan_update = */ NULL, + /* .graph_plan_compute = */ NULL, + /* .graph_compute = */ ggml_backend_remoting_graph_compute, + /* .event_record = */ NULL, + /* .event_wait = */ NULL, + /* .graph_optimize = */ ggml_backend_remoting_graph_optimize, +}; + +static ggml_guid_t ggml_backend_remoting_guid() { + static ggml_guid guid = { 0xb8, 0xf7, 0x4f, 0x86, 0x14, 0x03, 0x86, 0x02, + 0x91, 0xc8, 0xdd, 0xe9, 0x02, 0x3f, 0xc0, 0x2b }; + + return &guid; +} + +ggml_backend_t ggml_backend_remoting_device_init(ggml_backend_dev_t dev, const char * params) { + UNUSED(params); + + ggml_backend_remoting_device_context * ctx = (ggml_backend_remoting_device_context *) dev->context; + + ggml_backend_t remoting_backend = new ggml_backend{ + /* .guid = */ ggml_backend_remoting_guid(), + /* .interface = */ ggml_backend_remoting_interface, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_virtgpu_reg(), ctx->device), + /* .context = */ ctx, + }; + + return remoting_backend; +} diff --git a/ggml/src/ggml-virtgpu/ggml-remoting.h b/ggml/src/ggml-virtgpu/ggml-remoting.h new file mode 100644 index 00000000000..08766408676 --- /dev/null +++ b/ggml/src/ggml-virtgpu/ggml-remoting.h @@ -0,0 +1,71 @@ +#pragma once + +#include "ggml-backend-impl.h" +#include "ggml-backend.h" +#include "ggml-impl.h" +#include "virtgpu.h" + +#include +#include + +#define GGML_VIRTGPU_NAME "ggml-virtgpu" +#define GGML_VIRTGPU "ggml-virtgpu: " + +// USE_ALWAYS_TRUE_SUPPORTS_OP: 1 is fast, 0 avoid micro-benchmark crashes + +#define USE_ALWAYS_TRUE_SUPPORTS_OP 1 +#define USE_METAL_GUEST_SUPPORTS_OP 0 + +#define DEV_TO_GPU(name) ((ggml_backend_remoting_device_context *) (name)->context)->gpu + +#define BUFFER_TO_GGML_CONTEXT(name) ((ggml_backend_remoting_buffer_context *) (name)->context) + +#define BUFFER_TO_APIR_CONTEXT(name) &((ggml_backend_remoting_buffer_context *) (name)->context)->apir_context + +#define BUFFER_TO_HOST_HANDLE(name) ((ggml_backend_remoting_buffer_context *) (name)->context)->apir_context.host_handle + +#define GET_DEVICE_CONTEXT() (ggml_backend_remoting_device_context *) ggml_backend_remoting_get_device(0)->context + +#define BUFT_TO_GPU(name) ((ggml_backend_remoting_device_context *) (name)->device->context)->gpu + +struct ggml_backend_remoting_device_context { + size_t device; + std::string name; + std::string description; + + std::vector> shared_memory; + + virtgpu * gpu; +}; + +struct ggml_backend_remoting_buffer_context { + apir_buffer_context_t apir_context; + + virtgpu * gpu; + + void * base; + + bool is_from_ptr; +}; + +extern const ggml_backend_buffer_type_i ggml_backend_remoting_buffer_type_interface; +extern const ggml_backend_device_i ggml_backend_remoting_device_interface; +extern const ggml_backend_buffer_i ggml_backend_remoting_buffer_interface; +extern const ggml_backend_buffer_type_i ggml_backend_remoting_buffer_from_ptr_type_interface; +extern const ggml_backend_buffer_i ggml_backend_remoting_buffer_from_ptr_interface; + +ggml_backend_dev_t ggml_backend_remoting_get_device(size_t device); +ggml_backend_t ggml_backend_remoting_device_init(ggml_backend_dev_t dev, const char * params); +ggml_backend_buffer_type_t ggml_backend_remoting_device_get_buffer_type(ggml_backend_dev_t dev); + +static inline apir_buffer_type_host_handle_t ggml_buffer_type_to_apir_handle(ggml_backend_buffer_type_t buft) { + // in the backend, the buffer handle is the buffer pointer + return (apir_buffer_type_host_handle_t) buft->context; +} + +static inline apir_buffer_host_handle_t ggml_buffer_to_apir_handle(ggml_backend_buffer_t buffer) { + if (!buffer->context) { + GGML_ABORT(GGML_VIRTGPU "%s: no context available :/", __func__); + } + return BUFFER_TO_HOST_HANDLE(buffer); +} diff --git a/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml b/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml new file mode 100644 index 00000000000..14ef2433e46 --- /dev/null +++ b/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml @@ -0,0 +1,166 @@ +# YAML schema for GGML remoting API functions +# This defines the structure for generating the remoting layer code + +# Configuration for the generated files +config: + # Base path for the generated files + base_path: "ggml/src" + + # Header files to update + files: + apir_backend_header: "ggml-virtgpu-apir/backend/shared/apir_backend.gen.h" + backend_dispatched_header: "ggml-virtgpu-apir/backend/backend-dispatched.gen.h" + virtgpu_forward_header: "ggml-virtgpu-apir/virtgpu-forward.gen.h" + +# Simplified function definitions with grouping and metadata combined +functions: + device: + group_description: "device" + functions: + get_device_count: + # No specific metadata - uses default void return and base params + + get_count: + frontend_return: "int" + + get_name: + frontend_return: "char *" + + get_description: + frontend_return: "char *" + + get_type: + frontend_return: "uint32_t" + + get_memory: + frontend_return: "void" + frontend_extra_params: + - "size_t *free" + - "size_t *total" + + supports_op: + frontend_return: "bool" + frontend_extra_params: + - "const ggml_tensor *op" + + get_buffer_type: + frontend_return: "apir_buffer_type_host_handle_t" + + get_props: + frontend_return: "void" + frontend_extra_params: + - "bool *async" + - "bool *host_buffer" + - "bool *buffer_from_host_ptr" + - "bool *events" + + buffer_from_ptr: + frontend_return: "apir_buffer_context_t" + frontend_extra_params: + - "size_t size" + - "size_t max_tensor_size" + + buffer_type: + group_description: "buffer-type" + functions: + get_name: + frontend_return: "char *" + frontend_extra_params: + - "apir_buffer_type_host_handle_t host_handle" + + get_alignment: + frontend_return: "size_t" + frontend_extra_params: + - "apir_buffer_type_host_handle_t host_handle" + + get_max_size: + frontend_return: "size_t" + frontend_extra_params: + - "apir_buffer_type_host_handle_t host_handle" + + is_host: + deprecated: true + + alloc_buffer: + frontend_return: "apir_buffer_context_t" + frontend_extra_params: + - "apir_buffer_type_host_handle_t host_handle" + - "size_t size" + + get_alloc_size: + frontend_return: "size_t" + frontend_extra_params: + - "apir_buffer_type_host_handle_t host_handle" + - "const ggml_tensor *op" + + buffer: + group_description: "buffer" + functions: + get_base: + frontend_return: "void *" + frontend_extra_params: + - "apir_buffer_context_t *buffer_context" + + set_tensor: + frontend_return: "void" + frontend_extra_params: + - "apir_buffer_context_t *buffer_context" + - "ggml_tensor *tensor" + - "const void *data" + - "size_t offset" + - "size_t size" + + get_tensor: + frontend_return: "void" + frontend_extra_params: + - "apir_buffer_context_t *buffer_context" + - "const ggml_tensor *tensor" + - "void *data" + - "size_t offset" + - "size_t size" + + cpy_tensor: + frontend_return: "bool" + frontend_extra_params: + - "apir_buffer_context_t *buffer_context" + - "const ggml_tensor *src" + - "const ggml_tensor *dst" + + clear: + frontend_return: "void" + frontend_extra_params: + - "apir_buffer_context_t *buffer_context" + - "uint8_t value" + + free_buffer: + frontend_return: "void" + frontend_extra_params: + - "apir_buffer_context_t *buffer_context" + + backend: + group_description: "backend" + functions: + graph_compute: + frontend_return: "ggml_status" + frontend_extra_params: + - "ggml_cgraph *cgraph" + + graph_optimize: + frontend_return: "ggml_cgraph *" + frontend_extra_params: + - "ggml_cgraph *cgraph" + enabled: false + +# Naming patterns used for code generation +naming_patterns: + # How to generate enum names + enum_prefix: "APIR_COMMAND_TYPE_" + + # How to generate backend function names + backend_function_prefix: "backend_" + + # How to generate frontend function names + frontend_function_prefix: "apir_" + + # Standard frontend first parameter + frontend_base_param: "struct virtgpu *gpu" diff --git a/ggml/src/ggml-virtgpu/include/apir_hw.h b/ggml/src/ggml-virtgpu/include/apir_hw.h new file mode 100644 index 00000000000..33af045ca2b --- /dev/null +++ b/ggml/src/ggml-virtgpu/include/apir_hw.h @@ -0,0 +1,9 @@ +#pragma once + +#include + +struct virgl_renderer_capset_apir { + uint32_t apir_version; + uint32_t supports_blob_resources; + uint32_t reserved[4]; // For future expansion +}; diff --git a/ggml/src/ggml-virtgpu/regenerate_remoting.py b/ggml/src/ggml-virtgpu/regenerate_remoting.py new file mode 100755 index 00000000000..aeb48a4087e --- /dev/null +++ b/ggml/src/ggml-virtgpu/regenerate_remoting.py @@ -0,0 +1,332 @@ +#!/usr/bin/env python3 +""" +# Generated by Claude AI + +Script to completely regenerate the GGML remoting codebase from YAML configuration. + +This script reads api_functions.yaml and regenerates all the header files and +implementation templates for the GGML remoting layer. + +Usage: + python regenerate_remoting.py + +The script will: +1. Read ggmlremoting_functions.yaml configuration +2. Generate updated header files +3. Generate implementation templates in dedicated files +4. Show a summary of what was generated +""" + +import yaml +from typing import Dict, List, Any +from pathlib import Path +import os +import subprocess +import shutil +import logging + +NL = '\n' # can't have f"{'\n'}" in f-strings + + +class RemotingCodebaseGenerator: + def __init__(self, yaml_path: str = "ggmlremoting_functions.yaml"): + """Initialize the generator with the YAML configuration.""" + self.yaml_path = yaml_path + + if not Path(yaml_path).exists(): + raise FileNotFoundError(f"Configuration file {yaml_path} not found") + + with open(yaml_path, 'r') as f: + self.config = yaml.safe_load(f) + + self.functions = self.config['functions'] + self.naming_patterns = self.config['naming_patterns'] + self.config_data = self.config['config'] + + # Check if clang-format is available + self.clang_format_available = self._check_clang_format_available() + + def _check_clang_format_available(self) -> bool: + """Check if clang-format is available in the system PATH.""" + return shutil.which("clang-format") is not None + + def _format_file_with_clang_format(self, file_path: Path) -> bool: + """Format a file with clang-format -i. Returns True if successful, False otherwise.""" + if not self.clang_format_available: + return False + + try: + subprocess.run( + ["clang-format", "-i", str(file_path)], + check=True, + capture_output=True, + text=True + ) + return True + except subprocess.CalledProcessError: + logging.exception(f" ⚠️ clang-format failed for {file_path}") + return False + except Exception as e: + logging.exception(f" ⚠️ Unexpected error formatting {file_path}: {e}") + return False + + def generate_enum_name(self, group_name: str, function_name: str) -> str: + """Generate the APIR_COMMAND_TYPE enum name for a function.""" + prefix = self.naming_patterns['enum_prefix'] + return f"{prefix}{group_name.upper()}_{function_name.upper()}" + + def generate_backend_function_name(self, group_name: str, function_name: str) -> str: + """Generate the backend function name.""" + function_key = f"{group_name}_{function_name}" + overrides = self.naming_patterns.get('backend_function_overrides', {}) + + if function_key in overrides: + return overrides[function_key] + + prefix = self.naming_patterns['backend_function_prefix'] + return f"{prefix}{group_name}_{function_name}" + + def generate_frontend_function_name(self, group_name: str, function_name: str) -> str: + """Generate the frontend function name.""" + prefix = self.naming_patterns['frontend_function_prefix'] + return f"{prefix}{group_name}_{function_name}" + + def get_enabled_functions(self) -> List[Dict[str, Any]]: + """Get all enabled functions with their metadata.""" + functions = [] + enum_value = 0 + + for group_name, group_data in self.functions.items(): + group_description = group_data['group_description'] + + for function_name, func_metadata in group_data['functions'].items(): + # Handle case where func_metadata is None or empty (functions with only comments) + if func_metadata is None: + func_metadata = {} + + # Functions are enabled by default unless explicitly disabled + if func_metadata.get('enabled', True): + functions.append({ + 'group_name': group_name, + 'function_name': function_name, + 'enum_name': self.generate_enum_name(group_name, function_name), + 'enum_value': enum_value, + 'backend_function': self.generate_backend_function_name(group_name, function_name), + 'frontend_function': self.generate_frontend_function_name(group_name, function_name), + 'frontend_return': func_metadata.get('frontend_return', 'void'), + 'frontend_extra_params': func_metadata.get('frontend_extra_params', []), + 'group_description': group_description, + 'deprecated': func_metadata.get('deprecated', False), + }) + enum_value += 1 + + return functions + + def generate_apir_backend_header(self) -> str: + """Generate the complete apir_backend.h file.""" + functions = self.get_enabled_functions() + + # Generate the enum section + enum_lines = ["typedef enum ApirBackendCommandType {"] + current_group = None + + for func in functions: + # Add comment for new group + if func['group_name'] != current_group: + enum_lines.append("") + enum_lines.append(f" /* {func['group_description']} */") + current_group = func['group_name'] + + enum_lines.append(f" {func['enum_name']} = {func['enum_value']},") + + # Add the count + total_count = len(functions) + enum_lines.append("\n // last command_type index + 1") + enum_lines.append(f" APIR_BACKEND_DISPATCH_TABLE_COUNT = {total_count},") + enum_lines.append("} ApirBackendCommandType;") + + # Full header template + header_content = NL.join(enum_lines) + "\n" + + return header_content + + def generate_backend_dispatched_header(self) -> str: + """Generate the complete backend-dispatched.h file.""" + functions = self.get_enabled_functions() + + # Function declarations + decl_lines = [] + current_group = None + + for func in functions: + if func['group_name'] != current_group: + decl_lines.append(f"\n/* {func['group_description']} */") + current_group = func['group_name'] + + signature = "uint32_t" + params = "apir_encoder *enc, apir_decoder *dec, virgl_apir_context *ctx" + if func['deprecated']: + decl_lines.append(f"/* {func['enum_name']} is deprecated. Keeping the handler for backward compatibility. */") + + decl_lines.append(f"{signature} {func['backend_function']}({params});") + + # Switch cases + switch_lines = [] + current_group = None + + for func in functions: + if func['group_name'] != current_group: + switch_lines.append(f" /* {func['group_description']} */") + current_group = func['group_name'] + + deprecated = " (DEPRECATED)" if func['deprecated'] else "" + + switch_lines.append(f" case {func['enum_name']}: return \"{func['backend_function']}{deprecated}\";") + + # Dispatch table + table_lines = [] + current_group = None + + for func in functions: + if func['group_name'] != current_group: + table_lines.append(f"\n /* {func['group_description']} */") + table_lines.append("") + current_group = func['group_name'] + + deprecated = " /* DEPRECATED */" if func['deprecated'] else "" + table_lines.append(f" /* {func['enum_name']} = */ {func['backend_function']}{deprecated},") + + header_content = f'''\ +#pragma once + +{NL.join(decl_lines)} + +static inline const char *backend_dispatch_command_name(ApirBackendCommandType type) +{{ + switch (type) {{ +{NL.join(switch_lines)} + + default: return "unknown"; + }} +}} + +extern "C" {{ +static const backend_dispatch_t apir_backend_dispatch_table[APIR_BACKEND_DISPATCH_TABLE_COUNT] = {{ + {NL.join(table_lines)} +}}; +}} +''' + return header_content + + def generate_virtgpu_forward_header(self) -> str: + """Generate the complete virtgpu-forward.gen.h file.""" + functions = self.get_enabled_functions() + + decl_lines = [] + current_group = None + + for func in functions: + if func['group_name'] != current_group: + decl_lines.append("") + decl_lines.append(f"/* {func['group_description']} */") + current_group = func['group_name'] + + if func['deprecated']: + decl_lines.append(f"/* {func['frontend_function']} is deprecated. */") + continue + + # Build parameter list + params = [self.naming_patterns['frontend_base_param']] + params.extend(func['frontend_extra_params']) + param_str = ', '.join(params) + + decl_lines.append(f"{func['frontend_return']} {func['frontend_function']}({param_str});") + + header_content = f'''\ +#pragma once +{NL.join(decl_lines)} +''' + return header_content + + def regenerate_codebase(self) -> None: + """Regenerate the entire remoting codebase.""" + logging.info("🔄 Regenerating GGML Remoting Codebase...") + logging.info("=" * 50) + + # Detect if we're running from frontend directory + current_dir = os.getcwd() + is_frontend_dir = current_dir.endswith('ggml-virtgpu') + + if is_frontend_dir: + # Running from ggml/src/ggml-virtgpu-apir + logging.info("📍 Detected frontend directory execution") + frontend_base = Path(".") + else: + # Running from project root (fallback to original behavior) + logging.info("📍 Detected project root execution") + base_path = self.config_data.get('base_path', 'ggml/src') + frontend_base = Path(base_path) / "ggml-virtgpu" + + # Compute final file paths + backend_base = frontend_base / "backend" + apir_backend_path = backend_base / "shared" / "apir_backend.gen.h" + backend_dispatched_path = backend_base / "backend-dispatched.gen.h" + virtgpu_forward_path = frontend_base / "virtgpu-forward.gen.h" + + # Create output directories for each file + apir_backend_path.parent.mkdir(parents=True, exist_ok=True) + backend_dispatched_path.parent.mkdir(parents=True, exist_ok=True) + virtgpu_forward_path.parent.mkdir(parents=True, exist_ok=True) + + # Generate header files + logging.info("📁 Generating header files...") + + apir_backend_content = self.generate_apir_backend_header() + apir_backend_path.write_text(apir_backend_content) + logging.info(f" ✅ {apir_backend_path.resolve()}") + + backend_dispatched_content = self.generate_backend_dispatched_header() + backend_dispatched_path.write_text(backend_dispatched_content) + logging.info(f" ✅ {backend_dispatched_path.resolve()}") + + virtgpu_forward_content = self.generate_virtgpu_forward_header() + virtgpu_forward_path.write_text(virtgpu_forward_content) + logging.info(f" ✅ {virtgpu_forward_path.resolve()}") + + # Format generated files with clang-format + generated_files = [apir_backend_path, backend_dispatched_path, virtgpu_forward_path] + + if not self.clang_format_available: + logging.warning("\n⚠️clang-format not found in PATH. Generated files will not be formatted.\n" + " Install clang-format to enable automatic code formatting.") + else: + logging.info("\n🎨 Formatting files with clang-format...") + for file_path in generated_files: + if self._format_file_with_clang_format(file_path): + logging.info(f" ✅ Formatted {file_path.name}") + else: + logging.warning(f" ❌ Failed to format {file_path.name}") + + # Generate summary + functions = self.get_enabled_functions() + total_functions = len(functions) + + logging.info("\n📊 Generation Summary:") + logging.info("=" * 50) + logging.info(f" Total functions: {total_functions}") + logging.info(f" Function groups: {len(self.functions)}") + logging.info(" Header files: 3") + logging.info(f" Working directory: {current_dir}") + + +def main(): + try: + generator = RemotingCodebaseGenerator() + generator.regenerate_codebase() + except Exception as e: + logging.exception(f"❌ Error: {e}") + exit(1) + + +if __name__ == "__main__": + main() diff --git a/ggml/src/ggml-virtgpu/virtgpu-apir.h b/ggml/src/ggml-virtgpu/virtgpu-apir.h new file mode 100644 index 00000000000..238f960acd2 --- /dev/null +++ b/ggml/src/ggml-virtgpu/virtgpu-apir.h @@ -0,0 +1,15 @@ +#include "backend/shared/apir_backend.h" +#include "ggml-alloc.h" +#include "ggml-impl.h" +#include "ggml.h" +#include "virtgpu-shm.h" +#include "virtgpu-utils.h" + +struct apir_buffer_context_t { + apir_buffer_host_handle_t host_handle; + + struct virtgpu_shmem shmem; + apir_buffer_type_host_handle_t buft_host_handle; +}; + +#include "virtgpu-forward.gen.h" diff --git a/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp b/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp new file mode 100644 index 00000000000..07d9a668496 --- /dev/null +++ b/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp @@ -0,0 +1,58 @@ +#include "virtgpu-forward-impl.h" + +static long long current_time_ms() { + timespec ts; + clock_gettime(CLOCK_REALTIME, &ts); // Use CLOCK_MONOTONIC for elapsed time + return (long long) ts.tv_sec * 1000000000LL + ts.tv_nsec; +} + +ggml_status apir_backend_graph_compute(virtgpu * gpu, ggml_cgraph * cgraph) { + apir_encoder * encoder; + apir_decoder * decoder; + ApirForwardReturnCode ret; + + REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_BACKEND_GRAPH_COMPUTE); + + std::vector cgraph_data; + size_t cgraph_size = apir_serialize_ggml_cgraph(cgraph, cgraph_data); + + virtgpu_shmem temp_shmem; // Local storage for large buffers + virtgpu_shmem * shmem = &temp_shmem; + bool using_shared_shmem = false; + + if (cgraph_size <= gpu->data_shmem.mmap_size) { + // Lock mutex before using shared data_shmem buffer + if (mtx_lock(&gpu->data_shmem_mutex) != thrd_success) { + GGML_ABORT(GGML_VIRTGPU "%s: Failed to lock data_shmem mutex", __func__); + } + using_shared_shmem = true; + shmem = &gpu->data_shmem; + } else if (virtgpu_shmem_create(gpu, cgraph_size, shmem)) { + GGML_ABORT(GGML_VIRTGPU "%s: Couldn't allocate the guest-host shared buffer", __func__); + } + + apir_encode_virtgpu_shmem_res_id(encoder, shmem->res_id); + + apir_encode_size_t(encoder, &cgraph_size); + + char * shmem_data = (char *) shmem->mmap_ptr; + apir_encoder secondary_enc = apir_new_encoder(shmem_data, cgraph_size); + + apir_encode_cgraph_data(&secondary_enc, cgraph_data); + + REMOTE_CALL(gpu, encoder, decoder, ret); + + ggml_status status = GGML_STATUS_ABORTED; + apir_decode_ggml_status(decoder, &status); + + remote_call_finish(gpu, encoder, decoder); + + // Unlock mutex before cleanup + if (using_shared_shmem) { + mtx_unlock(&gpu->data_shmem_mutex); + } else { + virtgpu_shmem_destroy(gpu, shmem); + } + + return status; +} diff --git a/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp b/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp new file mode 100644 index 00000000000..cab74fd1707 --- /dev/null +++ b/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp @@ -0,0 +1,106 @@ +#include "virtgpu-forward-impl.h" + +char * apir_buffer_type_get_name(virtgpu * gpu, apir_buffer_type_host_handle_t host_handle) { + apir_encoder * encoder; + apir_decoder * decoder; + ApirForwardReturnCode ret; + + REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_BUFFER_TYPE_GET_NAME); + + apir_encode_apir_buffer_type_host_handle(encoder, host_handle); + + REMOTE_CALL(gpu, encoder, decoder, ret); + + const size_t string_size = apir_decode_array_size_unchecked(decoder); + char * string = (char *) apir_decoder_alloc_array(sizeof(char), string_size); + if (!string) { + GGML_LOG_ERROR(GGML_VIRTGPU "%s: Could not allocate the device name buffer\n", __func__); + apir_decoder_set_fatal(decoder); + } + apir_decode_char_array(decoder, string, string_size); + + remote_call_finish(gpu, encoder, decoder); + + return string; +} + +size_t apir_buffer_type_get_alignment(virtgpu * gpu, apir_buffer_type_host_handle_t host_handle) { + apir_encoder * encoder; + apir_decoder * decoder; + ApirForwardReturnCode ret; + + REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALIGNMENT); + + apir_encode_apir_buffer_type_host_handle(encoder, host_handle); + + REMOTE_CALL(gpu, encoder, decoder, ret); + + size_t alignment; + apir_decode_size_t(decoder, &alignment); + + remote_call_finish(gpu, encoder, decoder); + + return alignment; +} + +size_t apir_buffer_type_get_max_size(virtgpu * gpu, apir_buffer_type_host_handle_t host_handle) { + apir_encoder * encoder; + apir_decoder * decoder; + ApirForwardReturnCode ret; + + REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_BUFFER_TYPE_GET_MAX_SIZE); + + apir_encode_apir_buffer_type_host_handle(encoder, host_handle); + + REMOTE_CALL(gpu, encoder, decoder, ret); + + size_t max_size; + apir_decode_size_t(decoder, &max_size); + + remote_call_finish(gpu, encoder, decoder); + + return max_size; +} + +apir_buffer_context_t apir_buffer_type_alloc_buffer(virtgpu * gpu, apir_buffer_type_host_handle_t host_handle, size_t size) { + apir_encoder * encoder; + apir_decoder * decoder; + ApirForwardReturnCode ret; + + apir_buffer_context_t buffer_context; + + REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_BUFFER_TYPE_ALLOC_BUFFER); + + apir_encode_apir_buffer_type_host_handle(encoder, host_handle); + + apir_encode_size_t(encoder, &size); + + REMOTE_CALL(gpu, encoder, decoder, ret); + + apir_decode_apir_buffer_host_handle_t(decoder, &buffer_context.host_handle); + + remote_call_finish(gpu, encoder, decoder); + + return buffer_context; +} + +size_t apir_buffer_type_get_alloc_size(virtgpu * gpu, apir_buffer_type_host_handle_t host_handle, const ggml_tensor * op) { + apir_encoder * encoder; + apir_decoder * decoder; + ApirForwardReturnCode ret; + + REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALLOC_SIZE); + + apir_encode_apir_buffer_type_host_handle(encoder, host_handle); + + apir_encode_ggml_tensor_inline(encoder, op); + + REMOTE_CALL(gpu, encoder, decoder, ret); + + size_t alloc_size; + apir_decode_size_t(decoder, &alloc_size); + + remote_call_finish(gpu, encoder, decoder); + + return alloc_size; +} diff --git a/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp b/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp new file mode 100644 index 00000000000..86eee358cf4 --- /dev/null +++ b/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp @@ -0,0 +1,173 @@ +#include "virtgpu-forward-impl.h" + +void * apir_buffer_get_base(virtgpu * gpu, apir_buffer_context_t * buffer_context) { + apir_encoder * encoder; + apir_decoder * decoder; + ApirForwardReturnCode ret; + + REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_BUFFER_GET_BASE); + + apir_encode_apir_buffer_host_handle_t(encoder, &buffer_context->host_handle); + + REMOTE_CALL(gpu, encoder, decoder, ret); + + uintptr_t base; + apir_decode_uintptr_t(decoder, &base); + + remote_call_finish(gpu, encoder, decoder); + + return (void *) base; +} + +void apir_buffer_set_tensor(virtgpu * gpu, + apir_buffer_context_t * buffer_context, + ggml_tensor * tensor, + const void * data, + size_t offset, + size_t size) { + apir_encoder * encoder; + apir_decoder * decoder; + ApirForwardReturnCode ret; + + REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_BUFFER_SET_TENSOR); + + apir_encode_apir_buffer_host_handle_t(encoder, &buffer_context->host_handle); + apir_encode_ggml_tensor(encoder, tensor); + + virtgpu_shmem temp_shmem; // Local storage for large buffers + virtgpu_shmem * shmem = &temp_shmem; + bool using_shared_shmem = false; + + if (size <= gpu->data_shmem.mmap_size) { + // Lock mutex before using shared data_shmem buffer + if (mtx_lock(&gpu->data_shmem_mutex) != thrd_success) { + GGML_ABORT(GGML_VIRTGPU "%s: Failed to lock data_shmem mutex", __func__); + } + using_shared_shmem = true; + shmem = &gpu->data_shmem; + + } else if (virtgpu_shmem_create(gpu, size, shmem)) { + GGML_ABORT(GGML_VIRTGPU "%s: Couldn't allocate the guest-host shared buffer", __func__); + } + + memcpy(shmem->mmap_ptr, data, size); + apir_encode_virtgpu_shmem_res_id(encoder, shmem->res_id); + + apir_encode_size_t(encoder, &offset); + apir_encode_size_t(encoder, &size); + + REMOTE_CALL(gpu, encoder, decoder, ret); + + remote_call_finish(gpu, encoder, decoder); + + // Unlock mutex before cleanup + if (using_shared_shmem) { + mtx_unlock(&gpu->data_shmem_mutex); + } else { + virtgpu_shmem_destroy(gpu, shmem); + } + + return; +} + +void apir_buffer_get_tensor(virtgpu * gpu, + apir_buffer_context_t * buffer_context, + const ggml_tensor * tensor, + void * data, + size_t offset, + size_t size) { + apir_encoder * encoder; + apir_decoder * decoder; + ApirForwardReturnCode ret; + + REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_BUFFER_GET_TENSOR); + + apir_encode_apir_buffer_host_handle_t(encoder, &buffer_context->host_handle); + apir_encode_ggml_tensor(encoder, tensor); + + virtgpu_shmem temp_shmem; // Local storage for large buffers + virtgpu_shmem * shmem = &temp_shmem; + bool using_shared_shmem = false; + + if (size <= gpu->data_shmem.mmap_size) { + // Lock mutex before using shared data_shmem buffer + if (mtx_lock(&gpu->data_shmem_mutex) != thrd_success) { + GGML_ABORT(GGML_VIRTGPU "%s: Failed to lock data_shmem mutex", __func__); + } + using_shared_shmem = true; + shmem = &gpu->data_shmem; + + } else if (virtgpu_shmem_create(gpu, size, shmem)) { + GGML_ABORT(GGML_VIRTGPU "%s: Couldn't allocate the guest-host shared buffer", __func__); + } + + apir_encode_virtgpu_shmem_res_id(encoder, shmem->res_id); + apir_encode_size_t(encoder, &offset); + apir_encode_size_t(encoder, &size); + + REMOTE_CALL(gpu, encoder, decoder, ret); + + memcpy(data, shmem->mmap_ptr, size); + + remote_call_finish(gpu, encoder, decoder); + + // Unlock mutex before cleanup + if (using_shared_shmem) { + mtx_unlock(&gpu->data_shmem_mutex); + } else { + virtgpu_shmem_destroy(gpu, shmem); + } +} + +bool apir_buffer_cpy_tensor(virtgpu * gpu, + apir_buffer_context_t * buffer_context, + const ggml_tensor * src, + const ggml_tensor * dst) { + apir_encoder * encoder; + apir_decoder * decoder; + ApirForwardReturnCode ret; + + REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_BUFFER_CPY_TENSOR); + + apir_encode_apir_buffer_host_handle_t(encoder, &buffer_context->host_handle); + apir_encode_ggml_tensor(encoder, src); + apir_encode_ggml_tensor(encoder, dst); + + REMOTE_CALL(gpu, encoder, decoder, ret); + + bool ret_val; + apir_decode_bool_t(decoder, &ret_val); + + remote_call_finish(gpu, encoder, decoder); + + return ret_val; +} + +void apir_buffer_clear(virtgpu * gpu, apir_buffer_context_t * buffer_context, uint8_t value) { + apir_encoder * encoder; + apir_decoder * decoder; + ApirForwardReturnCode ret; + + REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_BUFFER_CLEAR); + + apir_encode_apir_buffer_host_handle_t(encoder, &buffer_context->host_handle); + apir_encode_uint8_t(encoder, &value); + + REMOTE_CALL(gpu, encoder, decoder, ret); + + remote_call_finish(gpu, encoder, decoder); +} + +void apir_buffer_free_buffer(virtgpu * gpu, apir_buffer_context_t * buffer_context) { + apir_encoder * encoder; + apir_decoder * decoder; + ApirForwardReturnCode ret; + + REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_BUFFER_FREE_BUFFER); + + apir_encode_apir_buffer_host_handle_t(encoder, &buffer_context->host_handle); + + REMOTE_CALL(gpu, encoder, decoder, ret); + + remote_call_finish(gpu, encoder, decoder); +} diff --git a/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp b/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp new file mode 100644 index 00000000000..4b6b8f527be --- /dev/null +++ b/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp @@ -0,0 +1,192 @@ +#include "virtgpu-forward-impl.h" +#include "virtgpu-shm.h" + +int apir_device_get_count(virtgpu * gpu) { + apir_encoder * encoder; + apir_decoder * decoder; + ApirForwardReturnCode ret; + + REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_DEVICE_GET_COUNT); + REMOTE_CALL(gpu, encoder, decoder, ret); + + int32_t dev_count = -1; + apir_decode_int32_t(decoder, &dev_count); + + remote_call_finish(gpu, encoder, decoder); + + return dev_count; +} + +char * apir_device_get_name(virtgpu * gpu) { + apir_encoder * encoder; + apir_decoder * decoder; + ApirForwardReturnCode ret; + + REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_DEVICE_GET_NAME); + REMOTE_CALL(gpu, encoder, decoder, ret); + + const size_t string_size = apir_decode_array_size_unchecked(decoder); + char * string = (char *) apir_decoder_alloc_array(sizeof(char), string_size); + if (!string) { + GGML_LOG_ERROR(GGML_VIRTGPU "%s: Could not allocate the device name buffer\n", __func__); + return NULL; + } + apir_decode_char_array(decoder, string, string_size); + + remote_call_finish(gpu, encoder, decoder); + + return string; +} + +char * apir_device_get_description(virtgpu * gpu) { + apir_encoder * encoder; + apir_decoder * decoder; + ApirForwardReturnCode ret; + + REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_DEVICE_GET_DESCRIPTION); + + REMOTE_CALL(gpu, encoder, decoder, ret); + + const size_t string_size = apir_decode_array_size_unchecked(decoder); + char * string = (char *) apir_decoder_alloc_array(sizeof(char), string_size); + if (!string) { + GGML_LOG_ERROR(GGML_VIRTGPU "%s: Could not allocate the device description buffer\n", __func__); + + return NULL; + } + apir_decode_char_array(decoder, string, string_size); + + remote_call_finish(gpu, encoder, decoder); + + return string; +} + +uint32_t apir_device_get_type(virtgpu * gpu) { + static uint32_t dev_type = 255; + if (dev_type != 255) { + return dev_type; + } + + apir_encoder * encoder; + apir_decoder * decoder; + ApirForwardReturnCode ret; + + REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_DEVICE_GET_TYPE); + + REMOTE_CALL(gpu, encoder, decoder, ret); + + apir_decode_uint32_t(decoder, &dev_type); + + remote_call_finish(gpu, encoder, decoder); + + return dev_type; +} + +void apir_device_get_memory(virtgpu * gpu, size_t * free, size_t * total) { + static size_t dev_free = 0; + static size_t dev_total = 0; + apir_encoder * encoder; + apir_decoder * decoder; + ApirForwardReturnCode ret; + + REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_DEVICE_GET_MEMORY); + + REMOTE_CALL(gpu, encoder, decoder, ret); + + apir_decode_size_t(decoder, &dev_free); + apir_decode_size_t(decoder, &dev_total); + + *free = dev_free; + *total = dev_total; + + remote_call_finish(gpu, encoder, decoder); + + return; +} + +bool apir_device_supports_op(virtgpu * gpu, const ggml_tensor * op) { + apir_encoder * encoder; + apir_decoder * decoder; + ApirForwardReturnCode ret; + + REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_DEVICE_SUPPORTS_OP); + + apir_encode_ggml_tensor_inline(encoder, op); + + REMOTE_CALL(gpu, encoder, decoder, ret); + + bool supports_op; + apir_decode_bool_t(decoder, &supports_op); + + remote_call_finish(gpu, encoder, decoder); + + return supports_op; +} + +apir_buffer_type_host_handle_t apir_device_get_buffer_type(virtgpu * gpu) { + apir_encoder * encoder; + apir_decoder * decoder; + ApirForwardReturnCode ret; + + REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_DEVICE_GET_BUFFER_TYPE); + + REMOTE_CALL(gpu, encoder, decoder, ret); + + apir_buffer_type_host_handle_t buft_handle; + apir_decode_apir_buffer_type_host_handle_t(decoder, &buft_handle); + + remote_call_finish(gpu, encoder, decoder); + + return buft_handle; +} + +void apir_device_get_props(virtgpu * gpu, + bool * async, + bool * host_buffer, + bool * buffer_from_host_ptr, + bool * events) { + apir_encoder * encoder; + apir_decoder * decoder; + ApirForwardReturnCode ret; + + REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_DEVICE_GET_PROPS); + + REMOTE_CALL(gpu, encoder, decoder, ret); + + apir_decode_bool_t(decoder, async); + apir_decode_bool_t(decoder, host_buffer); + apir_decode_bool_t(decoder, buffer_from_host_ptr); + apir_decode_bool_t(decoder, events); + + remote_call_finish(gpu, encoder, decoder); + + return; +} + +apir_buffer_context_t apir_device_buffer_from_ptr(virtgpu * gpu, size_t size, size_t max_tensor_size) { + apir_encoder * encoder; + apir_decoder * decoder; + ApirForwardReturnCode ret; + + apir_buffer_context_t buffer_context; + + REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_DEVICE_BUFFER_FROM_PTR); + + if (virtgpu_shmem_create(gpu, size, &buffer_context.shmem)) { + GGML_ABORT(GGML_VIRTGPU "Couldn't allocate the guest-host shared buffer"); + } + + apir_encode_virtgpu_shmem_res_id(encoder, buffer_context.shmem.res_id); + + apir_encode_size_t(encoder, &size); + apir_encode_size_t(encoder, &max_tensor_size); + + REMOTE_CALL(gpu, encoder, decoder, ret); + + apir_decode_apir_buffer_host_handle_t(decoder, &buffer_context.host_handle); + buffer_context.buft_host_handle = apir_decode_apir_buffer_type_host_handle(decoder); + + remote_call_finish(gpu, encoder, decoder); + + return buffer_context; +} diff --git a/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h b/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h new file mode 100644 index 00000000000..f23c75bb968 --- /dev/null +++ b/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h @@ -0,0 +1,29 @@ +#include "virtgpu.h" + +#include "ggml-remoting.h" +#include "backend/shared/apir_backend.h" +#include "backend/shared/apir_cs_ggml.h" + +#include "ggml-backend-impl.h" + +#define REMOTE_CALL_PREPARE(gpu_dev_name, encoder_name, apir_command_type__) \ + do { \ + int32_t forward_flag = (int32_t) apir_command_type__; \ + encoder_name = remote_call_prepare(gpu_dev_name, APIR_COMMAND_TYPE_FORWARD, forward_flag); \ + if (!encoder_name) { \ + GGML_ABORT(GGML_VIRTGPU "%s: failed to prepare the remote call encoder", __func__); \ + } \ + } while (0) + +#define REMOTE_CALL(gpu_dev_name, encoder_name, decoder_name, ret_name) \ + do { \ + ret_name = (ApirForwardReturnCode) remote_call(gpu_dev_name, encoder_name, &decoder_name, 0, NULL); \ + if (!decoder_name) { \ + GGML_ABORT(GGML_VIRTGPU "%s: failed to kick the remote call", __func__); \ + } \ + if (ret_name < APIR_FORWARD_BASE_INDEX) { \ + GGML_ABORT(GGML_VIRTGPU "%s: failed to forward the API call: %s: code %d", __func__, \ + apir_forward_error(ret_name), ret_name); \ + } \ + ret_name = (ApirForwardReturnCode) (ret_name - APIR_FORWARD_BASE_INDEX); \ + } while (0) diff --git a/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h b/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h new file mode 100644 index 00000000000..fe4cae20253 --- /dev/null +++ b/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h @@ -0,0 +1,52 @@ +#pragma once + +/* device */ +void apir_device_get_device_count(struct virtgpu * gpu); +int apir_device_get_count(struct virtgpu * gpu); +char * apir_device_get_name(struct virtgpu * gpu); +char * apir_device_get_description(struct virtgpu * gpu); +uint32_t apir_device_get_type(struct virtgpu * gpu); +void apir_device_get_memory(struct virtgpu * gpu, size_t * free, size_t * total); +bool apir_device_supports_op(struct virtgpu * gpu, const ggml_tensor * op); +apir_buffer_type_host_handle_t apir_device_get_buffer_type(struct virtgpu * gpu); +void apir_device_get_props(struct virtgpu * gpu, + bool * async, + bool * host_buffer, + bool * buffer_from_host_ptr, + bool * events); +apir_buffer_context_t apir_device_buffer_from_ptr(struct virtgpu * gpu, size_t size, size_t max_tensor_size); + +/* buffer-type */ +char * apir_buffer_type_get_name(struct virtgpu * gpu, apir_buffer_type_host_handle_t host_handle); +size_t apir_buffer_type_get_alignment(struct virtgpu * gpu, apir_buffer_type_host_handle_t host_handle); +size_t apir_buffer_type_get_max_size(struct virtgpu * gpu, apir_buffer_type_host_handle_t host_handle); +apir_buffer_context_t apir_buffer_type_alloc_buffer(struct virtgpu * gpu, + apir_buffer_type_host_handle_t host_handle, + size_t size); +size_t apir_buffer_type_get_alloc_size(struct virtgpu * gpu, + apir_buffer_type_host_handle_t host_handle, + const ggml_tensor * op); + +/* buffer */ +void * apir_buffer_get_base(struct virtgpu * gpu, apir_buffer_context_t * buffer_context); +void apir_buffer_set_tensor(struct virtgpu * gpu, + apir_buffer_context_t * buffer_context, + ggml_tensor * tensor, + const void * data, + size_t offset, + size_t size); +void apir_buffer_get_tensor(struct virtgpu * gpu, + apir_buffer_context_t * buffer_context, + const ggml_tensor * tensor, + void * data, + size_t offset, + size_t size); +bool apir_buffer_cpy_tensor(struct virtgpu * gpu, + apir_buffer_context_t * buffer_context, + const ggml_tensor * src, + const ggml_tensor * dst); +void apir_buffer_clear(struct virtgpu * gpu, apir_buffer_context_t * buffer_context, uint8_t value); +void apir_buffer_free_buffer(struct virtgpu * gpu, apir_buffer_context_t * buffer_context); + +/* backend */ +ggml_status apir_backend_graph_compute(struct virtgpu * gpu, ggml_cgraph * cgraph); diff --git a/ggml/src/ggml-virtgpu/virtgpu-shm.cpp b/ggml/src/ggml-virtgpu/virtgpu-shm.cpp new file mode 100644 index 00000000000..ce6b3b3e607 --- /dev/null +++ b/ggml/src/ggml-virtgpu/virtgpu-shm.cpp @@ -0,0 +1,98 @@ +#include "virtgpu-shm.h" + +#include "virtgpu.h" + +#include + +static uint32_t virtgpu_ioctl_resource_create_blob(virtgpu * gpu, + uint32_t blob_mem, + uint32_t blob_flags, + size_t blob_size, + uint64_t blob_id, + uint32_t * res_id) { +#ifdef SIMULATE_BO_SIZE_FIX + blob_size = align64(blob_size, 4096); +#endif + + drm_virtgpu_resource_create_blob args = { + .blob_mem = blob_mem, + .blob_flags = blob_flags, + .bo_handle = 0, + .res_handle = 0, + .size = blob_size, + .pad = 0, + .cmd_size = 0, + .cmd = 0, + .blob_id = blob_id, + }; + + if (virtgpu_ioctl(gpu, DRM_IOCTL_VIRTGPU_RESOURCE_CREATE_BLOB, &args)) { + return 0; + } + + *res_id = args.res_handle; + return args.bo_handle; +} + +static void virtgpu_ioctl_gem_close(virtgpu * gpu, uint32_t gem_handle) { + drm_gem_close args = { + .handle = gem_handle, + .pad = 0, + }; + + const int ret = virtgpu_ioctl(gpu, DRM_IOCTL_GEM_CLOSE, &args); + assert(!ret); +#ifdef NDEBUG + UNUSED(ret); +#endif +} + +static void * virtgpu_ioctl_map(virtgpu * gpu, uint32_t gem_handle, size_t size) { + drm_virtgpu_map args = { + .offset = 0, + .handle = gem_handle, + .pad = 0, + }; + + if (virtgpu_ioctl(gpu, DRM_IOCTL_VIRTGPU_MAP, &args)) { + return NULL; + } + + void * ptr = mmap(NULL, size, PROT_READ | PROT_WRITE, MAP_SHARED, gpu->fd, args.offset); + if (ptr == MAP_FAILED) { + return NULL; + } + + return ptr; +} + +void virtgpu_shmem_destroy(virtgpu * gpu, virtgpu_shmem * shmem) { + munmap(shmem->mmap_ptr, shmem->mmap_size); + virtgpu_ioctl_gem_close(gpu, shmem->gem_handle); +} + +int virtgpu_shmem_create(virtgpu * gpu, size_t size, virtgpu_shmem * shmem) { + size = align64(size, 16384); + + uint32_t res_id; + uint32_t gem_handle = virtgpu_ioctl_resource_create_blob(gpu, VIRTGPU_BLOB_MEM_HOST3D, + VIRTGPU_BLOB_FLAG_USE_MAPPABLE, size, 0, &res_id); + + if (!gem_handle) { + return 1; + } + + void * ptr = virtgpu_ioctl_map(gpu, gem_handle, size); + if (!ptr) { + virtgpu_ioctl_gem_close(gpu, gem_handle); + GGML_LOG_ERROR(GGML_VIRTGPU "%s: virtgpu_ioctl_map failed\n", __func__); + return 1; + } + + shmem->res_id = res_id; + shmem->mmap_size = size; + shmem->mmap_ptr = ptr; + shmem->gem_handle = gem_handle; + + return 0; +} diff --git a/ggml/src/ggml-virtgpu/virtgpu-shm.h b/ggml/src/ggml-virtgpu/virtgpu-shm.h new file mode 100644 index 00000000000..606860a0946 --- /dev/null +++ b/ggml/src/ggml-virtgpu/virtgpu-shm.h @@ -0,0 +1,23 @@ +#pragma once + +#include "virtgpu-utils.h" + +#include + +#include +#include +#include +#include + +struct virtgpu; + +struct virtgpu_shmem { + uint32_t res_id; + size_t mmap_size; + void * mmap_ptr; + + uint32_t gem_handle; +}; + +int virtgpu_shmem_create(virtgpu * gpu, size_t size, virtgpu_shmem * shmem); +void virtgpu_shmem_destroy(virtgpu * gpu, virtgpu_shmem * shmem); diff --git a/ggml/src/ggml-virtgpu/virtgpu-utils.cpp b/ggml/src/ggml-virtgpu/virtgpu-utils.cpp new file mode 100644 index 00000000000..8a2805e9902 --- /dev/null +++ b/ggml/src/ggml-virtgpu/virtgpu-utils.cpp @@ -0,0 +1,179 @@ +#include "virtgpu-utils.h" + +#include +#include + +#include + +#define NODE_ALLOC_ALIGN 64 +#define NODE_PTR_MASK (~((uintptr_t) NODE_ALLOC_ALIGN - 1)) +#define NODE_LEVEL_MASK ((uintptr_t) NODE_ALLOC_ALIGN - 1) +#define NULL_NODE 0 + +#define os_malloc_aligned(_size, _align) _aligned_malloc(_size, _align) +#define os_free_aligned(_ptr) free(_ptr) +#define p_atomic_cmpxchg(v, old, _new) __sync_val_compare_and_swap((v), (old), (_new)) + +static inline uint64_t util_logbase2_64(uint64_t n) { +#if defined(HAVE___BUILTIN_CLZLL) + return ((sizeof(uint64_t) * 8 - 1) - __builtin_clzll(n | 1)); +#else + uint64_t pos = 0ull; + if (n >= 1ull << 32) { + n >>= 32; + pos += 32; + } + if (n >= 1ull << 16) { + n >>= 16; + pos += 16; + } + if (n >= 1ull << 8) { + n >>= 8; + pos += 8; + } + if (n >= 1ull << 4) { + n >>= 4; + pos += 4; + } + if (n >= 1ull << 2) { + n >>= 2; + pos += 2; + } + if (n >= 1ull << 1) { + pos += 1; + } + return pos; +#endif +} + +void util_sparse_array_init(util_sparse_array * arr, size_t elem_size, size_t node_size) { + memset(arr, 0, sizeof(*arr)); + arr->elem_size = elem_size; + arr->node_size_log2 = util_logbase2_64(node_size); + assert(node_size >= 2 && node_size == (1ull << arr->node_size_log2)); +} + +static inline void * os_malloc_aligned(size_t size, size_t alignment) { + void * ptr; + alignment = (alignment + sizeof(void *) - 1) & ~(sizeof(void *) - 1); + if (posix_memalign(&ptr, alignment, size) != 0) { + return NULL; + } + return ptr; +} + +static inline void * _util_sparse_array_node_data(uintptr_t handle) { + return (void *) (handle & NODE_PTR_MASK); +} + +static inline unsigned _util_sparse_array_node_level(uintptr_t handle) { + return handle & NODE_LEVEL_MASK; +} + +static inline void _util_sparse_array_node_finish(util_sparse_array * arr, uintptr_t node) { + if (_util_sparse_array_node_level(node) > 0) { + uintptr_t * children = (uintptr_t *) _util_sparse_array_node_data(node); + size_t node_size = 1ull << arr->node_size_log2; + for (size_t i = 0; i < node_size; i++) { + if (children[i]) { + _util_sparse_array_node_finish(arr, children[i]); + } + } + } + + os_free_aligned(_util_sparse_array_node_data(node)); +} + +static inline uintptr_t _util_sparse_array_node(void * data, unsigned level) { + assert(data != NULL); + assert(((uintptr_t) data & NODE_LEVEL_MASK) == 0); + assert((level & NODE_PTR_MASK) == 0); + return (uintptr_t) data | level; +} + +inline uintptr_t _util_sparse_array_node_alloc(util_sparse_array * arr, unsigned level) { + size_t size; + if (level == 0) { + size = arr->elem_size << arr->node_size_log2; + } else { + size = sizeof(uintptr_t) << arr->node_size_log2; + } + + void * data = os_malloc_aligned(size, NODE_ALLOC_ALIGN); + memset(data, 0, size); + + return _util_sparse_array_node(data, level); +} + +static inline uintptr_t _util_sparse_array_set_or_free_node(uintptr_t * node_ptr, uintptr_t cmp_node, uintptr_t node) { + uintptr_t prev_node = p_atomic_cmpxchg(node_ptr, cmp_node, node); + + if (prev_node != cmp_node) { + /* We lost the race. Free this one and return the one that was already + * allocated. + */ + os_free_aligned(_util_sparse_array_node_data(node)); + return prev_node; + } else { + return node; + } +} + +void * util_sparse_array_get(util_sparse_array * arr, uint64_t idx) { + const unsigned node_size_log2 = arr->node_size_log2; + uintptr_t root = p_atomic_read(&arr->root); + if (unlikely(!root)) { + unsigned root_level = 0; + uint64_t idx_iter = idx >> node_size_log2; + while (idx_iter) { + idx_iter >>= node_size_log2; + root_level++; + } + uintptr_t new_root = _util_sparse_array_node_alloc(arr, root_level); + root = _util_sparse_array_set_or_free_node(&arr->root, NULL_NODE, new_root); + } + + while (1) { + unsigned root_level = _util_sparse_array_node_level(root); + uint64_t root_idx = idx >> (root_level * node_size_log2); + if (likely(root_idx < (1ull << node_size_log2))) { + break; + } + + /* In this case, we have a root but its level is low enough that the + * requested index is out-of-bounds. + */ + uintptr_t new_root = _util_sparse_array_node_alloc(arr, root_level + 1); + + uintptr_t * new_root_children = (uintptr_t *) _util_sparse_array_node_data(new_root); + new_root_children[0] = root; + + /* We only add one at a time instead of the whole tree because it's + * easier to ensure correctness of both the tree building and the + * clean-up path. Because we're only adding one node we never have to + * worry about trying to free multiple things without freeing the old + * things. + */ + root = _util_sparse_array_set_or_free_node(&arr->root, root, new_root); + } + + void * node_data = _util_sparse_array_node_data(root); + unsigned node_level = _util_sparse_array_node_level(root); + while (node_level > 0) { + uint64_t child_idx = (idx >> (node_level * node_size_log2)) & ((1ull << node_size_log2) - 1); + + uintptr_t * children = (uintptr_t *) node_data; + uintptr_t child = p_atomic_read(&children[child_idx]); + + if (unlikely(!child)) { + child = _util_sparse_array_node_alloc(arr, node_level - 1); + child = _util_sparse_array_set_or_free_node(&children[child_idx], NULL_NODE, child); + } + + node_data = _util_sparse_array_node_data(child); + node_level = _util_sparse_array_node_level(child); + } + + uint64_t elem_idx = idx & ((1ull << node_size_log2) - 1); + return (void *) ((char *) node_data + (elem_idx * arr->elem_size)); +} diff --git a/ggml/src/ggml-virtgpu/virtgpu-utils.h b/ggml/src/ggml-virtgpu/virtgpu-utils.h new file mode 100644 index 00000000000..a0036b4e2bc --- /dev/null +++ b/ggml/src/ggml-virtgpu/virtgpu-utils.h @@ -0,0 +1,86 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define unlikely(x) __builtin_expect(!!(x), 0) +#define likely(x) __builtin_expect(!!(x), 1) + +#ifndef UNUSED +# define UNUSED(x) (void) (x) +#endif + +/** Checks is a value is a power of two. Does not handle zero. */ +#define IS_POT(v) (((v) & ((v) - 1)) == 0) + +/** Checks is a value is a power of two. Zero handled. */ +#define IS_POT_NONZERO(v) ((v) != 0 && IS_POT(v)) + +/** Align a value to a power of two */ +#define ALIGN_POT(x, pot_align) (((x) + (pot_align) - 1) & ~((pot_align) - 1)) + +#define p_atomic_read(_v) __atomic_load_n((_v), __ATOMIC_ACQUIRE) + +static inline bool util_is_power_of_two_nonzero64(uint64_t v) { + return IS_POT_NONZERO(v); +} + +static inline uint64_t align64(uint64_t value, uint64_t alignment) { + assert(util_is_power_of_two_nonzero64(alignment)); + return ALIGN_POT(value, alignment); +} + +struct list_head { + list_head * prev; + list_head * next; +}; + +struct util_sparse_array { + size_t elem_size; + unsigned node_size_log2; + + uintptr_t root; +}; + +void * util_sparse_array_get(util_sparse_array * arr, uint64_t idx); +void util_sparse_array_init(util_sparse_array * arr, size_t elem_size, size_t node_size); + +inline void os_time_sleep(int64_t usecs) { + timespec time; + time.tv_sec = usecs / 1000000; + time.tv_nsec = (usecs % 1000000) * 1000; + while (clock_nanosleep(CLOCK_MONOTONIC, 0, &time, &time) == EINTR) + ; +} + +struct timer_data { + long long start; + long long total; + long long count; +}; + +static inline void start_timer(timer_data * timer) { + timespec ts; + clock_gettime(CLOCK_MONOTONIC, &ts); + timer->start = (long long) ts.tv_sec * 1000000000LL + ts.tv_nsec; +} + +// returns the duration in ns +static inline long long stop_timer(timer_data * timer) { + timespec ts; + clock_gettime(CLOCK_MONOTONIC, &ts); + long long timer_end = (long long) ts.tv_sec * 1000000000LL + ts.tv_nsec; + + long long duration = (timer_end - timer->start); + timer->total += duration; + timer->count += 1; + + return duration; +} diff --git a/ggml/src/ggml-virtgpu/virtgpu.cpp b/ggml/src/ggml-virtgpu/virtgpu.cpp new file mode 100644 index 00000000000..1e650dc65b2 --- /dev/null +++ b/ggml/src/ggml-virtgpu/virtgpu.cpp @@ -0,0 +1,559 @@ +#include "virtgpu.h" + +#include +#include + +#include +#include +#include + +static virt_gpu_result_t virtgpu_open_device(virtgpu * gpu, const drmDevicePtr dev); +static virt_gpu_result_t virtgpu_open(virtgpu * gpu); + +static virt_gpu_result_t virtgpu_init_capset(virtgpu * gpu); +static virt_gpu_result_t virtgpu_init_context(virtgpu * gpu); + +static int virtgpu_ioctl_context_init(virtgpu * gpu, virgl_renderer_capset capset_id); +static int virtgpu_ioctl_get_caps(virtgpu * gpu, + virgl_renderer_capset id, + uint32_t version, + void * capset, + size_t capset_size); +static uint64_t virtgpu_ioctl_getparam(virtgpu * gpu, uint64_t param); +static void virtgpu_init_renderer_info(virtgpu * gpu); + +static void log_call_duration(long long call_duration_ns, const char * name); + +const uint64_t APIR_HANDSHAKE_MAX_WAIT_MS = 2 * 1000; // 2s +const uint64_t APIR_LOADLIBRARY_MAX_WAIT_MS = 60 * 1000; // 60s + +static int virtgpu_handshake(virtgpu * gpu) { + apir_encoder * encoder; + apir_decoder * decoder; + + encoder = remote_call_prepare(gpu, APIR_COMMAND_TYPE_HANDSHAKE, 0); + if (!encoder) { + GGML_ABORT(GGML_VIRTGPU "%s: failed to prepare the remote call encoder", __func__); + return 1; + } + + /* write handshake props */ + + uint32_t guest_major = APIR_PROTOCOL_MAJOR; + uint32_t guest_minor = APIR_PROTOCOL_MINOR; + apir_encode_uint32_t(encoder, &guest_major); + apir_encode_uint32_t(encoder, &guest_minor); + + /* *** */ + + uint32_t ret_magic; + long long call_duration_ns; + ret_magic = remote_call(gpu, encoder, &decoder, APIR_HANDSHAKE_MAX_WAIT_MS, &call_duration_ns); + log_call_duration(call_duration_ns, "API Remoting handshake"); + + if (!decoder) { + GGML_ABORT(GGML_VIRTGPU + "%s: failed to initiate the communication with the virglrenderer library. " + "Most likely, the wrong virglrenderer library was loaded in the hypervisor.", + __func__); + return 1; + } + + /* read handshake return values */ + + uint32_t host_major; + uint32_t host_minor; + + if (ret_magic != APIR_HANDSHAKE_MAGIC) { + GGML_ABORT(GGML_VIRTGPU + "%s: handshake with the virglrenderer failed (code=%d | %s)", __func__, ret_magic, + apir_backend_initialize_error(ret_magic)); + } else { + apir_decode_uint32_t(decoder, &host_major); + apir_decode_uint32_t(decoder, &host_minor); + } + + remote_call_finish(gpu, encoder, decoder); + + if (ret_magic != APIR_HANDSHAKE_MAGIC) { + return 1; + } + + GGML_LOG_INFO(GGML_VIRTGPU "%s: Guest is running with %u.%u\n", __func__, guest_major, guest_minor); + GGML_LOG_INFO(GGML_VIRTGPU "%s: Host is running with %u.%u\n", __func__, host_major, host_minor); + + if (guest_major != host_major) { + GGML_LOG_ERROR(GGML_VIRTGPU "Host major (%d) and guest major (%d) version differ\n", host_major, guest_major); + } else if (guest_minor != host_minor) { + GGML_LOG_WARN(GGML_VIRTGPU "Host minor (%d) and guest minor (%d) version differ\n", host_minor, guest_minor); + } + + return 0; +} + +static ApirLoadLibraryReturnCode virtgpu_load_library(virtgpu * gpu) { + apir_encoder * encoder; + apir_decoder * decoder; + ApirLoadLibraryReturnCode ret; + + encoder = remote_call_prepare(gpu, APIR_COMMAND_TYPE_LOADLIBRARY, 0); + if (!encoder) { + GGML_ABORT(GGML_VIRTGPU "%s: hypercall error: failed to prepare the API Remoting command encoder", __func__); + return APIR_LOAD_LIBRARY_HYPERCALL_INITIALIZATION_ERROR; + } + + long long call_duration_ns; + + ret = (ApirLoadLibraryReturnCode) remote_call(gpu, encoder, &decoder, APIR_LOADLIBRARY_MAX_WAIT_MS, + &call_duration_ns); + log_call_duration(call_duration_ns, "API Remoting LoadLibrary"); + + if (!decoder) { + GGML_ABORT(GGML_VIRTGPU "%s: hypercall error: failed to trigger the API Remoting hypercall.\n", __func__); + return APIR_LOAD_LIBRARY_HYPERCALL_INITIALIZATION_ERROR; + } + + remote_call_finish(gpu, encoder, decoder); + + if (ret == APIR_LOAD_LIBRARY_SUCCESS) { + GGML_LOG_INFO(GGML_VIRTGPU "The API Remoting backend was successfully loaded and initialized\n"); + + return ret; + } + + // something wrong happened, find out what. + if (ret < APIR_LOAD_LIBRARY_INIT_BASE_INDEX) { + if (ret == APIR_LOAD_LIBRARY_ENV_VAR_MISSING) { + GGML_ABORT(GGML_VIRTGPU + "%s: virglrenderer could not open the API Remoting backend library, " + "some environment variables are missing. " + "Make sure virglrenderer is correctly configured by the hypervisor. (%s)", + __func__, apir_load_library_error(ret)); + } else if (ret == APIR_LOAD_LIBRARY_CANNOT_OPEN) { + GGML_ABORT(GGML_VIRTGPU + "%s: virglrenderer could not open the API Remoting backend library. " + "Make sure virglrenderer is correctly configured by the hypervisor. (%s)", + __func__, apir_load_library_error(ret)); + } else if (ret == APIR_LOAD_LIBRARY_ENV_VAR_MISSING) { + GGML_ABORT(GGML_VIRTGPU + "%s: could not load the backend library, some symbols are missing. " + "Make sure virglrenderer is correctly configured by the hypervisor. (%s) ", + __func__, apir_load_library_error(ret)); + } else { + GGML_ABORT(GGML_VIRTGPU + "%s: virglrenderer could not load the API Remoting backend library. (%s - code %d)", __func__, + apir_load_library_error(ret), ret); + } + return ret; + } + + GGML_LOG_INFO(GGML_VIRTGPU + "%s: virglrenderer successfully loaded the API Remoting backend library.\n", __func__); + + ApirLoadLibraryReturnCode apir_ret = (ApirLoadLibraryReturnCode) (ret - APIR_LOAD_LIBRARY_INIT_BASE_INDEX); + + if (apir_ret == APIR_LOAD_LIBRARY_CANNOT_OPEN) { + GGML_ABORT(GGML_VIRTGPU + "%s: the API Remoting backend library couldn't load the GGML backend library. " + "Make sure virglrenderer is correctly configured by the hypervisor. (%s)", + __func__, apir_load_library_error(apir_ret)); + } else if (apir_ret == APIR_LOAD_LIBRARY_SYMBOL_MISSING) { + GGML_ABORT(GGML_VIRTGPU + "%s: the API Remoting backend library couldn't load the GGML backend library, some symbols are missing. " + "Make sure virglrenderer is correctly configured by the hypervisor. (%s)", + __func__, apir_load_library_error(apir_ret)); + } else if (apir_ret < APIR_LOAD_LIBRARY_INIT_BASE_INDEX) { + GGML_ABORT(GGML_VIRTGPU + "%s: the API Remoting backend library couldn't load the GGML backend library: apir code=%d | %s)", + __func__, apir_ret, apir_load_library_error(apir_ret)); + } else { + uint32_t lib_ret = apir_ret - APIR_LOAD_LIBRARY_INIT_BASE_INDEX; + GGML_ABORT(GGML_VIRTGPU + "%s: the API Remoting backend library initialize its backend library: apir code=%d)", __func__, + lib_ret); + } + return ret; +} + +virtgpu * create_virtgpu() { + virtgpu * gpu = new virtgpu(); + + gpu->use_apir_capset = getenv("GGML_REMOTING_USE_APIR_CAPSET") != nullptr; + util_sparse_array_init(&gpu->shmem_array, sizeof(virtgpu_shmem), 1024); + + // Initialize mutex to protect shared data_shmem buffer + if (mtx_init(&gpu->data_shmem_mutex, mtx_plain) != thrd_success) { + delete gpu; + GGML_ABORT(GGML_VIRTGPU + "%s: failed to initialize data_shmem mutex", __func__); + return NULL; + } + + if (virtgpu_open(gpu) != APIR_SUCCESS) { + GGML_LOG_ERROR(GGML_VIRTGPU + "%s: failed to open the virtgpu device\n", __func__); + return NULL; + } + + if (virtgpu_init_capset(gpu) != APIR_SUCCESS) { + if (gpu->use_apir_capset) { + GGML_ABORT(GGML_VIRTGPU + "%s: failed to initialize the virtgpu APIR capset. Make sure that the virglrenderer library supports it.", __func__); + } else { + GGML_ABORT(GGML_VIRTGPU + "%s: failed to initialize the virtgpu Venus capset", __func__); + } + return NULL; + } + + if (virtgpu_init_context(gpu) != APIR_SUCCESS) { + GGML_ABORT(GGML_VIRTGPU + "%s: failed to initialize the GPU context", __func__); + return NULL; + } + + if (virtgpu_shmem_create(gpu, SHMEM_REPLY_SIZE, &gpu->reply_shmem)) { + GGML_ABORT(GGML_VIRTGPU + "%s: failed to create the shared reply memory pages", __func__); + return NULL; + } + + if (virtgpu_shmem_create(gpu, SHMEM_DATA_SIZE, &gpu->data_shmem)) { + GGML_ABORT(GGML_VIRTGPU + "%s: failed to create the shared data memory pages", __func__); + return NULL; + } + + if (virtgpu_handshake(gpu)) { + GGML_ABORT(GGML_VIRTGPU + "%s: failed to handshake with the virglrenderer library", __func__); + return NULL; + } + + if (virtgpu_load_library(gpu) != APIR_LOAD_LIBRARY_SUCCESS) { + GGML_ABORT(GGML_VIRTGPU + "%s: failed to load the backend library", __func__); + return NULL; + } + + return gpu; +} + +static virt_gpu_result_t virtgpu_open(virtgpu * gpu) { + drmDevicePtr devs[8]; + int count = drmGetDevices2(0, devs, ARRAY_SIZE(devs)); + if (count < 0) { + GGML_LOG_ERROR(GGML_VIRTGPU + "%s: failed to enumerate DRM devices\n", __func__); + return APIR_ERROR_INITIALIZATION_FAILED; + } + + virt_gpu_result_t result = APIR_ERROR_INITIALIZATION_FAILED; + for (int i = 0; i < count; i++) { + result = virtgpu_open_device(gpu, devs[i]); + if (result == APIR_SUCCESS) { + break; + } + } + + drmFreeDevices(devs, count); + + return result; +} + +static virt_gpu_result_t virtgpu_open_device(virtgpu * gpu, const drmDevicePtr dev) { + const char * node_path = dev->nodes[DRM_NODE_RENDER]; + + int fd = open(node_path, O_RDWR | O_CLOEXEC); + if (fd < 0) { + GGML_ABORT(GGML_VIRTGPU + "%s: failed to open %s", __func__, node_path); + return APIR_ERROR_INITIALIZATION_FAILED; + } + + drmVersionPtr version = drmGetVersion(fd); + if (!version || strcmp(version->name, "virtio_gpu") || version->version_major != 0) { + if (version) { + GGML_LOG_ERROR(GGML_VIRTGPU + "%s: unknown DRM driver %s version %d\n", __func__, version->name, version->version_major); + } else { + GGML_LOG_ERROR(GGML_VIRTGPU + "%s: failed to get DRM driver version\n", __func__); + } + + if (version) { + drmFreeVersion(version); + } + close(fd); + return APIR_ERROR_INITIALIZATION_FAILED; + } + + gpu->fd = fd; + + drmFreeVersion(version); + + GGML_LOG_INFO(GGML_VIRTGPU "using DRM device %s\n", node_path); + + return APIR_SUCCESS; +} + +static virt_gpu_result_t virtgpu_init_context(virtgpu * gpu) { + assert(!gpu->capset.version); + const int ret = virtgpu_ioctl_context_init(gpu, gpu->capset.id); + if (ret) { + GGML_LOG_ERROR(GGML_VIRTGPU "%s: failed to initialize context: %s\n", __func__, strerror(errno)); + return APIR_ERROR_INITIALIZATION_FAILED; + } + + return APIR_SUCCESS; +} + +static virt_gpu_result_t virtgpu_init_capset(virtgpu * gpu) { + if (gpu->use_apir_capset) { + GGML_LOG_INFO(GGML_VIRTGPU "Using the APIR capset\n"); + gpu->capset.id = VIRTGPU_DRM_CAPSET_APIR; + } else { + GGML_LOG_INFO(GGML_VIRTGPU "Using the Venus capset\n"); + gpu->capset.id = VIRTGPU_DRM_CAPSET_VENUS; + } + gpu->capset.version = 0; + + int ret = + virtgpu_ioctl_get_caps(gpu, gpu->capset.id, gpu->capset.version, &gpu->capset.data, sizeof(gpu->capset.data)); + + if (ret) { + GGML_LOG_ERROR(GGML_VIRTGPU + "%s: failed to get APIR v%d capset: %s\n", + __func__, gpu->capset.version, strerror(errno)); + return APIR_ERROR_INITIALIZATION_FAILED; + } + + assert(gpu->capset.data.supports_blob_resources); + + return APIR_SUCCESS; +} + +static int virtgpu_ioctl_context_init(virtgpu * gpu, virgl_renderer_capset capset_id) { + drm_virtgpu_context_set_param ctx_set_params[3] = { + { + .param = VIRTGPU_CONTEXT_PARAM_CAPSET_ID, + .value = capset_id, + }, + { + .param = VIRTGPU_CONTEXT_PARAM_NUM_RINGS, + .value = 1, + }, + { + .param = VIRTGPU_CONTEXT_PARAM_POLL_RINGS_MASK, + .value = 0, /* don't generate drm_events on fence signaling */ + }, + }; + + drm_virtgpu_context_init args = { + .num_params = ARRAY_SIZE(ctx_set_params), + .pad = 0, + .ctx_set_params = (uintptr_t) &ctx_set_params, + }; + + return virtgpu_ioctl(gpu, DRM_IOCTL_VIRTGPU_CONTEXT_INIT, &args); +} + +static int virtgpu_ioctl_get_caps(virtgpu * gpu, + virgl_renderer_capset id, + uint32_t version, + void * capset, + size_t capset_size) { + drm_virtgpu_get_caps args = { + .cap_set_id = id, + .cap_set_ver = version, + .addr = (uintptr_t) capset, + .size = (__u32) capset_size, + .pad = 0, + }; + + return virtgpu_ioctl(gpu, DRM_IOCTL_VIRTGPU_GET_CAPS, &args); +} + +static uint64_t virtgpu_ioctl_getparam(virtgpu * gpu, uint64_t param) { + /* val must be zeroed because kernel only writes the lower 32 bits */ + uint64_t val = 0; + drm_virtgpu_getparam args = { + .param = param, + .value = (uintptr_t) &val, + }; + + const int ret = virtgpu_ioctl(gpu, DRM_IOCTL_VIRTGPU_GETPARAM, &args); + return ret ? 0 : val; +} + +apir_encoder * remote_call_prepare(virtgpu * gpu, ApirCommandType apir_cmd_type, int32_t cmd_flags) { + /* + * Prepare the command encoder and its buffer + */ + + thread_local char encoder_buffer[4096]; + + thread_local apir_encoder enc; + enc = { + .cur = encoder_buffer, + .start = encoder_buffer, + .end = encoder_buffer + sizeof(encoder_buffer), + .fatal = false, + }; + + /* + * Fill the command encoder with the common args: + * - cmd_type (int32_t) + * - cmd_flags (int32_t) + * - reply res id (uint32_t) + */ + + int32_t cmd_type = apir_cmd_type; + + // for testing during the hypervisor transition + if (!gpu->use_apir_capset) { + cmd_type += VENUS_COMMAND_TYPE_LENGTH; + } + apir_encode_int32_t(&enc, &cmd_type); + apir_encode_int32_t(&enc, &cmd_flags); + + uint32_t reply_res_id = gpu->reply_shmem.res_id; + apir_encode_uint32_t(&enc, &reply_res_id); + + return &enc; +} + +void remote_call_finish(virtgpu * gpu, apir_encoder * enc, apir_decoder * dec) { + UNUSED(gpu); + + if (!enc) { + GGML_ABORT(GGML_VIRTGPU "%s: Invalid (null) encoder", __func__); + } + + if (!dec) { + GGML_ABORT(GGML_VIRTGPU "%s: Invalid (null) decoder", __func__); + } + + if (apir_encoder_get_fatal(enc)) { + GGML_LOG_ERROR(GGML_VIRTGPU "%s: Failed to encode the output parameters.", __func__); + } + + if (apir_decoder_get_fatal(dec)) { + GGML_LOG_ERROR(GGML_VIRTGPU "%s: Failed to decode the input parameters.", __func__); + } +} + +uint32_t remote_call(virtgpu * gpu, + apir_encoder * encoder, + apir_decoder ** decoder, + float max_wait_ms, + long long * call_duration_ns) { + /* + * Prepare the reply notification pointer + */ + + volatile std::atomic_uint * atomic_reply_notif = (volatile std::atomic_uint *) gpu->reply_shmem.mmap_ptr; + *atomic_reply_notif = 0; + + /* + * Trigger the execbuf ioctl + */ + + drm_virtgpu_execbuffer args = { + .flags = VIRTGPU_EXECBUF_RING_IDX, + .size = (uint32_t) (encoder->cur - encoder->start), + .command = (uintptr_t) encoder->start, + + .bo_handles = 0, + .num_bo_handles = 0, + + .fence_fd = 0, + .ring_idx = 0, + .syncobj_stride = 0, + .num_in_syncobjs = 0, + .num_out_syncobjs = 0, + .in_syncobjs = 0, + .out_syncobjs = 0, + }; + + *decoder = NULL; + + int ret = drmIoctl(gpu->fd, DRM_IOCTL_VIRTGPU_EXECBUFFER, &args); + + if (ret != 0) { + GGML_ABORT(GGML_VIRTGPU "%s: the virtgpu EXECBUFFER ioctl failed (%d)", __func__, ret); + } + + /* + * Wait for the response notification + */ + timer_data wait_host_reply_timer = { 0, 0, 0 }; + + start_timer(&wait_host_reply_timer); + + timespec ts_start, ts_end; + clock_gettime(CLOCK_MONOTONIC, &ts_start); + long long start_time = (long long) ts_start.tv_sec * 1000000000LL + ts_start.tv_nsec; + + bool timedout = false; + uint32_t notif_value = 0; + while (true) { + notif_value = std::atomic_load_explicit(atomic_reply_notif, std::memory_order_acquire); + + if (notif_value != 0) { + break; + } + + int64_t base_sleep_us = 15; + + os_time_sleep(base_sleep_us); + + if (max_wait_ms) { + clock_gettime(CLOCK_MONOTONIC, &ts_end); + long long end_time = (long long) ts_end.tv_sec * 1000000000LL + ts_end.tv_nsec; + float duration_ms = (end_time - start_time) / 1000000; + + if (duration_ms > max_wait_ms) { + timedout = true; + break; + } + } + } + + if (call_duration_ns) { + *call_duration_ns = stop_timer(&wait_host_reply_timer); + } + + if (max_wait_ms && timedout) { + GGML_LOG_ERROR(GGML_VIRTGPU "%s: timed out waiting for the host answer...\n", __func__); + return APIR_FORWARD_TIMEOUT; + } + + /* + * Prepare the decoder + */ + static apir_decoder response_dec; + response_dec.cur = (char *) gpu->reply_shmem.mmap_ptr + sizeof(*atomic_reply_notif); + response_dec.end = (char *) gpu->reply_shmem.mmap_ptr + gpu->reply_shmem.mmap_size; + *decoder = &response_dec; + + // extract the actual return value from the notif flag + uint32_t returned_value = notif_value - 1; + return returned_value; +} + +static void log_call_duration(long long call_duration_ns, const char * name) { + double call_duration_ms = (double) call_duration_ns / 1e6; // 1 millisecond = 1e6 nanoseconds + double call_duration_s = (double) call_duration_ns / 1e9; // 1 second = 1e9 nanoseconds + + if (call_duration_s > 1) { + GGML_LOG_INFO(GGML_VIRTGPU + "waited %.2fs for the %s host reply...\n", call_duration_s, name); + } else if (call_duration_ms > 1) { + GGML_LOG_INFO(GGML_VIRTGPU + "waited %.2fms for the %s host reply...\n", call_duration_ms, name); + } else { + GGML_LOG_INFO(GGML_VIRTGPU + "waited %lldns for the %s host reply...\n", call_duration_ns, name); + } +} diff --git a/ggml/src/ggml-virtgpu/virtgpu.h b/ggml/src/ggml-virtgpu/virtgpu.h new file mode 100644 index 00000000000..68e0f3a376e --- /dev/null +++ b/ggml/src/ggml-virtgpu/virtgpu.h @@ -0,0 +1,115 @@ +#pragma once + +#include "virtgpu-utils.h" +#include "virtgpu-shm.h" +#include "virtgpu-apir.h" + +#include "backend/shared/api_remoting.h" +#include "backend/shared/apir_cs.h" + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "ggml-remoting.h" + +#define VIRGL_RENDERER_UNSTABLE_APIS 1 +#include "apir_hw.h" +#include +#include "venus_hw.h" + +#ifndef VIRTGPU_DRM_CAPSET_APIR +// Will be defined include/drm/virtgpu_drm.h when +// https://gitlab.freedesktop.org/virgl/virglrenderer/-/merge_requests/1590/diffs +// is merged +#define VIRTGPU_DRM_CAPSET_APIR 10 +#endif + +// Mesa/Virlgrenderer Venus internal. Only necessary during the +// Venus->APIR transition in Virglrenderer +#define VENUS_COMMAND_TYPE_LENGTH 331 + +#ifndef VIRTGPU_DRM_CAPSET_VENUS // only available with Linux >= v6.16 +#define VIRTGPU_DRM_CAPSET_VENUS 4 +#endif + +typedef uint32_t virgl_renderer_capset; + +/* from src/virtio/vulkan/vn_renderer_virtgpu.c */ +#define VIRTGPU_PCI_VENDOR_ID 0x1af4 +#define VIRTGPU_PCI_DEVICE_ID 0x1050 +#define VIRTGPU_BLOB_MEM_GUEST_VRAM 0x0004 +#define VIRTGPU_PARAM_GUEST_VRAM 9 + +#define SHMEM_DATA_SIZE 0x1830000 // 24MiB +#define SHMEM_REPLY_SIZE 0x4000 + +#define ARRAY_SIZE(x) (sizeof(x) / sizeof((x)[0])) + +enum virt_gpu_result_t { + APIR_SUCCESS = 0, + APIR_ERROR_INITIALIZATION_FAILED = -1, +}; + +#define PRINTFLIKE(f, a) __attribute__((format(__printf__, f, a))) + +struct virtgpu { + bool use_apir_capset; + + int fd; + + struct { + virgl_renderer_capset id; + uint32_t version; + virgl_renderer_capset_apir data; + } capset; + + util_sparse_array shmem_array; + + /* APIR communication pages */ + virtgpu_shmem reply_shmem; + virtgpu_shmem data_shmem; + + /* Mutex to protect shared data_shmem buffer from concurrent access */ + mtx_t data_shmem_mutex; + + /* Cached device information to prevent memory leaks and race conditions */ + struct { + char * description; + char * name; + int32_t device_count; + uint32_t type; + size_t memory_free; + size_t memory_total; + } cached_device_info; + + /* Cached buffer type information to prevent memory leaks and race conditions */ + struct { + apir_buffer_type_host_handle_t host_handle; + char * name; + size_t alignment; + size_t max_size; + } cached_buffer_type; +}; + +static inline int virtgpu_ioctl(virtgpu * gpu, unsigned long request, void * args) { + return drmIoctl(gpu->fd, request, args); +} + +virtgpu * create_virtgpu(); + +apir_encoder * remote_call_prepare(virtgpu * gpu, ApirCommandType apir_cmd_type, int32_t cmd_flags); + +uint32_t remote_call(virtgpu * gpu, + apir_encoder * enc, + apir_decoder ** dec, + float max_wait_ms, + long long * call_duration_ns); + +void remote_call_finish(virtgpu * gpu, apir_encoder * enc, apir_decoder * dec); diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 08fd044ca03..72097ffd0ff 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -254,6 +254,7 @@ enum vk_device_architecture { AMD_RDNA3, INTEL_XE2, NVIDIA_PRE_TURING, + NVIDIA_TURING, }; static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& device) { @@ -336,18 +337,34 @@ static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& const std::vector ext_props = device.enumerateDeviceExtensionProperties(); bool cooperative_matrix = false; + bool sm_builtins = false; // Detect "pre-turing" based on lack of coopmat support. for (const auto& properties : ext_props) { if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0) { cooperative_matrix = true; - break; + } else if (strcmp("VK_NV_shader_sm_builtins", properties.extensionName) == 0) { + sm_builtins = true; } } if (!cooperative_matrix) { return vk_device_architecture::NVIDIA_PRE_TURING; } + + if (sm_builtins) { + vk::PhysicalDeviceProperties2 props2; + vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props; + + props2.pNext = &sm_props; + + device.getProperties2(&props2); + + // Turing has 32, following architectures have 48 + if (sm_props.shaderWarpsPerSM == 32) { + return vk_device_architecture::NVIDIA_TURING; + } + } } return vk_device_architecture::OTHER; } @@ -385,18 +402,19 @@ enum FaCodePath { }; struct vk_fa_pipeline_state { - vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, bool small_cache, FaCodePath path, bool aligned, bool f32acc) - : HSK(HSK), HSV(HSV), small_rows(small_rows), small_cache(small_cache), path(path), aligned(aligned), f32acc(f32acc) {} + vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, bool small_cache, FaCodePath path, bool aligned, bool f32acc, uint32_t flags) + : HSK(HSK), HSV(HSV), small_rows(small_rows), small_cache(small_cache), path(path), aligned(aligned), f32acc(f32acc), flags(flags) {} uint32_t HSK, HSV; bool small_rows, small_cache; FaCodePath path; bool aligned; bool f32acc; + uint32_t flags; bool operator<(const vk_fa_pipeline_state &b) const { - return std::tie(HSK, HSV, small_rows, small_cache, path, aligned, f32acc) < - std::tie(b.HSK, b.HSV, b.small_rows, b.small_cache, b.path, b.aligned, b.f32acc); + return std::tie(HSK, HSV, small_rows, small_cache, path, aligned, f32acc, flags) < + std::tie(b.HSK, b.HSV, b.small_rows, b.small_cache, b.path, b.aligned, b.f32acc, b.flags); } }; @@ -803,6 +821,8 @@ struct vk_device_struct { std::map pipeline_flash_attn_f32_f16[GGML_TYPE_COUNT]; + std::map, vk_pipeline> pipeline_fa_mask_opt; + vk_pipeline pipeline_flash_attn_split_k_reduce; vk_pipeline pipeline_count_experts; @@ -991,6 +1011,8 @@ struct vk_mat_vec_id_push_constants { uint32_t fusion_flags; uint32_t nei0; uint32_t ne11; + uint32_t expert_i1; + uint32_t nbi1; }; struct vk_flash_attn_push_constants { @@ -1244,25 +1266,30 @@ struct vk_op_diag_mask_push_constants { struct vk_op_rope_push_constants { uint32_t rope_mode; - uint32_t ncols; uint32_t nrows; uint32_t n_dims; float freq_scale; - uint32_t p_delta_rows; float freq_base; float ext_factor; float attn_factor; float corr_dims[2]; float theta_scale; uint32_t has_ff; - uint32_t ne02; - uint32_t s1; - uint32_t s2; int32_t sections[4]; uint32_t is_imrope; uint32_t is_back; uint32_t set_rows_stride; + uint32_t ne00; + uint32_t ne01; + uint32_t ne02; + uint32_t nb01; + uint32_t nb02; + uint32_t nb03; + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; }; +static_assert(sizeof(vk_op_rope_push_constants) <= 128, "sizeof(vk_op_rope_push_constants) must be <= 128"); // For fused rms_norm+mul+rope(+view+set_rows) struct vk_op_rms_norm_mul_rope_push_constants { @@ -1516,6 +1543,27 @@ struct vk_quantize_q8_1_push_constants { uint32_t num_blocks; }; +struct vk_op_flash_attn_split_k_reduce_push_constants { + uint32_t D; + uint32_t ne1; + uint32_t ne2; + uint32_t ne3; + uint32_t k_num; + uint32_t sinks; +}; + +struct vk_op_flash_attn_mask_opt_push_constants { + uint32_t nem0; + uint32_t nem1; + uint32_t nem2; + uint32_t nbm1; + uint32_t nbm2; + uint32_t nbm3; + uint32_t nbd1; + uint32_t nbd2; + uint32_t nbd3; +}; + // Allow pre-recording command buffers struct vk_staging_memcpy { vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {} @@ -1724,6 +1772,7 @@ class vk_perf_logger { " k(" << k->ne[0] << "," << k->ne[1] << "," << k->ne[2] << "," << k->ne[3] << "), " << " v(" << v->ne[0] << "," << v->ne[1] << "," << v->ne[2] << "," << v->ne[3] << "), " << " m(" << (m?m->ne[0]:0) << "," << (m?m->ne[1]:0) << "," << (m?m->ne[2]:0) << "," << (m?m->ne[3]:0) << ")"; + *n_flops = 2ull * q->ne[1] * q->ne[2] * (k->ne[0] + v->ne[0]) * k->ne[1] * q->ne[3]; return name.str(); } if (node->op == GGML_OP_TOP_K) { @@ -1802,7 +1851,6 @@ struct ggml_backend_vk_context { bool prealloc_x_need_sync, prealloc_y_need_sync, prealloc_split_k_need_sync; vk_context_ref compute_ctx; - vk_context_ref transfer_ctx; std::vector tensor_ctxs; @@ -1812,7 +1860,6 @@ struct ggml_backend_vk_context { uint32_t pipeline_descriptor_set_requirements {}; vk_command_pool compute_cmd_pool; - vk_command_pool transfer_cmd_pool; // number of additional consecutive nodes that are being fused with the // node currently being processed @@ -3146,24 +3193,39 @@ static void ggml_vk_load_shaders(vk_device& device) { return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache)[0], 1, 1}; }; - auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache) -> std::vector { + auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache, uint32_t flags) -> std::vector { // For large number of rows, 128 invocations seems to work best. // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we // can't use 256 for D==80. // For scalar, use 128 (arbitrary) // The same D_split value is used for both HSK and HSV, so just base it on the union of the LSBs. const uint32_t D = (hsk|hsv); - uint32_t wg_size = (path == FA_SCALAR || path == FA_COOPMAT1) - ? scalar_flash_attention_workgroup_size - : ((small_rows && (D % 32) == 0) ? 256 : 128); auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache); + uint32_t wg_size; + switch (path) { + case FA_COOPMAT2: + wg_size = ((small_rows && (D % 32) == 0) ? 256 : 128); + break; + case FA_COOPMAT1: + wg_size = (rows_cols[1] / 16) * device->subgroup_size; // enough subgroups for Bc/MatBc + break; + default: + wg_size = scalar_flash_attention_workgroup_size; + break; + } + // D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it. // D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader. const uint32_t D_lsb = D ^ (D & (D-1)); uint32_t D_split = std::min(std::min(device->subgroup_size, 8u), D_lsb / 4); - return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split}; + // Nvidia prefers shared memory use to load large tiles of K. + // Switch to loading from global memory when it would use too much shared memory. + // AMD prefers loading K directly from global memory + const uint32_t k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 ? 1 : 0; + + return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split, device->subgroup_size, k_load_shmem, flags}; }; #define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \ @@ -3175,18 +3237,19 @@ static void ggml_vk_load_shaders(vk_device& device) { FaCodePath path = fa.first.path; \ bool aligned = fa.first.aligned; \ bool f32acc = fa.first.f32acc; \ + uint32_t flags = fa.first.flags; \ if (path == FAPATH) { \ if (aligned) { \ if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \ } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \ } \ } else { \ if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache,flags), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \ } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache,flags), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \ } \ } \ } \ @@ -3980,7 +4043,12 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_MXFP4], "get_rows_mxfp4_f32", get_rows_mxfp4_f32_len, get_rows_mxfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, 5 * sizeof(uint32_t), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, sizeof(vk_op_flash_attn_split_k_reduce_push_constants), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true); + + for (auto &it : device->pipeline_fa_mask_opt) { + auto BrBc = it.first; + ggml_vk_create_pipeline(device, it.second, "fa_mask_opt", fa_mask_opt_len, fa_mask_opt_data, "main", 2, sizeof(vk_op_flash_attn_mask_opt_push_constants), {1, 1, 1}, {128, 128 / device->subgroup_size, BrBc.first, BrBc.second}, 1, true, true, device->subgroup_size); + } if (device->subgroup_clustered && device->subgroup_require_full_support) { ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1_x4, "quantize_q8_1_x4", quantize_q8_1_x4_subgroup_len, quantize_q8_1_x4_subgroup_data, "main", 2, sizeof(vk_quantize_q8_1_push_constants), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1, true, true); @@ -5513,22 +5581,30 @@ static void ggml_vk_instance_init() { if ((new_props.properties.deviceType == vk::PhysicalDeviceType::eDiscreteGpu || new_props.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu) && ggml_vk_device_is_supported(devices[i])) { // Check if there are two physical devices corresponding to the same GPU + // This handles the case where the same GPU appears with different drivers (e.g., RADV + AMDVLK on Linux), + // see https://github.com/ggml-org/llama.cpp/pull/7582 for original deduplication. + // MoltenVK on macOS may report the same UUID for distinct GPUs on multi-GPU cards, + // see https://github.com/KhronosGroup/MoltenVK/issues/2683. Skip when both old/new + // driver is MoltenVK auto old_device = std::find_if( vk_instance.device_indices.begin(), vk_instance.device_indices.end(), - [&devices, &new_id](const size_t k){ + [&devices, &new_id, &new_driver](const size_t k){ vk::PhysicalDeviceProperties2 old_props; + vk::PhysicalDeviceDriverProperties old_driver; vk::PhysicalDeviceIDProperties old_id; - old_props.pNext = &old_id; + old_props.pNext = &old_driver; + old_driver.pNext = &old_id; devices[k].getProperties2(&old_props); - bool equals = std::equal(std::begin(old_id.deviceUUID), std::end(old_id.deviceUUID), std::begin(new_id.deviceUUID)); - equals = equals || ( + bool same_uuid = std::equal(std::begin(old_id.deviceUUID), std::end(old_id.deviceUUID), std::begin(new_id.deviceUUID)); + same_uuid = same_uuid || ( old_id.deviceLUIDValid && new_id.deviceLUIDValid && std::equal(std::begin(old_id.deviceLUID), std::end(old_id.deviceLUID), std::begin(new_id.deviceLUID)) ); + bool both_molten_vk = (new_driver.driverID == vk::DriverId::eMoltenvk && old_driver.driverID == vk::DriverId::eMoltenvk); - return equals; + return same_uuid && !both_molten_vk; } ); if (old_device == vk_instance.device_indices.end()) { @@ -5647,7 +5723,6 @@ static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) { ctx->almost_ready_fence = ctx->device->device.createFence({}); ctx->compute_cmd_pool.init(ctx->device, &ctx->device->compute_queue); - ctx->transfer_cmd_pool.init(ctx->device, &ctx->device->transfer_queue); if (vk_perf_logger_enabled) { ctx->perf_logger = std::unique_ptr(new vk_perf_logger()); @@ -8083,8 +8158,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte const uint64_t nei0 = ids->ne[0]; const uint64_t nei1 = ids->ne[1]; - - GGML_ASSERT(nei1 == 1); + const uint32_t nbi1 = (uint32_t)(ids->nb[1] / sizeof(int)); const uint64_t ne20 = dst->ne[0]; const uint64_t ne21 = dst->ne[1]; @@ -8168,7 +8242,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte if (quantize_y) { ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1); } - ggml_pipeline_request_descriptor_sets(ctx, dmmv, 1); + ggml_pipeline_request_descriptor_sets(ctx, dmmv, nei1); } vk_subbuffer d_D = ggml_vk_tensor_subbuffer(ctx, cgraph->nodes[node_idx + ctx->num_additional_fused_ops]); @@ -8226,7 +8300,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte uint32_t stride_batch_y = ne10*ne11; if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) { - stride_batch_y = src1->nb[0] / ggml_type_size(src1->type); + stride_batch_y = src1->nb[2] / ggml_type_size(src1->type); } const uint32_t max_groups_x = ctx->device->properties.limits.maxComputeWorkGroupCount[0]; @@ -8262,23 +8336,25 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte fusion_flags |= MAT_VEC_FUSION_FLAGS_SCALE1; } - // compute - const vk_mat_vec_id_push_constants pc = { - (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01, - (uint32_t)(ne00 * ne01), stride_batch_y, (uint32_t)(ne20 * ne21), - fusion_flags, - (uint32_t)nei0, (uint32_t)ne11, - }; - ggml_vk_dispatch_pipeline(ctx, subctx, dmmv, - { - d_X, - d_Y, - d_D, - d_F0, - d_F1, - d_ids, - }, - pc, { groups_x, (uint32_t)nei0, groups_z }); + // Loop over the batch dimension + for (uint32_t expert_i1 = 0; expert_i1 < nei1; ++expert_i1) { + const vk_mat_vec_id_push_constants pc = { + (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01, + (uint32_t)(ne00 * ne01), stride_batch_y, (uint32_t)(ne20 * ne21), + fusion_flags, + (uint32_t)nei0, (uint32_t)ne11, expert_i1, nbi1 + }; + ggml_vk_dispatch_pipeline(ctx, subctx, dmmv, + { + d_X, + d_Y, + d_D, + d_F0, + d_F1, + d_ids, + }, + pc, { groups_x, (uint32_t)nei0, groups_z }); + } if (x_non_contig) { ctx->prealloc_x_need_sync = true; @@ -8292,7 +8368,7 @@ static bool ggml_vk_use_mul_mat_vec_id(const struct ggml_cgraph * cgraph, int no ggml_tensor * dst = cgraph->nodes[node_idx]; ggml_tensor * src0 = dst->src[0]; ggml_tensor * src2 = dst->src[2]; - return src2->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)); + return (src2->ne[1] <= 8) && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)); } static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) { @@ -8325,41 +8401,47 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf; const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize; - VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported); + VK_LOG_DEBUG("ggml_vk_flash_attn_scalar_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported); return supported; } -static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool f32acc) { +static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type kv_type) { // Needs to be kept up to date on shader changes GGML_UNUSED(hsv); - const uint32_t wg_size = scalar_flash_attention_workgroup_size; - const uint32_t Br = coopmat1_flash_attention_num_large_rows; - const uint32_t Bc = scalar_flash_attention_Bc; + const auto rows_cols = fa_rows_cols(FA_COOPMAT1, hsk, hsv, 0, kv_type, false, false); + const uint32_t Br = rows_cols[0]; + const uint32_t Bc = rows_cols[1]; + + const uint32_t MatBr = 16, MatBc = 16; + + const uint32_t row_split = Bc / MatBc; const uint32_t hsk_pad = ROUNDUP_POW2(hsk, 16); const uint32_t acctype = f32acc ? 4 : 2; const uint32_t f16vec4 = 8; - const uint32_t tmpsh = wg_size * sizeof(float); - const uint32_t tmpshv4 = wg_size * 4 * acctype; - const uint32_t qstride = hsk_pad / 4 + 2; const uint32_t Qf = Br * qstride * f16vec4; + const uint32_t psh_stride = Br / 4 + 2; + const uint32_t Psh = Bc * psh_stride * f16vec4; + const uint32_t sfshstride = (hsk <= 128) ? (Br + 8) : Br; const uint32_t sfsh = Bc * sfshstride * acctype; - const uint32_t kshstride = hsk_pad / 4 + 2; - const uint32_t ksh = Bc * kshstride * f16vec4; + const bool k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256; + const uint32_t kshstride = (k_load_shmem ? hsk_pad : MatBr) / 4 + 2; + const uint32_t vsh_stride = MatBc / 4 * row_split; + const uint32_t ksh = ((kshstride >= vsh_stride) ? (Bc * kshstride) : (Bc * vsh_stride)) * f16vec4; - const uint32_t slope = Br * sizeof(float); + const uint32_t slope = Br * acctype; - const uint32_t total_size = tmpsh + tmpshv4 + Qf + sfsh + ksh + slope; + const uint32_t total_size = Qf + Psh + sfsh + ksh + slope; const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize; - VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported); + VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", kv_type=" << kv_type << ", total_size=" << total_size << ", supported=" << supported); return supported; } @@ -8383,6 +8465,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + const uint32_t nem0 = mask ? mask->ne[0] : 0; const uint32_t nem1 = mask ? mask->ne[1] : 0; const uint32_t nem2 = mask ? mask->ne[2] : 0; const uint32_t nem3 = mask ? mask->ne[3] : 0; @@ -8419,11 +8502,16 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx FaCodePath path = ctx->device->coopmat2 ? FA_COOPMAT2 : ctx->device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR; + if (path == FA_COOPMAT1 && ctx->device->architecture == vk_device_architecture::NVIDIA_TURING) { + // Nvidia compiler bug, see https://github.com/ggml-org/llama.cpp/pull/19075#issuecomment-3820716090 + path = FA_SCALAR; + } + if (path == FA_COOPMAT1) { const bool coopmat_shape_supported = (dst->op_params[3] == GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f32acc) || (dst->op_params[3] != GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f16acc); - const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, HSK, HSV, dst->op_params[3] == GGML_PREC_F32); + const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, HSK, HSV, dst->op_params[3] == GGML_PREC_F32, k->type); if (!coopmat_shape_supported || !coopmat_shmem_supported) { path = FA_SCALAR; @@ -8454,14 +8542,14 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx GGML_ASSERT(0); } - if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa && + if (N <= 8 && qk_ratio > 1 && qk_ratio <= max_gqa && qk_ratio * nek2 == neq2 && nek2 == nev2 && nem2 <= 1) { // grouped query attention - make the N dimension equal to gqa_ratio, reduce // workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1 // and change addressing calculations to index Q's dimension 2. gqa_ratio = qk_ratio; N = gqa_ratio; - workgroups_y /= N; + workgroups_y /= gqa_ratio; } bool small_rows = N <= get_fa_num_small_rows(path); @@ -8507,7 +8595,26 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32; - vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, small_cache, path, aligned, f32acc); + float scale = 1.0f; + float max_bias = 0.0f; + float logit_softcap = 0.0f; + + memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float)); + memcpy(&logit_softcap, (const float *) dst->op_params + 2, sizeof(float)); + + if (logit_softcap != 0) { + scale /= logit_softcap; + } + + // Only use mask opt when the mask is fairly large. This hasn't been tuned extensively. + bool use_mask_opt = mask && nem1 >= 32 && nem0 * nem1 > 32768; + + uint32_t flags = (use_mask_opt ? 1 : 0) | + (mask != nullptr ? 2 : 0) | + (logit_softcap != 0 ? 4 : 0); + + vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, small_cache, path, aligned, f32acc, flags); vk_pipeline pipeline = nullptr; @@ -8523,6 +8630,8 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx } assert(pipeline); + // Compile early to initialize wg_denoms. + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); uint32_t split_kv = KV; uint32_t split_k = 1; @@ -8530,22 +8639,24 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx // Use a placeholder core count if one isn't available. split_k is a big help for perf. const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16; - // Try to use split_k when KV is large enough to be worth the overhead - if (workgroups_x == 1 && shader_core_count > 0) { + // Try to use split_k when KV is large enough to be worth the overhead. + // Must either be a single batch or be using gqa, we can't mix the two. + if (workgroups_x <= pipeline->wg_denoms[0] && (workgroups_x == 1 || gqa_ratio > 1)) { // Try to run two workgroups per SM. - split_k = shader_core_count * 2 / (workgroups_y * workgroups_z); + split_k = shader_core_count * 2 / (workgroups_x * workgroups_y * workgroups_z); if (split_k > 1) { // Try to evenly split KV into split_k chunks, but it needs to be a multiple // of "align", so recompute split_k based on that. split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), alignment); split_k = CEIL_DIV(KV, split_kv); - workgroups_x = split_k; } } // Reserve space for split_k temporaries. For each split x batch, we need to store the O matrix (D x ne1) // and the per-row m and L values (ne1 rows). We store all the matrices first, followed by the rows. - const uint64_t split_k_size = split_k > 1 ? (HSV * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne3 : 0; + // For matrices, the order is (inner to outer) [HSV, ne1, k, ne2, ne3]. + // For L/M, the order is (inner to outer) [ne1, k, ne2, ne3]. + const uint64_t split_k_size = split_k > 1 ? (HSV * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne2 * ne3 : 0; if (split_k_size > ctx->device->properties.limits.maxStorageBufferRange) { GGML_ABORT("Requested preallocation size is too large"); } @@ -8554,24 +8665,33 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx ggml_vk_preallocate_buffers(ctx, subctx); } - { - // Request descriptor sets - ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); - if (split_k > 1) { - ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_flash_attn_split_k_reduce, 1); - } - } + auto rows_cols = fa_rows_cols(path, HSK, HSV, !aligned, k->type, small_rows, small_cache); + const uint32_t Br = rows_cols[0]; + const uint32_t Bc = rows_cols[1]; - float scale = 1.0f; - float max_bias = 0.0f; - float logit_softcap = 0.0f; + const uint32_t mask_opt_num_dwords = CEIL_DIV(nem0, 16 * Bc); + const uint64_t mask_opt_size = sizeof(uint32_t) * mask_opt_num_dwords * CEIL_DIV(nem1, Br) * nem2 * nem3; - memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float)); - memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float)); - memcpy(&logit_softcap, (const float *) dst->op_params + 2, sizeof(float)); + vk_pipeline pipeline_fa_mask_opt = nullptr; + if (use_mask_opt) { + std::lock_guard guard(ctx->device->mutex); + auto &pipelines = ctx->device->pipeline_fa_mask_opt; + auto it = pipelines.find({Br, Bc}); + if (it != pipelines.end()) { + pipeline_fa_mask_opt = it->second; + } else { + pipelines[{Br, Bc}] = pipeline_fa_mask_opt = std::make_shared(); + } + assert(pipeline_fa_mask_opt); + ggml_pipeline_request_descriptor_sets(ctx, pipeline_fa_mask_opt, 1); - if (logit_softcap != 0) { - scale /= logit_softcap; + if (ctx->prealloc_size_y < mask_opt_size) { + ctx->prealloc_size_y = mask_opt_size; + ggml_vk_preallocate_buffers(ctx, subctx); + } + if (ctx->prealloc_y_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } } const uint32_t n_head_kv = neq2; @@ -8585,8 +8705,29 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst); vk_subbuffer mask_buf = mask ? ggml_vk_tensor_subbuffer(ctx, mask) : q_buf; vk_subbuffer sinks_buf = sinks ? ggml_vk_tensor_subbuffer(ctx, sinks) : q_buf; + vk_subbuffer mask_opt_buf = use_mask_opt ? ggml_vk_subbuffer(ctx, ctx->prealloc_y, 0) : q_buf; - uint32_t mask_n_head_log2 = ((sinks != nullptr) << 24) | ((mask != nullptr) << 16) | n_head_log2; + uint32_t mask_n_head_log2 = ((sinks != nullptr) << 24) | n_head_log2; + + if (use_mask_opt) + { + const vk_op_flash_attn_mask_opt_push_constants opt_pc = { + nem0, + nem1, + nem2, + (uint32_t)(mask->nb[1] / sizeof(ggml_fp16_t)), + (uint32_t)(mask->nb[2] / sizeof(ggml_fp16_t)), + (uint32_t)(mask->nb[3] / sizeof(ggml_fp16_t)), + mask_opt_num_dwords, + mask_opt_num_dwords * CEIL_DIV(nem1, Br), + mask_opt_num_dwords * CEIL_DIV(nem1, Br) * nem2, + }; + + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline_fa_mask_opt, + { mask_buf, mask_opt_buf }, opt_pc, + { mask_opt_num_dwords, CEIL_DIV(nem1, Br), nem2 * nem3 }); + ggml_vk_sync_buffers(ctx, subctx); + } const vk_flash_attn_push_constants pc = { N, KV, (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, @@ -8602,28 +8743,34 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx gqa_ratio, split_kv, split_k }; if (split_k > 1) { + ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_flash_attn_split_k_reduce, 1); + if (ctx->prealloc_split_k_need_sync) { ggml_vk_sync_buffers(ctx, subctx); } - + workgroups_x *= pipeline->wg_denoms[0]; vk_subbuffer split_k_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, - {q_buf, k_buf, v_buf, mask_buf, sinks_buf, split_k_buf}, + {q_buf, k_buf, v_buf, mask_buf, sinks_buf, split_k_buf, mask_opt_buf}, // We only use split_k when group query attention is enabled, which means // there's no more than one tile of rows (i.e. workgroups_x would have been // one). We reuse workgroups_x to mean the number of splits, so we need to // cancel out the divide by wg_denoms[0]. - pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z }); + pc, { split_k * workgroups_x, workgroups_y, workgroups_z }); ggml_vk_sync_buffers(ctx, subctx); - const std::array pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne3, split_k, (sinks != nullptr) }; + const vk_op_flash_attn_split_k_reduce_push_constants pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, split_k, (sinks != nullptr) }; ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce, {split_k_buf, sinks_buf, dst_buf}, - pc2, { (uint32_t)ne1, HSV, (uint32_t)ne3 }); + pc2, { (uint32_t)ne1, HSV, (uint32_t)(ne2 * ne3) }); ctx->prealloc_split_k_need_sync = true; } else { + if (gqa_ratio > 1) { + // When using gqa, we want one actual workgroup per batch, so cancel out wg_denoms + workgroups_x *= pipeline->wg_denoms[0]; + } ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, - {q_buf, k_buf, v_buf, mask_buf, sinks_buf, dst_buf}, + {q_buf, k_buf, v_buf, mask_buf, sinks_buf, dst_buf, mask_opt_buf}, pc, { workgroups_x, workgroups_y, workgroups_z }); } } @@ -10335,12 +10482,22 @@ static vk_op_rope_push_constants ggml_vk_make_rope_constants(const ggml_tensor * uint32_t nb01 = src0->nb[1] / ggml_type_size(src0->type); uint32_t nb02 = src0->nb[2] / ggml_type_size(src0->type); + uint32_t nb03 = src0->nb[3] / ggml_type_size(src0->type); + + uint32_t nb11 = dst->nb[1] / ggml_type_size(dst->type); + uint32_t nb12 = dst->nb[2] / ggml_type_size(dst->type); + uint32_t nb13 = dst->nb[3] / ggml_type_size(dst->type); vk_op_rope_push_constants rope { - (uint32_t)mode, (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1], - freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale, - has_ff, (uint32_t)src0->ne[2], nb01, nb02, + (uint32_t)mode, (uint32_t)ggml_nrows(src0), (uint32_t)n_dims, freq_scale, + freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale, has_ff, { sections[0], sections[1], sections[2], sections[3] }, is_imrope, backprop, set_rows_stride, + + (uint32_t)src0->ne[0], + (uint32_t)src0->ne[1], + (uint32_t)src0->ne[2], + nb01, nb02, nb03, + nb11, nb12, nb13, }; return rope; @@ -11560,7 +11717,6 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t free(d_chk); ggml_vk_command_pool_cleanup(ctx->device, ctx->compute_cmd_pool); - ggml_vk_command_pool_cleanup(ctx->device, ctx->transfer_cmd_pool); ggml_vk_destroy_buffer(d_X); ggml_vk_destroy_buffer(d_Y); @@ -11909,7 +12065,8 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, } } if (mmq) { - ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_quantize_q8_1, num_it); + vk_pipeline pipeline_quantize_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1); + ggml_pipeline_request_descriptor_sets(ctx, pipeline_quantize_q8_1, num_it); } ggml_pipeline_allocate_descriptor_sets(ctx); @@ -12145,7 +12302,9 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx, vk_contex ggml_vk_submit(subctx, {}); ctx->submit_pending = true; ggml_vk_synchronize(ctx); + GGML_ASSERT(ctx->compute_ctx.expired()); ggml_vk_ctx_begin(ctx->device, subctx); + ctx->compute_ctx = subctx; } if (ctx->prealloc_x == nullptr || (ctx->prealloc_size_x > 0 && ctx->prealloc_x->size < ctx->prealloc_size_x)) { @@ -12163,6 +12322,7 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx, vk_contex ggml_vk_destroy_buffer(ctx->prealloc_y); } ctx->prealloc_y = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_y); + ctx->prealloc_y_last_tensor_used = nullptr; } if (ctx->prealloc_split_k == nullptr || (ctx->prealloc_size_split_k > 0 && ctx->prealloc_split_k->size < ctx->prealloc_size_split_k)) { VK_LOG_MEMORY("ggml_vk_preallocate_buffers(split_k_size: " << ctx->prealloc_size_split_k << ")"); @@ -12743,7 +12903,6 @@ static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) { ctx->prealloc_x_need_sync = ctx->prealloc_y_need_sync = ctx->prealloc_split_k_need_sync = false; ggml_vk_command_pool_cleanup(ctx->device, ctx->compute_cmd_pool); - ggml_vk_command_pool_cleanup(ctx->device, ctx->transfer_cmd_pool); for (size_t i = 0; i < ctx->gc.semaphores.size(); i++) { ctx->device->device.destroySemaphore({ ctx->gc.semaphores[i].s }); @@ -12772,7 +12931,7 @@ static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) { static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) { VK_LOG_DEBUG("ggml_vk_cleanup(" << ctx->name << ")"); // discard any unsubmitted command buffers - ctx->transfer_ctx.reset(); + ctx->compute_ctx.reset(); // wait for any pending command buffers to finish ggml_vk_synchronize(ctx); @@ -12805,7 +12964,6 @@ static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) { ctx->descriptor_sets.clear(); ctx->compute_cmd_pool.destroy(ctx->device->device); - ctx->transfer_cmd_pool.destroy(ctx->device->device); if (vk_perf_logger_enabled) { ctx->perf_logger->print_timings(true); } @@ -13077,34 +13235,34 @@ static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context; - vk_context transfer_ctx; + vk_context compute_ctx; - if (ctx->transfer_ctx.expired()) { + if (ctx->compute_ctx.expired()) { // Initialize new transfer context - transfer_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); - ctx->transfer_ctx = transfer_ctx; - ggml_vk_ctx_begin(ctx->device, transfer_ctx); + compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); + ctx->compute_ctx = compute_ctx; + ggml_vk_ctx_begin(ctx->device, compute_ctx); } else { - transfer_ctx = ctx->transfer_ctx.lock(); + compute_ctx = ctx->compute_ctx.lock(); } vk_buffer buf = buf_ctx->dev_buffer; auto dst_offset = vk_tensor_offset(tensor) + tensor->view_offs + offset; - bool ret = ggml_vk_buffer_write_async(transfer_ctx, buf, dst_offset, data, size); + bool ret = ggml_vk_buffer_write_async(compute_ctx, buf, dst_offset, data, size); if (!ret) { ggml_vk_ensure_sync_staging_buffer(ctx, size); - ggml_vk_sync_buffers(nullptr, transfer_ctx); + ggml_vk_sync_buffers(nullptr, compute_ctx); vk::BufferCopy buffer_cpy; buffer_cpy.srcOffset = 0; buffer_cpy.dstOffset = dst_offset; buffer_cpy.size = size; - transfer_ctx->s->buffer.copyBuffer(ctx->sync_staging->buffer, buf->buffer, { buffer_cpy }); - deferred_memcpy(ctx->sync_staging->ptr, data, size, &transfer_ctx->in_memcpys); + compute_ctx->s->buffer.copyBuffer(ctx->sync_staging->buffer, buf->buffer, { buffer_cpy }); + deferred_memcpy(ctx->sync_staging->ptr, data, size, &compute_ctx->in_memcpys); ggml_vk_synchronize(ctx); } } @@ -13116,34 +13274,34 @@ static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_ ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context; - vk_context transfer_ctx; + vk_context compute_ctx; - if (ctx->transfer_ctx.expired()) { + if (ctx->compute_ctx.expired()) { // Initialize new transfer context - transfer_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); - ctx->transfer_ctx = transfer_ctx; - ggml_vk_ctx_begin(ctx->device, transfer_ctx); + compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); + ctx->compute_ctx = compute_ctx; + ggml_vk_ctx_begin(ctx->device, compute_ctx); } else { - transfer_ctx = ctx->transfer_ctx.lock(); + compute_ctx = ctx->compute_ctx.lock(); } vk_buffer buf = buf_ctx->dev_buffer; auto src_offset = vk_tensor_offset(tensor) + tensor->view_offs + offset; - bool ret = ggml_vk_buffer_read_async(transfer_ctx, buf, src_offset, data, size); + bool ret = ggml_vk_buffer_read_async(compute_ctx, buf, src_offset, data, size); // If that failed, copy synchronously through a staging buffer if (!ret) { ggml_vk_ensure_sync_staging_buffer(ctx, size); - ggml_vk_sync_buffers(nullptr, transfer_ctx); + ggml_vk_sync_buffers(nullptr, compute_ctx); vk::BufferCopy buffer_cpy; buffer_cpy.srcOffset = src_offset; buffer_cpy.dstOffset = 0; buffer_cpy.size = size; - transfer_ctx->s->buffer.copyBuffer(buf->buffer, ctx->sync_staging->buffer, { buffer_cpy }); - deferred_memcpy(data, ctx->sync_staging->ptr, size, &transfer_ctx->out_memcpys); + compute_ctx->s->buffer.copyBuffer(buf->buffer, ctx->sync_staging->buffer, { buffer_cpy }); + deferred_memcpy(data, ctx->sync_staging->ptr, size, &compute_ctx->out_memcpys); ggml_vk_synchronize(ctx); } } @@ -13155,21 +13313,21 @@ static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend, const ggml_ ggml_backend_vk_buffer_context * src_buf_ctx = (ggml_backend_vk_buffer_context *)src->buffer->context; ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; - vk_context transfer_ctx; + vk_context compute_ctx; - if (ctx->transfer_ctx.expired()) { + if (ctx->compute_ctx.expired()) { // Initialize new transfer context - transfer_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); - ctx->transfer_ctx = transfer_ctx; - ggml_vk_ctx_begin(ctx->device, transfer_ctx); + compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); + ctx->compute_ctx = compute_ctx; + ggml_vk_ctx_begin(ctx->device, compute_ctx); } else { - transfer_ctx = ctx->transfer_ctx.lock(); + compute_ctx = ctx->compute_ctx.lock(); } vk_buffer src_buf = src_buf_ctx->dev_buffer; vk_buffer dst_buf = dst_buf_ctx->dev_buffer; - ggml_vk_buffer_copy_async(transfer_ctx, dst_buf, vk_tensor_offset(dst) + dst->view_offs, src_buf, vk_tensor_offset(src) + src->view_offs, ggml_nbytes(src)); + ggml_vk_buffer_copy_async(compute_ctx, dst_buf, vk_tensor_offset(dst) + dst->view_offs, src_buf, vk_tensor_offset(src) + src->view_offs, ggml_nbytes(src)); return true; } @@ -13179,19 +13337,19 @@ static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend, const ggml_ static void ggml_vk_synchronize(ggml_backend_vk_context * ctx) { VK_LOG_DEBUG("ggml_vk_synchronize()"); - bool do_transfer = !ctx->transfer_ctx.expired(); + bool do_transfer = !ctx->compute_ctx.expired(); - vk_context transfer_ctx; + vk_context compute_ctx; if (do_transfer) { - transfer_ctx = ctx->transfer_ctx.lock(); + compute_ctx = ctx->compute_ctx.lock(); - ggml_vk_ctx_end(transfer_ctx); + ggml_vk_ctx_end(compute_ctx); - for (auto& cpy : transfer_ctx->in_memcpys) { + for (auto& cpy : compute_ctx->in_memcpys) { memcpy(cpy.dst, cpy.src, cpy.n); } - ggml_vk_submit(transfer_ctx, {}); + ggml_vk_submit(compute_ctx, {}); ctx->submit_pending = true; } @@ -13205,10 +13363,10 @@ static void ggml_vk_synchronize(ggml_backend_vk_context * ctx) { } if (do_transfer) { - for (auto& cpy : transfer_ctx->out_memcpys) { + for (auto& cpy : compute_ctx->out_memcpys) { memcpy(cpy.dst, cpy.src, cpy.n); } - ctx->transfer_ctx.reset(); + ctx->compute_ctx.reset(); } } @@ -13877,6 +14035,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg ggml_vk_submit(compute_ctx, ctx->device->fence); VK_CHECK(ctx->device->device.waitForFences({ ctx->device->fence }, true, UINT64_MAX), "GGML_VULKAN_PERF waitForFences"); ctx->device->device.resetFences({ ctx->device->fence }); + ctx->compute_ctx.reset(); // Get the results and pass them to the logger std::vector timestamps(cgraph->n_nodes + 1); @@ -14163,15 +14322,15 @@ static void ggml_backend_vk_event_record(ggml_backend_t backend, ggml_backend_ev ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; vk_event *vkev = (vk_event *)event->context; - vk_context transfer_ctx; + vk_context compute_ctx; - if (ctx->transfer_ctx.expired()) { + if (ctx->compute_ctx.expired()) { // Initialize new transfer context - transfer_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); - ctx->transfer_ctx = transfer_ctx; - ggml_vk_ctx_begin(ctx->device, transfer_ctx); + compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); + ctx->compute_ctx = compute_ctx; + ggml_vk_ctx_begin(ctx->device, compute_ctx); } else { - transfer_ctx = ctx->transfer_ctx.lock(); + compute_ctx = ctx->compute_ctx.lock(); } // the backend interface doesn't have an explicit reset, so reset it here @@ -14179,13 +14338,13 @@ static void ggml_backend_vk_event_record(ggml_backend_t backend, ggml_backend_ev ctx->device->device.resetEvent(vkev->event); ctx->device->device.resetFences({ vkev->fence }); - ggml_vk_set_event(transfer_ctx, vkev->event); + ggml_vk_set_event(compute_ctx, vkev->event); - ggml_vk_ctx_end(transfer_ctx); + ggml_vk_ctx_end(compute_ctx); - ggml_vk_submit(transfer_ctx, {vkev->fence}); + ggml_vk_submit(compute_ctx, {vkev->fence}); ctx->submit_pending = true; - ctx->transfer_ctx.reset(); + ctx->compute_ctx.reset(); } static void ggml_backend_vk_event_wait(ggml_backend_t backend, ggml_backend_event_t event) { @@ -14193,20 +14352,20 @@ static void ggml_backend_vk_event_wait(ggml_backend_t backend, ggml_backend_even ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; vk_event *vkev = (vk_event *)event->context; - vk_context transfer_ctx; + vk_context compute_ctx; - if (ctx->transfer_ctx.expired()) { + if (ctx->compute_ctx.expired()) { // Initialize new transfer context - transfer_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); - ctx->transfer_ctx = transfer_ctx; - ggml_vk_ctx_begin(ctx->device, transfer_ctx); + compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); + ctx->compute_ctx = compute_ctx; + ggml_vk_ctx_begin(ctx->device, compute_ctx); } else { - transfer_ctx = ctx->transfer_ctx.lock(); + compute_ctx = ctx->compute_ctx.lock(); } - ggml_vk_wait_events(transfer_ctx, {vkev->event}); - ggml_vk_ctx_end(transfer_ctx); - ctx->transfer_ctx.reset(); + ggml_vk_wait_events(compute_ctx, {vkev->event}); + ggml_vk_ctx_end(compute_ctx); + ctx->compute_ctx.reset(); } // TODO: enable async and synchronize @@ -14726,6 +14885,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_REPEAT_BACK: return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_ROPE: + return ggml_is_contiguous_rows(op) && ggml_is_contiguous_rows(op->src[0]); case GGML_OP_ROPE_BACK: case GGML_OP_NONE: case GGML_OP_RESHAPE: diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 0379e5d5024..914f131c965 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -53,7 +53,7 @@ void main() { const uint32_t d_tid = gl_LocalInvocationIndex % D_split; const uint32_t col_tid = gl_LocalInvocationIndex / D_split; - uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4; + uint32_t q_offset = gqa_iq1*p.nb01 + (iq2*p.nb02 + iq3*p.nb03) / 4; [[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) { uint32_t d = (idx + tid) % (HSK / 4); @@ -94,6 +94,10 @@ void main() { } } + const uint32_t mo_stride = CEIL_DIV(KV, 16 * Bc); + // mo_offset will point to the tile starting at row i*Br and col 0 + uint32_t mo_offset = mo_stride * i; + #if BLOCK_SIZE > 1 uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE; uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE; @@ -101,18 +105,31 @@ void main() { uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2; uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2; #endif - uint32_t m_offset = 0; + uint32_t m_offset = gqa_iq1*KV; if (p.nem2 != 1 || p.nem3 != 1) { - m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV; + m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV; + mo_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * CEIL_DIV(p.nem1, Br) * mo_stride; } + uint32_t mask_opt = 0; + uint32_t mask_opt_idx = ~0; + [[dont_unroll]] for (uint32_t j = start_j; j < end_j; ++j) { - if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { + if (USE_MASK_OPT && mask_opt_idx != j / 16) { + mask_opt_idx = j / 16; + mask_opt = data_mask_opt[mo_offset + mask_opt_idx]; + } + uint32_t mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3; + if (mask_opt_bits == MASK_OPT_ALL_NEG_INF) { + // skip this block + continue; + } + // Only load if the block is not all zeros + if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) { bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0; - float max_mask = NEG_FLT_MAX_OVER_2; [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { uint32_t c = (idx + tid) % Bc; uint32_t r = (idx + tid) / Bc; @@ -120,25 +137,12 @@ void main() { if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) { float m = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]); masksh[c][r] = m; - max_mask = max(max_mask, m); } else { masksh[c][r] = float(0); } } } - // skip the block if the mask is entirely -inf - bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2); - barrier(); - if (gl_SubgroupInvocationID == 0) { - tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f; - } barrier(); - [[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) { - max_mask = max(max_mask, tmpsh[s]); - } - if (max_mask <= NEG_FLT_MAX_OVER_2) { - continue; - } } float Sf[Br][cols_per_thread]; @@ -177,7 +181,7 @@ void main() { } } - if (p.logit_softcap != 0.0f) { + if (LOGIT_SOFTCAP) { [[unroll]] for (uint32_t r = 0; r < Br; ++r) { [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { Sf[r][c] = p.logit_softcap * tanh(Sf[r][c]); @@ -185,7 +189,7 @@ void main() { } } - if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { + if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) { [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { [[unroll]] for (uint32_t r = 0; r < Br; ++r) { float mvf = masksh[c * cols_per_iter + col_tid][r]; @@ -256,9 +260,6 @@ void main() { barrier(); } - // prevent race on tmpsh - barrier(); - // reduce across threads [[unroll]] for (uint32_t r = 0; r < Br; ++r) { @@ -320,7 +321,8 @@ void main() { // If there is split_k, then the split_k resolve shader does the final // division by L. Store the intermediate O value and per-row m and L values. if (p.k_num > 1) { - uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num); + // note: O and Q have swapped coord 1,2. + uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); [[unroll]] for (uint32_t r = 0; r < Br; ++r) { if (r < N) { @@ -332,7 +334,7 @@ void main() { } } - o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2; + o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); [[unroll]] for (uint32_t r = 0; r < Br; ++r) { if (r < N) { perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N); @@ -378,7 +380,7 @@ void main() { } } - uint32_t o_offset = iq3*p.ne2*p.ne1*HSV; + uint32_t o_offset = gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV; if (p.gqa_ratio > 1) { [[unroll]] for (uint32_t r = 0; r < Br; ++r) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl index eb93903c468..4142c1e6eaa 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl @@ -8,6 +8,13 @@ layout (constant_id = 3) const uint32_t HSK = 32; layout (constant_id = 4) const uint32_t HSV = 32; layout (constant_id = 5) const uint32_t Clamp = 0; layout (constant_id = 6) const uint32_t D_split = 16; +layout (constant_id = 7) const uint32_t SubGroupSize = 32; +layout (constant_id = 8) const uint32_t K_LOAD_SHMEM = 0; +layout (constant_id = 9) const uint32_t Flags = 0; + +const bool USE_MASK_OPT = (Flags & 1) != 0; +const bool MASK_ENABLE = (Flags & 2) != 0; +const bool LOGIT_SOFTCAP = (Flags & 4) != 0; // Round up head sizes to a multiple of 16, for coopmat1/coopmat2 paths const uint32_t HSK_pad = (HSK + 15) & ~15; @@ -57,13 +64,17 @@ layout (push_constant) uniform parameter { } p; #define SINK_ENABLE_BIT (1<<24) -#define MASK_ENABLE_BIT (1<<16) #define N_LOG2_MASK 0xFFFF layout (binding = 4) readonly buffer S {float data_s[];}; layout (binding = 5) writeonly buffer O {D_TYPE data_o[];}; +layout (binding = 6) readonly buffer MO {uint32_t data_mask_opt[];}; + +#define MASK_OPT_ALL_NEG_INF 1 +#define MASK_OPT_ALL_ZERO 2 + #define BINDING_IDX_K 0 #define BINDING_IDX_V 1 #if defined(DATA_A_F32) @@ -74,6 +85,10 @@ layout (binding = 1) readonly buffer K_PACKED16 {A_TYPE_PACKED16 k_data_packed16 layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16[];} v_packed; #endif +#ifndef BLOCK_SIZE +#define BLOCK_SIZE 1 +#endif + #if defined(DATA_A_F32) #undef BLOCK_SIZE #define BLOCK_SIZE 4 @@ -165,7 +180,7 @@ ACC_TYPE perElemOpGetSink(const in uint32_t r, const in uint32_t c, const in ACC } uint32_t i, N, KV, split_k_index, Tr, start_j, end_j, - iq2, iq3, rk2, rk3, rv2, rv3, ik2, ik3, iv2, iv3, + gqa_iq1, iq2, iq3, rk2, rk3, rv2, rv3, ik2, ik3, iv2, iv3, q_stride, k_stride, v_stride, m_stride; void init_indices() @@ -173,12 +188,19 @@ void init_indices() N = p.N; KV = p.KV; - i = gl_WorkGroupID.x; - split_k_index = 0; - if (p.k_num > 1) { i = 0; - split_k_index = gl_WorkGroupID.x; + // batch and split_k share gl_WorkGroupID.x + gqa_iq1 = gl_WorkGroupID.x / p.k_num; + split_k_index = gl_WorkGroupID.x % p.k_num; + } else if (p.gqa_ratio > 1) { + i = 0; + gqa_iq1 = gl_WorkGroupID.x; + split_k_index = 0; + } else { + i = gl_WorkGroupID.x; + gqa_iq1 = 0; + split_k_index = 0; } Tr = CEIL_DIV(N, Br); @@ -218,3 +240,7 @@ void init_indices() // and breaking the alignment detection. m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV; } + +// Bias applied to softmax to stay in fp16 range. +// Based on ggml-cuda issue https://github.com/ggml-org/llama.cpp/issues/18606 +const float FATTN_KQ_MAX_OFFSET = 3.0f*0.6931f; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp index c995ab140ee..b3177738234 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -7,6 +7,7 @@ #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require #extension GL_KHR_shader_subgroup_basic : enable +#extension GL_KHR_shader_subgroup_arithmetic : enable #extension GL_KHR_shader_subgroup_vote : enable #extension GL_KHR_memory_scope_semantics : enable #extension GL_KHR_cooperative_matrix : enable @@ -14,12 +15,13 @@ #include "types.glsl" #include "flash_attn_base.glsl" -const uint32_t HSK_per_thread = HSK / D_split; -const uint32_t HSV_per_thread = HSV / D_split; +// These need to be supported N,M values for a MatBc x MatBr x 16 coopmatmuladd +const uint32_t MatBr = 16; +const uint32_t MatBc = 16; -const uint32_t row_split = 4; +const uint32_t row_split = Bc / MatBc; const uint32_t rows_per_thread = Br / row_split; -const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split / row_split; +const uint32_t cols_per_iter = gl_WorkGroupSize.x / row_split; const uint32_t cols_per_thread = Bc / cols_per_iter; @@ -40,24 +42,22 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY return elem; } -// These need to be supported N,M values for a MatBc x MatBr x 16 coopmatmuladd -const uint32_t MatBr = 16; -const uint32_t MatBc = 16; - -shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x]; -shared ACC_TYPEV4 tmpshv4[gl_WorkGroupSize.x]; - const uint32_t qstride = HSK_pad / 4 + 2; // in units of f16vec4 shared f16vec4 Qf[Br * qstride]; +const uint psh_stride = Br / 4 + 2; +shared f16vec4 Psh[Bc * psh_stride]; + // Avoid padding for hsk==256 to make it fit in 48KB shmem. -const uint32_t sfshstride = (HSK <= 128) ? (Br + 8) : Br; -shared ACC_TYPE sfsh[Bc * sfshstride]; +const uint32_t sfshstride = (HSK <= 128) ? (Br / 4 + 2) : Br / 4; +shared ACC_TYPEV4 sfsh[Bc * sfshstride]; -const uint32_t kshstride = HSK_pad / 4 + 2; // in units of f16vec4 -shared f16vec4 ksh[Bc * kshstride]; +const uint32_t kshstride = (K_LOAD_SHMEM != 0 ? HSK_pad : MatBr) / 4 + 2; // in units of f16vec4 +const uint v_cols = MatBc / 4 * row_split; // total cols, 4 vec4s per MatBc * number of subgroups +const uint vsh_stride = v_cols; +shared f16vec4 ksh[(kshstride >= vsh_stride) ? (Bc * kshstride) : (Bc * vsh_stride)]; -shared float slope[Br]; +shared ACC_TYPE slope[Br]; void main() { #ifdef NEEDS_INIT_IQ_SHMEM @@ -69,9 +69,9 @@ void main() { const uint32_t tid = gl_LocalInvocationIndex; const uint32_t threads_per_rowgroup = gl_WorkGroupSize.x / row_split; + const uint32_t d_per_thread = (HSV/4 + threads_per_rowgroup - 1) / threads_per_rowgroup; const uint32_t row_tid = gl_LocalInvocationIndex / threads_per_rowgroup; - const uint32_t d_tid = gl_LocalInvocationIndex % D_split; - const uint32_t col_tid = (gl_LocalInvocationIndex % threads_per_rowgroup) / D_split; + const uint32_t col_tid = gl_LocalInvocationIndex % threads_per_rowgroup; #define tile_row(r) (row_tid * rows_per_thread + (r)) @@ -90,7 +90,7 @@ void main() { barrier(); } - uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4; + uint32_t q_offset = gqa_iq1*p.nb01 + (iq2*p.nb02+iq3*p.nb03) / 4; [[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) { uint32_t d = (idx + tid) % (HSK / 4); @@ -102,9 +102,9 @@ void main() { } barrier(); - ACC_TYPEV4 Of[rows_per_thread][HSV_per_thread / 4]; - [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { - [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + ACC_TYPEV4 Of[rows_per_thread][d_per_thread]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + [[unroll]] for (uint32_t d = 0; d < d_per_thread; ++d) { Of[r][d] = ACC_TYPEV4(0.0); } } @@ -125,15 +125,17 @@ void main() { uint r = tid; slope[r] = perElemOpComputeSlope(r, col_tid, ACC_TYPE(0), iq2); } - barrier(); } else { if (tid < Br) { uint r = tid; - slope[r] = 1.0; + slope[r] = ACC_TYPE(1.0); } - barrier(); } + const uint32_t mo_stride = CEIL_DIV(KV, 16 * Bc); + // mo_offset will point to the tile starting at row i*Br and col 0 + uint32_t mo_offset = mo_stride * i; + #if BLOCK_SIZE > 1 uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE; uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE; @@ -141,65 +143,101 @@ void main() { uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2; uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2; #endif - uint32_t m_offset = 0; + uint32_t m_offset = gqa_iq1*KV; if (p.nem2 != 1 || p.nem3 != 1) { - m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV; + m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV; + mo_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * CEIL_DIV(p.nem1, Br) * mo_stride; } + uint32_t mask_opt = 0; + uint32_t mask_opt_idx = ~0; + [[dont_unroll]] for (uint32_t j = start_j; j < end_j; ++j) { - float mask_cache[Bc * Br / WorkGroupSize]; - if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { - bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0; - - float max_mask = NEG_FLT_MAX_OVER_2; - [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { - uint32_t c = (idx + tid) % Bc; - uint32_t r = (idx + tid) / Bc; - if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) { - if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) { - float m = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]); - mask_cache[idx / WorkGroupSize] = m; - max_mask = max(max_mask, m); - } - } - } - // skip the block if the mask is entirely -inf - bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2); - barrier(); - if (gl_SubgroupInvocationID == 0) { - tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f; - } - barrier(); - [[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) { - max_mask = max(max_mask, tmpsh[s]); + f16vec4 mask_cache[Bc * Br / 4 / WorkGroupSize]; + [[unroll]] for (uint32_t idx = 0; idx < mask_cache.length(); ++idx) { + mask_cache[idx] = f16vec4(0); + } + + if (MASK_ENABLE) { + + if (USE_MASK_OPT && mask_opt_idx != j / 16) { + mask_opt_idx = j / 16; + mask_opt = data_mask_opt[mo_offset + mask_opt_idx]; } - if (max_mask <= NEG_FLT_MAX_OVER_2) { + uint32_t mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3; + if (mask_opt_bits == MASK_OPT_ALL_NEG_INF) { + // skip this block continue; } + // Only load if the block is not all zeros + if (mask_opt_bits != MASK_OPT_ALL_ZERO) { + bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0; + + float max_mask = NEG_FLT_MAX_OVER_2; + [[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) { + uint32_t c = (idx + tid) / (Br / 4); + uint32_t r = (idx + tid) % (Br / 4); + if (idx + tid < Bc * Br / 4 || idx + gl_WorkGroupSize.x <= Bc * Br / 4) { + if ((!KV_bounds_check || j * Bc + c < KV)) { + f16vec4 m; + if (!nem1_bounds_check || i * Br + r * 4 + 3 < p.nem1) { + m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)], + data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)], + data_m[m_offset + (i * Br + r * 4 + 2) * m_stride + (j * Bc + c)], + data_m[m_offset + (i * Br + r * 4 + 3) * m_stride + (j * Bc + c)]); + max_mask = max(max(max(max(max_mask, float(m[0])), float(m[1])), float(m[2])), float(m[3])); + } else if (i * Br + r * 4 + 2 < p.nem1) { + m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)], + data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)], + data_m[m_offset + (i * Br + r * 4 + 2) * m_stride + (j * Bc + c)], + 0.0); + max_mask = max(max(max(max_mask, float(m[0])), float(m[1])), float(m[2])); + } else if (i * Br + r * 4 + 1 < p.nem1) { + m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)], + data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)], + 0.0, + 0.0); + max_mask = max(max(max_mask, float(m[0])), float(m[1])); + } else if (i * Br + r * 4 < p.nem1) { + m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)], + 0.0, + 0.0, + 0.0); + max_mask = max(max_mask, float(m[0])); + } else { + m = f16vec4(0.0); + } + mask_cache[idx / WorkGroupSize] = m; + } + } + } + } } - [[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) { - uint32_t d = (idx + tid) % (HSK / 4); - uint32_t c = (idx + tid) / (HSK / 4); - if (c < Bc && d < HSK / 4) { - f16vec4 K_Tf = f16vec4(0); - if (!KV_bounds_check || j * Bc + c < KV) { + if (K_LOAD_SHMEM != 0) { + [[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) { + uint32_t d = (idx + tid) % (HSK / 4); + uint32_t c = (idx + tid) / (HSK / 4); + if (c < Bc && d < HSK / 4) { + f16vec4 K_Tf = f16vec4(0); + if (!KV_bounds_check || j * Bc + c < KV) { #if BLOCK_SIZE > 1 - uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d; - uint ib = coord / BLOCK_SIZE; - uint iqs = (coord % BLOCK_SIZE); - K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K)); + uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d; + uint ib = coord / BLOCK_SIZE; + uint iqs = (coord % BLOCK_SIZE); + K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K)); #else - K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]); + K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]); #endif - } + } - ksh[c * kshstride + d] = K_Tf; + ksh[c * kshstride + d] = K_Tf; + } } + barrier(); } - barrier(); // K * Q^T -> S^T: Bc x HSK_pad * HSK_pad x Br -> Bc x Br // Bc split across workgroup (four subgroups), loop over HSK in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16 @@ -208,11 +246,55 @@ void main() { coopmat KMat; coopmat QMat; - for (uint32_t d = 0; d < HSK_pad / 16; ++d) { - coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor); + [[unroll]] for (uint32_t d = 0; d < HSK_pad / 16; ++d) { + if (K_LOAD_SHMEM == 0) { +#if BLOCK_SIZE == 1 + if (KV_bounds_check || d * 16 + 16 > HSK) { +#endif + barrier(); + [[unroll]] for (uint32_t idx = 0; idx < Bc * MatBr / 4; idx += gl_WorkGroupSize.x) { + uint32_t col_vec = (idx + tid) % (MatBr / 4); + uint32_t row = (idx + tid) / (MatBr / 4); + if (idx + tid < Bc * MatBr / 4) { + f16vec4 K_Tf = f16vec4(0); + if ((!KV_bounds_check || j * Bc + row < KV) && (HSK == HSK_pad || d * 16 + col_vec * 4 < HSK)) { +#if BLOCK_SIZE > 1 + uint coord = (j * Bc + row) * k_stride * BLOCK_SIZE + d * 16 + col_vec * 4; + uint ib = coord / BLOCK_SIZE; + uint iqs = (coord % BLOCK_SIZE); + K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K)); +#else + K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + row) * k_stride / 4 + d * 16 / 4 + col_vec]); +#endif + } + + ksh[row * kshstride + col_vec] = K_Tf; + } + } + barrier(); +#if BLOCK_SIZE == 1 + } +#endif + +#if BLOCK_SIZE == 1 + if (KV_bounds_check || d * 16 + 16 > HSK) +#endif + { + uint coord = (gl_SubgroupID * MatBc) * kshstride; + coopMatLoad(KMat, ksh, coord, kshstride, gl_CooperativeMatrixLayoutRowMajor); + } +#if BLOCK_SIZE == 1 + else { + const uint coord = k_offset / 4 + (j * Bc + gl_SubgroupID * MatBc) * k_stride / 4 + d * 16 / 4; + coopMatLoad(KMat, data_kv4, coord, k_stride / 4, gl_CooperativeMatrixLayoutRowMajor); + } +#endif + } else { + uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4; + coopMatLoad(KMat, ksh, coord, kshstride, gl_CooperativeMatrixLayoutRowMajor); + } - uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4; - coopMatLoad(KMat, ksh, coord, kshstride, gl_CooperativeMatrixLayoutRowMajor); + coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor); SfMat = coopMatMulAdd(KMat, QMat, SfMat); } @@ -221,27 +303,27 @@ void main() { coopMatStore(SfMat, sfsh, coord, sfshstride, gl_CooperativeMatrixLayoutRowMajor); barrier(); - if (p.logit_softcap != 0.0f) { - [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { - uint32_t c = (idx + tid) / Br; - uint32_t r = (idx + tid) % Br; - if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) { - sfsh[c * sfshstride + r] = ACC_TYPE(p.logit_softcap * tanh(sfsh[c * sfshstride + r])); + if (LOGIT_SOFTCAP) { + [[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) { + uint32_t c = (idx + tid) / (Br / 4); + uint32_t r = (idx + tid) % (Br / 4); + if (idx + tid < Bc * Br / 4 || idx + gl_WorkGroupSize.x <= Bc * Br / 4) { + sfsh[c * sfshstride + r] = ACC_TYPEV4(p.logit_softcap * tanh(sfsh[c * sfshstride + r])); } } barrier(); } - if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { - bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0; - - [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { - uint32_t c = (idx + tid) % Bc; - uint32_t r = (idx + tid) / Bc; - if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) { - if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) { - float f = mask_cache[idx / WorkGroupSize]; - sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * f); + if (MASK_ENABLE) { + [[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) { + uint32_t c = (idx + tid) / (Br / 4); + uint32_t r = (idx + tid) % (Br / 4); + if (idx + tid < Bc * Br / 4 || idx + gl_WorkGroupSize.x <= Bc * Br / 4) { + if (!KV_bounds_check || j * Bc + c < KV) { + // Mask nem1 bounds check is handled when loading masks + ACC_TYPEV4 masks = ACC_TYPEV4(mask_cache[idx / WorkGroupSize]); + ACC_TYPEV4 slopes = ACC_TYPEV4(slope[r * 4], slope[r * 4 + 1], slope[r * 4 + 2], slope[r * 4 + 3]); + sfsh[c * sfshstride + r] += slopes * masks; } } } @@ -250,139 +332,176 @@ void main() { float eMf[rows_per_thread]; [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + const uint r_vec = tile_row(r) / 4; + const uint r_comp = tile_row(r) % 4; + float rowmaxf = NEG_FLT_MAX_OVER_2; [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { continue; } - rowmaxf = max(rowmaxf, float(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride])); + rowmaxf = max(rowmaxf, float(sfsh[r_vec + (c * cols_per_iter + col_tid) * sfshstride][r_comp])); } float Moldf = Mf[r]; + // Compute max across the row + rowmaxf = subgroupMax(rowmaxf); + // M = max(rowmax, Mold) // P = e^(S - M) // eM = e^(Mold - M) Mf[r] = max(rowmaxf, Moldf); eMf[r] = exp(Moldf - Mf[r]); + + Lf[r] = eMf[r]*Lf[r]; } - [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) { + const uint d_local = d0 / threads_per_rowgroup; [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Of[r][d] = ACC_TYPE(eMf[r]) * Of[r][d]; + Of[r][d_local] = ACC_TYPE(eMf[r]) * Of[r][d_local]; } } - [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Lf[r] = eMf[r]*Lf[r]; - } + // Calculate and store Pf in Psh [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { - if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { - continue; - } - float Pf[rows_per_thread]; - [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Pf[r] = exp(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride] - Mf[r]); - Lf[r] += Pf[r]; + const uint col = c * cols_per_iter + col_tid; + + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; r += 4) { + const uint row = tile_row(r); + if (KV_bounds_check && j * Bc + col >= KV) { + Psh[col * psh_stride + row / 4] = f16vec4(0.0f); + } else { + const vec4 mfvec = vec4(Mf[r], Mf[r + 1], Mf[r + 2], Mf[r + 3]); + const f16vec4 Pf = f16vec4(exp(vec4(sfsh[row / 4 + col * sfshstride]) - mfvec)); + [[unroll]] for (uint32_t vec_idx = 0; vec_idx < 4; ++vec_idx) { + Lf[r + vec_idx] += Pf[vec_idx]; + } + Psh[col * psh_stride + row / 4] = Pf; + } } - [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + } + + const uint num_hsv_tiles = (HSV + MatBc * row_split - 1) / (MatBc * row_split); // round up + + // Each subgroup handles HSV/4 columns + [[unroll]] for (uint32_t hsv_tile = 0; hsv_tile < num_hsv_tiles; ++hsv_tile) { + const uint hsv_offset = (hsv_tile * row_split + gl_SubgroupID) * 16; + + SfMat = coopmat(0); + + // Preload V tiles for [Bc, 16 * num subgroups] + const uint v_rows = Bc; + const uint v_total = v_rows * v_cols; + const uint v_loads_per_thread = v_total / gl_WorkGroupSize.x; + +#if BLOCK_SIZE == 1 + // For f16, only preload if not aligned + if (KV_bounds_check) { +#endif + [[unroll]] for (uint32_t i = 0; i < v_loads_per_thread; ++i) { + const uint idx = i * gl_WorkGroupSize.x + tid; + const uint row = idx / v_cols; + const uint col = idx % v_cols; + + const uint v_row = j * Bc + row; + const uint v_col = hsv_tile * MatBc * row_split + col * 4; + + const uint coord = v_row * v_stride * BLOCK_SIZE + v_col; + const uint ib = coord / BLOCK_SIZE; + const uint iqs = coord % BLOCK_SIZE; + + if (!KV_bounds_check || (v_row < KV && v_col < HSV)) { #if BLOCK_SIZE > 1 - uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); - uint ib = coord / BLOCK_SIZE; - uint iqs = (coord % BLOCK_SIZE); - vec4 Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); + ksh[row * vsh_stride + col] = f16vec4(dequantize4(ib, iqs, v_offset, BINDING_IDX_V)); #else - vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]); + ksh[row * vsh_stride + col] = data_vv4[(v_offset + v_row * v_stride + v_col) / 4]; #endif - [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Of[r][d] += ACC_TYPE(Pf[r]) * ACC_TYPEV4(Vf); + } else { + ksh[row * vsh_stride + col] = f16vec4(0.0f); } } - } +#if BLOCK_SIZE == 1 + } +#endif - barrier(); - } + barrier(); - // prevent race on tmpsh - barrier(); + [[unroll]] for (uint32_t bc_chunk = 0; bc_chunk < Bc / MatBc; ++bc_chunk) { + coopMatLoad(KMat, Psh, bc_chunk * MatBc * psh_stride, psh_stride, gl_CooperativeMatrixLayoutColumnMajor); + +#if BLOCK_SIZE == 1 + if (!KV_bounds_check) { + // F16 values can be loaded directly from global memory + const uint v_tile_row = j * Bc + bc_chunk * MatBc; + const uint v_tile_offset = v_offset / 4 + v_tile_row * v_stride / 4 + hsv_offset / 4; + coopMatLoad(QMat, data_vv4, v_tile_offset, v_stride / 4, gl_CooperativeMatrixLayoutRowMajor); + } else +#endif + { + const uint v_tile_offset = bc_chunk * MatBr * v_cols + gl_SubgroupID * (MatBc / 4); + coopMatLoad(QMat, ksh, v_tile_offset, vsh_stride, gl_CooperativeMatrixLayoutRowMajor); + } - // reduce across threads + SfMat = coopMatMulAdd(KMat, QMat, SfMat); + } + + // Store SfMat to sfsh and load into Of + const uint osh_stride = row_split * MatBc / 4; + const uint o_offset = gl_SubgroupID * MatBc / 4; + coopMatStore(SfMat, sfsh, o_offset, osh_stride, gl_CooperativeMatrixLayoutRowMajor); - float rowmaxf[rows_per_thread], eMf[rows_per_thread], Moldf[rows_per_thread]; - [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - FLOAT_TYPE M = Mf[r]; - tmpsh[tid] = M; - // Compute max across the row - barrier(); - [[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) { - M = max(M, tmpsh[tid ^ s]); - barrier(); - tmpsh[tid] = M; barrier(); - } - rowmaxf[r] = tmpsh[d_tid + row_tid * threads_per_rowgroup]; - barrier(); - } - [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Moldf[r] = Mf[r]; + const uint hsv_per_tile = row_split * MatBc; + const uint hsv_base = hsv_tile * hsv_per_tile; + const uint d_values_per_tile = hsv_per_tile / 4; - // M = max(rowmax, Mold) - // eM = e^(Mold - M) - Mf[r] = max(rowmaxf[r], Moldf[r]); - eMf[r] = exp(Moldf[r] - Mf[r]); + const uint d_start = hsv_tile * d_values_per_tile; + const uint d_end = min(d_start + d_values_per_tile, HSV / 4); - Lf[r] = eMf[r]*Lf[r]; - } + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + const uint row = tile_row(r); - [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - FLOAT_TYPE L = Lf[r]; - tmpsh[tid] = L; - // Compute sum across the row - barrier(); - [[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) { - L += tmpsh[tid ^ s]; - barrier(); - tmpsh[tid] = L; - barrier(); + [[unroll]] for (uint32_t d_local = 0; d_local < d_per_thread; ++d_local) { + const uint d = d_local * threads_per_rowgroup + col_tid; + const uint hsv_col = 4 * d; + + if (hsv_col >= hsv_base && hsv_col < hsv_base + hsv_per_tile && hsv_col < HSV) { + const uint local_hsv = (hsv_col - hsv_base) / 4; + Of[r][d_local] += ACC_TYPEV4(sfsh[row * osh_stride + local_hsv]); + } + } + } } - Lf[r] = tmpsh[d_tid + row_tid * threads_per_rowgroup]; + barrier(); } [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { - - Of[r][d] = ACC_TYPE(eMf[r]) * Of[r][d]; - tmpshv4[tid] = Of[r][d]; - - barrier(); - [[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) { - Of[r][d] += tmpshv4[tid ^ s]; - barrier(); - tmpshv4[tid] = Of[r][d]; - barrier(); - } - Of[r][d] = tmpshv4[d_tid + row_tid * threads_per_rowgroup]; - barrier(); - } + Lf[r] = subgroupAdd(Lf[r]); } // If there is split_k, then the split_k resolve shader does the final // division by L. Store the intermediate O value and per-row m and L values. if (p.k_num > 1) { - uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num); + // note: O and Q have swapped coord 1,2. + uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { if (tile_row(r) < N) { - [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) { + const uint d = d0 + col_tid; + if (d >= HSV/4) break; + const uint d_local = d0 / threads_per_rowgroup; [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { - perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N); + perElemOpGqaStore(tile_row(r), 4 * d + comp, float(Of[r][d_local][comp]), o_offset, iq2, N); } } } } - o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2; + o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { if (tile_row(r) < N) { perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N); @@ -403,8 +522,9 @@ void main() { if (sink > Mf[r]) { ms = exp(Mf[r] - sink); - [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { - Of[r][d] *= ACC_TYPE(ms); + [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) { + const uint d_local = d0 / threads_per_rowgroup; + Of[r][d_local] *= ACC_TYPE(ms); } } else { vs = exp(sink - Mf[r]); @@ -419,23 +539,27 @@ void main() { Lfrcp[r] = (Lf[r] == 0.0) ? 0.0 : (1.0 / Lf[r]); } - [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) { + const uint d_local = d0 / threads_per_rowgroup; [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Of[r][d] *= ACC_TYPE(Lfrcp[r]); + Of[r][d_local] *= ACC_TYPE(Lfrcp[r]); #if defined(ACC_TYPE_MAX) - Of[r][d] = clamp(Of[r][d], -ACC_TYPE_MAX, ACC_TYPE_MAX); + Of[r][d_local] = clamp(Of[r][d_local], -ACC_TYPE_MAX, ACC_TYPE_MAX); #endif } } - uint32_t o_offset = iq3*p.ne2*p.ne1*HSV; + uint32_t o_offset = gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV; if (p.gqa_ratio > 1) { [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { if (tile_row(r) < N) { - [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) { + const uint d = d0 + col_tid; + if (d >= HSV / 4) break; + const uint d_local = d0 / threads_per_rowgroup; [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { - perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N); + perElemOpGqaStore(tile_row(r), 4 * d + comp, float(Of[r][d_local][comp]), o_offset, iq2, N); } } } @@ -443,9 +567,12 @@ void main() { } else { [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { if (i * Br + tile_row(r) < N) { - [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) { + const uint d = d0 + col_tid; + if (d >= HSV / 4) break; + const uint d_local = d0 / threads_per_rowgroup; [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { - data_o[o_offset + iq2 * HSV + (i * Br + tile_row(r)) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]); + data_o[o_offset + iq2 * HSV + (i * Br + tile_row(r)) * p.ne1 * HSV + 4 * d + comp] = D_TYPE(Of[r][d_local][comp]); } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index 9a71996383d..39f0c4d23b9 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -55,7 +55,7 @@ ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE ele return max(elem0, elem1); } -#if defined(BLOCK_SIZE) +#if BLOCK_SIZE > 1 #define DECODEFUNC , DEQUANTFUNC #else #define DECODEFUNC @@ -85,7 +85,7 @@ void main() { tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0); -#if defined(BLOCK_SIZE) +#if BLOCK_SIZE > 1 tensorLayoutK = setTensorLayoutBlockSizeNV(tensorLayoutK, 1, BLOCK_SIZE); tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, BLOCK_SIZE); #endif @@ -98,7 +98,7 @@ void main() { if (Clamp != gl_CooperativeMatrixClampModeConstantNV) { q_stride &= ~7; -#if !defined(BLOCK_SIZE) +#if BLOCK_SIZE == 1 k_stride &= ~7; v_stride &= ~7; #endif @@ -111,13 +111,13 @@ void main() { coopmat Q; coopmat Qf16; - uint32_t q_offset = iq2*p.nb02+iq3*p.nb03; + uint32_t q_offset = gqa_iq1*p.nb01*4/*sizeof(float)*/ + iq2*p.nb02+iq3*p.nb03; coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, HSK_pad)); Qf16 = coopmat(Q); Qf16 *= float16_t(p.scale); - coopmat O = coopmat(0); + coopmat O = coopmat(0); coopmat L, M; @@ -138,48 +138,53 @@ void main() { coopMatPerElementNV(slopeMat, slopeMat, perElemOpComputeSlope, iq2); } - uint32_t m_offset = 0; + const uint32_t mo_stride = CEIL_DIV(KV, 16 * Bc); + // mo_offset will point to the tile starting at row i*Br and col 0 + uint32_t mo_offset = mo_stride * i; + + uint32_t m_offset = gqa_iq1*KV * 2 /*sizeof(float16_t)*/; if (p.nem2 != 1 || p.nem3 != 1) { - m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV * 2 /*sizeof(float16_t)*/; + m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV * 2 /*sizeof(float16_t)*/; + mo_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * CEIL_DIV(p.nem1, Br) * mo_stride; } + uint32_t mask_opt = 0; + uint32_t mask_opt_idx = ~0; + [[dont_unroll]] for (uint32_t j = start_j; j < end_j; ++j) { - coopmat mv; - if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { - bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0; - - if (nem1_bounds_check) { - tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); - tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV); - tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1); - tensorLayoutM = setTensorLayoutClampValueNV(tensorLayoutM, 0xfc00); // -inf in float16_t - - coopmat mvmax; - - coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc)); - - // skip the block if the mask is entirely -inf - coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16); - if (mvmax[0] <= NEG_FLT_MAX_OVER_2) { - continue; - } - } else { - tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp); - // Don't clamp against nem1 when GQA is enabled - uint32_t m_height = p.gqa_ratio > 1 ? ~0 : p.nem1; - tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, m_height, KV); - tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1); - - coopmat mvmax; + coopmat mv = coopmat(0); + if (MASK_ENABLE) { - coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc)); - - // skip the block if the mask is entirely -inf - coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16); - if (mvmax[0] <= NEG_FLT_MAX_OVER_2) { - continue; + if (USE_MASK_OPT && mask_opt_idx != j / 16) { + mask_opt_idx = j / 16; + mask_opt = data_mask_opt[mo_offset + mask_opt_idx]; + } + uint32_t mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3; + if (mask_opt_bits == MASK_OPT_ALL_NEG_INF) { + // skip this block + continue; + } + // Only load if the block is not all zeros + if (mask_opt_bits != MASK_OPT_ALL_ZERO) { + bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0; + + if (nem1_bounds_check) { + tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); + tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV); + tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1); + tensorLayoutM = setTensorLayoutClampValueNV(tensorLayoutM, 0xfc00); // -inf in float16_t + + coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc)); + } else { + tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp); + // Don't clamp against nem1 when GQA is enabled + uint32_t m_height = p.gqa_ratio > 1 ? ~0 : p.nem1; + tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, m_height, KV); + tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1); + + coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc)); } } } @@ -192,14 +197,14 @@ void main() { coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose DECODEFUNC); S = coopMatMulAdd(Qf16, K_T, S); - if (p.logit_softcap != 0.0f) { + if (LOGIT_SOFTCAP) { [[unroll]] for (int k = 0; k < S.length(); ++k) { S[k] = ACC_TYPE(p.logit_softcap)*tanh(S[k]); } } - if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { + if (MASK_ENABLE) { S += slopeMat*coopmat(mv); } @@ -218,6 +223,8 @@ void main() { coopMatReduceNV(rowmax, S, gl_CooperativeMatrixReduceRowNV, maxReduce); + rowmax += coopmat(FATTN_KQ_MAX_OFFSET); + coopmat Mold = M; // M = max(rowmax, Mold) @@ -260,11 +267,8 @@ void main() { // resize eM by using smear/reduce coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce); - // multiply with fp16 accumulation, then add to O. - coopmat PV = coopmat(0); - PV = coopMatMulAdd(P_A, V, PV); - - O = eMdiag * O + coopmat(PV); + O *= coopmat(eMdiag); + O = coopMatMulAdd(P_A, V, O); } // If there is split_k, then the split_k resolve shader does the final @@ -272,10 +276,11 @@ void main() { if (p.k_num > 1) { coopmat O_D = coopmat(O); - uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num); + // note: O and Q have swapped coord 1,2. + uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N); - o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2; + o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N); coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N); return; @@ -305,7 +310,7 @@ void main() { if (sink > Mr[i]) { ms = exp(Mr[i] - sink); - O[i] *= ms; + O[i] *= float16_t(ms); } else { vs = exp(sink - Mr[i]); } @@ -319,15 +324,16 @@ void main() { Ldiag[k] = (Ldiag[k] == 0.0) ? ACC_TYPE(0.0) : (ACC_TYPE(1.0) / Ldiag[k]); } - O = Ldiag*O; + coopmat O_D = coopmat(O); + + O_D = coopmat(Ldiag)*O_D; #if defined(ACC_TYPE_MAX) - [[unroll]] for (uint i = 0; i < O.length(); ++i) { O[i] = clamp(O[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); } + [[unroll]] for (uint i = 0; i < O_D.length(); ++i) { O_D[i] = clamp(O_D[i], D_TYPE(-ACC_TYPE_MAX), D_TYPE(ACC_TYPE_MAX)); } #endif - uint32_t o_offset = iq3*p.ne2*p.ne1*HSV; + uint32_t o_offset = gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV; - coopmat O_D = coopmat(O); if (p.gqa_ratio > 1) { coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N); } else { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp new file mode 100644 index 00000000000..8c92c1adcda --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp @@ -0,0 +1,142 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_shader_16bit_storage : enable +#extension GL_KHR_shader_subgroup_arithmetic : enable + +layout (constant_id = 0) const uint BLOCK_SIZE = 128; +layout (constant_id = 1) const uint NUM_SUBGROUPS = 4; +layout (constant_id = 2) const uint Br = 32; +layout (constant_id = 3) const uint Bc = 32; + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {float16_t data_a[];}; +layout (binding = 0) readonly buffer Av4 {f16vec4 data_av4[];}; +layout (binding = 1) writeonly buffer D {uint data_d[];}; + +layout (push_constant) uniform parameter { + uint nem0; + uint nem1; + uint nem2; + uint nbm1; + uint nbm2; + uint nbm3; + uint nbd1; + uint nbd2; + uint nbd3; +}; + +#define MASK_OPT_ALL_NEG_INF 1 +#define MASK_OPT_ALL_ZERO 2 + +shared float minsh[NUM_SUBGROUPS]; +shared float maxsh[NUM_SUBGROUPS]; + +// For each Br x Bc block of the mask (input) buffer, read all values and check +// if it's all -inf or all zero. Write out a two-bit code indicating which it is +// (or zero for neither). Each workgroup processes 16 tiles and writes out a +// 32-bit result mask. +// +// TODO: This is a lot of work per workgroup, might make sense to split this into +// more workgroups in the future. +void main() { + // Each workgroup handles a row + const uint tid = gl_LocalInvocationIndex; + const uint i0 = gl_WorkGroupID.x; + const uint i1 = gl_WorkGroupID.y; + const uint i2 = gl_WorkGroupID.z % nem2; + const uint i3 = gl_WorkGroupID.z / nem2; + + float FLT_MAX_OVER_2 = uintBitsToFloat(0x7EFFFFFF); + + uint result = 0; + + // Fast path for fully in-bounds blocks where we can do f16vec4 loads + if ((nem0 % Bc) == 0 && (nem1 % Br) == 0 && + ((Br * Bc) % (BLOCK_SIZE * 4)) == 0) { + [[unroll]] for (uint block_x = 0; block_x < 16; ++block_x) { + float min_v = FLT_MAX_OVER_2; + float max_v = -FLT_MAX_OVER_2; + [[unroll]] for (uint i = 0; i < Br * Bc / 4; i += BLOCK_SIZE) { + uint j0 = (i + tid) % (Bc / 4); + uint j1 = (i + tid) / (Bc / 4); + + j0 *= 4; + j0 += (i0 * 16 + block_x) * Bc; + j1 += i1 * Br; + + vec4 f = vec4(data_av4[(j0 + j1 * nbm1 + i2 * nbm2 + i3 * nbm3) / 4]); + [[unroll]] for (int c = 0; c < 4; ++c) { + min_v = min(min_v, f[c]); + max_v = max(max_v, f[c]); + } + } + min_v = subgroupMin(min_v); + max_v = subgroupMax(max_v); + if (gl_SubgroupInvocationID == 0) { + minsh[gl_SubgroupID] = min_v; + maxsh[gl_SubgroupID] = max_v; + } + barrier(); + if (tid == 0) { + [[unroll]] for (uint i = 0; i < NUM_SUBGROUPS; ++i) { + min_v = min(min_v, minsh[i]); + max_v = max(max_v, maxsh[i]); + } + if (max_v <= -FLT_MAX_OVER_2) { + result |= 1 << (2*block_x); + } + if (min_v == 0.0f && max_v == 0.0f) { + result |= 2 << (2*block_x); + } + } + barrier(); + } + } else { + [[unroll]] for (uint block_x = 0; block_x < 16; ++block_x) { + float min_v = FLT_MAX_OVER_2; + float max_v = -FLT_MAX_OVER_2; + [[unroll]] for (uint i = 0; i < Br * Bc; i += BLOCK_SIZE) { + if ((Br * Bc % BLOCK_SIZE) != 0 && i + tid >= Br * Bc) { + continue; + } + uint j0 = (i + tid) % Bc; + uint j1 = (i + tid) / Bc; + + j0 += (i0 * 16 + block_x) * Bc; + j1 += i1 * Br; + + if (j0 < nem0 && j1 < nem1) { + float f = float(data_a[j0 + j1 * nbm1 + i2 * nbm2 + i3 * nbm3]); + min_v = min(min_v, f); + max_v = max(max_v, f); + } + } + min_v = subgroupMin(min_v); + max_v = subgroupMax(max_v); + if (gl_SubgroupInvocationID == 0) { + minsh[gl_SubgroupID] = min_v; + maxsh[gl_SubgroupID] = max_v; + } + barrier(); + if (tid == 0) { + [[unroll]] for (uint i = 0; i < NUM_SUBGROUPS; ++i) { + min_v = min(min_v, minsh[i]); + max_v = max(max_v, maxsh[i]); + } + if (max_v <= -FLT_MAX_OVER_2) { + result |= 1 << (2*block_x); + } + if (min_v == 0.0f && max_v == 0.0f) { + result |= 2 << (2*block_x); + } + } + barrier(); + } + } + + if (tid == 0) { + data_d[i0 + i1 * nbd1 + i2 * nbd2 + i3 * nbd3] = result; + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp index 4eaddd31a8f..68917fc0bb0 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp @@ -12,7 +12,8 @@ layout (binding = 2) writeonly buffer D {float data_d[];}; layout (push_constant) uniform parameter { uint D; - uint N; + uint ne1; + uint ne2; uint ne3; uint k_num; uint sinks; @@ -24,15 +25,15 @@ void main() { // Each workgroup handles a row const uint n = gl_WorkGroupID.x; const uint tid = gl_LocalInvocationID.x; - const uint iq3 = gl_WorkGroupID.z; + const uint i2 = gl_WorkGroupID.z % p.ne2; + const uint i3 = gl_WorkGroupID.z / p.ne2; uint D = p.D; - uint N = p.N; uint k_num = p.k_num; - uint l_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + n; - uint m_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + N + n; - uint lm_stride = N * 2; + uint l_offset = D * p.ne1 * p.ne2 * p.ne3 * k_num + p.ne1 * 2 * (0/*split_k_index*/ + p.k_num * (i2 + p.ne2 * i3)) + n; + uint m_offset = D * p.ne1 * p.ne2 * p.ne3 * k_num + p.ne1 * 2 * (0/*split_k_index*/ + p.k_num * (i2 + p.ne2 * i3)) + p.ne1 + n; + uint lm_stride = p.ne1 * 2; // Compute the max m value for the row float m_max = -1.0/0.0; @@ -99,7 +100,7 @@ void main() { if (d < D) { float O = 0.0; [[unroll]] for (uint k = 0; k < k_num; ++k) { - uint o_offset = D * N * (k + iq3 * k_num) + D * n + d; + uint o_offset = D * p.ne1 * (k + p.k_num * (i2 + p.ne2 * i3)) + D * n + d; float m = data_a[m_offset + k * lm_stride]; O += exp(m - m_max) * data_a[o_offset]; } @@ -115,6 +116,6 @@ void main() { const float FLT_MAX = uintBitsToFloat(0x7F7FFFFF); O = clamp(O, -FLT_MAX, FLT_MAX); - data_d[iq3 * D * N + D * n + d] = O; + data_d[(i3 * p.ne2 + i2) * p.ne1 * D + D * n + d] = O; } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl index dfb78659362..4f2c7003065 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl @@ -29,6 +29,8 @@ layout (push_constant) uniform parameter #ifdef MUL_MAT_ID uint nei0; uint ne11; + uint expert_i1; + uint nbi1; #else uint ne02; uint ne12; @@ -43,7 +45,7 @@ uint expert_id; void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) { #ifdef MUL_MAT_ID - const uint expert_idx = gl_GlobalInvocationID.y; + const uint expert_i0 = gl_GlobalInvocationID.y; #else const uint batch_idx = gl_GlobalInvocationID.y; #endif @@ -60,7 +62,7 @@ void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) { batch_idx_a = i03 * p.ne02 + i02; } #else - expert_id = data_ids[expert_idx]; + expert_id = data_ids[expert_i0 + p.expert_i1 * p.nbi1]; #endif a_offset = @@ -71,13 +73,13 @@ void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) { #endif b_offset = #ifdef MUL_MAT_ID - (expert_idx % p.ne11) * p.stride_b; + (expert_i0 % p.ne11) * p.stride_b + p.expert_i1 * p.batch_stride_b; #else batch_idx * p.batch_stride_b; #endif d_offset = #ifdef MUL_MAT_ID - expert_idx * p.stride_d; + expert_i0 * p.stride_d + p.expert_i1 * p.batch_stride_d; #else batch_idx * p.batch_stride_d; #endif @@ -103,12 +105,12 @@ void reduce_result(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t temp[j][n] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]); } if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) { - const uint expert_idx = gl_GlobalInvocationID.y; - temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_idx]); + const uint expert_i0 = gl_GlobalInvocationID.y; + temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_i0]); } if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) { - const uint expert_idx = gl_GlobalInvocationID.y; - temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_idx]); + const uint expert_i0 = gl_GlobalInvocationID.y; + temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_i0]); } #else if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) { @@ -158,12 +160,12 @@ void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offs temp[j][n] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]); } if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) { - const uint expert_idx = gl_GlobalInvocationID.y; - temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_idx]); + const uint expert_i0 = gl_GlobalInvocationID.y; + temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_i0]); } if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) { - const uint expert_idx = gl_GlobalInvocationID.y; - temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_idx]); + const uint expert_i0 = gl_GlobalInvocationID.y; + temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_i0]); } #else if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) { @@ -203,12 +205,12 @@ void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offs tmpsh[j][n][0] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]); } if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) { - const uint expert_idx = gl_GlobalInvocationID.y; - tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse0[expert_idx]); + const uint expert_i0 = gl_GlobalInvocationID.y; + tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse0[expert_i0]); } if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) { - const uint expert_idx = gl_GlobalInvocationID.y; - tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse1[expert_idx]); + const uint expert_i0 = gl_GlobalInvocationID.y; + tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse1[expert_i0]); } #else if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp index 9d6d3665427..55b89f19a7a 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp @@ -112,12 +112,11 @@ void rms_norm(uint num_iters) { #if RMS_NORM_ROPE_FUSION barrier(); rope_params rp = p.rope; - uint rope_row = (samp*nchannels + channel)*nrows + row; for (uint t = 2*tid; t < ncols; t += 2*BLOCK_SIZE) { if (rp.rope_mode == GGML_ROPE_TYPE_NEOX) { - rope_neox(t, rope_row, rp); + rope_neox(t, row, channel, samp, rp); } else if (rp.rope_mode == GGML_ROPE_TYPE_NORMAL) { - rope_norm(t, rope_row, rp); + rope_norm(t, row, channel, samp, rp); } } #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl index aacec984696..2e53459909d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl @@ -4,12 +4,12 @@ float rope_yarn_ramp(const float low, const float high, const uint i0) { return 1.0f - min(1.0f, max(0.0f, y)); } -uint rope_a_coord(const uint i0, const uint i01, const uint i02, rope_params p) { +uint rope_a_coord(const uint i0, const uint i01, const uint i02, const uint i03, rope_params p) { #if RMS_NORM_ROPE_FUSION // Per-row offset in shared memory const uint ix = i0; #else - const uint ix = i02*p.nb02 + i01*p.nb01 + i0; + const uint ix = i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i0; #endif return ix; } @@ -34,26 +34,19 @@ void rope_yarn(const float theta_extrap, const uint i0, out float cos_theta, out sin_theta = sin(theta) * mscale; } -void rope_norm(const uint i0, const uint i1, rope_params p) { - uint ne0 = p.ncols; - uint ne1 = p.p_delta_rows; - - if (i0 >= ne0) { +void rope_norm(const uint i0, const uint i1, const uint i2, const uint i3, rope_params p) { + if (i0 >= p.ne00) { return; } - // i1 is actually i2*nb2+i1, but the rows are contiguous - const uint i01 = i1 % ne1; - const uint i02 = i1 / ne1; - - uint idst = i1*ne0 + i0; - const uint ix = rope_a_coord(i0, i01, i02, p); + uint idst = i0 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13; + const uint ix = rope_a_coord(i0, i1, i2, i3, p); // Fusion optimization: ROPE + VIEW + SET_ROWS. // The rope output is viewed as a 1D tensor and offset based on a row index in rope_data_i. if (p.set_rows_stride != 0) { - idst = i01*ne0 + i0; - idst += rope_data_i[i02].x * p.set_rows_stride; + idst = i1*p.nb11 + i0; + idst += rope_data_i[i2].x * p.set_rows_stride; } if (i0 >= p.n_dims) { @@ -63,7 +56,7 @@ void rope_norm(const uint i0, const uint i1, rope_params p) { return; } - const float theta_base = rope_data_pos[i02] * pow(p.theta_scale, i0/2.0f); + const float theta_base = rope_data_pos[i2] * pow(p.theta_scale, i0/2.0f); const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f; @@ -77,25 +70,19 @@ void rope_norm(const uint i0, const uint i1, rope_params p) { rope_data_d[idst + 1] = ROPE_D_TYPE(x0*sin_theta + x1*cos_theta); } -void rope_neox(const uint i0, const uint i1, rope_params p) { - uint ne0 = p.ncols; - uint ne1 = p.p_delta_rows; - - if (i0 >= ne0) { +void rope_neox(const uint i0, const uint i1, const uint i2, const uint i3, rope_params p) { + if (i0 >= p.ne00) { return; } - const uint i01 = i1 % ne1; - const uint i02 = i1 / ne1; - - uint idst = i1*ne0 + i0/2; - const uint ix = rope_a_coord(i0/2, i01, i02, p); + uint idst = i0/2 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13; + const uint ix = rope_a_coord(i0/2, i1, i2, i3, p); // Fusion optimization: ROPE + VIEW + SET_ROWS. // The rope output is viewed as a 1D tensor and offset based on a row index in rope_data_i. if (p.set_rows_stride != 0) { - idst = i01*ne0 + i0/2; - idst += rope_data_i[i02].x * p.set_rows_stride; + idst = i1*p.nb11 + i0/2; + idst += rope_data_i[i2].x * p.set_rows_stride; } if (i0 >= p.n_dims) { @@ -105,7 +92,7 @@ void rope_neox(const uint i0, const uint i1, rope_params p) { return; } - const float theta_base = rope_data_pos[i02] * pow(p.theta_scale, i0/2.0f); + const float theta_base = rope_data_pos[i2] * pow(p.theta_scale, i0/2.0f); const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f; @@ -120,26 +107,19 @@ void rope_neox(const uint i0, const uint i1, rope_params p) { } -void rope_multi(const uint i0, const uint i1, rope_params p) { - uint ne0 = p.ncols; - uint ne1 = p.p_delta_rows; - uint ne2 = p.ne02; - - if (i0 >= ne0) { +void rope_multi(const uint i0, const uint i1, const uint i2, const uint i3, rope_params p) { + if (i0 >= p.ne00) { return; } - const uint i01 = i1 % ne1; - const uint i02 = i1 / ne1; - - uint idst = i1*ne0 + i0/2; - const uint ix = rope_a_coord(i0/2, i01, i02, p); + uint idst = i0/2 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13; + const uint ix = rope_a_coord(i0/2, i1, i2, i3, p); // Fusion optimization: ROPE + VIEW + SET_ROWS. // The rope output is viewed as a 1D tensor and offset based on a row index in rope_data_i. if (p.set_rows_stride != 0) { - idst = i01*ne0 + i0/2; - idst += rope_data_i[i02].x * p.set_rows_stride; + idst = i1*p.nb11 + i0/2; + idst += rope_data_i[i2].x * p.set_rows_stride; } if (i0 >= p.n_dims) { @@ -156,26 +136,26 @@ void rope_multi(const uint i0, const uint i1, rope_params p) { float theta_base = 0.0; if (p.is_imrope != 0) { if (sector % 3 == 1 && sector < 3 * p.sections[1]) { - theta_base = rope_data_pos[i02 + ne2 * 1]*pow(p.theta_scale, i0/2.0f); + theta_base = rope_data_pos[i2 + p.ne02 * 1]*pow(p.theta_scale, i0/2.0f); } else if (sector % 3 == 2 && sector < 3 * p.sections[2]) { - theta_base = rope_data_pos[i02 + ne2 * 2]*pow(p.theta_scale, i0/2.0f); + theta_base = rope_data_pos[i2 + p.ne02 * 2]*pow(p.theta_scale, i0/2.0f); } else if (sector % 3 == 0 && sector < 3 * p.sections[0]) { - theta_base = rope_data_pos[i02]*pow(p.theta_scale, i0/2.0f); + theta_base = rope_data_pos[i2]*pow(p.theta_scale, i0/2.0f); } else { - theta_base = rope_data_pos[i02 + ne2 * 3]*pow(p.theta_scale, i0/2.0f); + theta_base = rope_data_pos[i2 + p.ne02 * 3]*pow(p.theta_scale, i0/2.0f); } } else { if (sector < p.sections[0]) { - theta_base = rope_data_pos[i02]*pow(p.theta_scale, i0/2.0f); + theta_base = rope_data_pos[i2]*pow(p.theta_scale, i0/2.0f); } else if (sector >= p.sections[0] && sector < sec_w) { - theta_base = rope_data_pos[i02 + ne2 * 1]*pow(p.theta_scale, i0/2.0f); + theta_base = rope_data_pos[i2 + p.ne02 * 1]*pow(p.theta_scale, i0/2.0f); } else if (sector >= sec_w && sector < sec_w + p.sections[2]) { - theta_base = rope_data_pos[i02 + ne2 * 2]*pow(p.theta_scale, i0/2.0f); + theta_base = rope_data_pos[i2 + p.ne02 * 2]*pow(p.theta_scale, i0/2.0f); } else if (sector >= sec_w + p.sections[2]) { - theta_base = rope_data_pos[i02 + ne2 * 3]*pow(p.theta_scale, i0/2.0f); + theta_base = rope_data_pos[i2 + p.ne02 * 3]*pow(p.theta_scale, i0/2.0f); } } @@ -191,20 +171,13 @@ void rope_multi(const uint i0, const uint i1, rope_params p) { rope_data_d[idst + p.n_dims/2] = ROPE_D_TYPE(x0*sin_theta + x1*cos_theta); } -void rope_vision(const uint i0, const uint i1, rope_params p) { - uint ne0 = p.ncols; - uint ne1 = p.p_delta_rows; - uint ne2 = p.ne02; - - if (i0 >= ne0) { +void rope_vision(const uint i0, const uint i1, const uint i2, const uint i3, rope_params p) { + if (i0 >= p.ne00) { return; } - const uint i01 = i1 % ne1; - const uint i02 = i1 / ne1; - - const uint idst = i1*ne0 + i0/2; - const uint ix = rope_a_coord(i0/2, i01, i02, p); + const uint idst = i0/2 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13; + const uint ix = rope_a_coord(i0/2, i1, i2, i3, p); const int sect_dims = p.sections[0] + p.sections[1]; const int sec_w = p.sections[1] + p.sections[0]; @@ -213,11 +186,11 @@ void rope_vision(const uint i0, const uint i1, rope_params p) { float theta_base = 0.0; if (sector < p.sections[0]) { const uint p0 = sector; - theta_base = rope_data_pos[i02]*pow(p.theta_scale, p0); + theta_base = rope_data_pos[i2]*pow(p.theta_scale, p0); } else if (sector >= p.sections[0] && sector < sec_w) { const uint p0 = sector - p.sections[0]; - theta_base = rope_data_pos[i02 + ne2]*pow(p.theta_scale, p0); + theta_base = rope_data_pos[i2 + p.ne02]*pow(p.theta_scale, p0); } const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp index f7587468a81..1528fbeeaec 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp @@ -5,10 +5,13 @@ void main() { const uint i0 = 2*gl_GlobalInvocationID.y; - // i1 is actually i2*nb2+i1, but the rows are contiguous - const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z; - if (i1 >= pc.nrows) { + const uint row = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z; + if (row >= pc.nrows) { return; } - rope_multi(i0, i1, pc); + const uint i3 = row / (pc.ne01*pc.ne02); + const uint i2 = (row - i3 * pc.ne01*pc.ne02) / pc.ne01; + const uint i1 = (row - i3 * pc.ne01*pc.ne02 - i2 * pc.ne01); + + rope_multi(i0, i1, i2, i3, pc); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp index acb8ed78155..ad0896095db 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp @@ -5,10 +5,13 @@ void main() { const uint i0 = 2*gl_GlobalInvocationID.y; - // i1 is actually i2*nb2+i1, but the rows are contiguous - const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z; - if (i1 >= pc.nrows) { + const uint row = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z; + if (row >= pc.nrows) { return; } - rope_neox(i0, i1, pc); + const uint i3 = row / (pc.ne01*pc.ne02); + const uint i2 = (row - i3 * pc.ne01*pc.ne02) / pc.ne01; + const uint i1 = (row - i3 * pc.ne01*pc.ne02 - i2 * pc.ne01); + + rope_neox(i0, i1, i2, i3, pc); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp index 0033cdb224f..11220817df0 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp @@ -5,10 +5,13 @@ void main() { const uint i0 = 2*gl_GlobalInvocationID.y; - // i1 is actually i2*nb2+i1, but the rows are contiguous - const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z; - if (i1 >= pc.nrows) { + const uint row = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z; + if (row >= pc.nrows) { return; } - rope_norm(i0, i1, pc); + const uint i3 = row / (pc.ne01*pc.ne02); + const uint i2 = (row - i3 * pc.ne01*pc.ne02) / pc.ne01; + const uint i1 = (row - i3 * pc.ne01*pc.ne02 - i2 * pc.ne01); + + rope_norm(i0, i1, i2, i3, pc); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl index 939cf3c51cd..ec6ceaca9bd 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl @@ -5,24 +5,29 @@ struct rope_params { uint rope_mode; - uint ncols; uint nrows; uint n_dims; float freq_scale; - uint p_delta_rows; float freq_base; float ext_factor; float attn_factor; float corr_dims[2]; float theta_scale; uint has_ff; - uint ne02; - uint nb01; - uint nb02; int sections[4]; uint is_imrope; uint is_back; uint set_rows_stride; + + uint ne00; + uint ne01; + uint ne02; + uint nb01; + uint nb02; + uint nb03; + uint nb11; + uint nb12; + uint nb13; }; #endif // !defined(GGML_ROPE_PARAMS) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp index d93800b5e76..ca71efb2f55 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp @@ -5,10 +5,13 @@ void main() { const uint i0 = 2*gl_GlobalInvocationID.y; - // i1 is actually i2*nb2+i1, but the rows are contiguous - const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z; - if (i1 >= pc.nrows) { + const uint row = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z; + if (row >= pc.nrows) { return; } - rope_vision(i0, i1, pc); + const uint i3 = row / (pc.ne01*pc.ne02); + const uint i2 = (row - i3 * pc.ne01*pc.ne02) / pc.ne01; + const uint i1 = (row - i3 * pc.ne01*pc.ne02 - i2 * pc.ne01); + + rope_vision(i0, i1, i2, i3, pc); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index bbdbf9dcaaa..42ebc21e2a6 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -330,7 +330,7 @@ void string_to_spv_func(std::string name, std::string in_path, std::string out_p std::vector cmd = {GLSLC, "-fshader-stage=compute", target_env, in_path, "-o", out_path}; #endif - // disable spirv-opt for coopmat shaders for https://github.com/ggerganov/llama.cpp/issues/10734 + // disable spirv-opt for coopmat shaders for https://github.com/ggml-org/llama.cpp/issues/10734 // disable spirv-opt for bf16 shaders for https://github.com/ggml-org/llama.cpp/issues/15344 // disable spirv-opt for rope shaders for https://github.com/ggml-org/llama.cpp/issues/16860 if (!coopmat && name.find("bf16") == std::string::npos && name.find("rope") == std::string::npos) { @@ -790,6 +790,8 @@ void process_shaders() { string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {}); string_to_spv("fa_split_k_reduce", "flash_attn_split_k_reduce.comp", {}); + string_to_spv("fa_mask_opt", "flash_attn_mask_opt.comp", {}); + string_to_spv("quantize_q8_1", "quantize_q8_1.comp", {}); string_to_spv("quantize_q8_1_subgroup", "quantize_q8_1.comp", {{"USE_SUBGROUPS", "1"}}); diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 84d88e81d45..6997f6bdd31 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -465,4 +465,73 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_unary_shader( return result; } +/** Binary **/ + +struct ggml_webgpu_binary_pipeline_key { + int type; + int op; + bool inplace; + bool overlap; + + bool operator==(const ggml_webgpu_binary_pipeline_key & other) const { + return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap; + } +}; + +struct ggml_webgpu_binary_pipeline_key_hash { + size_t operator()(const ggml_webgpu_binary_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.type); + ggml_webgpu_hash_combine(seed, key.op); + ggml_webgpu_hash_combine(seed, key.inplace); + ggml_webgpu_hash_combine(seed, key.overlap); + return seed; + } +}; + +struct ggml_webgpu_binary_shader_lib_context { + ggml_webgpu_binary_pipeline_key key; + uint32_t max_wg_size; +}; + +inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_binary_shader( + pre_wgsl::Preprocessor & preprocessor, + const char * shader_src, + const ggml_webgpu_binary_shader_lib_context & context) { + std::vector defines; + std::string op_name = ggml_op_name((ggml_op) context.key.op); + std::string variant = op_name; + + defines.push_back(std::string("OP_") + op_name); + + switch (context.key.type) { + case GGML_TYPE_F32: + defines.push_back("TYPE_F32"); + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("TYPE_F16"); + variant += "_f16"; + break; + default: + GGML_ABORT("Unsupported type for binary shader"); + } + + if (context.key.inplace) { + defines.push_back("INPLACE"); + variant += "_inplace"; + } else if (context.key.overlap) { + defines.push_back("OVERLAP"); + variant += "_overlap"; + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + ggml_webgpu_processed_shader result; + result.wgsl = preprocessor.preprocess(shader_src, defines); + result.variant = variant; + ggml_webgpu_generic_shader_decisions * decisions = new ggml_webgpu_generic_shader_decisions(); + decisions->wg_size = context.max_wg_size; + result.decisions = decisions; + return result; +} #endif // GGML_WEBGPU_SHADER_LIB_HPP diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 584cea7698b..f7ceca11212 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -47,7 +47,6 @@ double cpu_total_time_##id = \ std::chrono::duration(cpu_total_end_##id - cpu_total_start_##id).count(); \ (ctx)->cpu_time_ms[#id] += cpu_total_time_##id; - // fine-grained timing (not included in totals) # define WEBGPU_CPU_PROFILE_DETAIL_START(id) auto cpu_detail_start_##id = std::chrono::high_resolution_clock::now(); @@ -74,13 +73,13 @@ #define WEBGPU_MAX_WG_SIZE 288 #define WEBGPU_MUL_MAT_WG_SIZE 256 -#define WEBGPU_NUM_PARAM_BUFS 32u +#define WEBGPU_NUM_PARAM_BUFS 16u #define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 8u #define WEBGPU_WAIT_ANY_TIMEOUT_MS 0 // Maximum number of in-flight submissions per-thread, to avoid exhausting the parameter buffer pool #define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE #define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters -#define WEBGPU_NUM_SET_ROWS_ERROR_BUFS 32 +#define WEBGPU_NUM_SET_ROWS_ERROR_BUFS 16 #define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4 #define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4 @@ -147,8 +146,13 @@ struct webgpu_submission_futures { struct webgpu_buf_pool { std::vector free; - std::mutex mutex; - + // The pool must be synchronized because + // 1. The memset pool is shared globally by every ggml buffer, + // since allocating a pool per ggml buffer would consume too much memory. + // 2. For the per-thread buffer pools in webgpu_context, + // buffers are allocated and freed in Dawn callbacks, + // which can run on a different thread than the calling thread. + std::mutex mutex; std::condition_variable cv; void init(wgpu::Device device, @@ -267,30 +271,67 @@ struct webgpu_command { #endif }; -// All the base objects needed to run operations on a WebGPU device -struct webgpu_context_struct { +struct webgpu_capabilities { + wgpu::Limits limits; + bool supports_subgroup_matrix = false; + + uint32_t sg_mat_m = 0; + uint32_t sg_mat_n = 0; + uint32_t sg_mat_k = 0; + + uint32_t subgroup_size = 0; + uint32_t max_subgroup_size = 0; + size_t memset_bytes_per_thread; +}; + +// Stores global webgpu members +struct webgpu_global_context_struct { wgpu::Instance instance; wgpu::Adapter adapter; wgpu::Device device; wgpu::Queue queue; - wgpu::Limits limits; - uint32_t max_subgroup_size; + webgpu_capabilities capabilities; + // Shared buffer to move data from device to host + wgpu::Buffer get_tensor_staging_buf; + // Global mutex for pipeline and staging buffer, will be refactored to exclude pipeline caches. + std::recursive_mutex mutex; - bool supports_subgroup_matrix = false; - uint32_t sg_mat_m; - uint32_t sg_mat_n; - uint32_t sg_mat_k; + webgpu_buf_pool memset_buf_pool; + std::map memset_pipelines; // variant or type index + std::atomic_uint inflight_threads = 0; - std::recursive_mutex mutex; - std::atomic_uint inflight_threads = 0; +#ifdef GGML_WEBGPU_CPU_PROFILE + // Profiling: labeled CPU time in ms (total) + std::unordered_map cpu_time_ms; + // Profiling: detailed CPU time in ms + std::unordered_map cpu_detail_ms; +#endif - webgpu_buf_pool param_buf_pool; - webgpu_buf_pool set_rows_error_buf_pool; +#ifdef GGML_WEBGPU_GPU_PROFILE + // Profiling: per-shader GPU time in ms + std::unordered_map shader_gpu_time_ms; + // Profiling: pool of timestamp query buffers (one per operation) + webgpu_gpu_profile_buf_pool timestamp_query_buf_pool; +#endif + +#ifdef GGML_WEBGPU_DEBUG + wgpu::Buffer debug_host_buf; + wgpu::Buffer debug_dev_buf; +#endif +}; + +typedef std::shared_ptr webgpu_global_context; + +// All the base objects needed to run operations on a WebGPU device +struct webgpu_context_struct { + // Points to global instances owned by ggml_backend_webgpu_reg_context + webgpu_global_context global_ctx; pre_wgsl::Preprocessor p; - std::map memset_pipelines; // variant or type index + webgpu_buf_pool param_buf_pool; + webgpu_buf_pool set_rows_error_buf_pool; std::map>> mul_mat_pipelines; // src0_type, src1_type, vectorized std::map>> @@ -307,13 +348,12 @@ struct webgpu_context_struct { std::unordered_map set_rows_pipelines; - std::map> get_rows_pipelines; // src_type, vectorized + std::map> get_rows_pipelines; // src_type, vectorized + + std::map> cpy_pipelines; // src_type, dst_type - std::map> cpy_pipelines; // src_type, dst_type - std::map> add_pipelines; // type, inplace - std::map> sub_pipelines; // type, inplace - std::map> mul_pipelines; // type, inplace - std::map> div_pipelines; // type, inplace + std::unordered_map + binary_pipelines; std::map rms_norm_pipelines; // inplace std::map>> rope_pipelines; // type, ff, inplace @@ -325,58 +365,41 @@ struct webgpu_context_struct { std::unordered_map pad_pipelines; size_t memset_bytes_per_thread; - - // Staging buffer for reading data from the GPU - wgpu::Buffer get_tensor_staging_buf; - -#ifdef GGML_WEBGPU_DEBUG - wgpu::Buffer debug_host_buf; - wgpu::Buffer debug_dev_buf; -#endif - -#ifdef GGML_WEBGPU_CPU_PROFILE - // Profiling: labeled CPU time in ms (total) - std::unordered_map cpu_time_ms; - // Profiling: detailed CPU time in ms - std::unordered_map cpu_detail_ms; -#endif - -#ifdef GGML_WEBGPU_GPU_PROFILE - // Profiling: per-shader GPU time in ms - std::unordered_map shader_gpu_time_ms; - // Profiling: pool of timestamp query buffers (one per operation) - webgpu_gpu_profile_buf_pool timestamp_query_buf_pool; -#endif }; typedef std::shared_ptr webgpu_context; +// Metadata required for the ggml backend registration/discovery interface struct ggml_backend_webgpu_reg_context { - webgpu_context webgpu_ctx; - size_t device_count; - const char * name; + // Since the Instance is a global entrypoint into the WebGPU API, it lives here + webgpu_global_context webgpu_global_ctx; + size_t device_count; + const char * name; }; +// Per-device struct for the global logical device interface struct ggml_backend_webgpu_device_context { - webgpu_context webgpu_ctx; - std::string device_name; - std::string device_desc; + webgpu_global_context webgpu_global_ctx; + std::string device_name; + std::string device_desc; }; +// Per-thread data required to actually run WebGPU operations in a backend instance struct ggml_backend_webgpu_context { webgpu_context webgpu_ctx; std::string name; }; +// Per-thread data related to buffers struct ggml_backend_webgpu_buffer_context { - webgpu_context webgpu_ctx; - wgpu::Buffer buffer; - std::string label; + wgpu::Buffer buffer; + std::string label; + webgpu_global_context global_ctx; - ggml_backend_webgpu_buffer_context(webgpu_context ctx, wgpu::Buffer buf, std::string lbl) : - webgpu_ctx(std::move(ctx)), + ggml_backend_webgpu_buffer_context(wgpu::Buffer buf, std::string lbl, webgpu_global_context global_ctx_) : buffer(std::move(buf)), - label(std::move(lbl)) {} + label(std::move(lbl)), + global_ctx(std::move(global_ctx_)) {} }; /* WebGPU object initializations */ @@ -444,7 +467,7 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device, /** WebGPU Actions */ // Wait for the queue to finish processing all submitted work -static void ggml_backend_webgpu_wait(webgpu_context & ctx, +static void ggml_backend_webgpu_wait(webgpu_global_context & ctx, std::vector & futures, bool block = true) { // If we have too many in-flight submissions, wait on the oldest one first. If there are many threads, @@ -476,11 +499,11 @@ static void ggml_backend_webgpu_wait(webgpu_context & ct } } -static void ggml_backend_webgpu_map_buffer(webgpu_context & ctx, - wgpu::Buffer & buffer, - wgpu::MapMode mode, - size_t offset, - size_t size) { +static void ggml_backend_webgpu_map_buffer(webgpu_global_context & ctx, + wgpu::Buffer & buffer, + wgpu::MapMode mode, + size_t offset, + size_t size) { ctx->instance.WaitAny(buffer.MapAsync(mode, offset, size, wgpu::CallbackMode::AllowSpontaneous, [](wgpu::MapAsyncStatus status, wgpu::StringView message) { if (status != wgpu::MapAsyncStatus::Success) { @@ -495,7 +518,7 @@ static void ggml_backend_webgpu_map_buffer(webgpu_context & ctx, // This function adds debugging information to shaders, as WebGPU does not support printing directly. // To use, add a bind group entry to the setup for the shader you are debugging, add the buffer and // debug statements in the shader, and then call this function after encoding the commands and submitting them. -static void ggml_backend_webgpu_debug(webgpu_context & ctx) { +static void ggml_backend_webgpu_debug(webgpu_global_context & ctx) { wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder(); encoder.CopyBufferToBuffer(ctx->debug_dev_buf, 0, ctx->debug_host_buf, 0, ctx->debug_host_buf.GetSize()); wgpu::CommandBuffer commands = encoder.Finish(); @@ -507,7 +530,10 @@ static void ggml_backend_webgpu_debug(webgpu_context & ctx) { } #endif -static webgpu_submission_futures ggml_backend_webgpu_submit(webgpu_context ctx, std::vector commands) { +static webgpu_submission_futures ggml_backend_webgpu_submit(webgpu_global_context ctx, + std::vector commands, + webgpu_buf_pool & param_buf_pool, + webgpu_buf_pool * set_rows_error_buf_pool = nullptr) { std::vector command_buffers; std::vector params_bufs; std::vector set_rows_error_bufs; @@ -528,19 +554,19 @@ static webgpu_submission_futures ggml_backend_webgpu_submit(webgpu_context ctx, wgpu::Future p_f = ctx->queue.OnSubmittedWorkDone( wgpu::CallbackMode::AllowSpontaneous, - [ctx, params_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { + [¶m_buf_pool, params_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { if (status != wgpu::QueueWorkDoneStatus::Success) { GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", std::string(message).c_str()); } // Free the staged buffers - ctx->param_buf_pool.free_bufs(params_bufs); + param_buf_pool.free_bufs(params_bufs); }); futures.push_back({ p_f }); for (const auto & bufs : set_rows_error_bufs) { wgpu::Future f = bufs.host_buf.MapAsync( wgpu::MapMode::Read, 0, bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous, - [ctx, bufs](wgpu::MapAsyncStatus status, wgpu::StringView message) { + [set_rows_error_buf_pool, bufs](wgpu::MapAsyncStatus status, wgpu::StringView message) { if (status != wgpu::MapAsyncStatus::Success) { GGML_LOG_ERROR("ggml_webgpu: Failed to map error buffer: %s\n", std::string(message).c_str()); } else { @@ -549,7 +575,9 @@ static webgpu_submission_futures ggml_backend_webgpu_submit(webgpu_context ctx, GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported."); } // We can't unmap in here due to WebGPU reentrancy limitations. - ctx->set_rows_error_buf_pool.free_bufs({ bufs }); + if (set_rows_error_buf_pool) { + set_rows_error_buf_pool->free_bufs({ bufs }); + } } }); futures.push_back({ f }); @@ -581,7 +609,8 @@ static webgpu_submission_futures ggml_backend_webgpu_submit(webgpu_context ctx, } static webgpu_command ggml_backend_webgpu_build_multi( - webgpu_context & ctx, + webgpu_global_context & ctx, + webgpu_buf_pool & param_buf_pool, const std::vector & pipelines, const std::vector> & params_list, const std::vector> & bind_group_entries_list, @@ -595,7 +624,7 @@ static webgpu_command ggml_backend_webgpu_build_multi( std::vector bind_groups; for (size_t i = 0; i < pipelines.size(); i++) { - webgpu_pool_bufs params_bufs = ctx->param_buf_pool.alloc_bufs(); + webgpu_pool_bufs params_bufs = param_buf_pool.alloc_bufs(); ggml_backend_webgpu_map_buffer(ctx, params_bufs.host_buf, wgpu::MapMode::Write, 0, params_bufs.host_buf.GetSize()); @@ -672,34 +701,37 @@ static webgpu_command ggml_backend_webgpu_build_multi( return result; } -static webgpu_command ggml_backend_webgpu_build(webgpu_context & ctx, +static webgpu_command ggml_backend_webgpu_build(webgpu_global_context & ctx, + webgpu_buf_pool & param_buf_pool, webgpu_pipeline & pipeline, std::vector params, std::vector bind_group_entries, uint32_t wg_x, uint32_t wg_y = 1, std::optional set_rows_error_bufs = std::nullopt) { - return ggml_backend_webgpu_build_multi(ctx, + return ggml_backend_webgpu_build_multi(ctx, param_buf_pool, { pipeline }, { params }, { bind_group_entries }, { { wg_x, wg_y } }, set_rows_error_bufs); } -static void ggml_backend_webgpu_buffer_memset(webgpu_context & ctx, - wgpu::Buffer & buf, - uint32_t value, - size_t offset, - size_t size) { +static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx, + wgpu::Buffer & buf, + uint32_t value, + size_t offset, + size_t size) { std::vector params = { (uint32_t) offset, (uint32_t) size, value }; std::vector entries = { { .binding = 0, .buffer = buf, .offset = 0, .size = buf.GetSize() } }; - size_t bytes_per_wg = WEBGPU_MAX_WG_SIZE * ctx->memset_bytes_per_thread; + size_t bytes_per_wg = WEBGPU_MAX_WG_SIZE * ctx->capabilities.memset_bytes_per_thread; uint32_t wg_x = CEIL_DIV(size + 3, bytes_per_wg); - webgpu_command command = ggml_backend_webgpu_build(ctx, ctx->memset_pipelines[0], params, entries, wg_x); - std::vector futures = { ggml_backend_webgpu_submit(ctx, { command }) }; + webgpu_command command = + ggml_backend_webgpu_build(ctx, ctx->memset_buf_pool, ctx->memset_pipelines[0], params, entries, wg_x); + std::vector futures = { ggml_backend_webgpu_submit(ctx, { command }, + ctx->memset_buf_pool) }; ggml_backend_webgpu_wait(ctx, futures); } @@ -720,19 +752,19 @@ static void ggml_backend_webgpu_free(ggml_backend_t backend) { #ifdef GGML_WEBGPU_CPU_PROFILE std::cout << "\n[ggml_webgpu cpu profiling summary]\n"; double total_cpu = 0.0; - for (const auto & kv : ctx->webgpu_ctx->cpu_time_ms) { + for (const auto & kv : ctx->webgpu_ctx->global_ctx->cpu_time_ms) { total_cpu += kv.second; } std::cout << "ggml_webgpu: total cpu time: " << total_cpu << " ms\n"; std::cout << "ggml_webgpu: cpu breakdown:\n"; - for (const auto & kv : ctx->webgpu_ctx->cpu_time_ms) { + for (const auto & kv : ctx->webgpu_ctx->global_ctx->cpu_time_ms) { double pct = (total_cpu > 0.0) ? (kv.second / total_cpu * 100.0) : 0.0; std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << pct << "%)\n"; } - if (ctx->webgpu_ctx->cpu_detail_ms.size() > 0) { + if (ctx->webgpu_ctx->global_ctx->cpu_detail_ms.size() > 0) { std::cout << "ggml_webgpu: cpu detailed breakdown:\n"; } - for (const auto & kv : ctx->webgpu_ctx->cpu_detail_ms) { + for (const auto & kv : ctx->webgpu_ctx->global_ctx->cpu_detail_ms) { double pct = (total_cpu > 0.0) ? (kv.second / total_cpu * 100.0) : 0.0; std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << pct << "%)\n"; } @@ -741,12 +773,12 @@ static void ggml_backend_webgpu_free(ggml_backend_t backend) { #ifdef GGML_WEBGPU_GPU_PROFILE std::cout << "\n[ggml_webgpu gpu profiling summary]\n"; double total_gpu = 0.0; - for (const auto & kv : ctx->webgpu_ctx->shader_gpu_time_ms) { + for (const auto & kv : ctx->webgpu_ctx->global_ctx->shader_gpu_time_ms) { total_gpu += kv.second; } std::cout << "ggml_webgpu: total gpu time (all shaders): " << total_gpu << " ms\n"; std::cout << "\nggml_webgpu: gpu breakdown:\n"; - for (const auto & kv : ctx->webgpu_ctx->shader_gpu_time_ms) { + for (const auto & kv : ctx->webgpu_ctx->global_ctx->shader_gpu_time_ms) { double pct = (total_gpu > 0.0) ? (kv.second / total_gpu * 100.0) : 0.0; std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << pct << "%)\n"; } @@ -772,12 +804,12 @@ static wgpu::Buffer ggml_webgpu_tensor_buf(const ggml_tensor * tensor) { static size_t ggml_webgpu_tensor_misalignment(webgpu_context & ctx, const ggml_tensor * t) { size_t offset = ggml_webgpu_tensor_offset(t); - return offset & (ctx->limits.minStorageBufferOffsetAlignment - 1); + return offset & (ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment - 1); } static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, const ggml_tensor * t) { size_t offset = ggml_webgpu_tensor_offset(t); - return offset & ~(ctx->limits.minStorageBufferOffsetAlignment - 1); + return offset & ~(ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment - 1); } static size_t ggml_webgpu_tensor_binding_size(webgpu_context & ctx, ggml_tensor * t) { @@ -790,6 +822,28 @@ static bool ggml_webgpu_tensor_equal(ggml_tensor * a, ggml_tensor * b) { (ggml_webgpu_tensor_offset(a) == ggml_webgpu_tensor_offset(b)); } +// Used to determine if two tensors share the same buffer and their byte ranges overlap, +static bool ggml_webgpu_tensor_overlap(ggml_tensor * a, ggml_tensor * b) { + return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) && + ggml_webgpu_tensor_offset(a) < (ggml_webgpu_tensor_offset(b) + ggml_nbytes(b)) && + ggml_webgpu_tensor_offset(b) < (ggml_webgpu_tensor_offset(a) + ggml_nbytes(a)); +} + +struct binary_overlap_flags { + bool inplace; // src0 == dst + bool overlap; // src1 == dst +}; + +static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { + binary_overlap_flags flags = {}; + flags.inplace = ggml_webgpu_tensor_equal(src0, dst); + flags.overlap = ggml_webgpu_tensor_overlap(src1, dst); + + return flags; +} + static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { uint32_t ne = (uint32_t) ggml_nelements(dst); @@ -818,31 +872,28 @@ static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, g }; uint32_t wg_x = CEIL_DIV(ne, WEBGPU_MAX_WG_SIZE); - return ggml_backend_webgpu_build(ctx, ctx->cpy_pipelines[src->type][dst->type], params, entries, wg_x); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, ctx->cpy_pipelines[src->type][dst->type], + params, entries, wg_x); } static webgpu_command ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { const bool circular = ggml_get_op_params_i32(dst, 8) != 0; ggml_webgpu_pad_pipeline_key pipeline_key = { .circular = circular }; - ggml_webgpu_pad_shader_lib_context shader_lib_ctx = { .key = pipeline_key, - .max_wg_size = - ctx->limits.maxComputeInvocationsPerWorkgroup }; + ggml_webgpu_pad_shader_lib_context shader_lib_ctx = { + .key = pipeline_key, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup + }; webgpu_pipeline pipeline; - { - // TODO: remove guard once pipeline caches are per-thread - std::lock_guard lock(ctx->mutex); - auto it = ctx->pad_pipelines.find(pipeline_key); - if (it != ctx->pad_pipelines.end()) { - pipeline = it->second; - } else { - ggml_webgpu_processed_shader processed = - ggml_webgpu_preprocess_pad_shader(ctx->p, wgsl_pad, shader_lib_ctx); - pipeline = ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); - pipeline.context = processed.decisions; - ctx->pad_pipelines.emplace(pipeline_key, pipeline); - } + auto it = ctx->pad_pipelines.find(pipeline_key); + if (it != ctx->pad_pipelines.end()) { + pipeline = it->second; + } else { + ggml_webgpu_processed_shader processed = ggml_webgpu_preprocess_pad_shader(ctx->p, wgsl_pad, shader_lib_ctx); + pipeline = + ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); + pipeline.context = processed.decisions; + ctx->pad_pipelines.emplace(pipeline_key, pipeline); } ggml_webgpu_generic_shader_decisions decisions = @@ -891,7 +942,7 @@ static webgpu_command ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, g }; uint32_t wg_x = CEIL_DIV(ne, decisions.wg_size); - return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, @@ -907,24 +958,21 @@ static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, .vec4 = src->ne[0] % 4 == 0, .i64_idx = idx->type == GGML_TYPE_I64 }; - ggml_webgpu_set_rows_shader_lib_context shader_lib_ctx = { .key = key, - .max_wg_size = - ctx->limits.maxComputeInvocationsPerWorkgroup }; + ggml_webgpu_set_rows_shader_lib_context shader_lib_ctx = { + .key = key, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup + }; webgpu_pipeline pipeline; - // TODO: remove guard once pipeline caches are per-thread - { - std::lock_guard lock(ctx->mutex); - auto it = ctx->set_rows_pipelines.find(key); - if (it != ctx->set_rows_pipelines.end()) { - pipeline = it->second; - } else { - ggml_webgpu_processed_shader processed = - ggml_webgpu_preprocess_set_rows_shader(ctx->p, wgsl_set_rows, shader_lib_ctx); - pipeline = ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); - pipeline.context = processed.decisions; - ctx->set_rows_pipelines.emplace(key, pipeline); - } + auto it = ctx->set_rows_pipelines.find(key); + if (it != ctx->set_rows_pipelines.end()) { + pipeline = it->second; + } else { + ggml_webgpu_processed_shader processed = + ggml_webgpu_preprocess_set_rows_shader(ctx->p, wgsl_set_rows, shader_lib_ctx); + pipeline = + ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); + pipeline.context = processed.decisions; + ctx->set_rows_pipelines.emplace(key, pipeline); } ggml_webgpu_generic_shader_decisions decisions = @@ -981,7 +1029,8 @@ static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, threads = src->ne[0] * src->ne[1] * src->ne[2] * src->ne[3]; } uint32_t wg_x = CEIL_DIV(threads, decisions.wg_size); - return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, 1, error_bufs); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, 1, + error_bufs); } static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx, @@ -1023,7 +1072,7 @@ static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx, uint32_t vectorized = src->type == GGML_TYPE_F32 && dst->ne[0] % 4 == 0; webgpu_pipeline pipeline = ctx->get_rows_pipelines[src->type][vectorized]; - return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, @@ -1098,19 +1147,21 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, uint32_t batches = dst->ne[2] * dst->ne[3]; uint32_t output_groups = CEIL_DIV(dst->ne[0], WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG); uint32_t total_wg = output_groups * batches; - wg_x = total_wg % ctx->limits.maxComputeWorkgroupsPerDimension; - wg_y = CEIL_DIV(total_wg, ctx->limits.maxComputeWorkgroupsPerDimension); + wg_x = total_wg % ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; + wg_y = CEIL_DIV(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension); } else { pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][vectorized]; uint32_t wg_m; uint32_t wg_n; #ifndef __EMSCRIPTEN__ - if (ctx->supports_subgroup_matrix) { + if (ctx->global_ctx->capabilities.supports_subgroup_matrix) { // The total number of subgroups/workgroups needed per matrix. - uint32_t wg_m_sg_tile = WEBGPU_MUL_MAT_SUBGROUP_M * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M * ctx->sg_mat_m; + uint32_t wg_m_sg_tile = WEBGPU_MUL_MAT_SUBGROUP_M * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M * + ctx->global_ctx->capabilities.sg_mat_m; wg_m = CEIL_DIV(dst->ne[0], wg_m_sg_tile); - uint32_t wg_n_sg_tile = WEBGPU_MUL_MAT_SUBGROUP_N * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N * ctx->sg_mat_n; - wg_n = CEIL_DIV(dst->ne[1], wg_n_sg_tile); + uint32_t wg_n_sg_tile = WEBGPU_MUL_MAT_SUBGROUP_N * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N * + ctx->global_ctx->capabilities.sg_mat_n; + wg_n = CEIL_DIV(dst->ne[1], wg_n_sg_tile); } else { #endif uint32_t tile_m_s = WEBGPU_MUL_MAT_TILE_M * WEBGPU_MUL_MAT_WG_SIZE_M; @@ -1124,9 +1175,10 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, wg_x = wg_m * wg_n * dst->ne[2] * dst->ne[3]; } } - return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y); } +#ifndef __EMSCRIPTEN__ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, ggml_tensor * Q, ggml_tensor * K, @@ -1210,8 +1262,8 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, .offset = ggml_webgpu_tensor_align_offset(ctx, dst), .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); - bool kv_direct = - (K->type == GGML_TYPE_F16) && (Q->ne[0] % ctx->sg_mat_k == 0) && (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0); + bool kv_direct = (K->type == GGML_TYPE_F16) && (Q->ne[0] % ctx->global_ctx->capabilities.sg_mat_k == 0) && + (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0); ggml_webgpu_flash_attn_pipeline_key key = { .kv_type = K->type, @@ -1223,38 +1275,36 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, .uses_logit_softcap = logit_softcap != 0.0f, }; - webgpu_pipeline pipeline; - // TODO: remove guard once pipeline caches are per-thread - { - std::lock_guard lock(ctx->mutex); - auto it = ctx->flash_attn_pipelines.find(key); - if (it != ctx->flash_attn_pipelines.end()) { - pipeline = it->second; - } else { - ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = { .key = key, - .sg_mat_m = ctx->sg_mat_m, - .sg_mat_n = ctx->sg_mat_n, - .sg_mat_k = ctx->sg_mat_k, - .wg_mem_limit_bytes = - ctx->limits.maxComputeWorkgroupStorageSize, - .max_subgroup_size = ctx->max_subgroup_size }; - - ggml_webgpu_processed_shader processed = - ggml_webgpu_preprocess_flash_attn_shader(ctx->p, wgsl_flash_attn, shader_lib_ctx); - pipeline = ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); - pipeline.context = processed.decisions; - ctx->flash_attn_pipelines.emplace(key, pipeline); - } + webgpu_pipeline pipeline; + auto it = ctx->flash_attn_pipelines.find(key); + if (it != ctx->flash_attn_pipelines.end()) { + pipeline = it->second; + } else { + ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = { + .key = key, + .sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m, + .sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n, + .sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k, + .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize, + .max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size + }; + + ggml_webgpu_processed_shader processed = + ggml_webgpu_preprocess_flash_attn_shader(ctx->p, wgsl_flash_attn, shader_lib_ctx); + pipeline = + ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); + pipeline.context = processed.decisions; + ctx->flash_attn_pipelines.emplace(key, pipeline); } ggml_webgpu_flash_attn_shader_decisions decisions = *static_cast(pipeline.context); - uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions.q_tile); uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches - return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } +#endif static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { bool is_unary = dst->op == GGML_OP_UNARY; @@ -1264,24 +1314,21 @@ static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * s ggml_webgpu_unary_pipeline_key pipeline_key = { .type = dst->type, .op = op, .is_unary = is_unary, .inplace = inplace }; - ggml_webgpu_unary_shader_lib_context shader_lib_ctx = { .key = pipeline_key, - .max_wg_size = - ctx->limits.maxComputeInvocationsPerWorkgroup }; + ggml_webgpu_unary_shader_lib_context shader_lib_ctx = { + .key = pipeline_key, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup + }; webgpu_pipeline pipeline; - { - // TODO: remove guard once pipeline caches are per-thread - std::lock_guard lock(ctx->mutex); - auto it = ctx->unary_pipelines.find(pipeline_key); - if (it != ctx->unary_pipelines.end()) { - pipeline = it->second; - } else { - ggml_webgpu_processed_shader processed = - ggml_webgpu_preprocess_unary_shader(ctx->p, wgsl_unary, shader_lib_ctx); - pipeline = ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); - pipeline.context = processed.decisions; - ctx->unary_pipelines.emplace(pipeline_key, pipeline); - } + auto it = ctx->unary_pipelines.find(pipeline_key); + if (it != ctx->unary_pipelines.end()) { + pipeline = it->second; + } else { + ggml_webgpu_processed_shader processed = + ggml_webgpu_preprocess_unary_shader(ctx->p, wgsl_unary, shader_lib_ctx); + pipeline = + ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); + pipeline.context = processed.decisions; + ctx->unary_pipelines.emplace(pipeline_key, pipeline); } ggml_webgpu_generic_shader_decisions decisions = @@ -1346,17 +1393,45 @@ static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * s } uint32_t wg_x = CEIL_DIV(ne, decisions.wg_size); - return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } -static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * dst, - webgpu_pipeline & pipeline, - bool inplace) { +static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { + binary_overlap_flags flags = ggml_webgpu_detect_binary_overlap(src0, src1, dst); + + ggml_webgpu_binary_pipeline_key pipeline_key = { + .type = dst->type, + .op = dst->op, + .inplace = flags.inplace, + .overlap = flags.overlap, + }; + ggml_webgpu_binary_shader_lib_context shader_lib_ctx = { + .key = pipeline_key, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup + }; + + webgpu_pipeline pipeline; + auto it = ctx->binary_pipelines.find(pipeline_key); + if (it != ctx->binary_pipelines.end()) { + pipeline = it->second; + } else { + ggml_webgpu_processed_shader processed = + ggml_webgpu_preprocess_binary_shader(ctx->p, wgsl_binary, shader_lib_ctx); + pipeline = + ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); + pipeline.context = processed.decisions; + ctx->binary_pipelines.emplace(pipeline_key, pipeline); + } + + ggml_webgpu_generic_shader_decisions decisions = + *static_cast(pipeline.context); + + uint32_t ne = (uint32_t) ggml_nelements(dst); + std::vector params = { - (uint32_t) ggml_nelements(dst), + ne, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), @@ -1373,25 +1448,31 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx, (uint32_t) src1->ne[3], }; - std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1) } - }; - if (!inplace) { + std::vector entries; + + entries.push_back({ + .binding = 0, + .buffer = ggml_webgpu_tensor_buf(src0), + .offset = ggml_webgpu_tensor_align_offset(ctx, src0), + .size = ggml_webgpu_tensor_binding_size(ctx, src0), + }); + + entries.push_back({ + .binding = 1, + .buffer = ggml_webgpu_tensor_buf(src1), + .offset = ggml_webgpu_tensor_align_offset(ctx, src1), + .size = ggml_webgpu_tensor_binding_size(ctx, src1), + }); + + if (!flags.inplace && !flags.overlap) { entries.push_back({ .binding = 2, .buffer = ggml_webgpu_tensor_buf(dst), .offset = ggml_webgpu_tensor_align_offset(ctx, dst), .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); } - uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE); - return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); + uint32_t wg_x = CEIL_DIV(ne, decisions.wg_size); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } static webgpu_command ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { @@ -1426,7 +1507,8 @@ static webgpu_command ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * s .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); } - return ggml_backend_webgpu_build(ctx, ctx->rms_norm_pipelines[inplace], params, entries, ggml_nrows(src)); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, ctx->rms_norm_pipelines[inplace], params, + entries, ggml_nrows(src)); } static webgpu_command ggml_webgpu_rope(webgpu_context & ctx, @@ -1513,7 +1595,7 @@ static webgpu_command ggml_webgpu_rope(webgpu_context & ctx, webgpu_pipeline pipeline = ctx->rope_pipelines[dst->type][has_freq_factor][inplace]; uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE); - return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } static webgpu_command ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) { @@ -1565,7 +1647,7 @@ static webgpu_command ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, webgpu_pipeline pipeline = ctx->glu_pipelines[ggml_get_glu_op(dst)][dst->type][split]; uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE); - return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } static webgpu_command ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { @@ -1602,7 +1684,8 @@ static webgpu_command ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, } uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE); - return ggml_backend_webgpu_build(ctx, ctx->scale_pipelines[inplace], params, entries, wg_x); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, ctx->scale_pipelines[inplace], params, + entries, wg_x); } static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx, @@ -1674,7 +1757,8 @@ static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); } - return ggml_backend_webgpu_build(ctx, ctx->soft_max_pipelines[mask_type][has_sink][inplace], params, entries, + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, + ctx->soft_max_pipelines[mask_type][has_sink][inplace], params, entries, ggml_nrows(dst)); } @@ -1696,25 +1780,22 @@ static webgpu_command ggml_webgpu_argmax(webgpu_context & ctx, ggml_tensor * src ggml_webgpu_generic_shader_lib_context shader_lib_ctx = { .vec4 = src->ne[0] % 4 == 0, - .max_wg_size = ctx->limits.maxComputeInvocationsPerWorkgroup, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, }; webgpu_pipeline pipeline; - { - // TODO: remove guard once pipeline caches are per-thread - std::lock_guard lock(ctx->mutex); - auto it = ctx->argmax_pipelines.find(shader_lib_ctx.vec4); - if (it != ctx->argmax_pipelines.end()) { - pipeline = it->second; - } else { - ggml_webgpu_processed_shader processed = - ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_argmax, shader_lib_ctx, "argmax"); - pipeline = ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); - ctx->argmax_pipelines.emplace(shader_lib_ctx.vec4, pipeline); - } + auto it = ctx->argmax_pipelines.find(shader_lib_ctx.vec4); + if (it != ctx->argmax_pipelines.end()) { + pipeline = it->second; + } else { + ggml_webgpu_processed_shader processed = + ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_argmax, shader_lib_ctx, "argmax"); + pipeline = + ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); + ctx->argmax_pipelines.emplace(shader_lib_ctx.vec4, pipeline); } uint32_t wg_x = ggml_nelements(dst); - return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { @@ -1722,21 +1803,21 @@ static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * sr // ascending order is 0, descending order is 1 const int32_t order = is_top_k ? (int32_t) GGML_SORT_ORDER_DESC : (int32_t) ggml_get_op_params_i32(dst, 0); - ggml_webgpu_argsort_shader_lib_context shader_lib_ctx = { .max_wg_size = - ctx->limits.maxComputeInvocationsPerWorkgroup, - .wg_mem_limit_bytes = - ctx->limits.maxComputeWorkgroupStorageSize, - .order = order }; + ggml_webgpu_argsort_shader_lib_context shader_lib_ctx = { + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, + .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize, + .order = order + }; - std::lock_guard lock(ctx->mutex); - webgpu_pipeline argsort_pipeline; - auto it = ctx->argsort_pipelines.find(order); + webgpu_pipeline argsort_pipeline; + auto it = ctx->argsort_pipelines.find(order); if (it != ctx->argsort_pipelines.end()) { argsort_pipeline = it->second; } else { ggml_webgpu_processed_shader processed = ggml_webgpu_preprocess_argsort_shader(ctx->p, wgsl_argsort, shader_lib_ctx); - argsort_pipeline = ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); + argsort_pipeline = + ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); argsort_pipeline.context = processed.decisions; ctx->argsort_pipelines.emplace(order, argsort_pipeline); } @@ -1751,7 +1832,7 @@ static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * sr ggml_webgpu_processed_shader processed = ggml_webgpu_preprocess_argsort_merge_shader(ctx->p, wgsl_argsort_merge, shader_lib_ctx); argsort_merge_pipeline = - ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); + ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); argsort_merge_pipeline.context = processed.decisions; ctx->argsort_merge_pipelines.emplace(order, argsort_merge_pipeline); } @@ -1780,9 +1861,10 @@ static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * sr const bool start_in_tmp = (merge_passes % 2) == 1; - const size_t dst_offset = ggml_webgpu_tensor_offset(dst); - const size_t idx_nbytes = out_ne0 * ggml_nrows(dst) * sizeof(int32_t); - const size_t tmp_offset = ROUNDUP_POW2(dst_offset + idx_nbytes, ctx->limits.minStorageBufferOffsetAlignment); + const size_t dst_offset = ggml_webgpu_tensor_offset(dst); + const size_t idx_nbytes = out_ne0 * ggml_nrows(dst) * sizeof(int32_t); + const size_t tmp_offset = + ROUNDUP_POW2(dst_offset + idx_nbytes, ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment); const size_t tmp_binding_size = ROUNDUP_POW2(idx_nbytes, WEBGPU_STORAGE_BUF_BINDING_MULT); const size_t dst_binding_size = ROUNDUP_POW2(idx_nbytes + ggml_webgpu_tensor_misalignment(ctx, dst), WEBGPU_STORAGE_BUF_BINDING_MULT); @@ -1813,10 +1895,10 @@ static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * sr }; const uint32_t total_wg_init = npr * nrows; - const uint32_t max_wg = ctx->limits.maxComputeWorkgroupsPerDimension; - const uint32_t wg_x_init = std::min(total_wg_init, max_wg); - const uint32_t wg_y_init = CEIL_DIV(total_wg_init, wg_x_init); - std::vector init_entries = { + const uint32_t max_wg = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; + const uint32_t wg_x_init = std::min(total_wg_init, max_wg); + const uint32_t wg_y_init = CEIL_DIV(total_wg_init, wg_x_init); + std::vector init_entries = { { .binding = 0, .buffer = ggml_webgpu_tensor_buf(src), .offset = ggml_webgpu_tensor_align_offset(ctx, src), @@ -1830,7 +1912,8 @@ static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * sr workgroups_list.push_back({ wg_x_init, wg_y_init }); if (merge_passes == 0) { - return ggml_backend_webgpu_build_multi(ctx, pipelines, params_list, entries_list, workgroups_list); + return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_buf_pool, pipelines, params_list, + entries_list, workgroups_list); } bool in_is_tmp = start_in_tmp; @@ -1891,7 +1974,8 @@ static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * sr in_is_tmp = !in_is_tmp; } - return ggml_backend_webgpu_build_multi(ctx, pipelines, params_list, entries_list, workgroups_list); + return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_buf_pool, pipelines, params_list, entries_list, + workgroups_list); } static webgpu_command ggml_webgpu_cumsum(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { @@ -1912,24 +1996,21 @@ static webgpu_command ggml_webgpu_cumsum(webgpu_context & ctx, ggml_tensor * src ggml_webgpu_generic_shader_lib_context shader_lib_ctx = { .vec4 = false, - .max_wg_size = ctx->limits.maxComputeInvocationsPerWorkgroup, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, }; webgpu_pipeline pipeline; - // TODO: remove guard once pipeline caches are per-thread - { - std::lock_guard lock(ctx->mutex); - auto it = ctx->cumsum_pipelines.find(1); - if (it != ctx->cumsum_pipelines.end()) { - pipeline = it->second; - } else { - ggml_webgpu_processed_shader processed = - ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_cumsum, shader_lib_ctx, "cumsum"); - pipeline = ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); - ctx->cumsum_pipelines.emplace(1, pipeline); - } + auto it = ctx->cumsum_pipelines.find(1); + if (it != ctx->cumsum_pipelines.end()) { + pipeline = it->second; + } else { + ggml_webgpu_processed_shader processed = + ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_cumsum, shader_lib_ctx, "cumsum"); + pipeline = + ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); + ctx->cumsum_pipelines.emplace(1, pipeline); } uint32_t wg_x = ggml_nrows(dst); - return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } static webgpu_command ggml_webgpu_sum_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { @@ -1956,25 +2037,22 @@ static webgpu_command ggml_webgpu_sum_rows(webgpu_context & ctx, ggml_tensor * s ggml_webgpu_generic_shader_lib_context shader_lib_ctx = { .vec4 = false, - .max_wg_size = ctx->limits.maxComputeInvocationsPerWorkgroup, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, }; webgpu_pipeline pipeline; - { - // TODO: remove guard once pipeline caches are per-thread - std::lock_guard lock(ctx->mutex); - auto it = ctx->sum_rows_pipelines.find(1); - if (it != ctx->sum_rows_pipelines.end()) { - pipeline = it->second; - } else { - ggml_webgpu_processed_shader processed = - ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_sum_rows, shader_lib_ctx, "sum_rows"); - pipeline = ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); - ctx->sum_rows_pipelines.emplace(1, pipeline); - } + auto it = ctx->sum_rows_pipelines.find(1); + if (it != ctx->sum_rows_pipelines.end()) { + pipeline = it->second; + } else { + ggml_webgpu_processed_shader processed = + ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_sum_rows, shader_lib_ctx, "sum_rows"); + pipeline = + ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); + ctx->sum_rows_pipelines.emplace(1, pipeline); } uint32_t wg_x = total_sum ? 1 : ggml_nrows(dst); - return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } // Returns the encoded command, or std::nullopt if the operation is a no-op @@ -2009,27 +2087,16 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx, case GGML_OP_MUL_MAT: return ggml_webgpu_mul_mat(ctx, src0, src1, node); case GGML_OP_FLASH_ATTN_EXT: +#ifndef __EMSCRIPTEN__ return ggml_webgpu_flash_attn(ctx, src0, src1, src2, node->src[3], node->src[4], node); +#else + return std::nullopt; +#endif case GGML_OP_ADD: - { - int inplace = ggml_webgpu_tensor_equal(src0, node); - return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_pipelines[node->type][inplace], inplace); - } case GGML_OP_SUB: - { - int inplace = ggml_webgpu_tensor_equal(src0, node); - return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->sub_pipelines[node->type][inplace], inplace); - } case GGML_OP_MUL: - { - int inplace = ggml_webgpu_tensor_equal(src0, node); - return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_pipelines[node->type][inplace], inplace); - } case GGML_OP_DIV: - { - int inplace = ggml_webgpu_tensor_equal(src0, node); - return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->div_pipelines[node->type][inplace], inplace); - } + return ggml_webgpu_binary_op(ctx, src0, src1, node); case GGML_OP_RMS_NORM: return ggml_webgpu_rms_norm(ctx, src0, node); case GGML_OP_ROPE: @@ -2070,12 +2137,12 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx, static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) { WEBGPU_LOG_DEBUG("ggml_backend_webgpu_graph_compute(" << cgraph->n_nodes << " nodes)"); - ggml_backend_webgpu_context * backend_ctx = static_cast(backend->context); + ggml_backend_webgpu_context * backend_ctx = (ggml_backend_webgpu_context *) backend->context; webgpu_context ctx = backend_ctx->webgpu_ctx; WEBGPU_CPU_PROFILE_TOTAL_START(graph_compute); - ctx->inflight_threads++; + ctx->global_ctx->inflight_threads++; std::vector commands; std::vector futures; @@ -2084,25 +2151,27 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str commands.push_back(*cmd); } // compute the batch size based on the number of inflight threads - uint32_t inflight_threads = ctx->inflight_threads; + uint32_t inflight_threads = ctx->global_ctx->inflight_threads; uint32_t batch_size = std::min(std::max(1u, WEBGPU_NUM_PARAM_BUFS / std::max(inflight_threads, 1u)), WEBGPU_COMMAND_SUBMIT_BATCH_SIZE); if (commands.size() >= batch_size) { - futures.push_back(ggml_backend_webgpu_submit(ctx, commands)); + futures.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool, + &ctx->set_rows_error_buf_pool)); // Process events and check for completed submissions - ctx->instance.ProcessEvents(); - ggml_backend_webgpu_wait(ctx, futures, false); + ctx->global_ctx->instance.ProcessEvents(); + ggml_backend_webgpu_wait(ctx->global_ctx, futures, false); commands.clear(); } } if (!commands.empty()) { - webgpu_submission_futures new_futures = ggml_backend_webgpu_submit(ctx, commands); + webgpu_submission_futures new_futures = + ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool, &ctx->set_rows_error_buf_pool); futures.push_back(new_futures); } - ggml_backend_webgpu_wait(ctx, futures); - ctx->inflight_threads--; - WEBGPU_CPU_PROFILE_TOTAL_END(graph_compute, ctx); + ggml_backend_webgpu_wait(ctx->global_ctx, futures); + ctx->global_ctx->inflight_threads--; + WEBGPU_CPU_PROFILE_TOTAL_END(graph_compute, ctx->global_ctx); return GGML_STATUS_SUCCESS; } @@ -2159,8 +2228,8 @@ static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffe // This is a trick to set all bytes of a u32 to the same 1 byte value. uint32_t val32 = (uint32_t) value * 0x01010101; - ggml_backend_webgpu_buffer_memset(buf_ctx->webgpu_ctx, buf_ctx->buffer, val32, total_offset, size); - WEBGPU_CPU_PROFILE_TOTAL_END(memset_tensor, buf_ctx->webgpu_ctx); + ggml_backend_webgpu_buffer_memset(buf_ctx->global_ctx, buf_ctx->buffer, val32, total_offset, size); + WEBGPU_CPU_PROFILE_TOTAL_END(memset_tensor, buf_ctx->global_ctx); } static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer, @@ -2169,15 +2238,14 @@ static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer, size_t offset, size_t size) { WEBGPU_CPU_PROFILE_TOTAL_START(set_tensor); - ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context; - webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx; + ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context; WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_set_tensor(" << buf_ctx->label << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")"); size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset; - webgpu_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size / 4) * 4); + buf_ctx->global_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size / 4) * 4); if (size % 4 != 0) { // If size is not a multiple of 4, we need to memset the remaining bytes @@ -2190,21 +2258,21 @@ static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer, ((uint8_t *) &val32)[i] = ((const uint8_t *) data)[size - remaining_size + i]; } // memset the remaining bytes - ggml_backend_webgpu_buffer_memset(webgpu_ctx, buf_ctx->buffer, val32, total_offset + (size - remaining_size), - remaining_size); + ggml_backend_webgpu_buffer_memset(buf_ctx->global_ctx, buf_ctx->buffer, val32, + total_offset + (size - remaining_size), remaining_size); } else { // wait for WriteBuffer to complete - webgpu_ctx->instance.WaitAny( - webgpu_ctx->queue.OnSubmittedWorkDone(wgpu::CallbackMode::AllowSpontaneous, + buf_ctx->global_ctx->instance.WaitAny(buf_ctx->global_ctx->queue.OnSubmittedWorkDone( + wgpu::CallbackMode::AllowSpontaneous, [](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { if (status != wgpu::QueueWorkDoneStatus::Success) { GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", std::string(message).c_str()); } }), - UINT64_MAX); + UINT64_MAX); } - WEBGPU_CPU_PROFILE_TOTAL_END(set_tensor, webgpu_ctx); + WEBGPU_CPU_PROFILE_TOTAL_END(set_tensor, buf_ctx->global_ctx); } static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer, @@ -2216,8 +2284,7 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer, ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context; WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_get_tensor(" << buf_ctx->label << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")"); - webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx; - wgpu::Device device = webgpu_ctx->device; + wgpu::Device device = buf_ctx->global_ctx->device; size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset; @@ -2227,42 +2294,45 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer, final_size = size + (4 - (size % 4)); } - std::lock_guard lock(webgpu_ctx->mutex); + std::lock_guard lock(buf_ctx->global_ctx->mutex); - if (webgpu_ctx->get_tensor_staging_buf == nullptr || webgpu_ctx->get_tensor_staging_buf.GetSize() < final_size) { + if (buf_ctx->global_ctx->get_tensor_staging_buf == nullptr || + buf_ctx->global_ctx->get_tensor_staging_buf.GetSize() < final_size) { // Create a new staging buffer if it doesn't exist or is too small - if (webgpu_ctx->get_tensor_staging_buf) { - webgpu_ctx->get_tensor_staging_buf.Destroy(); + if (buf_ctx->global_ctx->get_tensor_staging_buf) { + buf_ctx->global_ctx->get_tensor_staging_buf.Destroy(); } - ggml_webgpu_create_buffer(device, webgpu_ctx->get_tensor_staging_buf, final_size, + ggml_webgpu_create_buffer(device, buf_ctx->global_ctx->get_tensor_staging_buf, final_size, wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "get_tensor_staging_buf"); } // Copy the data from the buffer to the staging buffer wgpu::CommandEncoder encoder = device.CreateCommandEncoder(); - encoder.CopyBufferToBuffer(buf_ctx->buffer, total_offset, webgpu_ctx->get_tensor_staging_buf, 0, final_size); + encoder.CopyBufferToBuffer(buf_ctx->buffer, total_offset, buf_ctx->global_ctx->get_tensor_staging_buf, 0, + final_size); wgpu::CommandBuffer commands = encoder.Finish(); // Submit the command buffer to the queue - webgpu_ctx->queue.Submit(1, &commands); + buf_ctx->global_ctx->queue.Submit(1, &commands); // Map the staging buffer to read the data - ggml_backend_webgpu_map_buffer(webgpu_ctx, webgpu_ctx->get_tensor_staging_buf, wgpu::MapMode::Read, 0, final_size); + ggml_backend_webgpu_map_buffer(buf_ctx->global_ctx, buf_ctx->global_ctx->get_tensor_staging_buf, + wgpu::MapMode::Read, 0, final_size); // Must specify size here since the staging buffer might be larger than the tensor size - const void * mapped_range = webgpu_ctx->get_tensor_staging_buf.GetConstMappedRange(0, final_size); + const void * mapped_range = buf_ctx->global_ctx->get_tensor_staging_buf.GetConstMappedRange(0, final_size); // Copy the data from the mapped range to the output buffer std::memcpy(data, mapped_range, size); - webgpu_ctx->get_tensor_staging_buf.Unmap(); - WEBGPU_CPU_PROFILE_TOTAL_END(get_tensor, webgpu_ctx); + buf_ctx->global_ctx->get_tensor_staging_buf.Unmap(); + WEBGPU_CPU_PROFILE_TOTAL_END(get_tensor, buf_ctx->global_ctx); } static void ggml_backend_webgpu_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_clear(" << buffer << ", " << (uint32_t) value << ")"); WEBGPU_CPU_PROFILE_TOTAL_START(clear); ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context; - ggml_backend_webgpu_buffer_memset(buf_ctx->webgpu_ctx, buf_ctx->buffer, value, 0, buffer->size); - WEBGPU_CPU_PROFILE_TOTAL_END(clear, buf_ctx->webgpu_ctx); + ggml_backend_webgpu_buffer_memset(buf_ctx->global_ctx, buf_ctx->buffer, value, 0, buffer->size); + WEBGPU_CPU_PROFILE_TOTAL_END(clear, buf_ctx->global_ctx); } static ggml_backend_buffer_i ggml_backend_webgpu_buffer_interface = { @@ -2292,28 +2362,30 @@ static ggml_backend_buffer_t ggml_backend_webgpu_buffer_type_alloc_buffer(ggml_b int buffer_id = buffer_count++; std::string buf_name = "tensor_buf" + std::to_string(buffer_id); WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_type_alloc_buffer_" << buffer_id << ": " << size << " bytes"); - ggml_backend_webgpu_device_context * ctx = static_cast(buft->device->context); - wgpu::Buffer buf; - ggml_webgpu_create_buffer(ctx->webgpu_ctx->device, buf, ROUNDUP_POW2(size, WEBGPU_STORAGE_BUF_BINDING_MULT), + ggml_backend_webgpu_device_context * ctx = static_cast(buft->device->context); + wgpu::Buffer buf; + ggml_webgpu_create_buffer(ctx->webgpu_global_ctx->device, buf, ROUNDUP_POW2(size, WEBGPU_STORAGE_BUF_BINDING_MULT), wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst, buf_name.c_str()); ggml_backend_webgpu_buffer_context * buf_ctx = - new ggml_backend_webgpu_buffer_context(ctx->webgpu_ctx, buf, buf_name); + new ggml_backend_webgpu_buffer_context(buf, buf_name, ctx->webgpu_global_ctx); return ggml_backend_buffer_init(buft, ggml_backend_webgpu_buffer_interface, buf_ctx, size); } static size_t ggml_backend_webgpu_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { - ggml_backend_webgpu_device_context * ctx = static_cast(buft->device->context); - return ctx->webgpu_ctx->limits.minStorageBufferOffsetAlignment; + ggml_backend_webgpu_device_context * dev_ctx = + static_cast(buft->device->context); + return dev_ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment; } // maxBufferSize might be larger, but you can't bind more than maxStorageBufferBindingSize to a single binding. static size_t ggml_backend_webgpu_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) { - ggml_backend_webgpu_device_context * ctx = static_cast(buft->device->context); - return ctx->webgpu_ctx->limits.maxStorageBufferBindingSize; + ggml_backend_webgpu_device_context * dev_ctx = + static_cast(buft->device->context); + return dev_ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize; } static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, @@ -2322,7 +2394,7 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer size_t res = ggml_nbytes(tensor); switch (tensor->op) { case GGML_OP_ARGSORT: - res = ROUNDUP_POW2(res * 2 + ctx->webgpu_ctx->limits.minStorageBufferOffsetAlignment, + res = ROUNDUP_POW2(res * 2 + ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment, WEBGPU_STORAGE_BUF_BINDING_MULT); break; case GGML_OP_TOP_K: @@ -2330,8 +2402,9 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer const ggml_tensor * src0 = tensor->src[0]; if (src0) { const size_t full = sizeof(int32_t) * ggml_nelements(src0); - res = ROUNDUP_POW2(full * 2 + ctx->webgpu_ctx->limits.minStorageBufferOffsetAlignment, - WEBGPU_STORAGE_BUF_BINDING_MULT); + res = ROUNDUP_POW2( + full * 2 + ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment, + WEBGPU_STORAGE_BUF_BINDING_MULT); } } break; @@ -2359,7 +2432,7 @@ static void ggml_backend_webgpu_device_get_memory(ggml_backend_dev_t dev, size_t ggml_backend_webgpu_device_context * ctx = static_cast(dev->context); // TODO: for now, return maxBufferSize as both free and total memory // Track https://github.com/gpuweb/gpuweb/issues/5505 for updates. - uint64_t max_buffer_size = ctx->webgpu_ctx->limits.maxBufferSize; + uint64_t max_buffer_size = ctx->webgpu_global_ctx->capabilities.limits.maxBufferSize; // If we're on a 32-bit system, clamp to UINTPTR_MAX #if UINTPTR_MAX < UINT64_MAX uint64_t max_ptr_size = static_cast(UINTPTR_MAX); @@ -2402,66 +2475,67 @@ static std::vector ggml_webgpu_wg_size_entry(uint32_t wg_si return constants; } -static void ggml_webgpu_init_memset_pipeline(webgpu_context & webgpu_ctx) { +static void ggml_webgpu_init_memset_pipeline(webgpu_global_context & ctx) { // we use the maximum workgroup size for the memset pipeline - size_t max_threads = WEBGPU_MAX_WG_SIZE * webgpu_ctx->limits.maxComputeWorkgroupsPerDimension; + size_t max_threads = WEBGPU_MAX_WG_SIZE * ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; // Size the bytes_per_thread so that the largest buffer size can be handled - webgpu_ctx->memset_bytes_per_thread = CEIL_DIV(webgpu_ctx->limits.maxStorageBufferBindingSize, max_threads); + ctx->capabilities.memset_bytes_per_thread = + CEIL_DIV(ctx->capabilities.limits.maxStorageBufferBindingSize, max_threads); std::vector constants(2); - constants[0].key = "wg_size"; - constants[0].value = WEBGPU_MAX_WG_SIZE; - constants[1].key = "bytes_per_thread"; - constants[1].value = webgpu_ctx->memset_bytes_per_thread; - webgpu_ctx->memset_pipelines[0] = ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_memset, "memset", constants); + constants[0].key = "wg_size"; + constants[0].value = WEBGPU_MAX_WG_SIZE; + constants[1].key = "bytes_per_thread"; + constants[1].value = ctx->capabilities.memset_bytes_per_thread; + ctx->memset_pipelines[0] = ggml_webgpu_create_pipeline(ctx->device, wgsl_memset, "memset", constants); } static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { // Q4/Q5/Q8 classic quantizations webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q4_0_f32, "mul_mat_q4_0_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q4_0_f32, "mul_mat_q4_0_f32"); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_1][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q4_1_f32, "mul_mat_q4_1_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q4_1_f32, "mul_mat_q4_1_f32"); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q5_0][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q5_0_f32, "mul_mat_q5_0_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q5_0_f32, "mul_mat_q5_0_f32"); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q5_1][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q5_1_f32, "mul_mat_q5_1_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q5_1_f32, "mul_mat_q5_1_f32"); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q8_0][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q8_0_f32, "mul_mat_q8_0_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q8_0_f32, "mul_mat_q8_0_f32"); // K-quantizations webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q2_K][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q2_k_f32, "mul_mat_q2_k_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q2_k_f32, "mul_mat_q2_k_f32"); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q3_K][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q3_k_f32, "mul_mat_q3_k_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q3_k_f32, "mul_mat_q3_k_f32"); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_K][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q4_k_f32, "mul_mat_q4_k_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q4_k_f32, "mul_mat_q4_k_f32"); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q5_K][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q5_k_f32, "mul_mat_q5_k_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q5_k_f32, "mul_mat_q5_k_f32"); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q6_K][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q6_k_f32, "mul_mat_q6_k_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q6_k_f32, "mul_mat_q6_k_f32"); // IQ quantizations (2-, 3-, 4-bit variants) webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ2_XXS][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq2_xxs_f32, "mul_mat_iq2_xxs_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq2_xxs_f32, "mul_mat_iq2_xxs_f32"); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ2_XS][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq2_xs_f32, "mul_mat_iq2_xs_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq2_xs_f32, "mul_mat_iq2_xs_f32"); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ2_S][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq2_s_f32, "mul_mat_iq2_s_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq2_s_f32, "mul_mat_iq2_s_f32"); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ3_XXS][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq3_xxs_f32, "mul_mat_iq3_xxs_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq3_xxs_f32, "mul_mat_iq3_xxs_f32"); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ3_S][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq3_s_f32, "mul_mat_iq3_s_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq3_s_f32, "mul_mat_iq3_s_f32"); // 1-bit and 4-bit IQ variants webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ1_S][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq1_s_f32, "mul_mat_iq1_s_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq1_s_f32, "mul_mat_iq1_s_f32"); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ1_M][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq1_m_f32, "mul_mat_iq1_m_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq1_m_f32, "mul_mat_iq1_m_f32"); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ4_NL][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq4_nl_f32, "mul_mat_iq4_nl_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq4_nl_f32, "mul_mat_iq4_nl_f32"); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ4_XS][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq4_xs_f32, "mul_mat_iq4_xs_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq4_xs_f32, "mul_mat_iq4_xs_f32"); std::string proc_mul_mat_f32_f32; std::string proc_mul_mat_f32_f32_vec; @@ -2474,18 +2548,18 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { std::vector mul_mat_constants; #ifndef __EMSCRIPTEN__ - if (webgpu_ctx->supports_subgroup_matrix) { + if (webgpu_ctx->global_ctx->capabilities.supports_subgroup_matrix) { std::map sg_matrix_repls; - sg_matrix_repls["WEBGPU_MAX_SUBGROUP_SIZE"] = std::to_string(webgpu_ctx->max_subgroup_size); + sg_matrix_repls["WEBGPU_MAX_SUBGROUP_SIZE"] = + std::to_string(webgpu_ctx->global_ctx->capabilities.max_subgroup_size); sg_matrix_repls["WEBGPU_TILE_K"] = std::to_string(WEBGPU_MUL_MAT_TILE_K); sg_matrix_repls["WEBGPU_SUBGROUP_M"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M); sg_matrix_repls["WEBGPU_SUBGROUP_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N); sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_M"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M); sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N); - sg_matrix_repls["WEBGPU_SG_MAT_M_SIZE"] = std::to_string(webgpu_ctx->sg_mat_m); - sg_matrix_repls["WEBGPU_SG_MAT_N_SIZE"] = std::to_string(webgpu_ctx->sg_mat_n); - sg_matrix_repls["WEBGPU_SG_MAT_K_SIZE"] = std::to_string(webgpu_ctx->sg_mat_k); - + sg_matrix_repls["WEBGPU_SG_MAT_M_SIZE"] = std::to_string(webgpu_ctx->global_ctx->capabilities.sg_mat_m); + sg_matrix_repls["WEBGPU_SG_MAT_N_SIZE"] = std::to_string(webgpu_ctx->global_ctx->capabilities.sg_mat_n); + sg_matrix_repls["WEBGPU_SG_MAT_K_SIZE"] = std::to_string(webgpu_ctx->global_ctx->capabilities.sg_mat_k); proc_mul_mat_f32_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32, sg_matrix_repls); proc_mul_mat_f32_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32_vec, sg_matrix_repls); @@ -2522,21 +2596,21 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { #endif webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, proc_mul_mat_f32_f32.c_str(), "mul_mat_f32_f32", mul_mat_constants); + webgpu_ctx->global_ctx->device, proc_mul_mat_f32_f32.c_str(), "mul_mat_f32_f32", mul_mat_constants); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, proc_mul_mat_f32_f32_vec.c_str(), "mul_mat_f32_f32_vec", mul_mat_constants); + webgpu_ctx->global_ctx->device, proc_mul_mat_f32_f32_vec.c_str(), "mul_mat_f32_f32_vec", mul_mat_constants); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, proc_mul_mat_f16_f32.c_str(), "mul_mat_f16_f32", mul_mat_constants); + webgpu_ctx->global_ctx->device, proc_mul_mat_f16_f32.c_str(), "mul_mat_f16_f32", mul_mat_constants); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, proc_mul_mat_f16_f32_vec.c_str(), "mul_mat_f16_f32_vec", mul_mat_constants); + webgpu_ctx->global_ctx->device, proc_mul_mat_f16_f32_vec.c_str(), "mul_mat_f16_f32_vec", mul_mat_constants); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, proc_mul_mat_f16_f16.c_str(), "mul_mat_f16_f16", mul_mat_constants); + webgpu_ctx->global_ctx->device, proc_mul_mat_f16_f16.c_str(), "mul_mat_f16_f16", mul_mat_constants); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, proc_mul_mat_f16_f16_vec.c_str(), "mul_mat_f16_f16_vec", mul_mat_constants); + webgpu_ctx->global_ctx->device, proc_mul_mat_f16_f16_vec.c_str(), "mul_mat_f16_f16_vec", mul_mat_constants); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, proc_mul_mat_q4_0_f32.c_str(), "mul_mat_q4_0_f32", mul_mat_constants); + webgpu_ctx->global_ctx->device, proc_mul_mat_q4_0_f32.c_str(), "mul_mat_q4_0_f32", mul_mat_constants); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, proc_mul_mat_q4_0_f32_vec.c_str(), "mul_mat_q4_0_f32_vec", mul_mat_constants); + webgpu_ctx->global_ctx->device, proc_mul_mat_q4_0_f32_vec.c_str(), "mul_mat_q4_0_f32_vec", mul_mat_constants); std::vector mul_mat_vec_constants(3); mul_mat_vec_constants[0].key = "WORKGROUP_SIZE"; @@ -2547,171 +2621,119 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { mul_mat_vec_constants[2].value = WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG; webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, wgsl_mul_mat_vec_f32_f32, "mul_mat_vec_f32_f32", mul_mat_vec_constants); + webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f32_f32, "mul_mat_vec_f32_f32", mul_mat_vec_constants); webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, wgsl_mul_mat_vec_f32_f32_vec, "mul_mat_vec_f32_f32_vec", mul_mat_vec_constants); + webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f32_f32_vec, "mul_mat_vec_f32_f32_vec", mul_mat_vec_constants); webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, wgsl_mul_mat_vec_f16_f32, "mul_mat_vec_f16_f32", mul_mat_vec_constants); + webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f16_f32, "mul_mat_vec_f16_f32", mul_mat_vec_constants); webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, wgsl_mul_mat_vec_f16_f32_vec, "mul_mat_vec_f16_f32_vec", mul_mat_vec_constants); + webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f16_f32_vec, "mul_mat_vec_f16_f32_vec", mul_mat_vec_constants); webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, wgsl_mul_mat_vec_f16_f16, "mul_mat_vec_f16_f16", mul_mat_vec_constants); + webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f16_f16, "mul_mat_vec_f16_f16", mul_mat_vec_constants); webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, wgsl_mul_mat_vec_f16_f16_vec, "mul_mat_vec_f16_f16_vec", mul_mat_vec_constants); + webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f16_f16_vec, "mul_mat_vec_f16_f16_vec", mul_mat_vec_constants); webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, wgsl_mul_mat_vec_q4_0_f32, "mul_mat_vec_q4_0_f32", mul_mat_vec_constants); + webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_q4_0_f32, "mul_mat_vec_q4_0_f32", mul_mat_vec_constants); } static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) { std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); webgpu_ctx->get_rows_pipelines[GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_f32, "get_rows_f32", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_f32_vec, "get_rows_f32_vec", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_f32, "get_rows_f32", constants); + webgpu_ctx->get_rows_pipelines[GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( + webgpu_ctx->global_ctx->device, wgsl_get_rows_f32_vec, "get_rows_f32_vec", constants); webgpu_ctx->get_rows_pipelines[GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_f16, "get_rows_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_f16, "get_rows_f16", constants); webgpu_ctx->get_rows_pipelines[GGML_TYPE_I32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_i32, "get_rows_i32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_i32, "get_rows_i32", constants); webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q4_0][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q4_0, "get_rows_q4_0", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q4_0, "get_rows_q4_0", constants); webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q4_1][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q4_1, "get_rows_q4_1", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q4_1, "get_rows_q4_1", constants); webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q5_0][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q5_0, "get_rows_q5_0", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q5_0, "get_rows_q5_0", constants); webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q5_1][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q5_1, "get_rows_q5_1", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q5_1, "get_rows_q5_1", constants); webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q8_0][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q8_0, "get_rows_q8_0", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q8_0, "get_rows_q8_0", constants); webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q2_K][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q2_k, "get_rows_q2_k", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q2_k, "get_rows_q2_k", constants); webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q3_K][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q3_k, "get_rows_q3_k", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q3_k, "get_rows_q3_k", constants); webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q4_K][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q4_k, "get_rows_q4_k", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q4_k, "get_rows_q4_k", constants); webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q5_K][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q5_k, "get_rows_q5_k", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q5_k, "get_rows_q5_k", constants); webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q6_K][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q6_k, "get_rows_q6_k", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q6_k, "get_rows_q6_k", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_XXS][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq2_xxs, "get_rows_iq2_xxs", constants); + webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_XXS][0] = ggml_webgpu_create_pipeline( + webgpu_ctx->global_ctx->device, wgsl_get_rows_iq2_xxs, "get_rows_iq2_xxs", constants); webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_XS][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq2_xs, "get_rows_iq2_xs", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq2_xs, "get_rows_iq2_xs", constants); webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_S][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq2_s, "get_rows_iq2_s", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ3_XXS][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq3_xxs, "get_rows_iq3_xxs", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq2_s, "get_rows_iq2_s", constants); + webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ3_XXS][0] = ggml_webgpu_create_pipeline( + webgpu_ctx->global_ctx->device, wgsl_get_rows_iq3_xxs, "get_rows_iq3_xxs", constants); webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ3_S][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq3_s, "get_rows_iq3_s", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq3_s, "get_rows_iq3_s", constants); webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ1_S][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq1_s, "get_rows_iq1_s", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq1_s, "get_rows_iq1_s", constants); webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ1_M][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq1_m, "get_rows_iq1_m", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq1_m, "get_rows_iq1_m", constants); webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ4_NL][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq4_nl, "get_rows_iq4_nl", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq4_nl, "get_rows_iq4_nl", constants); webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ4_XS][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq4_xs, "get_rows_iq4_xs", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq4_xs, "get_rows_iq4_xs", constants); } static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) { std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_F32] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_cpy_f32_f32, "cpy_f32_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f32_f32, "cpy_f32_f32", constants); webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_I32] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_cpy_f32_i32, "cpy_f32_i32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f32_i32, "cpy_f32_i32", constants); webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_F16] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_cpy_f32_f16, "cpy_f32_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f32_f16, "cpy_f32_f16", constants); webgpu_ctx->cpy_pipelines[GGML_TYPE_F16][GGML_TYPE_F32] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_cpy_f16_f32, "cpy_f16_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f16_f32, "cpy_f16_f32", constants); webgpu_ctx->cpy_pipelines[GGML_TYPE_F16][GGML_TYPE_F16] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_cpy_f16_f16, "cpy_f16_f16", constants); -} - -static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); - - webgpu_ctx->add_pipelines[GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_add_f32, "add_f32", constants); - webgpu_ctx->add_pipelines[GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_add_f16, "add_f16", constants); - webgpu_ctx->add_pipelines[GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_add_f32_inplace, "add_f32_inplace", constants); - webgpu_ctx->add_pipelines[GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_add_f16_inplace, "add_f16_inplace", constants); -} - -static void ggml_webgpu_init_sub_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); - - webgpu_ctx->sub_pipelines[GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sub_f32, "sub_f32", constants); - webgpu_ctx->sub_pipelines[GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sub_f16, "sub_f16", constants); - webgpu_ctx->sub_pipelines[GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sub_f32_inplace, "sub_f32_inplace", constants); - webgpu_ctx->sub_pipelines[GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sub_f16_inplace, "sub_f16_inplace", constants); -} - -static void ggml_webgpu_init_mul_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); - - webgpu_ctx->mul_pipelines[GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_f32, "mul_f32", constants); - webgpu_ctx->mul_pipelines[GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_f16, "mul_f16", constants); - webgpu_ctx->mul_pipelines[GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_f32_inplace, "mul_f32_inplace", constants); - webgpu_ctx->mul_pipelines[GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_f16_inplace, "mul_f16_inplace", constants); -} - -static void ggml_webgpu_init_div_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); - - webgpu_ctx->div_pipelines[GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_div_f32, "div_f32", constants); - webgpu_ctx->div_pipelines[GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_div_f16, "div_f16", constants); - webgpu_ctx->div_pipelines[GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_div_f32_inplace, "div_f32_inplace", constants); - webgpu_ctx->div_pipelines[GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_div_f16_inplace, "div_f16_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f16_f16, "cpy_f16_f16", constants); } static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) { std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE); webgpu_ctx->rms_norm_pipelines[0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rms_norm, "rms_norm", constants); - webgpu_ctx->rms_norm_pipelines[1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rms_norm_inplace, "rms_norm_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rms_norm, "rms_norm", constants); + webgpu_ctx->rms_norm_pipelines[1] = ggml_webgpu_create_pipeline( + webgpu_ctx->global_ctx->device, wgsl_rms_norm_inplace, "rms_norm_inplace", constants); } static void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) { std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); webgpu_ctx->rope_pipelines[GGML_TYPE_F32][0][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f32, "rope_f32", constants); - webgpu_ctx->rope_pipelines[GGML_TYPE_F32][0][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f32_inplace, "rope_f32_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f32, "rope_f32", constants); + webgpu_ctx->rope_pipelines[GGML_TYPE_F32][0][1] = ggml_webgpu_create_pipeline( + webgpu_ctx->global_ctx->device, wgsl_rope_f32_inplace, "rope_f32_inplace", constants); webgpu_ctx->rope_pipelines[GGML_TYPE_F32][1][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f32_ff, "rope_f32_ff", constants); - webgpu_ctx->rope_pipelines[GGML_TYPE_F32][1][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f32_ff_inplace, "rope_f32_ff_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f32_ff, "rope_f32_ff", constants); + webgpu_ctx->rope_pipelines[GGML_TYPE_F32][1][1] = ggml_webgpu_create_pipeline( + webgpu_ctx->global_ctx->device, wgsl_rope_f32_ff_inplace, "rope_f32_ff_inplace", constants); webgpu_ctx->rope_pipelines[GGML_TYPE_F16][0][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f16, "rope_f16", constants); - webgpu_ctx->rope_pipelines[GGML_TYPE_F16][0][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f16_inplace, "rope_f16_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f16, "rope_f16", constants); + webgpu_ctx->rope_pipelines[GGML_TYPE_F16][0][1] = ggml_webgpu_create_pipeline( + webgpu_ctx->global_ctx->device, wgsl_rope_f16_inplace, "rope_f16_inplace", constants); webgpu_ctx->rope_pipelines[GGML_TYPE_F16][1][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f16_ff, "rope_f16_ff", constants); - webgpu_ctx->rope_pipelines[GGML_TYPE_F16][1][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f16_ff_inplace, "rope_f16_ff_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f16_ff, "rope_f16_ff", constants); + webgpu_ctx->rope_pipelines[GGML_TYPE_F16][1][1] = ggml_webgpu_create_pipeline( + webgpu_ctx->global_ctx->device, wgsl_rope_f16_ff_inplace, "rope_f16_ff_inplace", constants); } static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) { @@ -2719,68 +2741,68 @@ static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) { // REGLU webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_reglu_f32, "reglu_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f32, "reglu_f32", constants); webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_reglu_f16, "reglu_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f16, "reglu_f16", constants); webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_reglu_f32_split, "reglu_f32_split", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f32_split, "reglu_f32_split", constants); webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_reglu_f16_split, "reglu_f16_split", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f16_split, "reglu_f16_split", constants); // GEGLU webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_f32, "geglu_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f32, "geglu_f32", constants); webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_f16, "geglu_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f16, "geglu_f16", constants); webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_f32_split, "geglu_f32_split", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f32_split, "geglu_f32_split", constants); webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_f16_split, "geglu_f16_split", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f16_split, "geglu_f16_split", constants); // SWIGLU webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_f32, "swiglu_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_swiglu_f32, "swiglu_f32", constants); webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_f16, "swiglu_f16", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_f32_split, "swiglu_f32_split", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_f16_split, "swiglu_f16_split", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_swiglu_f16, "swiglu_f16", constants); + webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( + webgpu_ctx->global_ctx->device, wgsl_swiglu_f32_split, "swiglu_f32_split", constants); + webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline( + webgpu_ctx->global_ctx->device, wgsl_swiglu_f16_split, "swiglu_f16_split", constants); // SWIGLU_OAI webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_oai_f32, "swiglu_oai_f32", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_oai_f32_split, "swiglu_oai_f32_split", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_swiglu_oai_f32, "swiglu_oai_f32", constants); + webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( + webgpu_ctx->global_ctx->device, wgsl_swiglu_oai_f32_split, "swiglu_oai_f32_split", constants); // GEGLU_ERF webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_erf_f32, "geglu_erf_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f32, "geglu_erf_f32", constants); webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_erf_f16, "geglu_erf_f16", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_erf_f32_split, "geglu_erf_f32_split", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_erf_f16_split, "geglu_erf_f16_split", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f16, "geglu_erf_f16", constants); + webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( + webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f32_split, "geglu_erf_f32_split", constants); + webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline( + webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f16_split, "geglu_erf_f16_split", constants); // GEGLU_QUICK webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_quick_f32, "geglu_quick_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f32, "geglu_quick_f32", constants); webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_quick_f16, "geglu_quick_f16", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_quick_f32_split, "geglu_quick_f32_split", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_quick_f16_split, "geglu_quick_f16_split", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f16, "geglu_quick_f16", constants); + webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( + webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f32_split, "geglu_quick_f32_split", constants); + webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline( + webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f16_split, "geglu_quick_f16_split", constants); } static void ggml_webgpu_init_scale_pipeline(webgpu_context & webgpu_ctx) { std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); webgpu_ctx->scale_pipelines[0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_scale_f32, "scale_f32", constants); - webgpu_ctx->scale_pipelines[1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_scale_f32_inplace, "scale_f32_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_scale_f32, "scale_f32", constants); + webgpu_ctx->scale_pipelines[1] = ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_scale_f32_inplace, + "scale_f32_inplace", constants); } static void ggml_webgpu_init_soft_max_pipeline(webgpu_context & webgpu_ctx) { @@ -2788,56 +2810,239 @@ static void ggml_webgpu_init_soft_max_pipeline(webgpu_context & webgpu_ctx) { // f32 (no mask) webgpu_ctx->soft_max_pipelines[2][0][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_soft_max_f32, "soft_max_f32", constants); - webgpu_ctx->soft_max_pipelines[2][0][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_soft_max_f32_inplace, "soft_max_f32_inplace", constants); - webgpu_ctx->soft_max_pipelines[2][1][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_soft_max_f32_sink, "soft_max_f32_sink", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_soft_max_f32, "soft_max_f32", constants); + webgpu_ctx->soft_max_pipelines[2][0][1] = ggml_webgpu_create_pipeline( + webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_inplace, "soft_max_f32_inplace", constants); + webgpu_ctx->soft_max_pipelines[2][1][0] = ggml_webgpu_create_pipeline( + webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_sink, "soft_max_f32_sink", constants); webgpu_ctx->soft_max_pipelines[2][1][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, wgsl_soft_max_f32_sink_inplace, "soft_max_f32_sink_inplace", constants); + webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_sink_inplace, "soft_max_f32_sink_inplace", constants); // f32 mask (mask_type = 0) - webgpu_ctx->soft_max_pipelines[0][0][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_soft_max_f32_mask_f32, "soft_max_f32_mask_f32", constants); + webgpu_ctx->soft_max_pipelines[0][0][0] = ggml_webgpu_create_pipeline( + webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32, "soft_max_f32_mask_f32", constants); webgpu_ctx->soft_max_pipelines[0][0][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, wgsl_soft_max_f32_mask_f32_inplace, "soft_max_f32_mask_f32_inplace", constants); + webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32_inplace, "soft_max_f32_mask_f32_inplace", constants); webgpu_ctx->soft_max_pipelines[0][1][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, wgsl_soft_max_f32_mask_f32_sink, "soft_max_f32_mask_f32_sink", constants); - webgpu_ctx->soft_max_pipelines[0][1][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, wgsl_soft_max_f32_mask_f32_sink_inplace, "soft_max_f32_mask_f32_sink_inplace", constants); + webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32_sink, "soft_max_f32_mask_f32_sink", constants); + webgpu_ctx->soft_max_pipelines[0][1][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32_sink_inplace, + "soft_max_f32_mask_f32_sink_inplace", constants); // f16 mask (mask_type = 1) - webgpu_ctx->soft_max_pipelines[1][0][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_soft_max_f32_mask_f16, "soft_max_f32_mask_f16", constants); + webgpu_ctx->soft_max_pipelines[1][0][0] = ggml_webgpu_create_pipeline( + webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16, "soft_max_f32_mask_f16", constants); webgpu_ctx->soft_max_pipelines[1][0][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, wgsl_soft_max_f32_mask_f16_inplace, "soft_max_f32_mask_f16_inplace", constants); + webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16_inplace, "soft_max_f32_mask_f16_inplace", constants); webgpu_ctx->soft_max_pipelines[1][1][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, wgsl_soft_max_f32_mask_f16_sink, "soft_max_f32_mask_f16_sink", constants); - webgpu_ctx->soft_max_pipelines[1][1][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, wgsl_soft_max_f32_mask_f16_sink_inplace, "soft_max_f32_mask_f16_sink_inplace", constants); + webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16_sink, "soft_max_f32_mask_f16_sink", constants); + webgpu_ctx->soft_max_pipelines[1][1][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16_sink_inplace, + "soft_max_f32_mask_f16_sink_inplace", constants); +} + +static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { + wgpu::RequestAdapterOptions options = {}; + +#ifndef __EMSCRIPTEN__ + // TODO: track need for these toggles: https://issues.chromium.org/issues/42251215 + const char * const adapterEnabledToggles[] = { "vulkan_enable_f16_on_nvidia", "use_vulkan_memory_model" }; + wgpu::DawnTogglesDescriptor adapterTogglesDesc; + adapterTogglesDesc.enabledToggles = adapterEnabledToggles; + adapterTogglesDesc.enabledToggleCount = 2; + options.nextInChain = &adapterTogglesDesc; +#endif + + ctx->webgpu_global_ctx->instance.WaitAny( + ctx->webgpu_global_ctx->instance.RequestAdapter( + &options, wgpu::CallbackMode::AllowSpontaneous, + [&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) { + if (status != wgpu::RequestAdapterStatus::Success) { + GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message); + return; + } + ctx->webgpu_global_ctx->adapter = std::move(adapter); + }), + UINT64_MAX); + GGML_ASSERT(ctx->webgpu_global_ctx->adapter != nullptr); + + ctx->webgpu_global_ctx->adapter.GetLimits(&ctx->webgpu_global_ctx->capabilities.limits); + + wgpu::AdapterInfo info{}; +#ifndef __EMSCRIPTEN__ + wgpu::AdapterPropertiesSubgroupMatrixConfigs subgroup_matrix_configs{}; + if (ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) { + info.nextInChain = &subgroup_matrix_configs; + } +#endif + ctx->webgpu_global_ctx->adapter.GetInfo(&info); + wgpu::SupportedFeatures features; + ctx->webgpu_global_ctx->adapter.GetFeatures(&features); + // we require f16 support + GGML_ASSERT(ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16)); + +#ifndef __EMSCRIPTEN__ + // Only support square f16 matrices of size 8 or 16 for now + bool valid_subgroup_matrix_config = false; + if (ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) { + for (size_t i = 0; i < subgroup_matrix_configs.configCount; i++) { + const wgpu::SubgroupMatrixConfig config = subgroup_matrix_configs.configs[i]; + if (config.M == config.N && config.N == config.K && (config.K == 8 || config.K == 16) && + config.componentType == wgpu::SubgroupMatrixComponentType::F16 && + config.resultComponentType == wgpu::SubgroupMatrixComponentType::F16) { + ctx->webgpu_global_ctx->capabilities.sg_mat_m = config.M; + ctx->webgpu_global_ctx->capabilities.sg_mat_n = config.N; + ctx->webgpu_global_ctx->capabilities.sg_mat_k = config.K; + valid_subgroup_matrix_config = true; + break; + } + } + } + ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix = valid_subgroup_matrix_config; +#endif + + // For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate. + // Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter. + ctx->webgpu_global_ctx->capabilities.max_subgroup_size = info.subgroupMaxSize; + // Initialize device + std::vector required_features = { wgpu::FeatureName::ShaderF16 }; + +#ifndef __EMSCRIPTEN__ + required_features.push_back(wgpu::FeatureName::ImplicitDeviceSynchronization); + if (ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) { + required_features.push_back(wgpu::FeatureName::Subgroups); + required_features.push_back(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix); + } +#endif + +#ifdef GGML_WEBGPU_GPU_PROFILE + required_features.push_back(wgpu::FeatureName::TimestampQuery); +#endif + + wgpu::DeviceDescriptor dev_desc; + dev_desc.requiredLimits = &ctx->webgpu_global_ctx->capabilities.limits; + dev_desc.requiredFeatures = required_features.data(); + dev_desc.requiredFeatureCount = required_features.size(); + dev_desc.SetDeviceLostCallback( + wgpu::CallbackMode::AllowSpontaneous, + [](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) { + GGML_UNUSED(device); + GGML_UNUSED(reason); + GGML_UNUSED(message); + //TODO: uncomment once proper free logic is in place + //GGML_LOG_ERROR("ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast(reason), + //std::string(message).c_str()); + }); + dev_desc.SetUncapturedErrorCallback( + [](const wgpu::Device & device, wgpu::ErrorType reason, wgpu::StringView message) { + GGML_UNUSED(device); + GGML_ABORT("ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast(reason), + std::string(message).c_str()); + }); + +#ifndef __EMSCRIPTEN__ + // Enable Dawn-specific toggles to increase native performance + // TODO: Maybe WebGPU needs a "fast" mode where you can request compilers skip adding checks like these, + // only for native performance? + const char * const deviceEnabledToggles[] = { "skip_validation", "disable_robustness", "disable_workgroup_init", + "disable_polyfills_on_integer_div_and_mod" }; + const char * const deviceDisabledToggles[] = { "timestamp_quantization" }; + wgpu::DawnTogglesDescriptor deviceTogglesDesc; + deviceTogglesDesc.enabledToggles = deviceEnabledToggles; + deviceTogglesDesc.enabledToggleCount = 4; + deviceTogglesDesc.disabledToggles = deviceDisabledToggles; + deviceTogglesDesc.disabledToggleCount = 1; + + dev_desc.nextInChain = &deviceTogglesDesc; +#endif + + ctx->webgpu_global_ctx->instance.WaitAny( + ctx->webgpu_global_ctx->adapter.RequestDevice( + &dev_desc, wgpu::CallbackMode::AllowSpontaneous, + [ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) { + if (status != wgpu::RequestDeviceStatus::Success) { + GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n", std::string(message).c_str()); + return; + } + ctx->webgpu_global_ctx->device = std::move(device); + }), + UINT64_MAX); + GGML_ASSERT(ctx->webgpu_global_ctx->device != nullptr); + + ggml_webgpu_init_memset_pipeline(ctx->webgpu_global_ctx); + ctx->webgpu_global_ctx->memset_buf_pool.init(ctx->webgpu_global_ctx->device, 1, WEBGPU_PARAMS_BUF_SIZE_BYTES, + wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform, + wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite); + ctx->webgpu_global_ctx->queue = ctx->webgpu_global_ctx->device.GetQueue(); + +#ifdef GGML_WEBGPU_GPU_PROFILE + // Initialize buffer pool for timestamp queries, used for profiling + ctx->webgpu_global_ctx->timestamp_query_buf_pool.init( + ctx->webgpu_global_ctx->device, WEBGPU_NUM_TIMESTAMP_QUERY_BUFS, WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES, + wgpu::BufferUsage::QueryResolve | wgpu::BufferUsage::CopySrc, + wgpu::BufferUsage::MapRead | wgpu::BufferUsage::CopyDst); +#endif + + GGML_LOG_INFO( + "ggml_webgpu: adapter_info: vendor_id: %u | vendor: %s | architecture: %s | device_id: %u | name: %s | " + "device_desc: %s\n", + info.vendorID, std::string(info.vendor).c_str(), std::string(info.architecture).c_str(), info.deviceID, + std::string(info.device).c_str(), std::string(info.description).c_str()); + return true; +} + +static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) { + ggml_backend_webgpu_device_context * dev_ctx = (ggml_backend_webgpu_device_context *) dev->context; + webgpu_context webgpu_ctx = std::make_shared(); + webgpu_ctx->global_ctx = dev_ctx->webgpu_global_ctx; + webgpu_ctx->param_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_PARAM_BUFS, WEBGPU_PARAMS_BUF_SIZE_BYTES, + wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform, + wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite); + webgpu_ctx->set_rows_error_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_SET_ROWS_ERROR_BUFS, + WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES, + wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage, + wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead); + + ggml_webgpu_init_mul_mat_pipeline(webgpu_ctx); + ggml_webgpu_init_get_rows_pipeline(webgpu_ctx); + ggml_webgpu_init_cpy_pipeline(webgpu_ctx); + ggml_webgpu_init_rms_norm_pipeline(webgpu_ctx); + ggml_webgpu_init_rope_pipeline(webgpu_ctx); + ggml_webgpu_init_glu_pipeline(webgpu_ctx); + ggml_webgpu_init_scale_pipeline(webgpu_ctx); + ggml_webgpu_init_soft_max_pipeline(webgpu_ctx); +#ifdef GGML_WEBGPU_DEBUG + // Initialize debug buffers + ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->global_ctx->debug_host_buf, + WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t), + wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "debug_host_buf"); + ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->global_ctx->debug_dev_buf, + WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t), + wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc, "debug_dev_buf"); +#endif + return webgpu_ctx; } -// TODO: move most initialization logic here -static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) { +static ggml_backend_t ggml_backend_webgpu_backend_init(ggml_backend_dev_t dev, const char * params) { GGML_UNUSED(params); - WEBGPU_LOG_DEBUG("ggml_backend_webgpu_device_init()"); + WEBGPU_LOG_DEBUG("ggml_backend_webgpu_backend_init()"); - ggml_backend_webgpu_device_context * dev_ctx = static_cast(dev->context); - webgpu_context webgpu_ctx = dev_ctx->webgpu_ctx; + ggml_backend_webgpu_device_context * dev_ctx = static_cast(dev->context); - static ggml_backend_webgpu_context backend_ctx; - backend_ctx.name = GGML_WEBGPU_NAME + std::string(": ") + dev_ctx->device_name; - backend_ctx.webgpu_ctx = webgpu_ctx; + auto * backend_ctx = new ggml_backend_webgpu_context(); + backend_ctx->name = GGML_WEBGPU_NAME + std::string(": ") + dev_ctx->device_name; + backend_ctx->webgpu_ctx = initialize_webgpu_context(dev); // See GGML Backend Interface section - static ggml_backend backend = { + auto * backend = new ggml_backend(); + *backend = { /* .guid = */ ggml_backend_webgpu_guid(), /* .interface = */ ggml_backend_webgpu_i, /* .device = */ dev, - /* .context = */ &backend_ctx, + /* .context = */ backend_ctx, }; - return &backend; + return backend; } static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggml_backend_dev_t dev) { @@ -2854,7 +3059,8 @@ static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggm }, /* .device = */ dev, - /* .context = */ NULL, + /* .context = */ + NULL }; return &ggml_backend_webgpu_buffer_type; @@ -2895,16 +3101,16 @@ static bool ggml_webgpu_supported_qtype(ggml_type type) { static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) { ggml_backend_webgpu_device_context * ctx = static_cast(dev->context); - webgpu_context webgpu_ctx = ctx->webgpu_ctx; - ggml_tensor * src0 = op->src[0]; ggml_tensor * src1 = op->src[1]; ggml_tensor * src2 = op->src[2]; // on smaller devices (or CI), tensors may be larger than the max storage buffer size - if (ggml_nbytes(op) > webgpu_ctx->limits.maxStorageBufferBindingSize || - (src0 != nullptr && ggml_nbytes(src0) > webgpu_ctx->limits.maxStorageBufferBindingSize) || - (src1 != nullptr && ggml_nbytes(src1) > webgpu_ctx->limits.maxStorageBufferBindingSize)) { + if (ggml_nbytes(op) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize || + (src0 != nullptr && + ggml_nbytes(src0) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize) || + (src1 != nullptr && + ggml_nbytes(src1) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize)) { return false; } @@ -2984,17 +3190,19 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const } case GGML_OP_FLASH_ATTN_EXT: { - if (!webgpu_ctx->supports_subgroup_matrix) { +#ifndef __EMSCRIPTEN__ + if (!ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) { break; } // Head dimensions must fit in workgroup memory with minimum tile sizes - size_t limit_bytes = webgpu_ctx->limits.maxComputeWorkgroupStorageSize; + size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; const bool has_mask = op->src[3] != nullptr; - const bool kv_direct = src1->type == GGML_TYPE_F16 && (src0->ne[0] % webgpu_ctx->sg_mat_k) == 0 && + const bool kv_direct = src1->type == GGML_TYPE_F16 && + (src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k) == 0 && (src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0; const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes( - webgpu_ctx->sg_mat_m, webgpu_ctx->sg_mat_n, (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], - has_mask, kv_direct); + ctx->webgpu_global_ctx->capabilities.sg_mat_m, ctx->webgpu_global_ctx->capabilities.sg_mat_n, + (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask, kv_direct); if (min_bytes > limit_bytes) { break; } @@ -3003,6 +3211,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const (src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_Q4_0 || src1->type == GGML_TYPE_Q8_0) && src2->type == src1->type && op->type == GGML_TYPE_F32; +#endif break; } case GGML_OP_RMS_NORM: @@ -3099,10 +3308,13 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const default: break; } - if (ggml_nbytes(op) > webgpu_ctx->limits.maxStorageBufferBindingSize || - (src0 != nullptr && ggml_nbytes(src0) > webgpu_ctx->limits.maxStorageBufferBindingSize) || - (src1 != nullptr && ggml_nbytes(src1) > webgpu_ctx->limits.maxStorageBufferBindingSize) || - (src2 != nullptr && ggml_nbytes(src2) > webgpu_ctx->limits.maxStorageBufferBindingSize)) { + if (ggml_nbytes(op) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize || + (src0 != nullptr && + ggml_nbytes(src0) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize) || + (src1 != nullptr && + ggml_nbytes(src1) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize) || + (src2 != nullptr && + ggml_nbytes(src2) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize)) { supports_op = false; WEBGPU_LOG_DEBUG("ggml_webgpu op not supported due to size: "); } @@ -3127,7 +3339,7 @@ static struct ggml_backend_device_i ggml_backend_webgpu_device_i = { /* .get_memory = */ ggml_backend_webgpu_device_get_memory, /* .get_type = */ ggml_backend_webgpu_device_get_type, /* .get_props = */ ggml_backend_webgpu_device_get_props, - /* .init_backend = */ ggml_backend_webgpu_device_init, + /* .init_backend = */ ggml_backend_webgpu_backend_init, /* .get_buffer_type = */ ggml_backend_webgpu_device_get_buffer_type, /* .get_host_buffer_type = */ NULL, /* .buffer_from_host_ptr = */ NULL, @@ -3156,6 +3368,7 @@ static size_t ggml_backend_webgpu_reg_get_device_count(ggml_backend_reg_t reg) { // TODO: Does this need to be thread safe? Is it only called once? // TODO: move most logic to device_init function so backend can be freed/initialized properly // Only one device is supported for now + static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t reg, size_t index) { GGML_ASSERT(index == 0); WEBGPU_LOG_DEBUG("ggml_backend_reg_get_device()"); @@ -3164,189 +3377,12 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t ggml_backend_webgpu_reg_context * reg_ctx = static_cast(reg->context); - webgpu_context ctx = reg_ctx->webgpu_ctx; - - wgpu::RequestAdapterOptions options = {}; - -#ifndef __EMSCRIPTEN__ - // TODO: track need for these toggles: https://issues.chromium.org/issues/42251215 - const char * const adapterEnabledToggles[] = { "vulkan_enable_f16_on_nvidia", "use_vulkan_memory_model" }; - wgpu::DawnTogglesDescriptor adapterTogglesDesc; - adapterTogglesDesc.enabledToggles = adapterEnabledToggles; - adapterTogglesDesc.enabledToggleCount = 2; - options.nextInChain = &adapterTogglesDesc; -#endif - - ctx->instance.WaitAny(ctx->instance.RequestAdapter( - &options, wgpu::CallbackMode::AllowSpontaneous, - [&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) { - if (status != wgpu::RequestAdapterStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message); - return; - } - ctx->adapter = std::move(adapter); - }), - UINT64_MAX); - GGML_ASSERT(ctx->adapter != nullptr); - - ctx->adapter.GetLimits(&ctx->limits); - - wgpu::AdapterInfo info{}; -#ifndef __EMSCRIPTEN__ - wgpu::AdapterPropertiesSubgroupMatrixConfigs subgroup_matrix_configs{}; - if (ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) { - info.nextInChain = &subgroup_matrix_configs; - } -#endif - ctx->adapter.GetInfo(&info); - - wgpu::SupportedFeatures features; - ctx->adapter.GetFeatures(&features); - // we require f16 support - GGML_ASSERT(ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16)); - -#ifndef __EMSCRIPTEN__ - // Only support square f16 matrices of size 8 or 16 for now - bool valid_subgroup_matrix_config = false; - if (ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) { - for (size_t i = 0; i < subgroup_matrix_configs.configCount; i++) { - const wgpu::SubgroupMatrixConfig config = subgroup_matrix_configs.configs[i]; - if (config.M == config.N && config.N == config.K && (config.K == 8 || config.K == 16) && - config.componentType == wgpu::SubgroupMatrixComponentType::F16 && - config.resultComponentType == wgpu::SubgroupMatrixComponentType::F16) { - ctx->sg_mat_m = config.M; - ctx->sg_mat_n = config.N; - ctx->sg_mat_k = config.K; - valid_subgroup_matrix_config = true; - break; - } - } - } - - ctx->supports_subgroup_matrix = valid_subgroup_matrix_config; -#endif - // For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate. - // Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter. - ctx->max_subgroup_size = info.subgroupMaxSize; - - // Initialize device - std::vector required_features = { wgpu::FeatureName::ShaderF16 }; - -#ifndef __EMSCRIPTEN__ - required_features.push_back(wgpu::FeatureName::ImplicitDeviceSynchronization); - if (ctx->supports_subgroup_matrix) { - required_features.push_back(wgpu::FeatureName::Subgroups); - required_features.push_back(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix); - } -#endif - -#ifdef GGML_WEBGPU_GPU_PROFILE - required_features.push_back(wgpu::FeatureName::TimestampQuery); -#endif - - wgpu::DeviceDescriptor dev_desc; - dev_desc.requiredLimits = &ctx->limits; - dev_desc.requiredFeatures = required_features.data(); - dev_desc.requiredFeatureCount = required_features.size(); - dev_desc.SetDeviceLostCallback( - wgpu::CallbackMode::AllowSpontaneous, - [](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) { - GGML_UNUSED(device); - GGML_UNUSED(reason); - GGML_UNUSED(message); - //TODO: uncomment once proper free logic is in place - //GGML_LOG_ERROR("ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast(reason), - //std::string(message).c_str()); - }); - dev_desc.SetUncapturedErrorCallback( - [](const wgpu::Device & device, wgpu::ErrorType reason, wgpu::StringView message) { - GGML_UNUSED(device); - GGML_ABORT("ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast(reason), - std::string(message).c_str()); - }); - -#ifndef __EMSCRIPTEN__ - // Enable Dawn-specific toggles to increase native performance - // TODO: Maybe WebGPU needs a "fast" mode where you can request compilers skip adding checks like these, - // only for native performance? - const char * const deviceEnabledToggles[] = { "skip_validation", "disable_robustness", "disable_workgroup_init", - "disable_polyfills_on_integer_div_and_mod" }; - const char * const deviceDisabledToggles[] = { "timestamp_quantization" }; - wgpu::DawnTogglesDescriptor deviceTogglesDesc; - deviceTogglesDesc.enabledToggles = deviceEnabledToggles; - deviceTogglesDesc.enabledToggleCount = 4; - deviceTogglesDesc.disabledToggles = deviceDisabledToggles; - deviceTogglesDesc.disabledToggleCount = 1; - - dev_desc.nextInChain = &deviceTogglesDesc; -#endif - - ctx->instance.WaitAny(ctx->adapter.RequestDevice( - &dev_desc, wgpu::CallbackMode::AllowSpontaneous, - [ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) { - if (status != wgpu::RequestDeviceStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n", - std::string(message).c_str()); - return; - } - ctx->device = std::move(device); - }), - UINT64_MAX); - GGML_ASSERT(ctx->device != nullptr); - - // Initialize (compute) queue - ctx->queue = ctx->device.GetQueue(); - - // Create buffer pool for shader parameters - ctx->param_buf_pool.init(ctx->device, WEBGPU_NUM_PARAM_BUFS, WEBGPU_PARAMS_BUF_SIZE_BYTES, - wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform, - wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite); - -#ifdef GGML_WEBGPU_GPU_PROFILE - // Initialize buffer pool for timestamp queries (profiling) - ctx->timestamp_query_buf_pool.init(ctx->device, WEBGPU_NUM_TIMESTAMP_QUERY_BUFS, - WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES, - wgpu::BufferUsage::QueryResolve | wgpu::BufferUsage::CopySrc, - wgpu::BufferUsage::MapRead | wgpu::BufferUsage::CopyDst); -#endif - - ctx->set_rows_error_buf_pool.init(ctx->device, WEBGPU_NUM_SET_ROWS_ERROR_BUFS, WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES, - wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage, - wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead); - - ggml_webgpu_init_memset_pipeline(ctx); - ggml_webgpu_init_mul_mat_pipeline(ctx); - ggml_webgpu_init_get_rows_pipeline(ctx); - ggml_webgpu_init_cpy_pipeline(ctx); - ggml_webgpu_init_add_pipeline(ctx); - ggml_webgpu_init_sub_pipeline(ctx); - ggml_webgpu_init_mul_pipeline(ctx); - ggml_webgpu_init_div_pipeline(ctx); - ggml_webgpu_init_rms_norm_pipeline(ctx); - ggml_webgpu_init_rope_pipeline(ctx); - ggml_webgpu_init_glu_pipeline(ctx); - ggml_webgpu_init_scale_pipeline(ctx); - ggml_webgpu_init_soft_max_pipeline(ctx); - -#ifdef GGML_WEBGPU_DEBUG - // Initialize debug buffers - ggml_webgpu_create_buffer(ctx->device, ctx->debug_host_buf, WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t), - wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "debug_host_buf"); - ggml_webgpu_create_buffer(ctx->device, ctx->debug_dev_buf, WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t), - wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc, "debug_dev_buf"); -#endif + create_webgpu_device(reg_ctx); static ggml_backend_webgpu_device_context device_ctx; - device_ctx.webgpu_ctx = ctx; - device_ctx.device_name = GGML_WEBGPU_NAME; - device_ctx.device_desc = info.description; - - GGML_LOG_INFO( - "ggml_webgpu: adapter_info: vendor_id: %u | vendor: %s | architecture: %s | device_id: %u | name: %s | " - "device_desc: %s\n", - info.vendorID, std::string(info.vendor).c_str(), std::string(info.architecture).c_str(), info.deviceID, - std::string(info.device).c_str(), std::string(info.description).c_str()); - + device_ctx.device_name = GGML_WEBGPU_NAME; + device_ctx.device_desc = GGML_WEBGPU_NAME; + device_ctx.webgpu_global_ctx = reg_ctx->webgpu_global_ctx; // See GGML Backend Device Interface section static ggml_backend_device device = { /* .iface = */ ggml_backend_webgpu_device_i, @@ -3354,7 +3390,7 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t /* .context = */ &device_ctx, }; - WEBGPU_CPU_PROFILE_TOTAL_END(reg_get_device, ctx); + WEBGPU_CPU_PROFILE_TOTAL_END(reg_get_device, reg_ctx->webgpu_global_ctx); return &device; } @@ -3370,10 +3406,7 @@ static const struct ggml_backend_reg_i ggml_backend_webgpu_reg_i = { ggml_backend_reg_t ggml_backend_webgpu_reg() { WEBGPU_LOG_DEBUG("ggml_backend_webgpu_reg()"); - webgpu_context webgpu_ctx = std::make_shared(); - static ggml_backend_webgpu_reg_context ctx; - ctx.webgpu_ctx = webgpu_ctx; ctx.name = GGML_WEBGPU_NAME; ctx.device_count = 1; @@ -3390,15 +3423,17 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() { instance_descriptor.nextInChain = &instanceTogglesDesc; #endif - webgpu_ctx->instance = wgpu::CreateInstance(&instance_descriptor); + wgpu::Instance inst = wgpu::CreateInstance(&instance_descriptor); + ctx.webgpu_global_ctx = webgpu_global_context(new webgpu_global_context_struct()); + ctx.webgpu_global_ctx->instance = std::move(inst); #ifdef __EMSCRIPTEN__ - if (webgpu_ctx->instance == nullptr) { + if (ctx.webgpu_global_ctx->instance == nullptr) { GGML_LOG_ERROR("ggml_webgpu: Failed to create WebGPU instance. Make sure either -sASYNCIFY or -sJSPI is set\n"); return nullptr; } #endif - GGML_ASSERT(webgpu_ctx->instance != nullptr); + GGML_ASSERT(ctx.webgpu_global_ctx->instance != nullptr); static ggml_backend_reg reg = { /* .api_version = */ GGML_BACKEND_API_VERSION, @@ -3411,7 +3446,7 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() { ggml_backend_t ggml_backend_webgpu_init(void) { ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_webgpu_reg(), 0); - return ggml_backend_webgpu_device_init(dev, nullptr); + return ggml_backend_webgpu_backend_init(dev, nullptr); } GGML_BACKEND_DL_IMPL(ggml_backend_webgpu_reg) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl deleted file mode 100644 index 1ce4d83fa8e..00000000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +++ /dev/null @@ -1,188 +0,0 @@ -#define(VARIANTS) - -[ - { - "SHADER_NAME": "add_f32", - "REPLS": { - "TYPE" : "f32", - "OP": "+" - }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "add_f16", - "REPLS": { - "TYPE" : "f16", - "OP": "+" - }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "add_f32_inplace", - "REPLS": { - "TYPE" : "f32", - "OP": "+" - }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "add_f16_inplace", - "REPLS": { - "TYPE" : "f16", - "OP": "+" - }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "mul_f32", - "REPLS": { - "TYPE" : "f32", - "OP": "*" - }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "mul_f16", - "REPLS": { - "TYPE" : "f16", - "OP": "*" - }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "mul_f32_inplace", - "REPLS": { - "TYPE" : "f32", - "OP": "*" - }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "mul_f16_inplace", - "REPLS": { - "TYPE" : "f16", - "OP": "*" - }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "sub_f32", - "REPLS": { - "TYPE" : "f32", - "OP": "-" - }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "sub_f16", - "REPLS": { - "TYPE" : "f16", - "OP": "-" - }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "sub_f32_inplace", - "REPLS": { - "TYPE" : "f32", - "OP": "-" - }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "sub_f16_inplace", - "REPLS": { - "TYPE" : "f16", - "OP": "-" - }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "div_f32", - "REPLS": { - "TYPE" : "f32", - "OP": "/" - }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "div_f16", - "REPLS": { - "TYPE" : "f16", - "OP": "/" - }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "div_f32_inplace", - "REPLS": { - "TYPE" : "f32", - "OP": "/" - }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "div_f16_inplace", - "REPLS": { - "TYPE" : "f16", - "OP": "/" - }, - "DECLS": ["INPLACE"] - } -] - -#end(VARIANTS) - -#define(DECLS) - -#decl(NOT_INPLACE) - -fn update(dst_i: u32, src0_i: u32, src1_i: u32) { - dst[dst_i] = src0[src0_i] {{OP}} src1[src1_i]; -} - -@group(0) @binding(2) -var dst: array<{{TYPE}}>; - -@group(0) @binding(3) -var params: Params; - -#enddecl(NOT_INPLACE) - -#decl(INPLACE) - -fn update(dst_i: u32, src0_i: u32, src1_i: u32) { - src0[dst_i] = src0[src0_i] {{OP}} src1[src1_i]; -} - -@group(0) @binding(2) -var params: Params; - -#enddecl(INPLACE) - -#end(DECLS) - - -#define(SHADER) - -enable f16; - -#include "binary_head.tmpl" - -@group(0) @binding(0) -var src0: array<{{TYPE}}>; - -@group(0) @binding(1) -var src1: array<{{TYPE}}>; - -DECLS - -override wg_size: u32; -@compute @workgroup_size(wg_size) -fn main(@builtin(global_invocation_id) gid: vec3) { - if (gid.x < params.ne) { - update(params.offset_dst + gid.x, params.offset_src0 + gid.x, params.offset_src1 + src1_index(gid.x)); - } -} - -#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl new file mode 100644 index 00000000000..55dd66408a3 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl @@ -0,0 +1,107 @@ +enable f16; + +struct Params { + ne: u32, + + // offsets in elements + offset_src0: u32, + offset_src1: u32, + offset_dst: u32, + + stride_src1_0: u32, + stride_src1_1: u32, + stride_src1_2: u32, + stride_src1_3: u32, + + a_ne0: u32, + a_ne1: u32, + a_ne2: u32, + + b_ne0: u32, + b_ne1: u32, + b_ne2: u32, + b_ne3: u32, +}; + +fn src1_index(_i: u32) -> u32 { + var i = _i; + let a_i3 = i / (params.a_ne2 * params.a_ne1 * params.a_ne0); + i = i % (params.a_ne2 * params.a_ne1 * params.a_ne0); + let a_i2 = i / (params.a_ne1 * params.a_ne0); + i = i % (params.a_ne1 * params.a_ne0); + let a_i1 = i / params.a_ne0; + let a_i0 = i % params.a_ne0; + + // handle repetition of b + // index loops back to the beginning and repeats after elements are exhausted = modulo + let b_i0 = a_i0 % params.b_ne0; + let b_i1 = a_i1 % params.b_ne1; + let b_i2 = a_i2 % params.b_ne2; + let b_i3 = a_i3 % params.b_ne3; + + // compute index for position in b's flat array + return b_i0 * params.stride_src1_0 + + b_i1 * params.stride_src1_1 + + b_i2 * params.stride_src1_2 + + b_i3 * params.stride_src1_3; +} + +#ifdef TYPE_F32 +#define DataType f32 +#endif +#ifdef TYPE_F16 +#define DataType f16 +#endif + +@group(0) @binding(0) +var src0: array; + +@group(0) @binding(1) +var src1 : array; + +#ifdef INPLACE +@group(0) @binding(2) +var params: Params; + +#elif defined(OVERLAP) +@group(0) @binding(2) +var params: Params; + +#else +@group(0) @binding(2) +var dst: array; + +@group(0) @binding(3) +var params: Params; +#endif + +fn op(a: DataType, b: DataType) -> DataType { +#ifdef OP_ADD + return a + b; +#elif defined(OP_SUB) + return a - b; +#elif defined(OP_MUL) + return a * b; +#elif defined(OP_DIV) + return a / b; +#endif +} + +fn update(dst_i: u32, src0_i: u32, src1_i: u32){ + let result = op(src0[src0_i], src1[src1_i]); + +#ifdef INPLACE + src0[dst_i] = result; +#elif defined(OVERLAP) + src1[dst_i] = result; +#else + dst[dst_i] = result; +#endif +} + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(global_invocation_id) gid: vec3) { + if (gid.x < params.ne) { + update(params.offset_dst + gid.x, params.offset_src0 + gid.x, params.offset_src1 + src1_index(gid.x)); + } +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl deleted file mode 100644 index 4b254f468d6..00000000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +++ /dev/null @@ -1,45 +0,0 @@ -struct Params { - ne: u32, - - // offsets in elements - offset_src0: u32, - offset_src1: u32, - offset_dst: u32, - - stride_src1_0: u32, - stride_src1_1: u32, - stride_src1_2: u32, - stride_src1_3: u32, - - a_ne0: u32, - a_ne1: u32, - a_ne2: u32, - - b_ne0: u32, - b_ne1: u32, - b_ne2: u32, - b_ne3: u32, -}; - -fn src1_index(_i: u32) -> u32 { - var i = _i; - let a_i3 = i / (params.a_ne2 * params.a_ne1 * params.a_ne0); - i = i % (params.a_ne2 * params.a_ne1 * params.a_ne0); - let a_i2 = i / (params.a_ne1 * params.a_ne0); - i = i % (params.a_ne1 * params.a_ne0); - let a_i1 = i / params.a_ne0; - let a_i0 = i % params.a_ne0; - - // handle repetition of b - // index loops back to the beginning and repeats after elements are exhausted = modulo - let b_i0 = a_i0 % params.b_ne0; - let b_i1 = a_i1 % params.b_ne1; - let b_i2 = a_i2 % params.b_ne2; - let b_i3 = a_i3 % params.b_ne3; - - // compute index for position in b's flat array - return b_i0 * params.stride_src1_0 + - b_i1 * params.stride_src1_1 + - b_i2 * params.stride_src1_2 + - b_i3 * params.stride_src1_3; -} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl index de7c132a624..b6822161464 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl @@ -114,7 +114,7 @@ struct Params { #define PARAMS_BINDING 4 #endif -@group(0) @binding(DST_BINDING) var dst: array; +@group(0) @binding(DST_BINDING) var dst: array>; @group(0) @binding(PARAMS_BINDING) var params: Params; // Just a very small float value. @@ -160,14 +160,21 @@ fn calc_softmax_term(kv_idx: u32, q_tile_row: u32, slope: f32) -> f32 { return v; } +fn load_f32x4(buf: ptr>, read_write>, scalar_index: u32) -> vec4 { + return (*buf)[scalar_index >> 2u]; +} + +fn load_kvx4(buf: ptr>, read_write>, scalar_index: u32) -> vec4 { + return (*buf)[scalar_index >> 2u]; +} @compute @workgroup_size(WG_SIZE) fn main(@builtin(workgroup_id) wg_id: vec3, - @builtin(local_invocation_id) local_id: vec3, - @builtin(subgroup_id) subgroup_id: u32, - @builtin(subgroup_size) subgroup_size: u32, - @builtin(num_subgroups) num_subgroups: u32, - @builtin(subgroup_invocation_id) sg_inv_id: u32) { + @builtin(local_invocation_id) local_id: vec3, + @builtin(subgroup_id) subgroup_id: u32, + @builtin(subgroup_size) subgroup_size: u32, + @builtin(num_subgroups) num_subgroups: u32, + @builtin(subgroup_invocation_id) sg_inv_id: u32) { // initialize row max for online softmax for (var i = local_id.x; i < Q_TILE; i += WG_SIZE) { @@ -231,9 +238,9 @@ fn main(@builtin(workgroup_id) wg_id: vec3, for (var kv_tile = 0u; kv_tile < params.seq_len_kv; kv_tile += KV_TILE) { // clear inter_shmem to ensure zero-initialized accumulators - for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) { - inter_shmem[elem_idx] = 0.0; - } + for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) { + inter_shmem[elem_idx] = 0.0; + } // load k tile into shared memory #if defined(KV_Q4_0) @@ -309,48 +316,77 @@ fn main(@builtin(workgroup_id) wg_id: vec3, // accumulate q block * k block into registers across the entire KV tile // TODO: this loop seems to be the current largest bottleneck - for (var kv_block = subgroup_id; kv_block < KV_BLOCKS; kv_block += num_subgroups) { - let inter_offset = kv_block * SG_MAT_N; - var acc: subgroup_matrix_result = subgroupMatrixLoad< - subgroup_matrix_result>(&inter_shmem, inter_offset, false, KV_TILE); + // this bracket exists to scope the lifetime of variables, reducing register pressure + { #ifdef KV_DIRECT - let k_block_row = kv_tile + kv_block * SG_MAT_N; - let k_global_offset = k_head_offset + k_block_row * params.stride_k1; + let k_block_row = kv_tile + subgroup_id * SG_MAT_N; + var k_global_offset = k_head_offset + k_block_row * params.stride_k1; #else - let k_block_offset = kv_block * SG_MAT_N * HEAD_DIM_QK; + var k_block_offset = subgroup_id * SG_MAT_N * HEAD_DIM_QK; #endif - for (var head_dim_block = 0u; head_dim_block < HEAD_DIM_QK; head_dim_block += SG_MAT_K) { - // load q submatrix from shared memory - var q_sg_mat: subgroup_matrix_left = subgroupMatrixLoad>( - &q_shmem, - head_dim_block, - false, - HEAD_DIM_QK - ); + for (var kv_block = subgroup_id; kv_block < KV_BLOCKS; kv_block += num_subgroups) { + let inter_offset = kv_block * SG_MAT_N; + var acc: subgroup_matrix_result = subgroupMatrixLoad>(&inter_shmem, inter_offset, false, KV_TILE); + + var q_cur = subgroupMatrixLoad>(&q_shmem, 0u, false, HEAD_DIM_QK); - // load k submatrix from device or shared memory #ifdef KV_DIRECT - var k_sg_mat: subgroup_matrix_right = subgroupMatrixLoad>( - &K, - k_global_offset + head_dim_block, - true, - params.stride_k1 - ); + var k_cur = subgroupMatrixLoad>(&K, k_global_offset + 0u, true, params.stride_k1); #else - var k_sg_mat: subgroup_matrix_right = subgroupMatrixLoad>( - &kv_shmem, - k_block_offset + head_dim_block, - true, - HEAD_DIM_QK - ); + var k_cur = subgroupMatrixLoad>(&kv_shmem, k_block_offset + 0u, true, HEAD_DIM_QK); #endif - acc = subgroupMatrixMultiplyAccumulate(q_sg_mat, k_sg_mat, acc); - } - // store acc to shared memory for softmax (S matrix from paper) - subgroupMatrixStore(&inter_shmem, inter_offset, acc, false, KV_TILE); + var t: u32 = 1u; + for (; t + 1u < HEAD_DIM_QK / SG_MAT_K; t += 2u) { + let h0 = t * SG_MAT_K; + var q0 = subgroupMatrixLoad>(&q_shmem, h0, false, HEAD_DIM_QK); +#ifdef KV_DIRECT + var k0 = subgroupMatrixLoad>(&K, k_global_offset + h0, true, params.stride_k1); +#else + var k0 = subgroupMatrixLoad>(&kv_shmem, k_block_offset + h0, true, HEAD_DIM_QK); +#endif + acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc); + q_cur = q0; + k_cur = k0; + + let h1 = (t + 1u) * SG_MAT_K; + var q1g = subgroupMatrixLoad>(&q_shmem, h1, false, HEAD_DIM_QK); +#ifdef KV_DIRECT + var k1g = subgroupMatrixLoad>(&K, k_global_offset + h1, true, params.stride_k1); +#else + var k1g = subgroupMatrixLoad>(&kv_shmem, k_block_offset + h1, true, HEAD_DIM_QK); +#endif + acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc); + q_cur = q1g; + k_cur = k1g; + } + + // handle odd tail + if (t < HEAD_DIM_QK / SG_MAT_K) { + let h = t * SG_MAT_K; + var qn = subgroupMatrixLoad>(&q_shmem, h, false, HEAD_DIM_QK); +#ifdef KV_DIRECT + var kn = subgroupMatrixLoad>(&K, k_global_offset + h, true, params.stride_k1); +#else + var kn = subgroupMatrixLoad>(&kv_shmem, k_block_offset + h, true, HEAD_DIM_QK); +#endif + acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc); + q_cur = qn; + k_cur = kn; + } + + acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc); + +#ifdef KV_DIRECT + k_global_offset += num_subgroups * SG_MAT_N * params.stride_k1; +#else + k_block_offset += num_subgroups * SG_MAT_N * HEAD_DIM_QK; +#endif + subgroupMatrixStore(&inter_shmem, inter_offset, acc, false, KV_TILE); + } } + #ifdef MASK // load mask tile into shared memory for this KV block // TODO: optimize and skip if mask is -INF for the entire tile @@ -495,7 +531,6 @@ fn main(@builtin(workgroup_id) wg_id: vec3, false, HEAD_DIM_V ); - for (var kv_block = 0u; kv_block < KV_BLOCKS; kv_block++) { let p_offset = kv_block * SG_MAT_N; var p_sg_mat: subgroup_matrix_left = subgroupMatrixLoad>( @@ -527,11 +562,9 @@ fn main(@builtin(workgroup_id) wg_id: vec3, // O += P * V o_sg_mat = subgroupMatrixMultiplyAccumulate(p_sg_mat, v_sg_mat, o_sg_mat); } - // store O back to shared memory subgroupMatrixStore(&o_shmem, head_dim_block, o_sg_mat, false, HEAD_DIM_V); } - workgroupBarrier(); } @@ -566,26 +599,38 @@ fn main(@builtin(workgroup_id) wg_id: vec3, o_shmem[idx] = f16(val); } } - workgroupBarrier(); #endif - - // write output back to global memory for (var q_tile_row = subgroup_id; - q_tile_row < Q_TILE; - q_tile_row += num_subgroups) { - let global_q_row = q_row_start + q_tile_row; - if (global_q_row >= params.seq_len_q) { - break; - } + q_tile_row < Q_TILE; + q_tile_row += num_subgroups) { - let exp_sum = exp_sum_shmem[q_tile_row]; - let scale = select(0.0, 1.0 / exp_sum, exp_sum != 0); + let global_q_row = q_row_start + q_tile_row; + if (global_q_row >= params.seq_len_q) { break; } - for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) { - let o_val = o_shmem[q_tile_row * HEAD_DIM_V + elem_idx]; - let scaled = f32(o_val) * scale; - dst[dst_global_offset + q_tile_row * dst2_stride + elem_idx] = scaled; - } + let exp_sum = exp_sum_shmem[q_tile_row]; + let scale = select(0.0, 1.0 / exp_sum, exp_sum != 0.0); + + let row_base: u32 = dst_global_offset + q_tile_row * dst2_stride; + + for (var elem_base = sg_inv_id * 4u; + elem_base < HEAD_DIM_V; + elem_base += subgroup_size * 4u) { + + let i0 = q_tile_row * HEAD_DIM_V + (elem_base + 0u); + let i1 = q_tile_row * HEAD_DIM_V + (elem_base + 1u); + let i2 = q_tile_row * HEAD_DIM_V + (elem_base + 2u); + let i3 = q_tile_row * HEAD_DIM_V + (elem_base + 3u); + + let v = vec4( + f32(o_shmem[i0]) * scale, + f32(o_shmem[i1]) * scale, + f32(o_shmem[i2]) * scale, + f32(o_shmem[i3]) * scale + ); + + let dst_vec_index: u32 = (row_base + elem_base) >> 2u; + dst[dst_vec_index] = v; + } } } diff --git a/ggml/src/ggml-zdnn/ggml-zdnn.cpp b/ggml/src/ggml-zdnn/ggml-zdnn.cpp index 906d25417e4..9b6938abf7e 100644 --- a/ggml/src/ggml-zdnn/ggml-zdnn.cpp +++ b/ggml/src/ggml-zdnn/ggml-zdnn.cpp @@ -372,7 +372,8 @@ static size_t ggml_backend_zdnn_buffer_type_get_alignment(ggml_backend_buffer_ty } static bool ggml_backend_zdnn_buffer_type_is_host(ggml_backend_buffer_type_t buft) { - return true; + /* while it resides in host memory, additional transformation is needed */ + return false; GGML_UNUSED(buft); } diff --git a/ggml/src/ggml-zendnn/CMakeLists.txt b/ggml/src/ggml-zendnn/CMakeLists.txt index bdbfc74369f..f5cf6eedd3a 100644 --- a/ggml/src/ggml-zendnn/CMakeLists.txt +++ b/ggml/src/ggml-zendnn/CMakeLists.txt @@ -21,7 +21,7 @@ if (NOT ZENDNN_ROOT OR ZENDNN_ROOT STREQUAL "" OR ZENDNN_ROOT STREQUAL "OFF") ExternalProject_Add( zendnn GIT_REPOSITORY https://github.com/amd/ZenDNN.git - GIT_TAG zendnnl + GIT_TAG 21ce8f7879c86bf3637f707fae6f29e0951db5fe PREFIX ${ZENDNN_PREFIX} SOURCE_DIR ${ZENDNN_SOURCE_DIR} BINARY_DIR ${ZENDNN_BUILD_DIR} diff --git a/ggml/src/ggml-zendnn/ggml-zendnn.cpp b/ggml/src/ggml-zendnn/ggml-zendnn.cpp index afbecde7a5a..551c15bb4ae 100644 --- a/ggml/src/ggml-zendnn/ggml-zendnn.cpp +++ b/ggml/src/ggml-zendnn/ggml-zendnn.cpp @@ -2,7 +2,6 @@ #include "ggml-backend-impl.h" #include "ggml-impl.h" -#include "ggml-cpu.h" #include "zendnnl.hpp" #include @@ -122,8 +121,8 @@ static void ggml_zendnn_compute_forward_mul_mat( GGML_TENSOR_BINARY_OP_LOCALS - ggml_type const vec_dot_type = ggml_get_type_traits_cpu(src0->type)->vec_dot_type; - ggml_from_float_t const from_float = ggml_get_type_traits_cpu(vec_dot_type)->from_float; + ggml_type const vec_dot_type = src0->type; + ggml_from_float_t const from_float = ggml_get_type_traits(vec_dot_type)->from_float_ref; GGML_ASSERT(ne0 == ne01); GGML_ASSERT(ne1 == ne11); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 1725ad16545..500cb6b72f9 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -6562,7 +6562,7 @@ static void ggml_compute_backward( case GGML_OP_DIAG_MASK_INF: { if (src0_needs_grads) { /* ggml_diag_mask_inf_impl() shouldn't be here */ - /* ref: https://github.com/ggerganov/llama.cpp/pull/4203#discussion_r1412377992 */ + /* ref: https://github.com/ggml-org/llama.cpp/pull/4203#discussion_r1412377992 */ const int n_past = ((const int32_t *) tensor->op_params)[0]; ggml_add_or_set(ctx, cgraph, isrc0, ggml_diag_mask_zero_impl(ctx, grad, n_past, false)); } @@ -7517,8 +7517,11 @@ void ggml_quantize_free(void) { iq2xs_free_impl(GGML_TYPE_IQ2_XXS); iq2xs_free_impl(GGML_TYPE_IQ2_XS); + iq2xs_free_impl(GGML_TYPE_IQ2_S); iq2xs_free_impl(GGML_TYPE_IQ1_S); + iq2xs_free_impl(GGML_TYPE_IQ1_M); iq3xs_free_impl(256); + iq3xs_free_impl(512); ggml_critical_section_end(); } diff --git a/ggml/src/gguf.cpp b/ggml/src/gguf.cpp index b165d8bdc62..ed0d7f2cae1 100644 --- a/ggml/src/gguf.cpp +++ b/ggml/src/gguf.cpp @@ -585,6 +585,14 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par break; } + // check that the size of the tensor in bytes is representable + if (ok && uint64_t(ggml_nelements(&info.t)/ggml_blck_size(info.t.type)) > SIZE_MAX/ggml_type_size(info.t.type)) { + GGML_LOG_ERROR("%s: tensor '%s' with shape (%" PRIi64 ", %" PRIi64 ", %" PRIi64 ", %" PRIi64 ") has a size in bytes > %zu\n", + __func__, info.t.name, info.t.ne[0], info.t.ne[1], info.t.ne[2], info.t.ne[3], SIZE_MAX); + ok = false; + break; + } + // calculate byte offsets given the tensor shape and type info.t.nb[0] = type_size; info.t.nb[1] = info.t.nb[0]*(info.t.ne[0]/blck_size); @@ -734,7 +742,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p FILE * file = ggml_fopen(fname, "rb"); if (!file) { - GGML_LOG_ERROR("%s: failed to open GGUF file '%s'\n", __func__, fname); + GGML_LOG_ERROR("%s: failed to open GGUF file '%s' (%s)\n", __func__, fname, strerror(errno)); return nullptr; } diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 31273b2b5a7..f27d8ca4cba 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -146,6 +146,8 @@ class LLM: ALTUP_ACTIVE_IDX = "{arch}.altup.active_idx" ALTUP_NUM_INPUTS = "{arch}.altup.num_inputs" EMBD_LENGTH_PER_LAYER_INP = "{arch}.embedding_length_per_layer_input" + SWIGLU_CLAMP_EXP = "{arch}.swiglu_clamp_exp" + SWIGLU_CLAMP_SHEXP = "{arch}.swiglu_clamp_shexp" DENSE_FEAT_IN_SIZE = "{arch}.{dense}_feat_in" DENSE_FEAT_OUT_SIZE = "{arch}.{dense}_feat_out" @@ -179,20 +181,20 @@ class Attention: TEMPERATURE_SCALE = "{arch}.attention.temperature_scale" class Rope: - DIMENSION_COUNT = "{arch}.rope.dimension_count" - DIMENSION_SECTIONS = "{arch}.rope.dimension_sections" - FREQ_BASE = "{arch}.rope.freq_base" - FREQ_BASE_SWA = "{arch}.rope.freq_base_swa" - SCALING_TYPE = "{arch}.rope.scaling.type" - SCALING_FACTOR = "{arch}.rope.scaling.factor" - SCALING_ATTN_FACTOR = "{arch}.rope.scaling.attn_factor" - SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length" - SCALING_FINETUNED = "{arch}.rope.scaling.finetuned" - SCALING_YARN_LOG_MUL = "{arch}.rope.scaling.yarn_log_multiplier" - SCALING_YARN_EXT_FACTOR = "{arch}.rope.scaling.yarn_ext_factor" - SCALING_YARN_ATTN_FACTOR = "{arch}.rope.scaling.yarn_attn_factor" - SCALING_YARN_BETA_FAST = "{arch}.rope.scaling.yarn_beta_fast" - SCALING_YARN_BETA_SLOW = "{arch}.rope.scaling.yarn_beta_slow" + DIMENSION_COUNT = "{arch}.rope.dimension_count" + DIMENSION_SECTIONS = "{arch}.rope.dimension_sections" + FREQ_BASE = "{arch}.rope.freq_base" + FREQ_BASE_SWA = "{arch}.rope.freq_base_swa" + SCALING_TYPE = "{arch}.rope.scaling.type" + SCALING_FACTOR = "{arch}.rope.scaling.factor" + SCALING_ATTN_FACTOR = "{arch}.rope.scaling.attn_factor" + SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length" + SCALING_FINETUNED = "{arch}.rope.scaling.finetuned" + SCALING_YARN_LOG_MUL = "{arch}.rope.scaling.yarn_log_multiplier" + SCALING_YARN_EXT_FACTOR = "{arch}.rope.scaling.yarn_ext_factor" + SCALING_YARN_ATTN_FACTOR = "{arch}.rope.scaling.yarn_attn_factor" + SCALING_YARN_BETA_FAST = "{arch}.rope.scaling.yarn_beta_fast" + SCALING_YARN_BETA_SLOW = "{arch}.rope.scaling.yarn_beta_slow" class Split: LLM_KV_SPLIT_NO = "split.no" @@ -207,6 +209,9 @@ class SSM: GROUP_COUNT = "{arch}.ssm.group_count" DT_B_C_RMS = "{arch}.ssm.dt_b_c_rms" + class KDA: + HEAD_DIM = "{arch}.kda.head_dim" + class WKV: HEAD_SIZE = "{arch}.wkv.head_size" @@ -284,6 +289,8 @@ class Clip: class ClipVision: PROJECTOR_TYPE = "clip.vision.projector_type" # for mixed modality models IMAGE_SIZE = "clip.vision.image_size" + IMAGE_MIN_PIXELS = "clip.vision.image_min_pixels" + IMAGE_MAX_PIXELS = "clip.vision.image_max_pixels" PREPROC_IMAGE_SIZE = "clip.vision.preproc_image_size" PATCH_SIZE = "clip.vision.patch_size" EMBEDDING_LENGTH = "clip.vision.embedding_length" @@ -375,6 +382,8 @@ class MODEL_ARCH(IntEnum): QWEN3 = auto() QWEN3MOE = auto() QWEN3NEXT = auto() + QWEN3_5 = auto() + QWEN3_5_MOE = auto() QWEN3VL = auto() QWEN3VLMOE = auto() PHI2 = auto() @@ -457,8 +466,10 @@ class MODEL_ARCH(IntEnum): PANGU_EMBED = auto() MISTRAL3 = auto() MIMO2 = auto() + STEP35 = auto() LLAMA_EMBED = auto() MAINCODER = auto() + KIMI_LINEAR = auto() class VISION_PROJECTOR_TYPE(IntEnum): @@ -549,6 +560,14 @@ class MODEL_TENSOR(IntEnum): SSM_NORM = auto() SSM_OUT = auto() SSM_BETA_ALPHA = auto() # qwen3next + SSM_CONV1D_Q = auto() # Kimi Linear + SSM_CONV1D_K = auto() # Kimi Linear + SSM_CONV1D_V = auto() # Kimi Linear + SSM_F_A = auto() # Kimi Linear + SSM_F_B = auto() # Kimi Linear + SSM_BETA = auto() # Kimi Linear + SSM_G_A = auto() # Kimi Linear + SSM_G_B = auto() # Kimi Linear TIME_MIX_W0 = auto() TIME_MIX_W1 = auto() TIME_MIX_W2 = auto() @@ -668,9 +687,13 @@ class MODEL_TENSOR(IntEnum): V_ENC_ATTN_O = auto() V_ENC_ATTN_O_NORM = auto() V_ENC_POST_ATTN_NORM = auto() + V_ENC_ATTN_LN = auto() V_ENC_FFN_UP = auto() V_ENC_FFN_GATE = auto() V_ENC_FFN_DOWN = auto() + V_ENC_FFN_NORM = auto() + V_ENC_ATTN_Q_BIAS = auto() + V_ENC_ATTN_V_BIAS = auto() V_LAYER_SCALE_1 = auto() V_LAYER_SCALE_2 = auto() V_PRE_NORM = auto() @@ -795,6 +818,8 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.QWEN3: "qwen3", MODEL_ARCH.QWEN3MOE: "qwen3moe", MODEL_ARCH.QWEN3NEXT: "qwen3next", + MODEL_ARCH.QWEN3_5: "qwen3_5", + MODEL_ARCH.QWEN3_5_MOE: "qwen3_5moe", MODEL_ARCH.QWEN3VL: "qwen3vl", MODEL_ARCH.QWEN3VLMOE: "qwen3vlmoe", MODEL_ARCH.PHI2: "phi2", @@ -878,8 +903,10 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.PANGU_EMBED: "pangu-embedded", MODEL_ARCH.MISTRAL3: "mistral3", MODEL_ARCH.MIMO2: "mimo2", + MODEL_ARCH.STEP35: "step35", MODEL_ARCH.LLAMA_EMBED: "llama-embed", MODEL_ARCH.MAINCODER: "maincoder", + MODEL_ARCH.KIMI_LINEAR: "kimi-linear", } VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = { @@ -967,6 +994,14 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.SSM_NORM: "blk.{bid}.ssm_norm", MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out", MODEL_TENSOR.SSM_BETA_ALPHA: "blk.{bid}.ssm_ba", + MODEL_TENSOR.SSM_CONV1D_Q: "blk.{bid}.ssm_conv1d_q", # Kimi Linear + MODEL_TENSOR.SSM_CONV1D_K: "blk.{bid}.ssm_conv1d_k", # Kimi Linear + MODEL_TENSOR.SSM_CONV1D_V: "blk.{bid}.ssm_conv1d_v", # Kimi Linear + MODEL_TENSOR.SSM_F_A: "blk.{bid}.ssm_f_a", # Kimi Linear + MODEL_TENSOR.SSM_F_B: "blk.{bid}.ssm_f_b", # Kimi Linear + MODEL_TENSOR.SSM_BETA: "blk.{bid}.ssm_beta", # Kimi Linear + MODEL_TENSOR.SSM_G_A: "blk.{bid}.ssm_g_a", # Kimi Linear + MODEL_TENSOR.SSM_G_B: "blk.{bid}.ssm_g_b", # Kimi Linear MODEL_TENSOR.TIME_MIX_W0: "blk.{bid}.time_mix_w0", MODEL_TENSOR.TIME_MIX_W1: "blk.{bid}.time_mix_w1", MODEL_TENSOR.TIME_MIX_W2: "blk.{bid}.time_mix_w2", @@ -1086,9 +1121,13 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.V_ENC_ATTN_O: "v.blk.{bid}.attn_out", MODEL_TENSOR.V_ENC_ATTN_O_NORM: "v.blk.{bid}.attn_out_norm", MODEL_TENSOR.V_ENC_POST_ATTN_NORM: "v.blk.{bid}.ln2", + MODEL_TENSOR.V_ENC_ATTN_LN: "v.blk.{bid}.attn_ln", MODEL_TENSOR.V_ENC_FFN_UP: "v.blk.{bid}.ffn_up", MODEL_TENSOR.V_ENC_FFN_GATE: "v.blk.{bid}.ffn_gate", MODEL_TENSOR.V_ENC_FFN_DOWN: "v.blk.{bid}.ffn_down", + MODEL_TENSOR.V_ENC_FFN_NORM: "v.blk.{bid}.ffn_norm", + MODEL_TENSOR.V_ENC_ATTN_Q_BIAS: "v.blk.{bid}.attn_q.bias", + MODEL_TENSOR.V_ENC_ATTN_V_BIAS: "v.blk.{bid}.attn_v.bias", MODEL_TENSOR.V_LAYER_SCALE_1: "v.blk.{bid}.ls1", MODEL_TENSOR.V_LAYER_SCALE_2: "v.blk.{bid}.ls2", MODEL_TENSOR.V_PRE_NORM: "v.pre_ln", @@ -1204,9 +1243,13 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.V_ENC_ATTN_O, MODEL_TENSOR.V_ENC_ATTN_O_NORM, MODEL_TENSOR.V_ENC_POST_ATTN_NORM, + MODEL_TENSOR.V_ENC_ATTN_LN, MODEL_TENSOR.V_ENC_FFN_UP, MODEL_TENSOR.V_ENC_FFN_GATE, MODEL_TENSOR.V_ENC_FFN_DOWN, + MODEL_TENSOR.V_ENC_FFN_NORM, + MODEL_TENSOR.V_ENC_ATTN_Q_BIAS, + MODEL_TENSOR.V_ENC_ATTN_V_BIAS, MODEL_TENSOR.V_LAYER_SCALE_1, MODEL_TENSOR.V_LAYER_SCALE_2, MODEL_TENSOR.V_PRE_NORM, @@ -1757,6 +1800,61 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.SSM_BETA_ALPHA, MODEL_TENSOR.SSM_OUT ], + MODEL_ARCH.QWEN3_5: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_POST_NORM, + MODEL_TENSOR.ATTN_GATE, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.SSM_A, + MODEL_TENSOR.SSM_CONV1D, + MODEL_TENSOR.SSM_DT, + MODEL_TENSOR.SSM_NORM, + MODEL_TENSOR.SSM_IN, + MODEL_TENSOR.SSM_BETA_ALPHA, + MODEL_TENSOR.SSM_OUT, + ], + MODEL_ARCH.QWEN3_5_MOE: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_POST_NORM, + MODEL_TENSOR.ATTN_GATE, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_GATE_INP_SHEXP, + MODEL_TENSOR.FFN_UP_SHEXP, + MODEL_TENSOR.FFN_DOWN_SHEXP, + MODEL_TENSOR.FFN_GATE_SHEXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.SSM_A, + MODEL_TENSOR.SSM_CONV1D, + MODEL_TENSOR.SSM_DT, + MODEL_TENSOR.SSM_NORM, + MODEL_TENSOR.SSM_IN, + MODEL_TENSOR.SSM_BETA_ALPHA, + MODEL_TENSOR.SSM_OUT, + ], MODEL_ARCH.QWEN3VL: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, @@ -3341,6 +3439,32 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_UP_EXP, MODEL_TENSOR.FFN_EXP_PROBS_B, ], + MODEL_ARCH.STEP35: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_GATE, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.FFN_UP_SHEXP, + MODEL_TENSOR.FFN_GATE_SHEXP, + MODEL_TENSOR.FFN_DOWN_SHEXP, + MODEL_TENSOR.FFN_EXP_PROBS_B, + ], MODEL_ARCH.LLAMA_EMBED: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, @@ -3377,6 +3501,47 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, ], + MODEL_ARCH.KIMI_LINEAR: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_Q_A, + MODEL_TENSOR.ATTN_Q_B, + MODEL_TENSOR.ATTN_KV_A_MQA, + MODEL_TENSOR.ATTN_KV_B, + MODEL_TENSOR.ATTN_K_B, + MODEL_TENSOR.ATTN_V_B, + MODEL_TENSOR.ATTN_Q_A_NORM, + MODEL_TENSOR.ATTN_KV_A_NORM, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.SSM_CONV1D_Q, + MODEL_TENSOR.SSM_CONV1D_K, + MODEL_TENSOR.SSM_CONV1D_V, + MODEL_TENSOR.SSM_F_A, + MODEL_TENSOR.SSM_F_B, + MODEL_TENSOR.SSM_BETA, + MODEL_TENSOR.SSM_A, + MODEL_TENSOR.SSM_G_A, + MODEL_TENSOR.SSM_G_B, + MODEL_TENSOR.SSM_DT, + MODEL_TENSOR.SSM_NORM, + MODEL_TENSOR.FFN_EXP_PROBS_B, + MODEL_TENSOR.FFN_GATE_SHEXP, + MODEL_TENSOR.FFN_DOWN_SHEXP, + MODEL_TENSOR.FFN_UP_SHEXP, + ], # TODO } @@ -3604,6 +3769,7 @@ class VisionProjectorType: QWEN3VL = "qwen3vl_merger" ULTRAVOX = "ultravox" INTERNVL = "internvl" + JINACLIP2 = "jinaclip2" QWEN2A = "qwen2a" # audio GLMA = "glma" # audio QWEN25O = "qwen2.5o" # omni @@ -3689,12 +3855,12 @@ class VisionProjectorType: KEY_ATTENTION_LAYERNORM_RMS_EPS = Keys.Attention.LAYERNORM_RMS_EPS # RoPE -KEY_ROPE_DIMENSION_COUNT = Keys.Rope.DIMENSION_COUNT -KEY_ROPE_FREQ_BASE = Keys.Rope.FREQ_BASE -KEY_ROPE_SCALING_TYPE = Keys.Rope.SCALING_TYPE -KEY_ROPE_SCALING_FACTOR = Keys.Rope.SCALING_FACTOR -KEY_ROPE_SCALING_ORIG_CTX_LEN = Keys.Rope.SCALING_ORIG_CTX_LEN -KEY_ROPE_SCALING_FINETUNED = Keys.Rope.SCALING_FINETUNED +KEY_ROPE_DIMENSION_COUNT = Keys.Rope.DIMENSION_COUNT +KEY_ROPE_FREQ_BASE = Keys.Rope.FREQ_BASE +KEY_ROPE_SCALING_TYPE = Keys.Rope.SCALING_TYPE +KEY_ROPE_SCALING_FACTOR = Keys.Rope.SCALING_FACTOR +KEY_ROPE_SCALING_ORIG_CTX_LEN = Keys.Rope.SCALING_ORIG_CTX_LEN +KEY_ROPE_SCALING_FINETUNED = Keys.Rope.SCALING_FINETUNED # SSM KEY_SSM_CONV_KERNEL = Keys.SSM.CONV_KERNEL @@ -3704,6 +3870,9 @@ class VisionProjectorType: KEY_SSM_GROUP_COUNT = Keys.SSM.GROUP_COUNT KEY_SSM_DT_B_C_RMS = Keys.SSM.DT_B_C_RMS +# KDA +KEY_KDA_HEAD_DIM = Keys.KDA.HEAD_DIM + # tokenization KEY_TOKENIZER_MODEL = Keys.Tokenizer.MODEL KEY_TOKENIZER_PRE = Keys.Tokenizer.PRE diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 7fbb78866bc..62172b24c38 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -824,6 +824,12 @@ def add_expert_weights_norm(self, value: bool) -> None: def add_expert_gating_func(self, value: ExpertGatingFuncType) -> None: self.add_uint32(Keys.LLM.EXPERT_GATING_FUNC.format(arch=self.arch), value.value) + def add_swiglu_clamp_exp(self, values: Sequence[float]) -> None: + self.add_array(Keys.LLM.SWIGLU_CLAMP_EXP.format(arch=self.arch), values) + + def add_swiglu_clamp_shexp(self, values: Sequence[float]) -> None: + self.add_array(Keys.LLM.SWIGLU_CLAMP_SHEXP.format(arch=self.arch), values) + def add_expert_group_scale(self, value: float) -> None: self.add_float32(Keys.LLM.EXPERT_GROUP_SCALE.format(arch=self.arch), value) @@ -980,6 +986,9 @@ def add_ssm_group_count(self, value: int) -> None: def add_ssm_dt_b_c_rms(self, value: bool) -> None: self.add_bool(Keys.SSM.DT_B_C_RMS.format(arch=self.arch), value) + def add_kda_head_dim(self, value: int) -> None: + self.add_uint32(Keys.KDA.HEAD_DIM.format(arch=self.arch), value) + def add_tokenizer_model(self, model: str) -> None: self.add_string(Keys.Tokenizer.MODEL, model) @@ -1113,6 +1122,12 @@ def add_vision_attention_layernorm_eps(self, value: float) -> None: def add_vision_image_size(self, value: int) -> None: self.add_uint32(Keys.ClipVision.IMAGE_SIZE, value) + def add_vision_max_pixels(self, value: int) -> None: + self.add_uint32(Keys.ClipVision.IMAGE_MAX_PIXELS, value) + + def add_vision_min_pixels(self, value: int) -> None: + self.add_uint32(Keys.ClipVision.IMAGE_MIN_PIXELS, value) + def add_vision_preproc_image_size(self, value: int) -> None: self.add_uint32(Keys.ClipVision.PREPROC_IMAGE_SIZE, value) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 84aa8688092..00ac85cd484 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -228,6 +228,7 @@ class TensorNameMap: "transformer_encoder.{bid}.qkv", # neobert "layers.{bid}.attn.Wqkv", # modern-bert "model.layers.{bid}.self_attn.language_expert_query_key_value", # cogvlm + "model.layers.{bid}.linear_attn.in_proj_qkv", # qwen3.5 ), # Attention query @@ -358,7 +359,9 @@ class TensorNameMap: ), MODEL_TENSOR.ATTN_GATE: ( - "model.layers.{bid}.self_attn.gate_proj", # afmoe + "model.layers.{bid}.self_attn.gate_proj", # afmoe + "model.layers.{bid}.self_attn.g_proj", # step3.5 head-wise attention gate + "model.layers.{bid}.linear_attn.in_proj_z", # qwen3.5 ), # Feed-forward norm @@ -423,6 +426,7 @@ class TensorNameMap: "model.layers.{bid}.mlp.router.gate", # afmoe "layers.{bid}.gate", # mistral-large "backbone.layers.{bid}.mixer.gate", # nemotron-h-moe + "model.layers.{bid}.moe.gate", # step3.5 ), MODEL_TENSOR.FFN_GATE_INP_SHEXP: ( @@ -438,6 +442,8 @@ class TensorNameMap: "model.layers.{bid}.block_sparse_moe.e_score_correction", # minimax-m2 "backbone.layers.{bid}.mixer.gate.e_score_correction", # nemotron-h-moe "model.layers.{bid}.mlp.e_score_correction", # exaone-moe + "model.layers.{bid}.block_sparse_moe.gate.e_score_correction", # kimi + "model.layers.{bid}.moe.router_bias", # step3.5 expert selection bias ), # Feed-forward up @@ -492,6 +498,7 @@ class TensorNameMap: "model.layers.{bid}.feed_forward.experts.up_proj", # llama4 "encoder.layers.{bid}.mlp.experts.mlp.w1", # nomic-bert-moe "model.layers.{bid}.block_sparse_moe.experts.up", # smallthinker + "model.layers.{bid}.moe.up_proj", # step3.5 ), MODEL_TENSOR.FFN_UP_SHEXP: ( @@ -502,6 +509,8 @@ class TensorNameMap: "model.layers.{bid}.mlp.shared_mlp.up_proj", # hunyuan "layers.{bid}.shared_experts.w3", # mistral-large "backbone.layers.{bid}.mixer.shared_experts.up_proj", # nemotron-h-moe + "model.layers.{bid}.block_sparse_moe.shared_experts.up_proj", # kimi + "model.layers.{bid}.share_expert.up_proj", # step3.5 ), MODEL_TENSOR.FFN_UP_CHEXP: ( @@ -541,6 +550,7 @@ class TensorNameMap: "model.layers.{bid}.block_sparse_moe.experts.w1", # phimoe (merged) "model.layers.{bid}.feed_forward.experts.gate_proj", # llama4 "model.layers.{bid}.block_sparse_moe.experts.gate", # smallthinker + "model.layers.{bid}.moe.gate_proj", # step3.5 ), MODEL_TENSOR.FFN_GATE_SHEXP: ( @@ -549,6 +559,8 @@ class TensorNameMap: "model.layers.{bid}.feed_forward.shared_expert.gate_proj", # llama4 "model.layers.{bid}.mlp.shared_mlp.gate_proj", # hunyuan "layers.{bid}.shared_experts.w1", # mistral-large + "model.layers.{bid}.block_sparse_moe.shared_experts.gate_proj", # kimi + "model.layers.{bid}.share_expert.gate_proj", # step3.5 ), MODEL_TENSOR.FFN_GATE_CHEXP: ( @@ -603,6 +615,7 @@ class TensorNameMap: "model.layers.{bid}.feed_forward.experts.down_proj", # llama4 "encoder.layers.{bid}.mlp.experts.mlp.w2", # nomic-bert-moe "model.layers.{bid}.block_sparse_moe.experts.down", # smallthinker + "model.layers.{bid}.moe.down_proj", # step3.5 ), MODEL_TENSOR.FFN_DOWN_SHEXP: ( @@ -613,6 +626,8 @@ class TensorNameMap: "model.layers.{bid}.mlp.shared_mlp.down_proj", # hunyuan "layers.{bid}.shared_experts.w2", # mistral-large "backbone.layers.{bid}.mixer.shared_experts.down_proj", # nemotron-h-moe + "model.layers.{bid}.block_sparse_moe.shared_experts.down_proj", # kimi + "model.layers.{bid}.share_expert.down_proj", # step3.5 ), MODEL_TENSOR.FFN_DOWN_CHEXP: ( @@ -759,6 +774,7 @@ class TensorNameMap: "model.layers.layers.{bid}.mixer.dt_proj", # plamo2 "model.layers.{bid}.linear_attn.dt_proj", # qwen3next "backbone.layers.{bid}.mixer.dt", # nemotron-h-moe + "model.layers.{bid}.self_attn.dt_proj", # kimi ), MODEL_TENSOR.SSM_DT_NORM: ( @@ -772,6 +788,7 @@ class TensorNameMap: "model.layers.{bid}.mamba.A_log", # jamba falcon-h1 granite-hybrid "model.layers.layers.{bid}.mixer.A_log", # plamo2 "model.layers.{bid}.linear_attn.A_log", # qwen3next + "model.layers.{bid}.self_attn.A_log", # kimi ), MODEL_TENSOR.SSM_B_NORM: ( @@ -797,6 +814,7 @@ class TensorNameMap: "model.layers.{bid}.mamba.norm", # falcon-h1 granite-hybrid "model.layers.{bid}.linear_attn.norm", # qwen3next "backbone.layers.{bid}.mixer.norm", # mamba2 + "model.layers.{bid}.self_attn.o_norm", # kimi ), MODEL_TENSOR.SSM_OUT: ( @@ -811,6 +829,31 @@ class TensorNameMap: "model.layers.{bid}.linear_attn.in_proj_ba", # qwen3next ), + # Kimi Linear KDA (using SSM_ prefix for consistency) + MODEL_TENSOR.SSM_CONV1D_Q: ( + "model.layers.{bid}.self_attn.q_conv1d", + ), + MODEL_TENSOR.SSM_CONV1D_K: ( + "model.layers.{bid}.self_attn.k_conv1d", + ), + MODEL_TENSOR.SSM_CONV1D_V: ( + "model.layers.{bid}.self_attn.v_conv1d", + ), + MODEL_TENSOR.SSM_F_A: ( + "model.layers.{bid}.self_attn.f_a_proj", + ), + MODEL_TENSOR.SSM_F_B: ( + "model.layers.{bid}.self_attn.f_b_proj", + ), + MODEL_TENSOR.SSM_BETA: ( + "model.layers.{bid}.self_attn.b_proj", + ), + MODEL_TENSOR.SSM_G_A: ( + "model.layers.{bid}.self_attn.g_a_proj", + ), + MODEL_TENSOR.SSM_G_B: ( + "model.layers.{bid}.self_attn.g_b_proj", + ), MODEL_TENSOR.TIME_MIX_W0: ( "model.layers.{bid}.attention.w0", # rwkv7 ), @@ -1281,6 +1324,7 @@ class TensorNameMap: "model.vision_tower.embeddings.cls_token", # Intern-S1 "vision_model.class_embedding", # llama 4 "model.vision.patch_embedding.cls_embedding", # cogvlm + "cls_token", # JinaCLIP v2 vision ), MODEL_TENSOR.V_ENC_EMBD_PATCH: ( @@ -1295,6 +1339,7 @@ class TensorNameMap: "vision_tower.patch_embed.proj", # kimi-vl "model.vision.patch_embedding.proj", # cogvlm "siglip2.vision_model.embeddings.patch_embedding", + "patch_embed.proj", # JinaCLIP v2 vision ), MODEL_TENSOR.V_ENC_EMBD_NORM: ( @@ -1329,6 +1374,7 @@ class TensorNameMap: "visual.blocks.{bid}.attn.q", # qwen2vl, generated "vision_tower.encoder.blocks.{bid}.wq", # kimi-vl, generated "siglip2.vision_model.encoder.layers.{bid}.self_attn.q_proj", # youtuvl + "blocks.{bid}.attn.q_proj", # JinaCLIP v2 vision ), MODEL_TENSOR.V_ENC_ATTN_Q_NORM: ( @@ -1347,6 +1393,7 @@ class TensorNameMap: "visual.blocks.{bid}.attn.k", # qwen2vl, generated "vision_tower.encoder.blocks.{bid}.wk", # kimi-vl, generated "siglip2.vision_model.encoder.layers.{bid}.self_attn.k_proj", + "blocks.{bid}.attn.k_proj", # JinaCLIP v2 vision ), MODEL_TENSOR.V_ENC_ATTN_K_NORM: ( @@ -1365,6 +1412,7 @@ class TensorNameMap: "visual.blocks.{bid}.attn.v", # qwen2vl, generated "vision_tower.encoder.blocks.{bid}.wv", # kimi-vl, generated "siglip2.vision_model.encoder.layers.{bid}.self_attn.v_proj", + "blocks.{bid}.attn.v_proj", # JinaCLIP v2 vision ), MODEL_TENSOR.V_ENC_INPUT_NORM: ( @@ -1380,6 +1428,7 @@ class TensorNameMap: "vision_tower.encoder.blocks.{bid}.norm0", # kimi-vl (norm0/norm1) "model.vision.transformer.layers.{bid}.input_layernorm", # cogvlm "siglip2.vision_model.encoder.layers.{bid}.layer_norm1", + "blocks.{bid}.norm1", # JinaCLIP v2 vision ), MODEL_TENSOR.V_ENC_ATTN_O: ( @@ -1396,6 +1445,7 @@ class TensorNameMap: "vision_tower.encoder.blocks.{bid}.wo", # kimi-vl "model.vision.transformer.layers.{bid}.attention.dense", # cogvlm "siglip2.vision_model.encoder.layers.{bid}.self_attn.out_proj", # youtuvl + "blocks.{bid}.attn.proj", # JinaCLIP v2 vision ), MODEL_TENSOR.V_ENC_POST_ATTN_NORM: ( @@ -1411,6 +1461,11 @@ class TensorNameMap: "vision_tower.encoder.blocks.{bid}.norm1", # kimi-vl (norm0/norm1) "model.vision.transformer.layers.{bid}.post_attention_layernorm", # cogvlm "siglip2.vision_model.encoder.layers.{bid}.layer_norm2", + "blocks.{bid}.norm2", # JinaCLIP v2 vision + ), + + MODEL_TENSOR.V_ENC_ATTN_LN: ( + "blocks.{bid}.attn.inner_attn_ln", # JinaCLIP v2 vision ), MODEL_TENSOR.V_ENC_FFN_UP: ( @@ -1427,12 +1482,14 @@ class TensorNameMap: "vision_tower.encoder.blocks.{bid}.mlp.fc0", # kimi-vl (fc0/fc1) "model.vision.transformer.layers.{bid}.mlp.fc1", # cogvlm "siglip2.vision_model.encoder.layers.{bid}.mlp.fc1", + "blocks.{bid}.mlp.w2", # JinaCLIP v2 vision (up) ), MODEL_TENSOR.V_ENC_FFN_GATE: ( "vision_tower.transformer.layers.{bid}.feed_forward.gate_proj", # pixtral-hf "vision_encoder.transformer.layers.{bid}.feed_forward.w1", # pixtral "visual.blocks.{bid}.mlp.gate_proj", # qwen2.5vl + "blocks.{bid}.mlp.w1", # JinaCLIP v2 vision ), MODEL_TENSOR.V_ENC_FFN_DOWN: ( @@ -1449,6 +1506,11 @@ class TensorNameMap: "vision_tower.encoder.blocks.{bid}.mlp.fc1", # kimi-vl (fc0/fc1) "model.vision.transformer.layers.{bid}.mlp.fc2", # cogvlm "siglip2.vision_model.encoder.layers.{bid}.mlp.fc2", + "blocks.{bid}.mlp.w3", # JinaCLIP v2 vision (down) + ), + + MODEL_TENSOR.V_ENC_FFN_NORM: ( + "blocks.{bid}.mlp.ffn_ln", # JinaCLIP v2 vision ), MODEL_TENSOR.V_LAYER_SCALE_1: ( @@ -1461,6 +1523,14 @@ class TensorNameMap: "model.vision_tower.encoder.layer.{bid}.lambda_2", # Intern-S1 ), + MODEL_TENSOR.V_ENC_ATTN_Q_BIAS: ( + "blocks.{bid}.attn.q_bias", # JinaCLIP v2 vision + ), + + MODEL_TENSOR.V_ENC_ATTN_V_BIAS: ( + "blocks.{bid}.attn.v_bias", # JinaCLIP v2 vision + ), + MODEL_TENSOR.V_PRE_NORM: ( "vision_tower.vision_model.pre_layrnorm", "vision_tower.ln_pre", # pixtral-hf @@ -1474,6 +1544,7 @@ class TensorNameMap: "vision_model.layernorm_post", # llama4 "visual.merger.ln_q", # qwen2vl "vision_tower.encoder.final_layernorm", # kimi-vl + "norm", # JinaCLIP v2 vision "visual.post_layernorm", # glm4v "siglip2.vision_model.post_layernorm", ), diff --git a/gguf-py/pyproject.toml b/gguf-py/pyproject.toml index f6c4cd14e74..48693ae3e3a 100644 --- a/gguf-py/pyproject.toml +++ b/gguf-py/pyproject.toml @@ -23,7 +23,7 @@ numpy = ">=1.17" tqdm = ">=4.27" pyyaml = ">=5.1" requests = ">=2.25" -sentencepiece = { version = ">=0.1.98,<=0.2.0", optional = true } +sentencepiece = { version = ">=0.1.98,<0.3.0", optional = true } PySide6 = { version = "^6.9", python = ">=3.9,<3.14", optional = true } [tool.poetry.dev-dependencies] diff --git a/include/llama.h b/include/llama.h index 280745713e5..bf4e28a8be1 100644 --- a/include/llama.h +++ b/include/llama.h @@ -309,7 +309,7 @@ extern "C" { // Keep the booleans together to avoid misalignment during copy-by-value. bool vocab_only; // only load the vocabulary, no weights bool use_mmap; // use mmap if possible - bool use_direct_io; // use direct io, takes precedence over use_mmap + bool use_direct_io; // use direct io, takes precedence over use_mmap when supported bool use_mlock; // force system to keep model in RAM bool check_tensors; // validate model tensor data bool use_extra_bufts; // use extra buffer types (used for weight repacking) @@ -489,6 +489,7 @@ extern "C" { // - returns true if the parameters could be successfully modified to fit device memory // - this function is NOT thread safe because it modifies the global llama logger state // - only parameters that have the same value as in llama_default_model_params are modified + // with the exception of the context size which is modified if and only if equal to 0 LLAMA_API enum llama_params_fit_status llama_params_fit( const char * path_model, struct llama_model_params * mparams, @@ -1475,12 +1476,12 @@ extern "C" { /// @details Build a split GGUF final path for this chunk. /// llama_split_path(split_path, sizeof(split_path), "/models/ggml-model-q4_0", 2, 4) => split_path = "/models/ggml-model-q4_0-00002-of-00004.gguf" // Returns the split_path length. - LLAMA_API int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int split_no, int split_count); + LLAMA_API int32_t llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int32_t split_no, int32_t split_count); /// @details Extract the path prefix from the split_path if and only if the split_no and split_count match. /// llama_split_prefix(split_prefix, 64, "/models/ggml-model-q4_0-00002-of-00004.gguf", 2, 4) => split_prefix = "/models/ggml-model-q4_0" // Returns the split_prefix length. - LLAMA_API int llama_split_prefix(char * split_prefix, size_t maxlen, const char * split_path, int split_no, int split_count); + LLAMA_API int32_t llama_split_prefix(char * split_prefix, size_t maxlen, const char * split_path, int32_t split_no, int32_t split_count); // Print system information LLAMA_API const char * llama_print_system_info(void); diff --git a/models/templates/upstage-Solar-Open-100B.jinja b/models/templates/upstage-Solar-Open-100B.jinja new file mode 100644 index 00000000000..13268c1a841 --- /dev/null +++ b/models/templates/upstage-Solar-Open-100B.jinja @@ -0,0 +1,156 @@ +{#- ======== Template Parameters ======== #} +{%- set add_generation_prompt = add_generation_prompt if add_generation_prompt is defined else true %} +{%- set default_system_prompt = default_system_prompt if default_system_prompt is defined else true %} +{%- set reasoning_effort = reasoning_effort if reasoning_effort is defined else "high" %} +{%- set think_render_option = think_render_option if think_render_option is defined else "lastthink" %} + +{#- ======== System Block State ======== #} +{%- set sys_ns = namespace(is_first_block=true) -%} + +{#- ======== Find last user message index ======== #} +{%- set last_user_idx = namespace(value=-1) -%} +{%- for message in messages -%} + {%- if message.role == 'user' -%} + {%- set last_user_idx.value = loop.index0 -%} + {%- endif -%} +{%- endfor -%} + +{#- ======== System messages renderers ======== #} +{%- macro render_system_message(user_system_messages) %} + {%- if default_system_prompt %} + {%- if not sys_ns.is_first_block %}{{- "\n\n" }}{%- endif %} + {%- set sys_ns.is_first_block = false %} + {{- "## Provider System Prompt\n\nYou are Solar Open 100B, a large language model trained by Upstage AI, a Korean startup. Your knowledge cutoff is 2025-07. The current date is " + strftime_now("%Y-%m-%d") + "." }} + {%- endif -%} + {%- if user_system_messages %} + {%- if not sys_ns.is_first_block %}{{- "\n\n" }}{%- endif %} + {%- set sys_ns.is_first_block = false %} + {{- "## System Prompt" }} + {%- for system_message in user_system_messages %} + {{- "\n\n" }} + {{- system_message }} + {%- endfor %} + {%- endif -%} +{%- endmacro %} + +{%- macro render_tool_instruction(tools) %} + {%- if not sys_ns.is_first_block %}{{- "\n\n" }}{%- endif %} + {%- set sys_ns.is_first_block = false %} + {{- "## Tools\n\n### Tool Call Instruction" }} + {{- "\nYou may invoke one or more tools to assist with the user's query. Available tools are provided in JSON Schema format: <|tools:begin|><|tool:begin|><|tool:end|>...<|tools:end|>\n" }} + {{- "\n### Available Tools\n" }} + {{- "<|tools:begin|>" }} + {%- for tool in tools %} + {{- "<|tool:begin|>" }} + {{- tool.function | tojson }} + {{- "<|tool:end|>" }} + {%- endfor %} + {{- "<|tools:end|>\n" }} + {{- "\n### Tool Call Format\n" }} + {{- "For each tool call, return a JSON object with the following structure, enclosed within <|tool_call:begin|> and <|tool_call:end|> tags: \n<|tool_call:begin|><|tool_call:name|><|tool_call:args|><|tool_call:end|>\n" }} + {{- "- The must be a randomly generated string consisting of 10 lowercase letters (a-z) and/or digits (0-9) (e.g., a1b2c3d4e5)\n" }} + {{- "\n### Tool Response Format\n" }} + {{- "Each tool is responded by `tool` with the following structure:\n<|tool_response:id|><|tool_response:name|><|tool_response:result|><|tool_response:end|>\n" }} + {{- "- Ensure the matches the corresponding tool call" -}} +{%- endmacro %} + +{%- macro render_json_response_format_instruction(response_format) %} + {%- if not sys_ns.is_first_block %}{{- "\n\n" }}{%- endif %} + {%- set sys_ns.is_first_block = false %} + {{- "## Output Format Constraint" }} + {{- "\n\nYour final response should follow the JSON schema: \n[Start of schema]" }} + {{- response_format }} + {{- "\n[End of schema]\nPlease ensure your answers adhere to this format and do not contain any unnecessary text." }} +{%- endmacro %} + +{%- macro get_tool_name(messages, tool_call_id) %} + {%- for msg in messages -%} + {%- if msg.role == 'assistant' and msg.tool_calls -%} + {%- for tool_call in msg.tool_calls -%} + {%- if tool_call.id == tool_call_id -%} + {{- tool_call.function.name }} + {%- endif -%} + {%- endfor -%} + {%- endif -%} + {%- endfor -%} +{%- endmacro %} + +{%- macro render_tool_arguments(tool_arguments) %} + {%- if tool_arguments is mapping -%} + {{- tool_arguments | tojson }} + {%- else -%} + {{- tool_arguments }} + {%- endif -%} +{%- endmacro %} + +{#- ======== Render system message ======== #} +{%- set ns = namespace(system_messages=[]) -%} +{%- for message in messages -%} + {%- if message.role == 'system' -%} + {%- set ns.system_messages = ns.system_messages + [message.content] -%} + {%- endif -%} +{%- endfor -%} + +{%- if ns.system_messages or default_system_prompt or tools or response_format -%} + {{- "<|begin|>system<|content|>" }} + {{- render_system_message(ns.system_messages) }} + {%- if tools -%} + {{- render_tool_instruction(tools) }} + {%- endif %} + {%- if response_format -%} + {{- render_json_response_format_instruction(response_format) }} + {%- endif %} + {{- "<|end|>" }} +{%- endif -%} + +{#- ======== Render main messages ======== #} +{%- for message in messages -%} + {%- if message.role == 'user' -%} + {{- "<|begin|>user<|content|>" + message.content + "<|end|>" }} + {%- elif message.role == 'tool' -%} + {%- set prev_is_tool = loop.index0 > 0 and messages[loop.index0 - 1].role == 'tool' -%} + {%- set next_is_tool = loop.index0 < (messages | length - 1) and messages[loop.index0 + 1].role == 'tool' -%} + {%- if not prev_is_tool -%} + {{- "<|begin|>tool<|tool_response|>" }} + {%- endif -%} + {{- "<|tool_response:begin|>" + message.tool_call_id + "<|tool_response:name|>" }} + {{- get_tool_name(messages, message.tool_call_id) }} + {{- "<|tool_response:result|>" }} + {{- message.content }} + {{- "<|tool_response:end|>" }} + {%- if not next_is_tool -%} + {{- "<|end|>" }} + {%- endif -%} + {%- elif message.role == 'assistant' -%} + {#- ======== Assistant Thinking ======== #} + {%- if think_render_option == "all" -%} + {%- if message.reasoning -%} + {{- "<|begin|>assistant<|think|>" + message.reasoning + "<|end|>" }} + {%- endif -%} + {%- elif think_render_option == "lastthink" -%} + {%- if message.reasoning and loop.index0 > last_user_idx.value -%} + {{- "<|begin|>assistant<|think|>" + message.reasoning + "<|end|>" }} + {%- endif -%} + {%- endif -%} + + {#- ======== Assistant Messages ======== #} + {%- if message.tool_calls -%} + {{- "<|begin|>assistant<|tool_calls|>" }} + {%- for tool_call in message.tool_calls -%} + {{- "<|tool_call:begin|>" + tool_call.id +"<|tool_call:name|>" + tool_call.function.name + "<|tool_call:args|>" }} + {{- render_tool_arguments(tool_call.function.arguments) }} + {{- "<|tool_call:end|>" }} + {%- endfor -%} + {{- "<|calls|>" }} + {%- else -%} + {{- "<|begin|>assistant<|content|>" + message.content + "<|end|>" }} + {%- endif -%} + {%- endif -%} +{%- endfor -%} + +{%- if add_generation_prompt -%} + {%- if reasoning_effort in ["low", "minimal"] -%} + {{- "<|begin|>assistant<|think|><|end|>" }} + {%- endif -%} + {{- "<|begin|>assistant" }} +{%- endif -%} diff --git a/pyproject.toml b/pyproject.toml index 3d71b055a8d..422f53c7c72 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,7 @@ classifiers = [ [tool.poetry.dependencies] python = ">=3.9" numpy = "^1.25.0" -sentencepiece = ">=0.1.98,<=0.2.0" +sentencepiece = ">=0.1.98,<0.3.0" transformers = ">=4.35.2,<5.0.0" protobuf = ">=4.21.0,<5.0.0" gguf = { path = "./gguf-py" } diff --git a/requirements/requirements-convert_legacy_llama.txt b/requirements/requirements-convert_legacy_llama.txt index dbab3b9508f..4898bf7ee29 100644 --- a/requirements/requirements-convert_legacy_llama.txt +++ b/requirements/requirements-convert_legacy_llama.txt @@ -1,5 +1,5 @@ numpy~=1.26.4 -sentencepiece~=0.2.0 +sentencepiece>=0.1.98,<0.3.0 transformers>=4.57.1,<5.0.0 diff --git a/requirements/requirements-tool_bench.txt b/requirements/requirements-tool_bench.txt index f7912aff724..3bb74fb9d01 100644 --- a/requirements/requirements-tool_bench.txt +++ b/requirements/requirements-tool_bench.txt @@ -3,7 +3,7 @@ pytest~=8.3.3 huggingface_hub>=0.34.0,<1.0 matplotlib~=3.10.0 numpy~=1.26.4 -openai~=1.55.3 +openai~=2.14.0 pandas~=2.2.3 prometheus-client~=0.20.0 requests~=2.32.3 diff --git a/scripts/bench-models.sh b/scripts/bench-models.sh old mode 100644 new mode 100755 index 744b0de359c..c241013040f --- a/scripts/bench-models.sh +++ b/scripts/bench-models.sh @@ -7,47 +7,54 @@ ARGS_BB="-c 270336 -npp 512,4096,8192 -npl 1,2,4,8,16,32 -ntg 32" ARGS_B="-d 0,4096,8192,16384,32768 -p 2048 -n 32" QUICK=0 +DIO=0 while (( "$#" )); do - case "$1" in - --quick) QUICK=1; shift ;; - *) shift ;; - esac + case "$1" in + --quick) QUICK=1; shift ;; + --dio) DIO=1; shift ;; + *) shift ;; + esac done if (( QUICK )); then - ARGS_BB="-c 20480 -npp 512,4096 -npl 1,2,4 -ntg 32" - ARGS_B="-d 0 -p 2048 -n 32" + ARGS_BB="-c 20480 -npp 512,4096 -npl 1,2,4 -ntg 32" + ARGS_B="-d 0 -p 2048 -n 32" +fi + +if (( DIO )); then + ARGS_BB="${ARGS_BB} --no-mmap --direct-io" + ARGS_B="${ARGS_B} -mmp 0 -dio 1" fi run_model() { - local HFR=$1 - local HFF=$2 + local HFR=$1 + local HFF=$2 - printf "## ${HFR}\n" | tee -a "$RESULTS" - printf "\n" | tee -a "$RESULTS" - printf "Model: https://huggingface.co/${HFR}\n" | tee -a "$RESULTS" - printf "\n" | tee -a "$RESULTS" + printf "## ${HFR}\n" | tee -a "$RESULTS" + printf "\n" | tee -a "$RESULTS" + printf "Model: https://huggingface.co/${HFR}\n" | tee -a "$RESULTS" + printf "\n" | tee -a "$RESULTS" - printf -- "- \`llama-batched-bench\`\n" | tee -a "$RESULTS" - printf "\n" | tee -a "$RESULTS" + printf -- "- \`llama-batched-bench\`\n" | tee -a "$RESULTS" + printf "\n" | tee -a "$RESULTS" - ./bin/llama-batched-bench \ - -hfr "${HFR}" -hff "${HFF}" \ - -m "${HFF}" -fa 1 -ub 2048 --no-mmap \ - ${ARGS_BB} | tee -a "$RESULTS" + ./bin/llama-batched-bench \ + -hfr "${HFR}" -hff "${HFF}" \ + -m "${HFF}" -fa 1 -ub 2048 \ + ${ARGS_BB} | tee -a "$RESULTS" - printf "\n" | tee -a "$RESULTS" + printf "\n" | tee -a "$RESULTS" - printf -- "- \`llama-bench\`\n" | tee -a "$RESULTS" - printf "\n" | tee -a "$RESULTS" + printf -- "- \`llama-bench\`\n" | tee -a "$RESULTS" + printf "\n" | tee -a "$RESULTS" - ./bin/llama-bench \ - -m "${HFF}" -fa 1 -ub 2048 -mmp 0 \ - ${ARGS_B} | tee -a "$RESULTS" + ./bin/llama-bench \ + -m "${HFF}" -fa 1 -ub 2048 \ + ${ARGS_B} | tee -a "$RESULTS" - printf "\n" | tee -a "$RESULTS" + printf "\n" | tee -a "$RESULTS" - printf "\n" + printf "\n" } run_model "ggml-org/gpt-oss-20b-GGUF" "gpt-oss-20b-mxfp4.gguf" @@ -55,6 +62,7 @@ run_model "ggml-org/gpt-oss-120b-GGUF" "gpt-oss-120b-mxfp4- run_model "ggml-org/Qwen3-Coder-30B-A3B-Instruct-Q8_0-GGUF" "qwen3-coder-30b-a3b-instruct-q8_0.gguf" run_model "ggml-org/Qwen2.5-Coder-7B-Q8_0-GGUF" "qwen2.5-coder-7b-q8_0.gguf" run_model "ggml-org/gemma-3-4b-it-qat-GGUF" "gemma-3-4b-it-qat-Q4_0.gguf" +run_model "ggml-org/GLM-4.7-Flash-GGUF" "GLM-4.7-Flash-Q8_0.gguf" if [[ -f models-extra.txt ]]; then while read -r HFR HFF; do diff --git a/scripts/compare-llama-bench.py b/scripts/compare-llama-bench.py index c45c83fdb55..9541b89eb91 100755 --- a/scripts/compare-llama-bench.py +++ b/scripts/compare-llama-bench.py @@ -29,7 +29,7 @@ "cpu_mask", "cpu_strict", "poll", "type_k", "type_v", "n_gpu_layers", "split_mode", "main_gpu", "no_kv_offload", "flash_attn", "tensor_split", "tensor_buft_overrides", "use_mmap", "embeddings", "no_op_offload", "n_prompt", "n_gen", "n_depth", - "test_time", "avg_ns", "stddev_ns", "avg_ts", "stddev_ts", + "test_time", "avg_ns", "stddev_ns", "avg_ts", "stddev_ts", "n_cpu_moe" ] LLAMA_BENCH_DB_TYPES = [ @@ -38,7 +38,7 @@ "TEXT", "INTEGER", "INTEGER", "TEXT", "TEXT", "INTEGER", "TEXT", "INTEGER", "INTEGER", "INTEGER", "TEXT", "TEXT", "INTEGER", "INTEGER", "INTEGER", "INTEGER", "INTEGER", "INTEGER", - "TEXT", "INTEGER", "INTEGER", "REAL", "REAL", + "TEXT", "INTEGER", "INTEGER", "REAL", "REAL", "INTEGER", ] # All test-backend-ops SQL fields @@ -59,7 +59,7 @@ # Properties by which to differentiate results per commit for llama-bench: LLAMA_BENCH_KEY_PROPERTIES = [ - "cpu_info", "gpu_info", "backends", "n_gpu_layers", "tensor_buft_overrides", "model_filename", "model_type", + "cpu_info", "gpu_info", "backends", "n_gpu_layers", "n_cpu_moe", "tensor_buft_overrides", "model_filename", "model_type", "n_batch", "n_ubatch", "embeddings", "cpu_mask", "cpu_strict", "poll", "n_threads", "type_k", "type_v", "use_mmap", "no_kv_offload", "split_mode", "main_gpu", "tensor_split", "flash_attn", "n_prompt", "n_gen", "n_depth" ] diff --git a/scripts/snapdragon/windows/run-bench.ps1 b/scripts/snapdragon/windows/run-bench.ps1 new file mode 100644 index 00000000000..21fd063ebe3 --- /dev/null +++ b/scripts/snapdragon/windows/run-bench.ps1 @@ -0,0 +1,40 @@ + +#!/usr/bin/env pwsh + +# Basedir on device +$basedir=".\pkg-snapdragon" + +$cli_opts=$args + +$model="Llama-3.2-3B-Instruct-Q4_0.gguf" +if ($null -ne $env:M) { + $model=$env:M +} + +$device="HTP0" +if ($null -ne $env:D) { + $device=$env:D +} + +if ($null -ne $env:V) { + $env:GGML_HEXAGON_VERBOSE=$env:V +} + +if ($null -ne $env:OPMASK) { + $env:GGML_HEXAGON_OPMASK=$env:OPMASK +} + +if ($null -ne $env:NHVX) { + $env:GGML_HEXAGON_NHVX=$env:NHVX +} + +if ($null -ne $env:NDEV) { + $env:GGML_HEXAGON_NDEV=$env:NDEV +} + +$env:ADSP_LIBRARY_PATH="$basedir\lib" + +& "$basedir\bin\llama-bench.exe" ` + --mmap 0 -m $basedir\..\..\gguf\$model ` + --poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 ` + --batch-size 128 -ngl 99 --device $device $cli_opts diff --git a/scripts/snapdragon/windows/run-cli.ps1 b/scripts/snapdragon/windows/run-cli.ps1 new file mode 100644 index 00000000000..b13161aa631 --- /dev/null +++ b/scripts/snapdragon/windows/run-cli.ps1 @@ -0,0 +1,53 @@ + +#!/usr/bin/env pwsh + +# Basedir on device +$basedir=".\pkg-snapdragon" + +$cli_opts=$args + +$model="Llama-3.2-3B-Instruct-Q4_0.gguf" +if ($null -ne $env:M) { + $model=$env:M +} + +$device="HTP0" +if ($null -ne $env:D) { + $device=$env:D +} + +if ($null -ne $env:V) { + $env:GGML_HEXAGON_VERBOSE=$env:V +} + +if ($null -ne $env:E) { + $env:GGML_HEXAGON_EXPERIMENTAL=$env:E +} + +if ($null -ne $env:SCHED) { + $env:GGML_SCHED_DEBUG=$env:SCHED; $cli_opts="$cli_opts -v" +} + +if ($null -ne $env:PROF) { + $env:GGML_HEXAGON_PROFILE=$env:PROF; $env:GGML_HEXAGON_OPSYNC=1 +} + +if ($null -ne $env:OPMASK) { + $env:GGML_HEXAGON_OPMASK=$env:OPMASK +} + +if ($null -ne $env:NHVX) { + $env:GGML_HEXAGON_NHVX=$env:NHVX +} + +if ($null -ne $env:NDEV) { + $env:GGML_HEXAGON_NDEV=$env:NDEV +} + +$env:ADSP_LIBRARY_PATH="$basedir\lib" + +& "$basedir\bin\llama-completion.exe" ` + --no-mmap -no-cnv -m $basedir\..\..\gguf\$model ` + --poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 ` + --ctx-size 8192 --batch-size 128 -ctk q8_0 -ctv q8_0 -fa on ` + -ngl 99 --device $device $cli_opts diff --git a/scripts/snapdragon/windows/run-tool.ps1 b/scripts/snapdragon/windows/run-tool.ps1 new file mode 100644 index 00000000000..70094af9bc9 --- /dev/null +++ b/scripts/snapdragon/windows/run-tool.ps1 @@ -0,0 +1,56 @@ + +#!/usr/bin/env pwsh + +# Basedir on device +$basedir=".\pkg-snapdragon" + +if ($args.Count -eq 0) { + Write-Host "No arguments provided.Expected the tool and argument to run." + exit -1 +} + +$tool=$args[0] +$cli_opts=@() + +if ($args.Count -gt 1) { + $cli_opts=$args[1..($args.Count - 1)] + $remainingArgs = $args[1..($args.Count - 1)] +} + +$device="HTP0" +if ($null -ne $env:D) { + $device=$env:D +} + +if ($null -ne $env:V) { + $env:GGML_HEXAGON_VERBOSE=$env:V +} + +if ($null -ne $env:E) { + $env:GGML_HEXAGON_EXPERIMENTAL=$env:E +} + +if ($null -ne $env:SCHED) { + $env:GGML_SCHED_DEBUG=$env:SCHED; $cli_opts="$cli_opts -v" +} + +if ($null -ne $env:PROF) { + $env:GGML_HEXAGON_PROFILE=$env:PROF; $env:GGML_HEXAGON_OPSYNC=1 +} + +if ($null -ne $env:OPMASK) { + $env:GGML_HEXAGON_OPMASK=$env:OPMASK +} + +if ($null -ne $env:NHVX) { + $env:GGML_HEXAGON_NHVX=$env:NHVX +} + +if ($null -ne $env:NDEV) { + $env:GGML_HEXAGON_NDEV=$env:NDEV +} + +$env:ADSP_LIBRARY_PATH="$basedir\lib" + +& "$basedir\bin\$tool" ` + $cli_opts diff --git a/scripts/snapdragon/windows/setup-build.ps1 b/scripts/snapdragon/windows/setup-build.ps1 new file mode 100644 index 00000000000..0f3244cc9d2 --- /dev/null +++ b/scripts/snapdragon/windows/setup-build.ps1 @@ -0,0 +1,105 @@ +# Requires Run as Administrator is NOT strictly necessary for User-scope env vars, +# but recommended for creating directories in C:\ root if permissions are restricted. + +$ErrorActionPreference = "Stop" + +# --- Configuration --- +$BaseDir = "C:\Qualcomm" + +# SDK 1: Hexagon +$HexagonUrl = "https://github.com/snapdragon-toolchain/hexagon-sdk/releases/download/v6.4.0.2/hexagon-sdk-v6.4.0.2-arm64-wos.tar.xz" +$HexagonParent = Join-Path $BaseDir "Hexagon_SDK" +$HexagonSdkVersion = "6.4.0.2" +$HexagonToolsVersion = "19.0.04" +$HexagonSdkTarget = Join-Path $HexagonParent $HexagonSdkVersion +$HexagonToolsTarget = Join-Path $HexagonSdkTarget "\tools\HEXAGON_Tools\$HexagonToolsVersion" + +# SDK 2: OpenCL +$OpenCLUrl = "https://github.com/snapdragon-toolchain/opencl-sdk/releases/download/v2.3.2/adreno-opencl-sdk-v2.3.2-arm64-wos.tar.xz" +$OpenCLParent = Join-Path $BaseDir "OpenCL_SDK" +$OpenCLVersion = "2.3.2" +$OpenCLTarget = Join-Path $OpenCLParent $OpenCLVersion + +# --- Helper Function --- +function Install-QualcommSDK { + param ( + [string]$Url, + [string]$ParentDir, + [string]$TargetDir, + [string]$Name + ) + + # 1. Create Parent Directory + if (-not (Test-Path -Path $ParentDir)) { + Write-Host "Creating directory: $ParentDir" -ForegroundColor Cyan + New-Item -Path $ParentDir -ItemType Directory -Force | Out-Null + } + + # 2. Check for Specific Version Directory + if (Test-Path -Path $TargetDir) { + Write-Host "$Name ($TargetDir) already exists. Skipping download." -ForegroundColor Green + } + else { + Write-Host "$Name not found. preparing to download..." -ForegroundColor Yellow + + # Create the target directory to extract into + New-Item -Path $TargetDir -ItemType Directory -Force | Out-Null + + # Define temporary archive path + $TempFile = Join-Path $ParentDir "temp_sdk.tar.xz" + + try { + # Download + Write-Host "Downloading from: $Url" + Invoke-WebRequest -Uri $Url -OutFile $TempFile + + # Untar + # Note: We assume Windows includes tar.exe (Win 10 build 17063+) + Write-Host "Extracting archive to $TargetDir..." + + # We use -C to extract contents INTO the target directory created above + tar -xJvf $TempFile -C $TargetDir\.. + + Write-Host "Extraction complete." -ForegroundColor Green + } + catch { + Write-Error "Failed to download or extract $Name. Error: $_" + # Cleanup target dir if failed so script tries again next time + Remove-Item -Path $TargetDir -Recurse -Force -ErrorAction SilentlyContinue + } + finally { + # Cleanup Archive + if (Test-Path $TempFile) { Remove-Item $TempFile -Force } + } + } +} + +# --- Execution --- + +# 1. Ensure Base C:\Qualcomm exists +if (-not (Test-Path $BaseDir)) { + New-Item -Path $BaseDir -ItemType Directory -Force | Out-Null +} + +# 2. Run Install Logic +Install-QualcommSDK -Url $HexagonUrl -ParentDir $HexagonParent -TargetDir $HexagonSdkTarget -Name "Hexagon SDK" +Install-QualcommSDK -Url $OpenCLUrl -ParentDir $OpenCLParent -TargetDir $OpenCLTarget -Name "OpenCL SDK" + +# --- Environment Variables --- + +Write-Host "`nSetting Environment Variables..." -ForegroundColor Cyan + +# Set OPENCL_SDK_ROOT +[System.Environment]::SetEnvironmentVariable('OPENCL_SDK_ROOT', $OpenCLTarget, [System.EnvironmentVariableTarget]::User) +$env:OPENCL_SDK_ROOT = $OpenCLTarget # Set for current session as well +Write-Host "OPENCL_SDK_ROOT set to: $OpenCLTarget" + +# Set HEXAGON_SDK_ROOT +[System.Environment]::SetEnvironmentVariable('HEXAGON_SDK_ROOT', $HexagonSdkTarget, [System.EnvironmentVariableTarget]::User) +$env:HEXAGON_SDK_ROOT = $HexagonSdkTarget # Set for current session as well +Write-Host "HEXAGON_SDK_ROOT set to: $HexagonSdkTarget" + +# Set HEXAGON_SDK_ROOT +[System.Environment]::SetEnvironmentVariable('HEXAGON_TOOLS_ROOT', $HexagonToolsTarget, [System.EnvironmentVariableTarget]::User) +$env:HEXAGON_TOOLS_ROOT = $HexagonToolsTarget # Set for current session as well +Write-Host "HEXAGON_TOOLS_ROOT set to: $HexagonToolsTarget" diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index c8382761582..81e79a94707 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -ebc3a0f4a56be1c9424a89fbec09962ac34fde85 +a8db410a252c8c8f2d120c6f2e7133ebe032f35d diff --git a/scripts/sync_vendor.py b/scripts/sync_vendor.py index 0771942d493..1ff6a9a40fd 100755 --- a/scripts/sync_vendor.py +++ b/scripts/sync_vendor.py @@ -12,8 +12,8 @@ # "https://github.com/mackron/miniaudio/raw/refs/tags/0.11.23/miniaudio.h": "vendor/miniaudio/miniaudio.h", "https://github.com/mackron/miniaudio/raw/669ed3e844524fcd883231b13095baee9f6de304/miniaudio.h": "vendor/miniaudio/miniaudio.h", - "https://raw.githubusercontent.com/yhirose/cpp-httplib/refs/tags/v0.30.1/httplib.h": "vendor/cpp-httplib/httplib.h", - "https://raw.githubusercontent.com/yhirose/cpp-httplib/refs/tags/v0.30.1/LICENSE": "vendor/cpp-httplib/LICENSE", + "https://raw.githubusercontent.com/yhirose/cpp-httplib/refs/tags/v0.30.2/httplib.h": "vendor/cpp-httplib/httplib.h", + "https://raw.githubusercontent.com/yhirose/cpp-httplib/refs/tags/v0.30.2/LICENSE": "vendor/cpp-httplib/LICENSE", "https://raw.githubusercontent.com/sheredom/subprocess.h/b49c56e9fe214488493021017bf3954b91c7c1f5/subprocess.h": "vendor/sheredom/subprocess.h", } diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index c15c281a5e6..0c164617a12 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -24,13 +24,14 @@ add_library(llama llama-kv-cache-iswa.cpp llama-memory.cpp llama-memory-hybrid.cpp + llama-memory-hybrid-iswa.cpp llama-memory-recurrent.cpp llama-mmap.cpp llama-model-loader.cpp llama-model-saver.cpp llama-model.cpp llama-quant.cpp - llama-sampling.cpp + llama-sampler.cpp llama-vocab.cpp unicode-data.cpp unicode.cpp @@ -56,6 +57,7 @@ add_library(llama models/deci.cpp models/deepseek.cpp models/deepseek2.cpp + models/delta.cpp models/dots1.cpp models/dream.cpp models/ernie4-5-moe.cpp @@ -83,6 +85,7 @@ add_library(llama models/internlm2.cpp models/jais.cpp models/jamba.cpp + models/kimi-linear.cpp models/lfm2.cpp models/llada-moe.cpp models/llada.cpp @@ -120,6 +123,8 @@ add_library(llama models/qwen3vl-moe.cpp models/qwen3moe.cpp models/qwen3next.cpp + models/qwen3-5.cpp + models/qwen3-5moe.cpp models/refact.cpp models/rnd1.cpp models/rwkv6-base.cpp @@ -133,6 +138,7 @@ add_library(llama models/stablelm.cpp models/starcoder.cpp models/starcoder2.cpp + models/step35-iswa.cpp models/t5-dec.cpp models/t5-enc.cpp models/wavtokenizer-dec.cpp diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index a54bc1956ae..fce46772d7e 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -35,6 +35,8 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_QWEN3, "qwen3" }, { LLM_ARCH_QWEN3MOE, "qwen3moe" }, { LLM_ARCH_QWEN3NEXT, "qwen3next" }, + { LLM_ARCH_QWEN3_5, "qwen3_5" }, + { LLM_ARCH_QWEN3_5_MOE, "qwen3_5moe" }, { LLM_ARCH_QWEN3VL, "qwen3vl" }, { LLM_ARCH_QWEN3VLMOE, "qwen3vlmoe" }, { LLM_ARCH_PHI2, "phi2" }, @@ -117,9 +119,11 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_RND1, "rnd1" }, { LLM_ARCH_PANGU_EMBED, "pangu-embedded" }, { LLM_ARCH_MISTRAL3, "mistral3" }, - { LLM_ARCH_MIMO2, "mimo2" }, + { LLM_ARCH_MIMO2, "mimo2" }, + { LLM_ARCH_STEP35, "step35" }, { LLM_ARCH_LLAMA_EMBED, "llama-embed" }, { LLM_ARCH_MAINCODER, "maincoder" }, + { LLM_ARCH_KIMI_LINEAR, "kimi-linear" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -161,6 +165,8 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_EXPERT_FEED_FORWARD_LENGTH, "%s.expert_feed_forward_length" }, { LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, "%s.expert_shared_feed_forward_length" }, { LLM_KV_EXPERT_CHUNK_FEED_FORWARD_LENGTH, "%s.expert_chunk_feed_forward_length" }, + { LLM_KV_SWIGLU_CLAMP_EXP, "%s.swiglu_clamp_exp" }, + { LLM_KV_SWIGLU_CLAMP_SHEXP, "%s.swiglu_clamp_shexp" }, { LLM_KV_USE_PARALLEL_RESIDUAL, "%s.use_parallel_residual" }, { LLM_KV_TENSOR_DATA_LAYOUT, "%s.tensor_data_layout" }, { LLM_KV_EXPERT_COUNT, "%s.expert_count" }, @@ -219,21 +225,21 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" }, { LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" }, - { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, - { LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" }, - { LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" }, - { LLM_KV_ROPE_FREQ_BASE_SWA, "%s.rope.freq_base_swa" }, - { LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" }, - { LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" }, - { LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" }, - { LLM_KV_ROPE_SCALING_ATTN_FACTOR, "%s.rope.scaling.attn_factor" }, - { LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" }, - { LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" }, - { LLM_KV_ROPE_SCALING_YARN_LOG_MUL, "%s.rope.scaling.yarn_log_multiplier" }, - { LLM_KV_ROPE_SCALING_YARN_EXT_FACTOR, "%s.rope.scaling.yarn_ext_factor" }, - { LLM_KV_ROPE_SCALING_YARN_ATTN_FACTOR, "%s.rope.scaling.yarn_attn_factor" }, - { LLM_KV_ROPE_SCALING_YARN_BETA_FAST, "%s.rope.scaling.yarn_beta_fast" }, - { LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, "%s.rope.scaling.yarn_beta_slow" }, + { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, + { LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" }, + { LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" }, + { LLM_KV_ROPE_FREQ_BASE_SWA, "%s.rope.freq_base_swa" }, + { LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" }, + { LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" }, + { LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" }, + { LLM_KV_ROPE_SCALING_ATTN_FACTOR, "%s.rope.scaling.attn_factor" }, + { LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" }, + { LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" }, + { LLM_KV_ROPE_SCALING_YARN_LOG_MUL, "%s.rope.scaling.yarn_log_multiplier" }, + { LLM_KV_ROPE_SCALING_YARN_EXT_FACTOR, "%s.rope.scaling.yarn_ext_factor" }, + { LLM_KV_ROPE_SCALING_YARN_ATTN_FACTOR, "%s.rope.scaling.yarn_attn_factor" }, + { LLM_KV_ROPE_SCALING_YARN_BETA_FAST, "%s.rope.scaling.yarn_beta_fast" }, + { LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, "%s.rope.scaling.yarn_beta_slow" }, { LLM_KV_SPLIT_NO, "split.no" }, { LLM_KV_SPLIT_COUNT, "split.count" }, @@ -246,6 +252,8 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_SSM_GROUP_COUNT, "%s.ssm.group_count" }, { LLM_KV_SSM_DT_B_C_RMS, "%s.ssm.dt_b_c_rms" }, + { LLM_KV_KDA_HEAD_DIM, "%s.kda.head_dim" }, + { LLM_KV_WKV_HEAD_SIZE, "%s.wkv.head_size" }, { LLM_KV_POSNET_EMBEDDING_LENGTH, "%s.posnet.embedding_length" }, @@ -371,6 +379,15 @@ static const std::map LLM_TENSOR_NAMES = { { LLM_TENSOR_SSM_DT_NORM, "blk.%d.ssm_dt_norm" }, { LLM_TENSOR_SSM_B_NORM, "blk.%d.ssm_b_norm" }, { LLM_TENSOR_SSM_C_NORM, "blk.%d.ssm_c_norm" }, + { LLM_TENSOR_SSM_CONV1D_Q, "blk.%d.ssm_conv1d_q" }, + { LLM_TENSOR_SSM_CONV1D_K, "blk.%d.ssm_conv1d_k" }, + { LLM_TENSOR_SSM_CONV1D_V, "blk.%d.ssm_conv1d_v" }, + { LLM_TENSOR_SSM_F_A, "blk.%d.ssm_f_a" }, + { LLM_TENSOR_SSM_F_B, "blk.%d.ssm_f_b" }, + { LLM_TENSOR_SSM_BETA, "blk.%d.ssm_beta" }, + { LLM_TENSOR_SSM_G_A, "blk.%d.ssm_g_a" }, + { LLM_TENSOR_SSM_G_B, "blk.%d.ssm_g_b" }, + { LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" }, { LLM_TENSOR_ATTN_Q_A_NORM, "blk.%d.attn_q_a_norm" }, { LLM_TENSOR_ATTN_KV_A_NORM, "blk.%d.attn_kv_a_norm" }, { LLM_TENSOR_ATTN_Q_A, "blk.%d.attn_q_a" }, @@ -970,6 +987,63 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_SSM_NORM, LLM_TENSOR_SSM_OUT, }; + case LLM_ARCH_QWEN3_5: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_POST_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_Q_NORM, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_K_NORM, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_ATTN_QKV, + LLM_TENSOR_ATTN_GATE, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_SSM_A_NOSCAN, + LLM_TENSOR_SSM_CONV1D, + LLM_TENSOR_SSM_DT, + LLM_TENSOR_SSM_BETA_ALPHA, + LLM_TENSOR_SSM_IN, + LLM_TENSOR_SSM_NORM, + LLM_TENSOR_SSM_OUT, + }; + case LLM_ARCH_QWEN3_5_MOE: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_POST_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_Q_NORM, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_K_NORM, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_ATTN_QKV, + LLM_TENSOR_ATTN_GATE, + LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_GATE_EXPS, + LLM_TENSOR_FFN_DOWN_EXPS, + LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_GATE_INP_SHEXP, + LLM_TENSOR_FFN_GATE_SHEXP, + LLM_TENSOR_FFN_DOWN_SHEXP, + LLM_TENSOR_FFN_UP_SHEXP, + LLM_TENSOR_SSM_A_NOSCAN, + LLM_TENSOR_SSM_CONV1D, + LLM_TENSOR_SSM_DT, + LLM_TENSOR_SSM_BETA_ALPHA, + LLM_TENSOR_SSM_IN, + LLM_TENSOR_SSM_NORM, + LLM_TENSOR_SSM_OUT, + }; case LLM_ARCH_QWEN3VL: case LLM_ARCH_CHAMELEON: case LLM_ARCH_HUNYUAN_DENSE: @@ -2267,6 +2341,35 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_FFN_UP_EXPS, LLM_TENSOR_FFN_EXP_PROBS_B, }; + case LLM_ARCH_STEP35: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ROPE_FREQS, + LLM_TENSOR_ROPE_FACTORS_LONG, + LLM_TENSOR_ROPE_FACTORS_SHORT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_Q_NORM, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_K_NORM, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_GATE, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_GATE_EXPS, + LLM_TENSOR_FFN_DOWN_EXPS, + LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_GATE_SHEXP, + LLM_TENSOR_FFN_UP_SHEXP, + LLM_TENSOR_FFN_DOWN_SHEXP, + LLM_TENSOR_FFN_EXP_PROBS_B, + }; case LLM_ARCH_GPTJ: case LLM_ARCH_UNKNOWN: return { @@ -2289,6 +2392,54 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_FFN_DOWN, LLM_TENSOR_FFN_UP, }; + case LLM_ARCH_KIMI_LINEAR: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ROPE_FREQS, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_NORM, + // Dense FFN (layer 0 only) + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + // MoE FFN (layers 1+) + LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_GATE_EXPS, + LLM_TENSOR_FFN_DOWN_EXPS, + LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_EXP_PROBS_B, + // Shared experts + LLM_TENSOR_FFN_GATE_SHEXP, + LLM_TENSOR_FFN_DOWN_SHEXP, + LLM_TENSOR_FFN_UP_SHEXP, + // KDA (using SSM_ enum prefix, keeping GGUF names for backward compat) + LLM_TENSOR_SSM_CONV1D_Q, + LLM_TENSOR_SSM_CONV1D_K, + LLM_TENSOR_SSM_CONV1D_V, + LLM_TENSOR_SSM_F_A, + LLM_TENSOR_SSM_F_B, + LLM_TENSOR_SSM_BETA, + LLM_TENSOR_SSM_A, + LLM_TENSOR_SSM_G_A, + LLM_TENSOR_SSM_G_B, + LLM_TENSOR_SSM_DT, + LLM_TENSOR_SSM_NORM, + // MLA + LLM_TENSOR_ATTN_Q_A, + LLM_TENSOR_ATTN_Q_B, + LLM_TENSOR_ATTN_Q_A_NORM, + LLM_TENSOR_ATTN_KV_A_MQA, + LLM_TENSOR_ATTN_KV_B, + LLM_TENSOR_ATTN_K_B, + LLM_TENSOR_ATTN_V_B, + LLM_TENSOR_ATTN_KV_A_NORM, + }; default: GGML_ABORT("unknown architecture for tensor mapping"); } @@ -2392,6 +2543,15 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_SSM_C_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_SSM_D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_SSM_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + // Kimi KDA - Conv tensors are 4D [d_conv, 1, d_inner, 1], reshaped to 2D at runtime + {LLM_TENSOR_SSM_CONV1D_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_SSM_CONV1D_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_SSM_CONV1D_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_SSM_F_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_SSM_F_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_SSM_BETA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_SSM_G_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_SSM_G_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_TIME_MIX_LERP_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_TIME_MIX_LN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_CHANNEL_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, @@ -2573,6 +2733,9 @@ bool llm_arch_is_hybrid(const llm_arch & arch) { case LLM_ARCH_NEMOTRON_H: case LLM_ARCH_NEMOTRON_H_MOE: case LLM_ARCH_QWEN3NEXT: + case LLM_ARCH_QWEN3_5: + case LLM_ARCH_QWEN3_5_MOE: + case LLM_ARCH_KIMI_LINEAR: return true; default: return false; diff --git a/src/llama-arch.h b/src/llama-arch.h index 270d28b16a4..a392ecce2b4 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -39,6 +39,8 @@ enum llm_arch { LLM_ARCH_QWEN3, LLM_ARCH_QWEN3MOE, LLM_ARCH_QWEN3NEXT, + LLM_ARCH_QWEN3_5, + LLM_ARCH_QWEN3_5_MOE, LLM_ARCH_QWEN3VL, LLM_ARCH_QWEN3VLMOE, LLM_ARCH_PHI2, @@ -122,8 +124,10 @@ enum llm_arch { LLM_ARCH_PANGU_EMBED, LLM_ARCH_MISTRAL3, LLM_ARCH_MIMO2, + LLM_ARCH_STEP35, LLM_ARCH_LLAMA_EMBED, LLM_ARCH_MAINCODER, + LLM_ARCH_KIMI_LINEAR, LLM_ARCH_UNKNOWN, }; @@ -165,6 +169,8 @@ enum llm_kv { LLM_KV_EXPERT_FEED_FORWARD_LENGTH, LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, LLM_KV_EXPERT_CHUNK_FEED_FORWARD_LENGTH, + LLM_KV_SWIGLU_CLAMP_EXP, + LLM_KV_SWIGLU_CLAMP_SHEXP, LLM_KV_USE_PARALLEL_RESIDUAL, LLM_KV_TENSOR_DATA_LAYOUT, LLM_KV_EXPERT_COUNT, @@ -250,6 +256,8 @@ enum llm_kv { LLM_KV_SSM_GROUP_COUNT, LLM_KV_SSM_DT_B_C_RMS, + LLM_KV_KDA_HEAD_DIM, + LLM_KV_WKV_HEAD_SIZE, LLM_KV_TOKENIZER_MODEL, @@ -398,6 +406,15 @@ enum llm_tensor { LLM_TENSOR_SSM_NORM, LLM_TENSOR_SSM_OUT, LLM_TENSOR_SSM_BETA_ALPHA, // qwen3next + // Kimi Linear KDA (using SSM_ prefix for consistency) + LLM_TENSOR_SSM_CONV1D_Q, // kimi: Q conv1d weight + LLM_TENSOR_SSM_CONV1D_K, // kimi: K conv1d weight + LLM_TENSOR_SSM_CONV1D_V, // kimi: V conv1d weight + LLM_TENSOR_SSM_F_A, // kimi: forget gate projection A + LLM_TENSOR_SSM_F_B, // kimi: forget gate projection B + LLM_TENSOR_SSM_BETA, // kimi: beta mixing coefficient + LLM_TENSOR_SSM_G_A, // kimi: output gate projection A + LLM_TENSOR_SSM_G_B, // kimi: output gate projection B LLM_TENSOR_TIME_MIX_W0, LLM_TENSOR_TIME_MIX_W1, LLM_TENSOR_TIME_MIX_W2, diff --git a/src/llama-chat.cpp b/src/llama-chat.cpp index 3c7e0afdae8..c415a998f33 100644 --- a/src/llama-chat.cpp +++ b/src/llama-chat.cpp @@ -233,7 +233,7 @@ int32_t llm_chat_apply_template( llm_chat_template tmpl, const std::vector & chat, std::string & dest, bool add_ass) { - // Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527 + // Taken from the research: https://github.com/ggml-org/llama.cpp/issues/5527 std::stringstream ss; if (tmpl == LLM_CHAT_TEMPLATE_CHATML) { // chatml template diff --git a/src/llama-context.cpp b/src/llama-context.cpp index a6d5ddfa330..80b9a7d46a6 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -253,11 +253,7 @@ llama_context::llama_context( // graph outputs buffer { - // resized during inference when a batch uses more outputs - // Create a dummy batch for initialization. - llama_batch dummy_batch = {}; - dummy_batch.n_tokens = 0; - if (output_reserve(params.n_seq_max, dummy_batch) < params.n_seq_max) { + if (output_reserve(params.n_seq_max) < params.n_seq_max) { throw std::runtime_error("failed to reserve initial output buffer"); } @@ -321,6 +317,7 @@ llama_context::llama_context( auto dev_type = ggml_backend_dev_type(ggml_backend_get_device(backend.get())); if (dev_type == GGML_BACKEND_DEVICE_TYPE_CPU) { // ignore CPU backend + // TODO: should we ignore ACCEL types too? continue; } auto * dev = ggml_backend_get_device(backend.get()); @@ -793,7 +790,7 @@ float * llama_context::get_embeddings_ith(int32_t i) { throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs)); } - const uint32_t n_embd_out = model.hparams.get_n_embd_out(); + const uint32_t n_embd_out = model.hparams.n_embd_out(); return embd + j*n_embd_out; } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what()); @@ -1030,11 +1027,7 @@ bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) { llama_sampler_chain_n(sampler) > 0; if (sampler && can_offload) { - ggml_backend_buffer_type_t buft = ggml_backend_dev_buffer_type(model.dev_output()); - auto * host_buft = ggml_backend_dev_host_buffer_type(model.dev_output()); - if (host_buft) { - buft = host_buft; - } + auto * buft = ggml_backend_dev_buffer_type(model.dev_output()); sampler->iface->backend_init(sampler, buft); @@ -1225,7 +1218,7 @@ int llama_context::encode(const llama_batch & batch_inp) { n_queued_tokens += n_tokens; // reserve output buffer - if (output_reserve(n_tokens, batch_inp) < n_tokens) { + if (output_reserve(n_tokens) < n_tokens) { LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens); return -2; }; @@ -1279,7 +1272,7 @@ int llama_context::encode(const llama_batch & batch_inp) { { // extract token embeddings GGML_ASSERT(embd != nullptr); - const uint32_t n_embd_out = hparams.get_n_embd_out(); + const uint32_t n_embd_out = hparams.n_embd_out(); GGML_ASSERT(n_tokens*n_embd_out <= (int64_t) embd_size); ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd_out*sizeof(float)); @@ -1456,6 +1449,23 @@ static void copy_tensor_async_candidates( } } +static bool needs_raw_logits(const llama_ubatch & ubatch, const std::map & samplers) { + for (uint32_t i = 0; i < ubatch.n_tokens; i++) { + if (!ubatch.output[i]) { + continue; + } + + // Check if the output token has at least one sequence without a backend sampler. + for (int32_t j = 0; j < ubatch.n_seq_id[i]; ++j) { + llama_seq_id seq_id = ubatch.seq_id[i][j]; + if (samplers.find(seq_id) == samplers.end()) { + return true; + } + } + } + return false; // all sequences use backend sampling +} + int llama_context::decode(const llama_batch & batch_inp) { GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT @@ -1588,7 +1598,7 @@ int llama_context::decode(const llama_batch & batch_inp) { } // reserve output buffer - if (output_reserve(n_outputs_all, balloc->get_batch()) < n_outputs_all) { + if (output_reserve(n_outputs_all) < n_outputs_all) { LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all); return -2; }; @@ -1661,10 +1671,7 @@ int llama_context::decode(const llama_batch & batch_inp) { } // extract logits - // For multi-sequence batches that mix backend samplers and CPU sampler - // this is currently inefficient as we copy all logits even for the - // backend sampled tokens. - if (logits && t_logits && n_outputs > 0) { + if (logits && t_logits && n_outputs > 0 && needs_raw_logits(ubatch, sampling.samplers)) { ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); GGML_ASSERT(backend_res != nullptr); GGML_ASSERT(logits != nullptr); @@ -1688,7 +1695,7 @@ int llama_context::decode(const llama_batch & batch_inp) { { // extract token embeddings GGML_ASSERT(embd != nullptr); - const uint32_t n_embd_out = hparams.get_n_embd_out(); + const uint32_t n_embd_out = hparams.n_embd_out(); float * embd_out = embd + n_outputs_prev*n_embd_out; if (n_outputs) { @@ -1734,11 +1741,8 @@ int llama_context::decode(const llama_batch & batch_inp) { } } - // This flag indicates whether a backend sampler has actually sampled a specific - // token, or if it has produced probabilites. If true, we can skip the normal copying of logits and embeddings. - const bool has_sampled = !res->t_sampled.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty(); - - if (has_samplers && has_sampled) { + // Copy backend sampling output if this ubatch produced any sampling tensors. + if (has_samplers && (!res->t_sampled.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty())) { const auto seq_to_output_row = build_seq_to_output_row(ubatch, n_outputs_prev); const auto stride = n_vocab; @@ -1813,7 +1817,8 @@ int llama_context::decode(const llama_batch & batch_inp) { // output // -uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & batch) { +uint32_t llama_context::output_reserve(int32_t n_outputs) { + const auto & hparams = model.hparams; const auto & vocab = model.vocab; @@ -1821,7 +1826,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba const auto n_batch = cparams.n_batch; const auto n_vocab = vocab.n_tokens(); - const auto n_embd_out = hparams.get_n_embd_out(); + const auto n_embd_out = hparams.n_embd_out(); bool has_logits = true; bool has_embd = cparams.embeddings; @@ -1832,45 +1837,16 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba has_embd = true; } - // Check which sampling modes are needed for the current batch. - // TODO: avoid this branching by working with the worst-case - bool has_sampling = false; - bool cpu_logits = false; - - if (batch.logits) { - for (int32_t i = 0; i < batch.n_tokens; i++) { - if (!batch.logits[i]) { - continue; - } - for (int32_t j = 0; j < batch.n_seq_id[i]; j++) { - llama_seq_id seq_id = batch.seq_id[i][j]; - if (sampling.samplers.find(seq_id) != sampling.samplers.end()) { - has_sampling = true; - } else { - cpu_logits = true; - } - } - } - } else { - // When batch.logits is nullptr (when loading state with a dummy batch), - // allocate CPU logits. - cpu_logits = true; - } size_t backend_float_count = 0; size_t backend_token_count = 0; - // Allocate CPU logits buffer only if needed by sequences in this batch - logits_size = (has_logits && cpu_logits) ? n_vocab*n_outputs_max : 0; + logits_size = has_logits ? n_vocab*n_outputs_max : 0; embd_size = has_embd ? n_embd_out*n_outputs_max : 0; - // TODO: avoid this branching by working with the worst-case - if (!has_sampling) { - sampling.logits_size = 0; - sampling.probs_size = 0; - sampling.sampled_size = 0; - sampling.candidates_size = 0; - } else { + // Allocate backend sampling output buffers if there are backend samplers configured. + const bool has_sampling = !sampling.samplers.empty(); + if (has_sampling) { sampling.logits_size = n_vocab*n_outputs_max; sampling.probs_size = n_vocab*n_outputs_max; sampling.sampled_size = n_outputs_max; @@ -1928,7 +1904,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba size_t offset = 0; uint8_t * base = (uint8_t *) output_base; - logits = (has_logits && cpu_logits) ? output_base : nullptr; + logits = has_logits ? output_base : nullptr; offset += logits_size * sizeof(float); embd = has_embd ? (float *) (base + offset) : nullptr; @@ -2037,7 +2013,7 @@ void llama_context::output_reorder() { // uint32_t llama_context::graph_max_nodes(uint32_t n_tokens) const { - if (model.arch == LLM_ARCH_QWEN3NEXT) { + if (model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_QWEN3_5 || model.arch == LLM_ARCH_QWEN3_5_MOE || model.arch == LLM_ARCH_KIMI_LINEAR) { return std::max(n_tokens * 40, 32u * model.n_tensors()); } uint32_t res = std::max(1024u, 8u*model.n_tensors()); @@ -2173,13 +2149,6 @@ llm_graph_cb llama_context::graph_get_cb() const { ggml_set_name(cur, name); } - if (!cparams.offload_kqv) { - if (strcmp(name, "kqv_merged_cont") == 0) { - // all nodes between the KV store and the attention output are run on the CPU - ggml_backend_sched_set_tensor_backend(sched.get(), cur, backend_cpu); - } - } - // norm may be automatically assigned to the backend of the previous layer, increasing data transfer between backends // FIXME: fix in ggml_backend_sched const bool full_offload = model.n_gpu_layers() > model.hparams.n_layer; @@ -2559,6 +2528,7 @@ size_t llama_context::state_write_data(llama_io_write_i & io) { } } + // [TAG_CONTEXT_STATE_LOGITS] // write logits { LLAMA_LOG_DEBUG("%s: - writing logits\n", __func__); @@ -2620,10 +2590,7 @@ size_t llama_context::state_read_data(llama_io_read_i & io) { auto n_outputs = this->n_outputs; io.read_to(&n_outputs, sizeof(n_outputs)); - // Create a dummy batch for state loading. - llama_batch dummy_batch = {}; - dummy_batch.n_tokens = 0; - if (n_outputs > output_reserve(n_outputs, dummy_batch)) { + if (n_outputs > output_reserve(n_outputs)) { throw std::runtime_error("could not reserve outputs"); } @@ -2868,7 +2835,7 @@ void llama_context::opt_epoch_iter( } // reserve output buffer - if (output_reserve(n_outputs_all, balloc->get_batch()) < n_outputs_all) { + if (output_reserve(n_outputs_all) < n_outputs_all) { LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all); GGML_ABORT("TODO: handle this error"); }; @@ -2903,7 +2870,7 @@ void llama_context::opt_epoch_iter( }; ctx_compute_opt = ggml_init(params); } - ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_tokens(), res->get_logits()); + ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_inp_tokens(), res->get_logits()); ggml_opt_alloc(opt_ctx, train); res->set_inputs(&ubatch); diff --git a/src/llama-context.h b/src/llama-context.h index 86decc05fbc..8e71cdd1dc5 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -212,7 +212,7 @@ struct llama_context { // Make sure enough space is available for outputs. // Returns max number of outputs for which space was reserved. - uint32_t output_reserve(int32_t n_outputs, const llama_batch & batch); + uint32_t output_reserve(int32_t n_outputs); void output_reorder(); diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index 64ea2fd00a9..2d55070cecc 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -2,7 +2,7 @@ #include "llama-impl.h" #include "llama-vocab.h" -#include "llama-sampling.h" +#include "llama-sampler.h" #include #include diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 944c7e53bd2..bba747d37b5 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -7,11 +7,14 @@ #include "llama-kv-cache.h" #include "llama-kv-cache-iswa.h" #include "llama-memory-hybrid.h" +#include "llama-memory-hybrid-iswa.h" #include "llama-memory-recurrent.h" #include #include #include +#include +#include #include void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) { @@ -22,7 +25,8 @@ void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) { } if (ubatch->embd) { - const int64_t n_embd = embd->ne[0]; + GGML_ASSERT(n_embd == embd->ne[0]); + const int64_t n_tokens = ubatch->n_tokens; ggml_backend_tensor_set(embd, ubatch->embd, 0, n_tokens*n_embd*ggml_element_size(embd)); @@ -32,8 +36,8 @@ void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) { bool llm_graph_input_embd::can_reuse(const llm_graph_params & params) { bool res = true; - res &= (!tokens && !params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens); - res &= (!embd && !params.ubatch.embd) || (embd && embd->ne[1] == params.ubatch.n_tokens); + res &= (!params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens); + res &= (!params.ubatch.embd) || (embd && embd->ne[1] == params.ubatch.n_tokens); return res; } @@ -405,6 +409,27 @@ bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) { return res; } +void llm_graph_input_attn_k::set_input(const llama_ubatch * ubatch) { + mctx->set_input_k_idxs(self_k_idxs, ubatch); + + mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); +} + +bool llm_graph_input_attn_k::can_reuse(const llm_graph_params & params) { + const auto * mctx = static_cast(params.mctx); + + this->mctx = mctx; + + bool res = true; + + res &= self_k_idxs->ne[0] == params.ubatch.n_tokens; + + res &= self_kq_mask->ne[0] == mctx->get_n_kv(); + res &= self_kq_mask->ne[1] == params.ubatch.n_tokens; + + return res; +} + void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) { mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch); mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch); @@ -510,6 +535,120 @@ bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) { return res; } +// TODO: Hybrid input classes are a bit redundant. +// Instead of creating a hybrid input, the graph can simply create 2 separate inputs. +// Refactoring is required in the future. +void llm_graph_input_mem_hybrid_k::set_input(const llama_ubatch * ubatch) { + mctx->get_attn()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch); + + mctx->get_attn()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn); + + const int64_t n_rs = mctx->get_recr()->get_n_rs(); + + if (inp_rs->s_copy) { + GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer)); + int32_t * data = (int32_t *) inp_rs->s_copy->data; + + // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n + for (uint32_t i = 0; i < n_rs; ++i) { + data[i] = mctx->get_recr()->s_copy(i); + } + } +} + +bool llm_graph_input_mem_hybrid_k::can_reuse(const llm_graph_params & params) { + const auto * mctx = static_cast(params.mctx); + + this->mctx = mctx; + + bool res = true; + + res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens; + + res &= inp_attn->self_kq_mask->ne[0] == mctx->get_attn()->get_n_kv(); + res &= inp_attn->self_kq_mask->ne[1] == params.ubatch.n_tokens; + + res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs(); + + res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs; + res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs; + + res &= inp_rs->head == mctx->get_recr()->get_head(); + res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z(); + + return res; +} + +void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) { + const auto * attn_ctx = mctx->get_attn(); + + // base tensors may not be allocated if there are no non-SWA attention layers + if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) { + attn_ctx->get_base()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch); + attn_ctx->get_base()->set_input_v_idxs(inp_attn->self_v_idxs, ubatch); + + attn_ctx->get_base()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn); + } + + // swa tensors may not be allocated if there are no SWA attention layers + if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) { + attn_ctx->get_swa()->set_input_k_idxs(inp_attn->self_k_idxs_swa, ubatch); + attn_ctx->get_swa()->set_input_v_idxs(inp_attn->self_v_idxs_swa, ubatch); + + attn_ctx->get_swa()->set_input_kq_mask(inp_attn->self_kq_mask_swa, ubatch, cparams.causal_attn); + } + + const int64_t n_rs = mctx->get_recr()->get_n_rs(); + + if (inp_rs->s_copy) { + GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer)); + int32_t * data = (int32_t *) inp_rs->s_copy->data; + + // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n + for (uint32_t i = 0; i < n_rs; ++i) { + data[i] = mctx->get_recr()->s_copy(i); + } + } +} + +bool llm_graph_input_mem_hybrid_iswa::can_reuse(const llm_graph_params & params) { + const auto * mctx = static_cast(params.mctx); + + this->mctx = mctx; + + bool res = true; + + const auto * attn_ctx = mctx->get_attn(); + + // base tensors may not be allocated if there are no non-SWA attention layers + if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) { + res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens; + //res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there + + res &= inp_attn->self_kq_mask->ne[0] == attn_ctx->get_base()->get_n_kv(); + res &= inp_attn->self_kq_mask->ne[1] == params.ubatch.n_tokens; + } + + // swa tensors may not be allocated if there are no SWA attention layers + if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) { + res &= inp_attn->self_k_idxs_swa->ne[0] == params.ubatch.n_tokens; + //res &= inp_attn->self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there + + res &= inp_attn->self_kq_mask_swa->ne[0] == attn_ctx->get_swa()->get_n_kv(); + res &= inp_attn->self_kq_mask_swa->ne[1] == params.ubatch.n_tokens; + } + + res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs(); + + res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs; + res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs; + + res &= inp_rs->head == mctx->get_recr()->get_head(); + res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z(); + + return res; +} + void llm_graph_input_sampling::set_input(const llama_ubatch * ubatch) { // set the inputs only for the active samplers in the current ubatch std::unordered_set active_samplers; @@ -563,7 +702,8 @@ int64_t llm_graph_result::get_max_nodes() const { } void llm_graph_result::reset() { - t_tokens = nullptr; + t_inp_tokens = nullptr; + t_inp_embd = nullptr; t_logits = nullptr; t_embd = nullptr; t_embd_pooled = nullptr; @@ -876,6 +1016,26 @@ ggml_tensor * llm_graph_context::build_ffn( switch (type_op) { case LLM_FFN_SILU: if (gate && type_gate == LLM_FFN_PAR) { + // Step35: HF clamps gate (after SiLU) and up before multiplication + if (arch == LLM_ARCH_STEP35 && il >= 0) { + const float limit = hparams.swiglu_clamp_shexp[il]; + constexpr float eps = 1e-6f; + if (limit > eps) { + ggml_tensor * gate_act = ggml_silu(ctx0, cur); + cb(gate_act, "ffn_silu", il); + gate_act = ggml_clamp(ctx0, gate_act, -INFINITY, limit); + cb(gate_act, "ffn_silu_clamped", il); + + tmp = ggml_clamp(ctx0, tmp, -limit, limit); + cb(tmp, "ffn_up_clamped", il); + + cur = ggml_mul(ctx0, gate_act, tmp); + cb(cur, "ffn_swiglu_limited", il); + type_gate = LLM_FFN_SEQ; + break; + } + } + cur = ggml_swiglu_split(ctx0, cur, tmp); cb(cur, "ffn_swiglu", il); type_gate = LLM_FFN_SEQ; @@ -1178,6 +1338,25 @@ ggml_tensor * llm_graph_context::build_moe_ffn( switch (type_op) { case LLM_FFN_SILU: if (gate_exps) { + // Step35: per-layer clamp for routed experts + if (arch == LLM_ARCH_STEP35 && il >= 0) { + const float limit = hparams.swiglu_clamp_exp[il]; + constexpr float eps = 1e-6f; + if (limit > eps) { + ggml_tensor * gate_act = ggml_silu(ctx0, cur); + cb(gate_act, "ffn_moe_silu", il); + gate_act = ggml_clamp(ctx0, gate_act, -INFINITY, limit); + cb(gate_act, "ffn_moe_silu_clamped", il); + + up = ggml_clamp(ctx0, up, -limit, limit); + cb(up, "ffn_moe_up_clamped", il); + + cur = ggml_mul(ctx0, gate_act, up); + cb(cur, "ffn_moe_swiglu_limited", il); + break; + } + } + cur = ggml_swiglu_split(ctx0, cur, up); cb(cur, "ffn_moe_swiglu", il); } else { @@ -1267,17 +1446,29 @@ ggml_tensor * llm_graph_context::build_moe_ffn( // input embeddings with optional lora ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const { - const int64_t n_embd = hparams.n_embd_inp(); + const int64_t n_embd_inp = hparams.n_embd_inp(); + const int64_t n_embd = hparams.n_embd; + + assert(n_embd_inp >= n_embd); + + auto inp = std::make_unique(n_embd_inp); - auto inp = std::make_unique(); + inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens); + cb(inp->tokens, "inp_tokens", -1); + ggml_set_input(inp->tokens); + res->t_inp_tokens = inp->tokens; - ggml_tensor * cur = nullptr; + inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd_inp, ubatch.n_tokens); + cb(inp->embd, "inp_embd", -1); + ggml_set_input(inp->embd); - if (ubatch.token) { - inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens); - //cb(inp->tokens, "inp_tokens", -1); - ggml_set_input(inp->tokens); - res->t_tokens = inp->tokens; + // select one of the 2 inputs, based on the batch contents + // ref: https://github.com/ggml-org/llama.cpp/pull/18550 + std::array inps; + + // token embeddings path (ubatch.token != nullptr) + { + auto & cur = inps[0]; cur = ggml_get_rows(ctx0, tok_embd, inp->tokens); @@ -1298,19 +1489,36 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const { cur = ggml_add(ctx0, cur, inpL_delta); } - } else { - inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, ubatch.n_tokens); - ggml_set_input(inp->embd); + + if (n_embd_inp != n_embd) { + cur = ggml_pad(ctx0, cur, hparams.n_embd_inp() - n_embd, 0, 0, 0); + } + } + + // vector embeddings path (ubatch.embd != nullptr) + { + auto & cur = inps[1]; cur = inp->embd; } + assert(ggml_are_same_shape (inps[0], inps[1])); + assert(ggml_are_same_stride(inps[0], inps[1])); + + ggml_tensor * cur = ggml_build_forward_select(gf, inps.data(), inps.size(), ubatch.token ? 0 : 1); + + if (n_embd_inp != n_embd) { + cur = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0); + } + + res->t_inp_embd = cur; + // For Granite architecture if (hparams.f_embedding_scale != 0.0f) { cur = ggml_scale(ctx0, cur, hparams.f_embedding_scale); } - cb(cur, "inp_embd", -1); + cb(cur, "embd", -1); res->add_input(std::move(inp)); @@ -1409,7 +1617,7 @@ ggml_tensor * llm_graph_context::build_inp_cross_embd() const { //} const auto n_embd = !cross->v_embd.empty() ? cross->n_embd : hparams.n_embd_inp(); - const auto n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train; + const auto n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train; cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_enc); ggml_set_input(cur); @@ -1716,9 +1924,11 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * v_cur, ggml_tensor * kq_b, ggml_tensor * sinks, - ggml_tensor * v_mla, + ggml_tensor * v_mla, // TODO: remove float kq_scale, int il) const { + GGML_ASSERT(v_mla == nullptr); + // these nodes are added to the graph together so that they are not reordered // by doing so, the number of splits in the graph is reduced // expand k later to enable rope fusion which directly writes into k-v cache @@ -1761,6 +1971,93 @@ ggml_tensor * llm_graph_context::build_attn( return cur; } +static std::unique_ptr build_attn_inp_k_impl( + ggml_context * ctx0, + const llama_ubatch & ubatch, + const llama_hparams & hparams, + const llama_cparams & cparams, + const llama_kv_cache_context * mctx_cur) { + + auto inp = std::make_unique(hparams, cparams, mctx_cur); + + { + GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA"); + + const auto n_kv = mctx_cur->get_n_kv(); + const auto n_tokens = ubatch.n_tokens; + const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq; + + inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch); + + inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); + ggml_set_input(inp->self_kq_mask); + + inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; + } + + return inp; +} + +llm_graph_input_attn_k * llm_graph_context::build_attn_inp_k() const { + const auto * mctx_cur = static_cast(mctx); + + auto inp = build_attn_inp_k_impl(ctx0, ubatch, hparams, cparams, mctx_cur); + + return (llm_graph_input_attn_k *) res->add_input(std::move(inp)); +} + +ggml_tensor * llm_graph_context::build_attn( + llm_graph_input_attn_k * inp, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * q_cur, + ggml_tensor * k_cur, + ggml_tensor * v_cur, + ggml_tensor * kq_b, + ggml_tensor * sinks, + ggml_tensor * v_mla, + float kq_scale, + int il) const { + // these nodes are added to the graph together so that they are not reordered + // by doing so, the number of splits in the graph is reduced + // expand k later to enable rope fusion which directly writes into k-v cache + ggml_build_forward_expand(gf, q_cur); + ggml_build_forward_expand(gf, v_cur); + ggml_build_forward_expand(gf, k_cur); + + const auto * mctx_cur = inp->mctx; + + // store to KV cache + { + const auto & k_idxs = inp->get_k_idxs(); + + ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il)); + } + + const auto & kq_mask = inp->get_kq_mask(); + + ggml_tensor * q = q_cur; + ggml_tensor * k = mctx_cur->get_k(ctx0, il); + ggml_tensor * v = ggml_view_4d(ctx0, k, v_cur->ne[0], k->ne[1], k->ne[2], k->ne[3], k->nb[1], k->nb[2], k->nb[3], 0); + + ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il); + cb(cur, "kqv_out", il); + + if (wo) { + cur = build_lora_mm(wo, cur); + if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) { + // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators + ggml_mul_mat_set_prec(cur, GGML_PREC_F32); + } + } + + if (wo_b) { + cur = ggml_add(ctx0, cur, wo_b); + } + + return cur; +} + ggml_tensor * llm_graph_context::build_attn( llm_graph_input_attn_kv_iswa * inp, ggml_tensor * wo, @@ -2056,6 +2353,58 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const { return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp)); } +llm_graph_input_mem_hybrid_k * llm_graph_context::build_inp_mem_hybrid_k() const { + const auto * mctx_cur = static_cast(mctx); + + auto inp_rs = build_rs_inp_impl (ctx0, ubatch, mctx_cur->get_recr()); + auto inp_attn = build_attn_inp_k_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn()); + + auto inp = std::make_unique(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur); + + return (llm_graph_input_mem_hybrid_k *) res->add_input(std::move(inp)); +} + +llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa() const { + const auto * mctx_cur = static_cast(mctx); + + auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr()); + + // build iswa attention input + const auto * attn_ctx = mctx_cur->get_attn(); + + auto inp_attn = std::make_unique(hparams, cparams, attn_ctx); + + const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq; + + { + const auto n_kv = attn_ctx->get_base()->get_n_kv(); + + inp_attn->self_k_idxs = attn_ctx->get_base()->build_input_k_idxs(ctx0, ubatch); + inp_attn->self_v_idxs = attn_ctx->get_base()->build_input_v_idxs(ctx0, ubatch); + + inp_attn->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); + ggml_set_input(inp_attn->self_kq_mask); + + inp_attn->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask, GGML_TYPE_F16) : inp_attn->self_kq_mask; + } + + { + const auto n_kv = attn_ctx->get_swa()->get_n_kv(); + + inp_attn->self_k_idxs_swa = attn_ctx->get_swa()->build_input_k_idxs(ctx0, ubatch); + inp_attn->self_v_idxs_swa = attn_ctx->get_swa()->build_input_v_idxs(ctx0, ubatch); + + inp_attn->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); + ggml_set_input(inp_attn->self_kq_mask_swa); + + inp_attn->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask_swa, GGML_TYPE_F16) : inp_attn->self_kq_mask_swa; + } + + auto inp = std::make_unique(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur); + + return (llm_graph_input_mem_hybrid_iswa *) res->add_input(std::move(inp)); +} + void llm_graph_context::build_dense_out( ggml_tensor * dense_2, ggml_tensor * dense_3) const { @@ -2166,6 +2515,9 @@ void llm_graph_context::build_sampling() const { return; } + std::array outs; + outs[0] = res->t_logits; + auto inp_sampling = std::make_unique(samplers); res->add_input(std::move(inp_sampling)); @@ -2186,14 +2538,14 @@ void llm_graph_context::build_sampling() const { // add a dummy row of logits // this trick makes the graph static, regardless of which samplers are activated // this is important in order to minimize graph reallocations - // TODO: use `ggml_build_forward_select()` when available (https://github.com/ggml-org/llama.cpp/pull/18550) ggml_tensor * logits_t = ggml_pad(ctx0, res->t_logits, 0, 1, 0, 0); for (const auto & [seq_id, sampler] : samplers) { const auto it = seq_to_logit_row.find(seq_id); // inactive samplers always work on the first row - const auto row_idx = seq_to_logit_row.find(seq_id) != seq_to_logit_row.end() ? it->second : 0; + const auto row_idx = it != seq_to_logit_row.end() ? it->second : 0; + const int i_out = it != seq_to_logit_row.end() ? 1 : 0; ggml_tensor * logits_seq = ggml_view_1d(ctx0, logits_t, logits_t->ne[0], row_idx * logits_t->nb[1]); ggml_format_name(logits_seq, "logits_seq_%d", seq_id); @@ -2210,22 +2562,26 @@ void llm_graph_context::build_sampling() const { if (data.sampled != nullptr) { res->t_sampled[seq_id] = data.sampled; - ggml_build_forward_expand(gf, data.sampled); + outs[1] = data.sampled; + ggml_build_forward_select(gf, outs.data(), outs.size(), i_out); } if (data.probs != nullptr) { res->t_sampled_probs[seq_id] = data.probs; - ggml_build_forward_expand(gf, data.probs); + outs[1] = data.probs; + ggml_build_forward_select(gf, outs.data(), outs.size(), i_out); } if (data.logits != nullptr) { res->t_sampled_logits[seq_id] = data.logits; - ggml_build_forward_expand(gf, data.logits); + outs[1] = data.logits; + ggml_build_forward_select(gf, outs.data(), outs.size(), i_out); } if (data.candidates != nullptr) { res->t_candidates[seq_id] = data.candidates; - ggml_build_forward_expand(gf, data.candidates); + outs[1] = data.candidates; + ggml_build_forward_select(gf, outs.data(), outs.size(), i_out); } } diff --git a/src/llama-graph.h b/src/llama-graph.h index 503ffd695aa..1d69ff1a6fc 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -24,6 +24,7 @@ class llama_kv_cache_context; class llama_kv_cache_iswa_context; class llama_memory_recurrent_context; class llama_memory_hybrid_context; +class llama_memory_hybrid_iswa_context; // certain models (typically multi-modal) can produce different types of graphs enum llm_graph_type { @@ -105,7 +106,7 @@ using llm_graph_input_ptr = std::unique_ptr; class llm_graph_input_embd : public llm_graph_input_i { public: - llm_graph_input_embd() = default; + llm_graph_input_embd(int64_t n_embd) : n_embd(n_embd) {} virtual ~llm_graph_input_embd() = default; void set_input(const llama_ubatch * ubatch) override; @@ -114,6 +115,8 @@ class llm_graph_input_embd : public llm_graph_input_i { ggml_tensor * tokens = nullptr; // I32 [n_batch] ggml_tensor * embd = nullptr; // F32 [n_embd, n_batch] + + const int64_t n_embd = 0; }; class llm_graph_input_pos : public llm_graph_input_i { @@ -314,6 +317,39 @@ class llm_graph_input_attn_kv : public llm_graph_input_i { const llama_kv_cache_context * mctx; }; +// V-less input for the KV cache +// ref: https://github.com/ggml-org/llama.cpp/pull/19067 +class llm_graph_input_attn_k : public llm_graph_input_i { +public: + llm_graph_input_attn_k( + const llama_hparams & hparams, + const llama_cparams & cparams, + const llama_kv_cache_context * mctx) : + hparams(hparams), + cparams(cparams), + mctx(mctx) { + } + ~llm_graph_input_attn_k() = default; + + void set_input(const llama_ubatch * ubatch) override; + + bool can_reuse(const llm_graph_params & params) override; + + ggml_tensor * get_k_idxs() const { return self_k_idxs; } + + ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; } + + ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch] + + ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] + + const llama_hparams hparams; + const llama_cparams cparams; + + const llama_kv_cache_context * mctx; +}; + class llm_graph_input_attn_kv_iswa : public llm_graph_input_i { public: llm_graph_input_attn_kv_iswa( @@ -397,6 +433,62 @@ class llm_graph_input_mem_hybrid : public llm_graph_input_i { const llama_memory_hybrid_context * mctx; }; +class llm_graph_input_mem_hybrid_k : public llm_graph_input_i { +public: + llm_graph_input_mem_hybrid_k( + const llama_cparams & cparams, + std::unique_ptr inp_attn, + std::unique_ptr inp_rs, + const llama_memory_hybrid_context * mctx) : + inp_attn(std::move(inp_attn)), + inp_rs(std::move(inp_rs)), + cparams(cparams), + mctx(mctx) { } + virtual ~llm_graph_input_mem_hybrid_k() = default; + + void set_input(const llama_ubatch * ubatch) override; + + bool can_reuse(const llm_graph_params & params) override; + + std::unique_ptr inp_attn; + std::unique_ptr inp_rs; + + llm_graph_input_attn_k * get_attn() const { return inp_attn.get(); } + llm_graph_input_rs * get_recr() const { return inp_rs.get(); } + + const llama_cparams cparams; + + const llama_memory_hybrid_context * mctx; +}; + +class llm_graph_input_mem_hybrid_iswa : public llm_graph_input_i { +public: + llm_graph_input_mem_hybrid_iswa( + const llama_cparams & cparams, + std::unique_ptr inp_attn, + std::unique_ptr inp_rs, + const llama_memory_hybrid_iswa_context * mctx) : + inp_attn(std::move(inp_attn)), + inp_rs(std::move(inp_rs)), + cparams(cparams), + mctx(mctx) { } + virtual ~llm_graph_input_mem_hybrid_iswa() = default; + + void set_input(const llama_ubatch * ubatch) override; + + bool can_reuse(const llm_graph_params & params) override; + + std::unique_ptr inp_attn; + std::unique_ptr inp_rs; + + llm_graph_input_attn_kv_iswa * get_attn() const { return inp_attn.get(); } + llm_graph_input_rs * get_recr() const { return inp_rs.get(); } + + const llama_cparams cparams; + + const llama_memory_hybrid_iswa_context * mctx; +}; + class llm_graph_input_sampling : public llm_graph_input_i { public: llm_graph_input_sampling(std::map samplers) : @@ -537,7 +629,7 @@ class llm_graph_result { virtual ~llm_graph_result() = default; - ggml_tensor * get_tokens() const { return t_tokens; } + ggml_tensor * get_inp_tokens() const { return t_inp_tokens; } ggml_tensor * get_logits() const { return t_logits; } ggml_tensor * get_embd() const { return t_embd; } ggml_tensor * get_embd_pooled() const { return t_embd_pooled; } @@ -564,7 +656,8 @@ class llm_graph_result { void set_params(const llm_graph_params & params); // important graph nodes - ggml_tensor * t_tokens = nullptr; + ggml_tensor * t_inp_tokens = nullptr; + ggml_tensor * t_inp_embd = nullptr; // [n_embd_inp, n_tokens] ggml_tensor * t_logits = nullptr; ggml_tensor * t_embd = nullptr; ggml_tensor * t_embd_pooled = nullptr; @@ -801,6 +894,21 @@ struct llm_graph_context { ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] ggml_tensor * kq_b, ggml_tensor * sinks, // [n_head_q] + ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] // TODO: remove + float kq_scale, + int il) const; + + llm_graph_input_attn_k * build_attn_inp_k() const; + + ggml_tensor * build_attn( + llm_graph_input_attn_k * inp, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] + ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] + ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] + ggml_tensor * kq_b, + ggml_tensor * sinks, // [n_head_q] ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] float kq_scale, int il) const; @@ -880,6 +988,9 @@ struct llm_graph_context { // llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const; + llm_graph_input_mem_hybrid_k * build_inp_mem_hybrid_k() const; + + llm_graph_input_mem_hybrid_iswa * build_inp_mem_hybrid_iswa() const; // // pooling diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index 5f1df995f3a..756dda1a7ab 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -72,8 +72,8 @@ uint32_t llama_hparams::n_embd_inp() const { return n_embd_inp; } -uint32_t llama_hparams::get_n_embd_out() const { - return n_embd_out > 0 ? n_embd_out : n_embd; +uint32_t llama_hparams::n_embd_out() const { + return n_embd_out_impl > 0 ? n_embd_out_impl : n_embd; } uint32_t llama_hparams::n_embd_k_gqa(uint32_t il) const { @@ -139,6 +139,13 @@ uint32_t llama_hparams::n_embd_r() const { return n_embd * (n_shortconv_l_cache - 1); } + if (n_embd_head_kda != 0) { + // for Kimi KDA layers + // Conv state for Q, K, V: 3 * (d_conv - 1) * n_head * head_dim + const uint32_t d_inner = n_head() * n_embd_head_kda; // 32 * 128 = 4096 + return 3 * (ssm_d_conv > 0 ? ssm_d_conv - 1 : 3) * d_inner; + } + // TODO: maybe support other convolution strides than 1 // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed // Corresponds to Mamba's conv_states size @@ -151,6 +158,13 @@ uint32_t llama_hparams::n_embd_s() const { return n_embd * wkv_head_size; } + if (n_embd_head_kda != 0) { + // for Kimi KDA layers + // Full recurrent state: head_dim * head_dim * n_head + // h tensor shape for delta attention: [head_dim, head_dim, n_head] + return n_embd_head_kda * n_embd_head_kda * n_head(); // 128 * 128 * 32 = 524288 + } + // corresponds to Mamba's ssm_states size return ssm_d_state * ssm_d_inner; } @@ -175,6 +189,21 @@ bool llama_hparams::is_swa(uint32_t il) const { GGML_ABORT("fatal error"); } +bool llama_hparams::is_mla() const { + assert((n_embd_head_k_mla_impl == 0 && n_embd_head_v_mla_impl == 0) || + (n_embd_head_k_mla_impl != 0 && n_embd_head_v_mla_impl != 0)); + + return n_embd_head_k_mla_impl != 0 && n_embd_head_v_mla_impl != 0; +} + +uint32_t llama_hparams::n_embd_head_k_mla() const { + return is_mla() ? n_embd_head_k_mla_impl : n_embd_head_k; +} + +uint32_t llama_hparams::n_embd_head_v_mla() const { + return is_mla() ? n_embd_head_v_mla_impl : n_embd_head_v; +} + bool llama_hparams::has_kv(uint32_t il) const { if (n_layer_kv_from_start >= 0) { if (il < (uint32_t) n_layer_kv_from_start) { diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 2bf86655208..6c695bdbf66 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -53,8 +53,8 @@ struct llama_hparams { uint32_t n_rel_attn_bkts = 0; // note: deepseek2 using MLA converts into MQA with larger heads, then decompresses to MHA - uint32_t n_embd_head_k_mla = 0; - uint32_t n_embd_head_v_mla = 0; + uint32_t n_embd_head_k_mla_impl = 0; + uint32_t n_embd_head_v_mla_impl = 0; // for WavTokenizer struct llama_hparams_posnet posnet; @@ -137,6 +137,9 @@ struct llama_hparams { uint32_t ssm_dt_rank = 0; uint32_t ssm_n_group = 0; + // for Kimi Linear KDA + uint32_t n_embd_head_kda = 0; + // for hybrid state space models std::array recurrent_layer_arr; @@ -164,7 +167,7 @@ struct llama_hparams { uint32_t n_cls_out = 1; // output embedding dimension (0 = use n_embd) - uint32_t n_embd_out = 0; + uint32_t n_embd_out_impl = 0; // llama4 smallthinker uint32_t n_moe_layer_step = 0; @@ -195,7 +198,7 @@ struct llama_hparams { uint32_t n_deepstack_layers = 0; // needed by encoder-decoder models (e.g. T5, FLAN-T5) - // ref: https://github.com/ggerganov/llama.cpp/pull/8141 + // ref: https://github.com/ggml-org/llama.cpp/pull/8141 llama_token dec_start_token_id = LLAMA_TOKEN_NULL; uint32_t dec_n_layer = 0; @@ -203,6 +206,11 @@ struct llama_hparams { enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE; enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE; + + // Step35: optional per-layer clamps for (Swi)GLU + std::array swiglu_clamp_exp; // clamping for expert FFN + std::array swiglu_clamp_shexp; // shared expert + // this value n_pattern means that every nth layer is dense (i.e. non-SWA) // dense_first means whether the pattern is start with a dense layer // note that if n_pattern == 0, all layers are SWA @@ -239,7 +247,7 @@ struct llama_hparams { uint32_t n_embd_inp() const; // dimension of output embeddings - uint32_t get_n_embd_out() const; + uint32_t n_embd_out() const; // dimension of key embeddings across all k-v heads uint32_t n_embd_k_gqa(uint32_t il = 0) const; @@ -269,6 +277,12 @@ struct llama_hparams { bool is_swa(uint32_t il) const; + // note: currently only support if either all or none of the layers are MLA + bool is_mla() const; + + uint32_t n_embd_head_k_mla() const; + uint32_t n_embd_head_v_mla() const; + bool has_kv(uint32_t il) const; // number of layers for which has_kv() returns true diff --git a/src/llama-kv-cache-iswa.cpp b/src/llama-kv-cache-iswa.cpp index 3a34102a23d..26e2cb4270b 100644 --- a/src/llama-kv-cache-iswa.cpp +++ b/src/llama-kv-cache-iswa.cpp @@ -218,7 +218,9 @@ llama_memory_context_ptr llama_kv_cache_iswa::init_update(llama_context * lctx, } bool llama_kv_cache_iswa::get_can_shift() const { - return kv_base->get_size() == kv_swa->get_size(); + return kv_base->get_can_shift() && + kv_swa->get_can_shift() && + kv_base->get_size() == kv_swa->get_size(); } void llama_kv_cache_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index fd9f97d52e8..cb702b2a59f 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -97,6 +97,8 @@ llama_kv_cache::llama_kv_cache( __func__, hparams.n_embd_v_gqa_max()); } + const bool is_mla = hparams.is_mla(); + for (uint32_t il = 0; il < hparams.n_layer; il++) { if (!hparams.has_kv(il)) { LLAMA_LOG_DEBUG("%s: layer %3d: does not have KV cache\n", __func__, il); @@ -130,18 +132,21 @@ llama_kv_cache::llama_kv_cache( throw std::runtime_error("failed to create ggml context for kv cache"); } - ggml_tensor * k = ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream); - ggml_tensor * v = ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream); + const bool has_k = true; + const bool has_v = !is_mla; + + ggml_tensor * k = has_k ? ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream) : nullptr; + ggml_tensor * v = has_v ? ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream) : nullptr; - ggml_format_name(k, "cache_k_l%d", il); - ggml_format_name(v, "cache_v_l%d", il); + has_k && ggml_format_name(k, "cache_k_l%d", il); + has_v && ggml_format_name(v, "cache_v_l%d", il); std::vector k_stream; std::vector v_stream; for (uint32_t s = 0; s < n_stream; ++s) { - k_stream.push_back(ggml_view_2d(ctx, k, n_embd_k_gqa, kv_size, k->nb[1], s*k->nb[2])); - v_stream.push_back(ggml_view_2d(ctx, v, n_embd_v_gqa, kv_size, v->nb[1], s*v->nb[2])); + k_stream.push_back(has_k ? ggml_view_2d(ctx, k, n_embd_k_gqa, kv_size, k->nb[1], s*k->nb[2]) : nullptr); + v_stream.push_back(has_v ? ggml_view_2d(ctx, v, n_embd_v_gqa, kv_size, v->nb[1], s*v->nb[2]) : nullptr); } map_layer_ids[il] = layers.size(); @@ -647,7 +652,10 @@ bool llama_kv_cache::update(llama_context * lctx, bool do_shift, const stream_co const auto & layer = layers[il]; ggml_backend_tensor_copy(layer.k_stream[ssrc], layer.k_stream[sdst]); - ggml_backend_tensor_copy(layer.v_stream[ssrc], layer.v_stream[sdst]); + + if (layer.v_stream[ssrc]) { + ggml_backend_tensor_copy(layer.v_stream[ssrc], layer.v_stream[sdst]); + } } } } @@ -966,6 +974,10 @@ void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch & } bool llama_kv_cache::get_can_shift() const { + // Step35 uses per-layer RoPE dims; K-shift assumes a single global n_rot. + if (model.arch == LLM_ARCH_STEP35) { + return false; + } return true; } @@ -1516,7 +1528,7 @@ size_t llama_kv_cache::size_v_bytes() const { size_t size_v_bytes = 0; for (const auto & layer : layers) { - size_v_bytes += ggml_nbytes(layer.v); + size_v_bytes += layer.v ? ggml_nbytes(layer.v) : 0; } return size_v_bytes; @@ -1594,6 +1606,10 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co const auto & n_embd_head_k = hparams.n_embd_head_k; //const auto & n_embd_head_v = hparams.n_embd_head_v; + const auto & n_rot = hparams.n_rot; + + const auto n_embd_nope = hparams.n_lora_kv > 0 ? n_embd_head_k - n_rot : 0; + auto inp = std::make_unique(this); inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, (int64_t) get_size()*n_stream); @@ -1614,10 +1630,10 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co ggml_tensor * k = ggml_view_3d(ctx, layer.k, - n_embd_head_k, n_head_kv, get_size()*n_stream, + n_rot, n_head_kv, get_size()*n_stream, ggml_row_size(layer.k->type, n_embd_head_k), ggml_row_size(layer.k->type, n_embd_k_gqa), - 0); + ggml_row_size(layer.k->type, n_embd_nope)); ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l); @@ -1760,8 +1776,6 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t io.write(&v_trans, sizeof(v_trans)); io.write(&n_layer, sizeof(n_layer)); - std::vector tmp_buf; - // Iterate and write all the keys first, each row is a cell // Get whole range at a time for (const auto & layer : layers) { @@ -1779,7 +1793,7 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t const uint64_t k_size_row = ggml_row_size(k->type, n_embd_k_gqa); io.write(&k_size_row, sizeof(k_size_row)); - // Read each range of cells of k_size length each into tmp_buf and write out + // Read each range of cells of k_size length and write out for (const auto & range : cr.data) { const size_t range_size = range.second - range.first; const size_t buf_size = range_size * k_size_row; @@ -1794,6 +1808,9 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); auto * v = layer.v_stream[cr.strm]; + if (!v) { + continue; + } // Write value type const int32_t v_type_i = (int32_t) v->type; @@ -1803,7 +1820,7 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t const uint64_t v_size_row = ggml_row_size(v->type, n_embd_v_gqa); io.write(&v_size_row, sizeof(v_size_row)); - // Read each range of cells of v_size length each into tmp_buf and write out + // Read each range of cells of v_size length and write out for (const auto & range : cr.data) { const size_t range_size = range.second - range.first; const size_t buf_size = range_size * v_size_row; @@ -1820,6 +1837,9 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); auto * v = layer.v_stream[cr.strm]; + if (!v) { + continue; + } // Write value type const int32_t v_type_i = (int32_t) v->type; @@ -1834,7 +1854,7 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t // For each row, we get the element values of each cell for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { - // Read each range of cells of v_size_el length each into tmp_buf and write out + // Read each range of cells of v_size_el length and write out for (const auto & range : cr.data) { const size_t range_size = range.second - range.first; const size_t src_offset = (range.first + j * kv_size) * v_size_el; @@ -2023,6 +2043,9 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32 const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); auto * v = layer.v_stream[strm]; + if (!v) { + continue; + } // Read type of value int32_t v_type_i_ref; @@ -2064,6 +2087,9 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32 const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); auto * v = layer.v_stream[strm]; + if (!v) { + continue; + } // Read type of value int32_t v_type_i_ref; diff --git a/src/llama-memory-hybrid-iswa.cpp b/src/llama-memory-hybrid-iswa.cpp new file mode 100644 index 00000000000..411769672af --- /dev/null +++ b/src/llama-memory-hybrid-iswa.cpp @@ -0,0 +1,275 @@ +#include "llama-memory-hybrid-iswa.h" + +#include "llama-impl.h" +#include "llama-model.h" +#include "llama-context.h" + +// +// llama_memory_hybrid_iswa +// + +llama_memory_hybrid_iswa::llama_memory_hybrid_iswa( + const llama_model & model, + /* attn */ + ggml_type type_k, + ggml_type type_v, + bool v_trans, + bool swa_full, + uint32_t kv_size, + uint32_t n_ubatch, + uint32_t n_pad, + /* recurrent */ + ggml_type type_r, + ggml_type type_s, + uint32_t rs_size, + /* common */ + uint32_t n_seq_max, + bool offload, + bool unified, + /* layer filters */ + const layer_filter_cb & filter_attn, + const layer_filter_cb & filter_recr) : + hparams(model.hparams), + mem_attn(new llama_kv_cache_iswa( + model, + type_k, + type_v, + v_trans, + offload, + swa_full, + unified, + kv_size, + n_seq_max, + n_ubatch, + n_pad, + filter_attn == nullptr ? + [&](int32_t il) { return !hparams.is_recurrent(il); } + : filter_attn, + nullptr + )), + mem_recr(new llama_memory_recurrent( + model, + type_r, + type_s, + offload, + rs_size, + n_seq_max, + filter_recr == nullptr ? + [&](int32_t il) { return hparams.is_recurrent(il); } + : filter_recr + )) {} + +llama_memory_context_ptr llama_memory_hybrid_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) { + do { + balloc.split_reset(); + + // follow the recurrent pattern for creating the ubatch splits + std::vector ubatches; + + while (true) { + llama_ubatch ubatch; + + if (embd_all) { + // if all tokens are output, split by sequence + ubatch = balloc.split_seq(n_ubatch); + } else { + // TODO: non-sequential equal split can be done if using unified KV cache + // for simplicity, we always use sequential equal split for now + ubatch = balloc.split_equal(n_ubatch, true); + } + + if (ubatch.n_tokens == 0) { + break; + } + + ubatches.push_back(std::move(ubatch)); // NOLINT + } + + if (balloc.get_n_used() < balloc.get_n_tokens()) { + // failed to find a suitable split + break; + } + + // prepare the recurrent batches first + if (!mem_recr->prepare(ubatches)) { + // TODO: will the recurrent cache be in an undefined context at this point? + LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__); + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + } + + // prepare the attention cache (iswa version returns both base and swa slot infos) + auto sinfos_base = mem_attn->get_base()->prepare(ubatches); + if (sinfos_base.empty()) { + LLAMA_LOG_ERROR("%s: failed to prepare attention base ubatches\n", __func__); + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + } + + auto sinfos_swa = mem_attn->get_swa()->prepare(ubatches); + if (sinfos_swa.empty()) { + LLAMA_LOG_ERROR("%s: failed to prepare attention swa ubatches\n", __func__); + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + } + + return std::make_unique( + this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches)); + } while(false); + + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); +} + +llama_memory_context_ptr llama_memory_hybrid_iswa::init_full() { + return std::make_unique(this); +} + +llama_memory_context_ptr llama_memory_hybrid_iswa::init_update(llama_context * lctx, bool optimize) { + return std::make_unique(this, lctx, optimize); +} + +bool llama_memory_hybrid_iswa::get_can_shift() const { + // Shifting is trivially supported for recurrent + return mem_attn->get_can_shift(); +} + +void llama_memory_hybrid_iswa::clear(bool data) { + mem_attn->clear(data); + mem_recr->clear(data); +} + +bool llama_memory_hybrid_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + // Try removing from the recurrent cache first since it may fail. If it does + // fail, the cache will not have been mutated. + if (!mem_recr->seq_rm(seq_id, p0, p1)) { + return false; + } + return mem_attn->seq_rm(seq_id, p0, p1); +} + +void llama_memory_hybrid_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + mem_attn->seq_cp(seq_id_src, seq_id_dst, p0, p1); + mem_recr->seq_cp(seq_id_src, seq_id_dst, p0, p1); +} + +void llama_memory_hybrid_iswa::seq_keep(llama_seq_id seq_id) { + mem_attn->seq_keep(seq_id); + mem_recr->seq_keep(seq_id); +} + +void llama_memory_hybrid_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { + mem_attn->seq_add(seq_id, p0, p1, shift); + mem_recr->seq_add(seq_id, p0, p1, shift); +} + +void llama_memory_hybrid_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { + mem_attn->seq_div(seq_id, p0, p1, d); + mem_recr->seq_div(seq_id, p0, p1, d); +} + +llama_pos llama_memory_hybrid_iswa::seq_pos_min(llama_seq_id seq_id) const { + // the min of the total cache is the max of the two caches' min values + return std::max(mem_attn->seq_pos_min(seq_id), mem_recr->seq_pos_min(seq_id)); +} + +llama_pos llama_memory_hybrid_iswa::seq_pos_max(llama_seq_id seq_id) const { + // the max of the total cache is the min of the two caches' max values + return std::min(mem_attn->seq_pos_max(seq_id), mem_recr->seq_pos_max(seq_id)); +} + +std::map llama_memory_hybrid_iswa::memory_breakdown() const { + std::map mb = mem_attn->memory_breakdown(); + for (const auto & buft_size : mem_recr->memory_breakdown()) { + mb[buft_size.first] += buft_size.second; + } + return mb; +} + +void llama_memory_hybrid_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { + mem_attn->state_write(io, seq_id, flags); + mem_recr->state_write(io, seq_id, flags); +} + +void llama_memory_hybrid_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { + mem_attn->state_read(io, seq_id, flags); + mem_recr->state_read(io, seq_id, flags); +} + +llama_kv_cache_iswa * llama_memory_hybrid_iswa::get_mem_attn() const { + return mem_attn.get(); +} + +llama_memory_recurrent * llama_memory_hybrid_iswa::get_mem_recr() const { + return mem_recr.get(); +} + +// +// llama_memory_hybrid_iswa_context +// + +llama_memory_hybrid_iswa_context::llama_memory_hybrid_iswa_context(llama_memory_status status) : status(status) {} + +llama_memory_hybrid_iswa_context::llama_memory_hybrid_iswa_context(llama_memory_hybrid_iswa * mem) : + ctx_attn(mem->get_mem_attn()->init_full()), + ctx_recr(mem->get_mem_recr()->init_full()), + status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) { +} + +llama_memory_hybrid_iswa_context::llama_memory_hybrid_iswa_context( + llama_memory_hybrid_iswa * mem, + llama_context * lctx, + bool optimize) : + ctx_attn(mem->get_mem_attn()->init_update(lctx, optimize)), + ctx_recr(mem->get_mem_recr()->init_update(lctx, optimize)), + status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) { +} + +llama_memory_hybrid_iswa_context::llama_memory_hybrid_iswa_context( + llama_memory_hybrid_iswa * mem, + slot_info_vec_t sinfos_base, + slot_info_vec_t sinfos_swa, + std::vector ubatches) : + ubatches(std::move(ubatches)), + // note: here we copy the ubatches. not sure if this is ideal + ctx_attn(new llama_kv_cache_iswa_context(mem->get_mem_attn(), std::move(sinfos_base), std::move(sinfos_swa), this->ubatches)), + ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)), + status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) { +} + +bool llama_memory_hybrid_iswa_context::next() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + ctx_attn->next(); + ctx_recr->next(); + + if (++i_next >= ubatches.size()) { + return false; + } + + return true; +} + +bool llama_memory_hybrid_iswa_context::apply() { + assert(!llama_memory_status_is_fail(status)); + + bool res = true; + + res = res & ctx_attn->apply(); + res = res & ctx_recr->apply(); + + return res; +} + +llama_memory_status llama_memory_hybrid_iswa_context::get_status() const { + return status; +} + +const llama_ubatch & llama_memory_hybrid_iswa_context::get_ubatch() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + return ubatches[i_next]; +} + +const llama_kv_cache_iswa_context * llama_memory_hybrid_iswa_context::get_attn() const { + return static_cast(ctx_attn.get()); +} + +const llama_memory_recurrent_context * llama_memory_hybrid_iswa_context::get_recr() const { + return static_cast(ctx_recr.get()); +} diff --git a/src/llama-memory-hybrid-iswa.h b/src/llama-memory-hybrid-iswa.h new file mode 100644 index 00000000000..807c8aac96c --- /dev/null +++ b/src/llama-memory-hybrid-iswa.h @@ -0,0 +1,140 @@ +#pragma once + +#include "llama-batch.h" +#include "llama-graph.h" +#include "llama-kv-cache-iswa.h" +#include "llama-memory.h" +#include "llama-memory-recurrent.h" + +#include +#include + +// +// llama_memory_hybrid_iswa +// + +// utilizes instances of llama_memory_recurrent and llama_kv_cache_iswa to +// support models where each layer may be either attention-based (with SWA support) or recurrent + +class llama_memory_hybrid_iswa : public llama_memory_i { +public: + llama_memory_hybrid_iswa( + const llama_model & model, + /* attn */ + ggml_type type_k, + ggml_type type_v, + bool v_trans, + bool swa_full, + uint32_t kv_size, + uint32_t n_ubatch, + uint32_t n_pad, + /* recurrent */ + ggml_type type_r, + ggml_type type_s, + uint32_t rs_size, + /* common */ + uint32_t n_seq_max, + bool offload, + bool unified, + /* layer filters */ + const layer_filter_cb & filter_attn = nullptr, + const layer_filter_cb & filter_recr = nullptr); + + ~llama_memory_hybrid_iswa() = default; + + // + // llama_memory_i + // + + llama_memory_context_ptr init_batch( + llama_batch_allocr & balloc, + uint32_t n_ubatch, + bool embd_all) override; + + llama_memory_context_ptr init_full() override; + + llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override; + + bool get_can_shift() const override; + + void clear(bool data) override; + + bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; + void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; + void seq_keep(llama_seq_id seq_id) override; + void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override; + void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; + + llama_pos seq_pos_min(llama_seq_id seq_id) const override; + llama_pos seq_pos_max(llama_seq_id seq_id) const override; + + std::map memory_breakdown() const override; + + // state write/load + + void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override; + void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override; + + // + // llama_memory_hybrid_iswa specific API + // + + llama_kv_cache_iswa * get_mem_attn() const; + llama_memory_recurrent * get_mem_recr() const; + +private: + const llama_hparams & hparams; + + const std::unique_ptr mem_attn; + const std::unique_ptr mem_recr; +}; + +class llama_memory_hybrid_iswa_context : public llama_memory_context_i { +public: + using slot_info_vec_t = llama_kv_cache::slot_info_vec_t; + + // init failure + explicit llama_memory_hybrid_iswa_context(llama_memory_status status); + + // init full + explicit llama_memory_hybrid_iswa_context(llama_memory_hybrid_iswa * mem); + + // init update + explicit llama_memory_hybrid_iswa_context( + llama_memory_hybrid_iswa * mem, + llama_context * lctx, + bool optimize); + + // init success + llama_memory_hybrid_iswa_context( + llama_memory_hybrid_iswa * mem, + slot_info_vec_t sinfos_base, + slot_info_vec_t sinfos_swa, + std::vector ubatches); + + ~llama_memory_hybrid_iswa_context() = default; + + bool next() override; + bool apply() override; + + llama_memory_status get_status() const override; + const llama_ubatch & get_ubatch() const override; + + // + // llama_memory_hybrid_iswa_context + // + + const llama_kv_cache_iswa_context * get_attn() const; + const llama_memory_recurrent_context * get_recr() const; + +private: + // the index of the next ubatch to process + size_t i_next = 0; + + std::vector ubatches; + + const llama_memory_context_ptr ctx_attn; + const llama_memory_context_ptr ctx_recr; + + const llama_memory_status status; +}; diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index 812bf253049..f0038036dcb 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -785,23 +785,21 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: io.write(&s_trans, sizeof(s_trans)); io.write(&n_layer, sizeof(n_layer)); - std::vector tmp_buf; - - // Iterate and write all the keys first, each row is a cell + // Iterate and write all the R tensors first, each row is a cell // Get whole range at a time for (uint32_t il = 0; il < n_layer; ++il) { // skip null layers (read_data will handle this by checking "r_l" and "s_l" for null) if (r_l[il] == nullptr) continue; - // Write key type + // Write R tensor type const int32_t r_type_i = (int32_t)r_l[il]->type; io.write(&r_type_i, sizeof(r_type_i)); - // Write row size of key + // Write row size of R tensor const uint64_t r_size_row = ggml_row_size(r_l[il]->type, hparams.n_embd_r()); io.write(&r_size_row, sizeof(r_size_row)); - // Read each range of cells of k_size length each into tmp_buf and write out + // Write each range of cells of r_size_row length for (const auto & range : cell_ranges) { const size_t range_size = range.second - range.first; const size_t buf_size = range_size * r_size_row; @@ -814,15 +812,15 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: // skip null layers (read_data will handle this by checking "r_l" and "s_l" for null) if (s_l[il] == nullptr) continue; - // Write value type + // Write S tensor type const int32_t s_type_i = (int32_t)s_l[il]->type; io.write(&s_type_i, sizeof(s_type_i)); - // Write row size of value + // Write row size of S tensor const uint64_t s_size_row = ggml_row_size(s_l[il]->type, hparams.n_embd_s()); io.write(&s_size_row, sizeof(s_size_row)); - // Read each range of cells of s_size length each into tmp_buf and write out + // Write each range of S tensor rows for (const auto & range : cell_ranges) { const size_t range_size = range.second - range.first; const size_t buf_size = range_size * s_size_row; @@ -830,7 +828,7 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: } } } else { - // When v is transposed, we also need the element size and get the element ranges from each row + // When S tensor is transposed, we also need the element size and get the element ranges from each row const uint32_t mem_size = size; for (uint32_t il = 0; il < n_layer; ++il) { // skip null layers (read_data will handle this by checking "r_l" and "s_l" for null) @@ -838,7 +836,7 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: const uint32_t n_embd_s = hparams.n_embd_s(); - // Write value type + // Write S tensor type const int32_t s_type_i = (int32_t)s_l[il]->type; io.write(&s_type_i, sizeof(s_type_i)); @@ -851,7 +849,7 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: // For each row, we get the element values of each cell for (uint32_t j = 0; j < n_embd_s; ++j) { - // Read each range of cells of v_size_el length each into tmp_buf and write out + // Write each range of cells of s_size_el length for (const auto & range : cell_ranges) { const size_t range_size = range.second - range.first; const size_t src_offset = (range.first + j * mem_size) * s_size_el; diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp index 383b8dc7618..1501e392ca8 100644 --- a/src/llama-model-loader.cpp +++ b/src/llama-model-loader.cpp @@ -541,15 +541,15 @@ llama_model_loader::llama_model_loader( if (use_mmap && use_direct_io) { if (files.back()->has_direct_io()) { - // Disable mmap, as DirectIO is available - use_mmap = false; LLAMA_LOG_WARN("%s: direct I/O is enabled, disabling mmap\n", __func__); + use_mmap = false; } else { - // Disable DirectIO and reopen file using std::fopen for mmap + LLAMA_LOG_WARN("%s: direct I/O is not available, using mmap\n", __func__); use_direct_io = false; + + // reopen file using std::fopen for mmap files.pop_back(); files.emplace_back(new llama_file(fname.c_str(), "rb", false)); - LLAMA_LOG_WARN("%s: direct I/O is not available, using mmap\n", __func__); } } diff --git a/src/llama-model-saver.cpp b/src/llama-model-saver.cpp index ae27c71ce23..36e353074e0 100644 --- a/src/llama-model-saver.cpp +++ b/src/llama-model-saver.cpp @@ -146,8 +146,8 @@ void llama_model_saver::add_kv_from_model() { add_kv(LLM_KV_VOCAB_SIZE, vocab.n_tokens()); add_kv(LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train); add_kv(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd); - if (hparams.n_embd_out > 0) { - add_kv(LLM_KV_EMBEDDING_LENGTH_OUT, hparams.n_embd_out); + if (hparams.n_embd_out_impl > 0) { + add_kv(LLM_KV_EMBEDDING_LENGTH_OUT, hparams.n_embd_out_impl); } add_kv(LLM_KV_BLOCK_COUNT, hparams.n_layer); add_kv(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 94c47dc2480..8fc61aee372 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -8,6 +8,7 @@ #include "llama-kv-cache.h" #include "llama-kv-cache-iswa.h" #include "llama-memory-hybrid.h" +#include "llama-memory-hybrid-iswa.h" #include "llama-memory-recurrent.h" #include "ggml-cpp.h" @@ -124,10 +125,12 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_21B_A3B: return "21B.A3B"; case LLM_TYPE_30B_A3B: return "30B.A3B"; case LLM_TYPE_31B_A3_5B: return "31B.A3.5B"; + case LLM_TYPE_48B_A3B: return "48B.A3B"; case LLM_TYPE_80B_A3B: return "80B.A3B"; case LLM_TYPE_100B_A6B: return "100B.A6B"; case LLM_TYPE_102B_A12B: return "102B.A12B"; case LLM_TYPE_106B_A12B: return "106B.A12B"; + case LLM_TYPE_196B_A11B: return "196B.A11B"; case LLM_TYPE_230B_A10B: return "230B.A10B"; case LLM_TYPE_235B_A22B: return "235B.A22B"; case LLM_TYPE_300B_A47B: return "300B.A47B"; @@ -511,7 +514,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train); ml.get_key(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd); - ml.get_key(LLM_KV_EMBEDDING_LENGTH_OUT, hparams.n_embd_out, false); + ml.get_key(LLM_KV_EMBEDDING_LENGTH_OUT, hparams.n_embd_out_impl, false); ml.get_key(LLM_KV_BLOCK_COUNT, hparams.n_layer); ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert, false); ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false); @@ -558,6 +561,8 @@ void llama_model::load_hparams(llama_model_loader & ml) { std::fill(hparams.xielu_alpha_p.begin(), hparams.xielu_alpha_p.end(), 0.0f); std::fill(hparams.xielu_beta.begin(), hparams.xielu_beta.end(), 0.0f); std::fill(hparams.xielu_eps.begin(), hparams.xielu_eps.end(), 0.0f); + std::fill(hparams.swiglu_clamp_exp.begin(), hparams.swiglu_clamp_exp.end(), 0.0f); + std::fill(hparams.swiglu_clamp_shexp.begin(), hparams.swiglu_clamp_shexp.end(), 0.0f); ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer, false); ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer, false); @@ -1696,15 +1701,16 @@ void llama_model::load_hparams(llama_model_loader & ml) { case LLM_ARCH_DEEPSEEK2: { // lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B - bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26); + const bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); if (!is_lite) { ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); } ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); - ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla, false); - ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla, false); + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla_impl, false); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla_impl, false); ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); @@ -1713,7 +1719,12 @@ void llama_model::load_hparams(llama_model_loader & ml) { if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { // for compatibility with existing DeepSeek V2 and V2.5 GGUFs // that have no expert_gating_func model parameter set - hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX; + if ((hparams.n_layer == 47 || hparams.n_layer == 48) && n_vocab == 154880) { + // GLM 4.7 Lite + hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID; + } else { + hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX; + } } if (ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, 0.0f)) { @@ -1730,6 +1741,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { switch (hparams.n_layer) { case 27: type = LLM_TYPE_16B; break; + case 47: type = LLM_TYPE_30B_A3B; break; case 60: type = LLM_TYPE_236B; break; case 61: type = LLM_TYPE_671B; break; default: type = LLM_TYPE_UNKNOWN; @@ -2400,6 +2412,25 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_QWEN3_5: + case LLM_ARCH_QWEN3_5_MOE: + { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + // Load linear attention (gated delta net) parameters + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + + // Mark recurrent layers (linear attention layers) + for (uint32_t i = 0; i < hparams.n_layer; ++i) { + hparams.recurrent_layer_arr[i] = ((i + 1) % 4 != 0); + } + } break; case LLM_ARCH_MISTRAL3: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -2442,6 +2473,66 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_KIMI_LINEAR: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla_impl); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla_impl); + ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot); + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_KDA_HEAD_DIM, hparams.n_embd_head_kda); + + // MLA qk_rope_head_dim (for reference) + // qk_rope_head_dim = 64, qk_nope_head_dim = 128, qk_head_dim = 192 + + // Mark KDA layers as recurrent using n_head_kv pattern (like Jamba) + // Set n_head_kv = 0 for KDA layers (recurrent), n_head_kv = n_head for MLA layers (attention) + for (uint32_t i = 0; i < hparams.n_layer; ++i) { + hparams.recurrent_layer_arr[i] = hparams.n_head_kv(i) == 0; // KDA layers are recurrent + } + + // MoE parameters - Kimi uses moe_intermediate_size = 1024 + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); + + switch (hparams.n_layer) { + case 27: type = LLM_TYPE_48B_A3B; break; // Kimi-Linear-48B-A3B + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_STEP35: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + + // MoE + SWA parameters + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + + // Step35 uses sigmoid gating by default (if not set in GGUF) + if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { + hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID; + } + + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa); + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer); + ml.get_key_or_arr(LLM_KV_SWIGLU_CLAMP_EXP, hparams.swiglu_clamp_exp, hparams.n_layer, false); + ml.get_key_or_arr(LLM_KV_SWIGLU_CLAMP_SHEXP, hparams.swiglu_clamp_shexp, hparams.n_layer, false); + + switch (hparams.n_layer) { + case 45: type = LLM_TYPE_196B_A11B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; default: throw std::runtime_error("unsupported model architecture"); } @@ -4903,14 +4994,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } break; case LLM_ARCH_DEEPSEEK2: { - // lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B - const bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26); - - const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0); + const bool is_mla = hparams.is_mla(); // note: these are the actual head sizes you get when treating as MHA or after "decompression" using wv_b for MLA - const int64_t n_embd_head_k_mla = is_mla ? hparams.n_embd_head_k_mla : hparams.n_embd_head_k; - const int64_t n_embd_head_v_mla = is_mla ? hparams.n_embd_head_v_mla : hparams.n_embd_head_v; + const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla(); + const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla(); const int64_t n_embd_head_qk_rope = hparams.n_rot; const int64_t n_embd_head_qk_nope = n_embd_head_k_mla - n_embd_head_qk_rope; @@ -4935,13 +5023,13 @@ bool llama_model::load_tensors(llama_model_loader & ml) { auto & layer = layers[i]; layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - if (!is_lite) { + if (q_lora_rank > 0) { layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, 0); } layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0); - if (!is_lite) { + if (q_lora_rank > 0) { layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, 0); layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k_mla}, 0); } else { @@ -6591,7 +6679,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } // for LFM2-ColBert-350M - dense_2_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "weight"), {n_embd, hparams.get_n_embd_out()}, TENSOR_NOT_REQUIRED); + dense_2_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "weight"), {n_embd, hparams.n_embd_out()}, TENSOR_NOT_REQUIRED); } break; case LLM_ARCH_SMALLTHINKER: { @@ -6747,6 +6835,141 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 0); } } break; + case LLM_ARCH_KIMI_LINEAR: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + // Check for KDA specific tensors to determine layer type or if it's a mixed model + // Assuming KDA layer if KDA tensors are present + + // KDA uses head_dim = 128 (from linear_attn_config.head_dim) + const int64_t n_embd_head_k_kda = hparams.n_embd_head_kda; + const int64_t n_embd_head_v_kda = hparams.n_embd_head_kda; + const int64_t ssm_d_conv = hparams.ssm_d_conv; + + // Try loading KDA specific tensors (using SSM_ prefix) + // Conv1d weights: try 4D first, then 3D (quantization may remove trailing 1) + // 4D: [d_conv, 1, d_inner, 1], 3D: [d_conv, 1, d_inner] + layer.ssm_q_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_Q, "weight", i), {ssm_d_conv, 1, n_embd_head_k_kda * n_head, 1}, TENSOR_NOT_REQUIRED); + if (!layer.ssm_q_conv) { + layer.ssm_q_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_Q, "weight", i), {ssm_d_conv, 1, n_embd_head_k_kda * n_head}, TENSOR_NOT_REQUIRED); + } + + if (layer.ssm_q_conv) { + // KDA Layer - Conv1d weights may be 3D or 4D + layer.ssm_k_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_K, "weight", i), {ssm_d_conv, 1, n_embd_head_k_kda * n_head, 1}, TENSOR_NOT_REQUIRED); + if (!layer.ssm_k_conv) { + layer.ssm_k_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_K, "weight", i), {ssm_d_conv, 1, n_embd_head_k_kda * n_head}, 0); + } + layer.ssm_v_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_V, "weight", i), {ssm_d_conv, 1, n_embd_head_v_kda * n_head, 1}, TENSOR_NOT_REQUIRED); + if (!layer.ssm_v_conv) { + layer.ssm_v_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_V, "weight", i), {ssm_d_conv, 1, n_embd_head_v_kda * n_head}, 0); + } + + // q, k, v projections + // Python: q_proj, k_proj, v_proj + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k_kda * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_head_k_kda * n_head}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_head_v_kda * n_head}, 0); + + // KDA specific projections + // f_a_proj, f_b_proj + layer.ssm_f_a = create_tensor(tn(LLM_TENSOR_SSM_F_A, "weight", i), {n_embd, n_embd_head_k_kda}, 0); // head_dim + layer.ssm_f_b = create_tensor(tn(LLM_TENSOR_SSM_F_B, "weight", i), {n_embd_head_k_kda, n_embd_head_k_kda * n_head}, 0); // projection_size + + // b_proj (beta mixing coefficient) + layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", i), {n_embd, n_head}, 0); + + // A_log - Shape in GGUF: [1, num_heads, 1, 1] (4D) or [1, num_heads] (2D after quantization) Note: -exp(A_log) is applied in convert_hf_to_gguf.py + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_head, 1, 1}, TENSOR_NOT_REQUIRED); + if (!layer.ssm_a) { + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_head}, 0); + } + + // dt_bias - shape [n_embd_head_k_kda * n_head] = [4096] + layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {n_embd_head_k_kda * n_head}, 0); + + // g_a_proj, g_b_proj (output gate) + layer.ssm_g_a = create_tensor(tn(LLM_TENSOR_SSM_G_A, "weight", i), {n_embd, n_embd_head_k_kda}, 0); + layer.ssm_g_b = create_tensor(tn(LLM_TENSOR_SSM_G_B, "weight", i), {n_embd_head_k_kda, n_embd_head_k_kda * n_head}, 0); + + // o_norm (reusing SSM_NORM) + layer.ssm_o_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {n_embd_head_k_kda}, 0); // FusedRMSNormGated + + // o_proj + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v_kda * n_head, n_embd}, 0); + + } else { + // MLA Layer - use MLA-specific head dimensions + const int64_t q_lora_rank = hparams.n_lora_q; + const int64_t kv_lora_rank = hparams.n_lora_kv; + const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla(); + const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla(); + + layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, TENSOR_NOT_REQUIRED); + layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0); + + if (layer.attn_q_a_norm) { + layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, 0); + layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k_mla}, 0); + } else { + // Kimi MLA without Q compression: wq = [n_embd, n_head * n_embd_head_k_mla] + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_head * n_embd_head_k_mla}, 0); + } + + // Kimi: qk_rope_head_dim = 64 (actual RoPE dimension for MLA) + // Note: hparams.n_rot may be 72 (from conversion) but actual is 64 + const int64_t qk_rope_head_dim = hparams.n_rot; // From config: qk_rope_head_dim + layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + qk_rope_head_dim}, 0); + // Support Legacy GGUFs that don't split wkv_b (MLA KV cache disabled) + layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_k_mla - qk_rope_head_dim + n_embd_head_v_mla)}, TENSOR_NOT_REQUIRED); + if (!layer.wkv_b) { // MLA KV cache enabled + layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_k_mla - qk_rope_head_dim, kv_lora_rank, n_head}, 0); + layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_embd_head_v_mla, n_head}, 0); + } + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_embd_head_v_mla, n_embd}, 0); + } + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + // MoE intermediate size (different from dense FFN) + const int64_t n_ff_exp = hparams.n_ff_exp; + + // Kimi uses n_layer_dense_lead to determine which layers use dense FFN vs MoE + // first_k_dense_replace = 1 means layer 0 uses dense FFN, layers 1+ use MoE + if (i < (int) hparams.n_layer_dense_lead) { + // Dense FFN layer - use normal n_ff + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } else { + // MoE layer - use n_ff_exp (1024) instead of n_ff (9216) + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + + // Shared experts use moe_intermediate_size * num_shared_experts + // Kimi: shared_expert_intermediate_size = 1024 * 1 = 1024 + // Tensors are 2D: [n_embd, n_ff_shexp] or [n_ff_shexp, n_embd] + const int64_t n_ff_shexp_actual = n_ff_exp * (hparams.n_expert_shared > 0 ? hparams.n_expert_shared : 1); + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp_actual}, TENSOR_NOT_REQUIRED); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp_actual, n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp_actual}, TENSOR_NOT_REQUIRED); + + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 0); + } + } + } break; case LLM_ARCH_COGVLM: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -6890,6 +7113,129 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0); layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0); + // Shared experts + layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), { n_embd }, 0); + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, hparams.n_ff_shexp }, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, hparams.n_ff_shexp }, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { hparams.n_ff_shexp, n_embd }, 0); + } + } break; + case LLM_ARCH_QWEN3_5: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); + } + + // Calculate dimensions from hyperparameters + const int64_t head_k_dim = hparams.ssm_d_state; + const int64_t head_v_dim = hparams.ssm_d_state; + const int64_t n_k_heads = hparams.ssm_n_group; + const int64_t n_v_heads = hparams.ssm_dt_rank; + const int64_t key_dim = head_k_dim * n_k_heads; + const int64_t value_dim = head_v_dim * n_v_heads; + const int64_t conv_dim = key_dim * 2 + value_dim; + + const int64_t ba_dim = n_v_heads * 2; + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); + + if (!hparams.is_recurrent(i)) { + // Full attention layers + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head * 2 }, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); + } else { + // Linear attention (gated delta net) specific tensors + layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), { n_embd, key_dim * 2 + value_dim * 2 }, TENSOR_NOT_REQUIRED); + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED); + layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), { n_embd, value_dim }, TENSOR_NOT_REQUIRED); + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0); + layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0); + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, i), { hparams.ssm_dt_rank }, 0); + layer.ssm_beta_alpha = create_tensor(tn(LLM_TENSOR_SSM_BETA_ALPHA, "weight", i), { n_embd, ba_dim }, 0); + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0); + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { value_dim, n_embd }, 0); + } + + // Dense FFN for all layers + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0); + } + } break; + case LLM_ARCH_QWEN3_5_MOE: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); + } + + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + + // Calculate dimensions from hyperparameters + const int64_t head_k_dim = hparams.ssm_d_state; + const int64_t head_v_dim = hparams.ssm_d_state; + const int64_t n_k_heads = hparams.ssm_n_group; + const int64_t n_v_heads = hparams.ssm_dt_rank; + const int64_t key_dim = head_k_dim * n_k_heads; + const int64_t value_dim = head_v_dim * n_v_heads; + const int64_t conv_dim = key_dim * 2 + value_dim; + + const int64_t ba_dim = n_v_heads * 2; + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); + + if (!hparams.is_recurrent(i)) { + // Full attention layers + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head * 2 }, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); + } else { + // Linear attention (gated delta net) specific tensors + layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), { n_embd, key_dim * 2 + value_dim * 2 }, TENSOR_NOT_REQUIRED); + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED); + layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), { n_embd, value_dim }, TENSOR_NOT_REQUIRED); + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0); + layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0); + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, i), { hparams.ssm_dt_rank }, 0); + layer.ssm_beta_alpha = create_tensor(tn(LLM_TENSOR_SSM_BETA_ALPHA, "weight", i), { n_embd, ba_dim }, 0); + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0); + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { value_dim, n_embd }, 0); + } + + // MoE FFN + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0); + // Shared experts layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), { n_embd }, 0); layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, hparams.n_ff_shexp }, 0); @@ -6935,6 +7281,72 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); } } break; + case LLM_ARCH_STEP35: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + // STEP35 supports per-layer partial RoPE dims; rope factors are stored as a single shared tensor + // ("rope_freqs.weight") and ggml uses only the first (n_rot_l/2) entries per layer. + uint32_t n_rot_max = 0; + for (int i = 0; i < n_layer; ++i) { + n_rot_max = std::max(n_rot_max, hparams.n_rot); + } + if (n_rot_max == 0) { + n_rot_max = n_rot; + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + const uint32_t n_head_l = hparams.n_head(i); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i); + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED); + + // optional rope factors (llama3) / longrope tensors + if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } else { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head_l}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v * n_head_l, n_embd}, 0); + + // head-wise attention gate (Step35 self_attn.g_proj) + layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), {n_embd, n_head_l}, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + // dense MLP (leading dense blocks) + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); + + // MoE routed experts + selection bias (router_bias) + const int64_t n_ff_exp = hparams.n_ff_exp; + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); + + // shared expert MLP + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, TENSOR_NOT_REQUIRED); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, TENSOR_NOT_REQUIRED); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, TENSOR_NOT_REQUIRED); + } + } break; case LLM_ARCH_MAINCODER: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -7275,6 +7687,8 @@ void llama_model::print_info() const { arch == LLM_ARCH_PLAMO2 || arch == LLM_ARCH_GRANITE_HYBRID || arch == LLM_ARCH_QWEN3NEXT || + arch == LLM_ARCH_QWEN3_5 || + arch == LLM_ARCH_QWEN3_5_MOE || arch == LLM_ARCH_NEMOTRON_H || arch == LLM_ARCH_NEMOTRON_H_MOE) { LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv); @@ -7310,8 +7724,8 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q); LLAMA_LOG_INFO("%s: n_lora_kv = %d\n", __func__, hparams.n_lora_kv); - LLAMA_LOG_INFO("%s: n_embd_head_k_mla = %d\n", __func__, hparams.n_embd_head_k_mla); - LLAMA_LOG_INFO("%s: n_embd_head_v_mla = %d\n", __func__, hparams.n_embd_head_v_mla); + LLAMA_LOG_INFO("%s: n_embd_head_k_mla = %d\n", __func__, hparams.n_embd_head_k_mla()); + LLAMA_LOG_INFO("%s: n_embd_head_v_mla = %d\n", __func__, hparams.n_embd_head_v_mla()); LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); @@ -7523,23 +7937,44 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, }; } - res = new llama_memory_hybrid( - /* model */ *this, - /* attn_type_k */ params.type_k, - /* attn_type_v */ params.type_v, - /* attn_v_trans */ !cparams.flash_attn, - /* attn_kv_size */ cparams.n_ctx, - /* attn_n_pad */ 1, - /* attn_n_swa */ hparams.n_swa, - /* attn_swa_type */ hparams.swa_type, - /* recurrent_type_k */ GGML_TYPE_F32, - /* recurrent_type_v */ GGML_TYPE_F32, - /* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max), - /* n_seq_max */ cparams.n_seq_max, - /* offload */ cparams.offload_kqv, - /* unified */ cparams.kv_unified, - /* filter_attn */ std::move(filter_attn), - /* filter_recr */ std::move(filter_recr)); + if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { + // Use hybrid-iswa for hybrid models with SWA + res = new llama_memory_hybrid_iswa( + /* model */ *this, + /* attn_type_k */ params.type_k, + /* attn_type_v */ params.type_v, + /* attn_v_trans */ !cparams.flash_attn, + /* attn_swa_full */ params.swa_full, + /* attn_kv_size */ cparams.n_ctx, + /* attn_n_ubatch */ cparams.n_ubatch, + /* attn_n_pad */ 1, + /* recurrent_type_r */ GGML_TYPE_F32, + /* recurrent_type_s */ GGML_TYPE_F32, + /* recurrent_rs_size */ std::max((uint32_t) 1, cparams.n_seq_max), + /* n_seq_max */ cparams.n_seq_max, + /* offload */ cparams.offload_kqv, + /* unified */ cparams.kv_unified, + /* filter_attn */ std::move(filter_attn), + /* filter_recr */ std::move(filter_recr)); + } else { + res = new llama_memory_hybrid( + /* model */ *this, + /* attn_type_k */ params.type_k, + /* attn_type_v */ params.type_v, + /* attn_v_trans */ !cparams.flash_attn, + /* attn_kv_size */ cparams.n_ctx, + /* attn_n_pad */ 1, + /* attn_n_swa */ hparams.n_swa, + /* attn_swa_type */ hparams.swa_type, + /* recurrent_type_k */ GGML_TYPE_F32, + /* recurrent_type_v */ GGML_TYPE_F32, + /* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max), + /* n_seq_max */ cparams.n_seq_max, + /* offload */ cparams.offload_kqv, + /* unified */ cparams.kv_unified, + /* filter_attn */ std::move(filter_attn), + /* filter_recr */ std::move(filter_recr)); + } } else { llama_memory_i::layer_reuse_cb reuse = nullptr; @@ -8052,6 +8487,14 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_QWEN3_5: + { + llm = std::make_unique(*this, params); + } break; + case LLM_ARCH_QWEN3_5_MOE: + { + llm = std::make_unique(*this, params); + } break; case LLM_ARCH_MISTRAL3: { llm = std::make_unique(*this, params); @@ -8060,6 +8503,14 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_KIMI_LINEAR: + { + llm = std::make_unique(*this, params); + } break; + case LLM_ARCH_STEP35: + { + llm = std::make_unique(*this, params); + } break; default: GGML_ABORT("fatal error"); } @@ -8099,7 +8550,7 @@ llama_model_params llama_model_default_params() { /*.kv_overrides =*/ nullptr, /*.vocab_only =*/ false, /*.use_mmap =*/ true, - /*.use_direct_io =*/ true, + /*.use_direct_io =*/ false, /*.use_mlock =*/ false, /*.check_tensors =*/ false, /*.use_extra_bufts =*/ true, @@ -8135,7 +8586,7 @@ int32_t llama_model_n_embd_inp(const llama_model * model) { } int32_t llama_model_n_embd_out(const llama_model * model) { - return model->hparams.get_n_embd_out(); + return model->hparams.n_embd_out(); } int32_t llama_model_n_layer(const llama_model * model) { @@ -8209,6 +8660,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_WAVTOKENIZER_DEC: case LLM_ARCH_NEMOTRON_H: case LLM_ARCH_NEMOTRON_H_MOE: + case LLM_ARCH_KIMI_LINEAR: return LLAMA_ROPE_TYPE_NONE; // use what we call a normal RoPE, operating on pairs of consecutive head values @@ -8303,7 +8755,10 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_PANGU_EMBED: case LLM_ARCH_AFMOE: case LLM_ARCH_QWEN3NEXT: + case LLM_ARCH_QWEN3_5: + case LLM_ARCH_QWEN3_5_MOE: case LLM_ARCH_MIMO2: + case LLM_ARCH_STEP35: return LLAMA_ROPE_TYPE_NEOX; case LLM_ARCH_QWEN2VL: diff --git a/src/llama-model.h b/src/llama-model.h index d1de16e3f28..7b580043b33 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -118,10 +118,12 @@ enum llm_type { LLM_TYPE_21B_A3B, // Ernie MoE small LLM_TYPE_30B_A3B, LLM_TYPE_31B_A3_5B, + LLM_TYPE_48B_A3B, // Kimi Linear LLM_TYPE_80B_A3B, // Qwen3 Next LLM_TYPE_100B_A6B, LLM_TYPE_102B_A12B, // Solar-Open LLM_TYPE_106B_A12B, // GLM-4.5-Air + LLM_TYPE_196B_A11B, // Step3.5-Flash LLM_TYPE_230B_A10B, // Minimax M2 LLM_TYPE_235B_A22B, LLM_TYPE_300B_A47B, // Ernie MoE big @@ -411,6 +413,18 @@ struct llama_layer { struct ggml_tensor * ffn_act_beta = nullptr; struct ggml_tensor * ffn_act_eps = nullptr; + // Kimi Linear KDA (using ssm_ prefix for consistency) + // Note: ssm_dt_b already exists above (mamba bias), reused for Kimi dt_bias + struct ggml_tensor * ssm_q_conv = nullptr; + struct ggml_tensor * ssm_k_conv = nullptr; + struct ggml_tensor * ssm_v_conv = nullptr; + struct ggml_tensor * ssm_f_a = nullptr; + struct ggml_tensor * ssm_f_b = nullptr; + struct ggml_tensor * ssm_beta = nullptr; + struct ggml_tensor * ssm_g_a = nullptr; + struct ggml_tensor * ssm_g_b = nullptr; + struct ggml_tensor * ssm_o_norm = nullptr; + struct llama_layer_posnet posnet; struct llama_layer_convnext convnext; diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index 048d65a75c2..a7891647c3d 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -422,57 +422,6 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t ++qs.i_ffn_up; } - // if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K; - //} - // IK: let's remove this, else Q2_K is almost the same as Q3_K_S - //else if (name.find("ffn_gate") != std::string::npos || name.find("ffn_up") != std::string::npos) { - // if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K; - //} - // This can be used to reduce the size of the Q5_K_S model. - // The associated PPL increase is fully in line with the size reduction - //else { - // if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_S) new_type = GGML_TYPE_Q4_K; - //} - bool convert_incompatible_tensor = false; - { - const int64_t nx = tensor->ne[0]; - const int64_t ny = tensor->ne[1]; - const int64_t qk_k = ggml_blck_size(new_type); - - if (nx % qk_k != 0) { - LLAMA_LOG_WARN("\n\n%s : tensor cols %" PRId64 " x %" PRId64 " are not divisible by %" PRId64 ", required for %s", __func__, nx, ny, qk_k, ggml_type_name(new_type)); - convert_incompatible_tensor = true; - } else { - ++qs.n_k_quantized; - } - } - - if (convert_incompatible_tensor) { - switch (new_type) { - case GGML_TYPE_TQ1_0: - case GGML_TYPE_TQ2_0: new_type = GGML_TYPE_Q4_0; break; // TODO: use a symmetric type instead - case GGML_TYPE_IQ2_XXS: - case GGML_TYPE_IQ2_XS: - case GGML_TYPE_IQ2_S: - case GGML_TYPE_IQ3_XXS: - case GGML_TYPE_IQ3_S: - case GGML_TYPE_IQ1_S: - case GGML_TYPE_IQ1_M: - case GGML_TYPE_Q2_K: - case GGML_TYPE_Q3_K: - case GGML_TYPE_IQ4_XS: new_type = GGML_TYPE_IQ4_NL; break; - case GGML_TYPE_Q4_K: new_type = GGML_TYPE_Q5_0; break; - case GGML_TYPE_Q5_K: new_type = GGML_TYPE_Q5_1; break; - case GGML_TYPE_Q6_K: new_type = GGML_TYPE_Q8_0; break; - default: throw std::runtime_error("\nUnsupported tensor size encountered\n"); - } - if (tensor->ne[0] % ggml_blck_size(new_type) != 0) { - new_type = GGML_TYPE_F16; - } - LLAMA_LOG_WARN(" - using fallback quantization %s\n", ggml_type_name(new_type)); - ++qs.n_fallback; - } - return new_type; } @@ -596,7 +545,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: } std::vector splits = {}; - llama_model_loader ml(fname_inp, splits, use_mmap, /*use_direct_io*/ true, /*check_tensors*/ true, /*no_alloc*/ false, kv_overrides, nullptr); + llama_model_loader ml(fname_inp, splits, use_mmap, /*use_direct_io*/ false, /*check_tensors*/ true, /*no_alloc*/ false, kv_overrides, nullptr); ml.init_mappings(false); // no prefetching llama_model model(llama_model_default_params()); @@ -838,9 +787,9 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_POS_EMBD, "weight"); quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_TOKEN_TYPES, "weight"); - // do not quantize Mamba's small yet 2D weights + // do not quantize Mamba /Kimi's small conv1d weights // NOTE: can't use LLM_TN here because the layer number is not known - quantize &= name.find("ssm_conv1d.weight") == std::string::npos; + quantize &= name.find("ssm_conv1d") == std::string::npos; quantize &= name.find("shortconv.conv.weight") == std::string::npos; // do not quantize RWKV's small yet 2D weights @@ -875,21 +824,69 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: // get more optimal quantization type based on the tensor shape, layer, etc. if (!params->pure && ggml_is_quantized(default_type)) { - int fallback = qs.n_fallback; - new_type = llama_tensor_get_type(qs, new_type, tensor, ftype); - // unless the user specifies a type, and the tensor geometry will not require fallback quantisation - if (params->tensor_types && qs.n_fallback - fallback == 0) { + // if the user provided tensor types - use those + bool manual = false; + if (params->tensor_types) { const std::vector & tensor_types = *static_cast *>(params->tensor_types); const std::string tensor_name(tensor->name); for (const auto & [tname, qtype] : tensor_types) { if (std::regex pattern(tname); std::regex_search(tensor_name, pattern)) { if (qtype != new_type) { - LLAMA_LOG_DEBUG("(overriding %s) ", ggml_type_name(new_type)); + LLAMA_LOG_WARN("(manual override: %s -> %s) ", ggml_type_name(new_type), ggml_type_name(qtype)); new_type = qtype; // if two or more types are specified for the same tensor, the last match wins + manual = true; + break; } } } } + + // if not manual - use the standard logic for choosing the quantization type based on the selected mixture + if (!manual) { + new_type = llama_tensor_get_type(qs, new_type, tensor, ftype); + } + + // incompatible tensor shapes are handled here - fallback to a compatible type + { + bool convert_incompatible_tensor = false; + + const int64_t nx = tensor->ne[0]; + const int64_t ny = tensor->ne[1]; + const int64_t qk_k = ggml_blck_size(new_type); + + if (nx % qk_k != 0) { + LLAMA_LOG_WARN("\n\n%s : tensor cols %" PRId64 " x %" PRId64 " are not divisible by %" PRId64 ", required for %s", __func__, nx, ny, qk_k, ggml_type_name(new_type)); + convert_incompatible_tensor = true; + } else { + ++qs.n_k_quantized; + } + + if (convert_incompatible_tensor) { + switch (new_type) { + case GGML_TYPE_TQ1_0: + case GGML_TYPE_TQ2_0: new_type = GGML_TYPE_Q4_0; break; // TODO: use a symmetric type instead + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_IQ4_XS: new_type = GGML_TYPE_IQ4_NL; break; + case GGML_TYPE_Q4_K: new_type = GGML_TYPE_Q5_0; break; + case GGML_TYPE_Q5_K: new_type = GGML_TYPE_Q5_1; break; + case GGML_TYPE_Q6_K: new_type = GGML_TYPE_Q8_0; break; + default: throw std::runtime_error("\nUnsupported tensor size encountered\n"); + } + if (tensor->ne[0] % ggml_blck_size(new_type) != 0) { + new_type = GGML_TYPE_F16; + } + LLAMA_LOG_WARN(" - using fallback quantization %s\n", ggml_type_name(new_type)); + ++qs.n_fallback; + } + } } if (params->token_embedding_type < GGML_TYPE_COUNT && strcmp(tensor->name, "token_embd.weight") == 0) { new_type = params->token_embedding_type; diff --git a/src/llama-sampling.cpp b/src/llama-sampler.cpp similarity index 98% rename from src/llama-sampling.cpp rename to src/llama-sampler.cpp index 5dde513065b..9bbc5dbde24 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampler.cpp @@ -1,4 +1,4 @@ -#include "llama-sampling.h" +#include "llama-sampler.h" #include "llama-impl.h" #include "llama-vocab.h" @@ -1025,11 +1025,7 @@ struct llama_sampler_dist : public llama_sampler_backend { std::mt19937 rng; - // backend input - struct ggml_tensor * inp_uniform; - - ggml_context_ptr inp_ctx; - ggml_backend_buffer_ptr inp_buf; + ggml_tensor * inp_uniform; }; static const char * llama_sampler_dist_name(const struct llama_sampler * smpl) { @@ -1138,37 +1134,10 @@ static bool llama_sampler_dist_backend_init( ggml_backend_buffer_type_t buft) { auto * sctx = (llama_sampler_dist *) smpl->ctx; - // allocate inputs - { - ggml_init_params params = { - /*.mem_size =*/ ggml_tensor_overhead(), - /*.mem_buffer =*/ nullptr, - /*.no_alloc =*/ true, - }; - - sctx->inp_ctx.reset(ggml_init(params)); - - // Create the uniform random scalar input tensor. This will be set by - // llama_sampler_dist_backend_set_input after this graph is built. - sctx->inp_uniform = ggml_new_tensor_1d(sctx->inp_ctx.get(), GGML_TYPE_F32, 1); - ggml_set_name (sctx->inp_uniform, "uniform"); - ggml_set_input(sctx->inp_uniform); - - // Allocate all tensors from our context to the backend - sctx->inp_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(sctx->inp_ctx.get(), buft)); - - ggml_backend_buffer_clear(sctx->inp_buf.get(), 0); - } - const bool res = llama_sampler_backend_support(smpl, buft); sctx->init(res); - if (!res) { - sctx->inp_ctx.reset(nullptr); - sctx->inp_buf.reset(nullptr); - } - return res; } @@ -1178,8 +1147,13 @@ static void llama_sampler_dist_backend_apply( struct ggml_cgraph * gf, struct llama_sampler_data * data) { GGML_UNUSED(gf); + auto * sctx = (llama_sampler_dist *) smpl->ctx; + sctx->inp_uniform = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); + ggml_set_name (sctx->inp_uniform, "uniform"); + ggml_set_input(sctx->inp_uniform); + struct ggml_tensor * probs = ggml_soft_max(ctx, data->logits); ggml_set_name(probs, "dist_probs"); @@ -1226,6 +1200,7 @@ static void llama_sampler_dist_backend_apply( static void llama_sampler_dist_backend_set_input(struct llama_sampler * smpl) { auto * sctx = (llama_sampler_dist *) smpl->ctx; + GGML_ASSERT(sctx->inp_uniform != nullptr); // We sample in double precision and cast to float to match rnd numbers of @@ -1262,8 +1237,6 @@ struct llama_sampler * llama_sampler_init_dist(uint32_t seed) { /* .seed_cur = */ seed_cur, /* .rng = */ std::mt19937(seed_cur), /* .inp_uniform = */ nullptr, - /* .inp_ctx = */ nullptr, - /* .inp_buf = */ nullptr, } ); } @@ -3461,9 +3434,6 @@ struct llama_sampler_logit_bias : public llama_sampler_backend { struct ggml_tensor * inp_logit_bias; struct ggml_tensor * inp_logit_idxs; - - ggml_context_ptr inp_ctx; - ggml_backend_buffer_ptr inp_buf; }; static const char * llama_sampler_logit_bias_name(const struct llama_sampler * smpl) { @@ -3526,6 +3496,16 @@ static void llama_sampler_logit_bias_backend_apply( return; } + const size_t n = sctx->logit_bias.size(); + + sctx->inp_logit_bias = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n); + ggml_set_name(sctx->inp_logit_bias, "logit_bias"); + ggml_set_input(sctx->inp_logit_bias); + + sctx->inp_logit_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n); + ggml_set_name(sctx->inp_logit_idxs, "logit_idxs"); + ggml_set_input(sctx->inp_logit_idxs); + ggml_tensor * cur = ggml_fill(ctx, data->logits, 0.0f); cur = ggml_reshape_2d(ctx, cur, 1, ggml_nelements(cur)); @@ -3562,6 +3542,8 @@ static void llama_sampler_logit_bias_backend_set_input(struct llama_sampler * sm static bool llama_sampler_logit_bias_backend_init( struct llama_sampler * smpl, ggml_backend_buffer_type_t buft) { + GGML_UNUSED(buft); + auto * sctx = (llama_sampler_logit_bias *) smpl->ctx; sctx->init(true); @@ -3570,29 +3552,6 @@ static bool llama_sampler_logit_bias_backend_init( return true; } - ggml_init_params params = { - /*.mem_size =*/ 2*ggml_tensor_overhead(), - /*.mem_buffer =*/ nullptr, - /*.no_alloc =*/ true, - }; - - sctx->inp_ctx.reset(ggml_init(params)); - - const size_t n = sctx->logit_bias.size(); - - sctx->inp_logit_bias = ggml_new_tensor_2d(sctx->inp_ctx.get(), GGML_TYPE_F32, 1, n); - ggml_set_name(sctx->inp_logit_bias, "logit_bias"); - ggml_set_input(sctx->inp_logit_bias); - - sctx->inp_logit_idxs = ggml_new_tensor_1d(sctx->inp_ctx.get(), GGML_TYPE_I32, n); - ggml_set_name(sctx->inp_logit_idxs, "logit_idxs"); - ggml_set_input(sctx->inp_logit_idxs); - - // Allocate all tensors from our context to the backend - sctx->inp_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(sctx->inp_ctx.get(), buft)); - - ggml_backend_buffer_clear(sctx->inp_buf.get(), 0); - return true; } @@ -3628,8 +3587,6 @@ struct llama_sampler * llama_sampler_init_logit_bias( /* .to_search = */ {}, /* .inp_logit_bias = */ nullptr, /* .inp_logit_idxs = */ nullptr, - /* .inp_ctx = */ nullptr, - /* .inp_buf = */ nullptr, } ); } diff --git a/src/llama-sampling.h b/src/llama-sampler.h similarity index 92% rename from src/llama-sampling.h rename to src/llama-sampler.h index 6a963c0bb73..b9bfc20d251 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampler.h @@ -1,7 +1,5 @@ #pragma once -// TODO: rename llama-sampling.h/.cpp to llama-sampler.h/.cpp ? - #include "llama.h" #include diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index a23950d007c..6d6bdfa090c 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -90,7 +90,7 @@ static_assert(std::is_trivially_copyable::value, "llm_symbol is not // // SPM tokenizer // original implementation: -// https://github.com/ggerganov/llama.cpp/commit/074bea2eb1f1349a0118239c4152914aecaa1be4 +// https://github.com/ggml-org/llama.cpp/commit/074bea2eb1f1349a0118239c4152914aecaa1be4 // struct llm_bigram_spm { @@ -285,7 +285,7 @@ struct llm_tokenizer_bpe : llm_tokenizer { // original regex from tokenizer.json //"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", - // adapted: https://github.com/ggerganov/llama.cpp/pull/6920#issuecomment-2080233989 + // adapted: https://github.com/ggml-org/llama.cpp/pull/6920#issuecomment-2080233989 "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", }; break; @@ -1752,26 +1752,33 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { // read bpe merges and populate bpe ranks const int merges_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_MERGES).c_str()); + // Kimi-K2 uses custom tokenization without traditional BPE merges + const bool is_kimi_k2 = (tokenizer_pre == "kimi-k2"); + if (merges_keyidx == -1) { - throw std::runtime_error("cannot find tokenizer merges in model file\n"); - } + if (!is_kimi_k2) { + throw std::runtime_error("cannot find tokenizer merges in model file\n"); + } + // Kimi-K2 doesn't need merges, skip + LLAMA_LOG_INFO("%s: Kimi-K2 tokenizer detected, skipping BPE merges\n", __func__); + } else { + const int n_merges = gguf_get_arr_n(ctx, merges_keyidx); + for (int i = 0; i < n_merges; i++) { + const std::string word = gguf_get_arr_str(ctx, merges_keyidx, i); + //GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0); - const int n_merges = gguf_get_arr_n(ctx, merges_keyidx); - for (int i = 0; i < n_merges; i++) { - const std::string word = gguf_get_arr_str(ctx, merges_keyidx, i); - //GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0); + std::string first; + std::string second; - std::string first; - std::string second; + const size_t pos = word.find(' ', 1); - const size_t pos = word.find(' ', 1); + if (pos != std::string::npos) { + first = word.substr(0, pos); + second = word.substr(pos + 1); + } - if (pos != std::string::npos) { - first = word.substr(0, pos); - second = word.substr(pos + 1); + bpe_ranks.emplace(std::make_pair(first, second), i); } - - bpe_ranks.emplace(std::make_pair(first, second), i); } // default special tokens @@ -2226,6 +2233,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { || t.first == "<|end_of_text|>" // granite || t.first == "" || t.first == "_" + || t.first == "[EOT]" // Kimi-K2 || t.first == "<|end▁of▁sentence|>" // DeepSeek || t.first == "" // smoldocling ) { @@ -2262,6 +2270,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { || t.first == "
"
                         || t.first == "▁
"          // CodeLlama
                         || t.first == "<|code_prefix|>" // GLM-4.5
+                        || t.first == "<|prefix|>"      // Falcon-H1-Tiny-Coder
                         ) {
                     special_fim_pre_id = t.second;
                     if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
@@ -2282,6 +2291,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                         || t.first == ""
                         || t.first == "▁"         // CodeLlama
                         || t.first == "<|code_suffix|>" // GLM-4.5
+                        || t.first == "<|suffix|>"      // Falcon-H1-Tiny-Coder
                         ) {
                     special_fim_suf_id = t.second;
                     if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
@@ -2302,6 +2312,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                         || t.first == ""
                         || t.first == "▁"         // CodeLlama
                         || t.first == "<|code_middle|>" // GLM-4.5
+                        || t.first == "<|middle|>"      // Falcon-H1-Tiny-Coder
                         ) {
                     special_fim_mid_id = t.second;
                     if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
@@ -2319,6 +2330,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                         || t.first == ""
                         || t.first == ""   // Granite
                         || t.first == ""
+                        || t.first == "[PAD]" // Kimi-K2
                         ) {
                     special_fim_pad_id = t.second;
                     if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
@@ -2390,7 +2402,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
 
         // maintain a list of tokens that cause end-of-generation
         // this is currently determined based on the token text, which is obviously not ideal
-        // ref: https://github.com/ggerganov/llama.cpp/issues/9606
+        // ref: https://github.com/ggml-org/llama.cpp/issues/9606
         special_eog_ids.clear();
 
         if (special_fim_pad_id != LLAMA_TOKEN_NULL && special_eog_ids.count(special_fim_pad_id) == 0) {
@@ -2421,6 +2433,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                     || t.first == "<|eom_id|>"
                     || t.first == ""
                     || t.first == "_"
+                    || t.first == "[EOT]" // Kimi-K2
+                    || t.first == "[EOS]" // Kimi-K2
                     || t.first == "<|end_of_text|>"
                     || t.first == "" // smoldocling
                ) {
@@ -3079,7 +3093,7 @@ std::vector llama_vocab::impl::tokenize(
 }
 
 int32_t llama_vocab::impl::token_to_piece(llama_token token, char * buf, int32_t length, int32_t lstrip, bool special) const {
-    // ref: https://github.com/ggerganov/llama.cpp/pull/7587#discussion_r1620983843
+    // ref: https://github.com/ggml-org/llama.cpp/pull/7587#discussion_r1620983843
     static const int attr_special = LLAMA_TOKEN_ATTR_UNKNOWN | LLAMA_TOKEN_ATTR_CONTROL;
     const llama_token_attr attr = token_get_attr(token);
     if (!special && (attr & attr_special)) {
diff --git a/src/llama.cpp b/src/llama.cpp
index f1096d960e1..6da90d6f1f8 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -311,8 +311,12 @@ static void llama_params_fit_impl(
                             __func__, hp_nct, cparams->n_ctx, memory_reduction/MiB);
                     }
                 } else {
-                    LLAMA_LOG_INFO("%s: default model context size is %" PRIu32 " which is <= the min. context size of %" PRIu32 " -> no change\n",
-                        __func__, hp_nct, n_ctx_min);
+                    if (n_ctx_min == UINT32_MAX) {
+                        LLAMA_LOG_INFO("%s: user has requested full context size of %" PRIu32 " -> no change\n", __func__, hp_nct);
+                    } else {
+                        LLAMA_LOG_INFO("%s: default model context size is %" PRIu32 " which is <= the min. context size of %" PRIu32 " -> no change\n",
+                            __func__, hp_nct, n_ctx_min);
+                    }
                 }
             } else {
                 LLAMA_LOG_INFO("%s: context size set by user to %" PRIu32 " -> no change\n", __func__, cparams->n_ctx);
@@ -1091,25 +1095,55 @@ int32_t llama_chat_apply_template(
 // model split
 //
 
-int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int split_no, int split_count) {
+int32_t llama_split_path(
+    char * split_path,
+    size_t maxlen,
+    const char * path_prefix,
+    int32_t split_no,
+    int32_t split_count) {
+
     static const char * const SPLIT_PATH_FORMAT = "%s-%05d-of-%05d.gguf";
-    if (snprintf(split_path, maxlen, SPLIT_PATH_FORMAT, path_prefix, split_no + 1, split_count)) {
-        return strlen(split_path);
+
+    const int written = snprintf(
+        split_path,
+        maxlen,
+        SPLIT_PATH_FORMAT,
+        path_prefix,
+        split_no + 1,
+        split_count
+    );
+
+    if (written < 0 || (size_t) written >= maxlen) {
+        return 0;
     }
-    return 0;
+
+    return (int32_t) written;
 }
 
-int llama_split_prefix(char * split_prefix, size_t maxlen, const char * split_path, int split_no, int split_count) {
-    std::string str_split_path(split_path);
+int32_t llama_split_prefix(
+    char * split_prefix,
+    size_t maxlen,
+    const char * split_path,
+    int32_t split_no,
+    int32_t split_count) {
+
+    const std::string str_split_path(split_path);
+
     char postfix[32];
-    snprintf(postfix, 32, "-%05d-of-%05d.gguf", split_no + 1, split_count);
-    std::string str_postfix(postfix);
-
-    // check if split_prefix ends with postfix
-    int size_prefix = str_split_path.size() - str_postfix.size();
-    if (size_prefix > 0 && str_split_path.find(str_postfix, size_prefix) != std::string::npos) {
-        snprintf(split_prefix, std::min((size_t) size_prefix + 1, maxlen), "%s", split_path);
-        return size_prefix;
+    snprintf(postfix, sizeof(postfix), "-%05d-of-%05d.gguf", split_no + 1, split_count);
+
+    const std::string str_postfix(postfix);
+    if (str_split_path.size() <= str_postfix.size()) {
+        return 0;
+    }
+
+    const size_t size_prefix = str_split_path.size() - str_postfix.size();
+
+    if (str_split_path.compare(size_prefix, std::string::npos, str_postfix) == 0) {
+        const size_t copy_len = std::min(size_prefix + 1, maxlen);
+        snprintf(split_prefix, copy_len, "%s", split_path);
+
+        return (int32_t) size_prefix;
     }
 
     return 0;
diff --git a/src/models/deepseek2.cpp b/src/models/deepseek2.cpp
index ca63a62ad1b..987f449934c 100644
--- a/src/models/deepseek2.cpp
+++ b/src/models/deepseek2.cpp
@@ -2,14 +2,11 @@
 
 llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_graph_params & params) :
     llm_graph_context(params) {
-    // lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B
-    bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26);
-
-    const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0);
+    const bool is_mla = hparams.is_mla();
 
     // note: these are the actual head sizes you get when treating as MHA or after "decompression" using wv_b for MLA
-    const int64_t n_embd_head_k = is_mla ? hparams.n_embd_head_k_mla : hparams.n_embd_head_k;
-    const int64_t n_embd_head_v = is_mla ? hparams.n_embd_head_v_mla : hparams.n_embd_head_v;
+    const int64_t n_embd_head_k = hparams.n_embd_head_k_mla();
+    const int64_t n_embd_head_v = hparams.n_embd_head_v_mla();
 
     const int64_t n_embd_head_qk_rope = hparams.n_rot;
     const int64_t n_embd_head_qk_nope = n_embd_head_k - n_embd_head_qk_rope;
@@ -17,7 +14,7 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
     const uint32_t kv_lora_rank = hparams.n_lora_kv;
 
     // We have to pre-scale kq_scale and attn_factor to make the YaRN RoPE work correctly.
-    // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
+    // See https://github.com/ggml-org/llama.cpp/discussions/7416 for detailed explanation.
     // And also: https://github.com/ggml-org/llama.cpp/pull/17945 [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX]
 
     // first cancel the adjustment from llama_hparams::yarn_attn_factor_adjust to get the original attn_factor
@@ -43,7 +40,8 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
     // inp_pos - contains the positions
     ggml_tensor * inp_pos = build_inp_pos();
 
-    auto * inp_attn = build_attn_inp_kv();
+    auto * inp_attn_kv = !is_mla ? build_attn_inp_kv() : nullptr;
+    auto * inp_attn_k  =  is_mla ? build_attn_inp_k()  : nullptr;
 
     ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -57,6 +55,9 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
         // self_attention
         {
             ggml_tensor * q = NULL;
+
+            const bool is_lite = model.layers[il].wq;
+
             if (!is_lite) {
                 q = ggml_mul_mat(ctx0, model.layers[il].wq_a, cur);
                 cb(q, "q", il);
@@ -124,14 +125,14 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
 
                 // {n_embd_head_qk_rope + kv_lora_rank, n_head, n_tokens}
                 // note: rope must go first for in-place context shifting in build_rope_shift()
-                ggml_tensor * Qcur = ggml_concat(ctx0, q_pe, q_nope_absorbed, 0);
+                ggml_tensor * Qcur = ggml_concat(ctx0, q_nope_absorbed, q_pe, 0);
                 cb(Qcur, "Qcur", il);
 
                 kv_cmpr = ggml_reshape_3d(ctx0, kv_cmpr, kv_lora_rank, 1, n_tokens);
                 cb(kv_cmpr, "kv_cmpr_reshape", il);
 
                 // {n_embd_head_qk_rope + kv_lora_rank, 1, n_tokens}
-                ggml_tensor * Kcur = ggml_concat(ctx0, k_pe, kv_cmpr, 0);
+                ggml_tensor * Kcur = ggml_concat(ctx0, kv_cmpr, k_pe, 0);
                 cb(Kcur, "Kcur", il);
 
                 // {kv_lora_rank, 1, n_tokens}
@@ -145,7 +146,7 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
                 }
 
                 // note: MLA with the absorption optimzation converts into MQA (ie: GQA with 1 group)
-                cur = build_attn(inp_attn,
+                cur = build_attn(inp_attn_k,
                         model.layers[il].wo, NULL,
                         Qcur, Kcur, Vcur, nullptr, nullptr, model.layers[il].wv_b, kq_scale, il);
             } else {
@@ -169,11 +170,10 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
                 Vcur = ggml_cont(ctx0, Vcur);
                 cb(Vcur, "Vcur_cont", il);
 
-                // note: rope must go first for in-place context shifting in build_rope_shift()
-                ggml_tensor * Qcur = ggml_concat(ctx0, q_pe, q_nope, 0);
+                ggml_tensor * Qcur = ggml_concat(ctx0, q_nope, q_pe, 0);
                 cb(Qcur, "Qcur", il);
 
-                ggml_tensor * Kcur = ggml_concat(ctx0, ggml_repeat(ctx0, k_pe, q_pe), k_nope, 0);
+                ggml_tensor * Kcur = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0);
                 cb(Kcur, "Kcur", il);
 
                 if (inp_attn_scale) {
@@ -183,7 +183,7 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
                 }
 
                 // note: MLA without the absorption optimization converts into MHA (ie: GQA with full n_head groups)
-                cur = build_attn(inp_attn,
+                cur = build_attn(inp_attn_kv,
                             model.layers[il].wo, NULL,
                             Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
             }
diff --git a/src/models/delta.cpp b/src/models/delta.cpp
new file mode 100644
index 00000000000..d1d9837d092
--- /dev/null
+++ b/src/models/delta.cpp
@@ -0,0 +1,618 @@
+#include "models.h"
+#include "ggml.h"
+#include 
+#include 
+#include 
+
+llm_graph_context_delta::llm_graph_context_delta(const llm_graph_params & params) : llm_graph_context_mamba(params) {}
+
+/**
+ * Unified Delta Net implementation supporting both GDA and KDA modes.
+ *
+ * GDA (Gated Delta Attention): g has shape [H, T, B] in GGML (PyTorch: [B, T, H])
+ *   - Per-head gating, broadcasts over K dimension
+ *
+ * KDA (Key-wise Delta Attention): g has shape [K, H, T, B] in GGML (PyTorch: [B, T, H, K])
+ *   - Per-key gating
+ *
+ * The mode is auto-detected based on g's dimensionality.
+ *
+ * Tensor dimension convention:
+ *   GGML: ne[0] is innermost (fastest varying), ne[3] is outermost
+ *   PyTorch: dim 0 is outermost, dim -1 is innermost
+ *   So GGML [A, B, C, D] corresponds to PyTorch [D, C, B, A]
+ */
+
+// Helper to get a slice along dimension 2 (n_chunks dimension)
+static ggml_tensor * get_slice_2d(ggml_context * ctx, ggml_tensor * t, int64_t chunk) {
+    return ggml_view_4d(ctx, t,
+        t->ne[0], t->ne[1], 1, t->ne[3],
+        t->nb[1], t->nb[2], t->nb[3],
+        chunk * t->nb[2]);
+}
+
+/**
+ * Unified chunked Delta Net implementation.
+ *
+ * Input tensor format matches qwen3next conventions:
+ * @param q         Query tensor [S_k, H_k, n_tokens, n_seqs]
+ * @param k         Key tensor [S_k, H_k, n_tokens, n_seqs]
+ * @param v         Value tensor [S_v, H_v, n_tokens, n_seqs]
+ * @param g         Gate tensor:
+ *                    GDA: [H_v, n_tokens, n_seqs]
+ *                    KDA: [S_k, H_v, n_tokens, n_seqs]
+ * @param beta      Beta tensor [H_v, 1, n_tokens, n_seqs]
+ * @param state     State tensor [S_v, S_v * H_v, 1, n_seqs]
+ * @param causal_mask   Lower triangular mask [chunk_size, chunk_size]
+ * @param identity      Identity matrix [chunk_size, chunk_size]
+ * @param diag_mask     Diagonal mask [chunk_size, chunk_size]
+ * @param il            Layer index (for debugging callbacks)
+ * @param chunk_size    Chunk size for chunked processing
+ * @param eps_norm      Epsilon for L2 normalization
+ *
+ * @return Pair of (output_tokens, new_state)
+ */
+std::pair llm_graph_context_delta::build_delta_net_unified_chunking(
+        ggml_context * ctx0,
+        ggml_tensor * q,
+        ggml_tensor * k,
+        ggml_tensor * v,
+        ggml_tensor * g,
+        ggml_tensor * beta,
+        ggml_tensor * state_reshaped,
+        ggml_tensor * causal_mask,
+        ggml_tensor * identity,
+        ggml_tensor * diag_mask,
+        int           il,
+        int64_t       chunk_size,
+        float         eps_norm) {
+
+    // Input format: [S, H, n_tokens, n_seqs] (matching qwen3next convention)
+    const int64_t S_k      = q->ne[0];
+    const int64_t H_k      = q->ne[1];
+    const int64_t n_tokens = q->ne[2];
+    const int64_t n_seqs   = q->ne[3];
+
+    const int64_t S_v = v->ne[0];
+    const int64_t H_v = v->ne[1];
+
+    // Detect KDA vs GDA based on g's shape
+    // GDA: g has shape [H_v, n_tokens, n_seqs]
+    // KDA: g has shape [S_k, H_v, n_tokens, n_seqs] (4D with ne[0]=S_k)
+    const bool is_kda = (g->ne[0] == S_k && g->ne[1] == H_v);
+
+    // Validate tensor shapes
+    GGML_ASSERT(v->ne[2] == n_tokens);
+    GGML_ASSERT(k->ne[2] == n_tokens);
+    GGML_ASSERT(state_reshaped->ne[0] == S_v && state_reshaped->ne[1] == S_v && state_reshaped->ne[2] == H_v && state_reshaped->ne[3] == n_seqs);
+    GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
+    GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
+    GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
+    GGML_ASSERT(H_k == H_v);
+
+    if (is_kda) {
+        // KDA: g shape [S_k, H_v, n_tokens, n_seqs]
+        GGML_ASSERT(g->ne[0] == S_k && g->ne[1] == H_v && g->ne[2] == n_tokens && g->ne[3] == n_seqs);
+    } else {
+        // GDA: g shape [H_v, n_tokens, n_seqs]
+        GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
+    }
+
+    // L2 normalize q and k
+    q = ggml_l2_norm(ctx0, q, eps_norm);
+    k = ggml_l2_norm(ctx0, k, eps_norm);
+
+    const float scale = 1.0f / sqrtf((float)S_v);
+    q = ggml_scale(ctx0, q, scale);
+
+    beta = ggml_sigmoid(ctx0, beta);
+
+    cb(q, "q_in", il);
+    cb(k, "k_in", il);
+    cb(v, "v_in", il);
+    cb(beta, "beta_in", il);
+    cb(g, "g_in", il);
+
+    // Permute tensors to working format [S, n_tokens, H, n_seqs]
+    // Input: [S, H, n_tokens, n_seqs] -> permute(0, 2, 1, 3) -> [S, n_tokens, H, n_seqs]
+    q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_k, n_tokens, H_k, n_seqs);
+    k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_k, n_tokens, H_k, n_seqs);
+    v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
+    if (is_kda) {
+        g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 0, 2, 1, 3), S_k, n_tokens, H_k, n_seqs);
+    } else {
+        g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 2, 0, 3, 1), n_tokens, 1, H_k, n_seqs);
+    }
+    beta = ggml_cont(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3));
+
+    cb(q, "q_perm", il);
+    cb(k, "k_perm", il);
+    cb(v, "v_perm", il);
+    cb(beta, "beta_perm", il);
+    cb(g, "g_perm", il);
+    cb(state_reshaped, "state_in", il);
+
+    // Padding for chunk processing
+    const int64_t pad = (chunk_size - n_tokens % chunk_size) % chunk_size;
+    const int64_t n_chunks = (n_tokens + pad) / chunk_size;
+
+    q = ggml_pad(ctx0, q, 0, pad, 0, 0);
+    k = ggml_pad(ctx0, k, 0, pad, 0, 0);
+    v = ggml_pad(ctx0, v, 0, pad, 0, 0);
+    beta = ggml_pad(ctx0, beta, 0, pad, 0, 0);
+    g = ggml_pad(ctx0, g, pad, 0, 0, 0);
+
+
+    cb(q, "q_pad", il);
+    cb(k, "k_pad", il);
+    cb(v, "v_pad", il);
+    cb(beta, "beta_pad", il);
+    cb(g, "g_pad", il);
+
+    ggml_tensor * v_beta = ggml_mul(ctx0, v, beta);
+    ggml_tensor * k_beta = ggml_mul(ctx0, k, beta);
+
+    cb(v_beta, "v_beta", il);
+    cb(k_beta, "k_beta", il);
+
+    // Reshape to chunks
+    q      = ggml_reshape_4d(ctx0, q,      S_k, chunk_size, n_chunks, H_k * n_seqs);
+    k      = ggml_reshape_4d(ctx0, k,      S_k, chunk_size, n_chunks, H_k * n_seqs);
+    k_beta = ggml_reshape_4d(ctx0, k_beta, S_k, chunk_size, n_chunks, H_k * n_seqs);
+    v      = ggml_reshape_4d(ctx0, v,      S_v, chunk_size, n_chunks, H_v * n_seqs);
+    v_beta = ggml_reshape_4d(ctx0, v_beta, S_v, chunk_size, n_chunks, H_v * n_seqs);
+    beta   = ggml_reshape_4d(ctx0, beta, 1, chunk_size, n_chunks, H_k * n_seqs);
+
+    // Reshape g for chunks
+    ggml_tensor * g_cumsum;
+    ggml_tensor * g_cumsum_t;
+    if (is_kda) {
+        // KDA: g [S_k, n_tokens+pad, H_k, n_seqs] -> [S_k, chunk_size, n_chunks, H_k * n_seqs]
+        g = ggml_reshape_4d(ctx0, g, S_k, chunk_size, n_chunks, H_k * n_seqs);
+        // Cumsum along chunk_size dimension (ne[1])
+        // GGML cumsum operates on ne[0], so we need to transpose, cumsum, transpose back
+        g = ggml_cont(ctx0, ggml_transpose(ctx0, g));  // [chunk_size, S_k, n_chunks, H_k * n_seqs]
+        g_cumsum_t = ggml_cumsum(ctx0, g);
+        g_cumsum = ggml_cont(ctx0, ggml_transpose(ctx0, g_cumsum_t));  // [S_k, chunk_size, n_chunks, H_k * n_seqs]
+    } else {
+        // GDA: g [n_tokens+pad, 1, H_k, n_seqs] -> [chunk_size, 1, n_chunks, H_k * n_seqs]
+        g = ggml_reshape_4d(ctx0, g, chunk_size, 1, n_chunks, H_k * n_seqs);
+        g_cumsum = ggml_cumsum(ctx0, g);
+        g_cumsum_t = ggml_reshape_4d(ctx0, g_cumsum, 1, chunk_size, n_chunks, H_k * n_seqs);
+    }
+
+    cb(g_cumsum, "g_cumsum", il);
+
+    // Build attention matrix A for the WY representation solve
+    // For GDA: A[j,i] = sum_k(k[j,k] * exp(g[j] - g[i]) * k[i,k]) = (k @ k^T) * exp(g[j] - g[i])
+    // For KDA: A[j,i] = sum_k(k_beta[j,k] * exp(g[j,k] - g[i,k]) * k[i,k])
+    // KDA uses decay mask with S_k packed into batch to compute exp(g[j,k] - g[i,k]) per-key
+
+    ggml_tensor * k_decay;
+    ggml_tensor * decay_mask = nullptr;
+    ggml_tensor * g_exp_pos = nullptr;
+
+    if (is_kda) {
+        // KDA: Use decay mask with S_k in leading dimension for efficient mul_mat reduction
+        // A[j,i] = sum_k(k_beta[j,k] * exp(g[j,k] - g[i,k]) * k[i,k])
+        // By putting S_k in dim 0, mul_mat implicitly sums over it
+
+        const int64_t CHB = n_chunks * H_k * n_seqs;
+
+        // g_cumsum_t is [chunk_size, S_k, n_chunks, H_k * n_seqs]
+        // Reshape to [chunk_size, S_k, CHB] then build decay mask
+        ggml_tensor * gcs = ggml_reshape_3d(ctx0, g_cumsum_t, chunk_size, S_k, CHB);
+        ggml_tensor * gcs_i = ggml_reshape_4d(ctx0, gcs, chunk_size, 1, S_k, CHB);
+        ggml_tensor * gcs_j = ggml_reshape_4d(ctx0, gcs, 1, chunk_size, S_k, CHB);
+
+        // Build decay mask: [chunk_size, chunk_size, S_k, CHB]
+        ggml_tensor * gcs_j_bc = ggml_repeat_4d(ctx0, gcs_j, chunk_size, chunk_size, S_k, CHB);
+        decay_mask = ggml_sub(ctx0, gcs_j_bc, gcs_i);
+
+        cb(decay_mask, "decay_mask_kda", il);
+
+        decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
+        decay_mask = ggml_exp(ctx0, decay_mask);
+        decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
+
+        // Permute to [S_k, chunk_size_j, chunk_size_i, CHB] for mul_mat reduction over S_k
+        decay_mask = ggml_cont_4d(ctx0, ggml_permute(ctx0, decay_mask, 2, 1, 0, 3), S_k, chunk_size, chunk_size, CHB);
+
+        // Reshape k and k_beta for broadcasting with decay_mask
+        // k_i: indexed at position i (dim 2 of decay_mask)
+        // k_beta_j: indexed at position j (dim 1 of decay_mask)
+        ggml_tensor * k_i = ggml_reshape_4d(ctx0, k, S_k, 1, chunk_size, CHB);
+        ggml_tensor * k_beta_j = ggml_reshape_4d(ctx0, k_beta, S_k, chunk_size, 1, CHB);
+
+        // decay_k_beta_j[s,j,i,b] = decay[s,j,i,b] * k_beta[s,j,b]
+        ggml_tensor * decay_k_beta_j = ggml_mul(ctx0, decay_mask, k_beta_j);
+
+        // mul_mat sums over S_k: result[j,1,i,CHB] = sum_s decay_k_beta_j[s,j,i,b] * k_i[s,1,i,b]
+        k_decay = ggml_mul_mat(ctx0, decay_k_beta_j, k_i);
+        k_decay = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, k_decay, chunk_size, chunk_size, n_chunks, H_k * n_seqs)));
+
+        // g_exp_pos is still needed for later (kbeta_gexp, etc.)
+        g_exp_pos = ggml_exp(ctx0, g_cumsum);
+    } else {
+        // GDA: Use decay mask approach (g broadcasts over K dimension)
+        // g_cumsum [chunk_size, 1, n_chunks, H_v * n_seqs]
+        ggml_tensor * gcs_i = g_cumsum;
+        ggml_tensor * gcs_j = g_cumsum_t;
+        g_exp_pos = ggml_exp(ctx0, g_cumsum_t);
+        ggml_tensor * gcs_j_broadcast = ggml_repeat_4d(ctx0, gcs_j, chunk_size, chunk_size, n_chunks, H_v * n_seqs);
+        decay_mask = ggml_sub(ctx0, gcs_j_broadcast, gcs_i);
+
+        cb(decay_mask, "decay_mask", il);
+
+        decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
+        decay_mask = ggml_exp(ctx0, decay_mask);
+        decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
+
+        ggml_tensor * kmulkbeta = ggml_mul_mat(ctx0, k, k_beta);
+        k_decay = ggml_mul(ctx0, kmulkbeta, decay_mask);
+    }
+
+    ggml_tensor * attn = ggml_neg(ctx0, ggml_mul(ctx0, k_decay, causal_mask));
+
+    cb(attn, "attn_pre_solve", il);
+
+    // Solve triangular system: (I + L) @ X = I, where L is strictly lower triangular
+    ggml_tensor * attn_lower = ggml_mul(ctx0, attn, causal_mask);
+    ggml_tensor * lhs = ggml_sub(ctx0, ggml_repeat(ctx0, identity, attn_lower), attn_lower);
+    ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false);
+    attn = ggml_mul(ctx0, lin_solve, causal_mask);
+    attn = ggml_add(ctx0, attn, identity);
+
+    cb(attn, "attn_solved", il);
+
+    // Compute u = A @ v and w = A @ (g.exp() * k)
+    v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), attn);
+
+    ggml_tensor * kbeta_gexp = ggml_mul(ctx0, k_beta, g_exp_pos);
+    cb(kbeta_gexp, "kbeta_gexp", il);
+
+    ggml_tensor * k_cumdecay = ggml_cont(ctx0, ggml_transpose(ctx0,
+        ggml_mul_mat(ctx0, attn, ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gexp)))));
+    cb(k_cumdecay, "k_cumdecay", il);
+
+    // Attention scores q @ k^T with decay
+    // For GDA: attn_kq[j,i] = sum_k(q[j,k] * exp(g[j] - g[i]) * k[i,k])
+    // For KDA: attn_kq[j,i] = sum_k(q[j,k] * exp(g[j,k] - g[i,k]) * k[i,k])
+    ggml_tensor * attn_kq;
+    if (is_kda) {
+        // KDA: Same approach as k_decay - use decay_mask with S_k in leading dim
+        const int64_t CHB = n_chunks * H_k * n_seqs;
+
+        // Rebuild decay mask (same structure as k_decay)
+        ggml_tensor * gcs = ggml_reshape_3d(ctx0, g_cumsum_t, chunk_size, S_k, CHB);
+        ggml_tensor * gcs_i = ggml_reshape_4d(ctx0, gcs, chunk_size, 1, S_k, CHB);
+        ggml_tensor * gcs_j = ggml_reshape_4d(ctx0, gcs, 1, chunk_size, S_k, CHB);
+        ggml_tensor * gcs_j_bc = ggml_repeat_4d(ctx0, gcs_j, chunk_size, chunk_size, S_k, CHB);
+        ggml_tensor * decay_mask_kq = ggml_sub(ctx0, gcs_j_bc, gcs_i);
+
+        decay_mask_kq = ggml_mul(ctx0, decay_mask_kq, diag_mask);
+        decay_mask_kq = ggml_exp(ctx0, decay_mask_kq);
+        decay_mask_kq = ggml_mul(ctx0, decay_mask_kq, diag_mask);
+
+        // Permute to [S_k, chunk_size_j, chunk_size_i, CHB]
+        decay_mask_kq = ggml_cont_4d(ctx0, ggml_permute(ctx0, decay_mask_kq, 2, 1, 0, 3), S_k, chunk_size, chunk_size, CHB);
+
+        // q_j: indexed at position j, k_i: indexed at position i
+        ggml_tensor * q_j = ggml_reshape_4d(ctx0, q, S_k, chunk_size, 1, CHB);
+        ggml_tensor * k_i = ggml_reshape_4d(ctx0, k, S_k, 1, chunk_size, CHB);
+
+        // decay_q_j[s,j,i,b] = decay[s,j,i,b] * q[s,j,b]
+        ggml_tensor * decay_q_j = ggml_mul(ctx0, decay_mask_kq, q_j);
+
+        // mul_mat sums over S_k
+        attn_kq = ggml_mul_mat(ctx0, decay_q_j, k_i);
+        attn_kq = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, attn_kq, chunk_size, chunk_size, n_chunks, H_k * n_seqs)));
+    } else {
+        // GDA: Use decay mask
+        attn_kq = ggml_mul_mat(ctx0, k, q);
+        attn_kq = ggml_mul(ctx0, attn_kq, decay_mask);
+        attn_kq = ggml_mul(ctx0, attn_kq, diag_mask);
+    }
+    cb(attn_kq, "attn_kq", il);
+
+    // Compute g_last and g_diff for state updates
+    ggml_tensor * g_last;
+    ggml_tensor * g_diff_exp;
+    ggml_tensor * g_last_exp;
+
+    if (is_kda) {
+        // KDA: g_cumsum [S_k, chunk_size, n_chunks, H_k * n_seqs]
+        // Get last element along chunk_size dimension (ne[1])
+        g_last = ggml_view_4d(ctx0, g_cumsum,
+            g_cumsum->ne[0], 1, g_cumsum->ne[2], g_cumsum->ne[3],
+            g_cumsum->nb[1], g_cumsum->nb[2], g_cumsum->nb[3],
+            (g_cumsum->ne[1] - 1) * g_cumsum->nb[1]);
+        g_last = ggml_cont(ctx0, g_last);
+        g_last_exp = ggml_exp(ctx0, g_last);
+
+        // g_diff = g_last - g_cumsum
+        ggml_tensor * g_last_broadcast = ggml_repeat_4d(ctx0, g_last,
+            g_cumsum->ne[0], g_cumsum->ne[1], g_cumsum->ne[2], g_cumsum->ne[3]);
+        ggml_tensor * g_diff = ggml_sub(ctx0, g_last_broadcast, g_cumsum);
+        g_diff_exp = ggml_exp(ctx0, g_diff);
+    } else {
+        // GDA: g_cumsum [chunk_size, 1, n_chunks, H_k * n_seqs]
+        g_last = ggml_view_4d(ctx0, g_cumsum,
+            1, 1, g_cumsum->ne[2], g_cumsum->ne[3],
+            g_cumsum->nb[1], g_cumsum->nb[2], g_cumsum->nb[3],
+            (g_cumsum->ne[0] - 1) * ggml_element_size(g_cumsum));
+        g_last = ggml_cont(ctx0, g_last);
+        g_last_exp = ggml_exp(ctx0, g_last);
+
+        ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cumsum, g_last));
+        g_diff_exp = ggml_exp(ctx0, g_diff);
+    }
+
+    cb(g_last, "g_last", il);
+    cb(g_last_exp, "g_last_exp", il);
+
+    ggml_tensor * key_gdiff = ggml_mul(ctx0, k, g_diff_exp);
+    cb(key_gdiff, "key_gdiff", il);
+
+    // Process chunks
+    ggml_tensor * new_state = state_reshaped;
+    ggml_tensor * core_attn_out = nullptr;
+
+    for (int64_t chunk = 0; chunk < n_chunks; chunk++) {
+        ggml_tensor * q_chunk = get_slice_2d(ctx0, q, chunk);
+        ggml_tensor * v_chunk = get_slice_2d(ctx0, v, chunk);
+        ggml_tensor * k_cumdecay_chunk = get_slice_2d(ctx0, k_cumdecay, chunk);
+        ggml_tensor * attn_chunk = get_slice_2d(ctx0, attn_kq, chunk);
+        ggml_tensor * gexp_chunk = get_slice_2d(ctx0, g_exp_pos, chunk);
+
+        cb(attn_chunk, "attn_chunk", il);
+
+        ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, new_state, 1, 0, 2, 3),
+            S_v, S_v, 1, H_v * n_seqs);
+
+        // v_prime = k_cumdecay @ state
+        ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay_chunk);
+        cb(v_prime, "v_prime_chunk", il);
+
+        // v_new = v - v_prime
+        ggml_tensor * v_new = ggml_sub(ctx0, ggml_repeat(ctx0, v_chunk, v_prime), v_prime);
+        ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new));
+        cb(v_new, "v_new_chunk", il);
+
+        // attn_inter = (q * g.exp()) @ state
+        ggml_tensor * q_g_exp = ggml_mul(ctx0, q_chunk, gexp_chunk);
+        ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_g_exp);
+        cb(attn_inter, "attn_inter_chunk", il);
+
+        // output = attn_inter + attn @ v_new
+        ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, attn_chunk);
+        cb(v_attn, "v_attn_chunk", il);
+
+        ggml_tensor * core_attn_out_chunk = ggml_add(ctx0, attn_inter, v_attn);
+        cb(core_attn_out_chunk, "core_attn_out_chunk", il);
+
+        core_attn_out = core_attn_out == nullptr
+            ? core_attn_out_chunk
+            : ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 2);
+
+        // State update: state = state * g_last_exp + key_gdiff^T @ v_new
+        ggml_tensor * k_gdiff = ggml_cont(ctx0, get_slice_2d(ctx0, key_gdiff, chunk));
+        ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, ggml_cont(ctx0, ggml_transpose(ctx0, k_gdiff)));
+
+        ggml_tensor * gexp_last_chunk = ggml_cont(ctx0, get_slice_2d(ctx0, g_last_exp, chunk));
+
+        if (is_kda) {
+            // KDA: g_last_exp [S_k, 1, n_chunks, H_k * n_seqs]
+            // State: [S_v, S_v, H_v, n_seqs]
+            // Need to reshape g_last_exp to broadcast correctly over V dimension only
+            gexp_last_chunk = ggml_reshape_4d(ctx0, gexp_last_chunk,
+                1, gexp_last_chunk->ne[0], H_v, n_seqs);  // [1, S_k, H_v, n_seqs]
+            // Transpose to [S_k, 1, H_v, n_seqs] then broadcast
+            gexp_last_chunk = ggml_cont(ctx0, ggml_permute(ctx0, gexp_last_chunk, 1, 0, 2, 3));
+        } else {
+            // GDA: g_last_exp [1, 1, n_chunks, H_k * n_seqs]
+            // Broadcasts over both K and V dimensions
+            gexp_last_chunk = ggml_reshape_4d(ctx0, gexp_last_chunk,
+                gexp_last_chunk->ne[0], gexp_last_chunk->ne[1], H_v, n_seqs);
+        }
+
+        new_state = ggml_add(ctx0,
+            ggml_mul(ctx0, new_state, gexp_last_chunk),
+            ggml_reshape_4d(ctx0, kgdmulvnew, kgdmulvnew->ne[0], kgdmulvnew->ne[1], H_v, n_seqs));
+    }
+
+    // Truncate padding and permute back
+    ggml_tensor * output_tokens = ggml_view_4d(ctx0, core_attn_out,
+        S_v, n_tokens, H_v, n_seqs,
+        ggml_row_size(core_attn_out->type, S_v),
+        ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks),
+        ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks * H_v), 0);
+    output_tokens = ggml_cont(ctx0, output_tokens);
+
+    cb(output_tokens, "output_tokens", il);
+
+    output_tokens = ggml_permute(ctx0, output_tokens, 0, 2, 1, 3);
+    output_tokens = ggml_cont(ctx0, output_tokens);
+
+    return {output_tokens, new_state};
+}
+
+
+/**
+ * Unified autoregressive Delta Net implementation (single token processing).
+ *
+ * This implementation uses matrix multiplication instead of elementwise operations + summation,
+ * which is more efficient and mathematically equivalent. See inline comments for equivalences.
+ *
+ * Input tensor format matches qwen3next conventions:
+ * @param q         Query tensor [S_k, H_k, 1, n_seqs]
+ * @param k         Key tensor [S_k, H_k, 1, n_seqs]
+ * @param v         Value tensor [S_v, H_v, 1, n_seqs]
+ * @param g         Gate tensor:
+ *                    GDA: [H_v, 1, n_seqs]
+ *                    KDA: [S_k, H_v, 1, n_seqs]
+ * @param beta      Beta tensor [H_v, 1, 1, n_seqs]
+ * @param state     State tensor [S_v, S_v * H_v, 1, n_seqs]
+ * @param il        Layer index (for debugging callbacks)
+ * @param eps_norm  Epsilon for L2 normalization
+ *
+ * @return Pair of (output_tokens, new_state)
+ */
+std::pair llm_graph_context_delta::build_delta_net_unified_autoregressive(
+        ggml_context * ctx0,
+        ggml_tensor * q,
+        ggml_tensor * k,
+        ggml_tensor * v,
+        ggml_tensor * g,
+        ggml_tensor * beta,
+        ggml_tensor * state,
+        int           il,
+        float         eps_norm) {
+
+    // Input format: [S, H, n_tokens, n_seqs] (matching qwen3next convention)
+    const int64_t S_k      = q->ne[0];
+    const int64_t H_k      = q->ne[1];
+    const int64_t n_tokens = q->ne[2];
+    const int64_t n_seqs   = q->ne[3];
+
+    const int64_t S_v = v->ne[0];
+    const int64_t H_v = v->ne[1];
+
+    GGML_ASSERT(n_tokens == 1);  // Autoregressive mode is for single token
+
+    // Detect KDA vs GDA based on g's shape
+    // GDA: g has shape [H_v, 1, n_seqs] or [H_v, n_tokens, n_seqs]
+    // KDA: g has shape [S_k, H_v, 1, n_seqs] or [S_k, H_v, n_tokens, n_seqs]
+    const bool is_kda = (g->ne[0] == S_k && g->ne[1] == H_v);
+
+    // Validate shapes
+    GGML_ASSERT(v->ne[2] == n_tokens);
+    GGML_ASSERT(k->ne[2] == n_tokens);
+    GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v && state->ne[2] == H_v && state->ne[3] == n_seqs);
+    GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
+    GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
+    GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
+    GGML_ASSERT(H_k == H_v);
+
+    if (is_kda) {
+        GGML_ASSERT(g->ne[0] == S_k && g->ne[1] == H_v);
+    } else {
+        GGML_ASSERT(g->ne[0] == H_v);
+    }
+
+    // L2 normalize q and k
+    q = ggml_l2_norm(ctx0, q, eps_norm);
+    k = ggml_l2_norm(ctx0, k, eps_norm);
+
+    const float scale = 1.0f / sqrtf((float)S_v);
+    q = ggml_scale(ctx0, q, scale);
+    beta = ggml_sigmoid(ctx0, beta);
+
+    cb(q, "q_in", il);
+    cb(k, "k_in", il);
+    cb(v, "v_in", il);
+    cb(beta, "beta_in", il);
+    cb(g, "g_in", il);
+
+    // Reshape g and beta for broadcasting
+    ggml_tensor * g_t;
+    ggml_tensor * beta_t;
+
+    if (is_kda) {
+        // KDA: g [S_k, H_v, 1, n_seqs] -> [S_k, 1, H_k, n_seqs]
+        // For state multiplication, need [1, S_k, H_v, n_seqs] to broadcast over V only
+        g_t = ggml_reshape_4d(ctx0, g, S_k, 1, H_k, n_seqs);
+    } else {
+        // GDA: g [H_v, 1, n_seqs] -> [1, 1, H_k, n_seqs]
+        // For state multiplication, broadcasts over both K and V
+        g_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, g), 1, 1, H_k, n_seqs);
+    }
+
+    beta_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, beta), 1, 1, H_k, n_seqs);
+
+    // Apply exponential to g_t
+    g_t = ggml_exp(ctx0, g_t);
+
+    // State decay: state = state * exp(g)
+    if (is_kda) {
+        // KDA: g_t [S_k, 1, H_k, n_seqs], state [S_v, S_v, H_v, n_seqs]
+        // Need to broadcast g_t over V dimension (ne[0] of state)
+        // Permute g_t to [1, S_k, H_k, n_seqs] for correct broadcasting
+        ggml_tensor * g_broadcast = ggml_cont(ctx0, ggml_permute(ctx0, g_t, 1, 0, 2, 3));
+        state = ggml_mul(ctx0, state, g_broadcast);
+    } else {
+        // GDA: g_t [1, 1, H_k, n_seqs] broadcasts over both dimensions
+        state = ggml_mul(ctx0, state, g_t);
+    }
+
+    // Equivalence to previous version:
+    // Previous: kv_mem = sum_k(state * k) using elementwise mult + sum_rows
+    // Current:  k_state = state_t @ k_t using matrix multiplication
+    // These are equivalent because: sum_k(A * B) = A @ B when dimensions align
+    ggml_tensor * state_t = ggml_cont(ctx0, ggml_transpose(ctx0, state));
+    ggml_tensor * k_t = ggml_reshape_4d(ctx0, k, S_k, 1, H_k, n_seqs);
+    ggml_tensor * k_state = ggml_mul_mat(ctx0, state_t, k_t);
+
+    // v_diff = v - k_state (equivalent to v - kv_mem in previous version)
+    ggml_tensor * v_t = ggml_reshape_4d(ctx0, v, S_v, 1, H_v, n_seqs);
+    ggml_tensor * v_diff = ggml_sub(ctx0, v_t, k_state);
+    ggml_tensor * k_beta = ggml_mul(ctx0, k_t, beta_t);
+
+    // Equivalence to previous version:
+    // Previous: state += k.unsqueeze(-1) * delta where delta = (v - kv_mem) * beta
+    // Current:  state += v_diff^T @ k_beta^T using matrix multiplication
+    // These are equivalent because: outer_product(k, v_diff * beta) = v_diff^T @ k^T
+    state = ggml_add(ctx0, state, ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_diff)), ggml_cont(ctx0, ggml_transpose(ctx0, k_beta))));
+
+    // Equivalence to previous version:
+    // Previous: core_attn_out = sum_k(state * q) using elementwise mult + sum_rows
+    // Current:  core_attn_out = state_t @ q using matrix multiplication
+    // These are equivalent because: sum_k(A * B) = A @ B when dimensions align
+    q = ggml_reshape_4d(ctx0, q, S_k, 1, H_k, n_seqs);
+    state_t = ggml_cont(ctx0, ggml_transpose(ctx0, state));
+    ggml_tensor * core_attn_out = ggml_mul_mat(ctx0, state_t, q);
+    // core_attn_out should be [S_v, 1, H_v, n_seqs] after this
+    cb(core_attn_out, "output_tokens", il);
+    cb(state, "new_state", il);
+
+    return {core_attn_out, state};
+}
+
+
+/**
+ * Main entry point that dispatches to chunked or autoregressive based on n_tokens.
+ *
+ * Input tensor format matches qwen3next conventions:
+ * @param q         Query tensor [S_k, H_k, n_tokens, n_seqs]
+ * @param k         Key tensor [S_k, H_k, n_tokens, n_seqs]
+ * @param v         Value tensor [S_v, H_v, n_tokens, n_seqs]
+ * @param g         Gate tensor (GDA: [H_v, n_tokens, n_seqs], KDA: [S_k, H_v, n_tokens, n_seqs])
+ * @param beta      Beta tensor [H_v, 1, n_tokens, n_seqs]
+ * @param state     State tensor [S_v, S_v * H_v, 1, n_seqs]
+ */
+std::pair llm_graph_context_delta::build_delta_net_unified(
+        ggml_context * ctx0,
+        ggml_tensor * q,
+        ggml_tensor * k,
+        ggml_tensor * v,
+        ggml_tensor * g,
+        ggml_tensor * beta,
+        ggml_tensor * state,
+        ggml_tensor * causal_mask,
+        ggml_tensor * identity,
+        ggml_tensor * diag_mask,
+        int           il,
+        int64_t       chunk_size,
+        float         eps_norm) {
+
+    // Input format: [S, H, n_tokens, n_seqs] (matching qwen3next convention)
+    const int64_t n_tokens = q->ne[2];
+
+    if (n_tokens == 1) {
+        return build_delta_net_unified_autoregressive(
+            ctx0, q, k, v, g, beta, state, il, eps_norm);
+    }
+    return build_delta_net_unified_chunking(
+        ctx0, q, k, v, g, beta, state, causal_mask, identity, diag_mask,
+        il, chunk_size, eps_norm);
+}
diff --git a/src/models/gemma3n-iswa.cpp b/src/models/gemma3n-iswa.cpp
index 51acab14908..7db6d3bf4ec 100644
--- a/src/models/gemma3n-iswa.cpp
+++ b/src/models/gemma3n-iswa.cpp
@@ -245,12 +245,12 @@ ggml_tensor * llm_build_gemma3n_iswa::view_2d_slice(ggml_tensor * x, int idx) {
 // equivalent to get_per_layer_inputs() in python code
 // output shape: [n_embd_altup, n_layer, n_tokens]
 ggml_tensor * llm_build_gemma3n_iswa::get_per_layer_inputs() {
-    auto inp = std::make_unique();
+    auto inp = std::make_unique(n_embd);
     ggml_tensor * inp_per_layer;
     if (ubatch.token) {
         inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
         ggml_set_input(inp->tokens);
-        res->t_tokens = inp->tokens;
+        res->t_inp_tokens = inp->tokens;
         inp_per_layer = ggml_get_rows(ctx0, model.tok_embd_per_layer, inp->tokens);
         inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_altup, n_layer, n_tokens);
         inp_per_layer = ggml_scale(ctx0, inp_per_layer, sqrtf((float) n_embd_altup));
diff --git a/src/models/kimi-linear.cpp b/src/models/kimi-linear.cpp
new file mode 100644
index 00000000000..d9ee6980751
--- /dev/null
+++ b/src/models/kimi-linear.cpp
@@ -0,0 +1,771 @@
+#include "models.h"
+
+#define CHUNK_SIZE 64
+
+// Causal Conv1d function for Q,K,V
+// When qkv is 0, it is Q, 1 is K, 2 is V
+static ggml_tensor * causal_conv1d(ggml_cgraph * gf, ggml_context * ctx0, ggml_tensor * conv_states_all, ggml_tensor * conv_state_all, int64_t qkv, ggml_tensor * x, ggml_tensor * proj_w, ggml_tensor * conv_w, int64_t d_conv, int64_t head_dim, int64_t n_head, int64_t n_seq_tokens, int64_t n_seqs, int64_t n_tokens, int64_t kv_head) {
+    const int64_t d_inner = head_dim * n_head;
+    const int64_t conv_state_size = (d_conv - 1) * d_inner;
+    const int64_t n_embd_r_total = 3 * conv_state_size;  // Q + K + V
+
+    // conv_state_all is [n_embd_r_total, n_seqs], split into Q, K, V
+    // Each conv state is [(d_conv-1) * d_inner] per sequence, need to reshape to [d_conv-1, d_inner, n_seqs]
+    // Memory layout: for each seq, Q state is first conv_state_size elements, then K, then V
+    // conv_state_all has stride: nb[0] = element_size, nb[1] = n_embd_r_total * element_size
+    // View Q conv state: offset 0, size conv_state_size per seq
+    // conv_state_all is [n_embd_r_total, n_seqs] with memory layout:
+    //   state[i + seq * n_embd_r_total] where i = conv_step + channel * (d_conv-1) + {0, conv_state_size, 2*conv_state_size} for Q/K/V
+    // We want [d_conv-1, d_inner, n_seqs] view:
+    //   nb1 = (d_conv-1) * element_size (stride between channels)
+    //   nb2 = n_embd_r_total * element_size (stride between seqs)
+    ggml_tensor * conv_state_x = ggml_view_3d(ctx0, conv_state_all, d_conv - 1, d_inner, n_seqs,
+        (d_conv - 1) * ggml_element_size(conv_state_all),  // nb1: stride between channels
+        n_embd_r_total * ggml_element_size(conv_state_all),  // nb2: stride between seqs
+        qkv * conv_state_size * ggml_element_size(conv_state_all));
+
+// Causal Conv1d function for Q,K,V
+// When qkv is 0, it is Q, 1 is K, 2 is V
+    // Step 1: Q, K, V projections -> [d_inner, n_tokens]
+    ggml_tensor * x_proj = ggml_mul_mat(ctx0, proj_w, x);
+
+    // Reshape input: {d_inner, n_tokens} -> {d_inner, n_seq_tokens, n_seqs}
+    ggml_tensor * x_3d = ggml_reshape_3d(ctx0, x_proj, d_inner, n_seq_tokens, n_seqs);
+
+    // Concat Q conv state and current input: {d_conv-1 + n_seq_tokens, d_inner, n_seqs}
+    ggml_tensor * conv_x = ggml_concat(ctx0, conv_state_x, ggml_transpose(ctx0, x_3d), 0);
+
+    // Save last (d_conv-1) columns back to Q conv state
+    ggml_tensor * last_conv_x = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_inner, n_seqs,
+        conv_x->nb[1], conv_x->nb[2], n_seq_tokens * conv_x->nb[0]);
+    ggml_build_forward_expand(gf,
+        ggml_cpy(ctx0, last_conv_x,
+            ggml_view_1d(ctx0, conv_states_all, conv_state_size * n_seqs,
+                (kv_head * n_embd_r_total + qkv * conv_state_size) * ggml_element_size(conv_states_all))));
+    // Reshape conv weight: GGUF [d_conv, 1, d_inner, 1] -> ggml_ssm_conv expects [d_conv, d_inner]
+    // GGUF stores as [d_conv, 1, d_inner, 1] with memory layout w[conv_step + channel * d_conv]
+    // vLLM stores as [d_inner, d_conv] with memory layout w[channel * d_conv + conv_step]
+    // ggml_ssm_conv computes: c[conv_step + channel * d_conv]
+    // GGUF layout: [d_conv, 1, d_inner] or [d_conv, 1, d_inner, 1] -> reshape to [d_conv, d_inner]
+    // Reshape conv weight from [d_conv, 1, d_inner, 1] to [d_conv, d_inner] for ggml_ssm_conv
+    ggml_tensor * conv_weight = ggml_reshape_2d(ctx0, conv_w, d_conv, d_inner);
+
+    // Apply conv1d
+    // ggml_ssm_conv output: {d_inner, n_seq_tokens, n_seqs}
+    ggml_tensor * Xcur = ggml_ssm_conv(ctx0, conv_x, conv_weight);
+    // Reshape to 2D for bias add: {d_inner, n_tokens}
+    Xcur = ggml_reshape_2d(ctx0, Xcur, d_inner, n_tokens);
+    Xcur = ggml_silu(ctx0, Xcur);
+
+    return ggml_reshape_4d(ctx0, Xcur, head_dim, n_head, n_seq_tokens, n_seqs);
+}
+
+llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const llm_graph_params & params) :
+    llm_graph_context_mamba(params), model(model) {
+    ggml_tensor * cur;
+    ggml_tensor * inpL;
+
+    inpL = build_inp_embd(model.tok_embd);
+    cb(inpL, "model.embed_tokens", -1);
+
+    // Note: Kimi MLA does NOT use RoPE (rotary_emb=None in vLLM)
+    // So we don't need inp_pos
+
+    auto * inp_kv = !hparams.is_mla() ? build_inp_mem_hybrid() : nullptr;
+    auto * inp_k = hparams.is_mla() ? build_inp_mem_hybrid_k() : nullptr;
+    auto * inp_rs = hparams.is_mla() ? inp_k->get_recr() : inp_kv->get_recr();
+    auto * inp_attn_kv = !hparams.is_mla() ? inp_kv->get_attn() : nullptr;
+    auto * inp_attn_k = hparams.is_mla() ? inp_k->get_attn() : nullptr;
+
+    // Output ids for selecting which tokens to output
+    ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+    ggml_tensor * chunked_causal_mask =
+        ggml_tri(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, CHUNK_SIZE, CHUNK_SIZE), 1.0f),
+                    GGML_TRI_TYPE_LOWER);
+
+    ggml_tensor * chunked_identity = ggml_diag(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, CHUNK_SIZE), 1.0f));
+    ggml_tensor * chunked_diag_mask = ggml_add(ctx0, chunked_causal_mask, chunked_identity);
+
+    ggml_build_forward_expand(gf, chunked_causal_mask);
+    ggml_build_forward_expand(gf, chunked_identity);
+    ggml_build_forward_expand(gf, chunked_diag_mask);
+
+    // Kimi dimension constants
+    const int64_t n_head = hparams.n_head();
+    const int64_t head_dim = hparams.n_embd_head_kda;
+    const int64_t d_conv = hparams.ssm_d_conv;
+    const int64_t d_inner = n_head * head_dim;  // 32 * 128 = 4096
+    const int64_t n_seqs = ubatch.n_seqs;
+    const int64_t n_seq_tokens = ubatch.n_seq_tokens;
+
+    // Verify batch consistency for recurrent layers
+    GGML_ASSERT(n_seqs != 0);
+    GGML_ASSERT(ubatch.equal_seqs());
+    GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
+
+    // MLA params
+    const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla();
+    const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla();
+    const int64_t kv_lora_rank = hparams.n_lora_kv;
+    // qk_rope_head_dim = 64 (from Kimi config) which is hparams.n_rot
+    // Confirmed from tensor shape: wkv_a_mqa [2304, 576] = [n_embd, kv_lora_rank + qk_rope_head_dim]
+    const int64_t n_embd_head_qk_rope = hparams.n_rot;  // config.qk_rope_head_dim
+    const int64_t n_embd_head_qk_nope = n_embd_head_k_mla - n_embd_head_qk_rope;  // 192 - 64 = 128
+    // Attention scale for MLA
+    const float kq_scale_mla = 1.0f / sqrtf((float)n_embd_head_k_mla);
+
+    for (int il = 0; il < n_layer; ++il) {
+        const auto & layer = model.layers[il];
+        ggml_tensor * inpSA = inpL;
+
+        // Attention Norm
+        cur = build_norm(inpL, layer.attn_norm, NULL, LLM_NORM_RMS, il);
+        cb(cur, "attn_norm", il);
+
+        // Check layer type by checking which tensors exist
+        // KDA layers have ssm_a_log tensor, MLA layers have wkv_a_mqa tensor
+        bool is_kda = (layer.ssm_a != nullptr);
+        bool is_mla = (layer.wkv_a_mqa != nullptr);
+
+        if (is_kda) {
+            // === KDA Layer (Kimi Delta Attention) with Recurrent State ===
+            // Reference: vLLM kda.py
+            const auto * mctx_cur = inp_rs->mctx;
+            const auto kv_head = mctx_cur->get_head();
+
+            // Get conv states from r_l tensor (Q, K, V each have separate state)
+            ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
+            cb(conv_states_all, "conv_states_all", il);
+            ggml_tensor * conv_state_all = build_rs(inp_rs, conv_states_all, hparams.n_embd_r(), n_seqs);
+            ggml_tensor * Qcur = causal_conv1d(gf, ctx0, conv_states_all, conv_state_all, 0, cur, layer.wq, layer.ssm_q_conv, d_conv, head_dim, n_head, n_seq_tokens, n_seqs, n_tokens, kv_head);
+            ggml_tensor * Kcur = causal_conv1d(gf, ctx0, conv_states_all, conv_state_all, 1, cur, layer.wk, layer.ssm_k_conv, d_conv, head_dim, n_head, n_seq_tokens, n_seqs, n_tokens, kv_head);
+            ggml_tensor * Vcur = causal_conv1d(gf, ctx0, conv_states_all, conv_state_all, 2, cur, layer.wv, layer.ssm_v_conv, d_conv, head_dim, n_head, n_seq_tokens, n_seqs, n_tokens, kv_head);
+
+            // g1 = -exp(A_log) * softplus(f_b(f_a(x)) + dt_bias)
+            ggml_tensor * f_a = ggml_mul_mat(ctx0, layer.ssm_f_a, cur);
+            ggml_tensor * g1 = ggml_mul_mat(ctx0, layer.ssm_f_b, f_a);
+            cb(g1, "g1 f_b(f_a(cur))", il);
+            g1 = ggml_add(ctx0, g1, layer.ssm_dt_b);
+            g1 = ggml_softplus(ctx0, g1);
+            g1 = ggml_reshape_3d(ctx0, g1, head_dim, n_head, n_tokens);
+
+            // A_log shape is [1, n_head] or [1, n_head, 1, 1], need to broadcast to [head_dim, n_head, n_tokens]. No need to -exp(a_log) because it was done in convert_hf_to_gguf.py
+            // Reshape to [1, n_head, 1] for broadcasting with g1 [head_dim, n_head, n_tokens]
+            ggml_tensor * A = ggml_reshape_3d(ctx0, layer.ssm_a, 1, n_head, 1);
+            g1 = ggml_mul(ctx0, g1, A);
+            cb(g1, "kda_g1", il);
+
+            // Compute beta (mixing coefficient)
+            ggml_tensor * beta = ggml_mul_mat(ctx0, layer.ssm_beta, cur);
+            beta = ggml_reshape_4d(ctx0, beta, n_head, 1, n_seq_tokens, n_seqs);
+            cb(beta, "kda_beta", il);
+
+            // Reshape for KDA recurrence
+            // {n_embd, n_tokens} -> {n_embd, n_seq_tokens, n_seqs}
+            cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs);
+
+            g1 = ggml_reshape_4d(ctx0, g1, head_dim, n_head, n_seq_tokens, n_seqs);
+
+            // Get SSM state and compute KDA recurrence using ggml_kda_scan
+            ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
+            ggml_tensor * state = build_rs(inp_rs, ssm_states_all, hparams.n_embd_s(), n_seqs);
+            state = ggml_reshape_4d(ctx0, state, head_dim, head_dim, n_head, n_seqs);
+            // Choose between build_kda_chunking and build_kda_recurrent based on n_tokens
+            std::pair attn_out = n_seq_tokens == 1 ?
+                build_kda_autoregressive(Qcur, Kcur, Vcur, g1, beta, state, il) :
+                build_kda_chunking(Qcur, Kcur, Vcur, g1, beta, state, chunked_causal_mask, chunked_identity, chunked_diag_mask, il);
+
+            ggml_tensor * output = attn_out.first;
+            ggml_tensor * new_state = attn_out.second;
+            cb(output, "attn_output", il);
+            cb(new_state, "new_state", il);
+
+            // Update the recurrent states
+            ggml_build_forward_expand(gf,
+                                     ggml_cpy(ctx0, new_state,
+                                              ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs,
+                                                           kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all))));
+
+            // Output gating g2 = g_b(g_a(x))
+            ggml_tensor * cur_2d = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs);
+            ggml_tensor * g_a = ggml_mul_mat(ctx0, layer.ssm_g_a, cur_2d);
+            ggml_tensor * g2 = ggml_mul_mat(ctx0, layer.ssm_g_b, g_a);
+            cb(g2, "g2 g_b(g_a(cur_2d))", il);
+            g2 = ggml_reshape_3d(ctx0, g2, head_dim, n_head, n_seq_tokens * n_seqs);
+
+            // Apply o_norm with sigmoid gating
+            // Note: Kimi model uses sigmoid gating, not SiLU (despite FusedRMSNormGated default being swish)
+            // Formula: output = RMSNorm(x) * sigmoid(g)
+            ggml_tensor * attn_out_final = ggml_reshape_3d(ctx0, output, head_dim, n_head,  n_seq_tokens * n_seqs);
+            ggml_tensor * normed = build_norm(attn_out_final, layer.ssm_o_norm, nullptr, LLM_NORM_RMS, il);
+            cb(normed, "kda_normed", il);
+            ggml_tensor * gate = ggml_sigmoid(ctx0, g2);
+            ggml_tensor * gated = ggml_mul(ctx0, normed, gate);
+
+            // Output projection
+            gated = ggml_cont_2d(ctx0, gated, d_inner, n_tokens);
+            cur = ggml_mul_mat(ctx0, layer.wo, gated);
+            cb(cur, "kda_out", il);
+
+        } else if (is_mla) {
+            // === MLA Layer (Multi-head Latent Attention) without KV Cache ===
+            // Reference: vLLM mla.py
+            // Step 1: Q projection and reshape
+            // vLLM Kimi: q = q_proj(hidden_states), then view as [n_tokens, n_head, qk_head_dim]
+            // Note: Kimi MLA does NOT use RoPE (rotary_emb=None in vLLM)
+            ggml_tensor * Qcur = ggml_mul_mat(ctx0, layer.wq, cur);
+
+            // Step 2: KV compression
+            // kv_cmpr_pe = kv_a_proj_with_mqa(hidden_states) -> [kv_lora_rank + qk_rope_head_dim, n_tokens]
+            ggml_tensor * kv_cmpr_pe = ggml_mul_mat(ctx0, layer.wkv_a_mqa, cur);
+
+            // Split: kv_cmpr = kv_lora[:kv_lora_rank], k_pe = kv_lora[kv_lora_rank:]
+            ggml_tensor * kv_cmpr = ggml_view_2d(ctx0, kv_cmpr_pe, kv_lora_rank, n_tokens,
+                ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope), 0);
+            ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_cmpr_pe, n_embd_head_qk_rope, 1, n_tokens,
+                ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope),
+                ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope),
+                ggml_row_size(kv_cmpr_pe->type, kv_lora_rank));
+            // Note: Kimi MLA does NOT apply RoPE (rotary_emb=None in vLLM)
+            // k_pe is used directly without RoPE
+            // Normalize kv_c
+            kv_cmpr = build_norm(kv_cmpr, layer.attn_kv_a_norm, nullptr, LLM_NORM_RMS, il);
+
+            if (layer.wk_b && layer.wv_b) { // MLA KV cache enabled
+                // extract q_nope
+                ggml_tensor * q_nope =
+                    ggml_view_3d(ctx0, Qcur, n_embd_head_qk_nope, n_head, n_tokens, ggml_row_size(Qcur->type, n_embd_head_k_mla),
+                                 ggml_row_size(Qcur->type, n_embd_head_k_mla) * n_head, 0);
+                cb(q_nope, "q_nope", il);
+
+                // and {n_embd_head_qk_rope, n_head, n_tokens}
+                ggml_tensor * q_pe = ggml_view_3d(
+                    ctx0, Qcur, n_embd_head_qk_rope, n_head, n_tokens, ggml_row_size(Qcur->type, n_embd_head_k_mla),
+                    ggml_row_size(Qcur->type, n_embd_head_k_mla) * n_head, ggml_row_size(Qcur->type, n_embd_head_qk_nope));
+                cb(q_pe, "q_pe", il);
+
+                // {n_embd_head_qk_nope, n_tokens, n_head}
+                q_nope = ggml_permute(ctx0, q_nope, 0, 2, 1, 3);
+                cb(q_nope, "q_nope_perm", il);
+
+                // {n_embd_head_qk_nope, kv_lora_rank, n_head} x {n_embd_head_qk_nope, n_tokens, n_head}
+                ggml_tensor * q_nope_absorbed = ggml_mul_mat(ctx0, layer.wk_b, q_nope);
+                cb(q_nope_absorbed, "q_nope_absorbed", il);
+
+                // {kv_lora_rank, n_head, n_tokens}
+                q_nope_absorbed = ggml_permute(ctx0, q_nope_absorbed, 0, 2, 1, 3);
+                cb(q_nope_absorbed, "q_nope_absorbed_perm", il);
+
+                // {n_embd_head_qk_rope + kv_lora_rank, n_head, n_tokens}
+                // note: rope must go first for in-place context shifting in build_rope_shift()
+                Qcur = ggml_concat(ctx0, q_nope_absorbed, q_pe, 0);
+                cb(Qcur, "Qcur", il);
+
+                kv_cmpr = ggml_reshape_3d(ctx0, kv_cmpr, kv_lora_rank, 1, n_tokens);
+                cb(kv_cmpr, "kv_cmpr_reshape", il);
+
+                // {n_embd_head_qk_rope + kv_lora_rank, 1, n_tokens}
+                ggml_tensor * Kcur = ggml_concat(ctx0, kv_cmpr, k_pe, 0);
+                cb(Kcur, "Kcur", il);
+
+                // {kv_lora_rank, 1, n_tokens}
+                ggml_tensor * Vcur = kv_cmpr;
+                cb(Vcur, "Vcur", il);
+
+                cur = build_attn(inp_attn_k, layer.wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, layer.wv_b, kq_scale_mla, il);
+                cb(cur, "mla_out", il);
+            } else { // MLA KV cache disabled. Fall back to MHA KV cache.
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k_mla, n_head, n_tokens);
+                cb(Qcur, "mla_Q", il);
+                // KV decompression: kv = kv_b_proj(kv_c_normed)
+                ggml_tensor * kv = ggml_mul_mat(ctx0, layer.wkv_b, kv_cmpr);
+                const int64_t kv_per_head = n_embd_head_qk_nope + n_embd_head_v_mla;
+
+                // Split kv into k_nope and v
+                ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens,
+                    ggml_row_size(kv->type, kv_per_head),
+                    ggml_row_size(kv->type, kv_per_head * n_head), 0);
+                ggml_tensor * Vcur = ggml_view_3d(ctx0, kv, n_embd_head_v_mla, n_head, n_tokens,
+                    ggml_row_size(kv->type, kv_per_head),
+                    ggml_row_size(kv->type, kv_per_head * n_head),
+                    ggml_row_size(kv->type, n_embd_head_qk_nope));
+                Vcur = ggml_cont(ctx0, Vcur);
+                cb(Vcur, "mla_V", il);
+
+                // Concatenate k_nope + k_pe (broadcast k_pe to all heads)
+                // K = [k_nope, k_pe] where k_nope is [qk_nope_head_dim, n_head, n_tokens]
+                // and k_pe is [qk_rope_head_dim, 1, n_tokens] broadcast to all heads
+                // Need to broadcast k_pe from [qk_rope, 1, n_tokens] to [qk_rope, n_head, n_tokens]
+                ggml_tensor * k_pe_target = ggml_new_tensor_3d(ctx0, k_pe->type, n_embd_head_qk_rope, n_head, n_tokens);
+                ggml_tensor * k_pe_repeated = ggml_repeat(ctx0, k_pe, k_pe_target);
+                ggml_tensor * Kcur = ggml_concat(ctx0, k_pe_repeated, k_nope, 0);
+                cb(Kcur, "mla_K", il);
+
+                // Direct softmax attention (with MHA KV cache)
+                // Use build_attn with inp_attn for proper mask handling
+                cur = build_attn(inp_attn_kv, layer.wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale_mla, il);
+                cb(cur, "mla_out", il);
+            }
+        } else {
+            // Unknown layer type - this should not happen
+            GGML_ABORT("Kimi layer is neither KDA nor MLA - missing required tensors");
+        }
+
+        // On last layer, select only the output tokens
+        if (il == n_layer - 1 && inp_out_ids) {
+            cur   = ggml_get_rows(ctx0, cur,   inp_out_ids);
+            inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+        }
+
+        // Residual
+        ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+        cb(ffn_inp, "ffn_inp", il);
+
+        // FFN Norm
+        cur = build_norm(ffn_inp, layer.ffn_norm, NULL, LLM_NORM_RMS, il);
+        cb(cur, "ffn_norm", il);
+
+        if ((uint32_t) il < hparams.n_layer_dense_lead) {
+            // Dense FFN layer
+            cur = build_ffn(cur,
+                layer.ffn_up, NULL, NULL,
+                layer.ffn_gate, NULL, NULL,
+                layer.ffn_down, NULL, NULL,
+                NULL, LLM_FFN_SILU, LLM_FFN_PAR, il);
+            cb(cur, "ffn_out", il);
+        } else {
+            // MoE layer
+            // Kimi uses moe_renormalize=True and routed_scaling_factor (stored as expert_weights_scale) = 2.446
+            ggml_tensor * moe_out = build_moe_ffn(cur,
+                layer.ffn_gate_inp,
+                layer.ffn_up_exps,
+                layer.ffn_gate_exps,
+                layer.ffn_down_exps,
+                layer.ffn_exp_probs_b,
+                hparams.n_expert,
+                hparams.n_expert_used,
+                LLM_FFN_SILU, true,
+                true, hparams.expert_weights_scale,
+                (llama_expert_gating_func_type) hparams.expert_gating_func,
+                il);
+            cb(moe_out, "ffn_moe_out", il);
+
+            // Shared expert
+            {
+                ggml_tensor * ffn_shexp = build_ffn(cur,
+                        layer.ffn_up_shexp, NULL, NULL,
+                        layer.ffn_gate_shexp, NULL, NULL,
+                        layer.ffn_down_shexp, NULL, NULL,
+                        NULL, LLM_FFN_SILU, LLM_FFN_PAR, il);
+                cb(ffn_shexp, "ffn_shexp", il);
+
+                cur = ggml_add(ctx0, moe_out, ffn_shexp);
+                cb(cur, "ffn_out", il);
+            }
+        }
+        // Residual
+        cur = ggml_add(ctx0, cur, ffn_inp);
+
+        cur = build_cvec(cur, il);
+        cb(cur, "l_out", il);
+
+        inpL = cur;
+    }
+    cur = inpL;
+
+    // Final Norm
+    cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);
+
+    cb(cur, "result_norm", -1);
+    res->t_embd = cur;
+
+    // Output
+    cur = ggml_mul_mat(ctx0, model.output, cur);
+    cb(cur, "result_output", -1);
+    res->t_logits = cur;
+
+    ggml_build_forward_expand(gf, cur);
+}
+
+/*
+    This is a ggml implementation of the naive_chunk_kda function of
+    https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/kda/naive.py
+*/
+std::pair llm_build_kimi_linear::build_kda_chunking(
+        ggml_tensor * q,
+        ggml_tensor * k,
+        ggml_tensor * v,
+        ggml_tensor * gk,
+        ggml_tensor * beta,
+        ggml_tensor * state,
+        ggml_tensor * causal_mask,
+        ggml_tensor * identity,
+        ggml_tensor * diag_mask,
+        int           il) {
+    GGML_ASSERT(ggml_is_contiguous(state));
+
+    const int64_t S_k      = q->ne[0];
+    const int64_t H_k      = q->ne[1];
+    const int64_t n_tokens = q->ne[2];
+    const int64_t n_seqs   = q->ne[3];
+
+    const int64_t S_v = v->ne[0];
+    const int64_t H_v = v->ne[1];
+
+    GGML_ASSERT(v->ne[2] == n_tokens);
+    GGML_ASSERT(k->ne[2] == n_tokens);
+    GGML_ASSERT(gk->ne[0] == S_v && gk->ne[1] == H_v && gk->ne[2] == n_tokens && gk->ne[3] == n_seqs);
+    GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
+    GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v && state->ne[2] == H_v && state->ne[3] == n_seqs);
+
+    GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
+    GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
+
+    GGML_ASSERT(H_k == H_v);  // we did a repeat to make sure this is the case
+
+    // TODO: can this ever be false?
+    const bool use_qk_l2norm = true;
+
+    if (use_qk_l2norm) {
+        const float eps_norm = hparams.f_norm_rms_eps;
+
+        q = ggml_l2_norm(ctx0, q, eps_norm);
+        k = ggml_l2_norm(ctx0, k, eps_norm);
+    }
+
+    const float scale = 1.0f / sqrtf(S_v);
+
+    beta = ggml_sigmoid(ctx0, beta);
+
+    cb(q, "q_in", il);
+    cb(k, "k_in", il);
+    cb(v, "v_in", il);
+    cb(beta, "beta_in", il);
+    cb(gk, "gk_in", il);
+
+    q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_k, n_tokens, H_k, n_seqs);
+    k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_k, n_tokens, H_k, n_seqs);
+    v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
+    gk = ggml_cont_4d(ctx0, ggml_permute(ctx0, gk, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
+
+    beta  = ggml_cont(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3));
+    state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs);
+
+    cb(q, "q_perm", il);
+    cb(k, "k_perm", il);
+    cb(v, "v_perm", il);
+    cb(beta, "beta_perm", il);
+    cb(gk, "gk_perm", il);
+    cb(state, "state_in", il);
+
+    GGML_ASSERT(q->ne[1] == n_tokens && q->ne[0] == S_k && q->ne[2] == H_k && q->ne[3] == n_seqs);
+    GGML_ASSERT(k->ne[1] == n_tokens && k->ne[0] == S_k && k->ne[2] == H_k && k->ne[3] == n_seqs);
+    GGML_ASSERT(v->ne[1] == n_tokens && v->ne[0] == S_v && v->ne[2] == H_k && v->ne[3] == n_seqs);
+    GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[0] == 1 && beta->ne[3] == n_seqs);
+
+    // Do padding
+    const int64_t chunk_size = CHUNK_SIZE;
+
+    const int64_t pad = (chunk_size - n_tokens % chunk_size) % chunk_size;
+    const int64_t n_chunks = (n_tokens + pad) / chunk_size;
+
+    q = ggml_pad(ctx0, q, 0, pad, 0, 0);
+    k = ggml_pad(ctx0, k, 0, pad, 0, 0);
+    v = ggml_pad(ctx0, v, 0, pad, 0, 0);
+    gk = ggml_pad(ctx0, gk, 0, pad, 0, 0);
+    beta = ggml_pad(ctx0, beta, 0, pad, 0, 0);
+
+    cb(q, "q_pad", il);
+    cb(k, "k_pad", il);
+    cb(v, "v_pad", il);
+    cb(beta, "beta_pad", il);
+    cb(gk, "gk_pad", il);
+
+    ggml_tensor * v_beta = ggml_mul(ctx0, v, beta);
+    ggml_tensor * k_beta = ggml_mul(ctx0, k, beta);
+
+    cb(v_beta, "v_beta", il);
+    cb(k_beta, "k_beta", il);
+
+    const int64_t HB = H_k * n_seqs;
+
+    q      = ggml_cont_4d(ctx0, q,      S_k, chunk_size, n_chunks, HB);
+    k      = ggml_cont_4d(ctx0, k,      S_k, chunk_size, n_chunks, HB);
+    k_beta = ggml_cont_4d(ctx0, k_beta, S_k, chunk_size, n_chunks, HB);
+    v      = ggml_cont_4d(ctx0, v,      S_v, chunk_size, n_chunks, HB);
+    v_beta = ggml_cont_4d(ctx0, v_beta, S_v, chunk_size, n_chunks, HB);
+
+    gk    = ggml_cont_4d(ctx0, gk, S_k, chunk_size, n_chunks, HB);
+    beta = ggml_cont_4d(ctx0, beta, 1, chunk_size, n_chunks, HB);
+
+    // switch for cumsum
+    gk = ggml_cont_4d(ctx0, ggml_permute(ctx0, gk, 1, 0, 2, 3), chunk_size, S_k, n_chunks, HB);
+    cb(gk, "gk", il);
+    ggml_tensor * gk_cumsum = ggml_cumsum(ctx0, gk);
+    cb(gk_cumsum, "gk_cumsum", il);
+
+/*
+    Compute Akk and Aqk loop together
+    Akk loop:
+    for i in range(BT):
+        k_i = k[..., i, :] # k_i [B,H,NT,S]
+        g_i = g[..., i:i+1, :] # g_i [B,H,NT,1,S]
+        A[..., i] = torch.einsum('... c d, ... d -> ... c', k * (g - g_i).exp(), k_i)
+    Aqk loop:
+    for j in range(BT):
+        k_j = k[:, :, i, j]
+        g_j = g[:, :, i, j:j+1, :]
+        A[..., j] = torch.einsum('... c d, ... d -> ... c', q_i * (g_i - g_j).exp(), k_j)
+*/
+    const int64_t CHB = n_chunks * H_k * n_seqs;
+    ggml_tensor * gkcs_i = ggml_reshape_4d(ctx0, gk_cumsum, chunk_size, 1, S_k, CHB);  // [chunk_size, 1, S_k, CHB]
+    ggml_tensor * gkcs_j = ggml_reshape_4d(ctx0, gkcs_i, 1, chunk_size, S_k, CHB);  // [1, chunk_size, S_k, CHB]
+
+    ggml_tensor * gkcs_j_bc = ggml_repeat_4d(ctx0, gkcs_j, chunk_size, chunk_size, S_k, CHB);  // [1, chunk_size, S_k, CHB] -> [chunk_size, chunk_size, S_k, CHB]
+    // decay_mask [chunk_size,chunk_size,S_k,CHB]
+    ggml_tensor * decay_mask = ggml_sub(ctx0, gkcs_j_bc, gkcs_i);
+    cb(decay_mask, "decay_mask", il);
+
+    decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
+    cb(decay_mask, "decay_masked", il);
+    decay_mask = ggml_exp(ctx0, decay_mask);
+    decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
+
+    // decay_mask [S_k,BT_j,BT_i,CHB] *Note* second and third chunk_sizes are switched
+    decay_mask = ggml_cont_4d(ctx0, ggml_permute(ctx0, decay_mask, 2, 1, 0, 3), S_k, chunk_size, chunk_size, CHB);
+
+    ggml_tensor * k_i = ggml_reshape_4d(ctx0, k, S_k, chunk_size, 1, CHB);
+    ggml_tensor * k_j = ggml_reshape_4d(ctx0, k, S_k, 1, chunk_size, CHB);
+    ggml_tensor * q_i = ggml_reshape_4d(ctx0, q, S_k, chunk_size, 1, CHB);
+
+    ggml_tensor * decay_k_i = ggml_mul(ctx0, decay_mask, k_i);
+    ggml_tensor * decay_q_i = ggml_mul(ctx0, decay_mask, q_i);
+
+    // decay_k_i [S.BT,BT,CHB] @ k_j [S,1,BT,CHB] = Akk [BT,1,BT,CHB]
+    ggml_tensor * Akk = ggml_mul_mat(ctx0, decay_k_i, k_j);
+    ggml_tensor * Aqk = ggml_mul_mat(ctx0, decay_q_i, k_j);
+    Akk = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, Akk, chunk_size, chunk_size, n_chunks, HB)));
+    Aqk = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, Aqk, chunk_size, chunk_size, n_chunks, HB)));
+    cb(Akk, "Akk", il);
+    cb(Aqk, "Aqk", il);
+
+    Akk = ggml_mul(ctx0, Akk, beta);
+    Akk = ggml_neg(ctx0, ggml_mul(ctx0, Akk, causal_mask));
+    cb(Akk, "attn_pre_solve", il);
+
+    Aqk = ggml_mul(ctx0, Aqk, diag_mask);
+    Aqk = ggml_scale(ctx0, Aqk, scale); // scale q
+    cb(Aqk, "Aqk_masked", il);
+
+    // for i in range(1, chunk_size):
+    //          row = attn[..., i, :i].clone()
+    //          sub = attn[..., :i, :i].clone()
+    //          attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
+    // attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
+    //
+    // We reduce this to a linear triangular solve: AX = B, where B = attn, A = I - tril(A)
+    ggml_tensor * attn_lower = ggml_mul(ctx0, Akk, causal_mask);
+    ggml_tensor * lhs        = ggml_sub(ctx0, ggml_repeat(ctx0, identity, attn_lower), attn_lower);
+
+    ggml_tensor * lin_solve  = ggml_solve_tri(ctx0, lhs, Akk, true, true, false);
+    Akk                      = ggml_mul(ctx0, lin_solve, causal_mask);
+    Akk                      = ggml_add(ctx0, Akk, identity);
+
+    cb(Akk, "attn_solved", il);
+
+    // switch back for downstream
+    gk_cumsum = ggml_cont_4d(ctx0, ggml_permute(ctx0, gk_cumsum, 1, 0, 2, 3), S_k, chunk_size, n_chunks, HB);
+    ggml_tensor * gkexp      = ggml_exp(ctx0, gk_cumsum);
+    cb(gk_cumsum, "gk_cumsum", il);
+
+    // u = (A*beta[..., None, :]) @ v  aka U_[t]
+    ggml_tensor * vb = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), Akk);
+
+    ggml_tensor * kbeta_gkexp = ggml_mul(ctx0, k_beta, gkexp);
+    cb(kbeta_gkexp, "kbeta_gkexp", il);
+
+    ggml_tensor * k_cumdecay = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gkexp)), Akk);
+    cb(k_cumdecay, "k_cumdecay", il);
+
+    ggml_tensor * core_attn_out = nullptr;
+    ggml_tensor * new_state = ggml_dup(ctx0, state);
+
+    cb(new_state, "new_state", il);
+
+    for (int64_t chunk = 0; chunk < n_chunks; chunk++) {
+// extract one chunk worth of data
+        auto chunkify = [=](ggml_tensor * t) {
+                    return ggml_cont(ctx0, ggml_view_4d(ctx0, t, t->ne[0], chunk_size, 1, t->ne[3],
+                t->nb[1], t->nb[2], t->nb[3], t->nb[2] * chunk));
+        };
+        auto chunkify_A = [=](ggml_tensor * t) {
+                    return ggml_cont(ctx0, ggml_view_4d(ctx0, t, chunk_size, chunk_size, 1, t->ne[3],
+                t->nb[1], t->nb[2], t->nb[3], t->nb[2] * chunk));
+        };
+
+
+// k [S,BT,NT,H*B] => k_chunk [S,BT,1,H*B]
+        ggml_tensor * k_chunk = chunkify(k);
+        ggml_tensor * q_chunk = chunkify(q);
+        ggml_tensor * vb_chunk = chunkify(vb);
+
+// gk_cumsum [S,BT,NT,H*B] => gk_cs_chunk [S,BT,1,H*B]
+        ggml_tensor * gk_cs_chunk = chunkify(gk_cumsum);
+        ggml_tensor * k_cumdecay_chunk = chunkify(k_cumdecay);
+        ggml_tensor * gkexp_chunk = ggml_exp(ctx0, gk_cs_chunk);
+        ggml_tensor * Aqk_chunk = chunkify_A(Aqk);
+
+        ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, new_state, 1, 0, 2, 3), S_v, S_v, 1, H_v * n_seqs);
+
+        // new_state [S,S,1,H*B] k_cumdecay_chunk [S,BT,1,H*B]
+        // v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state or W_[t] @ S_[t]
+        ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay_chunk);
+
+        // v_new = v_i - v_prime or U_[t] - W_[t]*S_[t]
+        ggml_tensor * v_new = ggml_sub(ctx0, ggml_repeat(ctx0, vb_chunk, v_prime), v_prime);
+        ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new));
+
+        // q_chunk [S,BT,1,H*B] gkexp_chunk [S,BT,1,H*B]
+        // attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
+        // or Gamma_[t]*Q_]t] @ S
+        ggml_tensor * q_gk_exp   = ggml_mul(ctx0, q_chunk, gkexp_chunk);
+        ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_gk_exp);
+        attn_inter = ggml_scale(ctx0, attn_inter, scale); // scale q
+
+        // v_new_t [S,BT,1,H*B] Aqk [BT,BT,1,H*B]
+        // core_attn_out[:, :, i] = attn_inter + attn @ v_new or A' @ (U_[t] - W_[t]*S_[t])
+        ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, Aqk_chunk);
+
+        // o[:, :, i] = (q_i * g_i.exp()) @ S + A @ v_i
+        ggml_tensor * core_attn_out_chunk = ggml_add(ctx0, attn_inter, v_attn);
+
+        core_attn_out = core_attn_out == nullptr ? core_attn_out_chunk : ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 1);
+
+        ggml_tensor * gk_cum_last =
+            ggml_cont(ctx0, ggml_view_4d(ctx0, gk_cs_chunk, gk_cs_chunk->ne[0], 1, gk_cs_chunk->ne[2], gk_cs_chunk->ne[3],
+                                        gk_cs_chunk->nb[1], gk_cs_chunk->nb[2], gk_cs_chunk->nb[3],
+                                        gk_cs_chunk->nb[1] * (gk_cs_chunk->ne[1] - 1)));
+
+        ggml_tensor * gkexp_last = ggml_exp(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, gk_cum_last)));
+
+        ggml_tensor * gk_diff = ggml_neg(ctx0, ggml_sub(ctx0, gk_cs_chunk, gk_cum_last));
+
+        ggml_tensor * gk_diff_exp = ggml_exp(ctx0, gk_diff);
+
+        ggml_tensor * key_gkdiff = ggml_mul(ctx0, k_chunk, gk_diff_exp);
+
+        // rearrange((g_i[:,:,-1:] - g_i).exp()*k_i, 'b h c k -> b h k c') @ (U_[t] - W_[t] @ S)
+        ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, ggml_cont(ctx0, ggml_transpose(ctx0, key_gkdiff)));
+
+        new_state = ggml_add(ctx0,
+            ggml_mul(ctx0, new_state, ggml_reshape_4d(ctx0, gkexp_last, gkexp_last->ne[0], gkexp_last->ne[1], H_v, n_seqs)),
+            ggml_reshape_4d(ctx0, kgdmulvnew, kgdmulvnew->ne[0], kgdmulvnew->ne[1], H_v, n_seqs));
+    }
+
+    core_attn_out = ggml_cont_4d(ctx0, core_attn_out, S_v, chunk_size * n_chunks, H_v, n_seqs);
+
+    // truncate padded tokens
+    ggml_tensor * output_tokens = ggml_view_4d(ctx0, core_attn_out,
+            S_v, n_tokens, H_v, n_seqs,
+            ggml_row_size(core_attn_out->type, S_v),
+            ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks),
+            ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks * H_v), 0);
+    output_tokens = ggml_cont(ctx0, output_tokens);
+    // permute back to (S_v, H_v, n_tokens, n_seqs)
+    output_tokens = ggml_permute(ctx0, output_tokens, 0, 2, 1, 3);
+    output_tokens = ggml_cont(ctx0, output_tokens);
+
+    cb(new_state, "output_state", il);
+
+    return {output_tokens, new_state};
+}
+
+std::pair llm_build_kimi_linear::build_kda_autoregressive(
+    ggml_tensor * q,
+    ggml_tensor * k,
+    ggml_tensor * v,
+    ggml_tensor * gk,
+    ggml_tensor * beta,
+    ggml_tensor * state,
+    int il) {
+    GGML_ASSERT(ggml_is_contiguous(v));
+    GGML_ASSERT(ggml_is_contiguous(gk));
+
+    const int64_t S_k      = q->ne[0];
+    const int64_t H_k      = q->ne[1];
+    const int64_t n_tokens = q->ne[2];
+    const int64_t n_seqs   = q->ne[3];
+
+    const int64_t S_v = v->ne[0];
+    const int64_t H_v = v->ne[1];
+
+    GGML_ASSERT(n_tokens == 1);
+    GGML_ASSERT(v->ne[2] == n_tokens);
+    GGML_ASSERT(k->ne[2] == n_tokens);
+    GGML_ASSERT(gk->ne[0] == S_k && gk->ne[1] == H_k && gk->ne[2] == n_tokens && gk->ne[3] == n_seqs);
+    GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
+    GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_k && state->ne[2] == H_v && state->ne[3] == n_seqs);
+
+    GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
+    GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
+
+    GGML_ASSERT(H_k == H_v);  // we did a repeat to make sure this is the case
+
+    const float eps_norm = hparams.f_norm_rms_eps;
+
+    q = ggml_l2_norm(ctx0, q, eps_norm);
+    k = ggml_l2_norm(ctx0, k, eps_norm);
+
+    const float scale = 1.0f / sqrtf(S_v);
+
+    q    = ggml_scale(ctx0, q, scale);
+    beta = ggml_sigmoid(ctx0, beta);
+
+    cb(q, "q_in", il);
+    cb(k, "k_in", il);
+    cb(v, "v_in", il);
+    cb(beta, "beta_in", il);
+    cb(gk, "gk_in", il);
+
+// g [H,1,B,1] g_t [1,H,B,1] => [1,1,H,B]
+// gk [S,H,1,B] => [S,1,H,B] gk_t [1,S,H,B]
+// beta [H,1,1,B] beta_t [1,H,1,B] => [1,1,H,B]
+    gk = ggml_reshape_4d(ctx0, gk, S_k, 1, H_k, n_seqs);
+    ggml_tensor * gk_t = ggml_cont(ctx0, ggml_transpose(ctx0, gk));
+    ggml_tensor * beta_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, beta), 1, 1, H_k, n_seqs);
+
+    // Apply exponential to gk_t
+    gk_t = ggml_exp(ctx0, gk_t);
+    // Apply the gated delta rule for the single timestep
+    // last_recurrent_state = last_recurrent_state * gk_t
+    // S = S * g_i[..., None].exp()
+    state = ggml_mul(ctx0, state, gk_t);
+
+    ggml_tensor * state_t = ggml_cont(ctx0, ggml_transpose(ctx0, state));
+
+// state [S,S,H,B] k [S,1,H,B] k_state [S_v,1,H,B]
+    k = ggml_reshape_4d(ctx0, k, S_k, 1, H_k, n_seqs);
+    ggml_tensor * k_state = ggml_mul_mat(ctx0, state_t, k);
+
+    // v_i - (k_i[..., None] * S).sum(-2)
+    v = ggml_reshape_4d(ctx0, v, S_v, 1, H_v, n_seqs);
+    ggml_tensor * v_diff = ggml_sub(ctx0, v, k_state);
+
+    // b_i[..., None] * k_i
+    ggml_tensor * k_beta = ggml_mul(ctx0, k, beta_t);
+
+    // S = S + torch.einsum('b h k, b h v -> b h k v', b_i[..., None] * k_i, v_i - (k_i[..., None] * S).sum(-2))
+    // v_diff_t [1,S_v,H,B] k_beta_t [1,S_k,H,B] state [S_v,S_k,H,B]
+    state = ggml_add(ctx0, state, ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_diff)), ggml_cont(ctx0, ggml_transpose(ctx0, k_beta))));
+
+    q = ggml_reshape_4d(ctx0, q, S_k, 1, H_k, n_seqs);
+    state_t = ggml_cont(ctx0, ggml_transpose(ctx0, state));
+    ggml_tensor * core_attn_out = ggml_mul_mat(ctx0, state_t, q);
+    // core_attn_out should be [S_v, 1, H_v, n_seqs] after this
+    cb(core_attn_out, "output_tokens", il);
+    cb(state, "new_state", il);
+
+    return {core_attn_out, state};
+}
+
diff --git a/src/models/minicpm3.cpp b/src/models/minicpm3.cpp
index f374a9fd030..297cc34ba58 100644
--- a/src/models/minicpm3.cpp
+++ b/src/models/minicpm3.cpp
@@ -9,6 +9,7 @@ llm_build_minicpm3::llm_build_minicpm3(const llama_model & model, const llm_grap
 
     const uint32_t n_embd_head_qk_rope = hparams.n_rot;
     const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
+
     const uint32_t kv_lora_rank = hparams.n_lora_kv;
 
     ggml_tensor * cur;
diff --git a/src/models/models.h b/src/models/models.h
index 3a44f7f140f..2a750c168ea 100644
--- a/src/models/models.h
+++ b/src/models/models.h
@@ -17,6 +17,53 @@ struct llm_graph_context_mamba : public llm_graph_context {
 
 };
 
+struct llm_graph_context_delta : public llm_graph_context_mamba {
+    llm_graph_context_delta(const llm_graph_params & params);
+
+    virtual ~llm_graph_context_delta() = default;
+
+    std::pair build_delta_net_unified_chunking(
+        ggml_context * ctx0,
+        ggml_tensor * q,
+        ggml_tensor * k,
+        ggml_tensor * v,
+        ggml_tensor * g,
+        ggml_tensor * beta,
+        ggml_tensor * state,
+        ggml_tensor * causal_mask,
+        ggml_tensor * identity,
+        ggml_tensor * diag_mask,
+        int           il,
+        int64_t       chunk_size,
+        float         eps_norm);
+
+    std::pair build_delta_net_unified_autoregressive(
+        ggml_context * ctx0,
+        ggml_tensor * q,
+        ggml_tensor * k,
+        ggml_tensor * v,
+        ggml_tensor * g,
+        ggml_tensor * beta,
+        ggml_tensor * state,
+        int           il,
+        float         eps_norm);
+
+    std::pair build_delta_net_unified(
+        ggml_context * ctx0,
+        ggml_tensor * q,
+        ggml_tensor * k,
+        ggml_tensor * v,
+        ggml_tensor * g,
+        ggml_tensor * beta,
+        ggml_tensor * state,
+        ggml_tensor * causal_mask,
+        ggml_tensor * identity,
+        ggml_tensor * diag_mask,
+        int           il,
+        int64_t       chunk_size,
+        float         eps_norm);
+};
+
 // Base class for RWKV-related models
 struct llm_build_rwkv6_base : public llm_graph_context {
     const llama_model & model;
@@ -288,6 +335,33 @@ struct llm_build_jamba : public llm_graph_context_mamba {
     llm_build_jamba(const llama_model & model, const llm_graph_params & params);
 };
 
+struct llm_build_kimi_linear : public llm_graph_context_mamba {
+    llm_build_kimi_linear(const llama_model & model, const llm_graph_params & params);
+
+    std::pair build_kda_autoregressive(
+                ggml_tensor * q,
+                ggml_tensor * k,
+                ggml_tensor * v,
+                ggml_tensor * gk,
+                ggml_tensor * beta,
+                ggml_tensor * state,
+                        int   il);
+
+    std::pair build_kda_chunking(
+                ggml_tensor * q,
+                ggml_tensor * k,
+                ggml_tensor * v,
+                ggml_tensor * gk,
+                ggml_tensor * beta,
+                ggml_tensor * state,
+                ggml_tensor * causal_mask,
+                ggml_tensor * identity,
+                ggml_tensor * diag_mask,
+                        int   il);
+
+    const llama_model & model;
+};
+
 struct llm_build_lfm2 : public llm_graph_context {
     const llama_model & model;
 
@@ -449,7 +523,7 @@ struct llm_build_qwen3vl : public llm_graph_context {
 struct llm_build_qwen3vlmoe : public llm_graph_context {
     llm_build_qwen3vlmoe(const llama_model & model, const llm_graph_params & params);
 };
-struct llm_build_qwen3next : public llm_graph_context_mamba {
+struct llm_build_qwen3next : public llm_graph_context_delta {
     llm_build_qwen3next(const llama_model & model, const llm_graph_params & params);
 private:
     ggml_tensor * build_layer_attn(
@@ -507,6 +581,59 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {
     const llama_model & model;
 };
 
+struct llm_build_qwen3_5 : public llm_graph_context_delta {
+    llm_build_qwen3_5(const llama_model & model, const llm_graph_params & params);
+
+protected:
+    // Tag type for subclass constructors that need to call build_graph() themselves
+    // (to ensure virtual dispatch works correctly)
+    struct defer_graph_build_t {};
+
+    llm_build_qwen3_5(const llama_model & model, const llm_graph_params & params, defer_graph_build_t);
+
+    void build_graph();
+
+    virtual ggml_tensor * build_layer_ffn(
+                ggml_tensor * cur,
+                        int   il);
+
+    const llama_model & model;
+
+private:
+    ggml_tensor * build_layer_attn(
+    llm_graph_input_attn_kv * inp_attn,
+                ggml_tensor * cur,
+                ggml_tensor * inp_pos,
+                        int   il);
+
+    ggml_tensor * build_layer_attn_linear(
+         llm_graph_input_rs * inp,
+                ggml_tensor * cur,
+                ggml_tensor * causal_mask,
+                ggml_tensor * identity,
+                ggml_tensor * diag_mask,
+                        int   il);
+
+    ggml_tensor * build_norm_gated(
+                ggml_tensor * input,
+                ggml_tensor * weights,
+                ggml_tensor * gate,
+                        int   layer);
+
+    std::pair build_qkvz(
+                ggml_tensor * input,
+                        int   il);
+};
+
+struct llm_build_qwen3_5_moe : public llm_build_qwen3_5 {
+    llm_build_qwen3_5_moe(const llama_model & model, const llm_graph_params & params);
+
+protected:
+    ggml_tensor * build_layer_ffn(
+                ggml_tensor * cur,
+                        int   il) override;
+};
+
 struct llm_build_qwen : public llm_graph_context {
     llm_build_qwen(const llama_model & model, const llm_graph_params & params);
 };
@@ -556,6 +683,10 @@ struct llm_build_starcoder : public llm_graph_context {
     llm_build_starcoder(const llama_model & model, const llm_graph_params & params);
 };
 
+struct llm_build_step35_iswa : public llm_graph_context {
+    llm_build_step35_iswa(const llama_model & model, const llm_graph_params & params);
+};
+
 struct llm_build_t5_dec : public llm_graph_context {
     llm_build_t5_dec(const llama_model & model, const llm_graph_params & params);
 };
diff --git a/src/models/nemotron-h.cpp b/src/models/nemotron-h.cpp
index eb135e63f18..079c730ac29 100644
--- a/src/models/nemotron-h.cpp
+++ b/src/models/nemotron-h.cpp
@@ -67,7 +67,7 @@ ggml_tensor * llm_build_nemotron_h::build_attention_layer(ggml_tensor *
                                                           const llama_model &       model,
                                                           const int64_t             n_embd_head,
                                                           const int                 il) {
-    // compute Q and K and (optionally) RoPE them
+    // compute Q and K
     ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
     cb(Qcur, "Qcur", il);
     if (model.layers[il].bq) {
diff --git a/src/models/openelm.cpp b/src/models/openelm.cpp
index ee46a3375e8..fbf682ec835 100644
--- a/src/models/openelm.cpp
+++ b/src/models/openelm.cpp
@@ -43,7 +43,7 @@ llm_build_openelm::llm_build_openelm(const llama_model & model, const llm_graph_
             ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, cur->nb[1], cur->nb[2], cur->nb[1]*n_head);
             cb(Kcur, "Kcur", il);
 
-            ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, cur->nb[1], cur->nb[2], cur->nb[1]*(n_head+n_head_kv)));
+            ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, cur->nb[1], cur->nb[2], cur->nb[1]*(n_head+n_head_kv));
             cb(Vcur, "Vcur", il);
 
             Qcur = build_norm(Qcur,
diff --git a/src/models/plm.cpp b/src/models/plm.cpp
index 481cbba6907..612a487c564 100644
--- a/src/models/plm.cpp
+++ b/src/models/plm.cpp
@@ -5,6 +5,7 @@ llm_build_plm::llm_build_plm(const llama_model & model, const llm_graph_params &
 
     const uint32_t n_embd_head_qk_rope = hparams.n_rot;
     const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
+
     const uint32_t kv_lora_rank = hparams.n_lora_kv;
 
     ggml_tensor * cur;
diff --git a/src/models/qwen3-5.cpp b/src/models/qwen3-5.cpp
new file mode 100644
index 00000000000..0947299d73a
--- /dev/null
+++ b/src/models/qwen3-5.cpp
@@ -0,0 +1,421 @@
+#include "models.h"
+
+#define CHUNK_SIZE 64
+
+llm_build_qwen3_5::llm_build_qwen3_5(const llama_model & model, const llm_graph_params & params) :
+    llm_graph_context_delta(params), model(model) {
+    build_graph();
+}
+
+// virtual call in constructor fix
+llm_build_qwen3_5::llm_build_qwen3_5(const llama_model & model, const llm_graph_params & params, defer_graph_build_t /*tag*/) :
+    llm_graph_context_delta(params), model(model) {
+}
+
+void llm_build_qwen3_5::build_graph() {
+    ggml_tensor * cur;
+    ggml_tensor * inpL;
+
+    inpL = build_inp_embd(model.tok_embd);
+    cb(inpL, "model.embed_tokens", -1);
+
+    auto * inp = build_inp_mem_hybrid();
+
+    ggml_tensor * inp_pos     = build_inp_pos();
+    ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+    ggml_tensor * causal_mask =
+        ggml_tri(ctx0, ggml_fill(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, CHUNK_SIZE, CHUNK_SIZE), 1.0f),
+                    GGML_TRI_TYPE_LOWER);
+
+    ggml_tensor * identity = ggml_diag(ctx0, ggml_fill(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, CHUNK_SIZE), 1.0f));
+    ggml_tensor * diag_mask = ggml_add(ctx0, causal_mask, identity);
+
+    ggml_build_forward_expand(gf, causal_mask);
+    ggml_build_forward_expand(gf, identity);
+    ggml_build_forward_expand(gf, diag_mask);
+
+    for (int il = 0; il < n_layer; ++il) {
+        ggml_tensor * inpSA = inpL;
+
+        cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il);
+        cb(cur, "attn_norm", il);
+
+        if (hparams.is_recurrent(il)) {
+            cur = build_layer_attn_linear(inp->get_recr(), cur, causal_mask, identity, diag_mask, il);
+        } else {
+            cur = build_layer_attn(inp->get_attn(), cur, inp_pos, il);
+        }
+
+        if (il == n_layer - 1 && inp_out_ids) {
+            cur   = ggml_get_rows(ctx0, cur, inp_out_ids);
+            inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+        }
+
+        cur = ggml_add(ctx0, cur, inpSA);
+        cb(cur, "attn_residual", il);
+
+        ggml_tensor * ffn_residual = cur;
+
+        ggml_tensor * attn_post_norm = build_norm(cur, model.layers[il].attn_post_norm, nullptr, LLM_NORM_RMS, il);
+        cb(attn_post_norm, "attn_post_norm", il);
+
+        cur = build_layer_ffn(attn_post_norm, il);
+        cb(cur, "ffn_out", il);
+
+        cur = ggml_add(ctx0, cur, ffn_residual);
+        cb(cur, "post_ffn", il);
+
+        inpL = cur;
+    }
+    cur = inpL;
+
+    cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1);
+
+    cb(cur, "result_norm", -1);
+    res->t_embd = cur;
+
+    cur = build_lora_mm(model.output, cur);
+
+    cb(cur, "result_output", -1);
+    res->t_logits = cur;
+
+    ggml_build_forward_expand(gf, cur);
+}
+
+ggml_tensor * llm_build_qwen3_5::build_norm_gated(
+        ggml_tensor * input,
+        ggml_tensor * weights,
+        ggml_tensor * gate,
+        int           layer) {
+    ggml_tensor * normalized = build_norm(input, weights, nullptr, LLM_NORM_RMS, layer);
+    ggml_tensor * gated_silu = ggml_silu(ctx0, gate);
+
+    return ggml_mul(ctx0, normalized, gated_silu);
+}
+
+ggml_tensor * llm_build_qwen3_5::build_layer_attn(
+        llm_graph_input_attn_kv * inp,
+        ggml_tensor *             cur,
+        ggml_tensor *             inp_pos,
+        int                       il) {
+    const int64_t n_embd_head = hparams.n_embd_head_v;
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+    ggml_tensor * Qcur_full = build_lora_mm(model.layers[il].wq, cur); // [ (n_embd_head * 2) * n_head, n_tokens ]
+    cb(Qcur_full, "Qcur_full", il);
+
+    ggml_tensor * Qcur = ggml_view_3d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens,
+        ggml_element_size(Qcur_full) * n_embd_head * 2,
+        ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, 0);
+    cb(Qcur, "Qcur_reshaped", il);
+
+    Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il);
+    cb(Qcur, "Qcur_normed", il);
+
+    ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+    cb(Kcur, "Kcur", il);
+
+    ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+    cb(Vcur, "Vcur", il);
+
+    Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+    Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il);
+    cb(Kcur, "Kcur_normed", il);
+
+    ggml_tensor * gate = ggml_view_3d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens,
+        ggml_element_size(Qcur_full) * n_embd_head * 2,
+        ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head,
+        ggml_element_size(Qcur_full) * n_embd_head);
+    gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens);
+    cb(gate, "gate_reshaped", il);
+
+    Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
+    Qcur = ggml_rope_ext(
+            ctx0, Qcur, inp_pos, nullptr,
+            n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+            ext_factor, attn_factor, beta_fast, beta_slow);
+
+    Kcur = ggml_rope_ext(
+            ctx0, Kcur, inp_pos, nullptr,
+            n_rot, rope_type, n_ctx_orig, freq_base,
+            freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
+
+    cb(Qcur, "Qcur", il);
+    cb(Kcur, "Kcur", il);
+    cb(Vcur, "Vcur", il);
+
+    const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
+
+    cur = build_attn(inp,
+                nullptr, nullptr,
+                Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
+    cb(cur, "attn_pregate", il);
+
+    ggml_tensor * gate_sigmoid = ggml_sigmoid(ctx0, gate);
+    cb(gate_sigmoid, "gate_sigmoid", il);
+
+    cur = ggml_mul(ctx0, cur, gate_sigmoid);
+    cb(cur, "attn_gated", il);
+
+    cur = build_lora_mm(model.layers[il].wo, cur);
+    cb(cur, "attn_output", il);
+
+    return cur;
+}
+
+std::pair llm_build_qwen3_5::build_qkvz(
+                ggml_tensor * input,
+                        int   il) {
+    const int64_t d_inner      = hparams.ssm_d_inner;
+    const int64_t n_seqs       = ubatch.n_seqs;
+    const int64_t head_k_dim   = hparams.ssm_d_state;
+    const int64_t num_k_heads  = hparams.ssm_n_group;
+    const int64_t num_v_heads  = hparams.ssm_dt_rank;
+    const int64_t head_v_dim   = d_inner / num_v_heads;
+    const int64_t n_seq_tokens = ubatch.n_seq_tokens;
+
+    if (model.layers[il].wqkv) {
+        ggml_tensor * qkv_mixed = build_lora_mm(model.layers[il].wqkv, input);
+        qkv_mixed = ggml_reshape_3d(ctx0, qkv_mixed, qkv_mixed->ne[0], n_seq_tokens, n_seqs);
+        cb(qkv_mixed, "linear_attn_qkv_mixed", il);
+
+        ggml_tensor * z = build_lora_mm(model.layers[il].wqkv_gate, input);
+        cb(z, "z", il);
+
+        return { qkv_mixed, z };
+
+    }
+    // legacy path for combined in_proj_qkvz
+    ggml_tensor * mixed_qkvz = build_lora_mm(model.layers[il].ssm_in, input);
+    cb(mixed_qkvz, "linear_attn_mixed_qkvz", il);
+
+    int64_t       qkvz_new_dim        = 2 * head_k_dim + 2 * head_v_dim * (num_v_heads / num_k_heads);
+    ggml_tensor * mixed_qkvz_reshaped = ggml_reshape_4d(ctx0, mixed_qkvz, qkvz_new_dim, num_k_heads, n_seq_tokens, n_seqs);
+
+    int64_t split_sizes_qkvz[4] = {
+        head_k_dim,
+        head_k_dim,
+        head_v_dim * num_v_heads / num_k_heads,
+        head_v_dim * num_v_heads / num_k_heads
+    };
+
+    ggml_tensor * query =
+        ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[0], num_k_heads, n_seq_tokens, n_seqs,
+                    mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], 0);
+    cb(query, "q", il);
+
+    ggml_tensor * key = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[1], num_k_heads, n_seq_tokens, n_seqs,
+                                    mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
+                                    split_sizes_qkvz[0] * ggml_element_size(mixed_qkvz_reshaped));
+    cb(key, "k", il);
+
+    ggml_tensor * value =
+        ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[2], num_k_heads, n_seq_tokens, n_seqs,
+                    mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
+                    (split_sizes_qkvz[0] + split_sizes_qkvz[1]) * ggml_element_size(mixed_qkvz_reshaped));
+    cb(value, "v", il);
+
+    ggml_tensor * z = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[3], num_k_heads, n_seq_tokens, n_seqs,
+                                mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
+                                (split_sizes_qkvz[0] + split_sizes_qkvz[1] + split_sizes_qkvz[2]) * ggml_element_size(mixed_qkvz_reshaped));
+    z = ggml_cont(ctx0, z);
+    cb(z, "z", il);
+
+    ggml_tensor * query_flat = ggml_reshape_3d(ctx0, query, head_k_dim * num_k_heads, n_seq_tokens, n_seqs);
+    cb(query_flat, "query_flat", il);
+
+    ggml_tensor * key_flat = ggml_reshape_3d(ctx0, key, head_k_dim * num_k_heads, n_seq_tokens, n_seqs);
+    cb(key_flat, "key_flat", il);
+
+    ggml_tensor * value_flat = ggml_reshape_3d(ctx0, value, head_v_dim * num_v_heads, n_seq_tokens, n_seqs);
+    cb(value_flat, "value_flat", il);
+
+    ggml_tensor * qkv_mixed = ggml_concat(ctx0, query_flat, key_flat, 0);
+    qkv_mixed               = ggml_concat(ctx0, qkv_mixed, value_flat, 0);
+    cb(qkv_mixed, "qkv_mixed", il);
+
+    return { qkv_mixed, z };
+}
+
+ggml_tensor * llm_build_qwen3_5::build_layer_attn_linear(
+        llm_graph_input_rs * inp,
+        ggml_tensor *        cur,
+        ggml_tensor *        causal_mask,
+        ggml_tensor *        identity,
+        ggml_tensor *        diag_mask,
+        int                  il) {
+    const auto * mctx_cur = inp->mctx;
+
+    const int64_t d_inner      = hparams.ssm_d_inner;
+    const int64_t n_seqs       = ubatch.n_seqs;
+    const int64_t head_k_dim   = hparams.ssm_d_state;
+    const int64_t num_k_heads  = hparams.ssm_n_group;
+    const int64_t num_v_heads  = hparams.ssm_dt_rank;
+    const int64_t head_v_dim   = d_inner / num_v_heads;
+    const int64_t n_seq_tokens = ubatch.n_seq_tokens;
+
+    const auto kv_head = mctx_cur->get_head();
+
+    GGML_ASSERT(n_seqs != 0);
+    GGML_ASSERT(ubatch.equal_seqs());
+    GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
+
+    auto qkvz = build_qkvz(cur, il);
+    ggml_tensor * qkv_mixed = qkvz.first;
+    ggml_tensor * z         = qkvz.second;
+
+    ggml_tensor * mixed_ba = build_lora_mm(model.layers[il].ssm_beta_alpha, cur);
+    cb(mixed_ba, "linear_attn_mixed_ba", il);
+
+    int64_t       ba_new_dim        = 2 * num_v_heads / num_k_heads;
+    ggml_tensor * mixed_ba_reshaped = ggml_reshape_4d(ctx0, mixed_ba, ba_new_dim, num_k_heads, n_seq_tokens, n_seqs);
+
+    int64_t split_sizes_ba[2] = {
+        num_v_heads / num_k_heads,
+        num_v_heads / num_k_heads
+    };
+
+    ggml_tensor * b = ggml_view_4d(ctx0, mixed_ba_reshaped, split_sizes_ba[0], num_k_heads, n_seq_tokens, n_seqs,
+                                   mixed_ba_reshaped->nb[1], mixed_ba_reshaped->nb[2], mixed_ba_reshaped->nb[3], 0);
+    cb(b, "b", il);
+
+    ggml_tensor * a = ggml_view_4d(ctx0, mixed_ba_reshaped, split_sizes_ba[1], num_k_heads, n_seq_tokens, n_seqs,
+                                   mixed_ba_reshaped->nb[1], mixed_ba_reshaped->nb[2], mixed_ba_reshaped->nb[3],
+                                   split_sizes_ba[0] * ggml_element_size(mixed_ba_reshaped));
+    cb(a, "a", il);
+
+    ggml_tensor * beta  = ggml_cont_4d(ctx0, b, num_v_heads, 1, n_seq_tokens, n_seqs);
+
+    ggml_tensor * alpha = ggml_cont_3d(ctx0, a, num_v_heads, n_seq_tokens, n_seqs);
+
+    ggml_tensor * alpha_biased   = ggml_add(ctx0, alpha, model.layers[il].ssm_dt);
+    ggml_tensor * alpha_softplus = ggml_softplus(ctx0, alpha_biased);
+    cb(alpha_softplus, "a_softplus", il);
+    ggml_tensor * gate = ggml_mul(ctx0, alpha_softplus, model.layers[il].ssm_a);
+    cb(gate, "gate", il);
+
+    ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
+    ggml_tensor * ssm_states_all  = mctx_cur->get_s_l(il);
+
+    ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs);
+    cb(conv_states, "conv_states", il);
+
+    ggml_tensor * conv_kernel      = model.layers[il].ssm_conv1d;
+    const int64_t conv_kernel_size = conv_kernel->ne[0];
+    const int64_t conv_channels    = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state;
+    conv_states                    = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs);
+    cb(conv_states, "conv_states_reshaped", il);
+
+    qkv_mixed = ggml_permute(ctx0, qkv_mixed, 1, 0, 2, 3);
+    cb(qkv_mixed, "qkv_mixed_permuted", il);
+
+    ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0);
+    cb(conv_input, "conv_input", il);
+
+    ggml_tensor * last_conv_states =
+        ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, conv_input->nb[1],
+                     conv_input->nb[2], (conv_input->ne[0] - conv_states->ne[0]) * ggml_element_size(conv_input));
+    cb(last_conv_states, "last_conv_states", il);
+
+    ggml_tensor * state_update_target =
+        ggml_view_1d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels * n_seqs,
+                     kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all));
+    cb(state_update_target, "state_update_target", il);
+
+    ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target));
+    cb(conv_states_all, "conv_states_updated", il);
+
+    ggml_tensor * conv_output_proper = ggml_ssm_conv(ctx0, conv_input, conv_kernel);
+    cb(conv_output_proper, "conv_output_raw", il);
+
+    ggml_tensor * conv_output_silu = ggml_silu(ctx0, conv_output_proper);
+    cb(conv_output_silu, "conv_output_silu", il);
+
+    ggml_tensor * conv_qkv_mix = conv_output_silu;
+
+    int64_t qkv_dim = head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads;
+    int64_t nb1_qkv = ggml_row_size(conv_qkv_mix->type, qkv_dim);
+
+    ggml_tensor * q_conv =
+        ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, nb1_qkv, 0);
+    cb(q_conv, "q_conv", il);
+    ggml_tensor * k_conv =
+        ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, nb1_qkv,
+                     head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix));
+    cb(k_conv, "k_conv", il);
+    ggml_tensor * v_conv =
+        ggml_view_2d(ctx0, conv_qkv_mix, head_v_dim * num_v_heads, n_seq_tokens * n_seqs, nb1_qkv,
+                     2 * head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix));
+    cb(v_conv, "v_conv", il);
+
+    q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs);
+    k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs);
+    v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
+
+    ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs);
+    state               = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim,  num_v_heads, n_seqs);
+    cb(state, "state_predelta", il);
+
+    if (num_k_heads != num_v_heads) {
+        GGML_ASSERT(num_v_heads % num_k_heads == 0);
+        int64_t repeat_factor = num_v_heads / num_k_heads;
+
+        ggml_tensor * q_reshaped = ggml_reshape_3d(ctx0, q_conv, head_k_dim, 1, num_k_heads * n_seq_tokens * n_seqs);
+        ggml_tensor * k_reshaped = ggml_reshape_3d(ctx0, k_conv, head_k_dim, 1, num_k_heads * n_seq_tokens * n_seqs);
+
+        ggml_tensor * q_repeated =
+            ggml_repeat_4d(ctx0, q_reshaped, head_k_dim, repeat_factor, num_k_heads * n_seq_tokens * n_seqs, 1);
+        ggml_tensor * k_repeated =
+            ggml_repeat_4d(ctx0, k_reshaped, head_k_dim, repeat_factor, num_k_heads * n_seq_tokens * n_seqs, 1);
+
+        q_conv = ggml_reshape_4d(ctx0, q_repeated, head_k_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs);
+        k_conv = ggml_reshape_4d(ctx0, k_repeated, head_k_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs);
+    }
+
+    cb(q_conv, "q_conv_predelta", il);
+    cb(k_conv, "k_conv_predelta", il);
+    cb(v_conv, "v_conv_predelta", il);
+
+    std::pair attn_out = build_delta_net_unified(ctx0, q_conv, k_conv, v_conv,
+            gate, beta, state, causal_mask, identity, diag_mask,
+            il, CHUNK_SIZE, hparams.f_norm_rms_eps);
+
+    ggml_tensor * output    = attn_out.first;
+    ggml_tensor * new_state = attn_out.second;
+    cb(output, "attn_output", il);
+    cb(new_state, "new_state", il);
+
+    ggml_build_forward_expand(gf,
+                              ggml_cpy(ctx0, new_state,
+                                       ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs,
+                                                    kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all))));
+
+    ggml_tensor * attn_out_2d_final = ggml_reshape_2d(ctx0, output, head_v_dim, num_v_heads * n_seq_tokens * n_seqs);
+
+    ggml_tensor * z_2d = ggml_reshape_2d(ctx0, z, head_v_dim, num_v_heads * n_seq_tokens * n_seqs);
+
+    ggml_tensor * attn_out_norm = build_norm_gated(attn_out_2d_final, model.layers[il].ssm_norm, z_2d, il);
+
+    ggml_tensor * final_output = ggml_reshape_3d(ctx0, attn_out_norm, head_v_dim * num_v_heads, n_seq_tokens, n_seqs);
+    cb(final_output, "final_output", il);
+
+    cur = build_lora_mm(model.layers[il].ssm_out, final_output);
+    cb(cur, "linear_attn_out", il);
+
+    cur = ggml_cont_2d(ctx0, cur, n_embd, n_seq_tokens * n_seqs);
+    return cur;
+}
+
+ggml_tensor * llm_build_qwen3_5::build_layer_ffn(ggml_tensor * cur, const int il) {
+    // Qwen3.5 Dense always uses dense FFN
+    cur = build_ffn(cur,
+        model.layers[il].ffn_up, NULL, NULL,
+        model.layers[il].ffn_gate, NULL, NULL,
+        model.layers[il].ffn_down, NULL, NULL,
+        NULL,
+        LLM_FFN_SILU, LLM_FFN_PAR, il);
+    cb(cur, "ffn_out", il);
+    return cur;
+}
diff --git a/src/models/qwen3-5moe.cpp b/src/models/qwen3-5moe.cpp
new file mode 100644
index 00000000000..a4884432188
--- /dev/null
+++ b/src/models/qwen3-5moe.cpp
@@ -0,0 +1,52 @@
+#include "models.h"
+
+llm_build_qwen3_5_moe::llm_build_qwen3_5_moe(const llama_model & model, const llm_graph_params & params) :
+    llm_build_qwen3_5(model, params, defer_graph_build_t{}) {
+    build_graph();
+}
+
+ggml_tensor * llm_build_qwen3_5_moe::build_layer_ffn(ggml_tensor * cur, const int il) {
+    // Check if this is an MoE layer
+    if (model.layers[il].ffn_gate_inp != nullptr) {
+        // MoE branch
+        ggml_tensor * moe_out =
+            build_moe_ffn(cur,
+                model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps,
+                model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps,
+                nullptr,
+                n_expert, n_expert_used, LLM_FFN_SILU,
+                true, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il);
+        cb(moe_out, "ffn_moe_out", il);
+
+        // Add shared experts if present
+        if (model.layers[il].ffn_up_shexp != nullptr) {
+            ggml_tensor * ffn_shexp =
+                build_ffn(cur,
+                    model.layers[il].ffn_up_shexp, NULL, NULL,
+                    model.layers[il].ffn_gate_shexp, NULL, NULL,
+                    model.layers[il].ffn_down_shexp, NULL, NULL,
+                    NULL,
+                    LLM_FFN_SILU, LLM_FFN_PAR, il);
+            cb(ffn_shexp, "ffn_shexp", il);
+
+            // Apply shared expert gating (sigmoid)
+            ggml_tensor * shared_gate = build_lora_mm(model.layers[il].ffn_gate_inp_shexp, cur);
+            cb(shared_gate, "shared_expert_gate", il);
+
+            shared_gate = ggml_sigmoid(ctx0, shared_gate);
+            cb(shared_gate, "shared_expert_gate_sigmoid", il);
+
+            ffn_shexp = ggml_mul(ctx0, ffn_shexp, shared_gate);
+            cb(ffn_shexp, "ffn_shexp_gated", il);
+
+            cur = ggml_add(ctx0, moe_out, ffn_shexp);
+            cb(cur, "ffn_out", il);
+        } else {
+            cur = moe_out;
+        }
+    } else {
+        // Dense FFN branch (fallback)
+        cur = llm_build_qwen3_5::build_layer_ffn(cur, il);
+    }
+    return cur;
+}
diff --git a/src/models/qwen3next.cpp b/src/models/qwen3next.cpp
index 57b6659baf0..0335f5ab766 100644
--- a/src/models/qwen3next.cpp
+++ b/src/models/qwen3next.cpp
@@ -1,10 +1,9 @@
-#include "ggml.h"
 #include "models.h"
 
 #define CHUNK_SIZE 64
 
 llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_graph_params & params) :
-    llm_graph_context_mamba(params), model(model) {
+    llm_graph_context_delta(params), model(model) {
     ggml_tensor * cur;
     ggml_tensor * inpL;
 
@@ -86,356 +85,6 @@ llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_gr
     ggml_build_forward_expand(gf, cur);
 }
 
-// utility to get one slice from the third dimension
-// input dim:  [x, y, c, b]
-// output dim: [x, y, 1, b]
-static ggml_tensor * get_slice_2d(ggml_context * ctx0, ggml_tensor * t, int64_t c) {
-    return ggml_view_4d(ctx0, t, t->ne[0], t->ne[1], 1, t->ne[3],
-        t->nb[1], t->nb[2], t->nb[3], t->nb[2] * c);
-}
-
-std::pair llm_build_qwen3next::build_delta_net_chunking(
-        ggml_tensor * q,
-        ggml_tensor * k,
-        ggml_tensor * v,
-        ggml_tensor * g,
-        ggml_tensor * beta,
-        ggml_tensor * state,
-        ggml_tensor * causal_mask,
-        ggml_tensor * identity,
-        ggml_tensor * diag_mask,
-        int           il) {
-    const int64_t S_k      = q->ne[0];
-    const int64_t H_k      = q->ne[1];
-    const int64_t n_tokens = q->ne[2];
-    const int64_t n_seqs   = q->ne[3];
-
-    const int64_t S_v = v->ne[0];
-    const int64_t H_v = v->ne[1];
-
-    GGML_ASSERT(v->ne[2] == n_tokens);
-    GGML_ASSERT(k->ne[2] == n_tokens);
-    GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
-    GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
-    GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs);
-
-    GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
-    GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
-
-    GGML_ASSERT(H_k == H_v);  // we did a repeat to make sure this is the case
-
-    const float eps_norm = hparams.f_norm_rms_eps;
-
-    q = ggml_l2_norm(ctx0, q, eps_norm);
-    k = ggml_l2_norm(ctx0, k, eps_norm);
-
-    const float scale = 1.0f / sqrtf(S_v);
-
-    q = ggml_scale(ctx0, q, scale);
-
-    beta = ggml_sigmoid(ctx0, beta);
-
-    cb(q, "q_in", il);
-    cb(k, "k_in", il);
-    cb(v, "v_in", il);
-    cb(beta, "beta_in", il);
-    cb(g, "g_in", il);
-
-    q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
-    k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
-    v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
-    g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 2, 0, 3, 1), n_tokens, 1, H_k, n_seqs);
-
-    beta  = ggml_cont(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3));
-    state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs);
-
-    cb(q, "q_perm", il);
-    cb(k, "k_perm", il);
-    cb(v, "v_perm", il);
-    cb(beta, "beta_perm", il);
-    cb(g, "g_perm", il);
-    cb(state, "state_in", il);
-
-    GGML_ASSERT(q->ne[1] == n_tokens && q->ne[0] == S_k && q->ne[2] == H_k && q->ne[3] == n_seqs);
-    GGML_ASSERT(k->ne[1] == n_tokens && k->ne[0] == S_k && k->ne[2] == H_k && k->ne[3] == n_seqs);
-    GGML_ASSERT(v->ne[1] == n_tokens && v->ne[0] == S_v && v->ne[2] == H_k && v->ne[3] == n_seqs);
-    GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[0] == 1 && beta->ne[3] == n_seqs);
-
-    // Do padding
-    const int64_t chunk_size = CHUNK_SIZE;
-
-    const int64_t pad = (chunk_size - n_tokens % chunk_size) % chunk_size;
-    const int64_t n_chunks = (n_tokens + pad) / chunk_size;
-
-    q = ggml_pad(ctx0, q, 0, pad, 0, 0);
-    k = ggml_pad(ctx0, k, 0, pad, 0, 0);
-    v = ggml_pad(ctx0, v, 0, pad, 0, 0);
-    g = ggml_pad(ctx0, g, pad, 0, 0, 0);
-    beta = ggml_pad(ctx0, beta, 0, pad, 0, 0);
-
-    cb(q, "q_pad", il);
-    cb(k, "k_pad", il);
-    cb(v, "v_pad", il);
-    cb(beta, "beta_pad", il);
-    cb(g, "g_pad", il);
-
-    ggml_tensor * v_beta = ggml_mul(ctx0, v, beta);
-    ggml_tensor * k_beta = ggml_mul(ctx0, k, beta);
-
-    cb(v_beta, "v_beta", il);
-    cb(k_beta, "k_beta", il);
-
-    q      = ggml_reshape_4d(ctx0, q,      S_k, chunk_size, n_chunks, H_k * n_seqs);
-    k      = ggml_reshape_4d(ctx0, k,      S_k, chunk_size, n_chunks, H_k * n_seqs);
-    k_beta = ggml_reshape_4d(ctx0, k_beta, S_k, chunk_size, n_chunks, H_k * n_seqs);
-    v      = ggml_reshape_4d(ctx0, v,      S_v, chunk_size, n_chunks, H_v * n_seqs);
-    v_beta = ggml_reshape_4d(ctx0, v_beta, S_v, chunk_size, n_chunks, H_v * n_seqs);
-
-    g    = ggml_reshape_4d(ctx0, g, chunk_size, 1, n_chunks, H_k * n_seqs);
-    beta = ggml_reshape_4d(ctx0, beta, 1, chunk_size, n_chunks, H_k * n_seqs);
-
-    ggml_tensor * g_cumsum = ggml_cumsum(ctx0, g);
-    cb(g_cumsum, "g_cumsum", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs)
-
-    ggml_tensor * gcs_i = g_cumsum; // ggml_reshape_4d(ctx0, g_cumsum, chunk_size, 1, n_chunks, H_v * n_seqs);
-    ggml_tensor * gcs_j = ggml_reshape_4d(ctx0, g_cumsum, 1, chunk_size, n_chunks, H_v * n_seqs);
-
-    ggml_tensor * gcs_j_broadcast =
-        ggml_repeat_4d(ctx0, gcs_j, chunk_size, chunk_size, n_chunks, H_v * n_seqs);
-
-    ggml_tensor * decay_mask = ggml_sub(ctx0, gcs_j_broadcast, gcs_i);
-    cb(decay_mask, "decay_mask", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
-
-    decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
-    decay_mask = ggml_exp(ctx0, decay_mask);
-    decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
-
-    ggml_tensor * kmulkbeta = ggml_mul_mat(ctx0, k, k_beta);
-
-    ggml_tensor * k_decay = ggml_mul(ctx0, kmulkbeta, decay_mask);
-    ggml_tensor * attn    = ggml_neg(ctx0, ggml_mul(ctx0, k_decay, causal_mask));
-    cb(attn, "attn_pre_solve", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
-
-    ggml_tensor * attn_lower = ggml_mul(ctx0, attn, causal_mask);
-    ggml_tensor * lhs        = ggml_sub(ctx0, ggml_repeat(ctx0, identity, attn_lower), attn_lower);
-
-    ggml_tensor * lin_solve  = ggml_solve_tri(ctx0, lhs, attn, true, true, false);
-    attn                     = ggml_mul(ctx0, lin_solve, causal_mask);
-    attn                     = ggml_add(ctx0, attn, identity);
-    cb(attn, "attn_solved", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
-
-    v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), attn);
-
-    ggml_tensor * g_cumsum_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_cumsum));
-    ggml_tensor * gexp       = ggml_exp(ctx0, g_cumsum_t);
-
-    ggml_tensor * kbeta_gexp = ggml_mul(ctx0, k_beta, gexp);
-    cb(kbeta_gexp, "kbeta_gexp", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs)
-
-    ggml_tensor * k_cumdecay =
-        ggml_cont(ctx0, ggml_transpose(ctx0, ggml_mul_mat(ctx0, attn, ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gexp)))));
-    cb(k_cumdecay, "k_cumdecay", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
-
-    ggml_tensor * attn_kq = ggml_mul_mat(ctx0, k, q);
-    attn_kq = ggml_mul(ctx0, attn_kq, decay_mask);
-    attn_kq = ggml_mul(ctx0, attn_kq, diag_mask);
-    cb(attn_kq, "attn_kq", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
-
-
-    // vectorized calculation of key_gdiff
-    // improved from the chunked version:
-    //   g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1)
-    //   g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp()
-    //   key_gdiff = key * g_diff.unsqueeze(-1)
-    //   kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new
-    //   last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
-
-    // get last element in g_cumsum along chunk_size dimension (ne0)
-    // example: [[x, y, z, ..., last], ...] -> [[last], ...]
-    ggml_tensor * g_last = ggml_view_4d(ctx0, g_cumsum, 1, 1, g_cumsum->ne[2], g_cumsum->ne[3],
-                                        g_cumsum->nb[1], g_cumsum->nb[2], g_cumsum->nb[3],
-                                        (g_cumsum->ne[0] - 1) * ggml_element_size(g_cumsum));
-    g_last = ggml_cont(ctx0, g_last);
-    cb(g_last, "g_last", il); // shape: (1, 1, n_chunks, H_v * n_seqs)
-
-    ggml_tensor * g_last_exp = ggml_exp(ctx0, g_last);
-    cb(g_last_exp, "g_last_exp", il); // shape: (1, 1, n_chunks, H_v * n_seqs)
-
-    ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cumsum, g_last));
-    cb(g_diff, "g_diff", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs)
-
-    ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff);
-    ggml_tensor * key_gdiff = ggml_mul(ctx0, k, g_diff_exp);
-    cb(key_gdiff, "key_gdiff", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs)
-
-
-    // state to be updated per chunk
-    ggml_tensor * new_state = state; // ggml_dup(ctx0, state);
-    cb(new_state, "new_state", il); // shape: (S_v, S_v, H_v, n_seqs)
-
-    // shape after loop of chunks: (S_v, chunk_size, n_chunks, H_v * n_seqs)
-    ggml_tensor * core_attn_out = nullptr;
-
-    for (int64_t chunk = 0; chunk < n_chunks; chunk++) {
-        // shape: (S_k, chunk_size, 1, H_k * n_seqs)
-        ggml_tensor * q_chunk = get_slice_2d(ctx0, q, chunk); // (no cont), next op: ggml_mul
-
-        // shape: (S_v, chunk_size, 1, H_v * n_seqs)
-        ggml_tensor * v_chunk = get_slice_2d(ctx0, v, chunk); // (no cont), next op: ggml_repeat
-
-        // shape: (chunk_size, 1, n_chunks, H_v * n_seqs)
-        ggml_tensor * gexp_chunk = get_slice_2d(ctx0, gexp, chunk); // (no cont), next op: ggml_mul
-
-        // shape: (chunk_size, 1, H_v * n_seqs)
-        ggml_tensor * k_cumdecay_chunk = get_slice_2d(ctx0, k_cumdecay, chunk); // (no cont), next op: ggml_mul_mat
-
-        // attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
-        // replaced by precomputed attn_kq
-        ggml_tensor * attn_chunk = get_slice_2d(ctx0, attn_kq, chunk);
-        cb(attn_chunk, "attn_chunk", il);
-
-        ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, new_state, 1, 0, 2, 3), S_v, S_v, 1, H_v * n_seqs);
-
-        // v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
-        ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay_chunk);
-        cb(v_prime, "v_prime_chunk", il); // shape: (S_v, 1, H_v * n_seqs)
-
-        // v_new = v_i - v_prime
-        ggml_tensor * v_new = ggml_sub(ctx0, ggml_repeat(ctx0, v_chunk, v_prime), v_prime);
-        ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new));
-        cb(v_new, "v_new_chunk", il);
-
-        // attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
-        ggml_tensor * q_g_exp    = ggml_mul(ctx0, q_chunk, gexp_chunk);
-        ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_g_exp);
-        cb(attn_inter, "attn_inter_chunk", il);
-
-        // core_attn_out[:, :, i] = attn_inter + attn @ v_new
-        ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, attn_chunk);
-        cb(v_attn, "v_attn_chunk", il);
-
-        ggml_tensor * core_attn_out_chunk = ggml_add(ctx0, attn_inter, v_attn);
-        cb(core_attn_out_chunk, "core_attn_out_chunk", il); // shape: (S_v, chunk_size, 1, H_v * n_seqs)
-
-        core_attn_out = core_attn_out == nullptr
-            ? core_attn_out_chunk
-            : ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 2);
-
-        // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new
-        ggml_tensor * k_gdiff = ggml_cont(ctx0, get_slice_2d(ctx0, key_gdiff, chunk));
-        //ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, k_gdiff, v_new); // this is slower on metal, why?
-        ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, ggml_cont(ctx0, ggml_transpose(ctx0, k_gdiff)));
-
-        // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
-        ggml_tensor * gexp_last_chunk = ggml_cont(ctx0, get_slice_2d(ctx0, g_last_exp, chunk));
-        new_state = ggml_add(ctx0,
-            ggml_mul(ctx0, new_state, ggml_reshape_4d(ctx0, gexp_last_chunk, gexp_last_chunk->ne[0], gexp_last_chunk->ne[1], H_v, n_seqs)),
-            ggml_reshape_4d(ctx0, kgdmulvnew, kgdmulvnew->ne[0], kgdmulvnew->ne[1], H_v, n_seqs));
-    }
-
-    // truncate padded tokens
-    ggml_tensor * output_tokens = ggml_view_4d(ctx0, core_attn_out,
-            S_v, n_tokens, H_v, n_seqs,
-            ggml_row_size(core_attn_out->type, S_v),
-            ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks),
-            ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks * H_v), 0);
-    output_tokens = ggml_cont(ctx0, output_tokens);
-    cb(output_tokens, "output_tokens", il);
-
-    // permute back to (S_v, H_v, n_tokens, n_seqs)
-    output_tokens = ggml_permute(ctx0, output_tokens, 0, 2, 1, 3);
-    output_tokens = ggml_cont(ctx0, output_tokens);
-
-    return {output_tokens, new_state};
-}
-
-std::pair llm_build_qwen3next::build_delta_net_autoregressive(
-        ggml_tensor * q,
-        ggml_tensor * k,
-        ggml_tensor * v,
-        ggml_tensor * g,
-        ggml_tensor * beta,
-        ggml_tensor * state,
-        int           il) {
-    const int64_t S_k      = q->ne[0];
-    const int64_t H_k      = q->ne[1];
-    const int64_t n_tokens = q->ne[2];
-    const int64_t n_seqs   = q->ne[3];
-
-    const int64_t S_v = v->ne[0];
-    const int64_t H_v = v->ne[1];
-
-    GGML_ASSERT(n_tokens == 1);  // This function is optimized for single token processing
-    GGML_ASSERT(v->ne[2] == n_tokens);
-    GGML_ASSERT(k->ne[2] == n_tokens);
-    GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
-    GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
-    GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs);
-
-    GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
-    GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
-
-    GGML_ASSERT(H_k == H_v);  // we did a repeat to make sure this is the case
-
-    const float eps_norm = hparams.f_norm_rms_eps;
-
-    q = ggml_l2_norm(ctx0, q, eps_norm);
-    k = ggml_l2_norm(ctx0, k, eps_norm);
-
-    const float scale = 1.0f / sqrtf(S_v);
-
-    q    = ggml_scale(ctx0, q, scale);
-    beta = ggml_sigmoid(ctx0, beta);
-
-    cb(q, "q_in", il);
-    cb(k, "k_in", il);
-    cb(v, "v_in", il);
-    cb(beta, "beta_in", il);
-    cb(g, "g_in", il);
-
-    state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs);
-
-    ggml_tensor * g_t    = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, g), 1, 1, H_k, n_seqs);
-    ggml_tensor * beta_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, beta), 1, 1, H_k, n_seqs);
-
-    // Apply exponential to g_t
-    g_t = ggml_exp(ctx0, g_t);
-
-    // Apply the gated delta rule for the single timestep
-    // last_recurrent_state = last_recurrent_state * g_t
-    state = ggml_mul(ctx0, state, g_t);
-
-    // kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2)
-    ggml_tensor * k_t_unsqueezed = ggml_reshape_4d(ctx0, k, 1, S_v, H_v, n_seqs);
-    ggml_tensor * kv_mem         = ggml_mul(ctx0, state, k_t_unsqueezed);
-    // we need to sum over dim=-2, so we transpose, sum, then transpose again
-    kv_mem = ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, kv_mem))));
-
-    // v_t = v.unsqueeze(2) (we insert the singleton dimension after n_seqs and H_v)
-    ggml_tensor * v_t    = ggml_reshape_4d(ctx0, v, S_v, 1, H_v, n_seqs);
-    // delta = (v_t - kv_mem) * beta_t
-    ggml_tensor * v_diff = ggml_sub(ctx0, v_t, kv_mem);  // both should be [S_v, 1, H_v, n_seqs]
-    ggml_tensor * delta  = ggml_mul(ctx0, v_diff, beta_t);
-
-    // last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta
-    ggml_tensor * k_t_delta = ggml_mul(ctx0, ggml_repeat_4d(ctx0, k_t_unsqueezed, S_v, S_v, H_v, n_seqs), delta);
-    state                   = ggml_add(ctx0, state, k_t_delta);
-
-    // Compute the attention output
-    // core_attn_out = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2)
-    ggml_tensor * q_t_unsqueezed = ggml_reshape_4d(ctx0, q, 1, S_v, H_v, n_seqs);  // unsqueeze q_t
-    ggml_tensor * state_q        = ggml_mul(ctx0, state, q_t_unsqueezed);
-    // again, since it's over dim = -2, transpose, sum, transpose back
-    ggml_tensor * core_attn_out =
-        ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, state_q))));
-
-    // core_attn_out should be [S_v, 1, H_v, n_seqs] after this
-    cb(core_attn_out, "output_tokens", il);
-    cb(state, "new_state", il);
-
-    return {core_attn_out, state};
-}
-
 ggml_tensor * llm_build_qwen3next::build_norm_gated(
         ggml_tensor * input,
         ggml_tensor * weights,
@@ -746,7 +395,7 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
     v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
 
     ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs);
-    state               = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim * num_v_heads, 1, n_seqs);
+    state               = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim,  num_v_heads, n_seqs);
     cb(state, "state_predelta", il);
 
     // if head keys and value keys are different, repeat to force tensors into matching shapes
@@ -775,13 +424,10 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
     cb(k_conv, "k_conv_predelta", il);
     cb(v_conv, "v_conv_predelta", il);
 
-    // Choose between build_delta_net_chunking, build_delta_net_recurrent, and build_delta_net_autoregressive based on n_tokens
-    std::pair attn_out; // pair of (output, new_state)
-    if (n_seq_tokens == 1) {
-        attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il);
-    } else {
-        attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il);
-    }
+    std::pair attn_out = build_delta_net_unified(ctx0, q_conv, k_conv, v_conv,
+            gate, beta, state, causal_mask, identity, diag_mask,
+            il, CHUNK_SIZE, hparams.f_norm_rms_eps);
+
     ggml_tensor * output    = attn_out.first;
     ggml_tensor * new_state = attn_out.second;
     cb(output, "attn_output", il);
diff --git a/src/models/qwen3vl-moe.cpp b/src/models/qwen3vl-moe.cpp
index f72f80a8376..e5e1a2150c8 100644
--- a/src/models/qwen3vl-moe.cpp
+++ b/src/models/qwen3vl-moe.cpp
@@ -2,7 +2,8 @@
 
 llm_build_qwen3vlmoe::llm_build_qwen3vlmoe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
     const size_t n_deepstack_layers = hparams.n_deepstack_layers;
-    const int64_t n_embd = hparams.n_embd;
+
+    const int64_t n_embd      = hparams.n_embd;
     const int64_t n_embd_head = hparams.n_embd_head_v;
 
     GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -16,17 +17,6 @@ llm_build_qwen3vlmoe::llm_build_qwen3vlmoe(const llama_model & model, const llm_
     int sections[4];
     std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
 
-    std::vector deepstack_features(n_deepstack_layers, nullptr);
-
-    if (ubatch.embd) {
-        // Image input: split main embd and deepstack embds
-        ggml_tensor * inpL_main = ggml_view_2d(ctx0, inpL, n_embd, n_tokens, inpL->nb[1], 0);
-        for (size_t i = 0; i < n_deepstack_layers; i++) {
-            deepstack_features[i] = ggml_view_2d(ctx0, inpL, n_embd, n_tokens, inpL->nb[1], (i + 1) * n_embd * sizeof(float));
-        }
-        inpL = inpL_main;
-    }
-
     // inp_pos - contains the positions
     ggml_tensor * inp_pos = build_inp_pos();
 
@@ -120,8 +110,9 @@ llm_build_qwen3vlmoe::llm_build_qwen3vlmoe(const llama_model & model, const llm_
         cur = build_cvec(cur, il);
         cb(cur, "l_out", il);
 
-        if (ubatch.embd && (size_t)il < n_deepstack_layers) {
-            cur = ggml_add(ctx0, cur, deepstack_features[il]);
+        if (il < (int) n_deepstack_layers) {
+            ggml_tensor * ds = ggml_view_2d(ctx0, res->t_inp_embd, n_embd, n_tokens, res->t_inp_embd->nb[1], (il + 1) * n_embd * sizeof(float));
+            cur = ggml_add(ctx0, cur, ds);
             cb(cur, "deepstack_out", il);
         }
 
diff --git a/src/models/qwen3vl.cpp b/src/models/qwen3vl.cpp
index 0bae52239ca..0f8315b3240 100644
--- a/src/models/qwen3vl.cpp
+++ b/src/models/qwen3vl.cpp
@@ -2,7 +2,8 @@
 
 llm_build_qwen3vl::llm_build_qwen3vl(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
     const size_t n_deepstack_layers = hparams.n_deepstack_layers;
-    const int64_t n_embd = hparams.n_embd;
+
+    const int64_t n_embd      = hparams.n_embd;
     const int64_t n_embd_head = hparams.n_embd_head_v;
 
     GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -16,17 +17,6 @@ llm_build_qwen3vl::llm_build_qwen3vl(const llama_model & model, const llm_graph_
     int sections[4];
     std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
 
-    std::vector deepstack_features(n_deepstack_layers, nullptr);
-
-    if (ubatch.embd) {
-        // Image input: split main embd and deepstack embds
-        ggml_tensor * inpL_main = ggml_view_2d(ctx0, inpL, n_embd, n_tokens, inpL->nb[1], 0);
-        for (size_t i = 0; i < n_deepstack_layers; i++) {
-            deepstack_features[i] = ggml_view_2d(ctx0, inpL, n_embd, n_tokens, inpL->nb[1], (i + 1) * n_embd * sizeof(float));
-        }
-        inpL = inpL_main;
-    }
-
     // inp_pos - contains the positions
     ggml_tensor * inp_pos = build_inp_pos();
 
@@ -113,8 +103,9 @@ llm_build_qwen3vl::llm_build_qwen3vl(const llama_model & model, const llm_graph_
         cur = build_cvec(cur, il);
         cb(cur, "l_out", il);
 
-        if (ubatch.embd && (size_t)il < n_deepstack_layers) {
-            cur = ggml_add(ctx0, cur, deepstack_features[il]);
+        if (il < (int) n_deepstack_layers) {
+            ggml_tensor * ds = ggml_view_2d(ctx0, res->t_inp_embd, n_embd, n_tokens, res->t_inp_embd->nb[1], (il + 1) * n_embd * sizeof(float));
+            cur = ggml_add(ctx0, cur, ds);
             cb(cur, "deepstack_out", il);
         }
 
diff --git a/src/models/step35-iswa.cpp b/src/models/step35-iswa.cpp
new file mode 100644
index 00000000000..f8737815a67
--- /dev/null
+++ b/src/models/step35-iswa.cpp
@@ -0,0 +1,168 @@
+#include "models.h"
+
+llm_build_step35_iswa::llm_build_step35_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
+    ggml_tensor * cur;
+    ggml_tensor * inpL;
+
+    inpL = build_inp_embd(model.tok_embd);
+    ggml_tensor * inp_pos     = build_inp_pos();
+    auto        * inp_attn    = build_attn_inp_kv_iswa();
+    ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+    for (int il = 0; il < n_layer; ++il) {
+        ggml_tensor * inpSA = inpL;
+
+        const uint32_t n_head_l    = hparams.n_head(il);
+        const uint32_t n_head_kv_l = hparams.n_head_kv(il);
+
+        const float freq_base_l  = model.get_rope_freq_base(cparams, il);
+        const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
+
+        cur = inpL;
+
+        // dump pre-attn RMSNorm input to pinpoint layer boundary issues
+        cb(cur, "attn_norm_in", il);
+
+        // self-attention
+        {
+            cur = build_norm(cur, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+            ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+            ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+            ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+
+            cb(Qcur, "Qcur", il);
+            cb(Kcur, "Kcur", il);
+            cb(Vcur, "Vcur", il);
+
+            Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head_l,    n_tokens);
+            Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv_l, n_tokens);
+            Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head_v, n_head_kv_l, n_tokens);
+
+            // Q/K per-head RMSNorm (Step35 q_norm / k_norm)
+            if (model.layers[il].attn_q_norm) {
+                Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il);
+                cb(Qcur, "Qcur_normed", il);
+            }
+            if (model.layers[il].attn_k_norm) {
+                Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il);
+                cb(Kcur, "Kcur_normed", il);
+            }
+
+            // RoPE (partial rotary factors per layer)
+            const bool is_swa = hparams.is_swa(il);
+            ggml_tensor * rope_factors = is_swa ? nullptr : model.get_rope_factors(cparams, il);
+            const int64_t n_rot_l = is_swa ? hparams.n_rot : (hparams.n_rot / 2);
+            Qcur = ggml_rope_ext(
+                ctx0, Qcur, inp_pos, rope_factors,
+                n_rot_l, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
+                ext_factor, attn_factor, beta_fast, beta_slow
+            );
+            Kcur = ggml_rope_ext(
+                ctx0, Kcur, inp_pos, rope_factors,
+                n_rot_l, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
+                ext_factor, attn_factor, beta_fast, beta_slow
+            );
+            cb(Qcur, "Qcur_pos", il);
+            cb(Kcur, "Kcur_pos", il);
+
+            const float kq_scale = 1.0f / sqrtf(float(n_embd_head_k));
+            ggml_tensor * attn_out = build_attn(inp_attn,
+                    nullptr, nullptr,
+                    Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
+            cb(attn_out, "attn_out", il);
+            // head-wise attention gate: sigmoid(g_proj(x)) in torch
+            if (model.layers[il].wqkv_gate) {
+                ggml_tensor * gate = build_lora_mm(model.layers[il].wqkv_gate, cur); // [n_head_l, n_tokens]
+                cb(gate, "attn_gate", il);
+
+                gate = ggml_sigmoid(ctx0, gate);
+                cb(gate, "attn_gate_sigmoid", il);
+
+                // reshape + broadcast to [n_embd_head_v, n_head_l, n_tokens]
+                ggml_tensor * attn_3d = ggml_reshape_3d(ctx0, attn_out, n_embd_head_v, n_head_l, n_tokens);
+                ggml_tensor * gate_3d = ggml_reshape_3d(ctx0, gate,       1,          n_head_l, n_tokens);
+                cb(gate_3d, "attn_gate_3d", il);
+
+                attn_3d = ggml_mul(ctx0, attn_3d, gate_3d);
+                cb(attn_3d, "attn_gated_3d", il);
+
+                attn_out = ggml_reshape_2d(ctx0, attn_3d, n_embd_head_v * n_head_l, n_tokens);
+                cb(attn_out, "attn_gated", il);
+            }
+
+            // output projection
+            cur = build_lora_mm(model.layers[il].wo, attn_out);
+            cb(cur, "attn_proj", il);
+        }
+
+        if (il == n_layer - 1 && inp_out_ids) {
+            cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+            inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+        }
+
+        ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+        cb(ffn_inp, "ffn_inp", il);
+
+        cur = build_norm(ffn_inp, model.layers[il].ffn_norm, nullptr, LLM_NORM_RMS, il);
+        cb(cur, "ffn_norm", il);
+
+        // feed-forward
+        if (model.layers[il].ffn_gate_inp == nullptr) {
+            // dense MLP
+            cur = build_ffn(cur,
+                    model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   nullptr,
+                    model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, nullptr,
+                    model.layers[il].ffn_down, model.layers[il].ffn_down_b, nullptr,
+                    nullptr,
+                    LLM_FFN_SILU, LLM_FFN_PAR, il);
+            cb(cur, "ffn_out", il);
+        } else {
+            // MoE routed experts
+            const bool  norm_w  = hparams.expert_weights_norm;
+            const float w_scale = hparams.expert_weights_scale;
+            const bool  scale_w = w_scale != 0.0f;
+            ggml_tensor * moe_out = build_moe_ffn(cur,
+                    model.layers[il].ffn_gate_inp,
+                    model.layers[il].ffn_up_exps,
+                    model.layers[il].ffn_gate_exps,
+                    model.layers[il].ffn_down_exps,
+                    model.layers[il].ffn_exp_probs_b,
+                    n_expert, n_expert_used,
+                    LLM_FFN_SILU,
+                    norm_w, scale_w, w_scale,
+                    (llama_expert_gating_func_type) hparams.expert_gating_func,
+                    il);
+            cb(moe_out, "ffn_moe_out", il);
+
+            // shared expert MLP (always added on MoE layers in Step35)
+            ggml_tensor * sh_out = build_ffn(cur,
+                    model.layers[il].ffn_up_shexp,   nullptr, nullptr,
+                    model.layers[il].ffn_gate_shexp, nullptr, nullptr,
+                    model.layers[il].ffn_down_shexp, nullptr, nullptr,
+                    nullptr,
+                    LLM_FFN_SILU, LLM_FFN_PAR, il);
+            cb(sh_out, "ffn_shared_out", il);
+
+            cur = ggml_add(ctx0, moe_out, sh_out);
+            cb(cur, "ffn_out", il);
+        }
+        cur = ggml_add(ctx0, cur, ffn_inp);
+        cur = build_cvec(cur, il);
+        cb(cur, "l_out", il);
+
+        inpL = cur;
+    }
+
+    cur = inpL;
+
+    cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1);
+    cb(cur, "result_norm", -1);
+    res->t_embd = cur;
+
+    cur = build_lora_mm(model.output, cur);
+    cb(cur, "result_output", -1);
+    res->t_logits = cur;
+
+    ggml_build_forward_expand(gf, cur);
+}
diff --git a/src/unicode.cpp b/src/unicode.cpp
index b47dcbe6198..adfc489d1f0 100644
--- a/src/unicode.cpp
+++ b/src/unicode.cpp
@@ -497,49 +497,26 @@ static std::vector unicode_regex_split_custom_llama3(const std::string &
     return bpe_offsets;
 }
 
-// use std::wregex to split the text
-static std::vector unicode_regex_split_stl(const std::wstring & wtext, const std::wstring & regex_expr, const std::vector & offsets) {
-    std::wregex expr(regex_expr, std::regex_constants::optimize | std::regex_constants::nosubs);
-    std::vector bpe_offsets; // store the offset of each word
-    bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
-    size_t start = 0;
-    for (auto offset : offsets) {
-        std::wcregex_iterator it(wtext.data() + start, wtext.data() + start + offset, expr);
-        std::wcregex_iterator end;
-
-        int64_t start_idx = 0;
-        while (it != end) {
-            std::wcmatch match = *it;
-            if (match.position() > start_idx) {
-                bpe_offsets.emplace_back(match.position() - start_idx);
-            }
-            bpe_offsets.emplace_back(match.length());
-            start_idx = match.position() + match.length();
-            ++it;
-        }
-
-        if (start_idx < (int64_t) offset) {
-            bpe_offsets.emplace_back(offset - start_idx);
-        }
-        start += offset;
-    }
-
-    return bpe_offsets;
-}
-
-// use std::regex to split the text
-static std::vector unicode_regex_split_stl(const std::string & text, const std::string & regex_expr, const std::vector & offsets) {
-    std::regex expr(regex_expr, std::regex_constants::optimize | std::regex_constants::nosubs);
+template 
+static std::vector unicode_regex_split_stl(const std::basic_string & text, const std::basic_string & regex, const std::vector & offsets) {
+    using BidirIt = typename std::basic_string::const_iterator;
+#ifdef _MSC_VER
+    // Bypass bug in MSVC: https://github.com/ggml-org/llama.cpp/issues/17830
+    constexpr auto regex_flags = std::regex_constants::ECMAScript;
+#else
+    constexpr auto regex_flags = std::regex_constants::optimize | std::regex_constants::nosubs;
+#endif
+    std::basic_regex expr(regex, regex_flags);
     std::vector bpe_offsets; // store the offset of each word
     bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
     size_t start = 0;
     for (auto offset : offsets) {
-        std::cregex_iterator it(text.data() + start, text.data() + start + offset, expr);
-        std::cregex_iterator end;
+        std::regex_iterator it(text.begin() + start, text.begin() + start + offset, expr);
+        std::regex_iterator end;
 
         int64_t start_idx = 0;
         while (it != end) {
-            std::cmatch match = *it;
+            std::match_results match = *it;
             if (match.position() > start_idx) {
                 bpe_offsets.emplace_back(match.position() - start_idx);
             }
diff --git a/tests/test-autorelease.cpp b/tests/test-autorelease.cpp
index 35b09aaeaca..ca87c56a8fd 100644
--- a/tests/test-autorelease.cpp
+++ b/tests/test-autorelease.cpp
@@ -1,4 +1,4 @@
-// ref: https://github.com/ggerganov/llama.cpp/issues/4952#issuecomment-1892864763
+// ref: https://github.com/ggml-org/llama.cpp/issues/4952#issuecomment-1892864763
 
 #include 
 #include 
diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp
index 6bb781737e9..6fe1780f3ba 100644
--- a/tests/test-backend-ops.cpp
+++ b/tests/test-backend-ops.cpp
@@ -169,20 +169,22 @@ static void init_tensor_kq_mask(ggml_tensor * tensor, float min = -1.0f, float m
     const int blck0 = 128;
     const int blck1 = 64;
 
-    // number of INF blocks
-    const int n_inf_blocks = 0.1*(ne0*ne1*ne2*ne3)/(blck0*blck1);
+    // number of INF/zero blocks
+    const int n_inf_zero_blocks = 0.2*(ne0*ne1*ne2*ne3)/(blck0*blck1);
 
-    for (int b = 0; b < n_inf_blocks; b++) {
+    for (int b = 0; b < n_inf_zero_blocks; b++) {
         const int p3 = (rd() % ne3);
         const int p2 = (rd() % ne2);
         const int p1 = (rd() % ne1);
         const int p0 = (rd() % ne0);
 
+        bool inf = rd() & 1;
+
         for (int i1 = 0; i1 < blck1 && p1 + i1 < ne1; i1++) {
             const int idx = p3*ne2*ne1*ne0 + p2*ne1*ne0 + (p1 + i1)*ne0 + p0;
 
             for (int i0 = 0; i0 < blck0 && p0 + i0 < ne0; i0++) {
-                data_f32[idx + i0] = -INFINITY;
+                data_f32[idx + i0] = inf ? -INFINITY : 0.0f;
             }
         }
     }
@@ -6122,7 +6124,19 @@ struct test_flash_attn_ext : public test_case {
         ggml_tensor * k = create_permuted(type_KV,       hsk_padded, kv, nh,         nr23[1], true); // the K tensor is usually a view of the K cache
         ggml_set_name(k, "k");
 
-        ggml_tensor * v = create_permuted(type_KV,       hsv_padded, kv, nh,         nr23[1], true); // the V tensor is usually a view of the V cache
+        ggml_tensor * v = nullptr;
+        if (hsk_padded == 576 && hsv_padded == 512) {
+            // TODO: this branch should become a separate test case parameter instead of hardcoding this for these head shapes
+
+            // in this branch, the V cache is sub-view of the K cache. this is used by some MLA-based models
+            // for more info:
+            //   - https://github.com/ggml-org/llama.cpp/pull/13435
+            //   - https://github.com/ggml-org/llama.cpp/pull/18953#issuecomment-3774948392
+            //   - https://github.com/ggml-org/llama.cpp/pull/18986
+            v = ggml_view_4d(ctx, k, hsv_padded, kv, nh, nr23[1], k->nb[1], k->nb[2], k->nb[3], 0);
+        } else {
+            v = create_permuted(type_KV,       hsv_padded, kv, nh,         nr23[1], true); // the V tensor is usually a view of the V cache
+        }
         ggml_set_name(v, "v");
 
         ggml_tensor * m = nullptr;
@@ -8020,6 +8034,8 @@ static std::vector> make_test_cases_eval() {
         for (int mode : {GGML_ROPE_TYPE_NORMAL, GGML_ROPE_TYPE_NEOX, GGML_ROPE_TYPE_MROPE, GGML_ROPE_TYPE_IMROPE, GGML_ROPE_TYPE_VISION}) {
             for (bool ff : {false, true}) {
                 test_cases.emplace_back(new test_rope(type, {128,  32, 2, 1}, 128, mode, 512, 1.4245f, 0.7465f, 1.4245f, ff, 0, true, true));
+                test_cases.emplace_back(new test_rope(type, {128,  32, 2, 1}, 128, mode, 512, 1.4245f, 0.7465f, 1.4245f, ff, 1, true, true));
+                test_cases.emplace_back(new test_rope(type, {128,  32, 2, 3}, 128, mode, 512, 1.4245f, 0.7465f, 1.4245f, ff, 1, true, true));
             }
         }
     }
@@ -8201,11 +8217,13 @@ static std::vector> make_test_cases_eval() {
                         if (!mask && max_bias > 0.0f) continue;
                         for (float logit_softcap : {0.0f, 10.0f}) {
                             if (hsk != 128 && logit_softcap != 0.0f) continue;
-                            for (int nh : { 4, }) {
+                            for (int nh : { 1, 4 }) {
+                                if (nh == 1 && hsk != 576) continue; // GLM 4.7 Flash
                                 for (int nr3 : { 1, 3, }) {
                                     if (hsk > 64 && nr3 > 1) continue; // skip broadcast for large head sizes
-                                    for (int nr2 : { 1, 4, 16 }) {
-                                        if (nr2 == 16 && hsk != 128) continue;
+                                    for (int nr2 : { 1, 4, 12, 20 }) {
+                                        if (nr2 == 12 && hsk != 128) continue;
+                                        if (nr2 == 20 && (nh != 1 || hsk != 576)) continue;
                                         //for (int kv : { 1, 17, 31, 33, 61, 113, 65, 127, 129, 130, 255, 260, 371, 380, 407, 512, 1024, }) {
                                         for (int kv : { 113, 512, 1024, }) {
                                             if (nr2 != 1 && kv != 512) continue;
@@ -8213,6 +8231,7 @@ static std::vector> make_test_cases_eval() {
                                                 for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) {
                                                     if (hsk != 128 && prec == GGML_PREC_DEFAULT) continue;
                                                     for (ggml_type type_KV : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
+                                                        if (type_KV != GGML_TYPE_F16 && hsk != 64 && hsk != 72) continue;
                                                         test_cases.emplace_back(new test_flash_attn_ext(
                                                                     hsk, hsv, nh, {nr2, nr3}, kv, nb, mask, sinks, max_bias, logit_softcap, prec, type_KV));
                                                         // run fewer test cases permuted
@@ -8460,6 +8479,9 @@ static std::vector> make_test_cases_perf() {
     // Qwen3-VL-8B https://github.com/ggml-org/llama.cpp/issues/17012
     test_cases.emplace_back(new test_flash_attn_ext(72, 72, 16, {1, 1}, 5776, 5776, false, false, 0, 0, GGML_PREC_F32, GGML_TYPE_F16));
 
+    test_cases.emplace_back(new test_flash_attn_ext(64, 64, 8, {8, 1}, 7680, 1, true, false, 0, 0, GGML_PREC_F32, GGML_TYPE_F16));
+    test_cases.emplace_back(new test_flash_attn_ext(64, 64, 8, {8, 1}, 7680, 4, true, false, 0, 0, GGML_PREC_F32, GGML_TYPE_F16));
+
     for (int kv : { 4096, 8192, 16384, }) {
         for (int hs : { 64, 128, }) {
             for (int nr : { 1, 4, }) {
@@ -8574,6 +8596,13 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
             output_printer->print_operation(info);
             return false;
         }
+        // Use reference implementation on the CPU backend for comparison
+        using ggml_backend_cpu_set_use_ref_t = void (*)(ggml_backend_t, bool);
+        auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend_cpu));
+        auto * set_use_ref = (ggml_backend_cpu_set_use_ref_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_set_use_ref");
+        if (set_use_ref) {
+            set_use_ref(backend_cpu, true);
+        }
 
         size_t n_ok = 0;
         size_t                   tests_run = 0;
diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp
index e1429007237..27b537a0369 100644
--- a/tests/test-chat-template.cpp
+++ b/tests/test-chat-template.cpp
@@ -54,7 +54,6 @@ std::string DEFAULT_JSON = R"({
     ],
     "bos_token": "",
     "eos_token": "",
-    "tools": [],
     "add_generation_prompt": true
 })";
 
diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp
index 6820acf6792..4378a8db716 100644
--- a/tests/test-chat.cpp
+++ b/tests/test-chat.cpp
@@ -462,9 +462,9 @@ static void test_parser_with_streaming(const common_chat_msg & expected, const s
     for (size_t i = 1; i <= raw_message.size(); ++i) {
         auto curr_msg = parse_msg(std::string(utf8_truncate_safe_view(std::string_view(raw_message).substr(0, i))));
         if (curr_msg == simple_assist_msg("")) continue;
-        LOG_INF("Streaming msg: %s\n", common_chat_msgs_to_json_oaicompat({curr_msg}).dump().c_str());
+        LOG_INF("Streaming msg: %s\n", common_chat_msgs_to_json_oaicompat({curr_msg}).dump().c_str());
         for (auto diff: common_chat_msg_diff::compute_diffs(last_msg, curr_msg)) {
-            LOG_INF("Streaming diff: %s\n", common_chat_msg_diff_to_json_oaicompat(diff).dump().c_str());
+            LOG_INF("Streaming diff: %s\n", common_chat_msg_diff_to_json_oaicompat(diff).dump().c_str());
             if (!diff.reasoning_content_delta.empty()) {
                 merged.reasoning_content += diff.reasoning_content_delta;
             }
@@ -480,7 +480,7 @@ static void test_parser_with_streaming(const common_chat_msg & expected, const s
                     merged.tool_calls.back().arguments += diff.tool_call_delta.arguments;
                 }
             }
-            LOG_INF("Streaming merged: %s\n", common_chat_msgs_to_json_oaicompat({merged}).dump().c_str());
+            LOG_INF("Streaming merged: %s\n", common_chat_msgs_to_json_oaicompat({merged}).dump().c_str());
         }
         assert_msg_equals(curr_msg, merged, true);
         last_msg = curr_msg;
@@ -592,7 +592,7 @@ static void test_peg_parser(common_chat_templates * tmpls, const std::function({msg});
+        auto oai_json = common_chat_msgs_to_json_oaicompat({msg});
         auto msgs2 = common_chat_msgs_parse_oaicompat(oai_json);
         assert_equals((size_t) 1, msgs2.size());
         auto msg2 = msgs2[0];
@@ -646,7 +646,7 @@ static void test_msgs_oaicompat_json_conversion() {
             "  }\n"
             "]"
         ),
-        common_chat_msgs_to_json_oaicompat({message_user_parts}).dump(2));
+        common_chat_msgs_to_json_oaicompat({message_user_parts}).dump(2));
 
     assert_equals(
         std::string(
@@ -666,7 +666,7 @@ static void test_msgs_oaicompat_json_conversion() {
             "  }\n"
             "]"
         ),
-        common_chat_msgs_to_json_oaicompat({message_assist_call_python}).dump(2));
+        common_chat_msgs_to_json_oaicompat({message_assist_call_python}).dump(2));
 
     auto res = common_chat_msgs_parse_oaicompat(json::parse("[{\"role\": \"assistant\", \"tool_calls\": []}]"));
     assert_equals(1, res.size());
@@ -693,7 +693,7 @@ static void test_tools_oaicompat_json_conversion() {
     };
 
     for (const auto & tool : tools) {
-        auto oai_json = common_chat_tools_to_json_oaicompat({tool});
+        auto oai_json = common_chat_tools_to_json_oaicompat({tool});
         auto tools2 = common_chat_tools_parse_oaicompat(oai_json);
         assert_equals((size_t) 1, tools2.size());
         auto tool2 = tools2[0];
@@ -726,7 +726,7 @@ static void test_tools_oaicompat_json_conversion() {
             "  }\n"
             "]"
         ),
-        common_chat_tools_to_json_oaicompat({special_function_tool}).dump(2));
+        common_chat_tools_to_json_oaicompat({special_function_tool}).dump(2));
 
     {
         auto tools_no_params = common_chat_tools_parse_oaicompat(json::parse(
@@ -3799,6 +3799,134 @@ static void test_template_output_peg_parsers() {
         });
     }
 
+    {
+        // Solar-Open-100B
+        auto tmpls = read_templates("models/templates/upstage-Solar-Open-100B.jinja");
+
+        // Test basic message
+        test_peg_parser(tmpls.get(), [&](auto & t) {
+            t.input = "<|content|>Hello, world!\nWhat's up?";
+            t.expect = message_assist;
+        });
+
+        // Test basic message and reasoning
+        test_peg_parser(tmpls.get(), [&](auto & t) {
+            t.input = "<|think|>I'm\nthinking<|end|><|begin|>assistant<|content|>Hello, world!\nWhat's up?";
+            t.expect = message_assist_thoughts;
+        });
+
+        // Test basic message and reasoning_effort = low
+        test_peg_parser(tmpls.get(), [&](auto & t) {
+            t.input = "<|content|>Hello, world!\nWhat's up?";
+            t.params.chat_template_kwargs["reasoning_effort"] = "\"low\"";
+            t.expect = message_assist;
+        });
+
+        // Test tool call
+        test_peg_parser(tmpls.get(), [&](auto & t) {
+            t.input = "<|tool_calls|>"
+                      "<|tool_call:begin|>123456789"
+                      "<|tool_call:name|>special_function"
+                      "<|tool_call:args|>{\"arg1\":1}"
+                      "<|tool_call:end|>";
+
+            t.params.chat_template_kwargs["reasoning_effort"] = "\"low\"";
+            t.params.tools = {special_function_tool};
+            t.expect = message_assist_call_id;
+        });
+
+        // Test tool call with reasoning
+        test_peg_parser(tmpls.get(), [&](auto & t) {
+            t.input = "<|think|>I'm\nthinking<|end|>"
+                      "<|begin|>assistant<|tool_calls|>"
+                      "<|tool_call:begin|>0"
+                      "<|tool_call:name|>special_function"
+                      "<|tool_call:args|>{\"arg1\":1}"
+                      "<|tool_call:end|>";
+
+            t.params.tools = {special_function_tool};
+            t.expect = message_assist_thoughts_call_idx;
+        });
+
+        // Test tool call with reasoning and tool_choice = required
+        test_peg_parser(tmpls.get(), [&](auto & t) {
+            t.input = "<|think|>I'm\nthinking<|end|>"
+                      "<|begin|>assistant<|tool_calls|>"
+                      "<|tool_call:begin|>0"
+                      "<|tool_call:name|>special_function"
+                      "<|tool_call:args|>{\"arg1\":1}"
+                      "<|tool_call:end|>";
+
+            t.params.tools = {special_function_tool};
+            t.params.tool_choice = COMMON_CHAT_TOOL_CHOICE_REQUIRED;
+            t.expect = message_assist_thoughts_call_idx;
+        });
+
+        // Test tool call without reasoning and tool_choice = required
+        test_peg_parser(tmpls.get(), [&](auto & t) {
+            t.input = "<|tool_calls|>"
+                      "<|tool_call:begin|>0"
+                      "<|tool_call:name|>special_function"
+                      "<|tool_call:args|>{\"arg1\":1}"
+                      "<|tool_call:end|>";
+
+            t.params.tools = {special_function_tool};
+            t.params.tool_choice = COMMON_CHAT_TOOL_CHOICE_REQUIRED;
+            t.params.chat_template_kwargs["reasoning_effort"] = "\"low\"";
+            t.expect = message_assist_call_idx;
+        });
+
+        // Test parallel tool calls
+        test_peg_parser(tmpls.get(), [&](auto & t) {
+            t.input = "<|think|>I'm\nthinking<|end|>"
+                      "<|begin|>assistant<|tool_calls|>"
+                      "<|tool_call:begin|>0"
+                      "<|tool_call:name|>special_function"
+                      "<|tool_call:args|>{\"arg1\":1}"
+                      "<|tool_call:end|>"
+                      "<|tool_call:begin|>1"
+                      "<|tool_call:name|>special_function_with_opt"
+                      "<|tool_call:args|>{\"arg1\": 1, \"arg2\": 2}"
+                      "<|tool_call:end|>";
+
+            t.params.parallel_tool_calls = true;
+            t.params.tools = {special_function_tool, special_function_tool_with_optional_param};
+
+            t.expect.reasoning_content = "I'm\nthinking";
+            t.expect.tool_calls = {{
+                /* .name = */      "special_function",
+                /* .arguments = */ R"({"arg1": 1})",
+                /* .id = */        "0",
+            }, {
+                /* .name = */      "special_function_with_opt",
+                /* .arguments = */ R"({"arg1": 1, "arg2": 2})",
+                /* .id = */        "1",
+            }};
+        });
+
+        // Test response format
+        test_peg_parser(tmpls.get(), [&](auto & t) {
+            t.input = "<|think|>I need to output the invoice details in JSON<|end|>"
+                      "<|begin|>assistant<|content|>"
+                      R"({"amount": 123.45, "date": "2025-12-03"})";
+
+            t.params.json_schema = invoice_schema;
+
+            t.expect.reasoning_content = "I need to output the invoice details in JSON";
+            t.expect.content =R"({"amount": 123.45, "date": "2025-12-03"})";
+        });
+
+        // Test response format no reasoning
+        test_peg_parser(tmpls.get(), [&](auto & t) {
+            t.input = "<|content|>"
+                      R"({"amount": 123.45, "date": "2025-12-03"})";
+
+            t.params.chat_template_kwargs["reasoning_effort"] = "\"low\"";
+            t.params.json_schema = invoice_schema;
+
+            t.expect.content =R"({"amount": 123.45, "date": "2025-12-03"})";
+        });
+    }
 }
 
 static void test_msg_diffs_compute() {
diff --git a/tests/test-gguf.cpp b/tests/test-gguf.cpp
index 3f0c312e2f0..84b7f3bc491 100644
--- a/tests/test-gguf.cpp
+++ b/tests/test-gguf.cpp
@@ -1,9 +1,11 @@
 #include "ggml.h"
 #include "ggml-backend.h"
 #include "../ggml/src/ggml-impl.h"
+#include "gguf.h"
 
 #include 
 #include 
+#include 
 #include 
 #include 
 #include 
@@ -34,6 +36,7 @@ enum handcrafted_file_type {
     HANDCRAFTED_TENSORS_BAD_N_DIMS         =  20 + offset_has_tensors,
     HANDCRAFTED_TENSORS_BAD_SHAPE          =  30 + offset_has_tensors,
     HANDCRAFTED_TENSORS_NE_TOO_BIG         =  40 + offset_has_tensors,
+    HANDCRAFTED_TENSORS_NBYTES_TOO_BIG     =  45 + offset_has_tensors,
     HANDCRAFTED_TENSORS_BAD_TYPE           =  50 + offset_has_tensors,
     HANDCRAFTED_TENSORS_BAD_OFFSET         =  60 + offset_has_tensors,
     HANDCRAFTED_TENSORS_DUPLICATE_NAME     =  70 + offset_has_tensors,
@@ -69,6 +72,7 @@ static std::string handcrafted_file_type_name(const enum handcrafted_file_type h
         case HANDCRAFTED_TENSORS_BAD_N_DIMS:         return "TENSORS_BAD_N_DIMS";
         case HANDCRAFTED_TENSORS_BAD_SHAPE:          return "TENSORS_BAD_SHAPE";
         case HANDCRAFTED_TENSORS_NE_TOO_BIG:         return "TENSORS_NE_TOO_BIG";
+        case HANDCRAFTED_TENSORS_NBYTES_TOO_BIG:     return "TENSORS_NBYTES_TOO_BIG";
         case HANDCRAFTED_TENSORS_BAD_TYPE:           return "TENSORS_BAD_TYPE";
         case HANDCRAFTED_TENSORS_BAD_OFFSET:         return "TENSORS_BAD_OFFSET";
         case HANDCRAFTED_TENSORS_DUPLICATE_NAME:     return "TENSORS_DUPLICATE_NAME";
@@ -326,7 +330,7 @@ static FILE * get_handcrafted_file(const unsigned int seed, const enum handcraft
 
     uint64_t offset = 0;
     for (int i = 0; i < int(tensor_configs.size()); ++i) {
-        const ggml_type                          type  = tensor_configs[i].first;
+        const ggml_type                          type  = hft == HANDCRAFTED_TENSORS_NBYTES_TOO_BIG ? GGML_TYPE_I64 : tensor_configs[i].first;
         const std::array shape = tensor_configs[i].second;
 
         std::string name = "my_tensor";
@@ -343,7 +347,7 @@ static FILE * get_handcrafted_file(const unsigned int seed, const enum handcraft
         }
         helper_write(file, name.data(), name.length());
 
-        uint32_t n_dims = hft == HANDCRAFTED_TENSORS_NE_TOO_BIG ? 2 : 1;
+        uint32_t n_dims = (hft == HANDCRAFTED_TENSORS_NE_TOO_BIG || hft == HANDCRAFTED_TENSORS_NBYTES_TOO_BIG) ? 2 : 1;
         for (int i = GGML_MAX_DIMS-1; i >= 1; --i) {
             if (shape[i] != 1) {
                 n_dims = i + 1;
@@ -358,13 +362,19 @@ static FILE * get_handcrafted_file(const unsigned int seed, const enum handcraft
         }
 
         if (hft == HANDCRAFTED_TENSORS_BAD_SHAPE) {
+            const int64_t bad_dim = -1;
             for (uint32_t j = 0; j < n_dims; ++j) {
-                const int64_t bad_dim = -1;
                 helper_write(file, bad_dim);
             }
         } else if (hft == HANDCRAFTED_TENSORS_NE_TOO_BIG){
+            const int64_t big_dim = 4*int64_t(INT32_MAX);
+            for (uint32_t j = 0; j < n_dims; ++j) {
+                helper_write(file, big_dim);
+            }
+        } else if (hft == HANDCRAFTED_TENSORS_NBYTES_TOO_BIG){
+            const size_t  big_ne  = SIZE_MAX/ggml_type_size(type);
+            const int64_t big_dim = GGML_PAD(int64_t(1.01f*std::pow(big_ne, 1.0f/n_dims)) + 1, ggml_blck_size(type));
             for (uint32_t j = 0; j < n_dims; ++j) {
-                const int64_t big_dim = 4*int64_t(INT32_MAX);
                 helper_write(file, big_dim);
             }
         } else {
@@ -682,6 +692,7 @@ static std::pair test_handcrafted_file(const unsigned int seed) {
         HANDCRAFTED_TENSORS_BAD_N_DIMS,
         HANDCRAFTED_TENSORS_BAD_SHAPE,
         HANDCRAFTED_TENSORS_NE_TOO_BIG,
+        HANDCRAFTED_TENSORS_NBYTES_TOO_BIG,
         HANDCRAFTED_TENSORS_BAD_TYPE,
         HANDCRAFTED_TENSORS_BAD_OFFSET,
         HANDCRAFTED_TENSORS_DUPLICATE_NAME,
diff --git a/tests/test-jinja.cpp b/tests/test-jinja.cpp
index 99630ecb3b8..1f25c6ae71a 100644
--- a/tests/test-jinja.cpp
+++ b/tests/test-jinja.cpp
@@ -9,6 +9,7 @@
 #include "jinja/runtime.h"
 #include "jinja/parser.h"
 #include "jinja/lexer.h"
+#include "jinja/utils.h"
 
 #include "testing.h"
 
@@ -30,6 +31,7 @@ static void test_tests(testing & t);
 static void test_string_methods(testing & t);
 static void test_array_methods(testing & t);
 static void test_object_methods(testing & t);
+static void test_hasher(testing & t);
 static void test_fuzzing(testing & t);
 
 static bool g_python_mode = false;
@@ -67,6 +69,7 @@ int main(int argc, char *argv[]) {
     t.test("array methods", test_array_methods);
     t.test("object methods", test_object_methods);
     if (!g_python_mode) {
+        t.test("hasher", test_hasher);
         t.test("fuzzing", test_fuzzing);
     }
 
@@ -156,6 +159,18 @@ static void test_conditionals(testing & t) {
         "big"
     );
 
+    test_template(t, "object comparison",
+        "{% if {0: 1, none: 2, 1.0: 3, '0': 4, true: 5} == {false: 1, none: 2, 1: 5, '0': 4} %}equal{% endif %}",
+        json::object(),
+        "equal"
+    );
+
+    test_template(t, "array comparison",
+        "{% if [0, 1.0, false] == [false, 1, 0.0] %}equal{% endif %}",
+        json::object(),
+        "equal"
+    );
+
     test_template(t, "logical and",
         "{% if a and b %}both{% endif %}",
         {{"a", true}, {"b", true}},
@@ -174,12 +189,24 @@ static void test_conditionals(testing & t) {
         "negated"
     );
 
-    test_template(t, "in operator",
+    test_template(t, "in operator (element in array)",
         "{% if 'x' in items %}found{% endif %}",
         {{"items", json::array({"x", "y"})}},
         "found"
     );
 
+    test_template(t, "in operator (substring)",
+        "{% if 'bc' in 'abcd' %}found{% endif %}",
+        json::object(),
+        "found"
+    );
+
+    test_template(t, "in operator (object key)",
+        "{% if 'key' in obj %}found{% endif %}",
+        {{"obj", {{"key", 1}, {"other", 2}}}},
+        "found"
+    );
+
     test_template(t, "is defined",
         "{% if x is defined %}yes{% else %}no{% endif %}",
         {{"x", 1}},
@@ -314,6 +341,12 @@ static void test_loops(testing & t) {
         "empty"
     );
 
+    test_template(t, "for undefined empty",
+        "{% for i in items %}{{ i }}{% else %}empty{% endfor %}",
+        json::object(),
+        "empty"
+    );
+
     test_template(t, "nested for",
         "{% for i in a %}{% for j in b %}{{ i }}{{ j }}{% endfor %}{% endfor %}",
         {{"a", json::array({1, 2})}, {"b", json::array({"x", "y"})}},
@@ -358,6 +391,30 @@ static void test_expressions(testing & t) {
         "b"
     );
 
+    test_template(t, "array negative access",
+        "{{ items[-1] }}",
+        {{"items", json::array({"a", "b", "c"})}},
+        "c"
+    );
+
+    test_template(t, "array slice",
+        "{{ items[1:-1]|string }}",
+        {{"items", json::array({"a", "b", "c"})}},
+        "['b']"
+    );
+
+    test_template(t, "array slice step",
+        "{{ items[::2]|string }}",
+        {{"items", json::array({"a", "b", "c"})}},
+        "['a', 'c']"
+    );
+
+    test_template(t, "tuple slice",
+        "{{ ('a', 'b', 'c')[::-1]|string }}",
+        json::object(),
+        "('c', 'b', 'a')"
+    );
+
     test_template(t, "arithmetic",
         "{{ (a + b) * c }}",
         {{"a", 2}, {"b", 3}, {"c", 4}},
@@ -401,6 +458,36 @@ static void test_set_statement(testing & t) {
         json::object(),
         "1"
     );
+
+    test_template(t, "set dict with mixed type keys",
+        "{% set d = {0: 1, none: 2, 1.0: 3, '0': 4, (0, 0): 5, false: 6, 1: 7} %}{{ d[(0, 0)] + d[0] + d[none] + d['0'] + d[false] + d[1.0] + d[1] }}",
+        json::object(),
+        "37"
+    );
+
+    test_template(t, "print dict with mixed type keys",
+        "{% set d = {0: 1, none: 2, 1.0: 3, '0': 4, (0, 0): 5, true: 6} %}{{ d|string }}",
+        json::object(),
+        "{0: 1, None: 2, 1.0: 6, '0': 4, (0, 0): 5}"
+    );
+
+    test_template(t, "print array with mixed types",
+        "{% set d = [0, none, 1.0, '0', true, (0, 0)] %}{{ d|string }}",
+        json::object(),
+        "[0, None, 1.0, '0', True, (0, 0)]"
+    );
+
+    test_template(t, "object member assignment with mixed key types",
+        "{% set d = namespace() %}{% set d.a = 123 %}{{ d['a'] == 123 }}",
+        json::object(),
+        "True"
+    );
+
+    test_template(t, "tuple unpacking",
+        "{% set t = (1, 2, 3) %}{% set a, b, c = t %}{{ a + b + c }}",
+        json::object(),
+        "6"
+    );
 }
 
 static void test_filters(testing & t) {
@@ -609,6 +696,12 @@ static void test_filters(testing & t) {
         json::object(),
         "hello"
     );
+
+    test_template(t, "none to string",
+        "{{ x|string }}",
+        {{"x", nullptr}},
+        "None"
+    );
 }
 
 static void test_literals(testing & t) {
@@ -943,6 +1036,54 @@ static void test_tests(testing & t) {
         {{"x", {{"a", 1}}}},
         "yes"
     );
+
+    test_template(t, "undefined is sequence",
+        "{{ 'yes' if x is sequence }}",
+        json::object(),
+        "yes"
+    );
+
+    test_template(t, "undefined is iterable",
+        "{{ 'yes' if x is iterable }}",
+        json::object(),
+        "yes"
+    );
+
+    test_template(t, "is in (array, true)",
+        "{{ 'yes' if 2 is in([1, 2, 3]) }}",
+        json::object(),
+        "yes"
+    );
+
+    test_template(t, "is in (array, false)",
+        "{{ 'yes' if 5 is in([1, 2, 3]) else 'no' }}",
+        json::object(),
+        "no"
+    );
+
+    test_template(t, "is in (string)",
+        "{{ 'yes' if 'bc' is in('abcde') }}",
+        json::object(),
+        "yes"
+    );
+
+    test_template(t, "is in (object keys)",
+        "{{ 'yes' if 'a' is in(obj) }}",
+        {{"obj", {{"a", 1}, {"b", 2}}}},
+        "yes"
+    );
+
+    test_template(t, "reject with in test",
+        "{{ items | reject('in', skip) | join(', ') }}",
+        {{"items", json::array({"a", "b", "c", "d"})}, {"skip", json::array({"b", "d"})}},
+        "a, c"
+    );
+
+    test_template(t, "select with in test",
+        "{{ items | select('in', keep) | join(', ') }}",
+        {{"items", json::array({"a", "b", "c", "d"})}, {"keep", json::array({"b", "c"})}},
+        "b, c"
+    );
 }
 
 static void test_string_methods(testing & t) {
@@ -1047,6 +1188,54 @@ static void test_string_methods(testing & t) {
         {{"s", "banana"}},
         "bXnXna"
     );
+
+    test_template(t, "undefined|capitalize",
+        "{{ arr|capitalize }}",
+        json::object(),
+        ""
+    );
+
+    test_template(t, "undefined|title",
+        "{{ arr|title }}",
+        json::object(),
+        ""
+    );
+
+    test_template(t, "undefined|truncate",
+        "{{ arr|truncate(9) }}",
+        json::object(),
+        ""
+    );
+
+    test_template(t, "undefined|upper",
+        "{{ arr|upper }}",
+        json::object(),
+        ""
+    );
+
+    test_template(t, "undefined|lower",
+        "{{ arr|lower }}",
+        json::object(),
+        ""
+    );
+
+    test_template(t, "undefined|replace",
+        "{{ arr|replace('a', 'b') }}",
+        json::object(),
+        ""
+    );
+
+    test_template(t, "undefined|trim",
+        "{{ arr|trim }}",
+        json::object(),
+        ""
+    );
+
+    test_template(t, "undefined|wordcount",
+        "{{ arr|wordcount }}",
+        json::object(),
+        "0"
+    );
 }
 
 static void test_array_methods(testing & t) {
@@ -1214,6 +1403,108 @@ static void test_array_methods(testing & t) {
     //     {{"arr", json::array({"a", "b", "c"})}},
     //     "a,x,b,c"
     // );
+
+    test_template(t, "undefined|select",
+        "{% for item in items|select('odd') %}{{ item.name }} {% endfor %}",
+        json::object(),
+        ""
+    );
+
+    test_template(t, "undefined|selectattr",
+        "{% for item in items|selectattr('active') %}{{ item.name }} {% endfor %}",
+        json::object(),
+        ""
+    );
+
+    test_template(t, "undefined|reject",
+        "{% for item in items|reject('even') %}{{ item.name }} {% endfor %}",
+        json::object(),
+        ""
+    );
+
+    test_template(t, "undefined|rejectattr",
+        "{% for item in items|rejectattr('active') %}{{ item.name }} {% endfor %}",
+        json::object(),
+        ""
+    );
+
+    test_template(t, "undefined|list",
+        "{{ arr|list|string }}",
+        json::object(),
+        "[]"
+    );
+
+    test_template(t, "undefined|string",
+        "{{ arr|string }}",
+        json::object(),
+        ""
+    );
+
+    test_template(t, "undefined|first",
+        "{{ arr|first }}",
+        json::object(),
+        ""
+    );
+
+    test_template(t, "undefined|last",
+        "{{ arr|last }}",
+        json::object(),
+        ""
+    );
+
+    test_template(t, "undefined|length",
+        "{{ arr|length }}",
+        json::object(),
+        "0"
+    );
+
+    test_template(t, "undefined|join",
+        "{{ arr|join }}",
+        json::object(),
+        ""
+    );
+
+    test_template(t, "undefined|sort",
+        "{{ arr|sort|string }}",
+        json::object(),
+        "[]"
+    );
+
+    test_template(t, "undefined|reverse",
+        "{{ arr|reverse|join }}",
+        json::object(),
+        ""
+    );
+
+    test_template(t, "undefined|map",
+        "{% for v in arr|map(attribute='age') %}{{ v }} {% endfor %}",
+        json::object(),
+        ""
+    );
+
+    test_template(t, "undefined|min",
+        "{{ arr|min }}",
+        json::object(),
+        ""
+    );
+
+    test_template(t, "undefined|max",
+        "{{ arr|max }}",
+        json::object(),
+        ""
+    );
+
+    test_template(t, "undefined|unique",
+        "{{ arr|unique|join }}",
+        json::object(),
+        ""
+    );
+
+    test_template(t, "undefined|sum",
+        "{{ arr|sum }}",
+        json::object(),
+        "0"
+    );
 }
 
 static void test_object_methods(testing & t) {
@@ -1306,6 +1597,160 @@ static void test_object_methods(testing & t) {
         {{"obj", {{"a", "b"}}}},
         "True True"
     );
+
+    test_template(t, "expression as object key",
+        "{% set d = {'ab': 123} %}{{ d['a' + 'b'] == 123 }}",
+        json::object(),
+        "True"
+    );
+
+    test_template(t, "numeric as object key (template: Seed-OSS)",
+        "{% set d = {1: 'a', 2: 'b'} %}{{ d[1] == 'a' and d[2] == 'b' }}",
+        json::object(),
+        "True"
+    );
+
+    test_template(t, "undefined|items",
+        "{{ arr|items|join }}",
+        json::object(),
+        ""
+    );
+}
+
+static void test_hasher(testing & t) {
+    static const std::vector> chunk_sizes = {
+        {1, 2},
+        {1, 16},
+        {8, 1},
+        {1, 1024},
+        {5, 512},
+        {16, 256},
+        {45, 122},
+        {70, 634},
+    };
+
+    static auto random_bytes = [](size_t length) -> std::string {
+        std::string data;
+        data.resize(length);
+        for (size_t i = 0; i < length; ++i) {
+            data[i] = static_cast(rand() % 256);
+        }
+        return data;
+    };
+
+    t.test("state unchanged with empty input", [](testing & t) {
+        jinja::hasher hasher;
+        hasher.update("some data");
+        size_t initial_state = hasher.digest();
+        hasher.update("", 0);
+        size_t final_state = hasher.digest();
+        t.assert_true("Hasher state should remain unchanged", initial_state == final_state);
+    });
+
+    t.test("different inputs produce different hashes", [](testing & t) {
+        jinja::hasher hasher1;
+        hasher1.update("data one");
+        size_t hash1 = hasher1.digest();
+
+        jinja::hasher hasher2;
+        hasher2.update("data two");
+        size_t hash2 = hasher2.digest();
+
+        t.assert_true("Different inputs should produce different hashes", hash1 != hash2);
+    });
+
+    t.test("same inputs produce same hashes", [](testing & t) {
+        jinja::hasher hasher1;
+        hasher1.update("consistent data");
+        size_t hash1 = hasher1.digest();
+
+        jinja::hasher hasher2;
+        hasher2.update("consistent data");
+        size_t hash2 = hasher2.digest();
+
+        t.assert_true("Same inputs should produce same hashes", hash1 == hash2);
+    });
+
+    t.test("property: update(a ~ b) == update(a).update(b)", [](testing & t) {
+        for (const auto & [size1, size2] : chunk_sizes) {
+            std::string data1 = random_bytes(size1);
+            std::string data2 = random_bytes(size2);
+
+            jinja::hasher hasher1;
+            hasher1.update(data1);
+            hasher1.update(data2);
+            size_t hash1 = hasher1.digest();
+
+            jinja::hasher hasher2;
+            hasher2.update(data1 + data2);
+            size_t hash2 = hasher2.digest();
+
+            t.assert_true(
+                "Hashing in multiple updates should match single update (" + std::to_string(size1) + ", " + std::to_string(size2) + ")",
+                hash1 == hash2);
+        }
+    });
+
+    t.test("property: update(a ~ b) == update(a).update(b) with more update passes", [](testing & t) {
+        static const std::vector sizes = {3, 732, 131, 13, 17, 256, 436, 99, 4};
+
+        jinja::hasher hasher1;
+        jinja::hasher hasher2;
+
+        std::string combined_data;
+        for (size_t size : sizes) {
+            std::string data = random_bytes(size);
+            hasher1.update(data);
+            combined_data += data;
+        }
+
+        hasher2.update(combined_data);
+        size_t hash1 = hasher1.digest();
+        size_t hash2 = hasher2.digest();
+        t.assert_true(
+            "Hashing in multiple updates should match single update with many chunks",
+            hash1 == hash2);
+    });
+
+    t.test("property: non associativity of update", [](testing & t) {
+        for (const auto & [size1, size2] : chunk_sizes) {
+            std::string data1 = random_bytes(size1);
+            std::string data2 = random_bytes(size2);
+
+            jinja::hasher hasher1;
+            hasher1.update(data1);
+            hasher1.update(data2);
+            size_t hash1 = hasher1.digest();
+
+            jinja::hasher hasher2;
+            hasher2.update(data2);
+            hasher2.update(data1);
+            size_t hash2 = hasher2.digest();
+
+            t.assert_true(
+                "Hashing order should matter (" + std::to_string(size1) + ", " + std::to_string(size2) + ")",
+                hash1 != hash2);
+        }
+    });
+
+    t.test("property: different lengths produce different hashes (padding block size)", [](testing & t) {
+        std::string random_data = random_bytes(64);
+
+        jinja::hasher hasher1;
+        hasher1.update(random_data);
+        size_t hash1 = hasher1.digest();
+
+        for (int i = 0; i < 16; ++i) {
+            random_data.push_back('A');  // change length
+            jinja::hasher hasher2;
+            hasher2.update(random_data);
+            size_t hash2 = hasher2.digest();
+
+            t.assert_true("Different lengths should produce different hashes (length " + std::to_string(random_data.size()) + ")", hash1 != hash2);
+
+            hash1 = hash2;
+        }
+    });
 }
 
 static void test_template_cpp(testing & t, const std::string & name, const std::string & tmpl, const json & vars, const std::string & expect) {
diff --git a/tools/cli/README.md b/tools/cli/README.md
index 3b6f0708ed0..4a15cbad9d7 100644
--- a/tools/cli/README.md
+++ b/tools/cli/README.md
@@ -45,10 +45,10 @@
 | `--rope-freq-base N` | RoPE base frequency, used by NTK-aware scaling (default: loaded from model)
(env: LLAMA_ARG_ROPE_FREQ_BASE) | | `--rope-freq-scale N` | RoPE frequency scaling factor, expands context by a factor of 1/N
(env: LLAMA_ARG_ROPE_FREQ_SCALE) | | `--yarn-orig-ctx N` | YaRN: original context size of model (default: 0 = model training context size)
(env: LLAMA_ARG_YARN_ORIG_CTX) | -| `--yarn-ext-factor N` | YaRN: extrapolation mix factor (default: -1.0, 0.0 = full interpolation)
(env: LLAMA_ARG_YARN_EXT_FACTOR) | -| `--yarn-attn-factor N` | YaRN: scale sqrt(t) or attention magnitude (default: -1.0)
(env: LLAMA_ARG_YARN_ATTN_FACTOR) | -| `--yarn-beta-slow N` | YaRN: high correction dim or alpha (default: -1.0)
(env: LLAMA_ARG_YARN_BETA_SLOW) | -| `--yarn-beta-fast N` | YaRN: low correction dim or beta (default: -1.0)
(env: LLAMA_ARG_YARN_BETA_FAST) | +| `--yarn-ext-factor N` | YaRN: extrapolation mix factor (default: -1.00, 0.0 = full interpolation)
(env: LLAMA_ARG_YARN_EXT_FACTOR) | +| `--yarn-attn-factor N` | YaRN: scale sqrt(t) or attention magnitude (default: -1.00)
(env: LLAMA_ARG_YARN_ATTN_FACTOR) | +| `--yarn-beta-slow N` | YaRN: high correction dim or alpha (default: -1.00)
(env: LLAMA_ARG_YARN_BETA_SLOW) | +| `--yarn-beta-fast N` | YaRN: low correction dim or beta (default: -1.00)
(env: LLAMA_ARG_YARN_BETA_FAST) | | `-kvo, --kv-offload, -nkvo, --no-kv-offload` | whether to enable KV cache offloading (default: enabled)
(env: LLAMA_ARG_KV_OFFLOAD) | | `--repack, -nr, --no-repack` | whether to enable weight repacking (default: enabled)
(env: LLAMA_ARG_REPACK) | | `--no-host` | bypass host buffer allowing extra buffers to be used
(env: LLAMA_ARG_NO_HOST) | @@ -109,30 +109,30 @@ | `-s, --seed SEED` | RNG seed (default: -1, use random seed for -1) | | `--sampler-seq, --sampling-seq SEQUENCE` | simplified sequence for samplers that will be used (default: edskypmxt) | | `--ignore-eos` | ignore end of stream token and continue generating (implies --logit-bias EOS-inf) | -| `--temp N` | temperature (default: 0.8) | +| `--temp N` | temperature (default: 0.80) | | `--top-k N` | top-k sampling (default: 40, 0 = disabled)
(env: LLAMA_ARG_TOP_K) | -| `--top-p N` | top-p sampling (default: 0.9, 1.0 = disabled) | -| `--min-p N` | min-p sampling (default: 0.1, 0.0 = disabled) | -| `--adaptive-target N` | adaptive-p: select tokens near this probability (valid range 0.0 to 1.0; negative = disabled) | -| `--adaptive-decay N` | adaptive-p: EMA decay for adaptation; effective history length ≈ 1/(1-decay) tokens (valid range 0.0 - 0.99) | -| `--top-nsigma N` | top-n-sigma sampling (default: -1.0, -1.0 = disabled) | -| `--xtc-probability N` | xtc probability (default: 0.0, 0.0 = disabled) | -| `--xtc-threshold N` | xtc threshold (default: 0.1, 1.0 = disabled) | -| `--typical N` | locally typical sampling, parameter p (default: 1.0, 1.0 = disabled) | +| `--top-p N` | top-p sampling (default: 0.95, 1.0 = disabled) | +| `--min-p N` | min-p sampling (default: 0.05, 0.0 = disabled) | +| `--top-nsigma N` | top-n-sigma sampling (default: -1.00, -1.0 = disabled) | +| `--xtc-probability N` | xtc probability (default: 0.00, 0.0 = disabled) | +| `--xtc-threshold N` | xtc threshold (default: 0.10, 1.0 = disabled) | +| `--typical N` | locally typical sampling, parameter p (default: 1.00, 1.0 = disabled) | | `--repeat-last-n N` | last n tokens to consider for penalize (default: 64, 0 = disabled, -1 = ctx_size) | -| `--repeat-penalty N` | penalize repeat sequence of tokens (default: 1.0, 1.0 = disabled) | -| `--presence-penalty N` | repeat alpha presence penalty (default: 0.0, 0.0 = disabled) | -| `--frequency-penalty N` | repeat alpha frequency penalty (default: 0.0, 0.0 = disabled) | -| `--dry-multiplier N` | set DRY sampling multiplier (default: 0.0, 0.0 = disabled) | +| `--repeat-penalty N` | penalize repeat sequence of tokens (default: 1.00, 1.0 = disabled) | +| `--presence-penalty N` | repeat alpha presence penalty (default: 0.00, 0.0 = disabled) | +| `--frequency-penalty N` | repeat alpha frequency penalty (default: 0.00, 0.0 = disabled) | +| `--dry-multiplier N` | set DRY sampling multiplier (default: 0.00, 0.0 = disabled) | | `--dry-base N` | set DRY sampling base value (default: 1.75) | | `--dry-allowed-length N` | set allowed length for DRY sampling (default: 2) | | `--dry-penalty-last-n N` | set DRY penalty for the last n tokens (default: -1, 0 = disable, -1 = context size) | | `--dry-sequence-breaker STRING` | add sequence breaker for DRY sampling, clearing out default breakers ('\n', ':', '"', '*') in the process; use "none" to not use any sequence breakers | -| `--dynatemp-range N` | dynamic temperature range (default: 0.0, 0.0 = disabled) | -| `--dynatemp-exp N` | dynamic temperature exponent (default: 1.0) | +| `--adaptive-target N` | adaptive-p: select tokens near this probability (valid range 0.0 to 1.0; negative = disabled) (default: -1.00)
[(more info)](https://github.com/ggml-org/llama.cpp/pull/17927) | +| `--adaptive-decay N` | adaptive-p: decay rate for target adaptation over time. lower values are more reactive, higher values are more stable.
(valid range 0.0 to 0.99) (default: 0.90) | +| `--dynatemp-range N` | dynamic temperature range (default: 0.00, 0.0 = disabled) | +| `--dynatemp-exp N` | dynamic temperature exponent (default: 1.00) | | `--mirostat N` | use Mirostat sampling.
Top K, Nucleus and Locally Typical samplers are ignored if used.
(default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0) | -| `--mirostat-lr N` | Mirostat learning rate, parameter eta (default: 0.1) | -| `--mirostat-ent N` | Mirostat target entropy, parameter tau (default: 5.0) | +| `--mirostat-lr N` | Mirostat learning rate, parameter eta (default: 0.10) | +| `--mirostat-ent N` | Mirostat target entropy, parameter tau (default: 5.00) | | `-l, --logit-bias TOKEN_ID(+/-)BIAS` | modifies the likelihood of token appearing in the completion,
i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',
or `--logit-bias 15043-1` to decrease likelihood of token ' Hello' | | `--grammar GRAMMAR` | BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '') | | `--grammar-file FNAME` | file to read grammar from | @@ -173,12 +173,12 @@ | `--jinja, --no-jinja` | whether to use jinja template engine for chat (default: enabled)
(env: LLAMA_ARG_JINJA) | | `--reasoning-format FORMAT` | controls whether thought tags are allowed and/or extracted from the response, and in which format they're returned; one of:
- none: leaves thoughts unparsed in `message.content`
- deepseek: puts thoughts in `message.reasoning_content`
- deepseek-legacy: keeps `` tags in `message.content` while also populating `message.reasoning_content`
(default: auto)
(env: LLAMA_ARG_THINK) | | `--reasoning-budget N` | controls the amount of thinking allowed; currently only one of: -1 for unrestricted thinking budget, or 0 to disable thinking (default: -1)
(env: LLAMA_ARG_THINK_BUDGET) | -| `--chat-template JINJA_TEMPLATE` | set custom jinja chat template (default: template taken from model's metadata)
if suffix/prefix are specified, template will be disabled
only commonly used templates are accepted (unless --jinja is set before this flag):
list of built-in templates:
bailing, bailing-think, bailing2, chatglm3, chatglm4, chatml, command-r, deepseek, deepseek2, deepseek3, exaone3, exaone4, falcon3, gemma, gigachat, glmedge, gpt-oss, granite, grok-2, hunyuan-dense, hunyuan-moe, kimi-k2, llama2, llama2-sys, llama2-sys-bos, llama2-sys-strip, llama3, llama4, megrez, minicpm, mistral-v1, mistral-v3, mistral-v3-tekken, mistral-v7, mistral-v7-tekken, monarch, openchat, orion, pangu-embedded, phi3, phi4, rwkv-world, seed_oss, smolvlm, solar-open, vicuna, vicuna-orca, yandex, zephyr
(env: LLAMA_ARG_CHAT_TEMPLATE) | -| `--chat-template-file JINJA_TEMPLATE_FILE` | set custom jinja chat template file (default: template taken from model's metadata)
if suffix/prefix are specified, template will be disabled
only commonly used templates are accepted (unless --jinja is set before this flag):
list of built-in templates:
bailing, bailing-think, bailing2, chatglm3, chatglm4, chatml, command-r, deepseek, deepseek2, deepseek3, exaone3, exaone4, falcon3, gemma, gigachat, glmedge, gpt-oss, granite, grok-2, hunyuan-dense, hunyuan-moe, kimi-k2, llama2, llama2-sys, llama2-sys-bos, llama2-sys-strip, llama3, llama4, megrez, minicpm, mistral-v1, mistral-v3, mistral-v3-tekken, mistral-v7, mistral-v7-tekken, monarch, openchat, orion, pangu-embedded, phi3, phi4, rwkv-world, seed_oss, smolvlm, solar-open, vicuna, vicuna-orca, yandex, zephyr
(env: LLAMA_ARG_CHAT_TEMPLATE_FILE) | +| `--chat-template JINJA_TEMPLATE` | set custom jinja chat template (default: template taken from model's metadata)
if suffix/prefix are specified, template will be disabled
only commonly used templates are accepted (unless --jinja is set before this flag):
list of built-in templates:
bailing, bailing-think, bailing2, chatglm3, chatglm4, chatml, command-r, deepseek, deepseek2, deepseek3, exaone-moe, exaone3, exaone4, falcon3, gemma, gigachat, glmedge, gpt-oss, granite, grok-2, hunyuan-dense, hunyuan-moe, kimi-k2, llama2, llama2-sys, llama2-sys-bos, llama2-sys-strip, llama3, llama4, megrez, minicpm, mistral-v1, mistral-v3, mistral-v3-tekken, mistral-v7, mistral-v7-tekken, monarch, openchat, orion, pangu-embedded, phi3, phi4, rwkv-world, seed_oss, smolvlm, solar-open, vicuna, vicuna-orca, yandex, zephyr
(env: LLAMA_ARG_CHAT_TEMPLATE) | +| `--chat-template-file JINJA_TEMPLATE_FILE` | set custom jinja chat template file (default: template taken from model's metadata)
if suffix/prefix are specified, template will be disabled
only commonly used templates are accepted (unless --jinja is set before this flag):
list of built-in templates:
bailing, bailing-think, bailing2, chatglm3, chatglm4, chatml, command-r, deepseek, deepseek2, deepseek3, exaone-moe, exaone3, exaone4, falcon3, gemma, gigachat, glmedge, gpt-oss, granite, grok-2, hunyuan-dense, hunyuan-moe, kimi-k2, llama2, llama2-sys, llama2-sys-bos, llama2-sys-strip, llama3, llama4, megrez, minicpm, mistral-v1, mistral-v3, mistral-v3-tekken, mistral-v7, mistral-v7-tekken, monarch, openchat, orion, pangu-embedded, phi3, phi4, rwkv-world, seed_oss, smolvlm, solar-open, vicuna, vicuna-orca, yandex, zephyr
(env: LLAMA_ARG_CHAT_TEMPLATE_FILE) | | `--simple-io` | use basic IO for better compatibility in subprocesses and limited consoles | | `--draft, --draft-n, --draft-max N` | number of tokens to draft for speculative decoding (default: 16)
(env: LLAMA_ARG_DRAFT_MAX) | | `--draft-min, --draft-n-min N` | minimum number of draft tokens to use for speculative decoding (default: 0)
(env: LLAMA_ARG_DRAFT_MIN) | -| `--draft-p-min P` | minimum speculative decoding probability (greedy) (default: 0.8)
(env: LLAMA_ARG_DRAFT_P_MIN) | +| `--draft-p-min P` | minimum speculative decoding probability (greedy) (default: 0.75)
(env: LLAMA_ARG_DRAFT_P_MIN) | | `-cd, --ctx-size-draft N` | size of the prompt context for the draft model (default: 0, 0 = loaded from model)
(env: LLAMA_ARG_CTX_SIZE_DRAFT) | | `-devd, --device-draft ` | comma-separated list of devices to use for offloading the draft model (none = don't offload)
use --list-devices to see a list of available devices | | `-ngld, --gpu-layers-draft, --n-gpu-layers-draft N` | max. number of draft model layers to store in VRAM, either an exact number, 'auto', or 'all' (default: auto)
(env: LLAMA_ARG_N_GPU_LAYERS_DRAFT) | diff --git a/tools/cli/cli.cpp b/tools/cli/cli.cpp index 0926e552e92..02ccb725981 100644 --- a/tools/cli/cli.cpp +++ b/tools/cli/cli.cpp @@ -84,6 +84,9 @@ struct cli_context { // chat template settings task.params.chat_parser_params = common_chat_parser_params(chat_params); task.params.chat_parser_params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; + if (!chat_params.parser.empty()) { + task.params.chat_parser_params.parser.load(chat_params.parser); + } rd.post_task({std::move(task)}); } diff --git a/tools/completion/README.md b/tools/completion/README.md index a16be3f684a..3ca3e684541 100644 --- a/tools/completion/README.md +++ b/tools/completion/README.md @@ -128,10 +128,10 @@ llama-completion.exe -m models\gemma-1.1-7b-it.Q4_K_M.gguf --ignore-eos -n -1 | `--rope-freq-base N` | RoPE base frequency, used by NTK-aware scaling (default: loaded from model)
(env: LLAMA_ARG_ROPE_FREQ_BASE) | | `--rope-freq-scale N` | RoPE frequency scaling factor, expands context by a factor of 1/N
(env: LLAMA_ARG_ROPE_FREQ_SCALE) | | `--yarn-orig-ctx N` | YaRN: original context size of model (default: 0 = model training context size)
(env: LLAMA_ARG_YARN_ORIG_CTX) | -| `--yarn-ext-factor N` | YaRN: extrapolation mix factor (default: -1.0, 0.0 = full interpolation)
(env: LLAMA_ARG_YARN_EXT_FACTOR) | -| `--yarn-attn-factor N` | YaRN: scale sqrt(t) or attention magnitude (default: -1.0)
(env: LLAMA_ARG_YARN_ATTN_FACTOR) | -| `--yarn-beta-slow N` | YaRN: high correction dim or alpha (default: -1.0)
(env: LLAMA_ARG_YARN_BETA_SLOW) | -| `--yarn-beta-fast N` | YaRN: low correction dim or beta (default: -1.0)
(env: LLAMA_ARG_YARN_BETA_FAST) | +| `--yarn-ext-factor N` | YaRN: extrapolation mix factor (default: -1.00, 0.0 = full interpolation)
(env: LLAMA_ARG_YARN_EXT_FACTOR) | +| `--yarn-attn-factor N` | YaRN: scale sqrt(t) or attention magnitude (default: -1.00)
(env: LLAMA_ARG_YARN_ATTN_FACTOR) | +| `--yarn-beta-slow N` | YaRN: high correction dim or alpha (default: -1.00)
(env: LLAMA_ARG_YARN_BETA_SLOW) | +| `--yarn-beta-fast N` | YaRN: low correction dim or beta (default: -1.00)
(env: LLAMA_ARG_YARN_BETA_FAST) | | `-kvo, --kv-offload, -nkvo, --no-kv-offload` | whether to enable KV cache offloading (default: enabled)
(env: LLAMA_ARG_KV_OFFLOAD) | | `--repack, -nr, --no-repack` | whether to enable weight repacking (default: enabled)
(env: LLAMA_ARG_REPACK) | | `--no-host` | bypass host buffer allowing extra buffers to be used
(env: LLAMA_ARG_NO_HOST) | @@ -192,28 +192,30 @@ llama-completion.exe -m models\gemma-1.1-7b-it.Q4_K_M.gguf --ignore-eos -n -1 | `-s, --seed SEED` | RNG seed (default: -1, use random seed for -1) | | `--sampler-seq, --sampling-seq SEQUENCE` | simplified sequence for samplers that will be used (default: edskypmxt) | | `--ignore-eos` | ignore end of stream token and continue generating (implies --logit-bias EOS-inf) | -| `--temp N` | temperature (default: 0.8) | +| `--temp N` | temperature (default: 0.80) | | `--top-k N` | top-k sampling (default: 40, 0 = disabled)
(env: LLAMA_ARG_TOP_K) | -| `--top-p N` | top-p sampling (default: 0.9, 1.0 = disabled) | -| `--min-p N` | min-p sampling (default: 0.1, 0.0 = disabled) | -| `--top-nsigma N` | top-n-sigma sampling (default: -1.0, -1.0 = disabled) | -| `--xtc-probability N` | xtc probability (default: 0.0, 0.0 = disabled) | -| `--xtc-threshold N` | xtc threshold (default: 0.1, 1.0 = disabled) | -| `--typical N` | locally typical sampling, parameter p (default: 1.0, 1.0 = disabled) | +| `--top-p N` | top-p sampling (default: 0.95, 1.0 = disabled) | +| `--min-p N` | min-p sampling (default: 0.05, 0.0 = disabled) | +| `--top-nsigma N` | top-n-sigma sampling (default: -1.00, -1.0 = disabled) | +| `--xtc-probability N` | xtc probability (default: 0.00, 0.0 = disabled) | +| `--xtc-threshold N` | xtc threshold (default: 0.10, 1.0 = disabled) | +| `--typical N` | locally typical sampling, parameter p (default: 1.00, 1.0 = disabled) | | `--repeat-last-n N` | last n tokens to consider for penalize (default: 64, 0 = disabled, -1 = ctx_size) | -| `--repeat-penalty N` | penalize repeat sequence of tokens (default: 1.0, 1.0 = disabled) | -| `--presence-penalty N` | repeat alpha presence penalty (default: 0.0, 0.0 = disabled) | -| `--frequency-penalty N` | repeat alpha frequency penalty (default: 0.0, 0.0 = disabled) | -| `--dry-multiplier N` | set DRY sampling multiplier (default: 0.0, 0.0 = disabled) | +| `--repeat-penalty N` | penalize repeat sequence of tokens (default: 1.00, 1.0 = disabled) | +| `--presence-penalty N` | repeat alpha presence penalty (default: 0.00, 0.0 = disabled) | +| `--frequency-penalty N` | repeat alpha frequency penalty (default: 0.00, 0.0 = disabled) | +| `--dry-multiplier N` | set DRY sampling multiplier (default: 0.00, 0.0 = disabled) | | `--dry-base N` | set DRY sampling base value (default: 1.75) | | `--dry-allowed-length N` | set allowed length for DRY sampling (default: 2) | | `--dry-penalty-last-n N` | set DRY penalty for the last n tokens (default: -1, 0 = disable, -1 = context size) | | `--dry-sequence-breaker STRING` | add sequence breaker for DRY sampling, clearing out default breakers ('\n', ':', '"', '*') in the process; use "none" to not use any sequence breakers | -| `--dynatemp-range N` | dynamic temperature range (default: 0.0, 0.0 = disabled) | -| `--dynatemp-exp N` | dynamic temperature exponent (default: 1.0) | +| `--adaptive-target N` | adaptive-p: select tokens near this probability (valid range 0.0 to 1.0; negative = disabled) (default: -1.00)
[(more info)](https://github.com/ggml-org/llama.cpp/pull/17927) | +| `--adaptive-decay N` | adaptive-p: decay rate for target adaptation over time. lower values are more reactive, higher values are more stable.
(valid range 0.0 to 0.99) (default: 0.90) | +| `--dynatemp-range N` | dynamic temperature range (default: 0.00, 0.0 = disabled) | +| `--dynatemp-exp N` | dynamic temperature exponent (default: 1.00) | | `--mirostat N` | use Mirostat sampling.
Top K, Nucleus and Locally Typical samplers are ignored if used.
(default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0) | -| `--mirostat-lr N` | Mirostat learning rate, parameter eta (default: 0.1) | -| `--mirostat-ent N` | Mirostat target entropy, parameter tau (default: 5.0) | +| `--mirostat-lr N` | Mirostat learning rate, parameter eta (default: 0.10) | +| `--mirostat-ent N` | Mirostat target entropy, parameter tau (default: 5.00) | | `-l, --logit-bias TOKEN_ID(+/-)BIAS` | modifies the likelihood of token appearing in the completion,
i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',
or `--logit-bias 15043-1` to decrease likelihood of token ' Hello' | | `--grammar GRAMMAR` | BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '') | | `--grammar-file FNAME` | file to read grammar from | @@ -251,8 +253,8 @@ llama-completion.exe -m models\gemma-1.1-7b-it.Q4_K_M.gguf --ignore-eos -n -1 | `--jinja, --no-jinja` | whether to use jinja template engine for chat (default: disabled)
(env: LLAMA_ARG_JINJA) | | `--reasoning-format FORMAT` | controls whether thought tags are allowed and/or extracted from the response, and in which format they're returned; one of:
- none: leaves thoughts unparsed in `message.content`
- deepseek: puts thoughts in `message.reasoning_content`
- deepseek-legacy: keeps `` tags in `message.content` while also populating `message.reasoning_content`
(default: auto)
(env: LLAMA_ARG_THINK) | | `--reasoning-budget N` | controls the amount of thinking allowed; currently only one of: -1 for unrestricted thinking budget, or 0 to disable thinking (default: -1)
(env: LLAMA_ARG_THINK_BUDGET) | -| `--chat-template JINJA_TEMPLATE` | set custom jinja chat template (default: template taken from model's metadata)
if suffix/prefix are specified, template will be disabled
only commonly used templates are accepted (unless --jinja is set before this flag):
list of built-in templates:
bailing, bailing-think, bailing2, chatglm3, chatglm4, chatml, command-r, deepseek, deepseek2, deepseek3, exaone3, exaone4, falcon3, gemma, gigachat, glmedge, gpt-oss, granite, grok-2, hunyuan-dense, hunyuan-moe, kimi-k2, llama2, llama2-sys, llama2-sys-bos, llama2-sys-strip, llama3, llama4, megrez, minicpm, mistral-v1, mistral-v3, mistral-v3-tekken, mistral-v7, mistral-v7-tekken, monarch, openchat, orion, pangu-embedded, phi3, phi4, rwkv-world, seed_oss, smolvlm, solar-open, vicuna, vicuna-orca, yandex, zephyr
(env: LLAMA_ARG_CHAT_TEMPLATE) | -| `--chat-template-file JINJA_TEMPLATE_FILE` | set custom jinja chat template file (default: template taken from model's metadata)
if suffix/prefix are specified, template will be disabled
only commonly used templates are accepted (unless --jinja is set before this flag):
list of built-in templates:
bailing, bailing-think, bailing2, chatglm3, chatglm4, chatml, command-r, deepseek, deepseek2, deepseek3, exaone3, exaone4, falcon3, gemma, gigachat, glmedge, gpt-oss, granite, grok-2, hunyuan-dense, hunyuan-moe, kimi-k2, llama2, llama2-sys, llama2-sys-bos, llama2-sys-strip, llama3, llama4, megrez, minicpm, mistral-v1, mistral-v3, mistral-v3-tekken, mistral-v7, mistral-v7-tekken, monarch, openchat, orion, pangu-embedded, phi3, phi4, rwkv-world, seed_oss, smolvlm, solar-open, vicuna, vicuna-orca, yandex, zephyr
(env: LLAMA_ARG_CHAT_TEMPLATE_FILE) | +| `--chat-template JINJA_TEMPLATE` | set custom jinja chat template (default: template taken from model's metadata)
if suffix/prefix are specified, template will be disabled
only commonly used templates are accepted (unless --jinja is set before this flag):
list of built-in templates:
bailing, bailing-think, bailing2, chatglm3, chatglm4, chatml, command-r, deepseek, deepseek2, deepseek3, exaone-moe, exaone3, exaone4, falcon3, gemma, gigachat, glmedge, gpt-oss, granite, grok-2, hunyuan-dense, hunyuan-moe, kimi-k2, llama2, llama2-sys, llama2-sys-bos, llama2-sys-strip, llama3, llama4, megrez, minicpm, mistral-v1, mistral-v3, mistral-v3-tekken, mistral-v7, mistral-v7-tekken, monarch, openchat, orion, pangu-embedded, phi3, phi4, rwkv-world, seed_oss, smolvlm, solar-open, vicuna, vicuna-orca, yandex, zephyr
(env: LLAMA_ARG_CHAT_TEMPLATE) | +| `--chat-template-file JINJA_TEMPLATE_FILE` | set custom jinja chat template file (default: template taken from model's metadata)
if suffix/prefix are specified, template will be disabled
only commonly used templates are accepted (unless --jinja is set before this flag):
list of built-in templates:
bailing, bailing-think, bailing2, chatglm3, chatglm4, chatml, command-r, deepseek, deepseek2, deepseek3, exaone-moe, exaone3, exaone4, falcon3, gemma, gigachat, glmedge, gpt-oss, granite, grok-2, hunyuan-dense, hunyuan-moe, kimi-k2, llama2, llama2-sys, llama2-sys-bos, llama2-sys-strip, llama3, llama4, megrez, minicpm, mistral-v1, mistral-v3, mistral-v3-tekken, mistral-v7, mistral-v7-tekken, monarch, openchat, orion, pangu-embedded, phi3, phi4, rwkv-world, seed_oss, smolvlm, solar-open, vicuna, vicuna-orca, yandex, zephyr
(env: LLAMA_ARG_CHAT_TEMPLATE_FILE) | | `--simple-io` | use basic IO for better compatibility in subprocesses and limited consoles | diff --git a/tools/completion/completion.cpp b/tools/completion/completion.cpp index a9eda119d72..977132756f7 100644 --- a/tools/completion/completion.cpp +++ b/tools/completion/completion.cpp @@ -342,44 +342,51 @@ int main(int argc, char ** argv) { return 1; } - // debug message about similarity of saved session, if applicable - size_t n_matching_session_tokens = 0; - if (!session_tokens.empty()) { - for (llama_token id : session_tokens) { - if (n_matching_session_tokens >= embd_inp.size() || id != embd_inp[n_matching_session_tokens]) { - break; - } - n_matching_session_tokens++; - } - if (params.prompt.empty() && n_matching_session_tokens == embd_inp.size()) { - LOG_INF("%s: using full prompt from session file\n", __func__); - } else if (n_matching_session_tokens >= embd_inp.size()) { - LOG_INF("%s: session file has exact match for prompt!\n", __func__); - } else if (n_matching_session_tokens < (embd_inp.size() / 2)) { - LOG_WRN("%s: session file has low similarity to prompt (%zu / %zu tokens); will mostly be reevaluated\n", - __func__, n_matching_session_tokens, embd_inp.size()); - } else { - LOG_INF("%s: session file matches %zu / %zu tokens of prompt\n", - __func__, n_matching_session_tokens, embd_inp.size()); - } + bool session_do_save = false; - // remove any "future" tokens that we might have inherited from the previous session - if (!llama_memory_seq_rm(mem, -1, n_matching_session_tokens, -1)) { - LOG_INF("%s: unable to resuse common prefix\n", __func__); - n_matching_session_tokens = 0; - llama_memory_seq_rm(mem, -1, -1, -1); - } - } + { + size_t n_match = 0; + + if (!session_tokens.empty()) { + for (llama_token id : session_tokens) { + if (n_match >= embd_inp.size() || id != embd_inp[n_match]) { + break; + } + n_match++; + } + if (params.prompt.empty() && n_match == embd_inp.size()) { + LOG_INF("%s: using full prompt from session file\n", __func__); + } else if (n_match >= embd_inp.size()) { + LOG_INF("%s: session file has exact match for prompt!\n", __func__); + } else if (n_match < (embd_inp.size() / 2)) { + LOG_WRN("%s: session file has low similarity to prompt (%zu / %zu tokens); will mostly be reevaluated\n", + __func__, n_match, embd_inp.size()); + } else { + LOG_INF("%s: session file matches %zu / %zu tokens of prompt\n", + __func__, n_match, embd_inp.size()); + } - LOG_DBG("recalculate the cached logits (check): embd_inp.size() %zu, n_matching_session_tokens %zu, embd_inp.size() %zu, session_tokens.size() %zu\n", - embd_inp.size(), n_matching_session_tokens, embd_inp.size(), session_tokens.size()); + if (session_tokens.size() == n_match) { + // [TAG_CONTEXT_STATE_LOGITS] + // in this case, we are going to reuse the logits from the session + // if we ever decide to remove the logits from the session, we need to handle this somehow + // ref: https://github.com/ggml-org/llama.cpp/pull/18862#issuecomment-3756330941 + } - // if we will use the cache for the full prompt without reaching the end of the cache, force - // reevaluation of the last token to recalculate the cached logits - if (!embd_inp.empty() && n_matching_session_tokens == embd_inp.size() && session_tokens.size() > embd_inp.size()) { - LOG_DBG("recalculate the cached logits (do): session_tokens.resize( %zu )\n", embd_inp.size() - 1); + // remove any "future" tokens that we might have inherited from the previous session + if (session_tokens.size() > n_match) { + if (!llama_memory_seq_rm(mem, -1, n_match, -1)) { + LOG_WRN("%s: unable to resuse common prefix (for example, when the memory is recurrent)\n", __func__); + llama_memory_clear(mem, true); + session_tokens.clear(); + n_match = 0; + } else { + session_tokens.resize(n_match); + } + } + } - session_tokens.resize(embd_inp.size() - 1); + session_do_save = !path_session.empty() && n_match < embd_inp.size() && !params.prompt_cache_ro; } // number of tokens to keep when resetting context @@ -521,10 +528,9 @@ int main(int argc, char ** argv) { is_interacting = params.interactive_first; } - bool is_antiprompt = false; - bool input_echo = true; - bool display = true; - bool need_to_save_session = !path_session.empty() && n_matching_session_tokens < embd_inp.size(); + bool is_antiprompt = false; + bool input_echo = true; + bool display = true; int n_past = 0; int n_remain = params.n_predict; @@ -668,15 +674,12 @@ int main(int argc, char ** argv) { } } - for (int i = 0; i < (int) embd.size(); i += params.n_batch) { - int n_eval = (int) embd.size() - i; - if (n_eval > params.n_batch) { - n_eval = params.n_batch; - } - + if (!embd.empty()) { + int n_eval = (int) embd.size(); LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str()); - if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval))) { + GGML_ASSERT(n_eval <= params.n_batch); + if (llama_decode(ctx, llama_batch_get_one(embd.data(), n_eval))) { LOG_ERR("%s : failed to eval\n", __func__); return 1; } @@ -700,8 +703,8 @@ int main(int argc, char ** argv) { if ((int) embd_inp.size() <= n_consumed && !is_interacting) { // optionally save the session on first sample (for faster prompt loading next time) - if (!path_session.empty() && need_to_save_session && !params.prompt_cache_ro) { - need_to_save_session = false; + if (session_do_save) { + session_do_save = false; llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size()); LOG_DBG("saved session to %s\n", path_session.c_str()); @@ -737,7 +740,7 @@ int main(int argc, char ** argv) { common_sampler_accept(smpl, embd_inp[n_consumed], /* accept_grammar= */ false); ++n_consumed; - if ((int) embd.size() >= params.n_batch) { + if ((int) embd.size() == params.n_batch) { break; } } diff --git a/tools/cvector-generator/pca.hpp b/tools/cvector-generator/pca.hpp index e88bbdde93f..afd3bf6380e 100644 --- a/tools/cvector-generator/pca.hpp +++ b/tools/cvector-generator/pca.hpp @@ -290,7 +290,7 @@ static void power_iteration( ggml_gallocr_free(allocr); // TODO @ngxson : The output vector is randomly inverted - // Solution: https://github.com/ggerganov/llama.cpp/pull/8069#issuecomment-2185328171 + // Solution: https://github.com/ggml-org/llama.cpp/pull/8069#issuecomment-2185328171 } static void run_pca( diff --git a/tools/export-lora/export-lora.cpp b/tools/export-lora/export-lora.cpp index f038019b007..41f426208f8 100644 --- a/tools/export-lora/export-lora.cpp +++ b/tools/export-lora/export-lora.cpp @@ -190,7 +190,7 @@ struct lora_merge_ctx { gguf_set_val_u32(ctx_out, "general.file_type", LLAMA_FTYPE_MOSTLY_F16); // check if all lora adapters have the same tensors - // TODO: remove this when we can support merging subset of adapters. Ref: https://github.com/ggerganov/llama.cpp/pull/8607#discussion_r1686027777 + // TODO: remove this when we can support merging subset of adapters. Ref: https://github.com/ggml-org/llama.cpp/pull/8607#discussion_r1686027777 static const char * err_no_subset_adapter = "Input adapters do not have the same list of tensors. This is not yet supported. Please merge the adapter one-by-one instead of merging all at once."; if (adapters.size() > 1) { for (size_t i = 1; i < adapters.size(); ++i) { diff --git a/tools/fit-params/fit-params.cpp b/tools/fit-params/fit-params.cpp index f9d9cb34c7d..0176be06e78 100644 --- a/tools/fit-params/fit-params.cpp +++ b/tools/fit-params/fit-params.cpp @@ -36,7 +36,7 @@ int main(int argc, char ** argv) { LOG_INF("%s: printing fitted CLI arguments to stdout...\n", __func__); common_log_flush(common_log_main()); - printf("-c %" PRIu32 " -ngl %" PRIu32, cparams.n_ctx, mparams.n_gpu_layers); + printf("-c %" PRIu32 " -ngl %" PRIi32, cparams.n_ctx, mparams.n_gpu_layers); size_t nd = llama_max_devices(); while (nd > 1 && mparams.tensor_split[nd - 1] == 0.0f) { diff --git a/tools/mtmd/CMakeLists.txt b/tools/mtmd/CMakeLists.txt index 751440af323..56ac7c377eb 100644 --- a/tools/mtmd/CMakeLists.txt +++ b/tools/mtmd/CMakeLists.txt @@ -19,6 +19,7 @@ add_library(mtmd models/glm4v.cpp models/internvl.cpp models/kimivl.cpp + models/jinaclip2.cpp models/llama4.cpp models/llava.cpp models/minicpmv.cpp diff --git a/tools/mtmd/clip-graph.h b/tools/mtmd/clip-graph.h index 4c7f7504cfc..62145d6dc98 100644 --- a/tools/mtmd/clip-graph.h +++ b/tools/mtmd/clip-graph.h @@ -31,6 +31,7 @@ struct clip_graph { const float eps; const float kq_scale; const clip_flash_attn_type flash_attn_type; + norm_type block_norm_t = NORM_TYPE_NORMAL; ggml_context_ptr ctx0_ptr; ggml_context * ctx0; diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index dd693623a26..f001bdfb16e 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -36,6 +36,8 @@ // vision-specific #define KEY_VISION_PROJ_TYPE "clip.vision.projector_type" // for models with mixed modalities #define KEY_IMAGE_SIZE "clip.vision.image_size" +#define KEY_IMAGE_MIN_PIXELS "clip.vision.image_min_pixels" +#define KEY_IMAGE_MAX_PIXELS "clip.vision.image_max_pixels" #define KEY_PREPROC_IMAGE_SIZE "clip.vision.preproc_image_size" #define KEY_PATCH_SIZE "clip.vision.patch_size" #define KEY_IMAGE_MEAN "clip.vision.image_mean" @@ -44,6 +46,7 @@ #define KEY_PROJ_SCALE_FACTOR "clip.vision.projector.scale_factor" #define KEY_SPATIAL_MERGE_SIZE "clip.vision.spatial_merge_size" #define KEY_IS_DEEPSTACK_LAYERS "clip.vision.is_deepstack_layers" +#define KEY_VISION_ROPE_THETA "clip.vision.rope_theta" #define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type" #define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints" @@ -75,6 +78,7 @@ #define TN_ATTN_Q "%s.blk.%d.attn_q.%s" #define TN_ATTN_V "%s.blk.%d.attn_v.%s" #define TN_ATTN_OUTPUT "%s.blk.%d.attn_out.%s" +#define TN_ATTN_LN "%s.blk.%d.attn_ln.%s" #define TN_ATTN_K_NORM "%s.blk.%d.attn_k_norm.%s" #define TN_ATTN_Q_NORM "%s.blk.%d.attn_q_norm.%s" #define TN_FFN_DOWN "%s.blk.%d.ffn_down.%s" @@ -225,6 +229,7 @@ enum projector_type { PROJECTOR_TYPE_QWEN25O, // will be replaced by QWEN2A or QWEN25VL depending on clip_ctx PROJECTOR_TYPE_VOXTRAL, PROJECTOR_TYPE_MUSIC_FLAMINGO, + PROJECTOR_TYPE_JINACLIP2, // JinaCLIP v2 PROJECTOR_TYPE_LFM2, PROJECTOR_TYPE_KIMIVL, PROJECTOR_TYPE_LIGHTONOCR, @@ -261,6 +266,7 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_LFM2, "lfm2"}, { PROJECTOR_TYPE_KIMIVL, "kimivl"}, { PROJECTOR_TYPE_LIGHTONOCR,"lightonocr"}, + { PROJECTOR_TYPE_JINACLIP2, "jinaclip2"}, { PROJECTOR_TYPE_COGVLM, "cogvlm"}, { PROJECTOR_TYPE_JANUS_PRO, "janus_pro"}, { PROJECTOR_TYPE_LFM2A, "lfm2a"}, diff --git a/tools/mtmd/clip-model.h b/tools/mtmd/clip-model.h index d4ff9151bb0..4c01a8ad544 100644 --- a/tools/mtmd/clip-model.h +++ b/tools/mtmd/clip-model.h @@ -117,6 +117,9 @@ struct clip_layer { ggml_tensor * k_norm = nullptr; ggml_tensor * q_norm = nullptr; + ggml_tensor * attn_out_norm_w = nullptr; + ggml_tensor * attn_out_norm_b = nullptr; + // layernorm 1 ggml_tensor * ln_1_w = nullptr; ggml_tensor * ln_1_b = nullptr; @@ -125,6 +128,8 @@ struct clip_layer { ggml_tensor * ff_up_b = nullptr; ggml_tensor * ff_gate_w = nullptr; ggml_tensor * ff_gate_b = nullptr; + ggml_tensor * ffn_hidden_norm_w = nullptr; + ggml_tensor * ffn_hidden_norm_b = nullptr; ggml_tensor * ff_down_w = nullptr; ggml_tensor * ff_down_b = nullptr; diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 9b076e0c562..611259776fc 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -292,6 +292,8 @@ ggml_tensor * clip_graph::build_vit( ggml_tensor * learned_pos_embd, std::function add_pos ) { + block_norm_t = norm_t; + if (learned_pos_embd) { inp = ggml_add(ctx0, inp, learned_pos_embd); cb(inp, "pos_embed", -1); @@ -489,7 +491,6 @@ ggml_tensor * clip_graph::build_norm( cur = ggml_add(ctx0, cur, mb); cb(cur, "norm_b", il); } - return cur; } @@ -560,6 +561,14 @@ ggml_tensor * clip_graph::build_ffn( } break; } + if (il >= 0 && il < (int) model.layers.size()) { + const auto & layer = model.layers[il]; + if (layer.ffn_hidden_norm_w) { + cur = build_norm(cur, layer.ffn_hidden_norm_w, layer.ffn_hidden_norm_b, block_norm_t, eps, il); + cb(cur, "ffn_hidden_normed", il); + } + } + if (down) { cur = ggml_mul_mat(ctx0, down, cur); } @@ -629,6 +638,14 @@ ggml_tensor * clip_graph::build_attn( cb(cur, "kqv_out", il); + if (il >= 0 && il < (int) model.layers.size()) { + const auto & layer = model.layers[il]; + if (layer.attn_out_norm_w) { + cur = build_norm(cur, layer.attn_out_norm_w, layer.attn_out_norm_b, block_norm_t, eps, il); + cb(cur, "kqv_out_normed", il); + } + } + if (wo) { cur = ggml_mul_mat(ctx0, wo, cur); } @@ -813,6 +830,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 { builder = std::make_unique(ctx, img); } break; + case PROJECTOR_TYPE_JINACLIP2: + { + builder = std::make_unique(ctx, img); + } break; case PROJECTOR_TYPE_ULTRAVOX: case PROJECTOR_TYPE_VOXTRAL: case PROJECTOR_TYPE_QWEN2A: @@ -1005,6 +1026,8 @@ struct clip_model_loader { hparams.minicpmv_query_num = 64; } else if (hparams.minicpmv_version == 6) { hparams.minicpmv_query_num = 64; + } else if (hparams.minicpmv_version == 100045) { + hparams.minicpmv_query_num = 64; } else { hparams.minicpmv_query_num = 96; } @@ -1198,6 +1221,11 @@ struct clip_model_loader { get_u32(KEY_PROJ_SCALE_FACTOR, hparams.n_merge, false); set_llava_uhd_res_candidates(model, 3); } break; + case PROJECTOR_TYPE_JINACLIP2: + { + hparams.rope_theta = 10000.0f; + get_f32(KEY_VISION_ROPE_THETA, hparams.rope_theta, false); + } break; case PROJECTOR_TYPE_ULTRAVOX: case PROJECTOR_TYPE_QWEN2A: case PROJECTOR_TYPE_GLMA: @@ -1356,6 +1384,7 @@ struct clip_model_loader { layer.qkv_w = get_tensor(string_format(TN_ATTN_QKV, prefix, il, "weight"), false); layer.k_norm = get_tensor(string_format(TN_ATTN_K_NORM, prefix, il, "weight"), false); layer.q_norm = get_tensor(string_format(TN_ATTN_Q_NORM, prefix, il, "weight"), false); + layer.attn_out_norm_w = get_tensor(string_format(TN_ATTN_LN, prefix, il, "weight"), false); layer.ln_1_w = get_tensor(string_format(TN_LN_1, prefix, il, "weight"), false); layer.ln_2_w = get_tensor(string_format(TN_LN_2, prefix, il, "weight"), false); layer.ls_1_w = get_tensor(string_format(TN_LS_1, prefix, il, "weight"), false); // no bias @@ -1366,6 +1395,7 @@ struct clip_model_loader { layer.v_b = get_tensor(string_format(TN_ATTN_V, prefix, il, "bias"), false); layer.o_b = get_tensor(string_format(TN_ATTN_OUTPUT, prefix, il, "bias"), false); layer.qkv_b = get_tensor(string_format(TN_ATTN_QKV, prefix, il, "bias"), false); + layer.attn_out_norm_b = get_tensor(string_format(TN_ATTN_LN, prefix, il, "bias"), false); layer.ln_1_b = get_tensor(string_format(TN_LN_1, prefix, il, "bias"), false); layer.ln_2_b = get_tensor(string_format(TN_LN_2, prefix, il, "bias"), false); @@ -1374,6 +1404,8 @@ struct clip_model_loader { layer.ff_up_b = get_tensor(string_format(TN_FFN_UP, prefix, il, "bias"), false); layer.ff_gate_w = get_tensor(string_format(TN_FFN_GATE, prefix, il, "weight"), false); layer.ff_gate_b = get_tensor(string_format(TN_FFN_GATE, prefix, il, "bias"), false); + layer.ffn_hidden_norm_w = get_tensor(string_format(TN_FFN_NORM, prefix, il, "weight"), false); + layer.ffn_hidden_norm_b = get_tensor(string_format(TN_FFN_NORM, prefix, il, "bias"), false); layer.ff_down_w = get_tensor(string_format(TN_FFN_DOWN, prefix, il, "weight")); layer.ff_down_b = get_tensor(string_format(TN_FFN_DOWN, prefix, il, "bias"), false); @@ -1783,6 +1815,9 @@ struct clip_model_loader { model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight")); model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias")); } break; + case PROJECTOR_TYPE_JINACLIP2: + { + } break; case PROJECTOR_TYPE_LFM2A: { for (int i : {0, 2, 3, 5, 6}) { @@ -3018,6 +3053,41 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str res_imgs->grid_y = inst.grid_size.height; } break; + case PROJECTOR_TYPE_JINACLIP2: + { + clip_image_u8 processed_image; + const int sz = params.image_size; + + const int in_w = img->nx; + const int in_h = img->ny; + if (in_w <= 0 || in_h <= 0) { + LOG_ERR("%s: invalid input image size %dx%d\n", __func__, in_w, in_h); + return false; + } + + int out_w = 0, out_h = 0; + if (in_w < in_h) { + out_w = sz; + out_h = std::max(1, (int) std::round((double) in_h * sz / in_w)); + } else { + out_h = sz; + out_w = std::max(1, (int) std::round((double) in_w * sz / in_h)); + } + + clip_image_u8 resized_keep_ratio; + img_tool::resize(*img, resized_keep_ratio, clip_image_size{out_w, out_h}, img_tool::RESIZE_ALGO_BICUBIC); + + const int x0 = std::max(0, (resized_keep_ratio.nx - sz) / 2); + const int y0 = std::max(0, (resized_keep_ratio.ny - sz) / 2); + const int crop_w = std::min(sz, resized_keep_ratio.nx); + const int crop_h = std::min(sz, resized_keep_ratio.ny); + img_tool::crop(resized_keep_ratio, processed_image, x0, y0, crop_w, crop_h); + + clip_image_f32_ptr img_f32(clip_image_f32_init()); + normalize_image_u8_to_f32(processed_image, *img_f32, params.image_mean, params.image_std); + res_imgs->entries.push_back(std::move(img_f32)); + } break; + case PROJECTOR_TYPE_LFM2: case PROJECTOR_TYPE_KIMIVL: { @@ -3181,6 +3251,10 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im { // do nothing } break; + case PROJECTOR_TYPE_JINACLIP2: + { + n_patches = 1; + } break; case PROJECTOR_TYPE_LDP: case PROJECTOR_TYPE_LDPV2: case PROJECTOR_TYPE_GLM_EDGE: @@ -3209,6 +3283,9 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im } else if (params.minicpmv_version == 6) { // MiniCPM-V 4.5 n_patches = 64; + } else if (params.minicpmv_version == 100045) { + // MiniCPM-o 4.5 + n_patches = 64; } else { GGML_ABORT("Unknown minicpmv version"); } @@ -3608,6 +3685,54 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima } set_input_i32("positions", positions); } break; + case PROJECTOR_TYPE_JINACLIP2: + { + std::vector positions(n_pos); + for (int i = 0; i < n_pos; i++) { + positions[i] = i; + } + set_input_i32("positions", positions); + + const int n_patches = n_pos - 1; + const int n_patches_per_col = image_size_width / patch_size; + + std::vector pos_data(n_pos, 0); + + for (int i = 0; i < n_patches; ++i) { + const int idx = i + 1; + pos_data[idx] = i / n_patches_per_col; + } + set_input_i32("pos_h", pos_data); + + std::fill(pos_data.begin(), pos_data.end(), 0); + for (int i = 0; i < n_patches; ++i) { + const int idx = i + 1; + pos_data[idx] = i % n_patches_per_col; + } + set_input_i32("pos_w", pos_data); + + int pt_seq_len = 16; + if (patch_size > 0) { + const int cand = (int) llroundf(224.0f / (float) patch_size); + if (cand > 0) { + pt_seq_len = cand; + } + } + const float s = (float) pt_seq_len / (float) n_patches_per_col; + const int d_head_local = hparams.n_embd / hparams.n_head; + const int half_local = d_head_local / 2; + std::vector rope_c_first(half_local); + std::vector rope_c_second(half_local); + const float odd = std::pow(hparams.rope_theta, (float) -2.0f / (float) d_head_local); + + for (int k = 0; k < half_local; ++k) { + rope_c_first[k] = 1.0f / s; + rope_c_second[k] = 1.0f / (s * odd); + } + + set_input_f32("rope_c_first", rope_c_first); + set_input_f32("rope_c_second", rope_c_second); + } break; case PROJECTOR_TYPE_MLP: case PROJECTOR_TYPE_MLP_NORM: case PROJECTOR_TYPE_LDP: @@ -3732,6 +3857,8 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { case PROJECTOR_TYPE_PIXTRAL: case PROJECTOR_TYPE_LIGHTONOCR: return ctx->model.mm_2_w->ne[1]; + case PROJECTOR_TYPE_JINACLIP2: + return ctx->model.hparams.projection_dim; case PROJECTOR_TYPE_MLP_NORM: return ctx->model.mm_3_b->ne[0]; case PROJECTOR_TYPE_MINICPMV: diff --git a/tools/mtmd/legacy-models/minicpmv-convert-image-encoder-to-gguf.py b/tools/mtmd/legacy-models/minicpmv-convert-image-encoder-to-gguf.py index bb2cc4e4ea5..944037e703e 100644 --- a/tools/mtmd/legacy-models/minicpmv-convert-image-encoder-to-gguf.py +++ b/tools/mtmd/legacy-models/minicpmv-convert-image-encoder-to-gguf.py @@ -501,7 +501,7 @@ def bytes_to_unicode(): default_image_std = [0.5, 0.5, 0.5] ap.add_argument('--image-mean', type=float, nargs='+', help='Mean of the images for normalization (overrides processor) ', default=None) ap.add_argument('--image-std', type=float, nargs='+', help='Standard deviation of the images for normalization (overrides processor)', default=None) -ap.add_argument('--minicpmv_version', type=int, help='minicpmv_version: MiniCPM-V-2 use 1; MiniCPM-V-2.5 use 2; MiniCPM-V-2.6 use 3; MiniCPM-o-2.6 use 4; MiniCPM-V 4.0 use 5; MiniCPM-o-4.0 use 6', default=2) +ap.add_argument('--minicpmv_version', type=int, help='minicpmv_version: MiniCPM-V-2 use 1; MiniCPM-V-2.5 use 2; MiniCPM-V-2.6 use 3; MiniCPM-o-2.6 use 4; MiniCPM-V 4.0 use 5; MiniCPM-o-4.0 use 6; MiniCPM-o-4.5 use 100045', default=2) # with proper args = ap.parse_args() @@ -610,6 +610,9 @@ def bytes_to_unicode(): elif minicpmv_version == 6: emb_dim = 4096 block_count = 27 + elif minicpmv_version == 100045: + emb_dim = 4096 + block_count = 27 default_vision_config = { "hidden_size": 1152, @@ -637,6 +640,10 @@ def bytes_to_unicode(): default_vision_config["model_type"] = "siglip_vision_model" vision_config = SiglipVisionConfig(**default_vision_config) model = SiglipVisionTransformer(vision_config) +elif minicpmv_version == 100045: + default_vision_config["model_type"] = "siglip_vision_model" + vision_config = SiglipVisionConfig(**default_vision_config) + model = SiglipVisionTransformer(vision_config) processor = None # if model.attn_pool is not None: diff --git a/tools/mtmd/models/jinaclip2.cpp b/tools/mtmd/models/jinaclip2.cpp new file mode 100644 index 00000000000..755ed9fa0f8 --- /dev/null +++ b/tools/mtmd/models/jinaclip2.cpp @@ -0,0 +1,110 @@ +#include "models.h" + +ggml_cgraph * clip_graph_jinaclip2::build() { + GGML_ASSERT(model.class_embedding != nullptr && "JinaCLIP2 requires a CLS token"); + + const int n_pos = n_patches + 1; + + GGML_ASSERT(n_patches_x == n_patches_y && "only square images supported"); + + ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos); + ggml_set_name(positions, "positions"); + ggml_set_input(positions); + + ggml_tensor * pos_h = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos); + ggml_set_name(pos_h, "pos_h"); + ggml_set_input(pos_h); + + ggml_tensor * pos_w = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos); + ggml_set_name(pos_w, "pos_w"); + ggml_set_input(pos_w); + + GGML_ASSERT(d_head % 2 == 0); + ggml_tensor * rope_c_first = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, d_head / 2); + ggml_set_name(rope_c_first, "rope_c_first"); + ggml_set_input(rope_c_first); + + ggml_tensor * rope_c_second = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, d_head / 2); + ggml_set_name(rope_c_second, "rope_c_second"); + ggml_set_input(rope_c_second); + + ggml_tensor * inp = build_inp(); + inp = ggml_concat(ctx0, model.class_embedding, inp, 1); + inp = ggml_add(ctx0, inp, ggml_get_rows(ctx0, model.position_embeddings, positions)); + + auto apply_rope_2d = [&](ggml_tensor * cur) -> ggml_tensor * { + + ggml_tensor * cur_in = ggml_permute(ctx0, cur, 0, 2, 1, 3); + + const int64_t n_dim = cur_in->ne[0]; + const int64_t seq = cur_in->ne[1]; + const int64_t nhead = cur_in->ne[2]; + GGML_ASSERT(seq == n_pos); + GGML_ASSERT(n_dim % 2 == 0); + + const int64_t half = n_dim / 2; + + ggml_tensor * cls = ggml_view_3d(ctx0, cur_in, n_dim, 1, nhead, cur_in->nb[1], cur_in->nb[2], 0); + ggml_tensor * patches = ggml_view_3d(ctx0, cur_in, n_dim, seq - 1, nhead, cur_in->nb[1], cur_in->nb[2], cur_in->nb[1]); + const int64_t n_pos_patches = seq - 1; + const int64_t pos_offset = 1; + + // select positions + ggml_tensor * pos_a = ggml_view_1d(ctx0, pos_h, n_pos_patches, pos_offset * (int64_t) ggml_element_size(pos_h)); + ggml_tensor * pos_b = ggml_view_1d(ctx0, pos_w, n_pos_patches, pos_offset * (int64_t) ggml_element_size(pos_w)); + + ggml_tensor * first = ggml_view_3d(ctx0, patches, + half, nhead, n_pos_patches, + patches->nb[2], patches->nb[1], 0); + ggml_tensor * first_rot = ggml_rope_ext( + ctx0, + first, + pos_a, + rope_c_first, + half, + 0, 0, hparams.rope_theta, + 1.0f, + 0.0f, 1.0f, 0.0f, 0.0f); + first = ggml_view_3d(ctx0, first_rot, + half, n_pos_patches, nhead, + first_rot->nb[2], first_rot->nb[1], 0); + + ggml_tensor * second = ggml_view_3d(ctx0, patches, + half, nhead, n_pos_patches, + patches->nb[2], patches->nb[1], + half * (int64_t) ggml_element_size(patches)); + ggml_tensor * second_rot = ggml_rope_ext( + ctx0, + second, + pos_b, + rope_c_second, + half, + 0, 0, hparams.rope_theta, + 1.0f, + 0.0f, 1.0f, 0.0f, 0.0f); + second = ggml_view_3d(ctx0, second_rot, + half, n_pos_patches, nhead, + second_rot->nb[2], second_rot->nb[1], 0); + + ggml_tensor * patches_out = ggml_concat(ctx0, first, second, 0); + ggml_tensor * out_seq = ggml_concat(ctx0, cls, patches_out, 1); + return ggml_permute(ctx0, out_seq, 0, 2, 1, 3); + }; + + auto add_pos = [&](ggml_tensor * cur, const clip_layer &) { + return apply_rope_2d(cur); + }; + + ggml_tensor * cur = build_vit( + inp, n_pos, + NORM_TYPE_NORMAL, + hparams.ffn_op, + nullptr, + add_pos); + + ggml_tensor * cls = ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], 0); + ggml_set_name(cls, "cls_view"); + ggml_build_forward_expand(gf, cls); + + return gf; +} diff --git a/tools/mtmd/models/models.h b/tools/mtmd/models/models.h index 9970980c7bc..d960673afb1 100644 --- a/tools/mtmd/models/models.h +++ b/tools/mtmd/models/models.h @@ -52,6 +52,11 @@ struct clip_graph_kimivl : clip_graph { ggml_cgraph * build() override; }; +struct clip_graph_jinaclip2 : clip_graph { + clip_graph_jinaclip2(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} + ggml_cgraph * build() override; +}; + struct clip_graph_cogvlm : clip_graph { clip_graph_cogvlm(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} ggml_cgraph * build() override; diff --git a/tools/mtmd/mtmd-cli.cpp b/tools/mtmd/mtmd-cli.cpp index 054c7faa6af..6dd25540ca0 100644 --- a/tools/mtmd/mtmd-cli.cpp +++ b/tools/mtmd/mtmd-cli.cpp @@ -40,7 +40,8 @@ static void show_additional_info(int /*argc*/, char ** argv) { LOG( "Experimental CLI for multimodal\n\n" "Usage: %s [options] -m --mmproj --image --audio