diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py index 78c99355b..019a4f6ab 100644 --- a/bitsandbytes/__init__.py +++ b/bitsandbytes/__init__.py @@ -12,9 +12,15 @@ matmul_cublas, mm_cublas, ) +from .cextension import lib from .nn import modules -from .optim import adam +if lib and lib.compiled_with_cuda: + from .backends import register_backend + from .backends.cuda import CUDABackend + from .optim import adam + + register_backend("cuda", CUDABackend()) __pdoc__ = { "libbitsandbytes": False, "optim.optimizer.Optimizer8bit": False, diff --git a/bitsandbytes/backends/__init__.py b/bitsandbytes/backends/__init__.py new file mode 100644 index 000000000..30f08073a --- /dev/null +++ b/bitsandbytes/backends/__init__.py @@ -0,0 +1,15 @@ +from typing import Dict + +from bitsandbytes.backends.base import Backend + +backends: Dict[str, Backend] = {} + + +def register_backend(backend_name: str, backend_instance: Backend): + backends[backend_name.lower()] = backend_instance + + +def ensure_backend_is_available(device_type: str): + """Check if a backend is available for the given device type.""" + if device_type.lower() not in backends: + raise NotImplementedError(f"Device backend for {device_type} is currently not supported.") diff --git a/bitsandbytes/backends/base.py b/bitsandbytes/backends/base.py new file mode 100644 index 000000000..8232d17c1 --- /dev/null +++ b/bitsandbytes/backends/base.py @@ -0,0 +1,133 @@ +from abc import ABC, abstractmethod +from typing import Optional, Tuple + +import torch + +from bitsandbytes.utils import QuantState + + +class Backend(ABC): + """Base class for devices backends that will implement their own 8bits and 4bits functions.""" + + @abstractmethod + def double_quant( + self, + A, + col_stats=None, + row_stats=None, + out_col=None, + out_row=None, + threshold=0.0, + ): + raise NotImplementedError + + @abstractmethod + def transform( + self, + A, + to_order, + from_order="row", + out=None, + transpose=False, + state=None, + ld=None, + ): + raise NotImplementedError + + @abstractmethod + def igemmlt(self, A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): + raise NotImplementedError + + @abstractmethod + def mm_dequant( + self, + A, + quant_state, + row_stats, + col_stats, + out=None, + new_row_stats=None, + new_col_stats=None, + bias=None, + ): + raise NotImplementedError + + @abstractmethod + def extract_outliers(self, A, SA, idx): + raise NotImplementedError + + @abstractmethod + def quantize_4bit( + self, + A: torch.Tensor, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize=64, + compress_statistics=False, + quant_type="fp4", + quant_storage=torch.uint8, + ) -> Tuple[torch.Tensor, QuantState]: + """ + Quantize tensor A in blocks of 4-bit values. + + Quantizes tensor A by dividing it into blocks which are independently quantized to FP4. + + Parameters + ---------- + A : torch.Tensor + The input tensor. + absmax : torch.Tensor + The absmax values. + out : torch.Tensor + The output tensor. + blocksize : int + The blocksize used in quantization. + quant_type : str + The 4-bit quantization data type {fp4, nf4} + + Returns + ------- + torch.Tensor: + Tensor with packed 4-bit values. + tuple(torch.Tensor, torch.Size, torch.dtype, int): + The quantization state to undo the quantization. + """ + raise NotImplementedError + + @abstractmethod + def dequantize_4bit( + self, + A: torch.Tensor, + quant_state: Optional[QuantState] = None, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize: int = 64, + quant_type="fp4", + ) -> torch.Tensor: + """ + Dequantizes FP4 blockwise quantized values. + + Dequantizes the tensor A with maximum absolute values absmax in blocks of size blocksize. + + Parameters + ---------- + A : torch.Tensor + The input tensor (packed 4-bit values). + quant_state : QuantState + object with quantisation stats, incl. absmax values, original tensor shape and original dtype. + absmax : torch.Tensor + The absmax values. + out : torch.Tensor + Dequantized output tensor. + blocksize : int + The blocksize used in quantization. + quant_type : str + The 4-bit quantization data type {fp4, nf4} + + + Returns + ------- + torch.Tensor: + Dequantized tensor. + """ + raise NotImplementedError diff --git a/bitsandbytes/backends/cuda.py b/bitsandbytes/backends/cuda.py new file mode 100644 index 000000000..c76bcaebd --- /dev/null +++ b/bitsandbytes/backends/cuda.py @@ -0,0 +1,528 @@ +import ctypes as ct +from typing import Optional, Tuple + +import torch + +from bitsandbytes.cextension import lib +from bitsandbytes.functional import ( + CUBLAS_Context, + coo_zeros, + dequantize_blockwise, + dtype2bytes, + get_4bit_type, + get_colrow_absmax, + get_ptr, + get_transform_buffer, + is_on_gpu, + post_call, + pre_call, + prod, + quantize_blockwise, +) +from bitsandbytes.utils import QuantState + +from .base import Backend + + +class CUDABackend(Backend): + def double_quant(self, A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0): + device = A.device + assert A.dtype == torch.half + assert device.type == "cuda" + prev_device = pre_call(A.device) + + cols = A.shape[-1] + if len(A.shape) == 3: + rows = A.shape[0] * A.shape[1] + else: + rows = A.shape[0] + + if row_stats is None or col_stats is None: + row_stats, col_stats, nnz_row_ptr = get_colrow_absmax(A, threshold=threshold) + + if out_col is None: + out_col = torch.zeros(A.shape, device=device, dtype=torch.int8) + if out_row is None: + out_row = torch.zeros(A.shape, device=device, dtype=torch.int8) + + coo_tensor = None + ptrA = get_ptr(A) + ptrColStats = get_ptr(col_stats) + ptrRowStats = get_ptr(row_stats) + ptrOutCol = get_ptr(out_col) + ptrOutRow = get_ptr(out_row) + + is_on_gpu([A, col_stats, row_stats, out_col, out_row]) + if threshold > 0.0: + nnz = nnz_row_ptr[-1].item() + if nnz > 0: + coo_tensor = coo_zeros(A.shape[0], A.shape[1], nnz_row_ptr[-1].item(), device) + ptrRowIdx = get_ptr(coo_tensor.rowidx) + ptrColIdx = get_ptr(coo_tensor.colidx) + ptrVal = get_ptr(coo_tensor.values) + ptrRowPtr = get_ptr(nnz_row_ptr) + + lib.cdouble_rowcol_quant( + ptrA, + ptrRowStats, + ptrColStats, + ptrOutCol, + ptrOutRow, + ptrRowIdx, + ptrColIdx, + ptrVal, + ptrRowPtr, + ct.c_float(threshold), + ct.c_int32(rows), + ct.c_int32(cols), + ) + val, idx = torch.sort(coo_tensor.rowidx) + coo_tensor.rowidx = val + coo_tensor.colidx = coo_tensor.colidx[idx] + coo_tensor.values = coo_tensor.values[idx] + else: + lib.cdouble_rowcol_quant( + ptrA, + ptrRowStats, + ptrColStats, + ptrOutCol, + ptrOutRow, + None, + None, + None, + None, + ct.c_float(0.0), + ct.c_int32(rows), + ct.c_int32(cols), + ) + else: + lib.cdouble_rowcol_quant( + ptrA, + ptrRowStats, + ptrColStats, + ptrOutCol, + ptrOutRow, + None, + None, + None, + None, + ct.c_float(threshold), + ct.c_int32(rows), + ct.c_int32(cols), + ) + post_call(prev_device) + + return out_row, out_col, row_stats, col_stats, coo_tensor + + def transform(self, A, to_order, from_order="row", out=None, transpose=False, state=None, ld=None): + prev_device = pre_call(A.device) + if state is None: + state = (A.shape, from_order) + else: + from_order = state[1] + + if out is None: + out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1], transpose) + else: + new_state = (state[0], to_order) # (shape, order) + + shape = state[0] + if len(shape) == 2: + dim1 = ct.c_int32(shape[0]) + dim2 = ct.c_int32(shape[1]) + else: + dim1 = ct.c_int32(shape[0] * shape[1]) + dim2 = ct.c_int32(shape[2]) + + is_on_gpu([A, out]) + if to_order == "col32": + if transpose: + lib.ctransform_row2col32T(get_ptr(A), get_ptr(out), dim1, dim2) + else: + lib.ctransform_row2col32(get_ptr(A), get_ptr(out), dim1, dim2) + + elif to_order == "col_turing": + if transpose: + lib.ctransform_row2turingT(get_ptr(A), get_ptr(out), dim1, dim2) + else: + lib.ctransform_row2turing(get_ptr(A), get_ptr(out), dim1, dim2) + + elif to_order == "col_ampere": + if transpose: + lib.ctransform_row2ampereT(get_ptr(A), get_ptr(out), dim1, dim2) + else: + lib.ctransform_row2ampere(get_ptr(A), get_ptr(out), dim1, dim2) + + elif to_order == "row": + if from_order == "col_turing": + lib.ctransform_turing2row(get_ptr(A), get_ptr(out), dim1, dim2) + elif from_order == "col_ampere": + lib.ctransform_ampere2row(get_ptr(A), get_ptr(out), dim1, dim2) + + else: + raise NotImplementedError(f"Transform function not implemented: From {from_order} to {to_order}") + + post_call(prev_device) + + return out, new_state + + def igemmlt(self, A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): + shapeA = SA[0] + shapeB = SB[0] + dimsA = len(shapeA) + dimsB = len(shapeB) + + assert dimsB == 2, "Only two dimensional matrices are supported for argument B" + if dimsA == 2: + m = shapeA[0] + elif dimsA == 3: + m = shapeA[0] * shapeA[1] + + rows = n = shapeB[0] + assert prod(list(shapeA)) > 0, f"Input tensor dimensions need to be > 0: {shapeA}" + + # if the tensor is empty, return a transformed empty tensor with the right dimensions + if shapeA[0] == 0 and dimsA == 2: + return torch.empty((0, shapeB[0]), device=A.device, dtype=torch.float16) + elif shapeA[1] == 0 and dimsA == 3: + return torch.empty(tuple(shapeA[:2] + [shapeB[0]]), device=A.device, dtype=torch.float16) + + if dimsA == 2 and out is None: + out, Sout = get_transform_buffer((shapeA[0], shapeB[0]), dtype, A.device, "col32", "row") + elif dimsA == 3 and out is None: + out, Sout = get_transform_buffer((shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, "col32", "row") + + assert dimsB != 3, "len(B.shape)==3 not supported" + assert A.device.type == "cuda" + assert B.device.type == "cuda" + assert A.dtype == torch.int8 + assert B.dtype == torch.int8 + assert out.dtype == dtype + assert SA[1] == "col32" + assert SB[1] in ["col_turing", "col_ampere"] + assert Sout[1] == "col32" + assert ( + shapeA[-1] == shapeB[-1] + ), f"Matmullt only supports A @ B^T. Inner matrix dimensions do not match: A @ B = {shapeA} @ {shapeB}" + + formatB = SB[1] + prev_device = A.device + torch.cuda.set_device(A.device) + + ptr = CUBLAS_Context.get_instance().get_context(A.device) + ptrA = get_ptr(A) + ptrB = get_ptr(B) + ptrC = get_ptr(out) + + k = shapeA[-1] + lda = ct.c_int32(m * 32) + if formatB == "col_turing": + # turing: tiles with rows filled up to multiple of 8 rows by 32 columns + # n = rows + ldb = ct.c_int32(((rows + 7) // 8) * 8 * 32) + else: + # ampere: tiles with rows filled up to multiple of 32 rows by 32 columns + # n = rows + ldb = ct.c_int32(((rows + 31) // 32) * 32 * 32) + + ldc = ct.c_int32(m * 32) + m = ct.c_int32(m) + n = ct.c_int32(n) + k = ct.c_int32(k) + + has_error = 0 + ptrRowScale = get_ptr(None) + is_on_gpu([A, B, out]) + + if formatB == "col_turing": + if dtype == torch.int32: + has_error = lib.cigemmlt_turing_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) + else: + has_error = lib.cigemmlt_turing_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) + + elif formatB == "col_ampere": + if dtype == torch.int32: + has_error = lib.cigemmlt_ampere_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) + else: + has_error = lib.cigemmlt_ampere_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) + + if has_error == 100: # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu` + raise NotImplementedError("igemmlt not available (probably built with NO_CUBLASLT)") + + if has_error: + print( + f"A: {shapeA}, B: {shapeB}, C: {Sout[0]}; (lda, ldb, ldc): {(lda, ldb, ldc)}; (m, n, k): {(m, n, k)}" + ) + raise Exception("cublasLt ran into an error!") + + torch.cuda.set_device(prev_device) + + return out, Sout + + def mm_dequant( + self, A, quant_state, row_stats, col_stats, out=None, new_row_stats=None, new_col_stats=None, bias=None + ): + assert A.dtype == torch.int32 + if bias is not None: + assert bias.dtype == torch.float16 + out_shape = quant_state[0] + if len(out_shape) == 3: + out_shape = (out_shape[0] * out_shape[1], out_shape[2]) + + if out is None: + out = torch.empty(out_shape, dtype=torch.float16, device=A.device) + if new_row_stats is None: + new_row_stats = torch.empty(out_shape[0], dtype=torch.float32, device=A.device) + if new_col_stats is None: + new_col_stats = torch.empty(out_shape[1], dtype=torch.float32, device=A.device) + assert new_row_stats.shape[0] == row_stats.shape[0], f"{new_row_stats.shape} vs {row_stats.shape}" + assert new_col_stats.shape[0] == col_stats.shape[0], f"{new_col_stats.shape} vs {col_stats.shape}" + + prev_device = pre_call(A.device) + ptrA = get_ptr(A) + ptrOut = get_ptr(out) + ptrRowStats = get_ptr(row_stats) + ptrColStats = get_ptr(col_stats) + ptrNewRowStats = get_ptr(new_row_stats) + ptrNewColStats = get_ptr(new_col_stats) + ptrBias = get_ptr(bias) + numRows = ct.c_int32(out_shape[0]) + numCols = ct.c_int32(out_shape[1]) + + is_on_gpu([A, row_stats, col_stats, out, new_row_stats, new_col_stats, bias]) + lib.cdequant_mm_int32_fp16( + ptrA, ptrRowStats, ptrColStats, ptrOut, ptrNewRowStats, ptrNewColStats, ptrBias, numRows, numCols + ) + post_call(prev_device) + + return out + + def extract_outliers(self, A, SA, idx): + shapeA = SA[0] + formatA = SA[1] + assert formatA in ["col_turing", "col_ampere"] + assert A.device.type == "cuda" + + out = torch.zeros((shapeA[0], idx.numel()), dtype=torch.int8, device=A.device) + + idx_size = ct.c_int32(idx.numel()) + rows = ct.c_int32(shapeA[0]) + cols = ct.c_int32(shapeA[1]) + ptrA = get_ptr(A) + ptrIdx = get_ptr(idx) + ptrOut = get_ptr(out) + + prev_device = pre_call(A.device) + + if formatA == "col_turing": + lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) + elif formatA == "col_ampere": + lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) + + post_call(prev_device) + + return out + + def quantize_4bit( + self, + A: torch.Tensor, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize=64, + compress_statistics=False, + quant_type="fp4", + quant_storage=torch.uint8, + ) -> Tuple[torch.Tensor, QuantState]: + if A.device.type != "cuda": + raise NotImplementedError(f"Device type not supported for FP4 quantization: {A.device.type}") + if quant_type not in ["fp4", "nf4"]: + raise NotImplementedError(f"4-bit quantization data type {quant_type} is not implemented.") + + n = A.numel() + input_shape = A.shape + + if absmax is None: + blocks = n // blocksize + blocks += 1 if n % blocksize > 0 else 0 + absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32) + + if out is None: + mod = dtype2bytes[quant_storage] * 2 + out = torch.zeros(((n + 1) // mod, 1), dtype=quant_storage, device=A.device) + + assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] + + prev_device = pre_call(A.device) + is_on_gpu([A, out, absmax]) + + if A.dtype == torch.float32: + if quant_type == "fp4": + lib.cquantize_blockwise_fp32_fp4( + get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n) + ) + else: + lib.cquantize_blockwise_fp32_nf4( + get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n) + ) + + elif A.dtype == torch.float16: + if quant_type == "fp4": + lib.cquantize_blockwise_fp16_fp4( + get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n) + ) + else: + lib.cquantize_blockwise_fp16_nf4( + get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n) + ) + + elif A.dtype == torch.bfloat16: + if quant_type == "fp4": + lib.cquantize_blockwise_bf16_fp4( + get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n) + ) + else: + lib.cquantize_blockwise_bf16_nf4( + get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n) + ) + + else: + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") + + post_call(A.device) + + code = get_4bit_type(quant_type, device=A.device) + + if compress_statistics: + offset = absmax.mean() + absmax -= offset + qabsmax, state2 = quantize_blockwise(absmax, blocksize=256) + del absmax + state = QuantState( + absmax=qabsmax, + shape=input_shape, + dtype=A.dtype, + blocksize=blocksize, + code=code, + quant_type=quant_type, + offset=offset, + state2=state2, + ) + + else: + state = QuantState( + absmax=absmax, shape=input_shape, dtype=A.dtype, blocksize=blocksize, code=code, quant_type=quant_type + ) + + return out, state + + def dequantize_4bit( + self, + A: torch.Tensor, + quant_state: Optional[QuantState] = None, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize: int = 64, + quant_type="fp4", + ) -> torch.Tensor: + if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]: + raise ValueError( + f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]" + ) + + if quant_type not in ["fp4", "nf4"]: + raise NotImplementedError(f"4-bit quantization data type {quant_type} is not implemented.") + + if quant_state is None: + assert absmax is not None and out is not None + + quant_state = QuantState( + absmax=absmax, shape=out.shape, dtype=out.dtype, blocksize=blocksize, quant_type=quant_type + ) + else: + absmax = quant_state.absmax + + if quant_state.nested: + absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2) + absmax += quant_state.offset + if absmax.dtype != torch.float32: + absmax = absmax.float() + + if out is None: + out = torch.empty(quant_state.shape, dtype=quant_state.dtype, device=A.device) + + n = out.numel() + + device = pre_call(A.device) + is_on_gpu([A, absmax, out]) + + if out.dtype == torch.float32: + if quant_state.quant_type == "fp4": + lib.cdequantize_blockwise_fp32_fp4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(n), + ) + else: + lib.cdequantize_blockwise_fp32_nf4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(n), + ) + + elif out.dtype == torch.float16: + if quant_state.quant_type == "fp4": + lib.cdequantize_blockwise_fp16_fp4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(n), + ) + else: + lib.cdequantize_blockwise_fp16_nf4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(n), + ) + + elif out.dtype == torch.bfloat16: + if quant_state.quant_type == "fp4": + lib.cdequantize_blockwise_bf16_fp4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(n), + ) + else: + lib.cdequantize_blockwise_bf16_nf4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(n), + ) + + else: + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") + + post_call(A.device) + + is_transposed = True if A.shape[0] == 1 else False + + if is_transposed: + return out.t() + else: + return out diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index bb6a04892..6bb02944d 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -6,13 +6,14 @@ from functools import reduce # Required in Python 3 import itertools import operator -from typing import Any, Dict, Optional, Tuple +from typing import Optional, Tuple import numpy as np import torch from torch import Tensor -from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict +from bitsandbytes.backends import backends, ensure_backend_is_available +from bitsandbytes.utils import QuantState from .cextension import lib @@ -618,180 +619,8 @@ def estimate_quantiles( return out -class QuantState: - """container for quantization state components to work with Params4bit and similar classes""" - - valid_quant_types = ("fp4", "nf4") - valid_qs_type_keys = [f"bitsandbytes__{x}" for x in valid_quant_types] - valid_qs_keys = [ - "absmax", - "quant_map", - "nested_absmax", - "nested_quant_map", - "quant_state", - "quant_type", - "blocksize", - "dtype", - "shape", - "nested_blocksize", - "nested_dtype", - "nested_offset", - ] - - def __init__( - self, - absmax, - shape=None, - code=None, - blocksize=None, - quant_type=None, - dtype=None, - offset=None, - state2=None, - ): - self.absmax = absmax - self.shape = shape - self.code = code - self.dtype = dtype - self.blocksize = blocksize - self.quant_type = quant_type - self.offset = offset - self.state2 = state2 - self.nested = state2 is not None - - def __get_item__(self, idx): - """ - ensures compatibility with older quant state scheme with nested lists. - assumes the following layout: - state = [qabsmax, input_shape, A.dtype, blocksize, [offset, state2], quant_type] - state2 = [absmax, input_shape, A.dtype, blocksize, None, quant_type] - """ - if self.nested: - list_repr = [ - self.absmax, - self.shape, - self.dtype, - self.blocksize, - [self.offset, self.state2], - self.quant_type, - ] - else: - list_repr = [self.absmax, self.shape, self.dtype, self.blocksize, None, self.quant_type] - return list_repr[idx] - - @classmethod - def from_dict(cls, qs_dict: Dict[str, Any], device: torch.device) -> "QuantState": - """ - unpacks components of state_dict into QuantState - where necessary, convert into strings, torch.dtype, ints, etc. - - qs_dict: based on state_dict, with only relevant keys, striped of prefixes. - - item with key `quant_state.bitsandbytes__[nf4/fp4]` may contain minor and non-tensor quant state items. - """ - - # unpacking tensor with non-tensor components - qs_key = [k for k, v in qs_dict.items() if "quant_state" in k and isinstance(v, torch.Tensor)] - if not len(qs_key) and "quant_type" not in qs_dict: - raise ValueError("Expected packed or unpacked quant_state items, found neither") - elif len(qs_key) != 1 or qs_key[0].split(".")[-1] not in cls.valid_qs_type_keys: - raise ValueError( - f"There should be exactly one `quant_state` item with ending from {cls.valid_qs_type_keys}.\nDetected {qs_key}.", - ) - - # unpacking minor and non-tensor quant state items if necessary - if len(qs_key) == 1: - first_qs_key = qs_key[0] - qs_dict.update(unpack_tensor_to_dict(qs_dict.pop(first_qs_key))) - - qs_dict = {k.split(".")[-1]: v for k, v in qs_dict.items()} # strip prefixes - assert set(qs_dict.keys()).issubset(cls.valid_qs_keys) - - if "nested_absmax" in qs_dict: - offset = torch.tensor(float(qs_dict["nested_offset"])).to(device) - state2 = cls( - absmax=qs_dict["nested_absmax"].to(device), - blocksize=qs_dict["nested_blocksize"], - code=qs_dict["nested_quant_map"].to(device), - dtype=getattr(torch, qs_dict["nested_dtype"]), - ) - else: - offset, state2 = None, None - - quant_state = cls( - quant_type=qs_dict["quant_type"], - absmax=qs_dict["absmax"].to(device), - blocksize=qs_dict["blocksize"], - code=qs_dict["quant_map"].to(device), - dtype=getattr(torch, qs_dict["dtype"]), - shape=torch.Size(qs_dict["shape"]) if qs_dict["shape"] is not None else None, - offset=offset, - state2=state2, - ) - return quant_state - - def as_dict(self, packed=False): - """ - returns dict of tensors and strings to use in serialization via _save_to_state_dict() - param: packed -- returns dict[str, torch.Tensor] for state_dict fit for safetensors saving - """ - qs_dict = { - "quant_type": self.quant_type, - "absmax": self.absmax, - "blocksize": self.blocksize, - "quant_map": self.code, - "dtype": str(self.dtype).strip("torch."), - "shape": tuple(self.shape), - } - if self.nested: - qs_dict.update( - { - "nested_absmax": self.state2.absmax, - "nested_blocksize": self.state2.blocksize, - "nested_quant_map": self.state2.code.clone(), # un-shared to avoid restoring it after shared tensors are removed by safetensors - "nested_dtype": str(self.state2.dtype).strip("torch."), - "nested_offset": self.offset.item(), - }, - ) - if not packed: - return qs_dict - - # packed format allows serialization of non-tensor components, critical for saving in safetensors format - qs_packed_dict = {k: v for k, v in qs_dict.items() if isinstance(v, torch.Tensor)} - non_tensor_dict = {k: v for k, v in qs_dict.items() if not isinstance(v, torch.Tensor)} - qs_packed_dict["quant_state." + "bitsandbytes__" + self.quant_type] = pack_dict_to_tensor(non_tensor_dict) - return qs_packed_dict - - def to(self, device): - # make sure the quantization state is on the right device - self.absmax = self.absmax.to(device) - if self.nested: - self.offset = self.offset.to(device) - self.state2.absmax = self.state2.absmax.to(device) - self.state2.code = self.state2.code.to(device) - - def __eq__(self, other): - if not isinstance(other, QuantState): - return False - - return ( - torch.allclose(self.absmax, other.absmax, atol=1e-6) - and self.shape == other.shape - and torch.allclose(self.code, other.code, atol=1e-6) - and self.dtype == other.dtype - and self.blocksize == other.blocksize - and self.quant_type == other.quant_type - and ( - self.offset == other.offset - if self.offset is not None and other.offset is not None - else self.offset is other.offset - ) - and ( - self.state2 == other.state2 - if self.state2 is not None and other.state2 is not None - else self.state2 is other.state2 - ) - ) +# maintain the compatibility as F.QuantState +QuantState = QuantState def quantize_blockwise( @@ -1150,117 +979,16 @@ def quantize_4bit( tuple(torch.Tensor, torch.Size, torch.dtype, int): The quantization state to undo the quantization. """ - if A.device.type != "cuda": - raise NotImplementedError(f"Device type not supported for FP4 quantization: {A.device.type}") - if quant_type not in ["fp4", "nf4"]: - raise NotImplementedError(f"4-bit quantization data type {quant_type} is not implemented.") - - n = A.numel() - input_shape = A.shape - - if absmax is None: - blocks = n // blocksize - blocks += 1 if n % blocksize > 0 else 0 - absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32) - - if out is None: - mod = dtype2bytes[quant_storage] * 2 - out = torch.zeros(((n + 1) // mod, 1), dtype=quant_storage, device=A.device) - - assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] - - prev_device = pre_call(A.device) - is_on_gpu([A, out, absmax]) - - if A.dtype == torch.float32: - if quant_type == "fp4": - lib.cquantize_blockwise_fp32_fp4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int(n), - ) - else: - lib.cquantize_blockwise_fp32_nf4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int(n), - ) - elif A.dtype == torch.float16: - if quant_type == "fp4": - lib.cquantize_blockwise_fp16_fp4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int(n), - ) - else: - lib.cquantize_blockwise_fp16_nf4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int(n), - ) - elif A.dtype == torch.bfloat16: - if quant_type == "fp4": - lib.cquantize_blockwise_bf16_fp4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int(n), - ) - else: - lib.cquantize_blockwise_bf16_nf4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int(n), - ) - else: - raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") - post_call(A.device) - - code = get_4bit_type(quant_type, device=A.device) - - if compress_statistics: - offset = absmax.mean() - absmax -= offset - qabsmax, state2 = quantize_blockwise(absmax, blocksize=256) - del absmax - state = QuantState( - absmax=qabsmax, - shape=input_shape, - dtype=A.dtype, - blocksize=blocksize, - code=code, - quant_type=quant_type, - offset=offset, - state2=state2, - ) - else: - state = QuantState( - absmax=absmax, - shape=input_shape, - dtype=A.dtype, - blocksize=blocksize, - code=code, - quant_type=quant_type, - ) - - return out, state + ensure_backend_is_available(A.device.type) + return backends[A.device.type].quantize_4bit( + A, + absmax=absmax, + out=out, + blocksize=blocksize, + compress_statistics=compress_statistics, + quant_type=quant_type, + quant_storage=quant_storage, + ) def dequantize_fp4( @@ -1317,106 +1045,10 @@ def dequantize_4bit( torch.Tensor: Dequantized tensor. """ - if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]: - raise ValueError( - f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]", - ) - if quant_type not in ["fp4", "nf4"]: - raise NotImplementedError(f"4-bit quantization data type {quant_type} is not implemented.") - - if quant_state is None: - assert absmax is not None and out is not None - - quant_state = QuantState( - absmax=absmax, - shape=out.shape, - dtype=out.dtype, - blocksize=blocksize, - quant_type=quant_type, - ) - - else: - absmax = quant_state.absmax - - if quant_state.nested: - absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2) - absmax += quant_state.offset - if absmax.dtype != torch.float32: - absmax = absmax.float() - - if out is None: - out = torch.empty(quant_state.shape, dtype=quant_state.dtype, device=A.device) - - n = out.numel() - - device = pre_call(A.device) - is_on_gpu([A, absmax, out]) - if out.dtype == torch.float32: - if quant_state.quant_type == "fp4": - lib.cdequantize_blockwise_fp32_fp4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(quant_state.blocksize), - ct.c_int(n), - ) - else: - lib.cdequantize_blockwise_fp32_nf4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(quant_state.blocksize), - ct.c_int(n), - ) - elif out.dtype == torch.float16: - if quant_state.quant_type == "fp4": - lib.cdequantize_blockwise_fp16_fp4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(quant_state.blocksize), - ct.c_int(n), - ) - else: - lib.cdequantize_blockwise_fp16_nf4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(quant_state.blocksize), - ct.c_int(n), - ) - elif out.dtype == torch.bfloat16: - if quant_state.quant_type == "fp4": - lib.cdequantize_blockwise_bf16_fp4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(quant_state.blocksize), - ct.c_int(n), - ) - else: - lib.cdequantize_blockwise_bf16_nf4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(quant_state.blocksize), - ct.c_int(n), - ) - else: - raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") - post_call(A.device) - - is_transposed = True if A.shape[0] == 1 else False - if is_transposed: - return out.t() - else: - return out + ensure_backend_is_available(A.device.type) + return backends[A.device.type].dequantize_4bit( + A, quant_state=quant_state, absmax=absmax, out=out, blocksize=blocksize, quant_type=quant_type + ) def quantize( @@ -2253,136 +1885,22 @@ def batched_igemm( def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): - shapeA = SA[0] - shapeB = SB[0] - dimsA = len(shapeA) - dimsB = len(shapeB) - assert dimsB == 2, "Only two dimensional matrices are supported for argument B" - if dimsA == 2: - m = shapeA[0] - elif dimsA == 3: - m = shapeA[0] * shapeA[1] - - rows = n = shapeB[0] - assert prod(list(shapeA)) > 0, f"Input tensor dimensions need to be > 0: {shapeA}" - - # if the tensor is empty, return a transformed empty tensor with the right dimensions - if shapeA[0] == 0 and dimsA == 2: - return torch.empty((0, shapeB[0]), device=A.device, dtype=torch.float16) - elif shapeA[1] == 0 and dimsA == 3: - return torch.empty(tuple(shapeA[:2] + [shapeB[0]]), device=A.device, dtype=torch.float16) - - if dimsA == 2 and out is None: - out, Sout = get_transform_buffer((shapeA[0], shapeB[0]), dtype, A.device, "col32", "row") - elif dimsA == 3 and out is None: - out, Sout = get_transform_buffer((shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, "col32", "row") - - assert dimsB != 3, "len(B.shape)==3 not supported" - assert A.device.type == "cuda" - assert B.device.type == "cuda" - assert A.dtype == torch.int8 - assert B.dtype == torch.int8 - assert out.dtype == dtype - assert SA[1] == "col32" - assert SB[1] in ["col_turing", "col_ampere"] - assert Sout[1] == "col32" - assert ( - shapeA[-1] == shapeB[-1] - ), f"Matmullt only supports A @ B^T. Inner matrix dimensions do not match: A @ B = {shapeA} @ {shapeB}" - formatB = SB[1] - prev_device = A.device - torch.cuda.set_device(A.device) - - ptr = CUBLAS_Context.get_instance().get_context(A.device) - ptrA = get_ptr(A) - ptrB = get_ptr(B) - ptrC = get_ptr(out) - - k = shapeA[-1] - lda = ct.c_int32(m * 32) - if formatB == "col_turing": - # turing: tiles with rows filled up to multiple of 8 rows by 32 columns - # n = rows - ldb = ct.c_int32(((rows + 7) // 8) * 8 * 32) - else: - # ampere: tiles with rows filled up to multiple of 32 rows by 32 columns - # n = rows - ldb = ct.c_int32(((rows + 31) // 32) * 32 * 32) - - ldc = ct.c_int32(m * 32) - m = ct.c_int32(m) - n = ct.c_int32(n) - k = ct.c_int32(k) - - has_error = 0 - ptrRowScale = get_ptr(None) - is_on_gpu([A, B, out]) - if formatB == "col_turing": - if dtype == torch.int32: - has_error = lib.cigemmlt_turing_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) - else: - has_error = lib.cigemmlt_turing_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) - elif formatB == "col_ampere": - if dtype == torch.int32: - has_error = lib.cigemmlt_ampere_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) - else: - has_error = lib.cigemmlt_ampere_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) - - if has_error == 100: # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu` - raise NotImplementedError("igemmlt not available (probably built with NO_CUBLASLT)") - - if has_error: - print(f"A: {shapeA}, B: {shapeB}, C: {Sout[0]}; (lda, ldb, ldc): {(lda, ldb, ldc)}; (m, n, k): {(m, n, k)}") - raise Exception("cublasLt ran into an error!") - - torch.cuda.set_device(prev_device) - - return out, Sout + ensure_backend_is_available(A.device.type) + return backends[A.device.type].igemmlt(A, B, SA, SB, out=out, Sout=Sout, dtype=dtype) def mm_dequant(A, quant_state, row_stats, col_stats, out=None, new_row_stats=None, new_col_stats=None, bias=None): - assert A.dtype == torch.int32 - if bias is not None: - assert bias.dtype == torch.float16 - out_shape = quant_state[0] - if len(out_shape) == 3: - out_shape = (out_shape[0] * out_shape[1], out_shape[2]) - - if out is None: - out = torch.empty(out_shape, dtype=torch.float16, device=A.device) - if new_row_stats is None: - new_row_stats = torch.empty(out_shape[0], dtype=torch.float32, device=A.device) - if new_col_stats is None: - new_col_stats = torch.empty(out_shape[1], dtype=torch.float32, device=A.device) - assert new_row_stats.shape[0] == row_stats.shape[0], f"{new_row_stats.shape} vs {row_stats.shape}" - assert new_col_stats.shape[0] == col_stats.shape[0], f"{new_col_stats.shape} vs {col_stats.shape}" - - prev_device = pre_call(A.device) - ptrA = get_ptr(A) - ptrOut = get_ptr(out) - ptrRowStats = get_ptr(row_stats) - ptrColStats = get_ptr(col_stats) - ptrNewRowStats = get_ptr(new_row_stats) - ptrNewColStats = get_ptr(new_col_stats) - ptrBias = get_ptr(bias) - numRows = ct.c_int32(out_shape[0]) - numCols = ct.c_int32(out_shape[1]) - - is_on_gpu([A, row_stats, col_stats, out, new_row_stats, new_col_stats, bias]) - lib.cdequant_mm_int32_fp16( - ptrA, - ptrRowStats, - ptrColStats, - ptrOut, - ptrNewRowStats, - ptrNewColStats, - ptrBias, - numRows, - numCols, + ensure_backend_is_available(A.device.type) + return backends[A.device.type].mm_dequant( + A, + quant_state, + row_stats, + col_stats, + out=out, + new_row_stats=new_row_stats, + new_col_stats=new_col_stats, + bias=bias, ) - post_call(prev_device) - - return out def get_colrow_absmax(A, row_stats=None, col_stats=None, nnz_block_ptr=None, threshold=0.0): @@ -2503,141 +2021,17 @@ def coo_zeros(rows, cols, nnz, device, dtype=torch.half): def double_quant(A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0): - device = A.device - assert A.dtype == torch.half - assert device.type == "cuda" - prev_device = pre_call(A.device) - - cols = A.shape[-1] - if len(A.shape) == 3: - rows = A.shape[0] * A.shape[1] - else: - rows = A.shape[0] - - if row_stats is None or col_stats is None: - row_stats, col_stats, nnz_row_ptr = get_colrow_absmax(A, threshold=threshold) - - if out_col is None: - out_col = torch.zeros(A.shape, device=device, dtype=torch.int8) - if out_row is None: - out_row = torch.zeros(A.shape, device=device, dtype=torch.int8) - - coo_tensor = None - ptrA = get_ptr(A) - ptrColStats = get_ptr(col_stats) - ptrRowStats = get_ptr(row_stats) - ptrOutCol = get_ptr(out_col) - ptrOutRow = get_ptr(out_row) - - is_on_gpu([A, col_stats, row_stats, out_col, out_row]) - if threshold > 0.0: - nnz = nnz_row_ptr[-1].item() - if nnz > 0: - coo_tensor = coo_zeros(A.shape[0], A.shape[1], nnz_row_ptr[-1].item(), device) - ptrRowIdx = get_ptr(coo_tensor.rowidx) - ptrColIdx = get_ptr(coo_tensor.colidx) - ptrVal = get_ptr(coo_tensor.values) - ptrRowPtr = get_ptr(nnz_row_ptr) - - lib.cdouble_rowcol_quant( - ptrA, - ptrRowStats, - ptrColStats, - ptrOutCol, - ptrOutRow, - ptrRowIdx, - ptrColIdx, - ptrVal, - ptrRowPtr, - ct.c_float(threshold), - ct.c_int32(rows), - ct.c_int32(cols), - ) - val, idx = torch.sort(coo_tensor.rowidx) - coo_tensor.rowidx = val - coo_tensor.colidx = coo_tensor.colidx[idx] - coo_tensor.values = coo_tensor.values[idx] - else: - lib.cdouble_rowcol_quant( - ptrA, - ptrRowStats, - ptrColStats, - ptrOutCol, - ptrOutRow, - None, - None, - None, - None, - ct.c_float(0.0), - ct.c_int32(rows), - ct.c_int32(cols), - ) - else: - lib.cdouble_rowcol_quant( - ptrA, - ptrRowStats, - ptrColStats, - ptrOutCol, - ptrOutRow, - None, - None, - None, - None, - ct.c_float(threshold), - ct.c_int32(rows), - ct.c_int32(cols), - ) - post_call(prev_device) - - return out_row, out_col, row_stats, col_stats, coo_tensor + ensure_backend_is_available(A.device.type) + return backends[A.device.type].double_quant( + A, col_stats=col_stats, row_stats=row_stats, out_col=out_col, out_row=out_row, threshold=threshold + ) def transform(A, to_order, from_order="row", out=None, transpose=False, state=None, ld=None): - prev_device = pre_call(A.device) - if state is None: - state = (A.shape, from_order) - else: - from_order = state[1] - if out is None: - out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1], transpose) - else: - new_state = (state[0], to_order) # (shape, order) - - shape = state[0] - if len(shape) == 2: - dim1 = ct.c_int32(shape[0]) - dim2 = ct.c_int32(shape[1]) - else: - dim1 = ct.c_int32(shape[0] * shape[1]) - dim2 = ct.c_int32(shape[2]) - - is_on_gpu([A, out]) - if to_order == "col32": - if transpose: - lib.ctransform_row2col32T(get_ptr(A), get_ptr(out), dim1, dim2) - else: - lib.ctransform_row2col32(get_ptr(A), get_ptr(out), dim1, dim2) - elif to_order == "col_turing": - if transpose: - lib.ctransform_row2turingT(get_ptr(A), get_ptr(out), dim1, dim2) - else: - lib.ctransform_row2turing(get_ptr(A), get_ptr(out), dim1, dim2) - elif to_order == "col_ampere": - if transpose: - lib.ctransform_row2ampereT(get_ptr(A), get_ptr(out), dim1, dim2) - else: - lib.ctransform_row2ampere(get_ptr(A), get_ptr(out), dim1, dim2) - elif to_order == "row": - if from_order == "col_turing": - lib.ctransform_turing2row(get_ptr(A), get_ptr(out), dim1, dim2) - elif from_order == "col_ampere": - lib.ctransform_ampere2row(get_ptr(A), get_ptr(out), dim1, dim2) - else: - raise NotImplementedError(f"Transform function not implemented: From {from_order} to {to_order}") - - post_call(prev_device) - - return out, new_state + ensure_backend_is_available(A.device.type) + return backends[A.device.type].transform( + A, to_order, from_order=from_order, out=out, transpose=transpose, state=state, ld=ld + ) def spmm_coo(cooA, B, out=None): @@ -2899,28 +2293,8 @@ def dequant_min_max(xq, A, B, SA, SB, dtype=torch.half): def extract_outliers(A, SA, idx): - shapeA = SA[0] - formatA = SA[1] - assert formatA in ["col_turing", "col_ampere"] - assert A.device.type == "cuda" - - out = torch.zeros((shapeA[0], idx.numel()), dtype=torch.int8, device=A.device) - - idx_size = ct.c_int32(idx.numel()) - rows = ct.c_int32(shapeA[0]) - cols = ct.c_int32(shapeA[1]) - ptrA = get_ptr(A) - ptrIdx = get_ptr(idx) - ptrOut = get_ptr(out) - - prev_device = pre_call(A.device) - if formatA == "col_turing": - lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) - elif formatA == "col_ampere": - lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) - post_call(prev_device) - - return out + ensure_backend_is_available(A.device.type) + return backends[A.device.type].extract_outliers(A, SA, idx) def pipeline_test(A, batch_size): diff --git a/bitsandbytes/utils.py b/bitsandbytes/utils.py index 0229e59e2..92744dead 100644 --- a/bitsandbytes/utils.py +++ b/bitsandbytes/utils.py @@ -1,7 +1,7 @@ import json import shlex import subprocess -from typing import Tuple +from typing import Any, Dict, Tuple import torch @@ -198,3 +198,179 @@ def unpack_tensor_to_dict(tensor_data): unpacked_dict = json.loads(json_str) return unpacked_dict + + +class QuantState: + """container for quantization state components to work with Params4bit and similar classes""" + + valid_quant_types = ("fp4", "nf4") + valid_qs_type_keys = [f"bitsandbytes__{x}" for x in valid_quant_types] + valid_qs_keys = [ + "absmax", + "quant_map", + "nested_absmax", + "nested_quant_map", + "quant_state", + "quant_type", + "blocksize", + "dtype", + "shape", + "nested_blocksize", + "nested_dtype", + "nested_offset", + ] + + def __init__( + self, + absmax, + shape=None, + code=None, + blocksize=None, + quant_type=None, + dtype=None, + offset=None, + state2=None, + ): + self.absmax = absmax + self.shape = shape + self.code = code + self.dtype = dtype + self.blocksize = blocksize + self.quant_type = quant_type + self.offset = offset + self.state2 = state2 + self.nested = state2 is not None + + def __get_item__(self, idx): + """ + ensures compatibility with older quant state scheme with nested lists. + assumes the following layout: + state = [qabsmax, input_shape, A.dtype, blocksize, [offset, state2], quant_type] + state2 = [absmax, input_shape, A.dtype, blocksize, None, quant_type] + """ + if self.nested: + list_repr = [ + self.absmax, + self.shape, + self.dtype, + self.blocksize, + [self.offset, self.state2], + self.quant_type, + ] + else: + list_repr = [self.absmax, self.shape, self.dtype, self.blocksize, None, self.quant_type] + return list_repr[idx] + + @classmethod + def from_dict(cls, qs_dict: Dict[str, Any], device: torch.device) -> "QuantState": + """ + unpacks components of state_dict into QuantState + where necessary, convert into strings, torch.dtype, ints, etc. + + qs_dict: based on state_dict, with only relevant keys, striped of prefixes. + + item with key `quant_state.bitsandbytes__[nf4/fp4]` may contain minor and non-tensor quant state items. + """ + + # unpacking tensor with non-tensor components + qs_key = [k for k, v in qs_dict.items() if "quant_state" in k and isinstance(v, torch.Tensor)] + if not len(qs_key) and "quant_type" not in qs_dict: + raise ValueError("Expected packed or unpacked quant_state items, found neither") + elif len(qs_key) != 1 or qs_key[0].split(".")[-1] not in cls.valid_qs_type_keys: + raise ValueError( + f"There should be exactly one `quant_state` item with ending from {cls.valid_qs_type_keys}.\nDetected {qs_key}.", + ) + + # unpacking minor and non-tensor quant state items if necessary + if len(qs_key) == 1: + first_qs_key = qs_key[0] + qs_dict.update(unpack_tensor_to_dict(qs_dict.pop(first_qs_key))) + + qs_dict = {k.split(".")[-1]: v for k, v in qs_dict.items()} # strip prefixes + assert set(qs_dict.keys()).issubset(cls.valid_qs_keys) + + if "nested_absmax" in qs_dict: + offset = torch.tensor(float(qs_dict["nested_offset"])).to(device) + state2 = cls( + absmax=qs_dict["nested_absmax"].to(device), + blocksize=qs_dict["nested_blocksize"], + code=qs_dict["nested_quant_map"].to(device), + dtype=getattr(torch, qs_dict["nested_dtype"]), + ) + else: + offset, state2 = None, None + + quant_state = cls( + quant_type=qs_dict["quant_type"], + absmax=qs_dict["absmax"].to(device), + blocksize=qs_dict["blocksize"], + code=qs_dict["quant_map"].to(device), + dtype=getattr(torch, qs_dict["dtype"]), + shape=torch.Size(qs_dict["shape"]) if qs_dict["shape"] is not None else None, + offset=offset, + state2=state2, + ) + return quant_state + + def as_dict(self, packed=False): + """ + returns dict of tensors and strings to use in serialization via _save_to_state_dict() + param: packed -- returns dict[str, torch.Tensor] for state_dict fit for safetensors saving + """ + qs_dict = { + "quant_type": self.quant_type, + "absmax": self.absmax, + "blocksize": self.blocksize, + "quant_map": self.code, + "dtype": str(self.dtype).strip("torch."), + "shape": tuple(self.shape), + } + if self.nested: + qs_dict.update( + { + "nested_absmax": self.state2.absmax, + "nested_blocksize": self.state2.blocksize, + "nested_quant_map": self.state2.code.clone(), # un-shared to avoid restoring it after shared tensors are removed by safetensors + "nested_dtype": str(self.state2.dtype).strip("torch."), + "nested_offset": self.offset.item(), + }, + ) + if not packed: + return qs_dict + + # packed format allows serialization of non-tensor components, critical for saving in safetensors format + qs_packed_dict = {k: v for k, v in qs_dict.items() if isinstance(v, torch.Tensor)} + non_tensor_dict = {k: v for k, v in qs_dict.items() if not isinstance(v, torch.Tensor)} + qs_packed_dict["quant_state." + "bitsandbytes__" + self.quant_type] = pack_dict_to_tensor(non_tensor_dict) + return qs_packed_dict + + def to(self, device): + # make sure the quantization state is on the right device + self.absmax = self.absmax.to(device) + if self.nested: + self.offset = self.offset.to(device) + self.state2.absmax = self.state2.absmax.to(device) + self.state2.code = self.state2.code.to(device) + + def __eq__(self, other): + if not isinstance(other, QuantState): + return False + + return ( + torch.allclose(self.absmax, other.absmax, atol=1e-6) + and self.shape == other.shape + and torch.allclose(self.code, other.code, atol=1e-6) + and self.dtype == other.dtype + and self.blocksize == other.blocksize + and self.quant_type == other.quant_type + and ( + self.offset == other.offset + if self.offset is not None and other.offset is not None + else self.offset is other.offset + ) + and ( + self.state2 == other.state2 + if self.state2 is not None and other.state2 is not None + else self.state2 is other.state2 + ) + ) diff --git a/install_cuda.py b/install_cuda.py index a5d09356d..cf7c8ee71 100644 --- a/install_cuda.py +++ b/install_cuda.py @@ -77,9 +77,7 @@ def main(): download_path = "/tmp" # default download path if len(sys.argv) < 2: - print( - "Usage: python install_cuda.py [user/system] [download_path]" - ) + print("Usage: python install_cuda.py [user/system] [download_path]") sys.exit(1) version = sys.argv[1] @@ -100,9 +98,7 @@ def main(): elif version in cuda_versions: install_cuda(version, base_path, download_path) else: - print( - f"Invalid CUDA version: {version}. Available versions are: {', '.join(cuda_versions.keys())}" - ) + print(f"Invalid CUDA version: {version}. Available versions are: {', '.join(cuda_versions.keys())}") sys.exit(1)