diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 847c7ef7a..7597ed9f9 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -162,7 +162,7 @@ jobs: - name: Run tests run: pytest --durations=100 - test-cpu-ipex: + test-cpu-intel: if: github.repository == 'bitsandbytes-foundation/bitsandbytes' needs: build-cpu runs-on: banb-aws-general-8-plus-use1-public-80 @@ -186,7 +186,6 @@ jobs: - name: Install dependencies run: | pip install torch==2.7.1 --index-url https://download.pytorch.org/whl/cpu - pip install intel_extension_for_pytorch==2.7.0 --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/cpu/us/ pip install -e ".[test]" pip install pytest-cov @@ -196,9 +195,6 @@ jobs: - name: Show environment information run: python -m torch.utils.collect_env - - name: IPEX smoke test - run: python -c "import torch; import intel_extension_for_pytorch as ipex; print(torch.__version__); print(ipex.__version__);" - - name: Run tests run: pytest --durations=100 @@ -286,15 +282,6 @@ jobs: fail-fast: false matrix: torch_version: ["2.7.1"] #["2.6.0", "2.7.1"] - ipex: [false] - # ipex: [true, false] - # include: - # - torch_version: "2.6.0" - # ipex: true - # ipex_version: "2.6.10+xpu" - # - torch_version: "2.7.1" - # ipex: true - # ipex_version: "2.7.10+xpu" runs-on: group: bandb-itac-bmsprpvc1550-8-1gpu env: @@ -330,10 +317,6 @@ jobs: - name: Install PyTorch run: pip install torch==${{ matrix.torch_version }} --index-url https://download.pytorch.org/whl/xpu - - name: Install IPEX - if: matrix.ipex == true - run: pip install intel_extension_for_pytorch==${{ matrix.ipex_version }} --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ - - name: Install dependencies run: | pip install -e ".[test]" diff --git a/CMakeLists.txt b/CMakeLists.txt index 770b4ba30..429570443 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -28,11 +28,12 @@ set(CUDA_FILES csrc/ops.cu csrc/kernels.cu) set(HIP_FILES csrc/ops.hip csrc/kernels.hip) set(MPS_FILES csrc/mps_ops.mm) set(METAL_FILES csrc/mps_kernels.metal) +set(XPU_FILES csrc/xpu_ops.cpp csrc/xpu_kernels.cpp) # C++ sources are always included list(APPEND SRC_FILES ${CPP_FILES}) -set(COMPUTE_BACKEND "cpu" CACHE STRING "The compute backend to use (cpu, cuda, hip, mps)") -set_property(CACHE COMPUTE_BACKEND PROPERTY STRINGS cpu cuda hip mps) +set(COMPUTE_BACKEND "cpu" CACHE STRING "The compute backend to use (cpu, cuda, hip, mps, xpu)") +set_property(CACHE COMPUTE_BACKEND PROPERTY STRINGS cpu cuda hip mps xpu) option(PTXAS_VERBOSE "Pass through -v flag to PTX Assembler" OFF) if(APPLE) @@ -64,10 +65,18 @@ elseif(${COMPUTE_BACKEND} STREQUAL "mps") set(BUILD_CUDA OFF) set(BUILD_HIP OFF) set(BUILD_MPS ON) +elseif(${COMPUTE_BACKEND} STREQUAL "xpu") + if(APPLE) + message(FATAL_ERROR "XPU is not supported on macOS" ) + endif() + set(BUILD_CUDA OFF) + set(BUILD_MPS OFF) + set(BUILD_XPU ON) else() set(BUILD_CUDA OFF) set(BUILD_HIP OFF) set(BUILD_MPS OFF) + set(BUILD_XPU OFF) endif() @@ -217,6 +226,15 @@ elseif(BUILD_MPS) COMMENT "Compiling Metal kernels" VERBATIM) add_custom_target(metallib DEPENDS "bitsandbytes/bitsandbytes.metallib") +elseif(BUILD_XPU) + list(APPEND SRC_FILES ${XPU_FILES}) + string(APPEND BNB_OUTPUT_NAME "_xpu") + add_compile_definitions(BUILD_XPU) + set(CMAKE_C_COMPILER icx) + set(CMAKE_CXX_COMPILER icpx) + if(WIN32) + set(CMAKE_CXX_COMPILER icx) + endif() else() string(APPEND BNB_OUTPUT_NAME "_cpu") set(GPU_SOURCES) @@ -285,6 +303,15 @@ if(BUILD_MPS) add_dependencies(bitsandbytes metallib) target_link_libraries(bitsandbytes objc "-framework Foundation" "-framework Metal" "-framework MetalPerformanceShaders" "-framework MetalPerformanceShadersGraph") endif() +if(BUILD_XPU) + set(SYCL_LINK_FLAGS "-fsycl;--offload-compress;-fsycl-targets=spir64_gen,spir64;-Xs;-device pvc,xe-lpg,ats-m150 -options ' -cl-intel-enable-auto-large-GRF-mode -cl-poison-unsupported-fp64-kernels -cl-intel-greater-than-4GB-buffer-required'") + set(SYCL_COMPILE_FLAGS "-fsycl;-fhonor-nans;-fhonor-infinities;-fno-associative-math;-fno-approx-func;-fno-sycl-instrument-device-code;--offload-compress;-fsycl-targets=spir64_gen,spir64;") + + set_property(TARGET bitsandbytes PROPERTY CXX_STANDARD 20) + target_compile_options(bitsandbytes PRIVATE ${SYCL_COMPILE_FLAGS}) + target_link_options(bitsandbytes PRIVATE ${SYCL_LINK_FLAGS}) + +endif() if(WIN32) set_target_properties(bitsandbytes PROPERTIES PREFIX "lib") diff --git a/bitsandbytes/_ops.py b/bitsandbytes/_ops.py index a260852f5..9a3ac46ac 100644 --- a/bitsandbytes/_ops.py +++ b/bitsandbytes/_ops.py @@ -4,8 +4,6 @@ import torch -from .cextension import ipex_cpu, ipex_xpu - _IS_TORCH_GTE_24 = False if hasattr(torch.library, "register_fake"): @@ -329,22 +327,3 @@ def _( ) torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}") torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}") - - -if ipex_cpu or ipex_xpu: - # Register the dequantize_nf4_ipex implementation - torch.library.define( - "bitsandbytes::dequantize_nf4_ipex", - "(Tensor A, Tensor absmax, int blocksize, int[] shape, ScalarType dtype) -> Tensor", - ) - - @register_fake("bitsandbytes::dequantize_nf4_ipex") - def _( - A: torch.Tensor, - absmax: torch.Tensor, - blocksize: int, - shape: Sequence[int], - dtype: torch.dtype, - ) -> torch.Tensor: - torch._check_is_size(blocksize) - return torch.empty(shape, dtype=dtype, device=A.device) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 80fc86861..cb761fe24 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -8,7 +8,6 @@ from typing_extensions import deprecated import bitsandbytes.functional as F -from bitsandbytes.functional import ipex_cpu, ipex_xpu # The inverse transformation for the colTuring and colAmpere format were contributed by Alex Borzunov: # https://github.com/bigscience-workshop/petals/blob/main/src/petals/utils/linear8bitlt_patch.py @@ -320,8 +319,6 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState): CB = state.CB.data.to(A.dtype).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0)) output = torch.nn.functional.linear(A, CB, bias) - # to pass the test: tests/test_modules.py::test_linear8bitlt_no_fp16_weights[2.0-xpu] - state.idx = False ctx.state = state ctx.dtype_A = A.dtype ctx.grad_shape = A.shape @@ -425,9 +422,9 @@ def matmul( if threshold > 0.0: state.threshold = threshold # MatMul8bitLt is slower because no fast kernel for quant/dequant 8bit in CPU/XPU - if state.is_training: - if (A.device.type == "cpu" and ipex_cpu) or (A.device.type == "xpu" and ipex_xpu): - return MatMul8bitFp.apply(A, B, out, bias, state) + if state.is_training and A.device.type in ("cpu", "xpu"): + return MatMul8bitFp.apply(A, B, out, bias, state) + return MatMul8bitLt.apply(A, B, out, bias, state) @@ -440,17 +437,6 @@ def matmul_4bit( ): assert quant_state is not None - if A.device.type in ("cpu", "xpu") and A.requires_grad == False: - if getattr(quant_state, "ipex", False): - # IPEX CPU will change weight to 4D so don't need transpose - B = B.t() if B.dim() == 2 else B - out = F.gemv_4bit(A, B, out, state=quant_state) - if bias is not None: - out += bias - return out - else: - return MatMul4Bit.apply(A, B, out, bias, quant_state) - if A.numel() == A.shape[-1] and A.requires_grad == False and A.device.type != "hpu": if A.shape[-1] % quant_state.blocksize != 0: warn( diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index 5f009ea40..e295cc2a3 100644 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -1,13 +1,14 @@ -from collections.abc import Sequence import ctypes as ct +import logging import torch from bitsandbytes.functional import get_ptr from ..._ops import register_kernel -from ...cextension import lib -from ..utils import ipex_cpu +from ...cextension import ErrorHandlerMockBNBNativeLibrary, lib + +logger = logging.getLogger(__name__) # torch._int_mm for s8@s8->s32 is supported on CPU from torch 2.4+. # However, we can overflow if we use this without AVX512_VNNI support. @@ -24,97 +25,77 @@ def _(A: torch.Tensor, B: torch.Tensor): ).reshape(*A.shape[:-1], B.shape[0]) -@register_kernel("bitsandbytes::quantize_blockwise", "cpu") -def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: - torch._check_is_size(blocksize) - - n = A.numel() - - # Only FP32 has c++ kernrl - if A.dtype == torch.float32: - blocks = -(n // -blocksize) - - absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) - out = torch.empty_like(A, dtype=torch.uint8) - - lib.cquantize_blockwise_cpu_fp32( - get_ptr(code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_longlong(blocksize), - ct.c_longlong(n), - ) - else: - rem = n % blocksize - has_rem = rem > 0 - blocks = n // blocksize + has_rem - absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32) - A_reshaped = A.reshape(n) - A_com = A_reshaped[: n - rem] - A_com_reshaped = A_com.reshape(n // blocksize, blocksize) - absmax[: blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=-1)[0] - scaled_A = torch.clamp(A_com_reshaped * (1 / absmax[: blocks - has_rem].view(-1, 1)), -1, 1) - scaled_A = scaled_A.reshape(-1) - if has_rem: - absmax[-1] = torch.abs(A_reshaped[n - rem :]).max() - scaled_A_rem = torch.clamp(A_reshaped[n - rem :] * (1 / absmax[-1]), -1, 1) - scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0) - - diff = torch.abs(scaled_A.unsqueeze(-1) - code.to(scaled_A.device)) - out = torch.argmin(diff, dim=-1).to(torch.uint8).to(scaled_A.device).reshape(A.shape) - - return out, absmax - - -@register_kernel("bitsandbytes::dequantize_blockwise", "cpu") -def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor: - torch._check_is_size(blocksize) - torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") - - # Only FP32 has c++ kernrl - if dtype == torch.float32: - out = torch.empty_like(A, dtype=dtype) - - lib.cdequantize_blockwise_cpu_fp32( - get_ptr(code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_longlong(blocksize), - ct.c_longlong(A.numel()), - ) - else: - out = code[A.reshape(-1).int()] - blocks = out.shape[-1] // blocksize - res = out.shape[-1] % blocksize - if res != 0: - out = torch.nn.functional.pad(out, (0, blocksize - res), mode="constant", value=0) - out = (out.view(-1, blocksize) * absmax.view(-1, 1)).to(dtype).reshape(-1) - out = out[: blocks * blocksize + res] - out = out.reshape(A.shape) - - return out - - -if ipex_cpu: - from bitsandbytes.utils import _reverse_4bit_compress_format - - @register_kernel("bitsandbytes::dequantize_nf4_ipex", "cpu") +if not isinstance(lib, ErrorHandlerMockBNBNativeLibrary): + + @register_kernel("bitsandbytes::quantize_blockwise", "cpu") + def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: + torch._check_is_size(blocksize) + + n = A.numel() + + # Only FP32 has c++ kernrl + if A.dtype == torch.float32: + blocks = -(n // -blocksize) + + absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) + out = torch.empty_like(A, dtype=torch.uint8) + + lib.cquantize_blockwise_cpu_fp32( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_longlong(blocksize), + ct.c_longlong(n), + ) + else: + rem = n % blocksize + has_rem = rem > 0 + blocks = n // blocksize + has_rem + absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32) + A_reshaped = A.reshape(n) + A_com = A_reshaped[: n - rem] + A_com_reshaped = A_com.reshape(n // blocksize, blocksize) + absmax[: blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=-1)[0] + scaled_A = torch.clamp(A_com_reshaped * (1 / absmax[: blocks - has_rem].view(-1, 1)), -1, 1) + scaled_A = scaled_A.reshape(-1) + if has_rem: + absmax[-1] = torch.abs(A_reshaped[n - rem :]).max() + scaled_A_rem = torch.clamp(A_reshaped[n - rem :] * (1 / absmax[-1]), -1, 1) + scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0) + + diff = torch.abs(scaled_A.unsqueeze(-1) - code.to(scaled_A.device)) + out = torch.argmin(diff, dim=-1).to(torch.uint8).to(scaled_A.device).reshape(A.shape) + + return out, absmax + + @register_kernel("bitsandbytes::dequantize_blockwise", "cpu") def _( - A: torch.Tensor, - absmax: torch.Tensor, - blocksize: int, - shape: Sequence[int], - dtype: torch.dtype, + A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype ) -> torch.Tensor: - ipex_weight = torch.ops.ipex_prepack.woq_linear_unpack_weight(A, "nf4", shape, 2) - A = _reverse_4bit_compress_format(ipex_weight.reshape(-1)).reshape(1, -1) - return torch.ops.bitsandbytes.dequantize_4bit.default( - A, - absmax, - blocksize, - "nf4", - shape, - dtype, - ) + torch._check_is_size(blocksize) + torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") + + # Only FP32 has c++ kernrl + if dtype == torch.float32: + out = torch.empty_like(A, dtype=dtype) + + lib.cdequantize_blockwise_cpu_fp32( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_longlong(blocksize), + ct.c_longlong(A.numel()), + ) + else: + out = code[A.reshape(-1).int()] + blocks = out.shape[-1] // blocksize + res = out.shape[-1] % blocksize + if res != 0: + out = torch.nn.functional.pad(out, (0, blocksize - res), mode="constant", value=0) + out = (out.view(-1, blocksize) * absmax.view(-1, 1)).to(dtype).reshape(-1) + out = out[: blocks * blocksize + res] + out = out.reshape(A.shape) + + return out diff --git a/bitsandbytes/backends/utils.py b/bitsandbytes/backends/utils.py old mode 100755 new mode 100644 index 1543f3474..19edd768d --- a/bitsandbytes/backends/utils.py +++ b/bitsandbytes/backends/utils.py @@ -3,16 +3,6 @@ from packaging import version import torch -try: - # to support Intel CPU/XPU (IPEX) backend - import intel_extension_for_pytorch as ipex - - ipex_cpu = ipex if ipex._C._has_cpu() else None - ipex_xpu = ipex if ipex._C._has_xpu() else None -except BaseException: - ipex_cpu = None - ipex_xpu = None - try: import triton # noqa: F401 import triton.language as tl # noqa: F401 diff --git a/bitsandbytes/backends/xpu/__init__.py b/bitsandbytes/backends/xpu/__init__.py old mode 100755 new mode 100644 diff --git a/bitsandbytes/backends/xpu/ops.py b/bitsandbytes/backends/xpu/ops.py old mode 100755 new mode 100644 index 999116c97..1c1422c35 --- a/bitsandbytes/backends/xpu/ops.py +++ b/bitsandbytes/backends/xpu/ops.py @@ -1,59 +1,210 @@ from collections.abc import Sequence -import warnings +import ctypes as ct +import logging import torch +from bitsandbytes.functional import _get_tensor_stream, get_ptr + from ..._ops import register_kernel -from ..utils import ipex_xpu, triton_available +from ...cextension import ErrorHandlerMockBNBNativeLibrary, lib +from ..utils import triton_available + +logger = logging.getLogger(__name__) + + +def _dequantize_4bit_impl( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + dtype: torch.dtype, + out: torch.Tensor, +) -> None: + args = ( + None, + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(blocksize), + ct.c_int(out.numel()), + _get_tensor_stream(A), + ) + if dtype == torch.bfloat16: + if quant_type == "fp4": + lib.cdequantize_blockwise_bf16_fp4(*args) + else: + lib.cdequantize_blockwise_bf16_nf4(*args) + elif dtype == torch.float16: + if quant_type == "fp4": + lib.cdequantize_blockwise_fp16_fp4(*args) + else: + lib.cdequantize_blockwise_fp16_nf4(*args) + elif dtype == torch.float32: + if quant_type == "fp4": + lib.cdequantize_blockwise_fp32_fp4(*args) + else: + lib.cdequantize_blockwise_fp32_nf4(*args) -# _int_mm is available in torch starting from 2.7 version, -# but currently it's don't have xpu implementation. -if ipex_xpu and torch.__version__ >= (2, 7): - @register_kernel("bitsandbytes::int8_linear_matmul", "xpu") - def _(A: torch.Tensor, B: torch.Tensor): - return torch._int_mm( - A.reshape(-1, A.shape[-1]), - B.t(), - ).reshape(*A.shape[:-1], B.shape[0]) +def _dequantize_blockwise_impl( + A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor +) -> None: + args = ( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(blocksize), + ct.c_int(A.numel()), + _get_tensor_stream(A), + ) + if dtype == torch.float16: + lib.cdequantize_blockwise_fp16(*args) + elif dtype == torch.bfloat16: + lib.cdequantize_blockwise_bf16(*args) + elif dtype == torch.float32: + lib.cdequantize_blockwise_fp32(*args) -# IPEX should be faster for xpu, so at first checking if it is available. -if ipex_xpu: +def _gemv_4bit_impl( + A: torch.Tensor, + B: torch.Tensor, + shapeB: Sequence[int], + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, + out: torch.Tensor, +) -> None: + m = ct.c_int32(1) + n = ct.c_int32(shapeB[0]) + k = ct.c_int32(shapeB[1]) - @register_kernel("bitsandbytes::dequantize_nf4_ipex", "xpu") + lda = m + ldb = ct.c_int32((A.shape[-1] + 1) // 2) + ldc = m + + stream = _get_tensor_stream(A) + if A.dtype == torch.float16: + lib.cgemv_4bit_inference_fp16( + m, + n, + k, + get_ptr(A), + get_ptr(B), + get_ptr(absmax), + get_ptr(code), + get_ptr(out), + lda, + ldb, + ldc, + ct.c_int32(blocksize), + stream, + ) + elif A.dtype == torch.bfloat16: + lib.cgemv_4bit_inference_bf16( + m, + n, + k, + get_ptr(A), + get_ptr(B), + get_ptr(absmax), + get_ptr(code), + get_ptr(out), + lda, + ldb, + ldc, + ct.c_int32(blocksize), + stream, + ) + elif A.dtype == torch.float32: + lib.cgemv_4bit_inference_fp32( + m, + n, + k, + get_ptr(A), + get_ptr(B), + get_ptr(absmax), + get_ptr(code), + get_ptr(out), + lda, + ldb, + ldc, + ct.c_int32(blocksize), + stream, + ) + + +# SYCL should be faster for xpu, so at first checking if it is available. +if not isinstance(lib, ErrorHandlerMockBNBNativeLibrary): + logger.info("Register sycl bitsandbytes kernels for XPU") + + @register_kernel("bitsandbytes::dequantize_4bit", "xpu") def _( A: torch.Tensor, absmax: torch.Tensor, blocksize: int, + quant_type: str, shape: Sequence[int], dtype: torch.dtype, ) -> torch.Tensor: - return torch.ops.torch_ipex.dequantize_4bit(A, "nf4", shape, absmax, None, blocksize).t().to(dtype) + out = torch.empty(shape, dtype=dtype, device=A.device) + _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) + return out @register_kernel("bitsandbytes::dequantize_blockwise", "xpu") + def _( + A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype + ) -> torch.Tensor: + out = torch.empty_like(A, dtype=dtype) + _dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out) + return out + + @register_kernel("bitsandbytes::dequantize_blockwise.out", "xpu") def _( A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, + out: torch.Tensor, + ) -> None: + torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") + torch._check(out.shape == A.shape, lambda: f"Expected out.shape == {A.shape}, got {out.shape}") + _dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out) + + @register_kernel("bitsandbytes::gemv_4bit", "xpu") + def _( + A: torch.Tensor, + B: torch.Tensor, + shapeB: Sequence[int], + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, ) -> torch.Tensor: - shape = A.shape - out = torch.empty(A.reshape(-1).shape, dtype=dtype, device=A.device) - # void cdequantize_blockwise_fp32( - # float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream) - if dtype == torch.float16: - ipex_xpu.xpu.bitsandbytes.cdequantize_blockwise_fp16(code, A, absmax, out, blocksize, A.numel()) - elif dtype == torch.bfloat16: - ipex_xpu.xpu.bitsandbytes.cdequantize_blockwise_bf16(code, A, absmax, out, blocksize, A.numel()) - elif dtype == torch.float32: - ipex_xpu.xpu.bitsandbytes.cdequantize_blockwise_fp32(code, A, absmax, out, blocksize, A.numel()) - else: - raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {out.dtype}") + shape = (*A.shape[:-1], shapeB[0]) + out = torch.empty(shape, device=A.device, dtype=A.dtype) + _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) + return out - return out.reshape(shape) + @register_kernel("bitsandbytes::gemv_4bit.out", "xpu") + def _( + A: torch.Tensor, + B: torch.Tensor, + shapeB: Sequence[int], + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, + out: torch.Tensor, + ) -> None: + torch._check( + out.shape == (*A.shape[:-1], shapeB[0]), + lambda: f"Expected out.shape == {(*A.shape[:-1], shapeB[0])}, got {out.shape}", + ) + torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}") + _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) elif triton_available: + logger.info("Register triton bitsandbytes kernels for XPU") from ..triton import ops as triton_ops register_kernel("bitsandbytes::quantize_blockwise", "xpu")(triton_ops.quantize_blockwise) @@ -64,4 +215,4 @@ def _( register_kernel("bitsandbytes::dequantize_4bit", "xpu")(triton_ops.dequantize_4bit) register_kernel("bitsandbytes::gemv_4bit", "xpu")(triton_ops.gemv_4bit) else: - warnings.warn("XPU available but no ipex or triton packages found.") + logger.warning("Register pytorch bitsandbytes kernels for XPU because no native library or triton packages found.") diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index bb301e712..c7e407efd 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -283,6 +283,9 @@ def get_native_library() -> BNBNativeLibrary: binary_path = cuda_binary_path + if torch._C._has_xpu: + binary_path = PACKAGE_DIR / f"libbitsandbytes_xpu{DYNAMIC_LIBRARY_SUFFIX}" + logger.debug(f"Loading bitsandbytes native library from: {binary_path}") # Try to load the library - any errors will propagate up @@ -300,30 +303,27 @@ def get_native_library() -> BNBNativeLibrary: ROCM_GPU_ARCH = get_rocm_gpu_arch() -try: - # to support Intel CPU/GPU (XPU) backend - import intel_extension_for_pytorch as ipex - - ipex_cpu = ipex if ipex._C._has_cpu() else None - ipex_xpu = ipex if ipex._C._has_xpu() else None -except BaseException: - ipex_cpu = None - ipex_xpu = None +HIP_ENVIRONMENT = False +BNB_BACKEND = "CPU" +if torch.version.hip: + HIP_ENVIRONMENT = True + BNB_BACKEND = "ROCm" +elif torch.cuda.is_available(): + BNB_BACKEND = "CUDA" +elif torch._C._has_xpu: + BNB_BACKEND = "XPU" try: - if torch.version.hip: - HIP_ENVIRONMENT, BNB_BACKEND = True, "ROCm" - else: - HIP_ENVIRONMENT, BNB_BACKEND = False, "CUDA" - lib = get_native_library() except Exception as e: - error_msg = str(e) - if not (ipex_cpu or ipex_xpu): + if BNB_BACKEND in ("CPU", "XPU"): + lib = ErrorHandlerMockBNBNativeLibrary("XPU/CPU can run without native library.") + else: + error_msg = str(e) logger.error( - f"bitsandbytes library load error: {error_msg}\n If you are using Intel CPU/XPU, please install intel_extension_for_pytorch to enable required ops", + f"bitsandbytes library load error: {error_msg}", exc_info=True, ) - # create a mock with error messaging as fallback - lib = ErrorHandlerMockBNBNativeLibrary(error_msg) + # create a mock with error messaging as fallback + lib = ErrorHandlerMockBNBNativeLibrary(error_msg) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 9b446a2de..5cd9eac67 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -13,9 +13,9 @@ from torch import Tensor from typing_extensions import deprecated -from bitsandbytes.utils import _reverse_4bit_compress_format, pack_dict_to_tensor, unpack_tensor_to_dict +from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict -from .cextension import HIP_ENVIRONMENT, ipex_cpu, ipex_xpu, lib +from .cextension import HIP_ENVIRONMENT, lib name2qmap = {} @@ -439,6 +439,8 @@ def is_on_gpu(tensors: Iterable[Optional[torch.Tensor]]): def _get_tensor_stream(tensor: Tensor) -> ct.c_void_p: # We use the raw stream for performance reasons. + if tensor.device.type == "xpu": + return ct.c_void_p(torch._C._xpu_getCurrentRawStream(tensor.device.index)) return ct.c_void_p(torch._C._cuda_getCurrentRawStream(tensor.device.index)) @@ -1053,16 +1055,6 @@ def dequantize_4bit( if absmax.dtype != torch.float32: absmax = absmax.float() - # IPEX format is different, we need extra process. - if getattr(quant_state, "ipex", False) and quant_state.quant_type == "nf4": - return torch.ops.bitsandbytes.dequantize_nf4_ipex( - A, - absmax, - quant_state.blocksize, - quant_state.shape, - quant_state.dtype, - ) - if out is not None: torch.ops.bitsandbytes.dequantize_4bit.out( A, absmax, quant_state.blocksize, quant_state.quant_type, quant_state.shape, quant_state.dtype, out=out @@ -1631,25 +1623,6 @@ def gemv_4bit( if state.nested: absmax = dequantize_blockwise(absmax, state.state2) + state.offset - if getattr(state, "ipex", False) and state.quant_type == "nf4": - # compute_dtype: 1 indicates fp16, 2 indicates bf16 - compute_dtype = 2 if A.dtype == torch.bfloat16 else 1 - out = torch.ops.torch_ipex.woq_linear( - A, - B, - "nf4", - state.shape, - state.new_scales, - state.new_zeros, - None, - None, - state.blocksize, - compute_dtype, - 1, - state.compensation, - ) - return out - if out is not None: torch.ops.bitsandbytes.gemv_4bit.out( A, @@ -2336,49 +2309,3 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None): C = 127.0 - - -def _enable_ipex_fusion(linear: torch.nn.Module, x: torch.Tensor): - quant_state = linear.weight.quant_state - - if quant_state.nested: - absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2) - absmax += quant_state.offset - if absmax.dtype != torch.float32: - absmax = absmax.float() - - quant_state.absmax = absmax - quant_state.nested = False - delattr(quant_state, "state2") - - if x.device.type == "cpu" and ipex_cpu: - converted_weight = _reverse_4bit_compress_format(linear.weight.data) - new_weight, new_scales, new_zeros, _, compensation = torch.ops.ipex_prepack.woq_linear_pack_weight( - converted_weight.reshape([quant_state.shape[0], quant_state.shape[1] // 2]), - "nf4", - quant_state.shape, # weight shape - quant_state.absmax.view(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize), # scales - None, # zero_points - None, # bias - None, # batch_size - quant_state.blocksize, - 2, - ) - elif x.device.type == "xpu" and ipex_xpu: - new_weight = _reverse_4bit_compress_format(linear.weight.data) - new_scales = quant_state.absmax.view(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize) - new_zeros = None - compensation = None - new_scales = list(new_scales) - if not linear.training and not x.requires_grad: - new_weight = new_weight.reshape([quant_state.shape[0], quant_state.shape[1] // 2]) - else: - raise ValueError( - "Please check the device and ipex version. The device should be cpu or xpu while ipex version should >= 2.7" - ) - - linear.weight.data = new_weight.data - linear.weight.quant_state.ipex = True - linear.weight.quant_state.new_scales = new_scales - linear.weight.quant_state.new_zeros = new_zeros - linear.weight.quant_state.compensation = compensation diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index ba134f52a..464205fa5 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -12,13 +12,9 @@ import bitsandbytes as bnb from bitsandbytes.cextension import HIP_ENVIRONMENT -from bitsandbytes.functional import QuantState, _enable_ipex_fusion, ipex_cpu, ipex_xpu +from bitsandbytes.functional import QuantState from bitsandbytes.optim import GlobalOptimManager -from bitsandbytes.utils import ( - INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING, - OutlierTracer, - _reverse_4bit_compress_format, -) +from bitsandbytes.utils import INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING, OutlierTracer T = TypeVar("T", bound="torch.nn.Module") @@ -443,7 +439,6 @@ def __init__( self.compute_type_is_set = False if compute_dtype is None else True self.quant_state = None self.quant_storage = quant_storage - self.ipex_linear_is_set = False def set_compute_type(self, x): if x.dtype in [torch.float32, torch.bfloat16]: @@ -470,40 +465,13 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): save weight and bias, then fill state_dict with components of quant_state """ - if getattr(self.weight, "quant_state", None) is not None and getattr(self.weight.quant_state, "ipex", False): - if self.weight.device.type == "cpu": - original_weight = torch.ops.ipex_prepack.woq_linear_unpack_weight( - self.weight, "nf4", self.weight.quant_state.shape, 2 - ) - self.weight.data = _reverse_4bit_compress_format(original_weight.data) - elif self.weight.device.type == "xpu": - self.weight.data = _reverse_4bit_compress_format(self.weight.data.reshape(1, -1)) - - self.weight.quant_state.ipex = False - self.ipex_linear_is_set = False - super()._save_to_state_dict(destination, prefix, keep_vars) # saving weight and bias if getattr(self.weight, "quant_state", None) is not None: for k, v in self.weight.quant_state.as_dict(packed=True).items(): destination[prefix + "weight." + k] = v if keep_vars else v.detach() - def set_ipex_linear(self, x: torch.Tensor): - if ( - not getattr(self.weight.quant_state, "ipex", False) - and self.weight.data.dtype == torch.uint8 - and self.weight.quant_state.shape[1] % self.weight.quant_state.blocksize == 0 - and self.weight.quant_state.quant_type == "nf4" - ): - if x.device.type == "xpu" or (x.device.type == "cpu" and not self.training and x.requires_grad == False): - _enable_ipex_fusion(self, x) - def forward(self, x: torch.Tensor): - # Check if ipex fusion can be used - if not self.ipex_linear_is_set and (ipex_cpu or ipex_xpu): - self.set_ipex_linear(x) - self.ipex_linear_is_set = True - fix_4bit_weight_quant_state_from_module(self) # weights are cast automatically as Int8Params, but the bias has to be cast manually @@ -519,8 +487,7 @@ def forward(self, x: torch.Tensor): x = x.to(self.compute_dtype) bias = None if self.bias is None else self.bias.to(self.compute_dtype) - # IPEX CPU will change weight to 4D so don't need transpose - weight = self.weight.t() if self.weight.dim() == 2 else self.weight + weight = self.weight.t() return bnb.matmul_4bit(x, weight, bias=bias, quant_state=self.weight.quant_state).to(inp_dtype) @@ -675,7 +642,7 @@ def to(self, *args, **kwargs): if device is not None and device.type != "meta" and self.data.device.type == "cpu": if device.type != "cpu" or self.data.dtype != torch.int8: return self._quantize(device) - elif self.data.dtype == torch.int8 and device.type in ("cpu", "xpu") and (ipex_cpu or ipex_xpu): + elif self.data.dtype == torch.int8 and device.type == "cpu": self.CB = self.data new_param = Int8Params( diff --git a/bitsandbytes/utils.py b/bitsandbytes/utils.py index 7920e2188..0828dd295 100644 --- a/bitsandbytes/utils.py +++ b/bitsandbytes/utils.py @@ -38,14 +38,6 @@ def outlier_hook(module, input): hook.remove() -# convert btw standard 4-bit compression format and ipex compression format -def _reverse_4bit_compress_format(weight: torch.Tensor): - out_1 = (weight & 0xF0) >> 4 - out_2 = (weight & 0xF) << 4 - out = out_1 | out_2 - return out - - class OutlierTracer: _instance = None diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index 9c4cab9cc..b5d9afc6b 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -12,6 +12,9 @@ #if BUILD_MPS // #include #endif +#if BUILD_XPU +#include +#endif #include // Compatibility between HIP/CUDA APIs @@ -308,6 +311,90 @@ void spmm_coo_very_sparse_naive_int8( } #endif +#if BUILD_XPU + +void dequantizeBlockwise_fp16( + float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream +) { + dequantizeBlockwise(code, A, absmax, out, blocksize, n, stream); +} + +void dequantizeBlockwise_fp16_fp4( + float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream +) { + dequantizeBlockwise(NULL, A, absmax, out, blocksize, n, stream); +} + +void dequantizeBlockwise_fp16_nf4( + float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream +) { + dequantizeBlockwise(NULL, A, absmax, out, blocksize, n, stream); +} + +void dequantizeBlockwise_fp32( + float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream +) { + dequantizeBlockwise(code, A, absmax, out, blocksize, n, stream); +} + +void dequantizeBlockwise_fp32_fp4( + float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream +) { + dequantizeBlockwise(NULL, A, absmax, out, blocksize, n, stream); +} + +void dequantizeBlockwise_fp32_nf4( + float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream +) { + dequantizeBlockwise(NULL, A, absmax, out, blocksize, n, stream); +} + +void dequantizeBlockwise_bf16( + float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n, + sycl::queue* stream +) { + dequantizeBlockwise(code, A, absmax, out, blocksize, n, stream); +} + +void dequantizeBlockwise_bf16_fp4( + float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n, + sycl::queue* stream +) { + dequantizeBlockwise(NULL, A, absmax, out, blocksize, n, stream); +} + +void dequantizeBlockwise_bf16_nf4( + float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n, + sycl::queue* stream +) { + dequantizeBlockwise(NULL, A, absmax, out, blocksize, n, stream); +} + +void gemv_4bit_inference_fp16( + int m, int n, int k, sycl::half* A, unsigned char* B, float* absmax, float* datatype, sycl::half* out, int lda, + int ldb, int ldc, int blocksize, sycl::queue* stream +) { + gemv_4bit_inference(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); +} + +void gemv_4bit_inference_bf16( + int m, int n, int k, sycl::ext::oneapi::bfloat16* A, unsigned char* B, float* absmax, float* datatype, + sycl::ext::oneapi::bfloat16* out, int lda, int ldb, int ldc, int blocksize, sycl::queue* stream +) { + gemv_4bit_inference( + m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream + ); +} + +void gemv_4bit_inference_fp32( + int m, int n, int k, float* A, unsigned char* B, float* absmax, float* datatype, float* out, int lda, int ldb, + int ldc, int blocksize, sycl::queue* stream +) { + gemv_4bit_inference(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); +} + +#endif + extern "C" { #if BUILD_CUDA || BUILD_HIP void cquantize(float* code, float* A, unsigned char* out, int n) { quantize(code, A, out, n); } @@ -658,6 +745,88 @@ void cgemm_4bit_inference_naive_fp32( #endif +#if BUILD_XPU + +void cdequantize_blockwise_fp16_fp4( + float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream +) { + dequantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n, stream); +} + +void cdequantize_blockwise_fp16( + float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream +) { + dequantizeBlockwise_fp16(code, A, absmax, out, blocksize, n, stream); +} + +void cdequantize_blockwise_fp16_nf4( + float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream +) { + dequantizeBlockwise_fp16_nf4(code, A, absmax, out, blocksize, n, stream); +} + +void cdequantize_blockwise_fp32( + float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream +) { + dequantizeBlockwise_fp32(code, A, absmax, out, blocksize, n, stream); +} + +void cdequantize_blockwise_fp32_fp4( + float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream +) { + dequantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n, stream); +} + +void cdequantize_blockwise_fp32_nf4( + float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream +) { + dequantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n, stream); +} + +void cdequantize_blockwise_bf16( + float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n, + sycl::queue* stream +) { + dequantizeBlockwise_bf16(code, A, absmax, out, blocksize, n, stream); +} + +void cdequantize_blockwise_bf16_fp4( + float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n, + sycl::queue* stream +) { + dequantizeBlockwise_bf16_fp4(code, A, absmax, out, blocksize, n, stream); +} + +void cdequantize_blockwise_bf16_nf4( + float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n, + sycl::queue* stream +) { + dequantizeBlockwise_bf16_nf4(code, A, absmax, out, blocksize, n, stream); +} + +void cgemv_4bit_inference_fp16( + int m, int n, int k, sycl::half* A, unsigned char* B, float* absmax, float* datatype, sycl::half* out, int lda, + int ldb, int ldc, int blocksize, sycl::queue* stream +) { + gemv_4bit_inference_fp16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); +} + +void cgemv_4bit_inference_bf16( + int m, int n, int k, sycl::ext::oneapi::bfloat16* A, unsigned char* B, float* absmax, float* datatype, + sycl::ext::oneapi::bfloat16* out, int lda, int ldb, int ldc, int blocksize, sycl::queue* stream +) { + gemv_4bit_inference_bf16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); +} + +void cgemv_4bit_inference_fp32( + int m, int n, int k, float* A, unsigned char* B, float* absmax, float* datatype, float* out, int lda, int ldb, + int ldc, int blocksize, sycl::queue* stream +) { + gemv_4bit_inference_fp32(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); +} + +#endif + void cquantize_blockwise_cpu_fp32( float* code, float* A, float* absmax, unsigned char* out, long long blocksize, long long n ) { diff --git a/csrc/xpu_kernels.cpp b/csrc/xpu_kernels.cpp new file mode 100644 index 000000000..8ee8add98 --- /dev/null +++ b/csrc/xpu_kernels.cpp @@ -0,0 +1,281 @@ +#include "xpu_kernels.h" +#include +#include +#include + +#include + +inline float dDequantizeFP4(unsigned char val) { + if ((val & 0b1000) == 8) + if ((val & 0b0100) == 4) + if ((val & 0b0010) == 2) + if ((val & 0b0001) == 1) + return -0.25000000f; + else + return -0.16666667f; + else if ((val & 0b0001) == 1) + return -0.50000000f; + else + return -0.33333333f; + else if ((val & 0b0010) == 2) + if ((val & 0b0001) == 1) + return -1.00000000f; + else + return -0.66666667f; + else if ((val & 0b0001) == 1) + return -5.208333333e-03f; + else + return 0.00000000f; + else if ((val & 0b0100) == 4) + if ((val & 0b0010) == 2) + if ((val & 0b0001) == 1) + return 0.25000000f; + else + return 0.16666667f; + else if ((val & 0b0001) == 1) + return 0.50000000f; + else + return 0.33333333f; + else if ((val & 0b0010) == 2) + if ((val & 0b0001) == 1) + return 1.00000000f; + else + return 0.66666667f; + else if ((val & 0b0001) == 1) + return 5.208333333e-03f; + else + return 0.00000000f; +} + +inline float dDequantizeNF4(unsigned char val) { + + // the values for this tree was generated by test_normal_map_tree + // in the file tests/test_functional.py + if ((val & 0b1000) == 8) + if ((val & 0b0100) == 4) // 1 + if ((val & 0b0010) == 2) // 11 + if ((val & 0b0001) == 1) // 111 + return 1.0f; //*1111 + else + return 0.7229568362236023f; //*1110 + else if ((val & 0b0001) == 1) // 110 + return 0.5626170039176941f; //*1101 + else + return 0.44070982933044434f; //*1100 + else if ((val & 0b0010) == 2) // 10 + if ((val & 0b0001) == 1) // 101 + return 0.33791524171829224f; //*1011 + else + return 0.24611230194568634f; //*1010 + else if ((val & 0b0001) == 1) // 100 + return 0.16093020141124725f; //*1001 + else + return 0.07958029955625534f; //*1000 + + else if ((val & 0b0100) == 4) // 0 + if ((val & 0b0010) == 2) // 01 + if ((val & 0b0001) == 1) // 011 + return 0.0f; //*0111 + else + return -0.09105003625154495f; //*0110 + else if ((val & 0b0001) == 1) // 010 + return -0.18477343022823334f; //*0101 + else + return -0.28444138169288635f; //*0100 + else if ((val & 0b0010) == 2) // 00 + if ((val & 0b0001) == 1) // 001 + return -0.39491748809814453f; //*0011 + else + return -0.5250730514526367f; //*0010 + else if ((val & 0b0001) == 1) // 000 + return -0.6961928009986877f; //*0001 + else + return -1.0f; //*0000 +} + +template +SYCL_EXTERNAL void kDequantizeBlockwise::operator()(sycl::nd_item<1> item) const { + const int base_idx = item.get_group(0) * TILE_SIZE; + size_t local_idx = item.get_local_id(0) * NUM_PER_TH; + float local_abs_max = -FLT_MAX; + int local_load_idx = 0; + int local_store_idx = 0; + + uint8_t qvals[NUM_PER_TH]; + T vals[NUM_PER_TH * ((DATA_TYPE > 0) ? 2 : 1)]; + + if (DATA_TYPE > 0) { + local_load_idx = sycl::min(TILE_SIZE, (n + 1) / 2 - base_idx); + local_store_idx = sycl::min(TILE_SIZE * 2, n - base_idx * 2); + } else { + local_load_idx = sycl::min(TILE_SIZE, n - base_idx); + local_store_idx = local_load_idx; + } + + // Avoid expensive division by the blocksize (as blocksize will always be a + // power-of-2) + local_abs_max = absmax[(base_idx + local_idx) >> (31 - std::countl_zero(blocksize))]; + + if (local_idx + NUM_PER_TH < local_load_idx) { + reinterpret_cast(&)[NUM_PER_TH]>(qvals)[0] = + reinterpret_cast*>(A)[(base_idx + local_idx) / NUM_PER_TH]; + } else { +#pragma unroll NUM_PER_TH + for (int i = 0; i < NUM_PER_TH; i++) { + if (local_idx + i < local_load_idx) { + qvals[i] = A[base_idx + local_idx + i]; + } else { + qvals[i] = (uint8_t)0; + } + } + } + + switch (DATA_TYPE) { + case General8bit: +#pragma unroll NUM_PER_TH + for (int j = 0; j < NUM_PER_TH; j++) + vals[j] = code[qvals[j]] * local_abs_max; + break; + case FP4: +#pragma unroll NUM_PER_TH + for (int j = 0; j < NUM_PER_TH; j++) { + vals[j * 2] = dDequantizeFP4(qvals[j] >> 4) * local_abs_max; + vals[j * 2 + 1] = dDequantizeFP4(qvals[j] & 0x0F) * local_abs_max; + } + break; + case NF4: +#pragma unroll NUM_PER_TH + for (int j = 0; j < NUM_PER_TH; j++) { + vals[j * 2] = dDequantizeNF4(qvals[j] >> 4) * local_abs_max; + vals[j * 2 + 1] = dDequantizeNF4(qvals[j] & 0x0F) * local_abs_max; + } + break; + } + + const int local_dst_size = (DATA_TYPE > 0) ? NUM_PER_TH * 2 : NUM_PER_TH; + int local_dst_idx = (DATA_TYPE > 0) ? local_idx * 2 : local_idx; + + if (local_dst_idx + local_dst_size < local_store_idx) { + reinterpret_cast*>( + out + )[(((DATA_TYPE > 0) ? base_idx * 2 : base_idx) + local_dst_idx) / local_dst_size] = + reinterpret_cast(&)[local_dst_size]>(vals)[0]; + } else { +#pragma unroll NUM_PER_TH + for (int i = 0; i < local_dst_size; i++) { + if (local_dst_idx + i < local_store_idx) { + out[((DATA_TYPE > 0) ? base_idx * 2 : base_idx) + local_dst_idx + i] = vals[i]; + } + } + } +} + +template +SYCL_EXTERNAL void + kgemv_4bit_inference::operator()(sycl::nd_item<1> item) const { + size_t idx = item.get_local_id(); + const int sg_idx = idx / SUBG_SIZE; + const int sg_lane = idx % SUBG_SIZE; + const int num_values_4bit = SUBG_SIZE; + const int row_B = NUM_PER_THREAD * item.get_group().get_group_id() + sg_idx; + const int offset_B = ldb * row_B; + const int num_values_8bit = num_values_4bit / 2; + float local_C = 0.0f; + + unsigned char local_B_4bit[num_values_8bit]; + T local_B[num_values_4bit / 4]; + T local_A[num_values_4bit / 4]; + T local_absmax = T(0.0f); + + if (idx < 16) { + quant_map[idx] = T(datatype[idx]); + } + + item.barrier(sycl::access::fence_space::local_space); + + for (int inner_idx = sg_lane * num_values_4bit; inner_idx < K; inner_idx += SUBG_SIZE * num_values_4bit) { + const int inner_idx_halved = inner_idx / 2; + + // Avoid expensive division by the blocksize (as blocksize will always be a + // power-of-2) + const int absidx = ((2 * offset_B) + inner_idx) >> (31 - std::countl_zero((unsigned int)blocksize)); + local_absmax = absmax[absidx]; + + if (row_B < N) { + if ((inner_idx_halved + num_values_8bit) < (K / 2)) { + reinterpret_cast(&)[num_values_8bit]>(local_B_4bit)[0] = + reinterpret_cast*>(B)[(offset_B + (inner_idx_halved)) / (num_values_8bit)]; + } else { +#pragma unroll + for (int j = 0; j < (num_values_8bit); j++) + if ((inner_idx_halved) + j < (K / 2)) + local_B_4bit[j] = B[offset_B + inner_idx_halved + j]; + else + local_B_4bit[j] = 0b01110111; + } + } else { +#pragma unroll + for (int j = 0; j < (num_values_8bit); j++) + local_B_4bit[j] = 0b01110111; + } + + for (int i = 0; i < 4; i++) { +#pragma unroll + for (int k = 0; k < num_values_8bit / 4; k++) { + local_B[k * 2] = quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] >> 4] * local_absmax; + local_B[k * 2 + 1] = quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] & 0x0F] * local_absmax; + } + + if (inner_idx + (num_values_4bit / 4) + (i * num_values_4bit / 4) < K) { + if (BITS == 16) { + reinterpret_cast(&)[num_values_4bit / 4]>(local_A)[0] = + reinterpret_cast*>(A)[inner_idx / (num_values_4bit / 4) + i]; + } else { + reinterpret_cast(&)[num_values_4bit / 4]>(local_A)[0] = + reinterpret_cast*>(A)[inner_idx / (num_values_4bit / 8) + (2 * i) + 0]; + reinterpret_cast(&)[num_values_4bit / 4]>(local_A)[1] = + reinterpret_cast*>(A)[inner_idx / (num_values_4bit / 8) + (2 * i) + 1]; + } + + } else { +#pragma unroll + for (int k = 0; k < num_values_4bit / 4; k++) + if (inner_idx + (i * num_values_4bit / 4) + k < K) + local_A[k] = A[inner_idx + k + (i * num_values_4bit / 4)]; + else + local_A[k] = T(0.0f); + } + +// accumulate in float for accuracy; +#pragma unroll + for (int k = 0; k < num_values_4bit / 4; k++) { + local_C += (float)(local_A[k] * local_B[k]); + } + } + } + + local_C = sycl::reduce_over_group(item.get_sub_group(), local_C, sycl::plus<>()); + + if (row_B < N && sg_lane == 0) + out[row_B] = T(local_C); +} + +//============================================================== +// TEMPLATE DEFINITIONS +//============================================================== + +template class kDequantizeBlockwise; +template class kDequantizeBlockwise; +template class kDequantizeBlockwise; + +template class kDequantizeBlockwise; +template class kDequantizeBlockwise; +template class kDequantizeBlockwise; + +template class kDequantizeBlockwise; +template class kDequantizeBlockwise; +template class kDequantizeBlockwise; + +template class kgemv_4bit_inference; +template class kgemv_4bit_inference; +template class kgemv_4bit_inference; diff --git a/csrc/xpu_kernels.h b/csrc/xpu_kernels.h new file mode 100644 index 000000000..caa7e6716 --- /dev/null +++ b/csrc/xpu_kernels.h @@ -0,0 +1,52 @@ +#include +#include + +#ifndef xpu_kernels +#define xpu_kernels + +template class kDequantizeBlockwise { + public: + SYCL_EXTERNAL void operator()(sycl::nd_item<1> item) const; + + kDequantizeBlockwise(float* code_, uint8_t* A_, float* absmax_, T* out_, const int blocksize_, const int n_) + : code(code_), A(A_), absmax(absmax_), out(out_), blocksize(blocksize_), n(n_) {} + + private: + float* code; + uint8_t* A; + float* absmax; + T* out; + const int blocksize; + const int n; +}; + +template class kgemv_4bit_inference { + public: + SYCL_EXTERNAL void operator()(sycl::nd_item<1> item) const; + + kgemv_4bit_inference( + int M_, int N_, int K_, T* A_, unsigned char* B_, float* absmax_, const float* datatype_, T* out_, int lda_, + int ldb_, int ldc_, int blocksize_ + ) + : M(M_), N(N_), K(K_), A(A_), B(B_), absmax(absmax_), datatype(datatype_), out(out_), lda(lda_), ldb(ldb_), + ldc(ldc_), blocksize(blocksize_), quant_map() {} + + void sycl_ker_local_memory_creation(sycl::handler& cgh) { quant_map = sycl::local_accessor(16, cgh); } + + private: + int M; + int N; + int K; + T* A; + unsigned char* B; + float* absmax; + const float* datatype; + T* out; + int lda; + int ldb; + int ldc; + int blocksize; + sycl::local_accessor quant_map; +}; + +#endif diff --git a/csrc/xpu_ops.cpp b/csrc/xpu_ops.cpp new file mode 100644 index 000000000..aa6ac808f --- /dev/null +++ b/csrc/xpu_ops.cpp @@ -0,0 +1,102 @@ +#include +#include +#include + +template +void dequantizeBlockwise( + float* code, unsigned char* A, float* absmax, T* out, int blocksize, const int n, sycl::queue* stream +) { + auto& queue = *stream; + const int workgroup_size = 128; + const int num_per_th = 4; + const int tile_size = workgroup_size * num_per_th; + if (DATA_TYPE > 0) { + const int workgroup_num = (n + tile_size * 2 - 1) / (tile_size * 2); + sycl::range<1> local_range{(size_t)workgroup_size}; + sycl::range<1> global_range{(size_t)workgroup_num * (size_t)workgroup_size}; + kDequantizeBlockwise kfn(code, A, absmax, out, blocksize / 2, n); + sycl_kernel_submit( + sycl::nd_range<1>(sycl::range<1>(global_range), sycl::range<1>(local_range)), queue, kfn + ); + } else { + const int workgroup_num = (n + tile_size - 1) / tile_size; + sycl::range<1> local_range{(size_t)workgroup_size}; + sycl::range<1> global_range{(size_t)workgroup_num * (size_t)workgroup_size}; + kDequantizeBlockwise kfn(code, A, absmax, out, blocksize, n); + sycl_kernel_submit( + sycl::nd_range<1>(sycl::range<1>(global_range), sycl::range<1>(local_range)), queue, kfn + ); + } +} + +template +void gemv_4bit_inference( + int m, int n, int k, T* A, unsigned char* B, float* absmax, float* datatype, T* out, int lda, int ldb, int ldc, + int blocksize, sycl::queue* stream +) { + + auto& queue = *stream; + + const size_t GROUP_SIZE = 128; // workgroup_size + const size_t SUBG_SIZE = 32; // subgroup_size + const size_t NUM_PER_THREAD = GROUP_SIZE / SUBG_SIZE; + size_t workgroup_num = (n + NUM_PER_THREAD - 1) / NUM_PER_THREAD; + + kgemv_4bit_inference kfn( + m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize + ); + + sycl_comp_kernel_submit( + sycl::nd_range<1>(sycl::range<1>(GROUP_SIZE * workgroup_num), sycl::range<1>(GROUP_SIZE)), queue, kfn + ); +} + +//============================================================== +// TEMPLATE DEFINITIONS +//============================================================== + +template void dequantizeBlockwise( + float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream +); +template void dequantizeBlockwise( + float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream +); +template void dequantizeBlockwise( + float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream +); + +template void dequantizeBlockwise( + float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream +); +template void dequantizeBlockwise( + float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream +); +template void dequantizeBlockwise( + float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream +); + +template void dequantizeBlockwise( + float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n, + sycl::queue* stream +); +template void dequantizeBlockwise( + float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n, + sycl::queue* stream +); +template void dequantizeBlockwise( + float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n, + sycl::queue* stream +); + +template void gemv_4bit_inference( + int m, int n, int k, sycl::half* A, unsigned char* B, float* absmax, float* datatype, sycl::half* out, int lda, + int ldb, int ldc, int blocksize, sycl::queue* stream +); +template void gemv_4bit_inference( + int m, int n, int k, sycl::ext::oneapi::bfloat16* A, unsigned char* B, float* absmax, float* datatype, + sycl::ext::oneapi::bfloat16* out, int lda, int ldb, int ldc, int blocksize, sycl::queue* stream +); +template void gemv_4bit_inference( + int m, int n, int k, float* A, unsigned char* B, float* absmax, float* datatype, float* out, int lda, int ldb, + int ldc, int blocksize, sycl::queue* stream +); diff --git a/csrc/xpu_ops.h b/csrc/xpu_ops.h new file mode 100644 index 000000000..142d6c161 --- /dev/null +++ b/csrc/xpu_ops.h @@ -0,0 +1,46 @@ +#ifndef xpu_ops_H +#define xpu_ops_H + +#include +#include +#include +#include + +#include +#include + +#include + +template +static inline void sycl_kernel_submit(sycl::nd_range range, sycl::queue q, ker_t ker) { + auto cgf = [&](::sycl::handler& cgh) + [[sycl::reqd_sub_group_size(subgroup_size)]] { cgh.parallel_for(range, ker); }; + q.submit(cgf); +} + +template +static inline void sycl_comp_kernel_submit(sycl::nd_range range, sycl::queue q, ker_t ker) { + auto cgf = [&](::sycl::handler& cgh) [[sycl::reqd_sub_group_size(subgroup_size)]] { + ker.sycl_ker_local_memory_creation(cgh); + cgh.parallel_for(range, ker); + }; + q.submit(cgf); +} + +typedef enum DataType_t { + General8bit = 0, + FP4 = 1, + NF4 = 2, +} DataType_t; + +template +void dequantizeBlockwise( + float* code, unsigned char* A, float* absmax, T* out, int workgroup_size, const int n, sycl::queue* stream +); +template +void gemv_4bit_inference( + int m, int n, int k, T* A, unsigned char* B, float* absmax, float* datatype, T* out, int lda, int ldb, int ldc, + int blocksize, sycl::queue* stream +); + +#endif diff --git a/docs/source/installation.mdx b/docs/source/installation.mdx index e61ce4655..7396c7dcf 100644 --- a/docs/source/installation.mdx +++ b/docs/source/installation.mdx @@ -138,8 +138,8 @@ We provide an early preview of support for AMD and Intel hardware as part of a d | **Backend** | **Supported Versions** | **Python versions** | **Architecture Support** | **Status** | |-------------|------------------------|---------------------------|-------------------------|------------| | **AMD ROCm** | 6.1+ | 3.10+ | minimum CDNA - `gfx90a`, RDNA - `gfx1100` | Alpha | -| **Intel CPU** | v2.4.0+ (`ipex`) | 3.10+ | Intel CPU | Alpha | -| **Intel GPU** | v2.4.0+ (`ipex`) | 3.10+ | Intel GPU | Experimental | +| **Intel CPU** | v2.4.0+ | 3.10+ | Intel CPU | Alpha | +| **Intel GPU** | v2.7.0+ | 3.10+ | Intel GPU | Experimental | | **Ascend NPU** | 2.1.0+ (`torch_npu`) | 3.10+ | Ascend NPU | Experimental | For each supported backend, follow the respective instructions below: @@ -179,7 +179,6 @@ pip install torch --index-url https://download.pytorch.org/whl/rocm6.3/ * A compatible PyTorch version with Intel XPU support is required. It is recommended to use the latest stable release. See [Getting Started on Intel GPU](https://docs.pytorch.org/docs/stable/notes/get_start_xpu.html) for guidance. -* The [Intel Extension for PyTorch](https://intel.github.io/intel-extension-for-pytorch/xpu/latest/) is recommended for performance improvements. @@ -235,27 +234,18 @@ pip install -e . # `-e` for "editable" install, when developing BNB (otherwise -#### Intel CPU + XPU +#### Intel CPU + GPU(XPU) - -If you are using Intel CPU on Linux or Intel XPU on Linux/Windows, please follow the [instruction](https://pytorch-extension.intel.com/) or the following command to install intel_extension_for_pytorch so you can get better performance. - -CPU: `pip install intel_extension_for_pytorch` -XPU: `pip install intel_extension_for_pytorch --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/` - -Install bitsandbytes: -CPU: Need to build CPU C++ codes +CPU needs to build CPU C++ codes, while XPU needs to build sycl codes. +Run `export bnb_device=xpu` if you are using xpu, run `export bnb_device=cpu` if you are using cpu. ``` git clone https://github.com/bitsandbytes-foundation/bitsandbytes.git && cd bitsandbytes/ -cmake -DCOMPUTE_BACKEND=cpu -S . +cmake -DCOMPUTE_BACKEND=$bnb_device -S . make -pip install . -``` -XPU: -``` -pip install git+https://github.com/bitsandbytes-foundation/bitsandbytes.git +pip install -e . ``` + diff --git a/tests/test_functional.py b/tests/test_functional.py index b84db6502..25844d20f 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -142,11 +142,11 @@ def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize, abserr = sum(diffs) / len(diffs) relerr = sum(reldiffs) / len(reldiffs) if signed: - threshold_abserr = 0.0036 if device in ("cpu", "xpu") and (F.ipex_cpu or F.ipex_xpu) else 0.0035 + threshold_abserr = 0.0035 assert abserr < 0.0036 assert relerr < 0.015 else: - assert abserr < 0.00175 if device in ("cpu", "xpu") and (F.ipex_cpu or F.ipex_xpu) else 0.0023 + assert abserr < 0.0023 assert relerr < 0.012 assert A2.dtype == dtype @@ -177,8 +177,8 @@ def test_blockwise_cpu_large(self, hidden, blocksize): @pytest.mark.parametrize("bits", range(2, 9), ids=id_formatter("bits")) @pytest.mark.parametrize("method", ["linear", "fp8", "dynamic"]) def test_few_bit_quant(self, device, bits, method): - if bits != 8 and (device == "cpu" or (device == "xpu" and F.ipex_xpu)): - pytest.skip("CPU/XPU implementation only supports 8 bits") + if bits != 8 and device == "cpu": + pytest.skip("CPU implementation only supports 8 bits") abserrs = [] relerrs = [] @@ -1238,8 +1238,8 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double max_errs3 = [] # Large number of iterations is excessive and slow on CPU. - # Keep for CUDA for now. - iters = 100 if device == "cuda" else 10 + # Keep for CUDA/XPU for now. + iters = 10 if device == "cpu" else 100 for i in range(iters): if kind == "fc1": @@ -1341,13 +1341,13 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double assert err1 < 6e-5 assert relerr1 < 2e-4 assert absratio < 1.005 and absratio > 0.995 - assert relratio < 1.005 and relratio > 0.995 - assert maxratio < 1.005 and maxratio > 0.995 + assert relratio < 1.005 and relratio > 0.992 + assert maxratio < 1.005 and maxratio > 0.992 elif dtype == torch.float32: if dim <= 512: assert err1 < 5e-8 assert relerr1 < 1e-6 - assert maxerr1 < 1e-7 + assert maxerr1 < 1.05e-7 else: assert err1 < 5e-8 assert relerr1 < 8e-6 @@ -1357,16 +1357,17 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double assert maxratio < 1.005 and maxratio > 0.995 elif dtype == torch.bfloat16: if dim <= 512: + relerr_thres = 0.013 if hasattr(torch, "xpu") and torch.xpu.is_available() else 0.007 assert err1 < 6e-4 - assert relerr1 < 0.007 + assert relerr1 < relerr_thres assert maxerr1 < 0.015 else: assert err1 < 2e-4 assert relerr1 < 0.002 assert maxerr1 < 0.0012 assert absratio < 1.005 and absratio > 0.995 - assert relratio < 1.04 and relratio > 0.96 - assert maxratio < 1.02 and maxratio > 0.98 + assert relratio < 1.05 and relratio > 0.96 + assert maxratio < 1.05 and maxratio > 0.97 @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("storage_type", ["nf4", "fp4"], ids=["nf4", "fp4"]) diff --git a/tests/test_linear8bitlt.py b/tests/test_linear8bitlt.py index 86726bd44..0e5f7bc18 100644 --- a/tests/test_linear8bitlt.py +++ b/tests/test_linear8bitlt.py @@ -272,14 +272,11 @@ def test_linear8bitlt_torch_compile(device, threshold, bias, fullgraph, mode): # Test with gradients. Currently only works with threshold=0. # Has a strange regression on Linux aarch64 CPU in torch==2.6.0. - # There is also an issue with torch==2.7.0 on x86-64 with IPEX. is_broken_platform = ( device == "cpu" and platform.system() == "Linux" - and ( - (platform.machine() == "aarch64" and (2, 6) <= torch.__version__ < (2, 7)) - or (platform.machine() == "x86_64" and bnb.functional.ipex_cpu) - ) + and platform.machine() == "aarch64" + and (2, 6) <= torch.__version__ < (2, 7) ) if threshold == 0 and not is_broken_platform: diff --git a/tests/test_modules.py b/tests/test_modules.py index 8946522d3..e5682e5c8 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -143,9 +143,8 @@ def test_linear8bitlt_no_fp16_weights(device, threshold): b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16) o1 = mlp(b1) assert o1.dtype == torch.float16 - if threshold > 0: + if threshold > 0 and device not in ("cpu", "xpu"): assert mlp.fc1.state.idx is not None - if threshold > 0: assert mlp.fc2.state.idx is not None mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).to(device).half() @@ -156,9 +155,8 @@ def test_linear8bitlt_no_fp16_weights(device, threshold): b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16) o1 = mlp(b1) assert o1.dtype == torch.float16 - if threshold > 0: + if threshold > 0 and device not in ("cpu", "xpu"): assert mlp.fc1.state.idx is not None - if threshold > 0: assert mlp.fc2.state.idx is not None mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().to(device) @@ -167,9 +165,8 @@ def test_linear8bitlt_no_fp16_weights(device, threshold): b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16) o1 = mlp(b1) assert o1.dtype == torch.float16 - if threshold > 0: + if threshold > 0 and device not in ("cpu", "xpu"): assert mlp.fc1.state.idx is not None - if threshold > 0: assert mlp.fc2.state.idx is not None assert mlp.fc1.weight.dtype == torch.int8 assert mlp.fc2.weight.dtype == torch.int8 @@ -189,9 +186,8 @@ def test_linear8bitlt_no_fp16_weights(device, threshold): b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16) o1 = mlp(b1) assert o1.dtype == torch.float16 - if threshold > 0: + if threshold > 0 and device not in ("cpu", "xpu"): assert mlp.fc1.state.idx is not None - if threshold > 0: assert mlp.fc2.state.idx is not None assert mlp.fc1.weight.dtype == torch.int8 assert mlp.fc2.weight.dtype == torch.int8 @@ -211,9 +207,8 @@ def test_linear8bitlt_no_fp16_weights(device, threshold): b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16) o1 = mlp(b1) assert o1.dtype == torch.float16 - if threshold > 0: + if threshold > 0 and device not in ("cpu", "xpu"): assert mlp.fc1.state.idx is not None - if threshold > 0: assert mlp.fc2.state.idx is not None assert mlp.fc1.weight.dtype == torch.int8 diff --git a/tests/test_ops.py b/tests/test_ops.py index 8aa0560fd..3b52bf284 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -5,7 +5,6 @@ import bitsandbytes from bitsandbytes.cextension import HIP_ENVIRONMENT -from bitsandbytes.functional import ipex_xpu from tests.helpers import TRUE_FALSE, get_available_devices, id_formatter, is_supported_on_hpu # torch.library.opcheck is only available in torch 2.4 and later. @@ -145,10 +144,6 @@ def test_dequantize_blockwise(self, device, dtype, blocksize): assert out.dtype == dtype assert out.device == A.device - # TODO: Enable it - if device == "xpu" and ipex_xpu: - pytest.skip("XPU implementation have torch.op inside torch.op, it will fail on op check") - opcheck(torch.ops.bitsandbytes.dequantize_blockwise.default, (A, absmax, code, blocksize, dtype))