diff --git a/.github/workflows/build-wheels-release.yml b/.github/workflows/build-wheels-release.yml index bf914c00..e23ba2e3 100644 --- a/.github/workflows/build-wheels-release.yml +++ b/.github/workflows/build-wheels-release.yml @@ -215,19 +215,19 @@ jobs: run: | mkdir -p "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8" choco install unzip -y - curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_cudart/windows-x86_64/cuda_cudart-windows-x86_64-11.8.89-archive.zip" - curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/windows-x86_64/cuda_nvcc-windows-x86_64-11.8.89-archive.zip" - curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvrtc/windows-x86_64/cuda_nvrtc-windows-x86_64-11.8.89-archive.zip" - curl -O "https://developer.download.nvidia.com/compute/cuda/redist/libcublas/windows-x86_64/libcublas-windows-x86_64-11.11.3.6-archive.zip" - curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvtx/windows-x86_64/cuda_nvtx-windows-x86_64-11.8.86-archive.zip" - curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_profiler_api/windows-x86_64/cuda_profiler_api-windows-x86_64-11.8.86-archive.zip" - curl -O "https://developer.download.nvidia.com/compute/cuda/redist/visual_studio_integration/windows-x86_64/visual_studio_integration-windows-x86_64-11.8.86-archive.zip" - curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvprof/windows-x86_64/cuda_nvprof-windows-x86_64-11.8.87-archive.zip" - curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_cccl/windows-x86_64/cuda_cccl-windows-x86_64-11.8.89-archive.zip" - curl -O "https://developer.download.nvidia.com/compute/cuda/redist/libcusparse/windows-x86_64/libcusparse-windows-x86_64-11.7.5.86-archive.zip" - curl -O "https://developer.download.nvidia.com/compute/cuda/redist/libcusolver/windows-x86_64/libcusolver-windows-x86_64-11.4.1.48-archive.zip" - curl -O "https://developer.download.nvidia.com/compute/cuda/redist/libcurand/windows-x86_64/libcurand-windows-x86_64-10.3.0.86-archive.zip" - curl -O "https://developer.download.nvidia.com/compute/cuda/redist/libcufft/windows-x86_64/libcufft-windows-x86_64-10.9.0.58-archive.zip" + curl -fL -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_cudart/windows-x86_64/cuda_cudart-windows-x86_64-11.8.89-archive.zip" + curl -fL -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/windows-x86_64/cuda_nvcc-windows-x86_64-11.8.89-archive.zip" + curl -fL -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvrtc/windows-x86_64/cuda_nvrtc-windows-x86_64-11.8.89-archive.zip" + curl -fL -O "https://developer.download.nvidia.com/compute/cuda/redist/libcublas/windows-x86_64/libcublas-windows-x86_64-11.11.3.6-archive.zip" + curl -fL -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvtx/windows-x86_64/cuda_nvtx-windows-x86_64-11.8.86-archive.zip" + curl -fL -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_profiler_api/windows-x86_64/cuda_profiler_api-windows-x86_64-11.8.86-archive.zip" + curl -fL -O "https://developer.download.nvidia.com/compute/cuda/redist/visual_studio_integration/windows-x86_64/visual_studio_integration-windows-x86_64-11.8.86-archive.zip" + curl -fL -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvprof/windows-x86_64/cuda_nvprof-windows-x86_64-11.8.87-archive.zip" + curl -fL -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_cccl/windows-x86_64/cuda_cccl-windows-x86_64-11.8.89-archive.zip" + curl -fL -O "https://developer.download.nvidia.com/compute/cuda/redist/libcusparse/windows-x86_64/libcusparse-windows-x86_64-11.7.5.86-archive.zip" + curl -fL -O "https://developer.download.nvidia.com/compute/cuda/redist/libcusolver/windows-x86_64/libcusolver-windows-x86_64-11.4.1.48-archive.zip" + curl -fL -O "https://developer.download.nvidia.com/compute/cuda/redist/libcurand/windows-x86_64/libcurand-windows-x86_64-10.3.0.86-archive.zip" + curl -fL -O "https://developer.download.nvidia.com/compute/cuda/redist/libcufft/windows-x86_64/libcufft-windows-x86_64-10.9.0.58-archive.zip" unzip '*.zip' -d "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8" xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8\cuda_cudart-windows-x86_64-11.8.89-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8" /E /I /H /Y xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8\cuda_nvcc-windows-x86_64-11.8.89-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8" /E /I /H /Y @@ -246,25 +246,25 @@ jobs: echo "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8\libnvvp" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append echo "CUDA_PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8 echo "CUDA_PATH_V11_8=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8 - + - name: Install Windows CUDA 12.1 if: runner.os == 'Windows' && contains(matrix.cuda, '12.1') run: | mkdir -p "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.1" choco install unzip -y - curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_cudart/windows-x86_64/cuda_cudart-windows-x86_64-12.1.105-archive.zip" - curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/windows-x86_64/cuda_nvcc-windows-x86_64-12.1.105-archive.zip" - curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvrtc/windows-x86_64/cuda_nvrtc-windows-x86_64-12.1.105-archive.zip" - curl -O "https://developer.download.nvidia.com/compute/cuda/redist/libcublas/windows-x86_64/libcublas-windows-x86_64-12.1.3.1-archive.zip" - curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvtx/windows-x86_64/cuda_nvtx-windows-x86_64-12.1.105-archive.zip" - curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_profiler_api/windows-x86_64/cuda_profiler_api-windows-x86_64-12.1.105-archive.zip" - curl -O "https://developer.download.nvidia.com/compute/cuda/redist/visual_studio_integration/windows-x86_64/visual_studio_integration-windows-x86_64-12.1.105-archive.zip" - curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvprof/windows-x86_64/cuda_nvprof-windows-x86_64-12.1.105-archive.zip" - curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_cccl/windows-x86_64/cuda_cccl-windows-x86_64-12.1.109-archive.zip" - curl -O "https://developer.download.nvidia.com/compute/cuda/redist/libcusparse/windows-x86_64/libcusparse-windows-x86_64-12.1.0.106-archive.zip" - curl -O "https://developer.download.nvidia.com/compute/cuda/redist/libcusolver/windows-x86_64/libcusolver-windows-x86_64-11.4.5.107-archive.zip" - curl -O "https://developer.download.nvidia.com/compute/cuda/redist/libcurand/windows-x86_64/libcurand-windows-x86_64-10.3.2.106-archive.zip" - curl -O "https://developer.download.nvidia.com/compute/cuda/redist/libcufft/windows-x86_64/libcufft-windows-x86_64-11.0.2.54-archive.zip" + curl -fL -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_cudart/windows-x86_64/cuda_cudart-windows-x86_64-12.1.105-archive.zip" + curl -fL -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/windows-x86_64/cuda_nvcc-windows-x86_64-12.1.105-archive.zip" + curl -fL -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvrtc/windows-x86_64/cuda_nvrtc-windows-x86_64-12.1.105-archive.zip" + curl -fL -O "https://developer.download.nvidia.com/compute/cuda/redist/libcublas/windows-x86_64/libcublas-windows-x86_64-12.1.3.1-archive.zip" + curl -fL -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvtx/windows-x86_64/cuda_nvtx-windows-x86_64-12.1.105-archive.zip" + curl -fL -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_profiler_api/windows-x86_64/cuda_profiler_api-windows-x86_64-12.1.105-archive.zip" + curl -fL -O "https://developer.download.nvidia.com/compute/cuda/redist/visual_studio_integration/windows-x86_64/visual_studio_integration-windows-x86_64-12.1.105-archive.zip" + curl -fL -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvprof/windows-x86_64/cuda_nvprof-windows-x86_64-12.1.105-archive.zip" + curl -fL -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_cccl/windows-x86_64/cuda_cccl-windows-x86_64-12.1.109-archive.zip" + curl -fL -O "https://developer.download.nvidia.com/compute/cuda/redist/libcusparse/windows-x86_64/libcusparse-windows-x86_64-12.1.0.106-archive.zip" + curl -fL -O "https://developer.download.nvidia.com/compute/cuda/redist/libcusolver/windows-x86_64/libcusolver-windows-x86_64-11.4.5.107-archive.zip" + curl -fL -O "https://developer.download.nvidia.com/compute/cuda/redist/libcurand/windows-x86_64/libcurand-windows-x86_64-10.3.2.106-archive.zip" + curl -fL -O "https://developer.download.nvidia.com/compute/cuda/redist/libcufft/windows-x86_64/libcufft-windows-x86_64-11.0.2.54-archive.zip" unzip '*.zip' -d "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.1" xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.1\cuda_cudart-windows-x86_64-12.1.105-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.1" /E /I /H /Y xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.1\cuda_nvcc-windows-x86_64-12.1.105-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.1" /E /I /H /Y @@ -289,19 +289,19 @@ jobs: run: | mkdir -p "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" choco install unzip -y - curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_cudart/windows-x86_64/cuda_cudart-windows-x86_64-12.4.127-archive.zip" - curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/windows-x86_64/cuda_nvcc-windows-x86_64-12.4.131-archive.zip" - curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvrtc/windows-x86_64/cuda_nvrtc-windows-x86_64-12.4.127-archive.zip" - curl -O "https://developer.download.nvidia.com/compute/cuda/redist/libcublas/windows-x86_64/libcublas-windows-x86_64-12.4.5.8-archive.zip" - curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvtx/windows-x86_64/cuda_nvtx-windows-x86_64-12.4.127-archive.zip" - curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_profiler_api/windows-x86_64/cuda_profiler_api-windows-x86_64-12.4.127-archive.zip" - curl -O "https://developer.download.nvidia.com/compute/cuda/redist/visual_studio_integration/windows-x86_64/visual_studio_integration-windows-x86_64-12.4.127-archive.zip" - curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvprof/windows-x86_64/cuda_nvprof-windows-x86_64-12.4.127-archive.zip" - curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_cccl/windows-x86_64/cuda_cccl-windows-x86_64-12.4.127-archive.zip" - curl -O "https://developer.download.nvidia.com/compute/cuda/redist/libcusparse/windows-x86_64/libcusparse-windows-x86_64-12.3.1.170-archive.zip" - curl -O "https://developer.download.nvidia.com/compute/cuda/redist/libcusolver/windows-x86_64/libcusolver-windows-x86_64-11.6.1.9-archive.zip" - curl -O "https://developer.download.nvidia.com/compute/cuda/redist/libcurand/windows-x86_64/libcurand-windows-x86_64-10.3.5.147-archive.zip" - curl -O "https://developer.download.nvidia.com/compute/cuda/redist/libcufft/windows-x86_64/libcufft-windows-x86_64-11.2.1.3-archive.zip" + curl -fL -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_cudart/windows-x86_64/cuda_cudart-windows-x86_64-12.4.127-archive.zip" + curl -fL -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/windows-x86_64/cuda_nvcc-windows-x86_64-12.4.131-archive.zip" + curl -fL -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvrtc/windows-x86_64/cuda_nvrtc-windows-x86_64-12.4.127-archive.zip" + curl -fL -O "https://developer.download.nvidia.com/compute/cuda/redist/libcublas/windows-x86_64/libcublas-windows-x86_64-12.4.5.8-archive.zip" + curl -fL -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvtx/windows-x86_64/cuda_nvtx-windows-x86_64-12.4.127-archive.zip" + curl -fL -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_profiler_api/windows-x86_64/cuda_profiler_api-windows-x86_64-12.4.127-archive.zip" + curl -fL -O "https://developer.download.nvidia.com/compute/cuda/redist/visual_studio_integration/windows-x86_64/visual_studio_integration-windows-x86_64-12.4.127-archive.zip" + curl -fL -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvprof/windows-x86_64/cuda_nvprof-windows-x86_64-12.4.127-archive.zip" + curl -fL -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_cccl/windows-x86_64/cuda_cccl-windows-x86_64-12.4.127-archive.zip" + curl -fL -O "https://developer.download.nvidia.com/compute/cuda/redist/libcusparse/windows-x86_64/libcusparse-windows-x86_64-12.3.1.170-archive.zip" + curl -fL -O "https://developer.download.nvidia.com/compute/cuda/redist/libcusolver/windows-x86_64/libcusolver-windows-x86_64-11.6.1.9-archive.zip" + curl -fL -O "https://developer.download.nvidia.com/compute/cuda/redist/libcurand/windows-x86_64/libcurand-windows-x86_64-10.3.5.147-archive.zip" + curl -fL -O "https://developer.download.nvidia.com/compute/cuda/redist/libcufft/windows-x86_64/libcufft-windows-x86_64-11.2.1.3-archive.zip" unzip '*.zip' -d "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\cuda_cudart-windows-x86_64-12.4.127-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" /E /I /H /Y xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\cuda_nvcc-windows-x86_64-12.4.131-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" /E /I /H /Y @@ -326,19 +326,19 @@ jobs: run: | mkdir -p "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.8" choco install unzip -y - curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_cudart/windows-x86_64/cuda_cudart-windows-x86_64-12.8.57-archive.zip" - curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/windows-x86_64/cuda_nvcc-windows-x86_64-12.8.61-archive.zip" - curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvrtc/windows-x86_64/cuda_nvrtc-windows-x86_64-12.8.61-archive.zip" - curl -O "https://developer.download.nvidia.com/compute/cuda/redist/libcublas/windows-x86_64/libcublas-windows-x86_64-12.8.3.14-archive.zip" - curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvtx/windows-x86_64/cuda_nvtx-windows-x86_64-12.8.55-archive.zip" - curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_profiler_api/windows-x86_64/cuda_profiler_api-windows-x86_64-12.8.55-archive.zip" - curl -O "https://developer.download.nvidia.com/compute/cuda/redist/visual_studio_integration/windows-x86_64/visual_studio_integration-windows-x86_64-12.8.55-archive.zip" - curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvprof/windows-x86_64/cuda_nvprof-windows-x86_64-12.8.57-archive.zip" - curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_cccl/windows-x86_64/cuda_cccl-windows-x86_64-12.8.55-archive.zip" - curl -O "https://developer.download.nvidia.com/compute/cuda/redist/libcusparse/windows-x86_64/libcusparse-windows-x86_64-12.5.7.53-archive.zip" - curl -O "https://developer.download.nvidia.com/compute/cuda/redist/libcusolver/windows-x86_64/libcusolver-windows-x86_64-11.7.2.55-archive.zip" - curl -O "https://developer.download.nvidia.com/compute/cuda/redist/libcurand/windows-x86_64/libcurand-windows-x86_64-10.3.9.55-archive.zip" - curl -O "https://developer.download.nvidia.com/compute/cuda/redist/libcufft/windows-x86_64/libcufft-windows-x86_64-11.3.3.41-archive.zip" + curl -fL -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_cudart/windows-x86_64/cuda_cudart-windows-x86_64-12.8.57-archive.zip" + curl -fL -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/windows-x86_64/cuda_nvcc-windows-x86_64-12.8.61-archive.zip" + curl -fL -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvrtc/windows-x86_64/cuda_nvrtc-windows-x86_64-12.8.61-archive.zip" + curl -fL -O "https://developer.download.nvidia.com/compute/cuda/redist/libcublas/windows-x86_64/libcublas-windows-x86_64-12.8.3.14-archive.zip" + curl -fL -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvtx/windows-x86_64/cuda_nvtx-windows-x86_64-12.8.55-archive.zip" + curl -fL -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_profiler_api/windows-x86_64/cuda_profiler_api-windows-x86_64-12.8.55-archive.zip" + curl -fL -O "https://developer.download.nvidia.com/compute/cuda/redist/visual_studio_integration/windows-x86_64/visual_studio_integration-windows-x86_64-12.8.55-archive.zip" + curl -fL -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvprof/windows-x86_64/cuda_nvprof-windows-x86_64-12.8.57-archive.zip" + curl -fL -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_cccl/windows-x86_64/cuda_cccl-windows-x86_64-12.8.55-archive.zip" + curl -fL -O "https://developer.download.nvidia.com/compute/cuda/redist/libcusparse/windows-x86_64/libcusparse-windows-x86_64-12.5.7.53-archive.zip" + curl -fL -O "https://developer.download.nvidia.com/compute/cuda/redist/libcusolver/windows-x86_64/libcusolver-windows-x86_64-11.7.2.55-archive.zip" + curl -fL -O "https://developer.download.nvidia.com/compute/cuda/redist/libcurand/windows-x86_64/libcurand-windows-x86_64-10.3.9.55-archive.zip" + curl -fL -O "https://developer.download.nvidia.com/compute/cuda/redist/libcufft/windows-x86_64/libcufft-windows-x86_64-11.3.3.41-archive.zip" unzip '*.zip' -d "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.8" xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.8\cuda_cudart-windows-x86_64-12.8.57-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.8" /E /I /H /Y xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.8\cuda_nvcc-windows-x86_64-12.8.61-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.8" /E /I /H /Y @@ -357,7 +357,7 @@ jobs: echo "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.8\libnvvp" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append echo "CUDA_PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.8" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8 echo "CUDA_PATH_V12_8=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.8" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8 - + # TODO: Find specific sub-packages - name: Install Linux CUDA ${{ matrix.cuda }} uses: Jimver/cuda-toolkit@v0.2.23 diff --git a/.github/workflows/build-wheels-release_torch27_only.yml b/.github/workflows/build-wheels-release_torch27_only.yml index a17a5b80..6f303660 100644 --- a/.github/workflows/build-wheels-release_torch27_only.yml +++ b/.github/workflows/build-wheels-release_torch27_only.yml @@ -68,58 +68,58 @@ jobs: # Windows 2022 CUDA # Python 3.10 - - { artname: 'wheel', os: windows-2022, pyver: '3.10', cuda: '11.8.0', rocm: '', torch: '2.3.1', cudaarch: '6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX' } - - { artname: 'wheel', os: windows-2022, pyver: '3.10', cuda: '12.1.0', rocm: '', torch: '2.3.1', cudaarch: '6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX' } - - { artname: 'wheel', os: windows-2022, pyver: '3.10', cuda: '11.8.0', rocm: '', torch: '2.4.0', cudaarch: '6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX' } - - { artname: 'wheel', os: windows-2022, pyver: '3.10', cuda: '12.1.0', rocm: '', torch: '2.4.0', cudaarch: '6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX' } - - { artname: 'wheel', os: windows-2022, pyver: '3.10', cuda: '11.8.0', rocm: '', torch: '2.5.0', cudaarch: '6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX' } - - { artname: 'wheel', os: windows-2022, pyver: '3.10', cuda: '12.1.0', rocm: '', torch: '2.5.0', cudaarch: '6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX' } - - { artname: 'wheel', os: windows-2022, pyver: '3.10', cuda: '11.8.0', rocm: '', torch: '2.6.0', cudaarch: '6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX' } - - { artname: 'wheel', os: windows-2022, pyver: '3.10', cuda: '12.4.0', rocm: '', torch: '2.6.0', cudaarch: '6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX' } - - { artname: 'wheel', os: windows-2022, pyver: '3.10', cuda: '12.8.1', rocm: '', torch: '2.7.0', cudaarch: '6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0 10.0 12.0+PTX' } - - # Python 3.11 - - { artname: 'wheel', os: windows-2022, pyver: '3.11', cuda: '11.8.0', rocm: '', torch: '2.3.1', cudaarch: '6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX' } - - { artname: 'wheel', os: windows-2022, pyver: '3.11', cuda: '12.1.0', rocm: '', torch: '2.3.1', cudaarch: '6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX' } - - { artname: 'wheel', os: windows-2022, pyver: '3.11', cuda: '11.8.0', rocm: '', torch: '2.4.0', cudaarch: '6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX' } - - { artname: 'wheel', os: windows-2022, pyver: '3.11', cuda: '12.1.0', rocm: '', torch: '2.4.0', cudaarch: '6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX' } - - { artname: 'wheel', os: windows-2022, pyver: '3.11', cuda: '11.8.0', rocm: '', torch: '2.5.0', cudaarch: '6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX' } - - { artname: 'wheel', os: windows-2022, pyver: '3.11', cuda: '12.1.0', rocm: '', torch: '2.5.0', cudaarch: '6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX' } - - { artname: 'wheel', os: windows-2022, pyver: '3.11', cuda: '11.8.0', rocm: '', torch: '2.6.0', cudaarch: '6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX' } - - { artname: 'wheel', os: windows-2022, pyver: '3.11', cuda: '12.4.0', rocm: '', torch: '2.6.0', cudaarch: '6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX' } - - { artname: 'wheel', os: windows-2022, pyver: '3.11', cuda: '12.8.1', rocm: '', torch: '2.7.0', cudaarch: '6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0 10.0 12.0+PTX' } - - # Python 3.12 - - { artname: 'wheel', os: windows-2022, pyver: '3.12', cuda: '11.8.0', rocm: '', torch: '2.3.1', cudaarch: '6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX' } - - { artname: 'wheel', os: windows-2022, pyver: '3.12', cuda: '12.1.0', rocm: '', torch: '2.3.1', cudaarch: '6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX' } - - { artname: 'wheel', os: windows-2022, pyver: '3.12', cuda: '11.8.0', rocm: '', torch: '2.4.0', cudaarch: '6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX' } - - { artname: 'wheel', os: windows-2022, pyver: '3.12', cuda: '12.1.0', rocm: '', torch: '2.4.0', cudaarch: '6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX' } - - { artname: 'wheel', os: windows-2022, pyver: '3.12', cuda: '11.8.0', rocm: '', torch: '2.5.0', cudaarch: '6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX' } - - { artname: 'wheel', os: windows-2022, pyver: '3.12', cuda: '12.1.0', rocm: '', torch: '2.5.0', cudaarch: '6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX' } - - { artname: 'wheel', os: windows-2022, pyver: '3.12', cuda: '11.8.0', rocm: '', torch: '2.6.0', cudaarch: '6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX' } - - { artname: 'wheel', os: windows-2022, pyver: '3.12', cuda: '12.4.0', rocm: '', torch: '2.6.0', cudaarch: '6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX' } - - { artname: 'wheel', os: windows-2022, pyver: '3.12', cuda: '12.8.1', rocm: '', torch: '2.7.0', cudaarch: '6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0 10.0 12.0+PTX' } - - # Python 3.13 - - { artname: 'wheel', os: windows-2022, pyver: '3.13', cuda: '11.8.0', rocm: '', torch: '2.6.0', cudaarch: '6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX' } - - { artname: 'wheel', os: windows-2022, pyver: '3.13', cuda: '12.4.0', rocm: '', torch: '2.6.0', cudaarch: '6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX' } - - { artname: 'wheel', os: windows-2022, pyver: '3.13', cuda: '12.8.1', rocm: '', torch: '2.7.0', cudaarch: '6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0 10.0 12.0+PTX' } + # - { artname: 'wheel', os: windows-2022, pyver: '3.10', cuda: '11.8.0', rocm: '', torch: '2.3.1', cudaarch: '6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX' } + # - { artname: 'wheel', os: windows-2022, pyver: '3.10', cuda: '12.1.0', rocm: '', torch: '2.3.1', cudaarch: '6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX' } + # - { artname: 'wheel', os: windows-2022, pyver: '3.10', cuda: '11.8.0', rocm: '', torch: '2.4.0', cudaarch: '6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX' } + # - { artname: 'wheel', os: windows-2022, pyver: '3.10', cuda: '12.1.0', rocm: '', torch: '2.4.0', cudaarch: '6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX' } + # - { artname: 'wheel', os: windows-2022, pyver: '3.10', cuda: '11.8.0', rocm: '', torch: '2.5.0', cudaarch: '6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX' } + # - { artname: 'wheel', os: windows-2022, pyver: '3.10', cuda: '12.1.0', rocm: '', torch: '2.5.0', cudaarch: '6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX' } + # - { artname: 'wheel', os: windows-2022, pyver: '3.10', cuda: '11.8.0', rocm: '', torch: '2.6.0', cudaarch: '6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX' } + # - { artname: 'wheel', os: windows-2022, pyver: '3.10', cuda: '12.4.0', rocm: '', torch: '2.6.0', cudaarch: '6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX' } + # - { artname: 'wheel', os: windows-2022, pyver: '3.10', cuda: '12.8.1', rocm: '', torch: '2.7.0', cudaarch: '6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0 10.0 12.0+PTX' } + + # # Python 3.11 + # - { artname: 'wheel', os: windows-2022, pyver: '3.11', cuda: '11.8.0', rocm: '', torch: '2.3.1', cudaarch: '6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX' } + # - { artname: 'wheel', os: windows-2022, pyver: '3.11', cuda: '12.1.0', rocm: '', torch: '2.3.1', cudaarch: '6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX' } + # - { artname: 'wheel', os: windows-2022, pyver: '3.11', cuda: '11.8.0', rocm: '', torch: '2.4.0', cudaarch: '6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX' } + # - { artname: 'wheel', os: windows-2022, pyver: '3.11', cuda: '12.1.0', rocm: '', torch: '2.4.0', cudaarch: '6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX' } + # - { artname: 'wheel', os: windows-2022, pyver: '3.11', cuda: '11.8.0', rocm: '', torch: '2.5.0', cudaarch: '6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX' } + # - { artname: 'wheel', os: windows-2022, pyver: '3.11', cuda: '12.1.0', rocm: '', torch: '2.5.0', cudaarch: '6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX' } + # - { artname: 'wheel', os: windows-2022, pyver: '3.11', cuda: '11.8.0', rocm: '', torch: '2.6.0', cudaarch: '6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX' } + # - { artname: 'wheel', os: windows-2022, pyver: '3.11', cuda: '12.4.0', rocm: '', torch: '2.6.0', cudaarch: '6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX' } + # - { artname: 'wheel', os: windows-2022, pyver: '3.11', cuda: '12.8.1', rocm: '', torch: '2.7.0', cudaarch: '6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0 10.0 12.0+PTX' } + + # # Python 3.12 + # - { artname: 'wheel', os: windows-2022, pyver: '3.12', cuda: '11.8.0', rocm: '', torch: '2.3.1', cudaarch: '6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX' } + # - { artname: 'wheel', os: windows-2022, pyver: '3.12', cuda: '12.1.0', rocm: '', torch: '2.3.1', cudaarch: '6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX' } + # - { artname: 'wheel', os: windows-2022, pyver: '3.12', cuda: '11.8.0', rocm: '', torch: '2.4.0', cudaarch: '6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX' } + # - { artname: 'wheel', os: windows-2022, pyver: '3.12', cuda: '12.1.0', rocm: '', torch: '2.4.0', cudaarch: '6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX' } + # - { artname: 'wheel', os: windows-2022, pyver: '3.12', cuda: '11.8.0', rocm: '', torch: '2.5.0', cudaarch: '6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX' } + # - { artname: 'wheel', os: windows-2022, pyver: '3.12', cuda: '12.1.0', rocm: '', torch: '2.5.0', cudaarch: '6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX' } + # - { artname: 'wheel', os: windows-2022, pyver: '3.12', cuda: '11.8.0', rocm: '', torch: '2.6.0', cudaarch: '6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX' } + # - { artname: 'wheel', os: windows-2022, pyver: '3.12', cuda: '12.4.0', rocm: '', torch: '2.6.0', cudaarch: '6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX' } + # - { artname: 'wheel', os: windows-2022, pyver: '3.12', cuda: '12.8.1', rocm: '', torch: '2.7.0', cudaarch: '6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0 10.0 12.0+PTX' } + + # # Python 3.13 + # - { artname: 'wheel', os: windows-2022, pyver: '3.13', cuda: '11.8.0', rocm: '', torch: '2.6.0', cudaarch: '6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX' } + # - { artname: 'wheel', os: windows-2022, pyver: '3.13', cuda: '12.4.0', rocm: '', torch: '2.6.0', cudaarch: '6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX' } + # - { artname: 'wheel', os: windows-2022, pyver: '3.13', cuda: '12.8.1', rocm: '', torch: '2.7.0', cudaarch: '6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0 10.0 12.0+PTX' } # # Ubuntu 20.04 ROCm - # # ROCm 5.6 - # - { artname: 'wheel', os: ubuntu-22.04, pyver: '3.10', cuda: '', rocm: '5.6', torch: '2.2.2', cudaarch: '' } - # - { artname: 'wheel', os: ubuntu-22.04, pyver: '3.11', cuda: '', rocm: '5.6', torch: '2.2.2', cudaarch: '' } + # ROCm 5.6 + - { artname: 'wheel', os: ubuntu-22.04, pyver: '3.10', cuda: '', rocm: '5.6', torch: '2.2.2', cudaarch: '' } + - { artname: 'wheel', os: ubuntu-22.04, pyver: '3.11', cuda: '', rocm: '5.6', torch: '2.2.2', cudaarch: '' } - # # ROCm 6.0 - # - { artname: 'wheel', os: ubuntu-22.04, pyver: '3.10', cuda: '', rocm: '6.0', torch: '2.3.1', cudaarch: '' } - # - { artname: 'wheel', os: ubuntu-22.04, pyver: '3.11', cuda: '', rocm: '6.0', torch: '2.3.1', cudaarch: '' } - # - { artname: 'wheel', os: ubuntu-22.04, pyver: '3.12', cuda: '', rocm: '6.0', torch: '2.3.1', cudaarch: '' } + # ROCm 6.0 + - { artname: 'wheel', os: ubuntu-22.04, pyver: '3.10', cuda: '', rocm: '6.0', torch: '2.3.1', cudaarch: '' } + - { artname: 'wheel', os: ubuntu-22.04, pyver: '3.11', cuda: '', rocm: '6.0', torch: '2.3.1', cudaarch: '' } + - { artname: 'wheel', os: ubuntu-22.04, pyver: '3.12', cuda: '', rocm: '6.0', torch: '2.3.1', cudaarch: '' } - # # ROCm 6.1 - # - { artname: 'wheel', os: ubuntu-22.04, pyver: '3.10', cuda: '', rocm: '6.1', torch: '2.4.0', cudaarch: '' } - # - { artname: 'wheel', os: ubuntu-22.04, pyver: '3.11', cuda: '', rocm: '6.1', torch: '2.4.0', cudaarch: '' } - # - { artname: 'wheel', os: ubuntu-22.04, pyver: '3.12', cuda: '', rocm: '6.1', torch: '2.4.0', cudaarch: '' } + # ROCm 6.1 + - { artname: 'wheel', os: ubuntu-22.04, pyver: '3.10', cuda: '', rocm: '6.1', torch: '2.4.0', cudaarch: '' } + - { artname: 'wheel', os: ubuntu-22.04, pyver: '3.11', cuda: '', rocm: '6.1', torch: '2.4.0', cudaarch: '' } + - { artname: 'wheel', os: ubuntu-22.04, pyver: '3.12', cuda: '', rocm: '6.1', torch: '2.4.0', cudaarch: '' } # # sdist # - { artname: 'sdist', os: ubuntu-22.04, pyver: '3.11', cuda: '', rocm: '', torch: '2.3.1', cudaarch: '' } diff --git a/exllamav2/architecture.py b/exllamav2/architecture.py index 300fe0ca..f318dd42 100644 --- a/exllamav2/architecture.py +++ b/exllamav2/architecture.py @@ -53,6 +53,10 @@ ["block_sparse_moe.experts.*.w2"], ["block_sparse_moe.experts.*.w3"], ["block_sparse_moe.gate"]] +layer_keys_qwen3moe_mlp = [["mlp.experts.*.gate_proj"], + ["mlp.experts.*.up_proj"], + ["mlp.experts.*.down_proj"], + ["mlp.gate"]] layer_keys_dbrx_mlp = [["block_sparse_moe.experts.*.v1", "block_sparse_moe.experts.v1"], ["block_sparse_moe.experts.*.w1", "block_sparse_moe.experts.w1"], ["block_sparse_moe.experts.*.w2", "block_sparse_moe.experts.w2"], @@ -428,6 +432,39 @@ class Params: self.lm.attention_bias_qkv = True self.lm.supports_tp = True + # Qwen3 + + if arch_string == "Qwen3ForCausalLM": + arch_recognized = True + self.lm.layer_keys += \ + layer_keys_llama_norms + \ + layer_keys_llama_attn + \ + layer_keys_llama_mlp + self.lm.expect_keys += \ + expect_keys_llama + self.lm.supports_tp = True + self.lm.default_use_qk_norm = True + + # Qwen3MoE + + if arch_string == "Qwen3MoeForCausalLM": + arch_recognized = True + self.lm.layer_keys += \ + layer_keys_llama_norms + \ + layer_keys_llama_attn + \ + layer_keys_qwen3moe_mlp + self.lm.expect_keys += \ + expect_keys_llama + self.lm.supports_tp = True + self.lm.default_use_qk_norm = True + self.lm.keys.update({ + "mlp_gate": ".mlp.experts.*.gate_proj", + "mlp_up": ".mlp.experts.*.up_proj", + "mlp_down": ".mlp.experts.*.down_proj", + "mlp_expert_gate": ".mlp.gate" + }) + self.lm.is_moe = True + # Qwen2-VL (2, 2.5) if arch_string in ["Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration"]: diff --git a/exllamav2/config.py b/exllamav2/config.py index ec5abb84..260f2287 100644 --- a/exllamav2/config.py +++ b/exllamav2/config.py @@ -319,9 +319,12 @@ def prepare(self, no_tensors: bool = False): default_intermediate_size, opt_subkey = "text_config", ) - self.num_experts = read(read_config, int, ["num_local_experts", "ffn_config->moe_num_experts"], None) + self.num_experts = read(read_config, int, ["num_local_experts", "ffn_config->moe_num_experts", "num_experts"], None) self.num_experts_per_token = read(read_config, int,["num_experts_per_tok", "ffn_config->moe_top_k"], None) + if self.arch.lm.is_moe: + self.intermediate_size = read(read_config, int, ["moe_intermediate_size"], self.intermediate_size) + # Logit/embedding/residual scale self.logit_scale = read(read_config, float, "logit_scale", 1) diff --git a/exllamav2/conversion/adaptivegptq.py b/exllamav2/conversion/adaptivegptq.py index 22382ba7..1c8e84a0 100644 --- a/exllamav2/conversion/adaptivegptq.py +++ b/exllamav2/conversion/adaptivegptq.py @@ -229,7 +229,10 @@ def prepare(self, no_h_inv = False): with torch.inference_mode(): - self.hessian /= self.num_batches + if self.hessian is None or self.num_batches == 0: + self.hessian = torch.eye(self.rows, device = self.device, dtype = torch.float) + else: + self.hessian /= self.num_batches diagonal = torch.diag(self.hessian) # Prepare weights diff --git a/exllamav2/exllamav2_ext/cpp/sampling.cpp b/exllamav2/exllamav2_ext/cpp/sampling.cpp index d4b88197..2fc1d99b 100644 --- a/exllamav2/exllamav2_ext/cpp/sampling.cpp +++ b/exllamav2/exllamav2_ext/cpp/sampling.cpp @@ -38,7 +38,7 @@ void apply_rep_penalty_cpu // { // if (g_rep_mask) free(g_rep_mask); // g_vocab_size = vocab_size; -// g_rep_mask = (bool*) malloc(g_vocab_size * sizeof(bool)); +// g_rep_mask = (bool*) calloc(1, g_vocab_size * sizeof(bool)); // } // memset(g_rep_mask, 0, g_vocab_size * sizeof(bool)); bool* g_rep_mask = (bool*) calloc(vocab_size, sizeof(bool)); @@ -655,7 +655,7 @@ int tfs_cpu int nc = sort_descending(num_candidates, temp_probs, temp_indices, num_candidates); - float* derivative = (float*) malloc(nc * sizeof(float)); + float* derivative = (float*) calloc(1, nc * sizeof(float)); float dsum = 0.0f; for (int i = 0; i < nc - 2; i++) { @@ -759,9 +759,9 @@ int typical_cpu int r_candidates = pre_sort_descending(num_candidates, temp_probs, temp_indices); - float* temp = (float*) malloc(r_candidates * sizeof(float)); - int* entropy_dev_order = (int*) malloc(r_candidates * sizeof(int)); - int* temp_indices_2 = (int*) malloc(r_candidates * sizeof(int)); + float* temp = (float*) calloc(1, r_candidates * sizeof(float)); + int* entropy_dev_order = (int*) calloc(1, r_candidates * sizeof(int)); + int* temp_indices_2 = (int*) calloc(1, r_candidates * sizeof(int)); float neg_entropy = 0.0f; for (int i = 0; i < r_candidates; i++) diff --git a/exllamav2/exllamav2_ext/cuda/cache.cu b/exllamav2/exllamav2_ext/cuda/cache.cu index 53ec1cb2..f1a8db32 100644 --- a/exllamav2/exllamav2_ext/cuda/cache.cu +++ b/exllamav2/exllamav2_ext/cuda/cache.cu @@ -1,4 +1,7 @@ #include "cache.cuh" +#include +#include +#include #include "quant/qdq_util.cuh" #include "util.cuh" @@ -165,8 +168,8 @@ __global__ void fp16_to_q_kv_paged_kernel int page = block_table[pages_per_seq * y + x]; int seqlen = cache_seqlens[y]; - int vx_a = page_size * x; - int px_a = seqlen - vx_a; + int vx_a = (int64_t)page_size * x; + int px_a = (int64_t)seqlen - vx_a; int px_b = px_a + q_len; if (dim % BLOCKSIZE_Q) @@ -346,7 +349,7 @@ __global__ void q_to_fp16_kv_paged_kernel int seqlen = cache_seqlens[y]; if (!seqlen) return; - int vx_a = page_size * x; + int vx_a = (int64_t)page_size * x; int vx_b = min(vx_a + page_size, seqlen); if (dim < BLOCKSIZE_Q) @@ -492,3 +495,82 @@ void array_q_to_fp16_kv_cuda dim, offset, stride ); } + +#define NUM_THREADS 512 +#define NUM_BLOCKS 128 +#define CEIL_DIVIDE(x, size) (((x) + (size) - 1) / (size)) + +__global__ __launch_bounds__(NUM_THREADS) +void cache_rotate_kernel +( + uint8_t* __restrict__ cache, + const uint32_t* __restrict__ order, + uint8_t* __restrict__ temp, + size_t page_size, + size_t rotate_len +) +{ + // Chunk for current CTA + size_t block_size = CEIL_DIVIDE(page_size, gridDim.x); + size_t block_beg = blockIdx.x * block_size; + size_t block_end = min(block_beg + block_size, page_size); + block_size = block_end - block_beg; + if (!block_size) return; + + // Rotate pages + auto copy = [&](uint8_t* dst, uint8_t* src) + { + for (int offset = threadIdx.x * 16; offset < block_size; offset += NUM_THREADS * 16) + *((uint4*) (dst + offset)) = *((uint4*) (src + offset)); + }; + + int i; + copy(temp + block_beg, cache + page_size * (uint64_t) order[0] + block_beg); + for (i = 0; i < rotate_len - 1; ++i) + copy(cache + page_size * (uint64_t) order[i] + block_beg, cache + page_size * (uint64_t) order[i + 1] + block_beg); + copy(cache + page_size * (uint64_t) order[i] + block_beg, temp + block_beg); +} + +/* +Reorder cache pages +- cache, paged cache, shape (num_pages, ...), any dtype, contiguous +- order, sequence to rotate, shape (n,), dtype long +- temp, temp storage, sized as one cache page + +Performs: + +temp <- page[order[0]] +for a, b in pairwise(order): + page[a] <- page[b] +page[order[-1]] <- temp +*/ + +void cache_rotate +( + const at::Tensor& cache, + const at::Tensor& order, + const at::Tensor& temp +) +{ + const at::cuda::OptionalCUDAGuard device_guard(cache.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + + TORCH_CHECK(cache.dim() > 1, "cache argument must have dim >= 2") + TORCH_CHECK(order.dim() == 1, "order argument must have dim == 1") +// TORCH_CHECK_DTYPE(order, kInt); + + size_t num_pages = cache.size(0); + size_t page_size = cache.nbytes() / num_pages; + size_t rotate_len = order.size(0); + + TORCH_CHECK(temp.nbytes() == page_size, "temp tensor incorrect size"); + + cache_rotate_kernel<<>> + ( + (uint8_t*) cache.data_ptr(), + (const uint32_t*) order.data_ptr(), + (uint8_t*) temp.data_ptr(), + page_size, + rotate_len + ); +} diff --git a/exllamav2/exllamav2_ext/cuda/cache.cuh b/exllamav2/exllamav2_ext/cuda/cache.cuh index 4a647c26..aec1e47a 100644 --- a/exllamav2/exllamav2_ext/cuda/cache.cuh +++ b/exllamav2/exllamav2_ext/cuda/cache.cuh @@ -6,6 +6,8 @@ #include #include +#include + void array_fp16_to_fp8_cuda ( cudaStream_t stream, @@ -100,4 +102,11 @@ void array_q_to_fp16_kv_paged_cuda // void array_fp16_to_fp8_ref_cuda(const half* pIn, unsigned char *pOut, int size); // void array_fp8_to_fp16_ref_cuda(const unsigned char* pIn, half* pOut, int size); +void cache_rotate +( + const at::Tensor& cache, + const at::Tensor& order, + const at::Tensor& temp +); + #endif diff --git a/exllamav2/exllamav2_ext/cuda/q_matrix.cu b/exllamav2/exllamav2_ext/cuda/q_matrix.cu index 40350e87..bac12e7e 100644 --- a/exllamav2/exllamav2_ext/cuda/q_matrix.cu +++ b/exllamav2/exllamav2_ext/cuda/q_matrix.cu @@ -603,9 +603,18 @@ bool QMatrix::make_sequential(const uint32_t* cpu_g_idx, cudaStream_t stream) return false; } + // Zero out the allocated memory + size_t mem_size = (height / 8) * width * sizeof(uint32_t); + err = cudaMemset(cuda_new_qweight, 0, mem_size); + if (err != cudaSuccess) {;;; + printf("CUDA memset failed: %s\n", cudaGetErrorString(err)); + cudaFree(cuda_new_qweight); // Free the allocated memory in case of error + return err; + } + uint32_t* cpu_g_idx_map = (uint32_t*) calloc(groups, sizeof(uint32_t)); - uint32_t* cpu_x_map = (uint32_t*) malloc(height * sizeof(uint32_t)); - uint32_t* cpu_x_map_inv = (uint32_t*) malloc(height * sizeof(uint32_t)); + uint32_t* cpu_x_map = (uint32_t*) calloc(1, height * sizeof(uint32_t)); + uint32_t* cpu_x_map_inv = (uint32_t*) calloc(1, height * sizeof(uint32_t)); // Group histogram diff --git a/exllamav2/exllamav2_ext/cuda/q_mlp.cu b/exllamav2/exllamav2_ext/cuda/q_mlp.cu index 8e9cda2e..f37fd249 100644 --- a/exllamav2/exllamav2_ext/cuda/q_mlp.cu +++ b/exllamav2/exllamav2_ext/cuda/q_mlp.cu @@ -324,9 +324,15 @@ void QMoEMLP::forward_ // half* lora_temp ) { - if (num_experts != 4 && num_experts != 8 && num_experts != 16) + if (rows > MAX_Q_GEMM_WEIGHTS) { - printf(" ## num_experts must be 4, 8 or 16\n"); + printf(" ## ropws > %i not implemented\n", MAX_Q_GEMM_WEIGHTS); + DBGI(rows); + } + + if (num_experts != 4 && num_experts != 8 && num_experts != 16 && num_experts != 128) + { + printf(" ## num_experts must be 4, 8, 16 or 128\n"); return; } @@ -354,36 +360,33 @@ void QMoEMLP::forward_ &beta_, temp_logits, num_experts); - // Compute softmax filter to and normalize top-k outputs + // Select activation kernel - dim3 blockDim, gridDim; - blockDim.x = WARPS; - blockDim.y = 1; - gridDim.x = 1; - gridDim.y = DIVIDE(rows, WARPS); - if (num_experts == 4) - softmax4_topk_norm_kernel<<>>(temp_logits, rows, num_experts_per_token); - else if (num_experts == 8) - softmax8_topk_norm_kernel<<>>(temp_logits, rows, num_experts_per_token); - else if (num_experts == 16) - softmax16_topk_norm_kernel<<>>(temp_logits, rows, num_experts_per_token); + int intermediate_size = w1[0]->width; + fp_act_mul_kernel kernel = pick_act_mul_kernel(use_half2, true, act_gelu); // For small no. rows, execute all kernels but pass the routing weights. Rows with a weight of zero will skip dot // product accum and kernels launched with only zero-weights will exit prematurely. - if (rows <= MAX_Q_GEMM_WEIGHTS) + if (num_experts == 4 || num_experts == 8 || num_experts == 16) { - int intermediate_size = w1[0]->width; - fp_act_mul_kernel kernel = pick_act_mul_kernel(use_half2, true, act_gelu); + dim3 blockDim, gridDim; + blockDim.x = WARPSIZE; + blockDim.y = 1; + gridDim.x = 1; + gridDim.y = DIVIDE(rows, WARPSIZE); + if (num_experts == 4) + softmax4_topk_norm_kernel<<>>(temp_logits, rows, num_experts_per_token); + else if (num_experts == 8) + softmax8_topk_norm_kernel<<>>(temp_logits, rows, num_experts_per_token); + else if (num_experts == 16) + softmax16_topk_norm_kernel<<>>(temp_logits, rows, num_experts_per_token); for (int i = 0; i < num_experts; i++) { gemm_half_q_half_cuda(stream, cublas_handle, temp_state, w1[i], temp_a, rows, intermediate_size, columns, true, temp_dq, true, temp_logits + i, num_experts, false); gemm_half_q_half_cuda(stream, cublas_handle, temp_state, w3[i], temp_b, rows, intermediate_size, columns, true, temp_dq, true, temp_logits + i, num_experts, false); -// apply_loras_cuda(cublas_handle, w1_lora[i], loras, w1[i], temp_state, temp_a, lora_temp, rows); -// apply_loras_cuda(cublas_handle, w3_lora[i], loras, w3[i], temp_state, temp_b, lora_temp, rows); - blockDim.x = THREADS_X; blockDim.y = THREADS_Y; gridDim.x = DIVIDE(intermediate_size, THREADS_X) / (use_half2 ? 2 : 1); @@ -391,17 +394,43 @@ void QMoEMLP::forward_ kernel<<>>(temp_a, temp_b, rows, intermediate_size, temp_logits + i, num_experts); gemm_half_q_half_cuda(stream, cublas_handle, temp_a, w2[i], x, rows, columns, intermediate_size, false, temp_dq, true, temp_logits + i, num_experts, true); - -// apply_loras_cuda(cublas_handle, w2_lora[i], loras, w2[i], temp_a, x, lora_temp, rows); } - } + } - // Gather larger number of rows in separate batches according to which experts they trigger, evaluate each MLP - // only on the affected rows and scale by routing weights while adding back directly onto the residual hidden state + // For very large number of experts (Qwen3 etc.) copy to CPU, synchronize and only launch top K experts. This is + // not optimal but the kernel launch overhead is very severe otherwise. Really needs a graph - else + else if (num_experts == 128) { - printf(" ## ropws > %i not implemented\n", MAX_Q_GEMM_WEIGHTS); - DBGI(rows); + dim3 blockDim, gridDim; + blockDim.x = WARPSIZE; + blockDim.y = 1; + gridDim.x = 1; + gridDim.y = DIVIDE(rows, WARPSIZE); + softmax128_topk_norm_kernel<<>>(temp_logits, rows, num_experts_per_token); + + half* h_logits; + h_logits = (half*) malloc(128 * sizeof(half)); + cudaMemcpyAsync(h_logits, temp_logits, 128 * sizeof(half), cudaMemcpyDeviceToHost, stream); + cudaStreamSynchronize(stream); + + for (int i = 0; i < num_experts; i++) + { + uint16_t w = *reinterpret_cast(&h_logits[i]); + if (!w) continue; + + gemm_half_q_half_cuda(stream, cublas_handle, temp_state, w1[i], temp_a, rows, intermediate_size, columns, true, temp_dq, true, temp_logits + i, num_experts, false); + gemm_half_q_half_cuda(stream, cublas_handle, temp_state, w3[i], temp_b, rows, intermediate_size, columns, true, temp_dq, true, temp_logits + i, num_experts, false); + + blockDim.x = THREADS_X; + blockDim.y = THREADS_Y; + gridDim.x = DIVIDE(intermediate_size, THREADS_X) / (use_half2 ? 2 : 1); + gridDim.y = DIVIDE(rows, THREADS_Y); + kernel<<>>(temp_a, temp_b, rows, intermediate_size, temp_logits + i, num_experts); + + gemm_half_q_half_cuda(stream, cublas_handle, temp_a, w2[i], x, rows, columns, intermediate_size, false, temp_dq, true, temp_logits + i, num_experts, true); + } + + free(h_logits); } } diff --git a/exllamav2/exllamav2_ext/cuda/q_mlp_softmax.cuh b/exllamav2/exllamav2_ext/cuda/q_mlp_softmax.cuh index 2cf1b597..224d4a31 100644 --- a/exllamav2/exllamav2_ext/cuda/q_mlp_softmax.cuh +++ b/exllamav2/exllamav2_ext/cuda/q_mlp_softmax.cuh @@ -1,5 +1,5 @@ -#define WARPS 32 +#define WARPSIZE 32 __global__ void softmax16_topk_norm_kernel ( @@ -8,7 +8,7 @@ __global__ void softmax16_topk_norm_kernel const int topk ) { - int row = blockIdx.y * WARPS + threadIdx.x; + int row = blockIdx.y * WARPSIZE + threadIdx.x; if (row >= rows) return; // Softmax @@ -122,7 +122,7 @@ __global__ void softmax8_topk_norm_kernel const int topk ) { - int row = blockIdx.y * WARPS + threadIdx.x; + int row = blockIdx.y * WARPSIZE + threadIdx.x; if (row >= rows) return; // Softmax @@ -206,7 +206,7 @@ __global__ void softmax4_topk_norm_kernel const int topk ) { - int row = blockIdx.y * WARPS + threadIdx.x; + int row = blockIdx.y * WARPSIZE + threadIdx.x; if (row >= rows) return; // Softmax @@ -268,3 +268,101 @@ __global__ void softmax4_topk_norm_kernel logits_int2.y = l23.as_uint32; *row_ptr = logits_int2; } + +__global__ void softmax128_topk_norm_kernel +( + half* __restrict__ x, + const int rows, + const int topk +) +{ + const int row = blockIdx.y * WARPSIZE + threadIdx.x; + if (row >= rows) return; + + #if defined(USE_ROCM) + float f[128]; + #else + register float f[128]; + #endif + + int4* row_ptr = reinterpret_cast(x + row * 128); + + #pragma unroll + for (int v = 0; v < 16; ++v) // 16 × 8 halfs = 128 halfs + { + int4 v4 = row_ptr[v]; + + half2_uint32 h0(v4.x), h1(v4.y), h2(v4.z), h3(v4.w); + + const int base = v * 8; + f[base + 0] = __low2float (h0.as_half2); + f[base + 1] = __high2float(h0.as_half2); + f[base + 2] = __low2float (h1.as_half2); + f[base + 3] = __high2float(h1.as_half2); + f[base + 4] = __low2float (h2.as_half2); + f[base + 5] = __high2float(h2.as_half2); + f[base + 6] = __low2float (h3.as_half2); + f[base + 7] = __high2float(h3.as_half2); + } + + float maxf = -FLT_MAX; + #pragma unroll + for (int i = 0; i < 128; ++i) maxf = fmaxf(maxf, f[i]); + + float sum = 0.f; + #pragma unroll + for (int i = 0; i < 128; ++i) + { + float e = __expf(f[i] - maxf); + f[i] = e; + sum += e; + } + + constexpr float epsilon = 1e-8f; + const float isum = 1.f / (sum + 128.0f * epsilon); + + #pragma unroll + for (int i = 0; i < 128; ++i) f[i] = f[i] * isum + epsilon; + + float remaining = 1.0f; + for (int drop = 0; drop < 128 - topk; ++drop) + { + float minv = 1.0f; + int mini = -1; + #pragma unroll + for (int j = 0; j < 128; ++j) + { + if (f[j] > 0.0f && f[j] < minv) + { + minv = f[j]; + mini = j; + } + } + remaining -= f[mini]; + f[mini] = 0.0f; + } + + const float inv_remaining = 1.f / remaining; + #pragma unroll + for (int i = 0; i < 128; ++i) f[i] *= inv_remaining; + + #pragma unroll + for (int v = 0; v < 16; ++v) + { + const int base = v * 8; + + half2_uint32 h0, h1, h2, h3; + h0.as_half2 = __floats2half2_rn(f[base + 0], f[base + 1]); + h1.as_half2 = __floats2half2_rn(f[base + 2], f[base + 3]); + h2.as_half2 = __floats2half2_rn(f[base + 4], f[base + 5]); + h3.as_half2 = __floats2half2_rn(f[base + 6], f[base + 7]); + + int4 v4; + v4.x = h0.as_uint32; + v4.y = h1.as_uint32; + v4.z = h2.as_uint32; + v4.w = h3.as_uint32; + + row_ptr[v] = v4; + } +} diff --git a/exllamav2/exllamav2_ext/cuda/util.cu b/exllamav2/exllamav2_ext/cuda/util.cu index 4f385791..8f7834ae 100644 --- a/exllamav2/exllamav2_ext/cuda/util.cu +++ b/exllamav2/exllamav2_ext/cuda/util.cu @@ -2,7 +2,7 @@ void print_global_mem(const half* ptr, int rows, int columns, int stride) { - half* temp = (half*) malloc(rows * columns * sizeof(half)); + half* temp = (half*) calloc(1, rows * columns * sizeof(half)); cudaDeviceSynchronize(); cudaMemcpyAsync(temp, ptr, rows * columns * sizeof(half), cudaMemcpyDeviceToHost); diff --git a/exllamav2/exllamav2_ext/ext_bindings.cpp b/exllamav2/exllamav2_ext/ext_bindings.cpp index c93ffc2f..ecbb65b5 100644 --- a/exllamav2/exllamav2_ext/ext_bindings.cpp +++ b/exllamav2/exllamav2_ext/ext_bindings.cpp @@ -22,6 +22,8 @@ #include "ext_element.h" #include "ext_tp.h" +#include "cuda/cache.cuh" + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // quant @@ -95,6 +97,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) m.def("count_match", &count_match, "count_match"); // m.def("array_fp16_to_fp8_ref", &array_fp16_to_fp8_ref, "array_fp16_to_fp8_ref"); // m.def("array_fp8_to_fp16_ref", &array_fp8_to_fp16_ref, "array_fp8_to_fp16_ref"); + m.def("cache_rotate", &cache_rotate, "cache_rotate"); // hadamard diff --git a/exllamav2/exllamav2_ext/ext_quant.cpp b/exllamav2/exllamav2_ext/ext_quant.cpp index 38e7b439..961ca6a9 100644 --- a/exllamav2/exllamav2_ext/ext_quant.cpp +++ b/exllamav2/exllamav2_ext/ext_quant.cpp @@ -173,6 +173,105 @@ std::tuple>, std::vector, float, ui float norm ) { + // --- Mode-Specific Parameters --- + enum Mode { MODE_RELAXED, MODE_BALANCED, MODE_UNIFORM, MODE_AGGRESSIVE, MODE_3_5_2, MODE_3_5_6, MODE_CUSTOM }; + // --- Mode Selection --- + Mode mode = MODE_UNIFORM; // Default mode, Can be changed into other mode or MODE_CUSTOM + + // Define a struct to hold parameters for different modes + struct ModeParams { + float bpw_penalty_scale; + float min_bpw_base; + float opportunistic_temp; + float error_floor; + float targeted_redistribution_max_err_increase; + float high_bpw_donor_threshold; + float bpw_balance_factor; + float low_error_threshold; + int redistribution_iterations; + int opportunistic_iterations; + int num_options_to_explore_per_layer; + int bpw_smoothing_passes; + float bpw_smoothing_threshold; + float targeted_redistribution_bpw_threshold; + }; + + // Define the parameter sets for each mode + const std::vector mode_params = { + // MODE_RELAXED: Minize error first + {0.1f, 3.0f, 0.05f, 0.0f, 1.2f, 5.0f, 0.1f, 0.0009f, 25, 15000, 3, 5, 0.5f, 3.3f}, + + // MODE_BALANCED: Balanced trade-off between BPW uniformity and error + {0.6f, 3.3f, 0.12f, 0.0005f, 1.5f, 5.5f, 1.8f, 0.002f, 50, 30000, 8, 8, 0.75f, 3.6f}, + + // MODE_UNIFORM: Strong emphasis on BPW uniformity + {0.8f, 3.5f, 0.12f, 0.0005f, 1.6f, 6.0f, 3.0f, 0.001f, 80, 40000, 8, 10, 0.8f, 3.7f}, + + // MODE_AGGRESSIVE: Aggressively avoids low BPW, potentially higher error + {1.0f, 3.8f, 0.15f, 0.001f, 1.7f, 6.5f, 4.0f, 0.001f, 100, 50000, 8, 12, 0.9f, 3.9f}, + + // MODE_CUSTOM: User-defined parameters, will be overwritten if using custom mode + {0.8f, 5.0f, 0.12f, 0.0005f, 1.5f, 6.0f, 3.0f, 0.001f, 80, 40000, 8, 10, 0.8f, 5.5f}, + }; + + ModeParams params; + if (mode == MODE_CUSTOM) + { + params = {0.7f, 3.3f, 0.11f, 0.0002f, 1.35f, 5.7f, 2.0f, 0.001f, 70, 35000, 8, 9, 0.8f, 3.6f}; // Example custom parameters, you should change this + } else { + params = mode_params[mode]; + } + + // --- Parameter Application --- + // (Consolidated parameters are grouped together) + + // Penalty-related parameters + const float bpw_penalty_scale = params.bpw_penalty_scale; + const float min_bpw_base = params.min_bpw_base; + const float bpw_balance_factor = params.bpw_balance_factor; + + // Redistribution-related parameters + const int redistribution_iterations = params.redistribution_iterations; + const float targeted_redistribution_bpw_threshold = params.targeted_redistribution_bpw_threshold; + const float targeted_redistribution_max_err_increase = params.targeted_redistribution_max_err_increase; + const float high_bpw_donor_threshold = params.high_bpw_donor_threshold; + const int num_options_to_explore_per_layer = params.num_options_to_explore_per_layer; + + // Opportunistic optimization parameters + const int opportunistic_iterations = params.opportunistic_iterations; + const float initial_opportunistic_temp = params.opportunistic_temp; + const float low_error_threshold = params.low_error_threshold; + + // Other parameters + const float error_floor = params.error_floor; + const int bpw_smoothing_passes = params.bpw_smoothing_passes; + const float bpw_smoothing_threshold = params.bpw_smoothing_threshold; + + // --- Dynamic Minimum BPW --- + auto calculate_dynamic_min_bpw = [&](float target_bpw, float temp_ratio) { + float scaled_min_bpw = min_bpw_base + 0.75f * (target_bpw - min_bpw_base); + return min_bpw_base + temp_ratio * (scaled_min_bpw - min_bpw_base); + }; + + // --- Calculate BPW --- + auto calculate_bpw = [&](const std::tuple& option) { + return 8.0f * std::get<0>(option) / 1024.0f; + }; + + // --- Calculate BPW stats --- + auto calculate_bpw_stats = [&](const std::vector>& sol) { + int num_slots = sol.size(); + std::vector current_bpws(num_slots); + for (int i = 0; i < num_slots; ++i) { + current_bpws[i] = calculate_bpw(sol[i]); + } + float bpw_mean = std::accumulate(current_bpws.begin(), current_bpws.end(), 0.0f) / num_slots; + float bpw_sq_sum = std::inner_product(current_bpws.begin(), current_bpws.end(), current_bpws.begin(), 0.0f); + float bpw_variance = bpw_sq_sum / num_slots - bpw_mean * bpw_mean; + return std::make_pair(bpw_mean, std::sqrt(std::max(0.0f, bpw_variance))); + }; + + // --- Simulated Annealing --- int num_slots = slots.size(); std::random_device rd; @@ -181,48 +280,453 @@ std::tuple>, std::vector, float, ui std::vector solution_idx(num_slots); uint64_t current_cost = 0; - float current_sum = 0.0; + float current_max_exp_error = 0; - float temp = initial_temp; + float temp = initial_temp * 2.5f; int iterations_outer = static_cast(std::log(min_temp / temp) / std::log(cooling_factor)); - - for (int i = 0; i < num_slots; ++i) - { - solution[i] = slots[i][0]; - current_cost += std::get<0>(slots[i][0]); - current_sum += powf(std::get<1>(slots[i][0]), norm); + float target_bpw = max_cost * 8.0f / 1024.0f / num_slots; + + // --- Balanced Initialization --- + for (int i = 0; i < num_slots; ++i) { + int best_idx = 0; + float best_score = -1e10f; + for (int j = 0; j < slots[i].size(); ++j) { + float bpw = calculate_bpw(slots[i][j]); + float error = std::get<1>(slots[i][j]); + // Favor options with BPW close to target and relatively high error + float score = -std::abs(bpw - target_bpw) + error * bpw_balance_factor; + if (score > best_score) { + best_score = score; + best_idx = j; + } + } + solution[i] = slots[i][best_idx]; + current_cost += std::get<0>(slots[i][best_idx]); + current_max_exp_error = std::max(current_max_exp_error, std::get<1>(slots[i][best_idx])); } - for (int j = 0; j < iterations_outer; ++j) - { - for (int k = 0; k < iterations; ++k) - { - int i = std::uniform_int_distribution<>(0, num_slots - 1)(gen); // target slot - int n = std::uniform_int_distribution<>(0, slots[i].size() - 1)(gen); // target option + for (int j = 0; j < iterations_outer; ++j) { + float temp_ratio = temp / (initial_temp * 2.5f); + float min_bpw_limit = calculate_dynamic_min_bpw(target_bpw, temp_ratio); + + for (int k = 0; k < iterations; ++k) { + int i = std::uniform_int_distribution<>(0, num_slots - 1)(gen); + int n = std::uniform_int_distribution<>(0, slots[i].size() - 1)(gen); auto new_option = slots[i][n]; auto old_option = solution[i]; uint64_t delta_cost = std::get<0>(new_option) - std::get<0>(old_option); - float delta_e = powf(std::get<1>(new_option), norm) - powf(std::get<1>(old_option), norm); + float delta_e = std::get<1>(new_option) - std::get<1>(old_option); + + float new_max_exp_error = current_max_exp_error; + if (std::get<1>(old_option) == current_max_exp_error) { + new_max_exp_error = std::get<1>(new_option); + for (int slot_idx = 0; slot_idx < num_slots; slot_idx++) { + if (slot_idx == i) continue; + new_max_exp_error = std::max(new_max_exp_error, std::get<1>(solution[slot_idx])); + } + } else { + new_max_exp_error = std::max(current_max_exp_error, std::get<1>(new_option)); + } + + // Enhanced Dynamic BPW Penalty (applied uniformly to all layers) + float bpw_new = calculate_bpw(new_option); + float bpw_penalty = 0.0f; + if (bpw_new < min_bpw_limit) { + // Clear formula for BPW penalty + bpw_penalty = (min_bpw_limit - bpw_new) * bpw_penalty_scale * (1 + temp_ratio) * bpw_balance_factor; + bpw_penalty = bpw_penalty * bpw_penalty; // Squared penalty + } - if (current_cost + delta_cost <= max_cost || (delta_cost < 0 && current_cost > max_cost)) - { - if (delta_e < 0 || - std::uniform_real_distribution<>(0, 1)(gen) < std::exp(-delta_e / temp)) - { + if (current_cost + delta_cost <= max_cost || (delta_cost < 0 && current_cost > max_cost)) { + if (delta_e + bpw_penalty < 0 || + std::uniform_real_distribution<>(0, 1)(gen) < std::exp(-(delta_e + bpw_penalty) / temp)) { solution[i] = new_option; solution_idx[i] = n; - current_sum += delta_e; current_cost += delta_cost; + current_max_exp_error = new_max_exp_error; } } } temp *= cooling_factor; } + // --- Error-Weighted Bit Redistribution --- + for (int r = 0; r < redistribution_iterations; ++r) { + float temp_ratio = temp / (initial_temp * 2.5f); + float min_bpw_limit = calculate_dynamic_min_bpw(target_bpw, temp_ratio); + + // Calculate BPW statistics and dynamic bpw_threshold + auto [bpw_mean, bpw_stddev] = calculate_bpw_stats(solution); + float bpw_threshold = std::max(min_bpw_limit, bpw_mean - bpw_stddev * bpw_balance_factor); + + std::vector low_bpw_indices; + std::vector high_bpw_indices; + std::vector high_bpw_errors; + + for (int i = 0; i < num_slots; ++i) { + float bpw = calculate_bpw(solution[i]); + if (bpw < bpw_threshold) { + low_bpw_indices.push_back(i); + } + if (bpw > high_bpw_donor_threshold) { + high_bpw_indices.push_back(i); + high_bpw_errors.push_back(std::get<1>(solution[i])); + } + } + + if (high_bpw_indices.empty()) continue; + + std::discrete_distribution high_idx_dist(high_bpw_errors.begin(), high_bpw_errors.end()); + + bool improved = false; + for (int low_idx : low_bpw_indices) { + int high_idx = high_bpw_indices[high_idx_dist(gen)]; + + int best_low_new_idx = -1; + float best_low_new_error = 1e10f; + for (int n = 0; n < slots[low_idx].size(); ++n) { + if (calculate_bpw(slots[low_idx][n]) > calculate_bpw(solution[low_idx])) { + if (std::get<1>(slots[low_idx][n]) < best_low_new_error) + { + best_low_new_error = std::get<1>(slots[low_idx][n]); + best_low_new_idx = n; + } + } + } + + int best_high_new_idx = -1; + float best_high_new_error = 1e10f; + for (int n = 0; n < slots[high_idx].size(); ++n) { + if (calculate_bpw(slots[high_idx][n]) < calculate_bpw(solution[high_idx])) { + if (std::get<1>(slots[high_idx][n]) < best_high_new_error) { + best_high_new_error = std::get<1>(slots[high_idx][n]); + best_high_new_idx = n; + } + } + } + + if (best_low_new_idx != -1 && best_high_new_idx != -1) { + auto new_low_option = slots[low_idx][best_low_new_idx]; + auto new_high_option = slots[high_idx][best_high_new_idx]; + + uint64_t new_cost = current_cost - std::get<0>(solution[low_idx]) - std::get<0>(solution[high_idx]) + + std::get<0>(new_low_option) + std::get<0>(new_high_option); + + if (new_cost <= max_cost) { + float new_max_exp_error = std::get<1>(new_low_option); + for (int i = 0; i < num_slots; i++) { + if (i == low_idx) continue; + if (i == high_idx) { + new_max_exp_error = std::max(new_max_exp_error, std::get<1>(new_high_option)); + } else { + new_max_exp_error = std::max(new_max_exp_error, std::get<1>(solution[i])); + } + } + + if (std::get<1>(new_low_option) < error_floor || std::get<1>(new_high_option) < error_floor) continue; + + if (new_max_exp_error < current_max_exp_error * (1 + 0.1f * bpw_balance_factor)) { + solution[low_idx] = new_low_option; + solution_idx[low_idx] = best_low_new_idx; + solution[high_idx] = new_high_option; + solution_idx[high_idx] = best_high_new_idx; + current_cost = new_cost; + current_max_exp_error = new_max_exp_error; + improved = true; + } + } + } + } + } + + // --- Enhanced Opportunistic Optimization with Simulated Annealing --- + float current_sum_log_err = 0; + for (int i = 0; i < num_slots; ++i) { + current_sum_log_err += log(std::get<1>(solution[i])); + } + + float best_sum_log_err = current_sum_log_err; + std::vector> best_solution = solution; + std::vector best_solution_idx = solution_idx; + float local_temp = initial_opportunistic_temp; + + for (int i = 0; i < opportunistic_iterations; ++i) { + float temp_ratio = temp / (initial_temp * 2.5f); + float min_bpw_limit = calculate_dynamic_min_bpw(target_bpw, temp_ratio); + + // Select a slot to adjust + int target_slot = std::uniform_int_distribution<>(0, num_slots - 1)(gen); + + // Calculate the global average BPW + float global_bpw_sum = 0; + for (int j = 0; j < num_slots; ++j) { + global_bpw_sum += calculate_bpw(solution[j]); + } + float global_bpw_avg = global_bpw_sum / num_slots; + + // Adjust BPW of the target slot + std::vector> new_solution = solution; + std::vector new_solution_idx = solution_idx; + float new_sum_log_err = current_sum_log_err; + uint64_t new_cost = current_cost; + + float current_bpw = calculate_bpw(solution[target_slot]); + + // Calculate average error + float avg_error = 0; + for (int k = 0; k < num_slots; ++k) { + avg_error += std::get<1>(solution[k]); + } + avg_error /= num_slots; + + // Calculate error ratio for the target slot + float error_ratio = std::get<1>(solution[target_slot]) / avg_error; + + // Enhanced adjustment factor, more sensitive to error ratio + float adjustment = 0.5f + 0.5f * error_ratio; + + // Adjust BPW towards the target, weighted by error, with a bias towards higher BPW + if (current_bpw < global_bpw_avg + adjustment) { + // Search for a higher BPW option + for (int n = 0; n < slots[target_slot].size(); ++n) { + auto new_option = slots[target_slot][n]; + float new_option_bpw = calculate_bpw(new_option); + if (new_option_bpw > current_bpw && new_option_bpw <= current_bpw + adjustment) { + if (new_cost - std::get<0>(solution[target_slot]) + std::get<0>(new_option) <= max_cost) { + if (std::get<1>(new_option) < error_floor) continue; + new_cost = new_cost - std::get<0>(solution[target_slot]) + std::get<0>(new_option); + new_sum_log_err = new_sum_log_err - log(std::get<1>(solution[target_slot])) + log(std::get<1>(new_option)); + new_solution[target_slot] = new_option; + new_solution_idx[target_slot] = n; + break; + } + } + } + } else if (current_bpw > global_bpw_avg) { + // Search for a lower BPW option + for (int n = slots[target_slot].size() - 1; n >= 0; --n) { + auto new_option = slots[target_slot][n]; + float new_option_bpw = calculate_bpw(new_option); + if (new_option_bpw < current_bpw && new_option_bpw >= current_bpw - adjustment) { + if (new_cost - std::get<0>(solution[target_slot]) + std::get<0>(new_option) <= max_cost) { + if (std::get<1>(new_option) < error_floor) continue; + new_cost = new_cost - std::get<0>(solution[target_slot]) + std::get<0>(new_option); + new_sum_log_err = new_sum_log_err - log(std::get<1>(solution[target_slot])) + log(std::get<1>(new_option)); + new_solution[target_slot] = new_option; + new_solution_idx[target_slot] = n; + break; + } + } + } + } + + // Calculate new max exp error + float new_max_exp_error = 0; + for (int j = 0; j < num_slots; ++j) { + new_max_exp_error = std::max(new_max_exp_error, std::get<1>(new_solution[j])); + } + + // Acceptance criterion with error equalization focus + bool accept = false; + float delta_sum_log_err = new_sum_log_err - current_sum_log_err; + + // Dampen penalty for low errors, but less aggressively + float error_factor = 1.0f; + if (current_max_exp_error < low_error_threshold) { + error_factor = 0.25f; + } + + if (new_cost <= max_cost) { + if (delta_sum_log_err * error_factor < 0 || std::uniform_real_distribution<>(0, 1)(gen) < std::exp(-delta_sum_log_err * error_factor / local_temp)) { + accept = true; + for (int j = 0; j < num_slots; ++j) { + if (calculate_bpw(new_solution[j]) < min_bpw_limit) { + accept = false; + break; + } + } + } + } + + if (accept) { + solution = new_solution; + solution_idx = new_solution_idx; + current_sum_log_err = new_sum_log_err; + current_cost = new_cost; + current_max_exp_error = new_max_exp_error; + + if (current_sum_log_err < best_sum_log_err) { + best_sum_log_err = current_sum_log_err; + best_solution = solution; + best_solution_idx = solution_idx; + } + } + + local_temp *= 0.95f; + } + + // Use the best solution found during opportunistic optimization + solution = best_solution; + solution_idx = best_solution_idx; + current_sum_log_err = best_sum_log_err; + + // --- Enhanced BPW Smoothing (Post-processing) --- + for (int pass = 0; pass < bpw_smoothing_passes; ++pass) { + for (int i = 1; i < num_slots - 1; ++i) { + float current_bpw = calculate_bpw(solution[i]); + float prev_bpw = calculate_bpw(solution[i - 1]); + float next_bpw = calculate_bpw(solution[i + 1]); + float avg_neighbor_bpw = (prev_bpw + next_bpw) / 2.0f; + + if (current_bpw < avg_neighbor_bpw - bpw_smoothing_threshold) { + // Find a higher BPW option for the current slot + for (int n = 0; n < slots[i].size(); ++n) { + auto new_option = slots[i][n]; + if (calculate_bpw(new_option) > current_bpw && calculate_bpw(new_option) <= avg_neighbor_bpw) { + if (current_cost - std::get<0>(solution[i]) + std::get<0>(new_option) <= max_cost) { + if (std::get<1>(new_option) < error_floor) continue; + float new_max_err = 0; + for (int j = 0; j < num_slots; ++j) { + if (j == i) { + new_max_err = std::max(new_max_err, std::get<1>(new_option)); + } else { + new_max_err = std::max(new_max_err, std::get<1>(solution[j])); + } + } + + if (new_max_err < current_max_exp_error * (1 + 0.1f * bpw_balance_factor)) { + current_cost = current_cost - std::get<0>(solution[i]) + std::get<0>(new_option); + solution[i] = new_option; + solution_idx[i] = n; + current_max_exp_error = new_max_err; + break; + } + } + } + } + } + } + } + + // --- Enhanced Targeted Bit Redistribution (Post-processing) --- + for (int iter = 0; iter < num_slots * 3; ++iter) { + // Create a global pool of donor indices + std::vector donor_indices; + std::vector donor_errors; + for (int j = 0; j < num_slots; ++j) { + if (calculate_bpw(solution[j]) > high_bpw_donor_threshold && std::get<1>(solution[j]) < low_error_threshold) { + donor_indices.push_back(j); + donor_errors.push_back(std::get<1>(solution[j])); + } + } + + if (donor_indices.empty()) continue; + + std::discrete_distribution donor_dist(donor_errors.begin(), donor_errors.end()); + + for (int i = 0; i < num_slots; ++i) { + float current_bpw = calculate_bpw(solution[i]); + if (current_bpw < targeted_redistribution_bpw_threshold) { + int donor_idx = donor_indices[donor_dist(gen)]; + + std::vector higher_bpw_options; + for (int n = 0; n < slots[i].size(); ++n) { + if (calculate_bpw(slots[i][n]) > current_bpw) { + higher_bpw_options.push_back(n); + } + } + + std::shuffle(higher_bpw_options.begin(), higher_bpw_options.end(), gen); + int options_to_explore = std::min((int)higher_bpw_options.size(), num_options_to_explore_per_layer); + + for (int option_idx = 0; option_idx < options_to_explore; ++option_idx) { + int best_new_idx = higher_bpw_options[option_idx]; + auto new_option = slots[i][best_new_idx]; + + if (std::get<1>(new_option) < error_floor) continue; + + int best_donor_new_idx = -1; + float best_donor_new_error = 1e10f; + for (int n = 0; n < slots[donor_idx].size(); ++n) { + if (calculate_bpw(slots[donor_idx][n]) < calculate_bpw(solution[donor_idx])) { + if (std::get<1>(slots[donor_idx][n]) < best_donor_new_error) { + best_donor_new_error = std::get<1>(slots[donor_idx][n]); + best_donor_new_idx = n; + } + } + } + + if (best_donor_new_idx != -1) { + auto donor_new_option = slots[donor_idx][best_donor_new_idx]; + + if (std::get<1>(donor_new_option) < error_floor) continue; + + uint64_t new_cost = current_cost - std::get<0>(solution[i]) - std::get<0>(solution[donor_idx]) + + std::get<0>(new_option) + std::get<0>(donor_new_option); + + if (new_cost <= max_cost) { + float new_max_err = std::get<1>(new_option); + for (int j = 0; j < num_slots; ++j) { + if (j == i) continue; + if (j == donor_idx) { + new_max_err = std::max(new_max_err, std::get<1>(donor_new_option)); + } else { + new_max_err = std::max(new_max_err, std::get<1>(solution[j])); + } + } + + if (new_max_err < current_max_exp_error * targeted_redistribution_max_err_increase) { + current_cost = new_cost; + solution[i] = new_option; + solution_idx[i] = best_new_idx; + solution[donor_idx] = donor_new_option; + solution_idx[donor_idx] = best_donor_new_idx; + current_max_exp_error = new_max_err; + break; + } + } + } + } + } + } + } + + // --- Final Cost Check and Rollback (if necessary) --- + if (current_cost > max_cost) { + std::vector> bpw_error_indices(num_slots); + for (int i = 0; i < num_slots; ++i) { + float bpw = calculate_bpw(solution[i]); + float error = std::get<1>(solution[i]); + float penalty = (bpw < targeted_redistribution_bpw_threshold) ? 1000.0f : 0.0f; + bpw_error_indices[i] = {error + penalty, bpw, i}; + } + std::sort(bpw_error_indices.begin(), bpw_error_indices.end(), std::greater>()); + + for (const auto& tuple : bpw_error_indices) { + int i = std::get<2>(tuple); + for (int n = slots[i].size() - 1; n >= 0; --n) { + if (calculate_bpw(slots[i][n]) < calculate_bpw(solution[i])) { + if (current_cost - std::get<0>(solution[i]) + std::get<0>(slots[i][n]) <= max_cost) { + uint64_t delta_cost = std::get<0>(slots[i][n]) - std::get<0>(solution[i]); + current_cost += delta_cost; + solution[i] = slots[i][n]; + solution_idx[i] = n; + break; + } + } + } + if (current_cost <= max_cost) break; + } + } + + // Calculate final max error and sum of log errors float max_err = 0.0f; - for (int i = 0; i < num_slots; ++i) + float sum_log_err = 0.0; + for (int i = 0; i < num_slots; ++i) { max_err = std::max(max_err, std::get<1>(solution[i])); + sum_log_err += log(std::get<1>(solution[i])); + } - return { solution, solution_idx, current_sum, current_cost, max_err }; + return { solution, solution_idx, sum_log_err, current_cost, max_err }; } diff --git a/exllamav2/exllamav2_ext/ext_stloader.cpp b/exllamav2/exllamav2_ext/ext_stloader.cpp index 2b0b4c1e..0a4ce540 100644 --- a/exllamav2/exllamav2_ext/ext_stloader.cpp +++ b/exllamav2/exllamav2_ext/ext_stloader.cpp @@ -31,7 +31,7 @@ void stloader_read } else { - load_buffer = (uint8_t*) malloc(size); + load_buffer = (uint8_t*) calloc(1, size); TORCH_CHECK(load_buffer, "Can't allocate buffer for tensor"); cuda_buffer = (uint8_t*) target.data_ptr(); cudaSetDevice(device.value().index()); diff --git a/exllamav2/generator/dynamic.py b/exllamav2/generator/dynamic.py index 07a121fc..c292eb4b 100644 --- a/exllamav2/generator/dynamic.py +++ b/exllamav2/generator/dynamic.py @@ -1352,27 +1352,34 @@ def defrag_cache(self): if not self.paged: return + # Defragment once job queue is empty after touching all the cache pages if self.access_serial < self.last_defrag_serial + self.max_pages: return self.last_defrag_serial = self.access_serial assert not self.referenced_pages - @dataclass class CacheNode: page: CachePage | None - parent: CachePage | None = None - children: set[CacheNode] = None - left_page: int = len(self.all_pages) + parent: CacheNode | None + children: set[CacheNode] | None + children_sorted: deque[CacheNode] | None + left_page: int = 0 def __init__(self, page_): self.page = page_ - if self.page: - self.left_page = page_.access_serial + self.parent = None self.children = set() + self.children_sorted = None + self.left_page = page_.access_serial if page_ else 0 def __hash__(self): return id(self) def __eq__(self, other): return self is other + def presort(self, recursive = True): + self.children_sorted = deque(sorted(self.children, key = lambda x: x.left_page)) + if recursive: + for c in self.children: + c.presort() # Build a tree of the current cache @@ -1393,28 +1400,50 @@ def __eq__(self, other): # Remove oldest branch until tree is empty + root_node.presort() + shift_counts = {} + new_page_index = 0 while root_node.children: - oldest = min(root_node.children, key = lambda x: x.left_page) + oldest = root_node.children_sorted[0] node = oldest skipped_nodes = set() while True: node.page.new_page_index = new_page_index + shift = node.page.new_page_index - node.page.page_index + if shift in shift_counts: + shift_counts[shift] += 1 + else: + shift_counts[shift] = 1 new_page_index += 1 if not node.children: break - next_node = min(node.children, key = lambda x: x.left_page) - skipped_nodes |= set([n for n in node.children if n != next_node]) + next_node = node.children_sorted[0] + if len(node.children_sorted) > 1: + skipped_nodes |= set([n for n in node.children if n != next_node]) node = next_node root_node.children.remove(oldest) + root_node.children_sorted.popleft() root_node.children |= skipped_nodes + if len(skipped_nodes): + root_node.presort(False) + + # Adjust overall shift to minimize page copies + + shift_adjust = max(shift_counts, key = shift_counts.get) # Order of operations defrag_map = {} for page in self.all_pages: + page.new_page_index = (page.new_page_index - shift_adjust + self.max_pages) % self.max_pages if page.page_index != page.new_page_index: defrag_map[page.new_page_index] = page.page_index + # Don't bother if less than 10% of cache is fragmented + + if len(defrag_map) <= self.max_pages // 10: + return + # Shuffle pages cache_tensors = self.cache.all_tensors() @@ -1435,12 +1464,11 @@ def __eq__(self, other): source = defrag_map[target] del defrag_map[target] - rotation = [r * self.page_size for r in rotation] + rotation = torch.tensor(rotation, dtype = torch.int) for cache, buffer in zip(cache_tensors, defrag_buffers): - buffer[:, :, :, :].copy_(cache[:, rotation[0] : rotation[0] + self.page_size, :, :]) - for a, b in pairwise(rotation): - cache[:, a : a + self.page_size, :, :].copy_(cache[:, b : b + self.page_size, :, :]) - cache[:, rotation[-1] : rotation[-1] + self.page_size, :, :].copy_(buffer[:, :, :, :]) + rotation = rotation.to(cache.device) + cache = cache.view(cache.shape[1] // self.page_size, -1) + ext_c.cache_rotate(cache, rotation, buffer) # Update page table diff --git a/exllamav2/model.py b/exllamav2/model.py index e487341a..9cecc965 100644 --- a/exllamav2/model.py +++ b/exllamav2/model.py @@ -205,32 +205,31 @@ def set_device_map( reserve_bytes_attn = [0 for a in allocation] fixed_bytes = [0 for a in allocation] - current_idx = 0 - for idx, module in enumerate(self.modules): + # Start from the last device index + current_idx = len(allocation_bytes) - 1 + for idx, module in reversed(list(enumerate(self.modules))): # Special case for token embeddings on CPU - if isinstance(module, ExLlamaV2Embedding) and embed_cpu: - module.set_device_idx(-1) continue # Special case for attention - attn_bytes_current = 0 - if isinstance(module, ExLlamaV2Attention): attn_bytes_current = module.temp_attn_size() - - # Advance current_idx until module fits in allocation + if isinstance(module, ExLlamaV2Attention): + attn_bytes_current = module.temp_attn_size() + # Move current_idx backward until module fits in allocation footprint = module.weight_footprint() # Footprint, in bytes scratch = module.scratch_space() # Scratch space required by module while True: - assert current_idx < len(allocation_bytes), "Insufficient space in device allocation" + assert current_idx >= 0, "Insufficient space in device allocation" dev_scratch = max(scratch, reserve_bytes[current_idx]) dev_scratch_attn = max(attn_bytes_current, reserve_bytes_attn[current_idx]) - if footprint + dev_scratch + dev_scratch_attn <= allocation_bytes[current_idx]: break - current_idx += 1 + if footprint + dev_scratch + dev_scratch_attn <= allocation_bytes[current_idx]: + break + current_idx -= 1 # Size for fixed tensors diff --git a/exllamav2/moe_mlp.py b/exllamav2/moe_mlp.py index 02ebbfd0..03559d92 100644 --- a/exllamav2/moe_mlp.py +++ b/exllamav2/moe_mlp.py @@ -167,7 +167,7 @@ def scratch_space_fixed(self) -> int: def scratch_space(self) -> int: - assert self.model.config.intermediate_size >= self.model.config.hidden_size + # assert self.model.config.intermediate_size >= self.model.config.hidden_size return self.temp_state_size() + \ self.temp_gathered_state_size() + \ self.temp_a_size() + \ @@ -235,7 +235,7 @@ def forward( # TODO: LoRA currently uses the Torch codepath. Needs conditional (early-exit) kernels with output scaling # for the LoRA matmuls in order to work with the C++ path - if self.q_handle is None or intermediates or batch_size * sequence_length > 4 or self.num_experts not in [4, 8, 16] or (loras is not None and len(loras) > 0): + if self.q_handle is None or intermediates or batch_size * sequence_length > 4 or self.num_experts not in [4, 8, 16, 128] or (loras is not None and len(loras) > 0): return self.forward_torch(hidden_states, cache, attn_params, past_len, intermediates, loras = loras, **kwargs) # if loras is None or self.temp_lora_size == 0: diff --git a/exllamav2/tokenizer/__init__.py b/exllamav2/tokenizer/__init__.py index 18a67114..1e2b8ed8 100644 --- a/exllamav2/tokenizer/__init__.py +++ b/exllamav2/tokenizer/__init__.py @@ -1,5 +1,4 @@ from exllamav2.version import __version__ from exllamav2.tokenizer.base import ExLlamaV2TokenizerBase -from exllamav2.tokenizer.spm import ExLlamaV2TokenizerSPM from exllamav2.tokenizer.hf import ExLlamaV2TokenizerHF diff --git a/exllamav2/tokenizer/spm.py b/exllamav2/tokenizer/spm.py deleted file mode 100644 index 4c59f5b1..00000000 --- a/exllamav2/tokenizer/spm.py +++ /dev/null @@ -1,57 +0,0 @@ -from __future__ import annotations -from typing import List, Union -from sentencepiece import SentencePieceProcessor -from exllamav2.tokenizer.base import ExLlamaV2TokenizerBase - -# Wrapper for SentencePiece - -class ExLlamaV2TokenizerSPM(ExLlamaV2TokenizerBase): - - vocab: list[str] | None - - def __init__(self, tokenizer_model: str): - super().__init__() - self.vocab = None - self.spm = SentencePieceProcessor(model_file = tokenizer_model) - - def unk_id(self) -> int or None: return self.spm.unk_id() - def pad_id(self) -> int or None: return self.spm.pad_id() - def bos_id(self) -> int or None: return self.spm.bos_id() - def eos_id(self) -> int or None: return self.spm.eos_id() - def unk_token(self) -> str or None: return None - def pad_token(self) -> str or None: return None - def bos_token(self) -> str or None: return None - def eos_token(self) -> str or None: return None - - def space_char(self): return "▁" - def newline_char(self): return "\n" - - def enumerate_tokens(self): - if self.vocab is not None: return enumerate(self.vocab) - self.vocab = [] - for i in range(self.vocab_size()): - p = self.spm.id_to_piece(i) - if all(c == self.space_char() for c in p): - d = " " * len(p) - else: - d = self.spm.decode(i) - if p.startswith(self.space_char()) and not d.startswith(" "): d = " " + d - self.vocab.append(d) - return enumerate(self.vocab) - - def id_to_piece(self, idx: int) -> str: - return self.spm.id_to_piece(idx) - - def piece_to_id(self, text: str) -> int: - return self.spm.piece_to_id(text) - - def vocab_size(self) -> int: - return self.spm.vocab_size() - - def decode(self, ids: List[int]) -> str: - text = self.spm.decode(ids) - return text - - def encode(self, text: list or str) -> list: - encoding = self.spm.EncodeAsIds(text) - return encoding diff --git a/exllamav2/tokenizer/tokenizer.py b/exllamav2/tokenizer/tokenizer.py index 0d867e7b..611534d8 100644 --- a/exllamav2/tokenizer/tokenizer.py +++ b/exllamav2/tokenizer/tokenizer.py @@ -5,7 +5,6 @@ import os, json, re from exllamav2.tokenizer import ( ExLlamaV2TokenizerBase, - ExLlamaV2TokenizerSPM, ExLlamaV2TokenizerHF ) import threading @@ -93,13 +92,12 @@ def __init__( Defer initialization of some data structures to speed up loading :param force_json: - No effect from v0.2.3. tokenizer.json is now preferred over tokenizer.model by default. - If True and no tokenizer.json is present in the model directory, will emit a warning before - falling back to SPM + No effect from v0.2.3. tokenizer.json is now preferred over tokenizer.model by default. From v0.3.1 + tokenizer.model is not used at all :param force_spm: - Use only tokenizer.model (SentencePiece) even if tokenizer.model (HF Tokenizers) - is available + Deprecated, Sentencepiece is abandoned and no longer supported. All SPM tokenizers should + still load correctly via the Tokenizers library """ self.config = config @@ -123,33 +121,31 @@ def __init__( # Detect tokenizer model type and initialize - path_spm = os.path.join(self.config.model_dir, "tokenizer.model") + assert not force_spm, "tokenizer.py: force_spm is deprecated. Sentencepiece is no longer supported." path_hf = os.path.join(self.config.model_dir, "tokenizer.json") - if os.path.exists(path_hf) and not force_spm: - self.tokenizer_model = ExLlamaV2TokenizerHF(path_hf) - elif os.path.exists(path_spm): - if force_json: - print(" !! Warning: Tokenizer loading with force_json = True but no tokenizer.json found, falling back to tokenizer.model") - self.tokenizer_model = ExLlamaV2TokenizerSPM(path_spm) - else: + if not os.path.exists(path_hf): raise FileNotFoundError("No supported tokenizer found.") + self.tokenizer_model = ExLlamaV2TokenizerHF(path_hf) + # Attempt to load added tokens from tokenizer.json self.extended_piece_to_id = {} self.unspecial_piece_to_id = {} tokenizer_json_path = os.path.join(self.config.model_dir, "tokenizer.json") - if os.path.exists(tokenizer_json_path): - with open(tokenizer_json_path, encoding = "utf8") as f: - tokenizer_json = json.load(f) - if "added_tokens" in tokenizer_json: - for v in tokenizer_json["added_tokens"]: - if v["special"]: - self.extended_piece_to_id[v["content"]] = v["id"] - else: - self.unspecial_piece_to_id[v["content"]] = v["id"] + if not os.path.exists(tokenizer_json_path): + raise ValueError(" ## Model does not include a tokenizer.json file. SentencePiece-only tokenizers are no longer supported") + + with open(tokenizer_json_path, encoding = "utf8") as f: + tokenizer_json = json.load(f) + if "added_tokens" in tokenizer_json: + for v in tokenizer_json["added_tokens"]: + if v["special"]: + self.extended_piece_to_id[v["content"]] = v["id"] + else: + self.unspecial_piece_to_id[v["content"]] = v["id"] # Attempt to load tokenizer_config.json diff --git a/exllamav2/version.py b/exllamav2/version.py index 467ab302..5a271749 100644 --- a/exllamav2/version.py +++ b/exllamav2/version.py @@ -1 +1 @@ -__version__ = "0.2.9" \ No newline at end of file +__version__ = "0.3.1" \ No newline at end of file diff --git a/exllamav2/vlm/vision_tower.py b/exllamav2/vlm/vision_tower.py index 07f40dd9..8781e3bc 100644 --- a/exllamav2/vlm/vision_tower.py +++ b/exllamav2/vlm/vision_tower.py @@ -42,6 +42,8 @@ def __init__( km = self.archparams.keys self.modules = [] + self.tp_context = None + # Preprocessor if cfg.vision_model_type == "pixtral": diff --git a/requirements.txt b/requirements.txt index d37b1375..ef82e1ad 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,6 @@ setuptools fastparquet torch>=2.2.0 safetensors>=0.4.3 -sentencepiece>=0.1.97 pygments websockets regex diff --git a/setup.py b/setup.py index 96912466..67e4fa1b 100644 --- a/setup.py +++ b/setup.py @@ -128,7 +128,6 @@ "fastparquet", "torch>=2.2.0", "safetensors>=0.3.2", - "sentencepiece>=0.1.97", "pygments", "websockets", "regex",