Skip to content

[Draft][triton] Implemented 8bit optimizers on XPU #1692

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,167 +4,6 @@
import triton.language as tl


# @triton.autotune(
# configs=[
# # triton.Config({'SPLIT_SIZE': 64}),
# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'large'}, num_stages=2, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'auto'}, num_stages=2, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'large'}, num_stages=4, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'auto'}, num_stages=4, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 128}),
# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'large'}, num_stages=2, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'auto'}, num_stages=2, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'large'}, num_stages=4, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'auto'}, num_stages=4, num_warps=32),
# triton.Config({"SPLIT_SIZE": 256}),
# # triton.Config({'SPLIT_SIZE': 256, 'grf_mode': 'large'}, num_stages=2, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 256, 'grf_mode': 'auto'}, num_stages=2, num_warps=32),
# triton.Config({"SPLIT_SIZE": 512}),
# # triton.Config({'SPLIT_SIZE': 1024}),
# ],
# key=["num_paired_elements", "QUANT_BLOCK"],
# )
@triton.jit
def dequant_8bit_kernel(
a_ptr,
c_ptr,
quant_ptr,
absmax_ptr,
num_paired_elements,
QUANT_BLOCK: tl.constexpr,
SPLIT_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * SPLIT_SIZE
offsets = block_start + tl.arange(0, SPLIT_SIZE)
mask = offsets < num_paired_elements

a = tl.load(a_ptr + offsets, mask)
a = a.to(tl.uint8)

# apply conversion
scaled_int8 = tl.load(quant_ptr + a, mask)

abs_blocks_lim = (num_paired_elements // QUANT_BLOCK) * QUANT_BLOCK + num_paired_elements % QUANT_BLOCK
abs_offsets = offsets // QUANT_BLOCK
mask_blocked = offsets < abs_blocks_lim

absmax = tl.load(absmax_ptr + abs_offsets, mask_blocked)
# apply scales
out_dq = scaled_int8 * absmax

offs = block_start + tl.arange(0, SPLIT_SIZE)
mask = offs < num_paired_elements
tl.store(c_ptr + offs, out_dq, mask)


def dequant_int8_blockwise(
A_nf4: torch.Tensor,
quant_state_code: torch.Tensor,
absmax: torch.Tensor,
out: torch.Tensor,
quant_blocksize: int = 64,
):
number_of_paired_elements = A_nf4.numel()

SPLIT_SIZE = 256
# grid = lambda META: (triton.cdiv(number_of_paired_elements, META["SPLIT_SIZE"]),)
grid = (triton.cdiv(number_of_paired_elements, SPLIT_SIZE),)
dequant_8bit_kernel[grid](
A_nf4,
out,
quant_state_code,
absmax,
number_of_paired_elements,
quant_blocksize,
SPLIT_SIZE,
)
return out


# @triton.autotune(
# configs=[
# triton.Config({"SPLIT_NUM_BLOCKS": 1, "grf_mode": "auto"}, num_stages=4, num_warps=32),
# triton.Config({"SPLIT_NUM_BLOCKS": 2, "grf_mode": "auto"}, num_stages=4, num_warps=32),
# triton.Config({"SPLIT_NUM_BLOCKS": 1}),
# triton.Config({"SPLIT_NUM_BLOCKS": 2}),
# ],
# key=["n_elements"],
# )
@triton.jit
def quantize_blockwise_kernel(
A_ptr,
code_ptr,
absmax_ptr,
out_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
CODE_SIZE: tl.constexpr,
SPLIT_NUM_BLOCKS: tl.constexpr,
):
block_start_idx = tl.program_id(0) * SPLIT_NUM_BLOCKS
thread_idx = tl.arange(0, SPLIT_NUM_BLOCKS * BLOCK_SIZE)

offsets = block_start_idx * BLOCK_SIZE + thread_idx
mask = offsets < n_elements

A = tl.load(A_ptr + offsets, mask=mask, other=0.0)

# To be able process several blocks -> (BLOCK_SIZE, SPLIT_NUM_BLOCKS)
A_reshaped = tl.reshape(A, (SPLIT_NUM_BLOCKS, BLOCK_SIZE))

# Calculating absamax for each block
absmax = tl.max(tl.abs(A_reshaped), axis=1)
tl.store(absmax_ptr + block_start_idx + tl.arange(0, SPLIT_NUM_BLOCKS), absmax)

A_normalized = A_reshaped / absmax[:, None]
A_normalized = tl.clamp(A_normalized, -1.0, 1.0)

lower_pivot = tl.zeros((SPLIT_NUM_BLOCKS, BLOCK_SIZE), dtype=tl.int32)
upper_pivot = tl.full((SPLIT_NUM_BLOCKS, BLOCK_SIZE), CODE_SIZE - 1, dtype=tl.int32)

for _ in range(8): # ceil(log2(code_size)) = 8, actually, in general case should be input parameter
pivot = (lower_pivot + upper_pivot) // 2
val = tl.load(code_ptr + pivot)
is_higher = A_normalized > val # code[pivot]
lower_pivot = tl.where(is_higher, pivot, lower_pivot)
upper_pivot = tl.where(is_higher, upper_pivot, pivot)

# Choose closest level
lower_val = tl.load(code_ptr + lower_pivot)
upper_val = tl.load(code_ptr + upper_pivot)
lower_dist = tl.abs(A_normalized - lower_val)
upper_dist = tl.abs(A_normalized - upper_val)
quantized = tl.where(lower_dist <= upper_dist, lower_pivot, upper_pivot).to(tl.uint8)

# too slow approach
# diff = tl.abs(A_normalized[:, :, None] - code[None, None, :])
# quantized = tl.argmin(diff, axis=2).to(tl.uint8)

quantized_flat = tl.reshape(quantized, (BLOCK_SIZE * SPLIT_NUM_BLOCKS,))
tl.store(out_ptr + offsets, quantized_flat, mask=mask)


def quantize_blockwise_triton(A, blocksize, code, blocks, absmax, quantized_out):
n = A.numel()

split_num_blocks = 1
grid = (triton.cdiv(blocks, split_num_blocks),)
# grid = lambda META: (triton.cdiv(blocks, META["SPLIT_NUM_BLOCKS"]),)
quantize_blockwise_kernel[grid](
A_ptr=A,
code_ptr=code,
absmax_ptr=absmax,
out_ptr=quantized_out,
n_elements=n,
BLOCK_SIZE=blocksize,
CODE_SIZE=code.numel(),
SPLIT_NUM_BLOCKS=split_num_blocks,
)

return quantized_out, absmax


# Triton implementation of similar CUDA kernel to avoid loading code from csrc/kernels.cu::dQuantizeFP4
# @triton.autotune(
# configs=[
Expand Down Expand Up @@ -587,7 +426,7 @@ def dequant_nf4_kernel(
tl.store(c_ptr + offs, out_dq, mask)


def _dequantize_4bit_impl(
def dequantize_4bit_impl(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
Expand All @@ -611,7 +450,7 @@ def _dequantize_4bit_impl(
dequant_nf4_kernel[grid](A, out, absmax, number_of_paired_elements, blocksize, SPLIT_SIZE)


def _dequantize_4bit_impl_passing_code(
def dequantize_4bit_impl_passing_code(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
Expand Down
Loading