Skip to content
Open
Original file line number Diff line number Diff line change
Expand Up @@ -4,167 +4,6 @@
import triton.language as tl


# @triton.autotune(
# configs=[
# # triton.Config({'SPLIT_SIZE': 64}),
# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'large'}, num_stages=2, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'auto'}, num_stages=2, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'large'}, num_stages=4, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'auto'}, num_stages=4, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 128}),
# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'large'}, num_stages=2, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'auto'}, num_stages=2, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'large'}, num_stages=4, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'auto'}, num_stages=4, num_warps=32),
# triton.Config({"SPLIT_SIZE": 256}),
# # triton.Config({'SPLIT_SIZE': 256, 'grf_mode': 'large'}, num_stages=2, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 256, 'grf_mode': 'auto'}, num_stages=2, num_warps=32),
# triton.Config({"SPLIT_SIZE": 512}),
# # triton.Config({'SPLIT_SIZE': 1024}),
# ],
# key=["num_paired_elements", "QUANT_BLOCK"],
# )
@triton.jit
def dequant_8bit_kernel(
a_ptr,
c_ptr,
quant_ptr,
absmax_ptr,
num_paired_elements,
QUANT_BLOCK: tl.constexpr,
SPLIT_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * SPLIT_SIZE
offsets = block_start + tl.arange(0, SPLIT_SIZE)
mask = offsets < num_paired_elements

a = tl.load(a_ptr + offsets, mask)
a = a.to(tl.uint8)

# apply conversion
scaled_int8 = tl.load(quant_ptr + a, mask)

abs_blocks_lim = (num_paired_elements // QUANT_BLOCK) * QUANT_BLOCK + num_paired_elements % QUANT_BLOCK
abs_offsets = offsets // QUANT_BLOCK
mask_blocked = offsets < abs_blocks_lim

absmax = tl.load(absmax_ptr + abs_offsets, mask_blocked)
# apply scales
out_dq = scaled_int8 * absmax

offs = block_start + tl.arange(0, SPLIT_SIZE)
mask = offs < num_paired_elements
tl.store(c_ptr + offs, out_dq, mask)


def dequant_int8_blockwise(
A_nf4: torch.Tensor,
quant_state_code: torch.Tensor,
absmax: torch.Tensor,
out: torch.Tensor,
quant_blocksize: int = 64,
):
number_of_paired_elements = A_nf4.numel()

SPLIT_SIZE = 256
# grid = lambda META: (triton.cdiv(number_of_paired_elements, META["SPLIT_SIZE"]),)
grid = (triton.cdiv(number_of_paired_elements, SPLIT_SIZE),)
dequant_8bit_kernel[grid](
A_nf4,
out,
quant_state_code,
absmax,
number_of_paired_elements,
quant_blocksize,
SPLIT_SIZE,
)
return out


# @triton.autotune(
# configs=[
# triton.Config({"SPLIT_NUM_BLOCKS": 1, "grf_mode": "auto"}, num_stages=4, num_warps=32),
# triton.Config({"SPLIT_NUM_BLOCKS": 2, "grf_mode": "auto"}, num_stages=4, num_warps=32),
# triton.Config({"SPLIT_NUM_BLOCKS": 1}),
# triton.Config({"SPLIT_NUM_BLOCKS": 2}),
# ],
# key=["n_elements"],
# )
@triton.jit
def quantize_blockwise_kernel(
A_ptr,
code_ptr,
absmax_ptr,
out_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
CODE_SIZE: tl.constexpr,
SPLIT_NUM_BLOCKS: tl.constexpr,
):
block_start_idx = tl.program_id(0) * SPLIT_NUM_BLOCKS
thread_idx = tl.arange(0, SPLIT_NUM_BLOCKS * BLOCK_SIZE)

offsets = block_start_idx * BLOCK_SIZE + thread_idx
mask = offsets < n_elements

A = tl.load(A_ptr + offsets, mask=mask, other=0.0)

# To be able process several blocks -> (BLOCK_SIZE, SPLIT_NUM_BLOCKS)
A_reshaped = tl.reshape(A, (SPLIT_NUM_BLOCKS, BLOCK_SIZE))

# Calculating absamax for each block
absmax = tl.max(tl.abs(A_reshaped), axis=1)
tl.store(absmax_ptr + block_start_idx + tl.arange(0, SPLIT_NUM_BLOCKS), absmax)

A_normalized = A_reshaped / absmax[:, None]
A_normalized = tl.clamp(A_normalized, -1.0, 1.0)

lower_pivot = tl.zeros((SPLIT_NUM_BLOCKS, BLOCK_SIZE), dtype=tl.int32)
upper_pivot = tl.full((SPLIT_NUM_BLOCKS, BLOCK_SIZE), CODE_SIZE - 1, dtype=tl.int32)

for _ in range(8): # ceil(log2(code_size)) = 8, actually, in general case should be input parameter
pivot = (lower_pivot + upper_pivot) // 2
val = tl.load(code_ptr + pivot)
is_higher = A_normalized > val # code[pivot]
lower_pivot = tl.where(is_higher, pivot, lower_pivot)
upper_pivot = tl.where(is_higher, upper_pivot, pivot)

# Choose closest level
lower_val = tl.load(code_ptr + lower_pivot)
upper_val = tl.load(code_ptr + upper_pivot)
lower_dist = tl.abs(A_normalized - lower_val)
upper_dist = tl.abs(A_normalized - upper_val)
quantized = tl.where(lower_dist <= upper_dist, lower_pivot, upper_pivot).to(tl.uint8)

# too slow approach
# diff = tl.abs(A_normalized[:, :, None] - code[None, None, :])
# quantized = tl.argmin(diff, axis=2).to(tl.uint8)

quantized_flat = tl.reshape(quantized, (BLOCK_SIZE * SPLIT_NUM_BLOCKS,))
tl.store(out_ptr + offsets, quantized_flat, mask=mask)


def quantize_blockwise_triton(A, blocksize, code, blocks, absmax, quantized_out):
n = A.numel()

split_num_blocks = 1
grid = (triton.cdiv(blocks, split_num_blocks),)
# grid = lambda META: (triton.cdiv(blocks, META["SPLIT_NUM_BLOCKS"]),)
quantize_blockwise_kernel[grid](
A_ptr=A,
code_ptr=code,
absmax_ptr=absmax,
out_ptr=quantized_out,
n_elements=n,
BLOCK_SIZE=blocksize,
CODE_SIZE=code.numel(),
SPLIT_NUM_BLOCKS=split_num_blocks,
)

return quantized_out, absmax


# Triton implementation of similar CUDA kernel to avoid loading code from csrc/kernels.cu::dQuantizeFP4
# @triton.autotune(
# configs=[
Expand Down Expand Up @@ -587,7 +426,7 @@ def dequant_nf4_kernel(
tl.store(c_ptr + offs, out_dq, mask)


def _dequantize_4bit_impl(
def dequantize_4bit_impl(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
Expand All @@ -611,7 +450,7 @@ def _dequantize_4bit_impl(
dequant_nf4_kernel[grid](A, out, absmax, number_of_paired_elements, blocksize, SPLIT_SIZE)


def _dequantize_4bit_impl_passing_code(
def dequantize_4bit_impl_passing_code(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
Expand Down
195 changes: 195 additions & 0 deletions bitsandbytes/backends/triton/kernels_8bit_quant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
import torch

import triton
import triton.language as tl


# @triton.autotune(
# configs=[
# # triton.Config({'SPLIT_SIZE': 64}),
# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'large'}, num_stages=2, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'auto'}, num_stages=2, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'large'}, num_stages=4, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'auto'}, num_stages=4, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 128}),
# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'large'}, num_stages=2, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'auto'}, num_stages=2, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'large'}, num_stages=4, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'auto'}, num_stages=4, num_warps=32),
# triton.Config({"SPLIT_SIZE": 256}),
# # triton.Config({'SPLIT_SIZE': 256, 'grf_mode': 'large'}, num_stages=2, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 256, 'grf_mode': 'auto'}, num_stages=2, num_warps=32),
# triton.Config({"SPLIT_SIZE": 512}),
# # triton.Config({'SPLIT_SIZE': 1024}),
# ],
# key=["num_paired_elements", "QUANT_BLOCK"],
# )
@triton.jit
def dequant_8bit_kernel(
a_ptr,
out_ptr,
code_ptr,
absmax_ptr,
n,
QUANT_BLOCK: tl.constexpr,
SPLIT_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * SPLIT_SIZE
offsets = block_start + tl.arange(0, SPLIT_SIZE)
mask = offsets < n
out_dq = dequant_8bit_blockwise_kernel_util(a_ptr, offsets, code_ptr, absmax_ptr, mask, QUANT_BLOCK)
tl.store(out_ptr + offsets, out_dq, mask)


def dequant_8bit_blockwise(
a: torch.Tensor,
absmax: torch.Tensor,
quant_state_code: torch.Tensor,
quant_blocksize: int = 64,
dtype: torch.dtype = None,
out: torch.Tensor = None,
):
n = a.numel()
if out is None:
if dtype is None:
raise ValueError("If out is None, dtype must be specified")
out = torch.empty_like(a, dtype=dtype, device=a.device)

SPLIT_SIZE = 256
# grid = lambda META: (triton.cdiv(number_of_paired_elements, META["SPLIT_SIZE"]),)
grid = (triton.cdiv(n, SPLIT_SIZE),)
dequant_8bit_kernel[grid](
a,
out,
quant_state_code,
absmax,
n,
quant_blocksize,
SPLIT_SIZE,
)
return out


# @triton.autotune(
# configs=[
# triton.Config({"SPLIT_NUM_BLOCKS": 1, "grf_mode": "auto"}, num_stages=4, num_warps=32),
# triton.Config({"SPLIT_NUM_BLOCKS": 2, "grf_mode": "auto"}, num_stages=4, num_warps=32),
# triton.Config({"SPLIT_NUM_BLOCKS": 1}),
# triton.Config({"SPLIT_NUM_BLOCKS": 2}),
# ],
# key=["n_elements"],
# )
@triton.jit
def quantize_8bit_blockwise_kernel(
A_ptr,
code_ptr,
absmax_ptr,
out_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
CODE_SIZE: tl.constexpr,
SPLIT_NUM_BLOCKS: tl.constexpr,
):
block_start_idx = tl.program_id(0) * SPLIT_NUM_BLOCKS
thread_idx = tl.arange(0, SPLIT_NUM_BLOCKS * BLOCK_SIZE)

offsets = block_start_idx * BLOCK_SIZE + thread_idx
mask = offsets < n_elements

A = tl.load(A_ptr + offsets, mask=mask, other=0.0)

quantized, absmax = quantize_8bit_blockwise_kernel_util(A, code_ptr, CODE_SIZE, BLOCK_SIZE, SPLIT_NUM_BLOCKS)
tl.store(absmax_ptr + block_start_idx + tl.arange(0, SPLIT_NUM_BLOCKS), absmax)
tl.store(out_ptr + offsets, quantized, mask=mask)


def quantize_blockwise_triton(A, code, blocksize, absmax=None, out=None):
n = A.numel()
blocks = -(n // -blocksize)

if absmax is None:
absmax = torch.empty((blocks,), device=A.device, dtype=A.dtype)
if out is None:
out = torch.empty_like(A.flatten(), dtype=torch.uint8)

split_num_blocks = 1
grid = (triton.cdiv(blocks, split_num_blocks),)
# grid = lambda META: (triton.cdiv(blocks, META["SPLIT_NUM_BLOCKS"]),)
quantize_8bit_blockwise_kernel[grid](
A_ptr=A,
code_ptr=code,
absmax_ptr=absmax,
out_ptr=out,
n_elements=n,
BLOCK_SIZE=blocksize,
CODE_SIZE=code.numel(),
SPLIT_NUM_BLOCKS=split_num_blocks,
# num_warps=1,
# num_stages=2,
)
out = out.reshape(A.shape)

return out, absmax


@triton.jit
def quantize_8bit_blockwise_kernel_util(
a,
code_ptr,
CODE_SIZE: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
N_PER_TH: tl.constexpr,
):
# To be able process several blocks -> (BLOCK_SIZE, SPLIT_NUM_BLOCKS)
a_reshaped = tl.reshape(a, (N_PER_TH, BLOCK_SIZE))

# Calculating absmax for each block
absmax = tl.max(tl.abs(a_reshaped), axis=1)

a_normalized = a_reshaped / absmax[:, None]
a_normalized = tl.clamp(a_normalized, -1.0, 1.0)

lower_pivot = tl.zeros((N_PER_TH, BLOCK_SIZE), dtype=tl.int32)
upper_pivot = tl.full((N_PER_TH, BLOCK_SIZE), CODE_SIZE - 1, dtype=tl.int32)

# ceil(log2(code_size)) = 8, actually, in general case should be input parameter
for _ in range(8):
pivot = (lower_pivot + upper_pivot) // 2
val = tl.load(code_ptr + pivot)
is_higher = a_normalized > val # code[pivot]
lower_pivot = tl.where(is_higher, pivot, lower_pivot)
upper_pivot = tl.where(is_higher, upper_pivot, pivot)

# Choose closest level
lower_val = tl.load(code_ptr + lower_pivot)
upper_val = tl.load(code_ptr + upper_pivot)
lower_dist = tl.abs(a_normalized - lower_val)
upper_dist = tl.abs(a_normalized - upper_val)
quantized = tl.where(lower_dist <= upper_dist, lower_pivot, upper_pivot).to(tl.uint8)

# too slow approach
# diff = tl.abs(A_normalized[:, :, None] - code[None, None, :])
# quantized = tl.argmin(diff, axis=2).to(tl.uint8)

quantized_flat = tl.reshape(quantized, (BLOCK_SIZE * N_PER_TH,))
return quantized_flat, absmax


@triton.jit
def dequant_8bit_blockwise_kernel_util(
a_ptr,
offsets,
code_ptr,
absmax_ptr,
mask,
BLOCK_SIZE: tl.constexpr,
):
a = tl.load(a_ptr + offsets, mask, other=0).to(tl.uint8)
scaled_int8 = tl.load(code_ptr + a, mask)
# Load scales
absmax_offsets = offsets // BLOCK_SIZE
absmax = tl.load(absmax_ptr + absmax_offsets, mask=mask, other=0.0, eviction_policy="evict_last")
# Apply scales
out_dq = scaled_int8 * absmax
return out_dq
Loading