From 82f0d708d82c58be0bdb2be86cae7ccbaab4833f Mon Sep 17 00:00:00 2001 From: Valentine233 Date: Tue, 3 Dec 2024 02:23:27 -0500 Subject: [PATCH 01/36] [CPU] add int8 sdpa path for cpu --- test/quantization/test_sfdp_int8_fx_pass.py | 199 ++ torchao/csrc/cpu/sdpa.cpp | 2195 +++++++++++++++++++ torchao/csrc/cpu/toy.cpp | 20 + torchao/ops.py | 55 +- torchao/quantization/__init__.py | 4 + torchao/quantization/sfdp_int8_fx_pass.py | 733 +++++++ 6 files changed, 3205 insertions(+), 1 deletion(-) create mode 100644 test/quantization/test_sfdp_int8_fx_pass.py create mode 100644 torchao/csrc/cpu/sdpa.cpp create mode 100644 torchao/csrc/cpu/toy.cpp create mode 100644 torchao/quantization/sfdp_int8_fx_pass.py diff --git a/test/quantization/test_sfdp_int8_fx_pass.py b/test/quantization/test_sfdp_int8_fx_pass.py new file mode 100644 index 0000000000..a39a98c364 --- /dev/null +++ b/test/quantization/test_sfdp_int8_fx_pass.py @@ -0,0 +1,199 @@ +import torchao + +import contextlib +import functools +import itertools +import math + +import torch +import torch.utils.checkpoint +from torch._dynamo.debug_utils import aot_graph_input_parser +from torch._dynamo.utils import counters +from torch._inductor import config +from torch._inductor.test_case import run_tests, TestCase +from torch._inductor.utils import run_and_get_code +from torch.testing._internal.common_utils import IS_LINUX, skipIfRocm +from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA + +import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq +from torch._export import capture_pre_autograd_graph +from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e +from torch.ao.quantization.quantizer.x86_inductor_quantizer import ( + X86InductorQuantizer, +) +from torchao.quantization.sfdp_int8_fx_pass import _sfdp_init_int8 + +class SelfAttnLikeModule(torch.nn.Module): + def __init__( + self, + input_dim, + has_mask, + num_attention_heads=None, + attention_head_size=None, + ) -> None: + super().__init__() + self.input_dim = input_dim + self.q_proj = torch.nn.Linear(input_dim, input_dim, bias=False) + self.k_proj = torch.nn.Linear(input_dim, input_dim, bias=False) + self.v_proj = torch.nn.Linear(input_dim, input_dim, bias=False) + self.softmax = torch.nn.Softmax(dim=-1) + assert num_attention_heads is not None + assert attention_head_size is not None + self.num_attention_heads = num_attention_heads + self.attention_head_size = attention_head_size + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.dense = torch.nn.Linear(self.all_head_size, self.all_head_size) + self.dropout = torch.nn.Dropout(0) + self.has_mask = has_mask + + def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: + new_x_shape = x.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) + x = x.view(new_x_shape) + return x.permute([0, 2, 1, 3]) + + def forward(self, x, mask): + q = self.q_proj(x) + k = self.k_proj(x) + v = self.v_proj(x) + q = self.transpose_for_scores(q) + k = self.transpose_for_scores(k) + v = self.transpose_for_scores(v) + scores = torch.matmul(q, k.transpose(-1, -2)) / (self.input_dim**0.5) + if self.has_mask: + scores = scores + mask + attention = self.softmax(scores) + # attention = self.dropout(attention) + context_layer = torch.matmul(attention, v) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + context_layer = context_layer.view( + context_layer.size()[:-2] + (self.all_head_size,) + ) + return self.dense(context_layer) + +def _generate_qdq_quantized_model(mod, inputs, quantizer): + with torch.no_grad(): + export_model = capture_pre_autograd_graph(mod, inputs) + prepare_model = prepare_pt2e(export_model, quantizer) + prepare_model(*inputs) + convert_model = convert_pt2e(prepare_model) + torch.ao.quantization.move_exported_model_to_eval(convert_model) + return convert_model + +class TestSDPAPatternRewriterTemplate(TestCase): + def _clone_inputs(self, inputs): + def clone(x): + if not isinstance(x, torch.Tensor): + return x + return x.clone() + + return [clone(x) for x in inputs] + + def _check_common( + self, + dot_prod_attention, + args1=None, + contains=True, + atol=1e-5, + has_fuse_pattern=True, + has_dropout=False, + check_train=True, + override_check_equal=False, + dtype=torch.float, + rtol=1.3e-6, + ): + if args1 is None: + tensor_shape = (4, 2, 16, 32) + args1 = [ + torch.randn(tensor_shape, device=self.device, dtype=dtype), + torch.randn(tensor_shape, device=self.device, dtype=dtype), + torch.randn(tensor_shape, device=self.device, dtype=dtype), + ] + else: + args1 = list(args1) + args2 = self._clone_inputs(args1) + + for training in [False, True] if check_train else [False]: + for x in itertools.chain(args1[:], args2[:]): + if isinstance(x, torch.Tensor) and x.is_floating_point(): + x.requires_grad = training + + dropout_arg = [training] if has_dropout else [] + torch.manual_seed(1234) + result1 = dot_prod_attention(*(args1 + dropout_arg)) + + counters.clear() + torch.manual_seed(1234) + result2, source_code = run_and_get_code( + torch.compile(dot_prod_attention, fullgraph=True), + *(args2 + dropout_arg), + ) + source_code = "\n".join(source_code) + if has_fuse_pattern: + self.assertGreaterEqual(counters["inductor"]["fuse_attention_int8"], 1) + if contains: + # many of the patterns get re-expanded in dispatcher + self.assertIn( + "torchao.scaled_dot_product_int8", + source_code, + ) + + # some tests configured with very low dropout where we still want to check equality + if not has_dropout or override_check_equal: + self.assertEqual(result1, result2, atol=atol, rtol=1.3e-6) + + if training: + result1.sum().backward() + result2.sum().backward() + for arg1, arg2 in zip(args1, args2): + if ( + isinstance(arg1, torch.Tensor) + and arg1.is_floating_point() + and (not has_dropout or override_check_equal) + ): + self.assertEqual(arg1.grad, arg2.grad, atol=atol, rtol=rtol) + + @skipIfRocm + @config.patch({"freezing": True}) + def _test_sdpa_rewriter_int8_1_to_4(self): + # pattern is different for bs=1 + for dtype, has_mask, bs in itertools.product( + [torch.float32], [True, False], [56, 1] + ): + mod = SelfAttnLikeModule( + input_dim=64 * 16, + has_mask=has_mask, + num_attention_heads=16, + attention_head_size=64, + ).eval() + maybe_autocast = ( + torch.cpu.amp.autocast() + if dtype == torch.bfloat16 + else contextlib.nullcontext() + ) + inputs = [ + torch.randn((bs, 384, 64 * 16), device=self.device, dtype=dtype), + torch.randn((bs, 1, 1, 384), device=self.device) if has_mask else None, + ] + with torch.no_grad(), maybe_autocast: + _sfdp_init_int8() + quantizer = X86InductorQuantizer() + quantizer.set_global(xiq.get_default_x86_inductor_quantization_config()) + quantizer.set_function_type_qconfig( + torch.matmul, quantizer.get_global_quantization_config() + ) + convert_model = _generate_qdq_quantized_model(mod, inputs, quantizer) + self._check_common( + convert_model, args1=inputs, check_train=False, atol=1.0 + ) + +if HAS_CPU: + class SDPAPatternRewriterCpuTests(TestSDPAPatternRewriterTemplate): + device = "cpu" + test_sdpa_rewriter_int8_1_to_4_cpu = TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_int8_1_to_4 + +if __name__ == "__main__": + if IS_LINUX: + run_tests() diff --git a/torchao/csrc/cpu/sdpa.cpp b/torchao/csrc/cpu/sdpa.cpp new file mode 100644 index 0000000000..44aaef1bcc --- /dev/null +++ b/torchao/csrc/cpu/sdpa.cpp @@ -0,0 +1,2195 @@ +// // #define TORCH_ASSERT_ONLY_METHOD_OPERATORS +// #include +// #include + +// #include +// #include +// #include +// #include +// #include +// #include +// #include +// #include +// #include +// #include +// #include +// #include +// // #include +// // #include +// #include +// #include + +// #ifndef AT_PER_OPERATOR_HEADERS +// #include +// #else +// #include +// #endif + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#endif + +#include +#include +#include +#include +#include +#include + +namespace torchao { + +namespace { + +template +struct is_reduced_floating_point: + std::integral_constant || + std::is_same_v> { +}; + +template +constexpr bool is_reduced_floating_point_v = is_reduced_floating_point::value; + +// out = val * a + b +// is_b_stride_zero: If the stride of b is 0 (mask broadcasting case), +// take b as a scalar pointer. +template +void _scale_attn_mask_fusion_kernel( + T1* a, + T2* b, + const int& size, + T1* out, + T1& val) { + const auto vec_size1 = at::vec::Vectorized::size(); + const auto vec_size2 = at::vec::Vectorized::size(); + constexpr int64_t T1_n = + (vec_size2 == vec_size1 * 2 && is_reduced_floating_point_v) ? 2 : 1; + constexpr int64_t T2_n = 1; + auto vec_scale = at::vec::VectorizedN(val); + int64_t i = 0; + for (; i < size - (size % vec_size2); i += vec_size2) { + auto a_n = at::vec::VectorizedN::loadu(a + i); + at::vec::VectorizedN b_n; + if constexpr(is_b_stride_zero) { + b_n = at::vec::VectorizedN((T1)b[0]); + } else { + b_n = at::vec::VectorizedN::loadu(b + i); + } + auto b_n_convert = at::vec::convert(b_n); + auto res = a_n * vec_scale + b_n_convert; + res.store(out + i); + } + for (; i < size; i++) { + auto tmp0 = a[i]; + T1 tmp1; + if constexpr(is_b_stride_zero) { + tmp1 = (T1)b[0]; + } else { + tmp1 = (T1)b[i]; + } + out[i] = tmp0 * val + tmp1; + } +} + +// 1) out = exp(a - val) +// 2) val = sum(out) +template +void _exp_reduce_sum_fusion_kernel( + T1* a, + const int& size, + T2* out, + T1& val) { + auto vec_size = at::vec::Vectorized::size(); + auto vec_max = at::vec::Vectorized(val); + T1 tmp_sum = 0; + auto vec_tmp_sum = at::vec::Vectorized(tmp_sum); + for (long i = 0; i < vec_size * (size / vec_size); i += vec_size) { + auto tmp0 = at::vec::Vectorized::loadu(a + i); + auto tmp1 = tmp0 - vec_max; + auto tmp2 = tmp1.exp_u20(); + vec_tmp_sum += tmp2; + _store(out + i, tmp2); + } + tmp_sum = at::vec::vec_reduce_all( + [](at::vec::Vectorized& x, at::vec::Vectorized& y) { + return x + y; + }, + vec_tmp_sum); + for (long i = vec_size * (size / vec_size); i < size; i++) { + auto tmp0 = a[i]; + auto tmp1 = tmp0 - val; + auto tmp2 = exp(tmp1); + tmp_sum += tmp2; + out[i] = tmp2; + } + val = tmp_sum; +} + +// 1) out = a * scale +// 2) max = max(out) +template +void _mul_reduce_max_fusion_kernel( + const scalar_t* a, + const scalar_t& scale, + const int& size, + scalar_t* out, + scalar_t& max) { + auto vec_size = at::vec::Vectorized::size(); + auto vec_scale = at::vec::Vectorized(scale); + scalar_t tmp_max = -std::numeric_limits::infinity(); + auto vec_tmp_max = at::vec::Vectorized(tmp_max); + for (long i = 0; i < vec_size * (size / vec_size); i += vec_size) { + auto tmp0 = at::vec::Vectorized::loadu(a + i); + auto tmp1 = tmp0 * vec_scale; + vec_tmp_max = at::vec::maximum(vec_tmp_max, tmp1); + _store(out + i, tmp1); + } + for (long i = vec_size * (size / vec_size); i < size; i++) { + auto tmp0 = a[i]; + auto tmp1 = tmp0 * scale; + tmp_max = std::max(tmp_max, tmp1); + out[i] = tmp1; + } + // max = std::max( + // tmp_max, + // at::vec::vec_reduce_all( + // [](vec::Vectorized& x, at::vec::Vectorized& y) { + // return at::vec::maximum(x, y); + // }, + // vec_tmp_max)); + max = std::max(tmp_max, vec_tmp_max.reduce_max()); +} + +template +static scalar_t* conditional_data_ptr(scalar_t* ptr, scalar_t* ptr2) { + TORCH_CHECK(ptr2 == nullptr); + return ptr; +} + +template , int> = 0> +static scalar_t* conditional_data_ptr(float* ptr, scalar_t* ptr2) { + return ptr2; +} + +template +void fill_stub(scalar_t* data, scalar_t val, int64_t size) { + using Vec = at::vec::Vectorized; + Vec data_vec = Vec(val); + int64_t d = 0; + for (; d < size - (size % Vec::size()); d += Vec::size()) { + data_vec.store(data + d); + } + #if !defined(_MSC_VER) && !defined(COMPILING_FOR_MIN_SIZE) + # pragma unroll + #endif + for (; d < size; d++) { + data[d] = val; + } +} + +void reshape_attn_mask_to_4d( + at::Tensor& attn_mask, + int64_t batchSize, + int64_t num_head, + int64_t qSize, + int64_t kvSize) { + // Support mask shapes: + // 2d: ({Q_seq_len, 1} x {KV_seq_len, 1}) + // 4d: ({Batch, 1} x {Num_heads, 1} x {Q_seq_len, 1} x {KV_seq_len, 1}) + // Guaranteed in check_attn_mask_shape + int64_t attn_mask_size_0 = 1; + int64_t attn_mask_size_1 = 1; + if (attn_mask.dim() == 4) { + if (attn_mask.size(0) == batchSize) { + attn_mask_size_0 = batchSize; + } + if (attn_mask.size(1) == num_head) { + attn_mask_size_1 = num_head; + } + } + attn_mask = attn_mask + .view({attn_mask_size_0, attn_mask_size_1, attn_mask.size(-2), attn_mask.size(-1)}) + .expand({attn_mask_size_0, attn_mask_size_1, qSize, kvSize}); +} + +// TODO: Use at::native::_store instead when it supports Half. +template +void _store(scalar_t* dst, at::vec::Vectorized src, int size=at::vec::Vectorized::size()) { + src.store(dst, size); +} + +template +typename std::enable_if_t, void> +_store(scalar_t* dst, at::vec::Vectorized src) { + auto res = at::vec::convert_from_float(src, src); + res.store(dst, at::vec::Vectorized::size()); +} + +template +typename std::enable_if_t || std::is_same_v, void> +_store(scalar_t* dst, at::vec::Vectorized src) { + auto res = at::vec::convert(src); + res.store(dst, at::vec::Vectorized::size()); +} + +template +void pad_row_zero( + scalar_t* value_ptr, + scalar_t* padding_value_ptr, + int rows, + int cols, + int ldi) { + auto vec_size = at::vec::Vectorized::size(); + int i = 0; + for (; i < rows - 1; i++) { + int j = 0; + for (; j < cols - (cols % vec_size); j += vec_size) { + auto vec_v = + at::vec::Vectorized::loadu(value_ptr + i * ldi + j); + vec_v.store(padding_value_ptr + i * cols + j); + } + + if (j < cols) { + auto vec_v = at::vec::Vectorized::loadu( + value_ptr + i * ldi + j, cols - j); + vec_v.store(padding_value_ptr + i * cols + j, cols - j); + } + } + + // zero padding + int j = 0; + for (; j < cols - (cols % vec_size); j += vec_size) { + auto vec_v = at::vec::Vectorized(0); + vec_v.store(padding_value_ptr + i * cols + j); + } + + if (j < cols) { + auto vec_v = at::vec::Vectorized(0); + vec_v.store(padding_value_ptr + i * cols + j, cols - j); + } +} + +template +void pad_row_128_padding( + scalar_t* value_ptr, + scalar_t* padding_value_ptr, + int rows, + int cols, + int ldi, + int padding) { + auto vec_size = at::vec::Vectorized::size(); + int i = 0; + for (; i < rows - padding; i++) { + int j = 0; + for (; j < cols - (cols % vec_size); j += vec_size) { + auto vec_v = + at::vec::Vectorized::loadu(value_ptr + i * ldi + j); + vec_v.store(padding_value_ptr + i * cols + j); + } + + if (j < cols) { + auto vec_v = at::vec::Vectorized::loadu( + value_ptr + i * ldi + j, cols - j); + vec_v.store(padding_value_ptr + i * cols + j, cols - j); + } + } + + // 128 padding + for (; i < rows; i++) { + int j = 0; + for (; j < cols - (cols % vec_size); j += vec_size) { + auto vec_v = at::vec::Vectorized(128); + vec_v.store(padding_value_ptr + i * cols + j); + } + + if (j < cols) { + auto vec_v = at::vec::Vectorized(128); + vec_v.store(padding_value_ptr + i * cols + j, cols - j); + } + } +} + +template +void pad_col_zero( + scalar_t* value_ptr, + scalar_t* padding_value_ptr, + int rows, + int cols, + int ldi) { + auto vec_size = at::vec::Vectorized::size(); + for (int i = 0; i < rows; i++) { + int j = 0; + for (; j < cols - 1 - ((cols - 1) % vec_size); j += vec_size) { + auto vec_v = + at::vec::Vectorized::loadu(value_ptr + i * ldi + j); + vec_v.store(padding_value_ptr + i * cols + j); + } + if (j < cols - 1) { + auto vec_v = at::vec::Vectorized::loadu( + value_ptr + i * ldi + j, cols - 1 - j); + vec_v.store(padding_value_ptr + i * cols + j, cols - 1 - j); + *(padding_value_ptr + i * cols + cols - 1) = scalar_t(0); + } + } +} + +template +void pad_col_zero_padding( + scalar_t* value_ptr, + scalar_t* padding_value_ptr, + int rows, + int cols, + int ldi, + int padding) { + auto vec_size = at::vec::Vectorized::size(); + for (int i = 0; i < rows; i++) { + int j = 0; + for (; j < cols - padding - ((cols - padding) % vec_size); j += vec_size) { + auto vec_v = + at::vec::Vectorized::loadu(value_ptr + i * ldi + j); + vec_v.store(padding_value_ptr + i * cols + j); + } + if (j < cols - padding) { + auto vec_v = at::vec::Vectorized::loadu( + value_ptr + i * ldi + j, cols - padding - j); + vec_v.store(padding_value_ptr + i * cols + j, cols - padding - j); + *(padding_value_ptr + i * cols + cols - padding) = scalar_t(0); + } + } +} + +/* +1. dequant +2. add mask +3. max reduce for softmax +*/ +template +void _dequant_mask_max_fusion_kernel( + const int32_t* in, + const mask_t* mask_ptr, + const int32_t* sum_a_ptr, + const int32_t* sum_b_ptr, + const int& M, + const int& N, + const int& ldi, + const int& ldm, // leading dimension mask + const int& ldo, + const int32_t& beta, // zp_a*zp_b*k + const float& alpha, // scale_a*scale_b*scale_sdpa + float* out, + float* sfm_max_ptr) { + const int32_t vec_size = at::vec::Vectorized::size(); + auto vec_beta = at::vec::Vectorized(beta); + auto vec_alpha = at::vec::Vectorized(alpha); + for (long row = 0; row < M; row += 1) { + auto sum_a = sum_a_ptr[row]; + auto vec_sum_a = at::vec::Vectorized(sum_a); + const int32_t* tmp_in = in + row * ldi; + float* tmp_out = out + row * ldo; + const mask_t* mask_data_ptr = mask_ptr + row * ldm; + float tmp_max = -std::numeric_limits::infinity(); + auto vec_tmp_max = at::vec::Vectorized(tmp_max); + for (long col = 0; col < vec_size * (N / vec_size); col += vec_size) { + auto vec_sum_b = at::vec::Vectorized::loadu(sum_b_ptr + col); + auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col); + auto tmp1 = tmp0 - vec_sum_b; + auto tmp2 = tmp1 - vec_sum_a; + auto tmp3 = tmp2 + vec_beta; + auto tmp4 = at::vec::convert(tmp3); + auto tmp5 = tmp4 * vec_alpha; + auto tmp6 = at::vec::Vectorized::loadu(mask_data_ptr + col); + auto tmp7 = at::vec::convert(tmp6); + auto tmp8 = tmp5 + tmp7; + vec_tmp_max = at::vec::maximum(vec_tmp_max, tmp8); + _store(tmp_out + col, tmp8); + } + tmp_max = std::max(tmp_max, vec_tmp_max.reduce_max()); + for (long col = vec_size * (N / vec_size); col < N; col++) { + auto sum_b = sum_b_ptr[col]; + auto tmp0 = tmp_in[col]; + auto tmp1 = tmp0 - sum_b; + auto tmp2 = tmp1 - sum_a; + auto tmp3 = tmp2 + beta; + auto tmp4 = (float) tmp3; + auto tmp5 = tmp4 * alpha; + auto tmp6 = mask_data_ptr[col]; + auto tmp7 = (float) tmp6; + auto tmp8 = tmp5 + tmp7; + tmp_max = std::max(tmp_max, tmp8); + tmp_out[col] = tmp8; + } + sfm_max_ptr[row] = std::max(sfm_max_ptr[row], tmp_max); + } +} + +/* +1. dequant +2. max reduce for softmax +*/ +void _dequant_max_fusion_kernel( + const int32_t* in, + const int32_t* sum_a_ptr, + const int32_t* sum_b_ptr, + const int& M, + const int& N, + const int& ldi, + const int& ldo, + const int32_t& beta, // zp_a*zp_b*k + const float& alpha, // scale_a*scale_b*scale_sdpa + float* out, + float* sfm_max_ptr) { + const int32_t vec_size = at::vec::Vectorized::size(); + auto vec_beta = at::vec::Vectorized(beta); + auto vec_alpha = at::vec::Vectorized(alpha); + for (long row = 0; row < M; row += 1) { + auto sum_a = sum_a_ptr[row]; + auto vec_sum_a = at::vec::Vectorized(sum_a); + const int32_t* tmp_in = in + row * ldi; + float* tmp_out = out + row * ldo; + float tmp_max = -std::numeric_limits::infinity(); + auto vec_tmp_max = at::vec::Vectorized(tmp_max); + for (long col = 0; col < vec_size * (N / vec_size); col += vec_size) { + auto vec_sum_b = at::vec::Vectorized::loadu(sum_b_ptr + col); + auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col); + auto tmp1 = tmp0 - vec_sum_b; + auto tmp2 = tmp1 - vec_sum_a; + auto tmp3 = tmp2 + vec_beta; + auto tmp4 = at::vec::convert(tmp3); + auto tmp5 = tmp4 * vec_alpha; + vec_tmp_max = at::vec::maximum(vec_tmp_max, tmp5); + _store(tmp_out + col, tmp5); + } + tmp_max = std::max(tmp_max, vec_tmp_max.reduce_max()); + for (long col = vec_size * (N / vec_size); col < N; col++) { + auto sum_b = sum_b_ptr[col]; + auto tmp0 = tmp_in[col]; + auto tmp1 = tmp0 - sum_b; + auto tmp2 = tmp1 - sum_a; + auto tmp3 = tmp2 + beta; + auto tmp4 = (float) tmp3; + auto tmp5 = tmp4 * alpha; + tmp_max = std::max(tmp_max, tmp5); + tmp_out[col] = tmp5; + } + sfm_max_ptr[row] = std::max(sfm_max_ptr[row], tmp_max); + } +} + +/* +1. Softmax: sub max, exp, sum reduce, div sum +2. quant +3. sum for attention +*/ +template +void _sub_exp_sum_div_quant_sum_fusion_kernel( + const float* in, + const int64_t& M, + const int64_t& N_step, + const int64_t& NSlice, + const int& ldi, + const int& ldo, + const int& kvSize, + const int& rndkvSplitSize, + const int& av_gemm_K, + const int32_t& beta1, // zp_a + const int32_t& beta2, // zp_b + const float& alpha, // scale_a + float* local, + scalar_t* out, + float* sfm_max_ptr, + float* sfm_sum_ptr, + int32_t* sum_a_ptr) { + const int32_t vec_size = at::vec::Vectorized::size(); + float min_val = 0; + float max_val = 255; + auto vec_min_val = at::vec::Vectorized(min_val); + auto vec_max_val = at::vec::Vectorized(max_val); + scalar_t zero = 0; + auto vec_zero = at::vec::Vectorized(zero); + float beta1_float = (float) beta1; + auto vec_beta1 = at::vec::Vectorized(beta1_float); + for (int64_t row = 0; row < M; ++row) { + auto sfm_max = sfm_max_ptr[row]; + auto vec_max = at::vec::Vectorized(sfm_max); + // sub max, exp, sum reduce + const float* qk_block_data = in + row * rndkvSplitSize; + for (int64_t l = 0; l < NSlice; l ++) { + int64_t n = l * N_step; + int64_t kvBlockSize = std::min(N_step, kvSize - n); + const float* tmp_in = qk_block_data + l * ldi; + float tmp_sum = 0; + auto vec_tmp_sum = at::vec::Vectorized(tmp_sum); + float* tmp_out = local + n; + for (long col = 0; col < vec_size * (kvBlockSize / vec_size); col += vec_size) { + auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col); + auto tmp1 = tmp0 - vec_max; + auto tmp2 = tmp1.exp_u20(); + vec_tmp_sum += tmp2; + _store(tmp_out + col, tmp2); + } + tmp_sum += vec_tmp_sum.reduce_add(); + for (long col = vec_size * (kvBlockSize / vec_size); col < kvBlockSize; col++) { + auto tmp0 = tmp_in[col]; + auto tmp1 = tmp0 - sfm_max; + auto tmp2 = exp(tmp1); + tmp_sum += tmp2; + tmp_out[col] = tmp2; + } + sfm_sum_ptr[row] += tmp_sum; + } + // div sum, sum for attention + auto sum_scale = 1 / sfm_sum_ptr[row] / alpha; + auto vec_sum_scale = at::vec::Vectorized(sum_scale); + scalar_t* qk_reduced_block_data = out + row * av_gemm_K; + for (int64_t l = 0; l < NSlice; l ++) { + int64_t n = l * N_step; + int64_t kvBlockSize = std::min(N_step, kvSize - n); + int32_t tmp_sum = 0; + auto vec_tmp_sum = at::vec::Vectorized(tmp_sum); + float* tmp_in = local + n; + scalar_t* tmp_out = qk_reduced_block_data + l * ldo; + for (long col = 0; col < vec_size * (kvBlockSize / vec_size); col += vec_size) { + auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col); + auto tmp1 = tmp0 * vec_sum_scale; + auto tmp2 = tmp1.round(); + auto tmp3 = tmp2 + vec_beta1; + auto tmp4 = at::vec::maximum(tmp3, vec_min_val); + auto tmp5 = at::vec::minimum(tmp4, vec_max_val); + _store(tmp_out + col, tmp5); + auto tmp6 = at::vec::convert(tmp5); + vec_tmp_sum += tmp6; + } + tmp_sum += vec_tmp_sum.reduce_add(); + for (long col = vec_size * (kvBlockSize / vec_size); col < kvBlockSize; col++) { + auto tmp0 = tmp_in[col]; + auto tmp1 = tmp0 * sum_scale; + auto tmp2 = std::nearbyint(tmp1); + auto tmp3 = tmp2 + beta1_float; + auto tmp4 = std::max(tmp3, min_val); + auto tmp5 = std::min(tmp4, max_val); + tmp_out[col] = tmp5; + auto tmp6 = (int32_t) tmp5; + tmp_sum += tmp6; + } + sum_a_ptr[row] += tmp_sum * beta2; + // set zero + for (long col = kvBlockSize; col < vec_size * (av_gemm_K / vec_size); col += vec_size) { + _store(tmp_out + col, vec_zero); + } + for (long col = vec_size * (av_gemm_K / vec_size); col < av_gemm_K; col++) { + tmp_out[col] = zero; + } + } + } +} + +template +void _sub_exp_sum_div_quant_fusion_kernel( + const float* in, + const int64_t& M, + const int64_t& N_step, + const int64_t& NSlice, + const int& ldi, + const int& ldo, + const int& kvSize, + const int& rndkvSplitSize, + const int& av_gemm_K, + const int32_t& beta1, // zp_a + const float& alpha, // scale_a + float* local, + scalar_t* out, + float* sfm_max_ptr, + float* sfm_sum_ptr) { + const int32_t vec_size = at::vec::Vectorized::size(); + float min_val = 0; + float max_val = 255; + auto vec_min_val = at::vec::Vectorized(min_val); + auto vec_max_val = at::vec::Vectorized(max_val); + scalar_t zero = 0; + auto vec_zero = at::vec::Vectorized(zero); + float beta1_float = (float) beta1; + auto vec_beta1 = at::vec::Vectorized(beta1_float); + for (int64_t row = 0; row < M; ++row) { + auto sfm_max = sfm_max_ptr[row]; + auto vec_max = at::vec::Vectorized(sfm_max); + // sub max, exp, sum reduce + const float* qk_block_data = in + row * rndkvSplitSize; + for (int64_t l = 0; l < NSlice; l ++) { + int64_t n = l * N_step; + int64_t kvBlockSize = std::min(N_step, kvSize - n); + const float* tmp_in = qk_block_data + l * ldi; + float tmp_sum = 0; + auto vec_tmp_sum = at::vec::Vectorized(tmp_sum); + float* tmp_out = local + n; + for (long col = 0; col < vec_size * (kvBlockSize / vec_size); col += vec_size) { + auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col); + auto tmp1 = tmp0 - vec_max; + auto tmp2 = tmp1.exp_u20(); + vec_tmp_sum += tmp2; + _store(tmp_out + col, tmp2); + } + tmp_sum += vec_tmp_sum.reduce_add(); + for (long col = vec_size * (kvBlockSize / vec_size); col < kvBlockSize; col++) { + auto tmp0 = tmp_in[col]; + auto tmp1 = tmp0 - sfm_max; + auto tmp2 = exp(tmp1); + tmp_sum += tmp2; + tmp_out[col] = tmp2; + } + sfm_sum_ptr[row] += tmp_sum; + } + // div sum, sum for attention + auto sum_scale = 1 / sfm_sum_ptr[row] / alpha; + auto vec_sum_scale = at::vec::Vectorized(sum_scale); + scalar_t* qk_reduced_block_data = out + row * av_gemm_K; + for (int64_t l = 0; l < NSlice; l ++) { + int64_t n = l * N_step; + int64_t kvBlockSize = std::min(N_step, kvSize - n); + float* tmp_in = local + n; + scalar_t* tmp_out = qk_reduced_block_data + l * ldo; + for (long col = 0; col < vec_size * (kvBlockSize / vec_size); col += vec_size) { + auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col); + auto tmp1 = tmp0 * vec_sum_scale; + auto tmp2 = tmp1.round(); + auto tmp3 = tmp2 + vec_beta1; + auto tmp4 = at::vec::maximum(tmp3, vec_min_val); + auto tmp5 = at::vec::minimum(tmp4, vec_max_val); + _store(tmp_out + col, tmp5); + } + for (long col = vec_size * (kvBlockSize / vec_size); col < kvBlockSize; col++) { + auto tmp0 = tmp_in[col]; + auto tmp1 = tmp0 * sum_scale; + auto tmp2 = std::nearbyint(tmp1); + auto tmp3 = tmp2 + beta1_float; + auto tmp4 = std::max(tmp3, min_val); + auto tmp5 = std::min(tmp4, max_val); + tmp_out[col] = tmp5; + } + // set zero + for (long col = kvBlockSize; col < vec_size * (av_gemm_K / vec_size); col += vec_size) { + _store(tmp_out + col, vec_zero); + } + for (long col = vec_size * (av_gemm_K / vec_size); col < av_gemm_K; col++) { + tmp_out[col] = zero; + } + } + } +} + +/* +1. dequant +2. quant +*/ +template +void _dequant_quant_fusion_kernel( + const int32_t* in, + const int32_t* sum_a_ptr, + const int32_t* sum_b_ptr, + const int& M, + const int& N, + const int& ldi, + const int& ldo, + const int32_t& beta1, // zp_a*zp_b*k + const int32_t& beta2, // zp_c + const float& alpha, // scale_a*scale_b/scale_c + scalar_t* out) { + const int32_t vec_size = at::vec::Vectorized::size(); + float min_val = 0; + float max_val = 255; + auto vec_min_val = at::vec::Vectorized(min_val); + auto vec_max_val = at::vec::Vectorized(max_val); + auto vec_beta1 = at::vec::Vectorized(beta1); + auto vec_alpha = at::vec::Vectorized(alpha); + float beta2_float = (float) beta2; + auto vec_beta2 = at::vec::Vectorized(beta2_float); + for (long row = 0; row < M; row += 1) { + auto sum_a = sum_a_ptr[row]; + auto vec_sum_a = at::vec::Vectorized(sum_a); + const int32_t* tmp_in = in + row * ldi; + scalar_t* tmp_out = out + row * ldo; + for (long col = 0; col < vec_size * (N / vec_size); col += vec_size) { + auto vec_sum_b = at::vec::Vectorized::loadu(sum_b_ptr + col); + auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col); + auto tmp1 = tmp0 - vec_sum_b; + auto tmp2 = tmp1 - vec_sum_a; + auto tmp3 = tmp2 + vec_beta1; + auto tmp4 = at::vec::convert(tmp3); + auto tmp5 = tmp4 * vec_alpha; + auto tmp6 = tmp5.round(); + auto tmp7 = tmp6 + vec_beta2; + auto tmp8 = at::vec::maximum(tmp7, vec_min_val); + auto tmp9 = at::vec::minimum(tmp8, vec_max_val); + _store(tmp_out + col, tmp9); + } + for (long col = vec_size * (N / vec_size); col < N; col++) { + auto sum_b = sum_b_ptr[col]; + auto tmp0 = tmp_in[col]; + auto tmp1 = tmp0 - sum_b; + auto tmp2 = tmp1 - sum_a; + auto tmp3 = tmp2 + beta1; + auto tmp4 = (float) tmp3; + auto tmp5 = tmp4 * alpha; + auto tmp6 = std::nearbyint(tmp5); + auto tmp7 = tmp6 + beta2_float; + auto tmp8 = std::max(tmp7, min_val); + auto tmp9 = std::min(tmp8, max_val); + tmp_out[col] = tmp9; + } + } +} + +template +void _int_sum_b_contiguous_kernel_helper( + const scalar_t* in, + int32_t* out, + const int& N, + const int32_t& scale) { + const int32_t vec_size = at::vec::Vectorized::size(); + int32_t tmp_sum = 0; + auto vec_tmp_sum = at::vec::Vectorized(tmp_sum); + for (long i = 0; i < vec_size * (N / vec_size); i += vec_size) { + auto tmp0 = at::vec::Vectorized::loadu(in + i); + auto tmp1 = at::vec::convert(tmp0); + vec_tmp_sum = vec_tmp_sum + tmp1; + } + tmp_sum += vec_tmp_sum.reduce_add(); + for (long i = vec_size * (N / vec_size); i < N; i++) { + // for (long i = 0; i < N; i++) { + tmp_sum += static_cast(in[i]); + } + out[0] = tmp_sum * scale; +} + +template +void _int_sum_b_contiguous_kernel( + const scalar_t* in, + int32_t* out, + const int& M, + const int& N, + const int& ld, + const int32_t& scale) { + for (long r = 0; r < M; r += 1) { + _int_sum_b_contiguous_kernel_helper(in + r * ld, out + r, N, scale); + } +} + +template +void _int_sum_a_contiguous_kernel( + const scalar_t* in, + int32_t* out, + const int& M, + const int& N, + const int& ld, + const int32_t& scale) { + const int32_t vec_size = at::vec::Vectorized::size(); + auto vec_scale = at::vec::Vectorized(scale); + // initialization with 0 + int32_t zero = 0; + auto vec_zero = at::vec::Vectorized(zero); + for (long i = 0; i < vec_size * (M / vec_size); i += vec_size) { + _store(out + i, vec_zero); + } + for (long i = vec_size * (M / vec_size); i < M; i++) { + out[i] = zero; + } + // sum + for (long j = 0; j < N; j++) { + const scalar_t* tmp_in = in + j * ld; + for (long i = 0; i < vec_size * (M / vec_size); i += vec_size) { + auto tmp0 = at::vec::Vectorized::loadu(tmp_in + i); + auto tmp1 = at::vec::Vectorized::loadu(out + i); + auto tmp2 = at::vec::convert(tmp0); + auto tmp3 = tmp1 + tmp2; + _store(out + i, tmp3); + } + for (long i = vec_size * (M / vec_size); i < M; i++) { + // for (long i = 0; i < M; i++) { + auto tmp0 = tmp_in[i]; + auto tmp1 = out[i]; + auto tmp2 = static_cast(tmp0); + auto tmp3 = tmp1 + tmp2; + out[i] = tmp3; + } + } + // scale + for (long i = 0; i < vec_size * (M / vec_size); i += vec_size) { + auto tmp0 = at::vec::Vectorized::loadu(out + i); + auto tmp1 = tmp0 * vec_scale; + _store(out + i, tmp1); + } + for (long i = vec_size * (M / vec_size); i < M; i++) { + auto tmp0 = out[i]; + auto tmp1 = tmp0 * scale; + out[i] = tmp1; + } +} + +void do_convert_u8_s8( + unsigned char* src, + signed char* dst, + int64_t in_rows, + int64_t in_cols, + int64_t ldi, + int64_t ldo) { + auto vec_size = at::vec::Vectorized::size(); + auto vec_128 = at::vec::Vectorized(128); + for (int64_t r = 0; r < in_rows; r++) { + const unsigned char* tmp_src = src + r * ldi; + signed char* tmp_dst = dst + r * ldo; + for (int64_t c = 0; c < vec_size * (in_cols / vec_size); c += vec_size) { + auto tmp0 = at::vec::Vectorized::loadu(tmp_src + c, vec_size); + auto tmp1 = at::vec::convert(tmp0); + auto tmp2 = tmp1 - vec_128; + auto tmp3 = at::vec::convert(tmp2); + _store(tmp_dst + c, tmp3, vec_size); + } + for (int64_t c = vec_size * (in_cols / vec_size); c < in_cols; c++) { + // for (int64_t c = 0; c < in_cols; c++) { + auto tmp0 = tmp_src[c]; + auto tmp1 = (int16_t) tmp0; + auto tmp2 = tmp1 - 128; + auto tmp3 = (signed char) tmp2; + tmp_dst[c] = tmp3; + } + } +} + +template +void do_transpose( + scalar_t* src, + scalar_t* dst, + int64_t in_rows, + int64_t in_cols, + int64_t ldi, + int64_t ldo) { + for (int64_t r=0; r +void do_copy( + scalar_t* src, + scalar_t* dst, + int64_t in_rows, + int64_t in_cols, + int64_t ldi, + int64_t ldo) { + for (int64_t r=0; r +void pad_remain_row_col( + scalar_t* value_ptr, + int rows, + int cols, + int prows, + int pcols, + int ldi, + scalar_t pad_val=0) { + auto psize = pcols - cols; + if (psize == 0 && prows == rows) { + return; + } + const int32_t vec_size = at::vec::Vectorized::size(); + auto pad = at::vec::Vectorized(pad_val); + if (psize > 0) { + for (int i = 0; i < rows; i++) { + int j = 0; + for (; j < psize - (psize % vec_size); j += vec_size) { + pad.store(value_ptr + i * ldi + cols + j); + } + if (j < psize) { + pad.store(value_ptr + i * ldi + cols + j, psize - j); + } + } + } + + for (int i = rows; i < prows; i++) { + int j = 0; + for (; j < pcols - (pcols % vec_size); j += vec_size) { + pad.store(value_ptr + i * ldi + j); + } + if (j < pcols) { + pad.store(value_ptr + i * ldi + j, pcols - j); + } + } +} + +template +void copy_value_with_pad( + scalar_t* value_ptr, + scalar_t* dst_ptr, + int rows, + int cols, + int prows, + int pcols, + int ldi, + scalar_t pad_val=0) { + const int32_t vec_size = at::vec::Vectorized::size(); + auto pad = at::vec::Vectorized(pad_val); + int i = 0; + for (; i < rows; i++) { + int j = 0; + for (; j < cols - (cols % vec_size); j += vec_size) { + auto vec_v = + at::vec::Vectorized::loadu(value_ptr + i * ldi + j); + vec_v.store(dst_ptr + i * pcols + j); + } + + if (j < cols) { + auto vec_v = at::vec::Vectorized::loadu( + value_ptr + i * ldi + j, cols - j); + vec_v.store(dst_ptr + i * pcols + j, cols - j); + } + + // col padding + auto psize = pcols - cols; + if (psize > 0) { + int pj = 0; + for (; pj < psize - (psize % vec_size); pj += vec_size) { + pad.store(dst_ptr + i * pcols + cols + pj); + } + if (pj < psize) { + pad.store(dst_ptr + i * pcols + cols + pj, psize - pj); + } + } + } + + // row padding + for (; i < prows; i++) { + int j = 0; + for (; j < pcols - (pcols % vec_size); j += vec_size) { + pad.store(dst_ptr + i * pcols + j); + } + if (j < pcols) { + pad.store(dst_ptr + i * pcols + j, pcols - j); + } + + } + +} + +// thread_local std::unordered_map< +// BrgemmKey, +// std::shared_ptr> cache_brgemm_kernels; + +// thread_local std::unordered_map< +// PackBKey, +// std::shared_ptr> cache_packb_kernels; + +// std::shared_ptr create_or_get_microkernel( +// int64_t M, +// int64_t N, +// int64_t K, +// int64_t batch_size, +// int lda, +// int ldb, +// int ldc, +// dt dt_a, +// dt dt_b, +// dt dt_c) { +// BrgemmKey key_brgemm(M, N, K, batch_size, lda, ldb, ldc, dt_a, dt_b, dt_c); +// auto search = cache_brgemm_kernels.find(key_brgemm); +// if (search != cache_brgemm_kernels.end()) { +// return search->second; +// } else { +// cache_brgemm_kernels.insert( +// {key_brgemm, +// std::make_shared( +// M, N, K, batch_size, lda, ldb, ldc, dt_a, dt_b, dt_c)}); +// return cache_brgemm_kernels[key_brgemm]; +// } +// } + +// std::shared_ptr create_or_get_packb_microkernel( +// int64_t K, +// int64_t N, +// int ld_in, +// int ld_out, +// dt dt_in, +// dt dt_out, +// bool do_trans) { +// PackBKey key_packb(K, N, ld_in, ld_out, dt_in, dt_out); +// auto search = cache_packb_kernels.find(key_packb); +// if (search != cache_packb_kernels.end()) { +// return search->second; +// } else { +// cache_packb_kernels.insert( +// {key_packb, +// std::make_shared( +// K, N, +// do_trans ? dnnl::ukernel::pack_type::trans : dnnl::ukernel::pack_type::no_trans, +// ld_in, ld_out, dt_in, dt_out)}); +// return cache_packb_kernels[key_packb]; +// } +// } + +// UINT8 - u8u8s32 +template +typename std::enable_if_t, void> +sdpa_int8_kernel_impl( + const at::Tensor& output, + const at::Tensor& q, + const at::Tensor& k, + const at::Tensor& v, + double dropout_p, + bool is_causal, + at::Tensor& attention_mask, + double scale, + int32_t q_zp, + float q_scale, + int32_t k_zp, + float k_scale, + int32_t v_zp, + float v_scale, + int32_t a_zp, + float a_scale, + int32_t o_zp, + float o_scale) { + // using dt = dnnl::memory::data_type; + // using namespace dnnl; + // using namespace dnnl::ukernel; + // auto starts = duration_cast(system_clock::now().time_since_epoch()).count(); + // Query (Batch x Num_heads x Q_seq_len x Dim_per_head) + // -> (Batch x Q_seq_len x Num_heads x Dim_per_head) + // Key (Batch x Num_heads x KV_seq_len x Dim_per_head) + // -> (Batch x KV_seq_len x Num_heads x Dim_per_head) + // Value (Batch x Num_heads x KV_seq_len x Dim_per_head) + // -> (Batch x KV_seq_len x Num_heads x Dim_per_head) + at::Tensor query = q.transpose(1, 2); + at::Tensor key = k.transpose(1, 2); + at::Tensor value = v.transpose(1, 2); + + const auto accumulate_dtype = at::kFloat; // at::toOpMathType(dtype); + + using accum_t = float; // at::opmath_type; + using Vec = at::vec::Vectorized; + accum_t scaling_factor = + sdp::calculate_scale(query, scale).as_float_unchecked(); + // if (attention_mask.defined() && attention_mask.scalar_type() != ScalarType::Float) { + // attention_mask = attention_mask.to(at::kFloat); + // } + int block_64 = 64; + // Sizes + TORCH_CHECK( + (query.size(3) == value.size(3)) && (key.size(3) == value.size(3)), + "scaled_dot_product_attention_sdpa: Q/K/V should have the same head size"); + TORCH_CHECK( + kv_split_size % block_64 == 0, "kv_split_size is not divisble by ", block_64); + + int64_t batchSize = query.size(0); + int64_t qSize = query.size(1); + int64_t kvSize = value.size(1); + int64_t num_head = query.size(2); + int64_t headSize = query.size(3); + + + bool has_attn_mask = attention_mask.defined() && attention_mask.numel(); + if (has_attn_mask) { + reshape_attn_mask_to_4d(attention_mask, batchSize, num_head, qSize, kvSize); + } + + // Strides + int64_t qStrideB = query.stride(0); + int64_t qStrideM = query.stride(1); + int64_t qStrideH = query.stride(2); + int64_t kStrideB = key.stride(0); + int64_t kStrideN = key.stride(1); + int64_t kStrideH = key.stride(2); + int64_t vStrideB = value.stride(0); + int64_t vStrideN = value.stride(1); + int64_t vStrideH = value.stride(2); + int64_t oStrideB = output.stride(0); + int64_t oStrideM = output.stride(1); + int64_t oStrideH = output.stride(2); + int64_t mStrideB = + (attention_mask.defined() && attention_mask.size(0) > 1) + ? attention_mask.stride(0) + : 0; + int64_t mStrideH = + (attention_mask.defined() && attention_mask.size(1) > 1) + ? attention_mask.stride(1) + : 0; + int64_t mStrideM = + (attention_mask.defined() && attention_mask.size(2) > 1) + ? attention_mask.stride(2) + : 0; + int64_t mStrideN = + (attention_mask.defined() && attention_mask.size(3) > 1) + ? attention_mask.stride(3) + : 0; + + int64_t qSplitSize = q_split_size > qSize ? qSize : q_split_size; + int64_t kvSplitSize = kv_split_size > kvSize ? kvSize : kv_split_size; + int64_t qSlice = (qSize - 1) / qSplitSize + 1; + int64_t qTail = (qSize - 1) % qSplitSize + 1; + int64_t kvSlice = (kvSize - 1) / kvSplitSize + 1; + int64_t kvTail = (kvSize - 1) % kvSplitSize + 1; + int64_t num_thread = at::get_num_threads(); + + int64_t rndHeadSize = (headSize + block_64 - 1L) / block_64 * block_64; + // one of 16, 32, 48, 64 + auto select_tail_tail_block_size = [](int64_t size) -> int64_t { + if (size == 0) { + return 0; + } else if (size <= 16) { + return 16; + } else if (size <= 32) { + return 32; + } else if (size <= 48) { + return 48; + } else { + return 64; + } + }; + int64_t kv_tail_tail_block_size = select_tail_tail_block_size(kvTail % block_64); + int64_t rndkvSplitSize = (kvSplitSize + block_64 - 1L) / block_64 * block_64; + int64_t rndkvTail = (kvTail + block_64 - 1L) / block_64 * block_64; + int64_t rndkvSize = kv_split_size > kvSize ? rndkvTail : rndkvSplitSize * kvSlice + rndkvTail; + + bool av_gemm_K_mul4 = kvSplitSize % 4 == 0; + int av_gemm_K_padding = av_gemm_K_mul4 ? 0 : 4 - kvSplitSize % 4; + // // If K of Gemm is not even, use mkl gemm instead of tpp for BF16 + int av_gemm_K = kvSplitSize + av_gemm_K_padding; + bool av_gemm_K_tail_mul4 = kvTail % 4 == 0; + int av_gemm_K_tail_padding = av_gemm_K_tail_mul4 ? 0 : 4 - kvTail % 4; + int av_gemm_K_tail = kvTail + av_gemm_K_tail_padding; + + + // dt u8_dt = dt::u8; + // dt s8_dt = dt::s8; + // // dt f32_dt = dt::f32; + // dt s32_dt = dt::s32; + auto u8_dt = at::ScalarType::Byte; + auto s8_dt = at::ScalarType::Int; + auto f32_dt = at::ScalarType::Float; + + // Data ptrs + scalar_t* q_data = query.data_ptr(); + scalar_t* k_data = key.data_ptr(); + scalar_t* v_data = value.data_ptr(); + mask_t* mask_data = attention_mask.defined() + ? attention_mask.data_ptr() + : nullptr; + scalar_t* out_data = output.data_ptr(); + + // Create tpp kernels for Query @ Key + bool headSize_mul4 = headSize % 4 == 0; + // If K of Gemm is not even, use mkl gemm instead of tpp for BF16 + int qk_gemm_K_padding = headSize_mul4 ? 0 : 4 - headSize % 4; + int qk_gemm_K = headSize + qk_gemm_K_padding; + + // auto && qk_gemm = create_or_get_microkernel( + // qSplitSize, block_64, qk_gemm_K, + // 1, //batch_size + // qk_gemm_K, // lda + // block_64, //ldb + // rndkvSplitSize, //ldc + // u8_dt, //a dtype + // u8_dt, //b dtype + // s32_dt //c dtype + // ); + // (*qk_gemm).finalize(); + // (*qk_gemm).generate(); + // size_t qk_scratchpad_size = (*qk_gemm).get_scratchpad_size(); + + // auto && qk_gemm_ktail = create_or_get_microkernel( + // qSplitSize, block_64, qk_gemm_K, + // 1, //batch_size + // qk_gemm_K, // lda + // block_64, //ldb + // rndkvTail, //ldc + // u8_dt, //a dtype + // u8_dt, //b dtype + // s32_dt //c dtype + // ); + // // size_t qk_ktail_scratchpad_size = (*qk_gemm_ktail).get_scratchpad_size(); + // (*qk_gemm_ktail).finalize(); + // (*qk_gemm_ktail).generate(); + + // std::shared_ptr qk_gemm_ktail_tail; + // if (kvTail % block_64 != 0) { + // qk_gemm_ktail_tail = create_or_get_microkernel( + // qSplitSize, kv_tail_tail_block_size, qk_gemm_K, + // 1, //batch_size + // qk_gemm_K, // lda + // kv_tail_tail_block_size, //ldb + // rndkvTail, //ldc + // u8_dt, //a dtype + // u8_dt, //b dtype + // s32_dt //c dtype + // ); + // (*qk_gemm_ktail_tail).finalize(); + // (*qk_gemm_ktail_tail).generate(); + // } + + // auto && qk_gemm_qtail = create_or_get_microkernel( + // qTail, block_64, qk_gemm_K, + // 1, //batch_size + // qk_gemm_K,//headSize_mul4 ? qStrideM : qk_gemm_K, // lda + // block_64, //ldb + // rndkvSplitSize, //ldc + // u8_dt, //a dtype + // u8_dt, //b dtype + // s32_dt //c dtype + // ); + // // size_t qk_qtail_scratchpad_size = (*qk_gemm_qtail).get_scratchpad_size(); + // (*qk_gemm_qtail).finalize(); + // (*qk_gemm_qtail).generate(); + // auto && qk_gemm_qktail = create_or_get_microkernel( + // qTail, block_64, qk_gemm_K, + // 1, //batch_size + // qk_gemm_K, // lda + // block_64, //ldb + // rndkvTail, //ldc + // u8_dt, //a dtype + // u8_dt, //b dtype + // s32_dt //c dtype + // ); + // // size_t qk_qktail_scratchpad_size = (*qk_gemm_qktail).get_scratchpad_size(); + // (*qk_gemm_qktail).finalize(); + // (*qk_gemm_qktail).generate(); + + // std::shared_ptr qk_gemm_qktail_tail; + // if (kvTail % block_64 != 0) { + // qk_gemm_qktail_tail = create_or_get_microkernel( + // qSplitSize, kv_tail_tail_block_size, qk_gemm_K, + // 1, //batch_size + // qk_gemm_K, // lda + // kv_tail_tail_block_size, //ldb + // rndkvTail, //ldc + // u8_dt, //a dtype + // u8_dt, //b dtype + // s32_dt //c dtype + // ); + // (*qk_gemm_qktail_tail).finalize(); + // (*qk_gemm_qktail_tail).generate(); + // } + + // std::vector> A_B_offsets(1); + std::vector> A_B_offsets(1); + A_B_offsets[0] = std::make_pair(0, 0); + + // std::vector> A_B_offsets_batch(kvSlice); + std::vector> A_B_offsets_batch(kvSlice); + for (auto s=0; s(); + + int64_t kv_sum_size_per_BH = + /* key_sum */ kvSize + + /* value_sum */ headSize; + + at::Tensor kv_sum_buf = at::empty( + {batchSize, num_head, kv_sum_size_per_BH}, + query.options().dtype(at::kInt)); + int32_t* k_sum_buf_data = kv_sum_buf.data_ptr(); + int32_t* v_sum_buf_data = k_sum_buf_data + batchSize * num_head * kvSize; + + int64_t kv_reorder_size_per_BH = + /* key_t_reorder */ qk_gemm_K * rndkvSize + + /* value_t_reorder */ kvSlice * av_gemm_K * rndHeadSize; + + at::Tensor kv_reorder_buf = at::empty( + {batchSize, num_head, kv_reorder_size_per_BH}, + query.options()); + scalar_t* kv_reorder_buf_data = kv_reorder_buf.data_ptr(); + scalar_t* key_reorder_ptr = kv_reorder_buf_data; + scalar_t* value_reorder_ptr = kv_reorder_buf_data + batchSize * num_head * qk_gemm_K * rndkvSize; + +// // Create transforms for Key +// auto && brgemm_k_xform = create_or_get_packb_microkernel( +// qk_gemm_K, // K +// block_64, // N +// block_64, // kStrideN, // block_64, // ld_in +// block_64, // ld_out +// u8_dt, // dt_in +// u8_dt, // dt_out +// false // true +// ); +// (*brgemm_k_xform).generate(); +// auto && brgemm_k_xform_tail = create_or_get_packb_microkernel( +// qk_gemm_K, +// block_64, +// block_64, // kStrideN, // block_64, +// block_64, +// u8_dt, +// u8_dt, +// false // true +// ); +// (*brgemm_k_xform_tail).generate(); +// std::shared_ptr brgemm_k_xform_tail_tail; +// if (kvTail % block_64 != 0) { +// brgemm_k_xform_tail_tail = create_or_get_packb_microkernel( +// qk_gemm_K, +// kv_tail_tail_block_size, +// kv_tail_tail_block_size, // kStrideN, // kv_tail_tail_block_size, +// kv_tail_tail_block_size, +// u8_dt, +// u8_dt, +// false // true +// ); +// (*brgemm_k_xform_tail_tail).generate(); +// } + +// // Create transforms for Value +// auto && brgemm_v_xform = create_or_get_packb_microkernel( +// av_gemm_K, +// block_64, +// vStrideN, // block_64, +// block_64, +// u8_dt, +// u8_dt, +// false +// ); +// (*brgemm_v_xform).generate(); +// auto && brgemm_v_xform_tail = create_or_get_packb_microkernel( +// av_gemm_K_tail, +// block_64, +// vStrideN, // block_64, +// block_64, +// u8_dt, +// u8_dt, +// false +// ); +// (*brgemm_v_xform_tail).generate(); + + // sum k + if (q_zp != 0) { + at::parallel_for( + 0, batchSize * num_head * kvSize, 1, [&](int64_t begin, int64_t end) { + int64_t i = 0, j = 0, k = 0; + at::native::data_index_init( + begin, i, batchSize, j, num_head, k, kvSize); + for (const auto z : c10::irange(begin, end)) { + (void)z; // Suppress unused variable + int32_t* k_sum_ptr = k_sum_buf_data + + i * num_head * kvSize + + j * kvSize + k; + _int_sum_b_contiguous_kernel_helper( + k_data + i * kStrideB + j * kStrideH + k * kStrideN, + k_sum_ptr, + headSize, q_zp); + // Move to the next query + at::native::data_index_step(i, batchSize, j, num_head, k, kvSize); + } + }); + } + + // sum v + if (a_zp != 0) { + at::parallel_for( + 0, batchSize * num_head, 1, [&](int64_t begin, int64_t end) { + int64_t i = 0, j = 0; + at::native::data_index_init( + begin, i, batchSize, j, num_head); + for (const auto z : c10::irange(begin, end)) { + (void)z; // Suppress unused variable + int32_t* v_sum_ptr = v_sum_buf_data + + i * num_head * headSize + + j * headSize; + _int_sum_a_contiguous_kernel(v_data + i * vStrideB + j * vStrideH, + v_sum_ptr, + headSize, kvSize, vStrideN, a_zp); + // Move to the next query + at::native::data_index_step(i, batchSize, j, num_head); + } + }); + } + + at::parallel_for( + 0, batchSize * num_head, 1, [&](int64_t begin, int64_t end) { + int64_t i = 0, j = 0; + at::native::data_index_init( + begin, i, batchSize, j, num_head); + int ompIdx = at::get_thread_num(); + scalar_t* total_buf_ptr = total_buf_data + ompIdx * total_size_uint8_per_thread; + int32_t offset = 0; + accum_t* qk_data = reinterpret_cast(total_buf_ptr); + offset += kvSlice * qSplitSize * rndkvSplitSize * 4; + accum_t* qk_local_data = reinterpret_cast(total_buf_ptr + offset); + offset += kvSlice * av_gemm_K * 4; + scalar_t* qk_reduced_data = reinterpret_cast(total_buf_ptr + offset); + offset += kvSlice * qSplitSize * av_gemm_K; + int32_t* qk_s32_data = reinterpret_cast(total_buf_ptr + offset); + offset += qSplitSize * rndkvSplitSize * 4; + int32_t* dst_s32_data = reinterpret_cast(total_buf_ptr + offset); + offset += qSplitSize * rndHeadSize * 4; + accum_t* sfm_sum_ptr = reinterpret_cast(total_buf_ptr + offset); + offset += qSplitSize * 4; + int32_t* q_sum_ptr = reinterpret_cast(total_buf_ptr + offset); + offset += qSplitSize * 4; + int32_t* a_sum_ptr = reinterpret_cast(total_buf_ptr + offset); + offset += qSplitSize * 4; + accum_t* sfm_max_ptr = reinterpret_cast(total_buf_ptr + offset); + offset += qSplitSize * 4; + scalar_t* query_t_padding_ptr = reinterpret_cast(total_buf_ptr + offset); + offset += qSplitSize * qk_gemm_K; + // scalar_t* scratchpad_gemm = reinterpret_cast(total_buf_ptr + offset); + // offset += scratchpad_size; + + scalar_t* key_reorder_ptr = reinterpret_cast(total_buf_ptr + offset); + offset += qk_gemm_K * rndkvSize; + scalar_t* value_reorder_ptr = reinterpret_cast(total_buf_ptr + offset); + + uint8_t* B_blocked_xform_u8 = new uint8_t[std::max(qk_gemm_K, av_gemm_K) * block_64]; + + for (const auto z : c10::irange(begin, end)) { + (void)z; // Suppress unused variable + + // pack + for (int64_t n = 0; n < kvSize; n += kvSplitSize) { + // long ss, ee; + if (n + kvSplitSize < kvSize) { + for (int64_t b = 0; b < rndkvSplitSize; b += block_64) { + bool tail = kvSplitSize - b < block_64; + do_transpose( + // do_copy( + k_data + i * kStrideB + j * kStrideH + n * kStrideN + b * kStrideN, + B_blocked_xform_u8, + tail ? kvSplitSize - b : block_64, + headSize, + kStrideN, + block_64); + if (!headSize_mul4) { + pad_remain_row_col( + B_blocked_xform_u8, + headSize, + block_64, + qk_gemm_K, + block_64, + block_64 + ); + } + // Pack + // (*brgemm_k_xform).execute( + // // k_data + i * kStrideB + j * kStrideH + n * kStrideN + b * kStrideN, + // B_blocked_xform_u8, + // key_reorder_ptr + n * qk_gemm_K + + // b * qk_gemm_K + // ); + at::native::cpublas::pack( + qk_gemm_K, // K + block_64, // N + block_64, // ld_in + block_64, // ld_out + u8_dt, // dt_in + u8_dt, // dt_out + B_blocked_xform_u8, + key_reorder_ptr + n * qk_gemm_K + + b * qk_gemm_K); + } + // split headSize to block_64, block_64, block_64 ... + // [kvSplitSize, headSize] -> [av_gemm_K, block_64 ...] + for (int64_t b = 0; b < rndHeadSize; b += block_64) { + // (*brgemm_v_xform).execute( + // v_data + i * vStrideB + j * vStrideH + n * vStrideN + b, + // // B_blocked_xform_u8, + // value_reorder_ptr + n * rndHeadSize + + // av_gemm_K * b); + at::native::cpublas::pack( + av_gemm_K, + block_64, + vStrideN, // block_64, + block_64, + u8_dt, + u8_dt, + v_data + i * vStrideB + j * vStrideH + n * vStrideN + b, + value_reorder_ptr + n * rndHeadSize + + av_gemm_K * b); + } + } else { + // tail + auto block_size = kvTail >= block_64 ? block_64 : kv_tail_tail_block_size; + int64_t b = 0; + while (b < rndkvTail) { + bool tail = kvTail - b < block_size; + do_transpose( + k_data + i * kStrideB + j * kStrideH + n * kStrideN + b * kStrideN, + B_blocked_xform_u8, + tail ? kvTail - b : block_size, + headSize, + kStrideN, + block_size); + if (!headSize_mul4) { + pad_remain_row_col( + B_blocked_xform_u8, + headSize, + block_size, + qk_gemm_K, + block_size, + block_size + ); + } + // Pack + if (block_size == block_64) { + // (*brgemm_k_xform_tail).execute( + // // k_data + i * kStrideB + j * kStrideH + n * kStrideN + b * kStrideN, + // B_blocked_xform_u8, + // key_reorder_ptr + n * qk_gemm_K + + // b * qk_gemm_K + // ); + at::native::cpublas::pack( + qk_gemm_K, + block_64, + block_64, // kStrideN, // block_64, + block_64, + u8_dt, + u8_dt, + B_blocked_xform_u8, + key_reorder_ptr + n * qk_gemm_K + + b * qk_gemm_K); + } else { + // (*brgemm_k_xform_tail_tail).execute( + // // k_data + i * kStrideB + j * kStrideH + n * kStrideN + b * kStrideN, + // B_blocked_xform_u8, + // key_reorder_ptr + n * qk_gemm_K + + // b * qk_gemm_K + // ); + at::native::cpublas::pack( + qk_gemm_K, + kv_tail_tail_block_size, + kv_tail_tail_block_size, // kStrideN, // kv_tail_tail_block_size, + kv_tail_tail_block_size, + u8_dt, + u8_dt, + B_blocked_xform_u8, + key_reorder_ptr + n * qk_gemm_K + + b * qk_gemm_K); + } + b += block_size; + block_size = (kvTail - b) >= block_64 ? block_64 : kv_tail_tail_block_size; + } + // split headSize to block_64, block_64, block_64 ... + // [kvTail, headSize] -> [av_gemm_K_tail, block_64 ...] + for (int64_t b = 0; b < headSize; b += block_64) { + // (*brgemm_v_xform).execute( + // v_data + i * vStrideB + j * vStrideH + n * vStrideN + b, + // // B_blocked_xform_u8, + // value_reorder_ptr + n * rndHeadSize + + // av_gemm_K * b); + at::native::cpublas::pack( + av_gemm_K, + block_64, + vStrideN, // block_64, + block_64, + u8_dt, + u8_dt, + v_data + i * vStrideB + j * vStrideH + n * vStrideN + b, + value_reorder_ptr + n * rndHeadSize + + av_gemm_K * b); + } + } + } + + // sdpa core + int32_t* k_sum_ptr = k_sum_buf_data + i * num_head * kvSize + j * kvSize; + int32_t* v_sum_ptr = v_sum_buf_data + i * num_head * headSize + j * headSize; + for (int64_t k = 0; k < qSlice; k++) { + int64_t m = k * qSplitSize; + int64_t qBlockSize = std::min(qSplitSize, qSize - m); + // Initialize sum and max + fill_stub( + sfm_sum_ptr, static_cast(0), qSplitSize); + fill_stub( + a_sum_ptr, static_cast(0), qSplitSize); + fill_stub( + sfm_max_ptr, static_cast(-std::numeric_limits::infinity()), qSplitSize); + int64_t num_keys = + is_causal ? std::min(m + qBlockSize, kvSize) : kvSize; + copy_value_with_pad( + q_data + i * qStrideB + j * qStrideH + m * qStrideM, + query_t_padding_ptr, + qBlockSize, + headSize, + qBlockSize, + qk_gemm_K, + qStrideM); + + if (k_zp == 0) { + _int_sum_b_contiguous_kernel(q_data + i * qStrideB + j * qStrideH + m * qStrideM, + q_sum_ptr, qBlockSize, headSize, qStrideM, k_zp); + } else { + fill_stub( + q_sum_ptr, static_cast(0), qSplitSize); + } + const int64_t rkvSlice = (num_keys - 1) / kvSplitSize + 1; + for (int64_t l = 0; l < rkvSlice; l++) { + int64_t n = l * kvSplitSize; + int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n); + // Calculate sums for dequant compensation item + if (qBlockSize == qSplitSize) { + // q main + if (n + kvSplitSize < kvSize) { + // k main + for (int64_t b = 0; b < kvSplitSize; b += block_64) { + // dnnl::ukernel::brgemm::release_hw_context(); + // (*qk_gemm).set_hw_context(); + // (*qk_gemm).execute( + // query_t_padding_ptr, + // key_reorder_ptr + n * qk_gemm_K + + // b * qk_gemm_K, + // A_B_offsets, + // qk_s32_data + b, + // scratchpad_gemm); + at::native::cpublas::brgemm( + qSplitSize, block_64, qk_gemm_K, + 1, //batch_size + qk_gemm_K, // lda + block_64, //ldb + rndkvSplitSize, //ldc, + false, + query_t_padding_ptr, + key_reorder_ptr + n * qk_gemm_K + + b * qk_gemm_K, + qk_s32_data + b, + A_B_offsets); + } + } else { + auto block_size = kvTail >= block_64 ? block_64 : kv_tail_tail_block_size; + int64_t b = 0; + while (b < kvTail) { + if (block_size == block_64) { + // dnnl::ukernel::brgemm::release_hw_context(); + // (*qk_gemm_ktail).set_hw_context(); + // (*qk_gemm_ktail).execute( + // query_t_padding_ptr, + // key_reorder_ptr + n * qk_gemm_K + + // b * qk_gemm_K, + // A_B_offsets, + // qk_s32_data + b, + // scratchpad_gemm); + at::native::cpublas::brgemm( + qSplitSize, block_64, qk_gemm_K, + 1, //batch_size + qk_gemm_K, // lda + block_64, //ldb + rndkvTail, //ldc + false, + query_t_padding_ptr, + key_reorder_ptr + n * qk_gemm_K + + b * qk_gemm_K, + qk_s32_data + b, + A_B_offsets); + } else { + // dnnl::ukernel::brgemm::release_hw_context(); + // (*qk_gemm_ktail_tail).set_hw_context(); + // (*qk_gemm_ktail_tail).execute( + // query_t_padding_ptr, + // key_reorder_ptr + n * qk_gemm_K + + // b * qk_gemm_K, + // A_B_offsets, + // qk_s32_data + b, + // scratchpad_gemm); + at::native::cpublas::brgemm( + qSplitSize, kv_tail_tail_block_size, qk_gemm_K, + 1, //batch_size + qk_gemm_K, // lda + kv_tail_tail_block_size, //ldb + rndkvTail, //ldc + false, + query_t_padding_ptr, + key_reorder_ptr + n * qk_gemm_K + + b * qk_gemm_K, + qk_s32_data + b, + A_B_offsets); + } + b += block_size; + block_size = (kvTail - b) >= block_64 ? block_64 : kv_tail_tail_block_size; + } + } + } else { + if (n + kvSplitSize < kvSize) { + for (int64_t b = 0; b < kvSplitSize; b += block_64) { + // dnnl::ukernel::brgemm::release_hw_context(); + // (*qk_gemm_qtail).set_hw_context(); + // (*qk_gemm_qtail).execute( + // query_t_padding_ptr, + // key_reorder_ptr + n * qk_gemm_K + + // b * qk_gemm_K, + // A_B_offsets, + // qk_s32_data + b, + // scratchpad_gemm); + at::native::cpublas::brgemm( + qTail, block_64, qk_gemm_K, + 1, //batch_size + qk_gemm_K,//headSize_mul4 ? qStrideM : qk_gemm_K, // lda + block_64, //ldb + rndkvSplitSize, //ldc + false, + query_t_padding_ptr, + key_reorder_ptr + n * qk_gemm_K + + b * qk_gemm_K, + qk_s32_data + b, + A_B_offsets); + } + } else { + // k tail + auto block_size = kvTail >= block_64 ? block_64 : kv_tail_tail_block_size; + int64_t b = 0; + while (b < kvTail) { + if (block_size == block_64) { + // dnnl::ukernel::brgemm::release_hw_context(); + // (*qk_gemm_qktail).set_hw_context(); + // (*qk_gemm_qktail).execute( + // query_t_padding_ptr, + // key_reorder_ptr + n * qk_gemm_K + + // b * qk_gemm_K, + // A_B_offsets, + // qk_s32_data + b, + // scratchpad_gemm); + at::native::cpublas::brgemm( + qTail, block_64, qk_gemm_K, + 1, //batch_size + qk_gemm_K, // lda + block_64, //ldb + rndkvTail, //ldc + false, + query_t_padding_ptr, + key_reorder_ptr + n * qk_gemm_K + + b * qk_gemm_K, + qk_s32_data + b, + A_B_offsets); + } else { + // dnnl::ukernel::brgemm::release_hw_context(); + // (*qk_gemm_qktail_tail).set_hw_context(); + // (*qk_gemm_qktail_tail).execute( + // query_t_padding_ptr, + // key_reorder_ptr + n * qk_gemm_K + + // b * qk_gemm_K, + // A_B_offsets, + // qk_s32_data + b, + // scratchpad_gemm); + at::native::cpublas::brgemm( + qSplitSize, kv_tail_tail_block_size, qk_gemm_K, + 1, //batch_size + qk_gemm_K, // lda + kv_tail_tail_block_size, //ldb + rndkvTail, //ldc + false, + query_t_padding_ptr, + key_reorder_ptr + n * qk_gemm_K + + b * qk_gemm_K, + qk_s32_data + b, + A_B_offsets); + } + b += block_size; + block_size = (kvTail - b) >= block_64 ? block_64 : kv_tail_tail_block_size; + } + } + } + + // do dequant compensation, add mask, max reduce for softmax, and convert qk from s32 to fp32 + int64_t rndkvBlockSize = kvBlockSize == kvSplitSize ? rndkvSplitSize : rndkvTail; + accum_t* qk_block_data = qk_data + l * qSplitSize * rndkvSplitSize; + if (has_attn_mask) { + mask_t* mask_data_offset = mask_data + i * mStrideB + j * mStrideH + m * mStrideM + (mStrideN == 0 ? 0 : n); + _dequant_mask_max_fusion_kernel( + qk_s32_data, //in + mask_data_offset, //mask_ptr + q_sum_ptr, //sum_a_ptr + k_sum_ptr + n, //sum_b_ptr + qBlockSize, //M + kvBlockSize, //N + rndkvBlockSize, //ldi + mStrideM, //ldm + rndkvSplitSize,//kvBlockSize, //ldo + q_zp * k_zp * headSize, //zp_a*zp_b*k=beta + q_scale * k_scale * scaling_factor, //scale_a*scale_b*scale_sdpa=alpha + qk_block_data, //out + sfm_max_ptr // sfm_max_ptr + ); + } else { + _dequant_max_fusion_kernel( + qk_s32_data, //in + q_sum_ptr, //sum_a_ptr + k_sum_ptr + n, //sum_b_ptr + qBlockSize, //M + kvBlockSize, //N + rndkvBlockSize, //ldi + rndkvSplitSize,//kvBlockSize, //ldo + q_zp * k_zp * headSize, //zp_a*zp_b*k=beta + q_scale * k_scale * scaling_factor, //scale_a*scale_b*scale_sdpa=alpha + qk_block_data, //out + sfm_max_ptr // sfm_max_ptr + ); + } + } + + // sub max, exp, sum reduce, div sum for softmax + // and quant + // and sum for attention + if (v_zp == 0) { + _sub_exp_sum_div_quant_fusion_kernel( + qk_data, //in + qBlockSize, //M + kvSplitSize, //N_step + rkvSlice, //NSlices + qSplitSize * rndkvSplitSize, //ldi + qSplitSize * av_gemm_K, //ldo + kvSize, //kvSize + rndkvSplitSize, //rndkvSplitSize + av_gemm_K, //av_gemm_K + a_zp, // zp_a=beta1 + a_scale, // scale_a=alpha + qk_local_data, //local + qk_reduced_data, //out + sfm_max_ptr, //sfm_max_ptr + sfm_sum_ptr //sfm_sum_ptr + ); + } else { + _sub_exp_sum_div_quant_sum_fusion_kernel( + qk_data, //in + qBlockSize, //M + kvSplitSize, //N_step + rkvSlice, //NSlice + qSplitSize * rndkvSplitSize, //ldi + qSplitSize * av_gemm_K, //ldo + kvSize, //kvSize + rndkvSplitSize, //rndkvSplitSize + av_gemm_K, //av_gemm_K + a_zp, // zp_a=beta1 + v_zp, // zp_b=beta2 + a_scale, // scale_a=alpha + qk_local_data, //local + qk_reduced_data, //out + sfm_max_ptr, //sfm_max_ptr + sfm_sum_ptr, //sfm_sum_ptr + a_sum_ptr //a_sum_ptr + ); + } + + // Calculate Softmax(q @ k.T) @ v + for (int64_t b = 0; b < headSize; b += block_64) { + // dnnl::ukernel::brgemm::release_hw_context(); + // (*av_gemm_batch).set_hw_context(); + // (*av_gemm_batch).execute( + // qk_reduced_data, + // value_reorder_ptr + b * av_gemm_K, + // A_B_offsets_batch, + // dst_s32_data + b, + // scratchpad_gemm); + at::native::cpublas::brgemm( + qSplitSize, block_64, av_gemm_K, + kvSlice, //batch_size + av_gemm_K, // lda + rndHeadSize, //block_64, //ldb + rndHeadSize, //ldc + false, + qk_reduced_data, + value_reorder_ptr + b * av_gemm_K, + dst_s32_data + b, + A_B_offsets_batch); + } + + // After the last gemm, + // do dequant compensation, quant and convert from s32 to int8 + _dequant_quant_fusion_kernel( + dst_s32_data, //in + a_sum_ptr, //sum_a_ptr + v_sum_ptr, //sum_b_ptr + qBlockSize, //M + headSize, //N + rndHeadSize, //ldi + oStrideM, //ldo + a_zp * v_zp * kvSize, //zp_a*zp_b*k=beta1 + o_zp, //zp_c=beta2 + a_scale * v_scale / o_scale, //scale_a*scale_b/scale_c=alpha + out_data + i * oStrideB + j * oStrideH + m * oStrideM //out + ); + } + // Move to the next query + at::native::data_index_step(i, batchSize, j, num_head); + } + }); + // Once all computations are done, need to release HW context. + // brgemm::release_hw_context(); + at::native::cpublas::brgemm_release(); +} + +#define AT_DISPATCH_MASK_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_PRIVATE_CASE_TYPE_USING_HINT( \ + at::ScalarType::Bool, mask_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE_USING_HINT( \ + at::ScalarType::Float, mask_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE_USING_HINT( \ + at::ScalarType::Double, mask_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE_USING_HINT( \ + at::ScalarType::BFloat16, mask_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE_USING_HINT( \ + at::ScalarType::Half, mask_t, __VA_ARGS__)) + +void sdpa_int8_kernel( + const at::Tensor& output, + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + double dropout_p, + bool is_causal, + at::Tensor& attn_mask, + double scale, + long q_zp, + double q_scale, + long k_zp, + double k_scale, + long v_zp, + double v_scale, + long a_zp, + double a_scale, + long o_zp, + double o_scale) { + int64_t batchSize = query.size(0); + int64_t num_head = query.size(1); + int64_t q_seq_len = query.size(2); + + TORCH_CHECK(query.scalar_type() == c10::kByte); + if (!attn_mask.defined()) { + if (q_seq_len >= 768) { + sdpa_int8_kernel_impl( + output, query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_zp, q_scale, + k_zp, k_scale, + v_zp, v_scale, + a_zp, a_scale, + o_zp, o_scale); + } else if (q_seq_len >= 192) { + sdpa_int8_kernel_impl( + output, query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_zp, q_scale, + k_zp, k_scale, + v_zp, v_scale, + a_zp, a_scale, + o_zp, o_scale); + } else { + sdpa_int8_kernel_impl( + output, query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_zp, q_scale, + k_zp, k_scale, + v_zp, v_scale, + a_zp, a_scale, + o_zp, o_scale); + } + } else { + AT_DISPATCH_MASK_TYPES(attn_mask.scalar_type(), "sdpa_mask", [&]() { + if (q_seq_len >= 768) { + sdpa_int8_kernel_impl( + output, query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_zp, q_scale, + k_zp, k_scale, + v_zp, v_scale, + a_zp, a_scale, + o_zp, o_scale); + } else if (q_seq_len >= 192) { + sdpa_int8_kernel_impl( + output, query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_zp, q_scale, + k_zp, k_scale, + v_zp, v_scale, + a_zp, a_scale, + o_zp, o_scale); + } else { + sdpa_int8_kernel_impl( + output, query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_zp, q_scale, + k_zp, k_scale, + v_zp, v_scale, + a_zp, a_scale, + o_zp, o_scale); + } + }); + } +} + +// at::Tensor sdpa_int8_math_impl( +// const at::Tensor& query_, +// const at::Tensor& key, +// const at::Tensor& value, +// double dropout_p, +// bool is_causal, +// at::Tensor& attn_mask_, +// double scale, +// int32_t q_zp, +// float q_scale, +// int32_t k_zp, +// float k_scale, +// int32_t v_zp, +// float v_scale, +// int32_t a_zp, +// float a_scale, +// int32_t o_zp, +// float o_scale) { +// // dequant q/k/v +// auto q = (query_.to(at::kFloat) - q_zp) * q_scale; +// auto k = (key.to(at::kFloat) - k_zp) * k_scale; +// auto v = (value.to(at::kFloat) - v_zp) * v_scale; +// auto attn_mask = attn_mask_; +// if (attn_mask.defined()) { +// *attn_mask = (*attn_mask).to(at::kFloat); +// } +// // Scale q, k before matmul for stability see https://tinyurl.com/sudb9s96 for math +// bool is_negative_scaling = scale.defined() && scale < 0.0; +// const auto scaling_factor = sdp::calculate_scale(q, is_negative_scaling ? std::abs(scale) : scale).sqrt(); +// q = q * (is_negative_scaling ? c10::SymFloat(0.0) - scaling_factor: scaling_factor); +// auto attn = at::matmul(q, k.transpose(-2, -1) * scaling_factor); +// if (attn_mask.defined()) { +// if (at::areAnyTensorSubclassLike({attn, *attn_mask})) { +// attn = attn.add(*attn_mask); +// } else { +// attn.add_(*attn_mask); +// } +// } +// attn = at::softmax(attn, -1); +// // quant attn +// attn = at::clamp_max( +// at::clamp_min(at::round(attn / a_scale) + a_zp, 0), 255 +// ); +// // dequant attn +// attn = (attn - a_zp) * a_scale; +// auto output = at::matmul(attn, v); +// // quant output +// output = at::clamp_max( +// at::clamp_min(at::round(output / o_scale) + o_zp, 0), 255 +// ).to(at::kByte); +// return output; +// } + + +at::Tensor _scaled_dot_product_int8_cpu( + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + at::Tensor& attn_mask, + // const std::optional& attn_mask, + double dropout_p, + bool is_causal, + double scale, + // std::optional scale, + int64_t q_zp, + double q_scale, + int64_t k_zp, + double k_scale, + int64_t v_zp, + double v_scale, + int64_t a_zp, + double a_scale, + int64_t o_zp, + double o_scale) { + const auto dtype = query.scalar_type(); + TORCH_CHECK(!query.is_nested() && !key.is_nested() && !value.is_nested(), + "_scaled_dot_product_int8_cpu: Only accept plain inputs"); + TORCH_CHECK(!is_causal, + "_scaled_dot_product_int8_cpu: is_causal not supported."); + TORCH_CHECK(dtype == at::ScalarType::Byte, + "_scaled_dot_product_int8_cpu: Expected data type be U8, but got ", dtype, " instead."); + TORCH_CHECK(query.dim() == 4 && key.dim() == 4 && value.dim() == 4, + "_scaled_dot_product_int8_cpu: Accept only 4 dims inputs shape of {B, H, T, K}"); + TORCH_CHECK(dropout_p == 0.0, + "_scaled_dot_product_int8_cpu: Currently do not support dropout > 0"); + TORCH_CHECK((query.size(3) == value.size(3)) && (key.size(3) == value.size(3)), + "_scaled_dot_product_int8_cpu: Q/K/V should have the same head size"); + TORCH_CHECK(!attn_mask.defined() || + attn_mask.scalar_type() == at::kFloat || + attn_mask.scalar_type() == at::kBFloat16, + "_scaled_dot_product_int8_cpu: Expected attention mask be float or bf16"); + TORCH_CHECK(!attn_mask.defined() || + (attn_mask.dim() == 2 || attn_mask.dim() == 4), + "_scaled_dot_product_int8_cpu: Attention mask dim in {2, 4}"); + + // fallback math path + // at::Tensor output = sdpa_int8_math_impl(query, key, value, + // dropout_p, is_causal, attn_mask, scale, + // q_zp, q_scale, + // k_zp, k_scale, + // v_zp, v_scale, + // a_zp, a_scale, + // o_zp, o_scale); + + // fused sdpa int8 impl + at::Tensor output = at::empty_like(query, query.options()).transpose(1, 2); + sdpa_int8_kernel(output, query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_zp, q_scale, + k_zp, k_scale, + v_zp, v_scale, + a_zp, a_scale, + o_zp, o_scale); + + return output.transpose(1, 2); +} + + +} // anonymous namespace + +TORCH_LIBRARY_IMPL(torchao, CPU, m) { + m.impl("torchao::scaled_dot_product_int8", &_scaled_dot_product_int8_cpu); +} + +// } // at::native +} // namespace torchao diff --git a/torchao/csrc/cpu/toy.cpp b/torchao/csrc/cpu/toy.cpp new file mode 100644 index 0000000000..a835aae661 --- /dev/null +++ b/torchao/csrc/cpu/toy.cpp @@ -0,0 +1,20 @@ +#include +#include +#include + +namespace torchao { + +torch::Tensor toy_op2_cpu( + torch::Tensor _in_feats) +{ + std::cout<<"---- run into cpu 2 ----"< Tensor", tags=[torch._C.Tag.needs_fixed_stride_order], ) - +lib.define( + "scaled_dot_product_int8(Tensor query, Tensor key, Tensor value, Tensor attn_mask=None, float dropout_p=0.0, bool is_causal=False, float scale=1.0, int q_zp=0, float q_scale=1.0, int k_zp=0, float k_scale=1.0, int v_zp=0, float v_scale=1.0, int a_zp=0, float a_scale=1.0, int o_zp=0, float o_scale=1.0) -> Tensor" +) def register_custom_op(name): def decorator(func): @@ -159,6 +161,57 @@ def _( return _in_feats.new_empty((BS, OC)) +def scaled_dot_product_int8( + query: Tensor, + key: Tensor, + value: Tensor, + attn_mask: Tensor = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: float = 1.0, + q_zp: int = 0, + q_scale: float = 1.0, + k_zp: int = 0, + k_scale: float = 1.0, + v_zp: int = 0, + v_scale: float = 1.0, + a_zp: int = 0, + a_scale: float = 1.0, + o_zp: int = 0, + o_scale: float = 1.0, +) -> Tensor: + return torch.ops.torchao.scaled_dot_product_int8.default(query, key, value, + attn_mask, dropout_p, is_causal, scale, + q_zp, q_scale, + k_zp, k_scale, + v_zp, v_scale, + a_zp, a_scale, + o_zp, o_scale) + + +@register_custom_op("torchao::scaled_dot_product_int8") +def _( + query: Tensor, + key: Tensor, + value: Tensor, + attn_mask: Tensor = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: float = 1.0, + q_zp: int = 0, + q_scale: float = 1.0, + k_zp: int = 0, + k_scale: float = 1.0, + v_zp: int = 0, + v_scale: float = 1.0, + a_zp: int = 0, + a_scale: float = 1.0, + o_zp: int = 0, + o_scale: float = 1.0, +) -> Tensor: + return query + + def unpack_tensor_core_tiled_layout(packed_w: Tensor, inner_k_tiles: int) -> Tensor: """ Unpacks weights that were packed with `torch.ops.aten._convert_weight_to_int4pack` to original tensor of shape `N x K`. diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 67a70c5a35..a53a87a919 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -94,6 +94,9 @@ smooth_fq_linear_to_inference, swap_linear_with_smooth_fq_linear, ) +from .sfdp_int8_fx_pass import ( + _sfdp_init_int8, +) from .subclass import * # noqa: F403 from .transform_module import register_quantize_module_handler from .unified import Quantizer, TwoStepQuantizer @@ -192,4 +195,5 @@ "TensorCoreTiledLayout", "CutlassInt4PackedLayout", "Float8MMConfig", + "_sfdp_init_int8", ] diff --git a/torchao/quantization/sfdp_int8_fx_pass.py b/torchao/quantization/sfdp_int8_fx_pass.py new file mode 100644 index 0000000000..672db14f6b --- /dev/null +++ b/torchao/quantization/sfdp_int8_fx_pass.py @@ -0,0 +1,733 @@ +import functools +from typing import Callable + +import torch +from torch._inductor import config +from torch._inductor.pattern_matcher import ( + filter_nodes, + fwd_only, + register_replacement, + gen_register_replacement, + PatternMatcherPass, +) +from torch._dynamo.utils import counters +from torch._inductor.fx_passes.fuse_attention import ( + partialize_and_update_signature +) +from torchao.ops import scaled_dot_product_int8 + +__all__ = [ + # "_sfdp_pattern_int8", + # "_sfdp_replacement_int8", + # "_gen_sfdp_patterns_int8", + "_sfdp_init_int8", +] + +aten = torch.ops.aten +# scaled_dot_product_int8 = torch.ops.torchao.scaled_dot_product_int8 +patterns = PatternMatcherPass() + +# def _sfdp_pattern_int8(query, key, value, inv_scale): +# return ( +# torch.matmul(query, key.transpose(-2, -1)) +# .div(inv_scale) +# .softmax(dim=-1) +# .matmul(value) +# ) + +# def _sfdp_replacement_int8(query, key, value, inv_scale): +# print("*** enter _sfdp_replacement in torchao ***") +# counters["inductor"]["fuse_attention_int8"] += 1 +# return torch.nn.functional.scaled_dot_product_attention( +# query, +# key, +# value, +# attn_mask=None, +# dropout_p=0.0, +# is_causal=False, +# scale=1.0 / inv_scale, +# ) + +def _sfdp_pattern_int8_1( + query, + key, + value, + attn_mask, + inv_scale, + q_zp, + q_scale, + k_zp, + k_scale, + v_zp, + v_scale, + a_zp, + a_scale, + o_zp, + o_scale, + dropout, +): + # int8-mix-fp32 QUANTIZED SDPA with mask + q = query.permute([0, 2, 1, 3]) + q = torch.ops.quantized_decomposed.dequantize_per_tensor.default( + q, float(q_scale), int(q_zp), 0, 255, torch.uint8 + ) + k = key.permute([0, 2, 1, 3]).transpose(-2, -1) + k = torch.ops.quantized_decomposed.dequantize_per_tensor.default( + k, float(k_scale), int(k_zp), 0, 255, torch.uint8 + ) + v = value.permute([0, 2, 1, 3]) + v = torch.ops.quantized_decomposed.dequantize_per_tensor.default( + v, float(v_scale), int(v_zp), 0, 255, torch.uint8 + ) + a = torch.nn.functional.dropout( + (torch.matmul(q, k).div(inv_scale) + attn_mask).softmax(dim=-1), + dropout, + ) + qa = torch.ops.quantized_decomposed.quantize_per_tensor.default( + a, float(a_scale), int(a_zp), 0, 255, torch.uint8 + ) + a = torch.ops.quantized_decomposed.dequantize_per_tensor.default( + qa, float(a_scale), int(a_zp), 0, 255, torch.uint8 + ) + o = a.matmul(v) + o = o.permute(0, 2, 1, 3).contiguous() + return torch.ops.quantized_decomposed.quantize_per_tensor.default( + o, float(o_scale), int(o_zp), 0, 255, torch.uint8 + ) + + +def _sfdp_replacement_int8_1( + query, + key, + value, + attn_mask, + inv_scale, + q_zp, + q_scale, + k_zp, + k_scale, + v_zp, + v_scale, + a_zp, + a_scale, + o_zp, + o_scale, + dropout, +): + print("hit _sfdp_replacement_int8_1") + counters["inductor"]["fuse_attention_int8"] += 1 + res = scaled_dot_product_int8( + query.transpose(1, 2), + key.transpose(1, 2), + value.transpose(1, 2), + attn_mask=attn_mask, + dropout_p=dropout, + is_causal=False, + scale=1.0 / inv_scale, + q_zp=q_zp, + q_scale=q_scale, + k_zp=k_zp, + k_scale=k_scale, + v_zp=v_zp, + v_scale=v_scale, + a_zp=a_zp, + a_scale=a_scale, + o_zp=o_zp, + o_scale=o_scale, + ) + return res.permute(0, 2, 1, 3).contiguous() + + +def _sfdp_pattern_int8_2( + query, + key, + value, + attn_mask, + inv_scale, + q_zp, + q_scale, + k_zp, + k_scale, + v_zp, + v_scale, + a_zp, + a_scale, + o_zp, + o_scale, + dropout, +): + # int8-mix-reduce QUANTIZED SDPA with mask + q = query.permute([0, 2, 1, 3]) + q = torch.ops.quantized_decomposed.dequantize_per_tensor.default( + q, float(q_scale), int(q_zp), 0, 255, torch.uint8 + ).to(torch.float16) + k = key.permute([0, 2, 1, 3]).transpose(-2, -1) + k = torch.ops.quantized_decomposed.dequantize_per_tensor.default( + k, float(k_scale), int(k_zp), 0, 255, torch.uint8 + ).to(torch.float16) + v = value.permute([0, 2, 1, 3]) + v = torch.ops.quantized_decomposed.dequantize_per_tensor.default( + v, float(v_scale), int(v_zp), 0, 255, torch.uint8 + ).to(torch.float16) + a = torch.nn.functional.dropout( + (torch.matmul(q, k).div(inv_scale) + attn_mask).softmax(dim=-1), + dropout, + ) + qa = torch.ops.quantized_decomposed.quantize_per_tensor.default( + a, float(a_scale), int(a_zp), 0, 255, torch.uint8 + ) + a = torch.ops.quantized_decomposed.dequantize_per_tensor.default( + qa, float(a_scale), int(a_zp), 0, 255, torch.uint8 + ).to(torch.float16) + o = a.matmul(v) + o = o.permute(0, 2, 1, 3).contiguous() + return torch.ops.quantized_decomposed.quantize_per_tensor.default( + o, float(o_scale), int(o_zp), 0, 255, torch.uint8 + ) + + +def _sfdp_replacement_int8_2( + query, + key, + value, + attn_mask, + inv_scale, + q_zp, + q_scale, + k_zp, + k_scale, + v_zp, + v_scale, + a_zp, + a_scale, + o_zp, + o_scale, + dropout, +): + print("hit _sfdp_replacement_int8_2") + counters["inductor"]["fuse_attention_int8"] += 1 + res = scaled_dot_product_int8( + query.transpose(1, 2), + key.transpose(1, 2), + value.transpose(1, 2), + attn_mask=attn_mask, + dropout_p=dropout, + is_causal=False, + scale=1.0 / inv_scale, + q_zp=q_zp, + q_scale=q_scale, + k_zp=k_zp, + k_scale=k_scale, + v_zp=v_zp, + v_scale=v_scale, + a_zp=a_zp, + a_scale=a_scale, + o_zp=o_zp, + o_scale=o_scale, + ) + return res.permute(0, 2, 1, 3).contiguous() + + +def _sfdp_pattern_int8_3( + query, + key, + value, + inv_scale, + q_zp, + q_scale, + k_zp, + k_scale, + v_zp, + v_scale, + a_zp, + a_scale, + o_zp, + o_scale, + dropout, +): + # int8-mix-fp32 QUANTIZED SDPA without mask + q = query.permute([0, 2, 1, 3]) + q = torch.ops.quantized_decomposed.dequantize_per_tensor.default( + q, float(q_scale), int(q_zp), 0, 255, torch.uint8 + ) + k = key.permute([0, 2, 1, 3]).transpose(-2, -1) + k = torch.ops.quantized_decomposed.dequantize_per_tensor.default( + k, float(k_scale), int(k_zp), 0, 255, torch.uint8 + ) + v = value.permute([0, 2, 1, 3]) + v = torch.ops.quantized_decomposed.dequantize_per_tensor.default( + v, float(v_scale), int(v_zp), 0, 255, torch.uint8 + ) + a = torch.nn.functional.dropout( + torch.matmul(q, k).div(inv_scale).softmax(dim=-1), + dropout, + ) + qa = torch.ops.quantized_decomposed.quantize_per_tensor.default( + a, float(a_scale), int(a_zp), 0, 255, torch.uint8 + ) + a = torch.ops.quantized_decomposed.dequantize_per_tensor.default( + qa, float(a_scale), int(a_zp), 0, 255, torch.uint8 + ) + o = a.matmul(v) + o = o.permute(0, 2, 1, 3).contiguous() + return torch.ops.quantized_decomposed.quantize_per_tensor.default( + o, float(o_scale), int(o_zp), 0, 255, torch.uint8 + ) + + +def _sfdp_replacement_int8_3( + query, + key, + value, + inv_scale, + q_zp, + q_scale, + k_zp, + k_scale, + v_zp, + v_scale, + a_zp, + a_scale, + o_zp, + o_scale, + dropout, +): + print("hit _sfdp_replacement_int8_3") + counters["inductor"]["fuse_attention_int8"] += 1 + res = scaled_dot_product_int8( + query.transpose(1, 2), + key.transpose(1, 2), + value.transpose(1, 2), + dropout_p=dropout, + is_causal=False, + scale=1.0 / inv_scale, + q_zp=q_zp, + q_scale=q_scale, + k_zp=k_zp, + k_scale=k_scale, + v_zp=v_zp, + v_scale=v_scale, + a_zp=a_zp, + a_scale=a_scale, + o_zp=o_zp, + o_scale=o_scale, + ) + return res.permute(0, 2, 1, 3).contiguous() + + +def _sfdp_pattern_int8_4( + query, + key, + value, + inv_scale, + q_zp, + q_scale, + k_zp, + k_scale, + v_zp, + v_scale, + a_zp, + a_scale, + o_zp, + o_scale, + dropout, +): + # int8-mix-reduce QUANTIZED SDPA without mask + q = query.permute([0, 2, 1, 3]) + q = torch.ops.quantized_decomposed.dequantize_per_tensor.default( + q, float(q_scale), int(q_zp), 0, 255, torch.uint8 + ).to(torch.float16) + k = key.permute([0, 2, 1, 3]).transpose(-2, -1) + k = torch.ops.quantized_decomposed.dequantize_per_tensor.default( + k, float(k_scale), int(k_zp), 0, 255, torch.uint8 + ).to(torch.float16) + v = value.permute([0, 2, 1, 3]) + v = torch.ops.quantized_decomposed.dequantize_per_tensor.default( + v, float(v_scale), int(v_zp), 0, 255, torch.uint8 + ).to(torch.float16) + a = torch.nn.functional.dropout( + torch.matmul(q, k).div(inv_scale).softmax(dim=-1), + dropout, + ) + qa = torch.ops.quantized_decomposed.quantize_per_tensor.default( + a, float(a_scale), int(a_zp), 0, 255, torch.uint8 + ) + a = torch.ops.quantized_decomposed.dequantize_per_tensor.default( + qa, float(a_scale), int(a_zp), 0, 255, torch.uint8 + ).to(torch.float16) + o = a.matmul(v) + o = o.permute(0, 2, 1, 3).contiguous() + return torch.ops.quantized_decomposed.quantize_per_tensor.default( + o, float(o_scale), int(o_zp), 0, 255, torch.uint8 + ) + + +def _sfdp_replacement_int8_4( + query, + key, + value, + inv_scale, + q_zp, + q_scale, + k_zp, + k_scale, + v_zp, + v_scale, + a_zp, + a_scale, + o_zp, + o_scale, + dropout, +): + print("hit _sfdp_replacement_int8_4") + counters["inductor"]["fuse_attention_int8"] += 1 + res = scaled_dot_product_int8( + query.transpose(1, 2), + key.transpose(1, 2), + value.transpose(1, 2), + dropout_p=dropout, + is_causal=False, + scale=1.0 / inv_scale, + q_zp=q_zp, + q_scale=q_scale, + k_zp=k_zp, + k_scale=k_scale, + v_zp=v_zp, + v_scale=v_scale, + a_zp=a_zp, + a_scale=a_scale, + o_zp=o_zp, + o_scale=o_scale, + ) + return res.permute(0, 2, 1, 3).contiguous() + + +def _sfdp_params_check_int8(match): + assert all(k in match.kwargs for k in ("query", "key", "value")) + query = match.kwargs["query"].meta["val"] + key = match.kwargs["key"].meta["val"] + value = match.kwargs["value"].meta["val"] + if not (query.dtype == key.dtype == value.dtype) or not ( + query.device == key.device == value.device + ): + return False + add_nodes = filter_nodes(match.nodes, aten.add.Tensor) + # Has attn_mask add. + add_mask_node = [n for n in add_nodes if n.prev.target == torch.ops.aten.div.Tensor] + if len(add_mask_node) > 0: + attn_mask_node = add_mask_node[0].args[1] + # attn_mask_node may be a float/int number. + if not hasattr(attn_mask_node, "meta"): + return False + attn_mask = attn_mask_node.meta["val"] # type: ignore[union-attr] + # Make sure attn_mask.dtype == query.dtype or attn_mask.dtype == torch.bool + # attn_mask.dtype == torch.float for models like albert. + if ( + not isinstance(attn_mask, torch.Tensor) + or not ( + attn_mask.dtype == query.dtype + or attn_mask.dtype == torch.bool + or attn_mask.dtype == torch.float + ) + or query.device != attn_mask.device + ): + return False + return True + + +def _sfdp_extra_check_int8(scale_factor_op=None, disable_cuda=False): + def fn(match): + if ( + disable_cuda + and "query" in match.kwargs + and "cuda" in str(match.kwargs["query"].meta["val"].device) + ): + return False + if scale_factor_op is not None: + scale_factor_node = filter_nodes(match.nodes, scale_factor_op)[0] + # Note: args[1] of the scale_factor_node is always the scale_factor for the current patterns. + scale_factor = scale_factor_node.args[1] + # make sure the scale_factor a float/int. SymInt? + if not isinstance(scale_factor, (float, int)): + return False + return _sfdp_params_check_int8(match) + + return fn + + +def _gen_sfdp_patterns_int8(): + if torch.cuda.is_available(): + device = "cuda" + else: + device = "cpu" + g_inp = functools.partial( + torch.empty, (2, 4, 8, 16), device=device, requires_grad=True + ) + m_inp = functools.partial(torch.empty, (2, 1, 1, 4), device=device) # attn_mask + c_inp = functools.partial(torch.tensor, 2.0, device=device) # inv_scale + zp_inp = functools.partial(torch.tensor, 127, device=device) # quant_zero_point + scale_inp = functools.partial(torch.tensor, 0.018, device=device) # quant_scale + + # reshape in matmul decomposition generates a clone when batch_size>1 due to the memory layout change. + # however when batch_size=1, reshape does not change the memory layout, so clone would not be generated. + # here we need to trace with input of batch_size=1 to generate a pattern graph without clone. + g_bs1_inp = functools.partial( + torch.empty, (1, 4, 8, 16), device=device, requires_grad=True + ) + m_bs1_inp = functools.partial(torch.empty, (1, 1, 1, 4), device=device) + for dtype in [torch.float, torch.half]: + # g = functools.partial(g_inp, dtype=dtype) + # c = functools.partial(c_inp, dtype=dtype) + # candidates = [ + # ( + # _sfdp_pattern_int8, + # _sfdp_replacement_int8, + # [g(), g(), g(), c()], + # {}, + # _sfdp_extra_check_int8(aten.div.Tensor), + # ), + # ] + g_u8 = functools.partial(g_inp, dtype=torch.uint8, requires_grad=False) + g_u8_bs1 = functools.partial(g_bs1_inp, dtype=torch.uint8, requires_grad=False) + m = functools.partial(m_inp, dtype=dtype) + m_bs1 = functools.partial(m_bs1_inp, dtype=dtype) + c = functools.partial(c_inp, dtype=dtype) + zp = functools.partial(zp_inp, dtype=torch.int) + scale = functools.partial(scale_inp, dtype=torch.float) + d_u8 = { + "dropout": 0.113377, + "q_zp": 23, + "q_scale": 0.0111541, + "k_zp": 14, + "k_scale": 0.0256212, + "v_zp": 28, + "v_scale": 0.0164518, + "a_zp": 12, + "a_scale": 0.0572114, + "o_zp": 36, + "o_scale": 0.0235489, + } + int8_candidates = [ + ( + _sfdp_pattern_int8_1, + _sfdp_replacement_int8_1, + [ + g_u8(), + g_u8(), + g_u8(), + m(), + c(), + zp(), + scale(), + zp(), + scale(), + zp(), + scale(), + zp(), + scale(), + zp(), + scale(), + ], + d_u8, + _sfdp_extra_check_int8(aten.div.Tensor), + ), + ( + _sfdp_pattern_int8_1, + _sfdp_replacement_int8_1, + [ + g_u8_bs1(), + g_u8_bs1(), + g_u8_bs1(), + m_bs1(), + c(), + zp(), + scale(), + zp(), + scale(), + zp(), + scale(), + zp(), + scale(), + zp(), + scale(), + ], + d_u8, + _sfdp_extra_check_int8(aten.div.Tensor), + ), + ( + _sfdp_pattern_int8_2, + _sfdp_replacement_int8_2, + [ + g_u8(), + g_u8(), + g_u8(), + m(), + c(), + zp(), + scale(), + zp(), + scale(), + zp(), + scale(), + zp(), + scale(), + zp(), + scale(), + ], + d_u8, + _sfdp_extra_check_int8(aten.div.Tensor), + ), + ( + _sfdp_pattern_int8_2, + _sfdp_replacement_int8_2, + [ + g_u8_bs1(), + g_u8_bs1(), + g_u8_bs1(), + m_bs1(), + c(), + zp(), + scale(), + zp(), + scale(), + zp(), + scale(), + zp(), + scale(), + zp(), + scale(), + ], + d_u8, + _sfdp_extra_check_int8(aten.div.Tensor), + ), + ( + _sfdp_pattern_int8_3, + _sfdp_replacement_int8_3, + [ + g_u8(), + g_u8(), + g_u8(), + c(), + zp(), + scale(), + zp(), + scale(), + zp(), + scale(), + zp(), + scale(), + zp(), + scale(), + ], + d_u8, + _sfdp_extra_check_int8(aten.div.Tensor), + ), + ( + _sfdp_pattern_int8_3, + _sfdp_replacement_int8_3, + [ + g_u8_bs1(), + g_u8_bs1(), + g_u8_bs1(), + c(), + zp(), + scale(), + zp(), + scale(), + zp(), + scale(), + zp(), + scale(), + zp(), + scale(), + ], + d_u8, + _sfdp_extra_check_int8(aten.div.Tensor), + ), + ( + _sfdp_pattern_int8_4, + _sfdp_replacement_int8_4, + [ + g_u8(), + g_u8(), + g_u8(), + c(), + zp(), + scale(), + zp(), + scale(), + zp(), + scale(), + zp(), + scale(), + zp(), + scale(), + ], + d_u8, + _sfdp_extra_check_int8(aten.div.Tensor), + ), + ( + _sfdp_pattern_int8_4, + _sfdp_replacement_int8_4, + [ + g_u8_bs1(), + g_u8_bs1(), + g_u8_bs1(), + c(), + zp(), + scale(), + zp(), + scale(), + zp(), + scale(), + zp(), + scale(), + zp(), + scale(), + ], + d_u8, + _sfdp_extra_check_int8(aten.div.Tensor), + ), + ] + for pattern, replacement, args, workaround, extra_check in int8_candidates: + # XXX: when adding a new pattern, re-run `gen_attention_patterns` so the pattern + # gets serialized to a python file and does not require tracing at runtime. + assert isinstance(workaround, dict) + name = pattern.__name__ + + if len(workaround) >= 1: + # if "dropout_p" in workaround: + # # functools.partial insufficient because we look at signature downstream + # pattern = partialize_and_update_signature(pattern, dropout_p=0.0) + # replacement = partialize_and_update_signature( + # replacement, dropout_p=0.0 + # ) + # workaround = {} + # else: + # for uint8 pattern with more workarounds other than dropout, + # we need to rename it to avoid influcing other patterns + pattern = partialize_and_update_signature(pattern, dropout=0.0) + replacement = partialize_and_update_signature( + replacement, dropout=0.0 + ) + if "dropout" in workaround: + del workaround["dropout"] + + inference_name = name + "_inference" + yield inference_name, { + "search_fn": pattern, + "replace_fn": replacement, + "example_inputs": args, + "trace_fn": fwd_only, + "pass_dicts": patterns, + "extra_check": extra_check, + "scalar_workaround": workaround, + } + + +@functools.lru_cache(None) +def _sfdp_init_int8(): + for key, register_replacement_kwargs in _gen_sfdp_patterns_int8(): + register_replacement(**register_replacement_kwargs) + config.joint_custom_pre_pass = patterns.apply + # print("\n\njoint_custom_pre_pass:", config.joint_custom_pre_pass) From b5985ae215017ca53ded2c788be237110c6590c3 Mon Sep 17 00:00:00 2001 From: Valentine233 Date: Tue, 3 Dec 2024 02:45:51 -0500 Subject: [PATCH 02/36] update int8 sdpa --- test/test_ops.py | 122 ++++++++++++++++++++++++++++++++++++++ torchao/csrc/cpu/sdpa.cpp | 27 --------- 2 files changed, 122 insertions(+), 27 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index a86b71a79e..ec116e6e4e 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -7,6 +7,7 @@ import sys import pytest +import math import torch from torch.testing._internal.common_utils import ( TestCase, @@ -109,6 +110,127 @@ def test_quant_llm_linear_correctness( rtol = 1e-2 if dtype == torch.bfloat16 else 1e-3 assert relative_error < rtol + def _scaled_dot_product_int8_op_ref( + self, + q, + k, + v, + attn_mask=None, + dropout_p=0, + is_causal=False, + q_zp=0, + q_scale=1.0, + k_zp=0, + k_scale=1.0, + v_zp=0, + v_scale=1.0, + a_zp=0, + a_scale=1.0, + o_zp=0, + o_scale=1.0): + q = q.to(torch.float) + k = k.to(torch.float) + v = v.to(torch.float) + scale_factor = 1 / math.sqrt(q.size(-1)) + attn = q @ k.transpose(-2, -1) + attn = attn * scale_factor + if attn_mask is not None: + attn = attn + attn_mask + attn_max = attn.max(dim=-1, keepdim=True).values + attn = attn - attn_max + attn = torch.exp(attn) + attn_sum = torch.sum(attn, dim=-1, keepdim=True) + attn = attn / attn_sum + math_ref = attn @ v + return math_ref.to(torch.uint8) + + SDPA_INT8_BATCH_SIZE = [56] + SDPA_INT8_NUM_HEADS = [16] + SDPA_INT8_Q_SEQ_LEN = [188] + SDPA_INT8_KV_SEQ_LEN = [253] + SDPA_INT8_HEAD_DIM = [64] + SDPA_INT8_MASK_DTYPE = [torch.bfloat16] + + SDPA_INT8_TEST_PARAMS = list( + itertools.product( + SDPA_INT8_BATCH_SIZE, + SDPA_INT8_NUM_HEADS, + SDPA_INT8_Q_SEQ_LEN, + SDPA_INT8_KV_SEQ_LEN, + SDPA_INT8_HEAD_DIM, + SDPA_INT8_MASK_DTYPE, + ) + ) + + @parametrize("batch_size", SDPA_INT8_BATCH_SIZE) + @parametrize("n_head", SDPA_INT8_NUM_HEADS) + @parametrize("q_seq_len", SDPA_INT8_Q_SEQ_LEN) + @parametrize("kv_seq_len", SDPA_INT8_KV_SEQ_LEN) + @parametrize("head_dim", SDPA_INT8_HEAD_DIM) + @parametrize("mask_dtype", SDPA_INT8_MASK_DTYPE) + def test_scaled_dot_product_int8_op(self, batch_size, n_head, q_seq_len, kv_seq_len, head_dim, mask_dtype): + device = "cpu" + q_zp = int(127) + q_scale = float(1.7907238006591797) + k_zp = int(125) + k_scale = float(1.8039721250534058) + v_zp = int(127) + v_scale = float(1.839004635810852) + a_zp = int(120) + a_scale = float(0.003919653594493866) + o_zp = int(128) + o_scale = float(1.8191684484481812) + q_shape = [batch_size, n_head, q_seq_len, head_dim] + kv_shape = [batch_size, n_head, kv_seq_len, head_dim] + mask_shape = [batch_size, 1, q_seq_len, kv_seq_len] + q = torch.randn(q_shape, dtype=torch.float, device=device) + k = torch.randn(kv_shape, dtype=torch.float, device=device) + v = torch.randn(kv_shape, dtype=torch.float, device=device) + q = q.to(torch.uint8) + k = k.to(torch.uint8) + v = v.to(torch.uint8) + attn_mask = torch.randn(mask_shape, dtype=mask_dtype, device=device) + q2, k2, v2, attn_mask_2 = q.clone(), k.clone(), v.clone(), attn_mask.clone() + + math_ref = self._scaled_dot_product_int8_op_ref( + q2, + k2, + v2, + attn_mask=attn_mask_2, + dropout_p=0.0, + is_causal=False, + q_zp=q_zp, + q_scale=q_scale, + k_zp=k_zp, + k_scale=k_scale, + v_zp=v_zp, + v_scale=v_scale, + a_zp=a_zp, + a_scale=a_scale, + o_zp=o_zp, + o_scale=o_scale + ) + actual = torch.ops.torchao.scaled_dot_product_int8( + q, + k, + v, + attn_mask=attn_mask, + dropout_p=0.0, + is_causal=False, + q_zp=q_zp, + q_scale=q_scale, + k_zp=k_zp, + k_scale=k_scale, + v_zp=v_zp, + v_scale=v_scale, + a_zp=a_zp, + a_scale=a_scale, + o_zp=o_zp, + o_scale=o_scale + ) + + self.assertEqual(actual, math_ref, atol=3.0, rtol=5e-6) + instantiate_parametrized_tests(TestOps) diff --git a/torchao/csrc/cpu/sdpa.cpp b/torchao/csrc/cpu/sdpa.cpp index 44aaef1bcc..229ef3433e 100644 --- a/torchao/csrc/cpu/sdpa.cpp +++ b/torchao/csrc/cpu/sdpa.cpp @@ -1,30 +1,3 @@ -// // #define TORCH_ASSERT_ONLY_METHOD_OPERATORS -// #include -// #include - -// #include -// #include -// #include -// #include -// #include -// #include -// #include -// #include -// #include -// #include -// #include -// #include -// // #include -// // #include -// #include -// #include - -// #ifndef AT_PER_OPERATOR_HEADERS -// #include -// #else -// #include -// #endif - #include #include #include From 57949ccf0783d3cd4a6547e3a6fd5ad039314097 Mon Sep 17 00:00:00 2001 From: Valentine233 Date: Tue, 3 Dec 2024 02:47:47 -0500 Subject: [PATCH 03/36] update int8 sdpa --- torchao/quantization/sfdp_int8_fx_pass.py | 47 ----------------------- 1 file changed, 47 deletions(-) diff --git a/torchao/quantization/sfdp_int8_fx_pass.py b/torchao/quantization/sfdp_int8_fx_pass.py index 672db14f6b..1dedd60ee5 100644 --- a/torchao/quantization/sfdp_int8_fx_pass.py +++ b/torchao/quantization/sfdp_int8_fx_pass.py @@ -17,37 +17,12 @@ from torchao.ops import scaled_dot_product_int8 __all__ = [ - # "_sfdp_pattern_int8", - # "_sfdp_replacement_int8", - # "_gen_sfdp_patterns_int8", "_sfdp_init_int8", ] aten = torch.ops.aten -# scaled_dot_product_int8 = torch.ops.torchao.scaled_dot_product_int8 patterns = PatternMatcherPass() -# def _sfdp_pattern_int8(query, key, value, inv_scale): -# return ( -# torch.matmul(query, key.transpose(-2, -1)) -# .div(inv_scale) -# .softmax(dim=-1) -# .matmul(value) -# ) - -# def _sfdp_replacement_int8(query, key, value, inv_scale): -# print("*** enter _sfdp_replacement in torchao ***") -# counters["inductor"]["fuse_attention_int8"] += 1 -# return torch.nn.functional.scaled_dot_product_attention( -# query, -# key, -# value, -# attn_mask=None, -# dropout_p=0.0, -# is_causal=False, -# scale=1.0 / inv_scale, -# ) - def _sfdp_pattern_int8_1( query, key, @@ -476,17 +451,6 @@ def _gen_sfdp_patterns_int8(): ) m_bs1_inp = functools.partial(torch.empty, (1, 1, 1, 4), device=device) for dtype in [torch.float, torch.half]: - # g = functools.partial(g_inp, dtype=dtype) - # c = functools.partial(c_inp, dtype=dtype) - # candidates = [ - # ( - # _sfdp_pattern_int8, - # _sfdp_replacement_int8, - # [g(), g(), g(), c()], - # {}, - # _sfdp_extra_check_int8(aten.div.Tensor), - # ), - # ] g_u8 = functools.partial(g_inp, dtype=torch.uint8, requires_grad=False) g_u8_bs1 = functools.partial(g_bs1_inp, dtype=torch.uint8, requires_grad=False) m = functools.partial(m_inp, dtype=dtype) @@ -696,16 +660,6 @@ def _gen_sfdp_patterns_int8(): name = pattern.__name__ if len(workaround) >= 1: - # if "dropout_p" in workaround: - # # functools.partial insufficient because we look at signature downstream - # pattern = partialize_and_update_signature(pattern, dropout_p=0.0) - # replacement = partialize_and_update_signature( - # replacement, dropout_p=0.0 - # ) - # workaround = {} - # else: - # for uint8 pattern with more workarounds other than dropout, - # we need to rename it to avoid influcing other patterns pattern = partialize_and_update_signature(pattern, dropout=0.0) replacement = partialize_and_update_signature( replacement, dropout=0.0 @@ -730,4 +684,3 @@ def _sfdp_init_int8(): for key, register_replacement_kwargs in _gen_sfdp_patterns_int8(): register_replacement(**register_replacement_kwargs) config.joint_custom_pre_pass = patterns.apply - # print("\n\njoint_custom_pre_pass:", config.joint_custom_pre_pass) From cf82d1ca1fd9a794d9b3829f27f557e63947a67e Mon Sep 17 00:00:00 2001 From: Valentine233 Date: Tue, 3 Dec 2024 03:03:42 -0500 Subject: [PATCH 04/36] update int8 sdpa --- torchao/csrc/cpu/toy.cpp | 20 -------------------- 1 file changed, 20 deletions(-) delete mode 100644 torchao/csrc/cpu/toy.cpp diff --git a/torchao/csrc/cpu/toy.cpp b/torchao/csrc/cpu/toy.cpp deleted file mode 100644 index a835aae661..0000000000 --- a/torchao/csrc/cpu/toy.cpp +++ /dev/null @@ -1,20 +0,0 @@ -#include -#include -#include - -namespace torchao { - -torch::Tensor toy_op2_cpu( - torch::Tensor _in_feats) -{ - std::cout<<"---- run into cpu 2 ----"< Date: Tue, 17 Dec 2024 00:11:14 -0500 Subject: [PATCH 05/36] update int8 sdpa cpu --- test/quantization/test_sfdp_int8_fx_pass.py | 10 +- test/test_ops.py | 63 +- torchao/csrc/cpu/sdpa.cpp | 1148 ++++++------------- torchao/ops.py | 4 +- 4 files changed, 414 insertions(+), 811 deletions(-) diff --git a/test/quantization/test_sfdp_int8_fx_pass.py b/test/quantization/test_sfdp_int8_fx_pass.py index a39a98c364..3e17d9ce81 100644 --- a/test/quantization/test_sfdp_int8_fx_pass.py +++ b/test/quantization/test_sfdp_int8_fx_pass.py @@ -16,7 +16,7 @@ from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq -from torch._export import capture_pre_autograd_graph +from torch.export import export_for_training from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e from torch.ao.quantization.quantizer.x86_inductor_quantizer import ( X86InductorQuantizer, @@ -65,7 +65,7 @@ def forward(self, x, mask): if self.has_mask: scores = scores + mask attention = self.softmax(scores) - # attention = self.dropout(attention) + attention = self.dropout(attention) context_layer = torch.matmul(attention, v) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() context_layer = context_layer.view( @@ -75,7 +75,7 @@ def forward(self, x, mask): def _generate_qdq_quantized_model(mod, inputs, quantizer): with torch.no_grad(): - export_model = capture_pre_autograd_graph(mod, inputs) + export_model = export_for_training(mod, inputs).module() prepare_model = prepare_pt2e(export_model, quantizer) prepare_model(*inputs) convert_model = convert_pt2e(prepare_model) @@ -173,10 +173,10 @@ def _test_sdpa_rewriter_int8_1_to_4(self): if dtype == torch.bfloat16 else contextlib.nullcontext() ) - inputs = [ + inputs = ( torch.randn((bs, 384, 64 * 16), device=self.device, dtype=dtype), torch.randn((bs, 1, 1, 384), device=self.device) if has_mask else None, - ] + ) with torch.no_grad(), maybe_autocast: _sfdp_init_int8() quantizer = X86InductorQuantizer() diff --git a/test/test_ops.py b/test/test_ops.py index ec116e6e4e..d7094f29f4 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -128,39 +128,31 @@ def _scaled_dot_product_int8_op_ref( a_scale=1.0, o_zp=0, o_scale=1.0): - q = q.to(torch.float) - k = k.to(torch.float) - v = v.to(torch.float) + q = (q.to(torch.float) - q_zp) * q_scale + k = (k.to(torch.float) - k_zp) * k_scale + v = (v.to(torch.float) - v_zp) * v_scale scale_factor = 1 / math.sqrt(q.size(-1)) attn = q @ k.transpose(-2, -1) attn = attn * scale_factor if attn_mask is not None: - attn = attn + attn_mask + attn = attn + attn_mask.to(torch.float) attn_max = attn.max(dim=-1, keepdim=True).values attn = attn - attn_max attn = torch.exp(attn) attn_sum = torch.sum(attn, dim=-1, keepdim=True) attn = attn / attn_sum - math_ref = attn @ v - return math_ref.to(torch.uint8) - - SDPA_INT8_BATCH_SIZE = [56] - SDPA_INT8_NUM_HEADS = [16] - SDPA_INT8_Q_SEQ_LEN = [188] - SDPA_INT8_KV_SEQ_LEN = [253] - SDPA_INT8_HEAD_DIM = [64] - SDPA_INT8_MASK_DTYPE = [torch.bfloat16] - - SDPA_INT8_TEST_PARAMS = list( - itertools.product( - SDPA_INT8_BATCH_SIZE, - SDPA_INT8_NUM_HEADS, - SDPA_INT8_Q_SEQ_LEN, - SDPA_INT8_KV_SEQ_LEN, - SDPA_INT8_HEAD_DIM, - SDPA_INT8_MASK_DTYPE, - ) - ) + attn = torch.clamp(torch.round(attn / a_scale) + a_zp, min=0, max=255) + attn = (attn - a_zp) * a_scale + out = attn @ v + out = torch.clamp(torch.round(out / o_scale) + o_zp, min=0, max=255) + return out.to(torch.uint8) + + SDPA_INT8_BATCH_SIZE = [56, 120] + SDPA_INT8_NUM_HEADS = [2, 16] + SDPA_INT8_Q_SEQ_LEN = [18, 89] + SDPA_INT8_KV_SEQ_LEN = [100, 253] + SDPA_INT8_HEAD_DIM = [32, 64] + SDPA_INT8_MASK_DTYPE = [None, torch.float32, torch.bfloat16] @parametrize("batch_size", SDPA_INT8_BATCH_SIZE) @parametrize("n_head", SDPA_INT8_NUM_HEADS) @@ -169,6 +161,7 @@ def _scaled_dot_product_int8_op_ref( @parametrize("head_dim", SDPA_INT8_HEAD_DIM) @parametrize("mask_dtype", SDPA_INT8_MASK_DTYPE) def test_scaled_dot_product_int8_op(self, batch_size, n_head, q_seq_len, kv_seq_len, head_dim, mask_dtype): + torch.manual_seed(1234) device = "cpu" q_zp = int(127) q_scale = float(1.7907238006591797) @@ -180,23 +173,23 @@ def test_scaled_dot_product_int8_op(self, batch_size, n_head, q_seq_len, kv_seq_ a_scale = float(0.003919653594493866) o_zp = int(128) o_scale = float(1.8191684484481812) - q_shape = [batch_size, n_head, q_seq_len, head_dim] - kv_shape = [batch_size, n_head, kv_seq_len, head_dim] - mask_shape = [batch_size, 1, q_seq_len, kv_seq_len] - q = torch.randn(q_shape, dtype=torch.float, device=device) - k = torch.randn(kv_shape, dtype=torch.float, device=device) - v = torch.randn(kv_shape, dtype=torch.float, device=device) + q_shape = [batch_size, q_seq_len, n_head, head_dim] + kv_shape = [batch_size, kv_seq_len, n_head, head_dim] + mask_shape = [batch_size, 1, 1, kv_seq_len] + q = torch.randn(q_shape, dtype=torch.float, device=device).transpose(1, 2) * 100 + k = torch.randn(kv_shape, dtype=torch.float, device=device).transpose(1, 2) * 100 + v = torch.randn(kv_shape, dtype=torch.float, device=device).transpose(1, 2) * 100 q = q.to(torch.uint8) k = k.to(torch.uint8) v = v.to(torch.uint8) - attn_mask = torch.randn(mask_shape, dtype=mask_dtype, device=device) - q2, k2, v2, attn_mask_2 = q.clone(), k.clone(), v.clone(), attn_mask.clone() + attn_mask = torch.randn(mask_shape, dtype=mask_dtype, device=device) if mask_dtype is not None else None + q2, k2, v2, attn_mask_2 = q.clone(), k.clone(), v.clone(), attn_mask.clone() if mask_dtype is not None else None math_ref = self._scaled_dot_product_int8_op_ref( q2, k2, v2, - attn_mask=attn_mask_2, + attn_mask=attn_mask, dropout_p=0.0, is_causal=False, q_zp=q_zp, @@ -214,7 +207,7 @@ def test_scaled_dot_product_int8_op(self, batch_size, n_head, q_seq_len, kv_seq_ q, k, v, - attn_mask=attn_mask, + attn_mask=attn_mask_2, dropout_p=0.0, is_causal=False, q_zp=q_zp, @@ -229,7 +222,7 @@ def test_scaled_dot_product_int8_op(self, batch_size, n_head, q_seq_len, kv_seq_ o_scale=o_scale ) - self.assertEqual(actual, math_ref, atol=3.0, rtol=5e-6) + self.assertEqual(actual, math_ref, atol=1.0, rtol=5e-6) instantiate_parametrized_tests(TestOps) diff --git a/torchao/csrc/cpu/sdpa.cpp b/torchao/csrc/cpu/sdpa.cpp index 229ef3433e..3357608db5 100644 --- a/torchao/csrc/cpu/sdpa.cpp +++ b/torchao/csrc/cpu/sdpa.cpp @@ -37,11 +37,17 @@ struct is_reduced_floating_point: template constexpr bool is_reduced_floating_point_v = is_reduced_floating_point::value; +inline double calculate_scale( + const at::Tensor& query, + double scale) { + return scale == 0.0 ? 1.0 / std::sqrt(query.size(-1)) : scale; +} + // out = val * a + b // is_b_stride_zero: If the stride of b is 0 (mask broadcasting case), // take b as a scalar pointer. template -void _scale_attn_mask_fusion_kernel( +inline void _scale_attn_mask_fusion_kernel( T1* a, T2* b, const int& size, @@ -81,7 +87,7 @@ void _scale_attn_mask_fusion_kernel( // 1) out = exp(a - val) // 2) val = sum(out) template -void _exp_reduce_sum_fusion_kernel( +inline void _exp_reduce_sum_fusion_kernel( T1* a, const int& size, T2* out, @@ -115,7 +121,7 @@ void _exp_reduce_sum_fusion_kernel( // 1) out = a * scale // 2) max = max(out) template -void _mul_reduce_max_fusion_kernel( +inline void _mul_reduce_max_fusion_kernel( const scalar_t* a, const scalar_t& scale, const int& size, @@ -137,30 +143,23 @@ void _mul_reduce_max_fusion_kernel( tmp_max = std::max(tmp_max, tmp1); out[i] = tmp1; } - // max = std::max( - // tmp_max, - // at::vec::vec_reduce_all( - // [](vec::Vectorized& x, at::vec::Vectorized& y) { - // return at::vec::maximum(x, y); - // }, - // vec_tmp_max)); max = std::max(tmp_max, vec_tmp_max.reduce_max()); } template -static scalar_t* conditional_data_ptr(scalar_t* ptr, scalar_t* ptr2) { +static inline scalar_t* conditional_data_ptr(scalar_t* ptr, scalar_t* ptr2) { TORCH_CHECK(ptr2 == nullptr); return ptr; } template , int> = 0> -static scalar_t* conditional_data_ptr(float* ptr, scalar_t* ptr2) { +static inline scalar_t* conditional_data_ptr(float* ptr, scalar_t* ptr2) { return ptr2; } template -void fill_stub(scalar_t* data, scalar_t val, int64_t size) { +inline void fill_stub(scalar_t* data, scalar_t val, int64_t size) { using Vec = at::vec::Vectorized; Vec data_vec = Vec(val); int64_t d = 0; @@ -202,26 +201,26 @@ void reshape_attn_mask_to_4d( // TODO: Use at::native::_store instead when it supports Half. template -void _store(scalar_t* dst, at::vec::Vectorized src, int size=at::vec::Vectorized::size()) { +inline void _store(scalar_t* dst, at::vec::Vectorized src, int size=at::vec::Vectorized::size()) { src.store(dst, size); } template -typename std::enable_if_t, void> +inline typename std::enable_if_t, void> _store(scalar_t* dst, at::vec::Vectorized src) { auto res = at::vec::convert_from_float(src, src); res.store(dst, at::vec::Vectorized::size()); } template -typename std::enable_if_t || std::is_same_v, void> +inline typename std::enable_if_t || std::is_same_v, void> _store(scalar_t* dst, at::vec::Vectorized src) { auto res = at::vec::convert(src); res.store(dst, at::vec::Vectorized::size()); } template -void pad_row_zero( +inline void pad_row_zero( scalar_t* value_ptr, scalar_t* padding_value_ptr, int rows, @@ -258,7 +257,7 @@ void pad_row_zero( } template -void pad_row_128_padding( +inline void pad_row_128_padding( scalar_t* value_ptr, scalar_t* padding_value_ptr, int rows, @@ -298,7 +297,7 @@ void pad_row_128_padding( } template -void pad_col_zero( +inline void pad_col_zero( scalar_t* value_ptr, scalar_t* padding_value_ptr, int rows, @@ -322,7 +321,7 @@ void pad_col_zero( } template -void pad_col_zero_padding( +inline void pad_col_zero_padding( scalar_t* value_ptr, scalar_t* padding_value_ptr, int rows, @@ -352,7 +351,7 @@ void pad_col_zero_padding( 3. max reduce for softmax */ template -void _dequant_mask_max_fusion_kernel( +inline void _dequant_mask_max_fusion_kernel( const int32_t* in, const mask_t* mask_ptr, const int32_t* sum_a_ptr, @@ -414,7 +413,7 @@ void _dequant_mask_max_fusion_kernel( 1. dequant 2. max reduce for softmax */ -void _dequant_max_fusion_kernel( +inline void _dequant_max_fusion_kernel( const int32_t* in, const int32_t* sum_a_ptr, const int32_t* sum_b_ptr, @@ -469,7 +468,7 @@ void _dequant_max_fusion_kernel( 3. sum for attention */ template -void _sub_exp_sum_div_quant_sum_fusion_kernel( +inline void _sub_exp_sum_div_quant_sum_fusion_kernel( const float* in, const int64_t& M, const int64_t& N_step, @@ -572,7 +571,7 @@ void _sub_exp_sum_div_quant_sum_fusion_kernel( } template -void _sub_exp_sum_div_quant_fusion_kernel( +inline void _sub_exp_sum_div_quant_fusion_kernel( const float* in, const int64_t& M, const int64_t& N_step, @@ -669,7 +668,7 @@ void _sub_exp_sum_div_quant_fusion_kernel( 2. quant */ template -void _dequant_quant_fusion_kernel( +inline void _dequant_quant_fusion_kernel( const int32_t* in, const int32_t* sum_a_ptr, const int32_t* sum_b_ptr, @@ -727,7 +726,7 @@ void _dequant_quant_fusion_kernel( } template -void _int_sum_b_contiguous_kernel_helper( +inline void _int_sum_b_contiguous_kernel_helper( const scalar_t* in, int32_t* out, const int& N, @@ -742,14 +741,13 @@ void _int_sum_b_contiguous_kernel_helper( } tmp_sum += vec_tmp_sum.reduce_add(); for (long i = vec_size * (N / vec_size); i < N; i++) { - // for (long i = 0; i < N; i++) { tmp_sum += static_cast(in[i]); } out[0] = tmp_sum * scale; } template -void _int_sum_b_contiguous_kernel( +inline void _int_sum_b_contiguous_kernel( const scalar_t* in, int32_t* out, const int& M, @@ -762,7 +760,7 @@ void _int_sum_b_contiguous_kernel( } template -void _int_sum_a_contiguous_kernel( +inline void _int_sum_a_contiguous_kernel( const scalar_t* in, int32_t* out, const int& M, @@ -791,7 +789,6 @@ void _int_sum_a_contiguous_kernel( _store(out + i, tmp3); } for (long i = vec_size * (M / vec_size); i < M; i++) { - // for (long i = 0; i < M; i++) { auto tmp0 = tmp_in[i]; auto tmp1 = out[i]; auto tmp2 = static_cast(tmp0); @@ -812,7 +809,7 @@ void _int_sum_a_contiguous_kernel( } } -void do_convert_u8_s8( +inline void do_convert_u8_s8( unsigned char* src, signed char* dst, int64_t in_rows, @@ -832,7 +829,6 @@ void do_convert_u8_s8( _store(tmp_dst + c, tmp3, vec_size); } for (int64_t c = vec_size * (in_cols / vec_size); c < in_cols; c++) { - // for (int64_t c = 0; c < in_cols; c++) { auto tmp0 = tmp_src[c]; auto tmp1 = (int16_t) tmp0; auto tmp2 = tmp1 - 128; @@ -843,7 +839,7 @@ void do_convert_u8_s8( } template -void do_transpose( +inline void do_transpose( scalar_t* src, scalar_t* dst, int64_t in_rows, @@ -858,7 +854,7 @@ void do_transpose( } template -void do_copy( +inline void do_copy( scalar_t* src, scalar_t* dst, int64_t in_rows, @@ -873,7 +869,7 @@ void do_copy( } template -void pad_remain_row_col( +inline void pad_remain_row_col( scalar_t* value_ptr, int rows, int cols, @@ -911,7 +907,7 @@ void pad_remain_row_col( } template -void copy_value_with_pad( +inline void copy_value_with_pad( scalar_t* value_ptr, scalar_t* dst_ptr, int rows, @@ -964,64 +960,9 @@ void copy_value_with_pad( } -// thread_local std::unordered_map< -// BrgemmKey, -// std::shared_ptr> cache_brgemm_kernels; - -// thread_local std::unordered_map< -// PackBKey, -// std::shared_ptr> cache_packb_kernels; - -// std::shared_ptr create_or_get_microkernel( -// int64_t M, -// int64_t N, -// int64_t K, -// int64_t batch_size, -// int lda, -// int ldb, -// int ldc, -// dt dt_a, -// dt dt_b, -// dt dt_c) { -// BrgemmKey key_brgemm(M, N, K, batch_size, lda, ldb, ldc, dt_a, dt_b, dt_c); -// auto search = cache_brgemm_kernels.find(key_brgemm); -// if (search != cache_brgemm_kernels.end()) { -// return search->second; -// } else { -// cache_brgemm_kernels.insert( -// {key_brgemm, -// std::make_shared( -// M, N, K, batch_size, lda, ldb, ldc, dt_a, dt_b, dt_c)}); -// return cache_brgemm_kernels[key_brgemm]; -// } -// } - -// std::shared_ptr create_or_get_packb_microkernel( -// int64_t K, -// int64_t N, -// int ld_in, -// int ld_out, -// dt dt_in, -// dt dt_out, -// bool do_trans) { -// PackBKey key_packb(K, N, ld_in, ld_out, dt_in, dt_out); -// auto search = cache_packb_kernels.find(key_packb); -// if (search != cache_packb_kernels.end()) { -// return search->second; -// } else { -// cache_packb_kernels.insert( -// {key_packb, -// std::make_shared( -// K, N, -// do_trans ? dnnl::ukernel::pack_type::trans : dnnl::ukernel::pack_type::no_trans, -// ld_in, ld_out, dt_in, dt_out)}); -// return cache_packb_kernels[key_packb]; -// } -// } - // UINT8 - u8u8s32 template -typename std::enable_if_t, void> +inline typename std::enable_if_t, void> sdpa_int8_kernel_impl( const at::Tensor& output, const at::Tensor& q, @@ -1041,10 +982,6 @@ sdpa_int8_kernel_impl( float a_scale, int32_t o_zp, float o_scale) { - // using dt = dnnl::memory::data_type; - // using namespace dnnl; - // using namespace dnnl::ukernel; - // auto starts = duration_cast(system_clock::now().time_since_epoch()).count(); // Query (Batch x Num_heads x Q_seq_len x Dim_per_head) // -> (Batch x Q_seq_len x Num_heads x Dim_per_head) // Key (Batch x Num_heads x KV_seq_len x Dim_per_head) @@ -1055,15 +992,11 @@ sdpa_int8_kernel_impl( at::Tensor key = k.transpose(1, 2); at::Tensor value = v.transpose(1, 2); - const auto accumulate_dtype = at::kFloat; // at::toOpMathType(dtype); + const auto accumulate_dtype = at::kFloat; - using accum_t = float; // at::opmath_type; + using accum_t = float; using Vec = at::vec::Vectorized; - accum_t scaling_factor = - sdp::calculate_scale(query, scale).as_float_unchecked(); - // if (attention_mask.defined() && attention_mask.scalar_type() != ScalarType::Float) { - // attention_mask = attention_mask.to(at::kFloat); - // } + accum_t scaling_factor = calculate_scale(query, scale); int block_64 = 64; // Sizes TORCH_CHECK( @@ -1150,11 +1083,6 @@ sdpa_int8_kernel_impl( int av_gemm_K_tail_padding = av_gemm_K_tail_mul4 ? 0 : 4 - kvTail % 4; int av_gemm_K_tail = kvTail + av_gemm_K_tail_padding; - - // dt u8_dt = dt::u8; - // dt s8_dt = dt::s8; - // // dt f32_dt = dt::f32; - // dt s32_dt = dt::s32; auto u8_dt = at::ScalarType::Byte; auto s8_dt = at::ScalarType::Int; auto f32_dt = at::ScalarType::Float; @@ -1174,119 +1102,14 @@ sdpa_int8_kernel_impl( int qk_gemm_K_padding = headSize_mul4 ? 0 : 4 - headSize % 4; int qk_gemm_K = headSize + qk_gemm_K_padding; - // auto && qk_gemm = create_or_get_microkernel( - // qSplitSize, block_64, qk_gemm_K, - // 1, //batch_size - // qk_gemm_K, // lda - // block_64, //ldb - // rndkvSplitSize, //ldc - // u8_dt, //a dtype - // u8_dt, //b dtype - // s32_dt //c dtype - // ); - // (*qk_gemm).finalize(); - // (*qk_gemm).generate(); - // size_t qk_scratchpad_size = (*qk_gemm).get_scratchpad_size(); - - // auto && qk_gemm_ktail = create_or_get_microkernel( - // qSplitSize, block_64, qk_gemm_K, - // 1, //batch_size - // qk_gemm_K, // lda - // block_64, //ldb - // rndkvTail, //ldc - // u8_dt, //a dtype - // u8_dt, //b dtype - // s32_dt //c dtype - // ); - // // size_t qk_ktail_scratchpad_size = (*qk_gemm_ktail).get_scratchpad_size(); - // (*qk_gemm_ktail).finalize(); - // (*qk_gemm_ktail).generate(); - - // std::shared_ptr qk_gemm_ktail_tail; - // if (kvTail % block_64 != 0) { - // qk_gemm_ktail_tail = create_or_get_microkernel( - // qSplitSize, kv_tail_tail_block_size, qk_gemm_K, - // 1, //batch_size - // qk_gemm_K, // lda - // kv_tail_tail_block_size, //ldb - // rndkvTail, //ldc - // u8_dt, //a dtype - // u8_dt, //b dtype - // s32_dt //c dtype - // ); - // (*qk_gemm_ktail_tail).finalize(); - // (*qk_gemm_ktail_tail).generate(); - // } - - // auto && qk_gemm_qtail = create_or_get_microkernel( - // qTail, block_64, qk_gemm_K, - // 1, //batch_size - // qk_gemm_K,//headSize_mul4 ? qStrideM : qk_gemm_K, // lda - // block_64, //ldb - // rndkvSplitSize, //ldc - // u8_dt, //a dtype - // u8_dt, //b dtype - // s32_dt //c dtype - // ); - // // size_t qk_qtail_scratchpad_size = (*qk_gemm_qtail).get_scratchpad_size(); - // (*qk_gemm_qtail).finalize(); - // (*qk_gemm_qtail).generate(); - // auto && qk_gemm_qktail = create_or_get_microkernel( - // qTail, block_64, qk_gemm_K, - // 1, //batch_size - // qk_gemm_K, // lda - // block_64, //ldb - // rndkvTail, //ldc - // u8_dt, //a dtype - // u8_dt, //b dtype - // s32_dt //c dtype - // ); - // // size_t qk_qktail_scratchpad_size = (*qk_gemm_qktail).get_scratchpad_size(); - // (*qk_gemm_qktail).finalize(); - // (*qk_gemm_qktail).generate(); - - // std::shared_ptr qk_gemm_qktail_tail; - // if (kvTail % block_64 != 0) { - // qk_gemm_qktail_tail = create_or_get_microkernel( - // qSplitSize, kv_tail_tail_block_size, qk_gemm_K, - // 1, //batch_size - // qk_gemm_K, // lda - // kv_tail_tail_block_size, //ldb - // rndkvTail, //ldc - // u8_dt, //a dtype - // u8_dt, //b dtype - // s32_dt //c dtype - // ); - // (*qk_gemm_qktail_tail).finalize(); - // (*qk_gemm_qktail_tail).generate(); - // } - - // std::vector> A_B_offsets(1); std::vector> A_B_offsets(1); A_B_offsets[0] = std::make_pair(0, 0); - // std::vector> A_B_offsets_batch(kvSlice); std::vector> A_B_offsets_batch(kvSlice); for (auto s=0; s(); int64_t kv_sum_size_per_BH = @@ -1313,9 +1133,8 @@ sdpa_int8_kernel_impl( at::Tensor kv_sum_buf = at::empty( {batchSize, num_head, kv_sum_size_per_BH}, - query.options().dtype(at::kInt)); - int32_t* k_sum_buf_data = kv_sum_buf.data_ptr(); - int32_t* v_sum_buf_data = k_sum_buf_data + batchSize * num_head * kvSize; + query.options().dtype(at::kInt)).zero_(); + int32_t* kv_sum_buf_data = kv_sum_buf.data_ptr(); int64_t kv_reorder_size_per_BH = /* key_t_reorder */ qk_gemm_K * rndkvSize + @@ -1323,183 +1142,74 @@ sdpa_int8_kernel_impl( at::Tensor kv_reorder_buf = at::empty( {batchSize, num_head, kv_reorder_size_per_BH}, - query.options()); + query.options()).zero_(); scalar_t* kv_reorder_buf_data = kv_reorder_buf.data_ptr(); scalar_t* key_reorder_ptr = kv_reorder_buf_data; scalar_t* value_reorder_ptr = kv_reorder_buf_data + batchSize * num_head * qk_gemm_K * rndkvSize; -// // Create transforms for Key -// auto && brgemm_k_xform = create_or_get_packb_microkernel( -// qk_gemm_K, // K -// block_64, // N -// block_64, // kStrideN, // block_64, // ld_in -// block_64, // ld_out -// u8_dt, // dt_in -// u8_dt, // dt_out -// false // true -// ); -// (*brgemm_k_xform).generate(); -// auto && brgemm_k_xform_tail = create_or_get_packb_microkernel( -// qk_gemm_K, -// block_64, -// block_64, // kStrideN, // block_64, -// block_64, -// u8_dt, -// u8_dt, -// false // true -// ); -// (*brgemm_k_xform_tail).generate(); -// std::shared_ptr brgemm_k_xform_tail_tail; -// if (kvTail % block_64 != 0) { -// brgemm_k_xform_tail_tail = create_or_get_packb_microkernel( -// qk_gemm_K, -// kv_tail_tail_block_size, -// kv_tail_tail_block_size, // kStrideN, // kv_tail_tail_block_size, -// kv_tail_tail_block_size, -// u8_dt, -// u8_dt, -// false // true -// ); -// (*brgemm_k_xform_tail_tail).generate(); -// } - -// // Create transforms for Value -// auto && brgemm_v_xform = create_or_get_packb_microkernel( -// av_gemm_K, -// block_64, -// vStrideN, // block_64, -// block_64, -// u8_dt, -// u8_dt, -// false -// ); -// (*brgemm_v_xform).generate(); -// auto && brgemm_v_xform_tail = create_or_get_packb_microkernel( -// av_gemm_K_tail, -// block_64, -// vStrideN, // block_64, -// block_64, -// u8_dt, -// u8_dt, -// false -// ); -// (*brgemm_v_xform_tail).generate(); - - // sum k - if (q_zp != 0) { - at::parallel_for( - 0, batchSize * num_head * kvSize, 1, [&](int64_t begin, int64_t end) { - int64_t i = 0, j = 0, k = 0; - at::native::data_index_init( - begin, i, batchSize, j, num_head, k, kvSize); - for (const auto z : c10::irange(begin, end)) { - (void)z; // Suppress unused variable - int32_t* k_sum_ptr = k_sum_buf_data - + i * num_head * kvSize - + j * kvSize + k; - _int_sum_b_contiguous_kernel_helper( - k_data + i * kStrideB + j * kStrideH + k * kStrideN, - k_sum_ptr, - headSize, q_zp); - // Move to the next query - at::native::data_index_step(i, batchSize, j, num_head, k, kvSize); - } - }); - } - - // sum v - if (a_zp != 0) { - at::parallel_for( - 0, batchSize * num_head, 1, [&](int64_t begin, int64_t end) { - int64_t i = 0, j = 0; - at::native::data_index_init( - begin, i, batchSize, j, num_head); - for (const auto z : c10::irange(begin, end)) { - (void)z; // Suppress unused variable - int32_t* v_sum_ptr = v_sum_buf_data - + i * num_head * headSize - + j * headSize; - _int_sum_a_contiguous_kernel(v_data + i * vStrideB + j * vStrideH, - v_sum_ptr, - headSize, kvSize, vStrideN, a_zp); - // Move to the next query - at::native::data_index_step(i, batchSize, j, num_head); - } - }); - } - + // sum k and v at::parallel_for( 0, batchSize * num_head, 1, [&](int64_t begin, int64_t end) { int64_t i = 0, j = 0; at::native::data_index_init( begin, i, batchSize, j, num_head); - int ompIdx = at::get_thread_num(); - scalar_t* total_buf_ptr = total_buf_data + ompIdx * total_size_uint8_per_thread; - int32_t offset = 0; - accum_t* qk_data = reinterpret_cast(total_buf_ptr); - offset += kvSlice * qSplitSize * rndkvSplitSize * 4; - accum_t* qk_local_data = reinterpret_cast(total_buf_ptr + offset); - offset += kvSlice * av_gemm_K * 4; - scalar_t* qk_reduced_data = reinterpret_cast(total_buf_ptr + offset); - offset += kvSlice * qSplitSize * av_gemm_K; - int32_t* qk_s32_data = reinterpret_cast(total_buf_ptr + offset); - offset += qSplitSize * rndkvSplitSize * 4; - int32_t* dst_s32_data = reinterpret_cast(total_buf_ptr + offset); - offset += qSplitSize * rndHeadSize * 4; - accum_t* sfm_sum_ptr = reinterpret_cast(total_buf_ptr + offset); - offset += qSplitSize * 4; - int32_t* q_sum_ptr = reinterpret_cast(total_buf_ptr + offset); - offset += qSplitSize * 4; - int32_t* a_sum_ptr = reinterpret_cast(total_buf_ptr + offset); - offset += qSplitSize * 4; - accum_t* sfm_max_ptr = reinterpret_cast(total_buf_ptr + offset); - offset += qSplitSize * 4; - scalar_t* query_t_padding_ptr = reinterpret_cast(total_buf_ptr + offset); - offset += qSplitSize * qk_gemm_K; - // scalar_t* scratchpad_gemm = reinterpret_cast(total_buf_ptr + offset); - // offset += scratchpad_size; - - scalar_t* key_reorder_ptr = reinterpret_cast(total_buf_ptr + offset); - offset += qk_gemm_K * rndkvSize; - scalar_t* value_reorder_ptr = reinterpret_cast(total_buf_ptr + offset); - - uint8_t* B_blocked_xform_u8 = new uint8_t[std::max(qk_gemm_K, av_gemm_K) * block_64]; - for (const auto z : c10::irange(begin, end)) { (void)z; // Suppress unused variable + int32_t* kv_sum_ptr = kv_sum_buf_data + + i * num_head * kv_sum_size_per_BH + + j * kv_sum_size_per_BH; + int32_t* k_sum_ptr = kv_sum_ptr; + int32_t* v_sum_ptr = kv_sum_ptr + kvSize; + if (q_zp == 0) { + fill_stub(k_sum_ptr, static_cast(0), kvSize); + } else { + _int_sum_b_contiguous_kernel(k_data + i * kStrideB + j * kStrideH, + k_sum_ptr, + kvSize, headSize, kStrideN, q_zp); + } + if (a_zp == 0) { + fill_stub(v_sum_ptr, static_cast(0), headSize); + } else { + _int_sum_a_contiguous_kernel(v_data + i * vStrideB + j * vStrideH, + v_sum_ptr, + headSize, kvSize, vStrideN, a_zp); + } + // Move to the next query + at::native::data_index_step(i, batchSize, j, num_head); + } + }); - // pack - for (int64_t n = 0; n < kvSize; n += kvSplitSize) { - // long ss, ee; - if (n + kvSplitSize < kvSize) { - for (int64_t b = 0; b < rndkvSplitSize; b += block_64) { - bool tail = kvSplitSize - b < block_64; - do_transpose( - // do_copy( - k_data + i * kStrideB + j * kStrideH + n * kStrideN + b * kStrideN, - B_blocked_xform_u8, - tail ? kvSplitSize - b : block_64, - headSize, - kStrideN, - block_64); - if (!headSize_mul4) { - pad_remain_row_col( - B_blocked_xform_u8, - headSize, - block_64, - qk_gemm_K, - block_64, - block_64 - ); - } - // Pack - // (*brgemm_k_xform).execute( - // // k_data + i * kStrideB + j * kStrideH + n * kStrideN + b * kStrideN, - // B_blocked_xform_u8, - // key_reorder_ptr + n * qk_gemm_K + - // b * qk_gemm_K - // ); - at::native::cpublas::pack( + // packing + at::parallel_for( + 0, batchSize * num_head * kvSlice, 1, [&](int64_t begin, int64_t end) { + int64_t i = 0, j = 0, l = 0, n = 0; + at::native::data_index_init( + begin, i, batchSize, j, num_head, l, kvSlice); + uint8_t* B_blocked_xform_u8 = new uint8_t[std::max(qk_gemm_K, av_gemm_K) * block_64]; + for (const auto z : c10::irange(begin, end)) { + (void)z; // Suppress unused variable + n = l * kvSplitSize; + if (n + kvSplitSize < kvSize) { + for (int64_t b = 0; b < rndkvSplitSize; b += block_64) { + bool tail = kvSplitSize - b < block_64; + do_transpose( + k_data + i * kStrideB + j * kStrideH + n * kStrideN + b * kStrideN, + B_blocked_xform_u8, + tail ? kvSplitSize - b : block_64, + headSize, + kStrideN, + block_64); + if (!headSize_mul4) { + pad_remain_row_col( + B_blocked_xform_u8, + headSize, + block_64, + qk_gemm_K, + block_64, + block_64 + ); + } + at::native::cpublas::pack( qk_gemm_K, // K block_64, // N block_64, // ld_in @@ -1507,18 +1217,14 @@ sdpa_int8_kernel_impl( u8_dt, // dt_in u8_dt, // dt_out B_blocked_xform_u8, - key_reorder_ptr + n * qk_gemm_K + - b * qk_gemm_K); - } - // split headSize to block_64, block_64, block_64 ... - // [kvSplitSize, headSize] -> [av_gemm_K, block_64 ...] - for (int64_t b = 0; b < rndHeadSize; b += block_64) { - // (*brgemm_v_xform).execute( - // v_data + i * vStrideB + j * vStrideH + n * vStrideN + b, - // // B_blocked_xform_u8, - // value_reorder_ptr + n * rndHeadSize + - // av_gemm_K * b); - at::native::cpublas::pack( + key_reorder_ptr + i * num_head * qk_gemm_K * rndkvSize + + j * qk_gemm_K * rndkvSize + n * qk_gemm_K + + b * qk_gemm_K); + } + // split headSize to block_64, block_64, block_64 ... + // [kvSplitSize, headSize] -> [av_gemm_K, block_64 ...] + for (int64_t b = 0; b < rndHeadSize; b += block_64) { + at::native::cpublas::pack( av_gemm_K, block_64, vStrideN, // block_64, @@ -1526,80 +1232,67 @@ sdpa_int8_kernel_impl( u8_dt, u8_dt, v_data + i * vStrideB + j * vStrideH + n * vStrideN + b, - value_reorder_ptr + n * rndHeadSize + - av_gemm_K * b); - } - } else { - // tail - auto block_size = kvTail >= block_64 ? block_64 : kv_tail_tail_block_size; - int64_t b = 0; - while (b < rndkvTail) { - bool tail = kvTail - b < block_size; - do_transpose( - k_data + i * kStrideB + j * kStrideH + n * kStrideN + b * kStrideN, - B_blocked_xform_u8, - tail ? kvTail - b : block_size, - headSize, - kStrideN, - block_size); - if (!headSize_mul4) { - pad_remain_row_col( - B_blocked_xform_u8, - headSize, - block_size, - qk_gemm_K, - block_size, - block_size - ); - } - // Pack - if (block_size == block_64) { - // (*brgemm_k_xform_tail).execute( - // // k_data + i * kStrideB + j * kStrideH + n * kStrideN + b * kStrideN, - // B_blocked_xform_u8, - // key_reorder_ptr + n * qk_gemm_K + - // b * qk_gemm_K - // ); - at::native::cpublas::pack( + value_reorder_ptr + + i * num_head * kvSlice * av_gemm_K * rndHeadSize + + j * kvSlice * av_gemm_K * rndHeadSize + n * rndHeadSize + + av_gemm_K * b); + } + } else { + // tail + auto block_size = kvTail >= block_64 ? block_64 : kv_tail_tail_block_size; + int64_t b = 0; + while (b < rndkvTail) { + bool tail = kvTail - b < block_size; + do_transpose( + k_data + i * kStrideB + j * kStrideH + n * kStrideN + b * kStrideN, + B_blocked_xform_u8, + tail ? kvTail - b : block_size, + headSize, + kStrideN, + block_size); + if (!headSize_mul4) { + pad_remain_row_col( + B_blocked_xform_u8, + headSize, + block_size, + qk_gemm_K, + block_size, + block_size + ); + } + // Pack + if (block_size == block_64) { + at::native::cpublas::pack( qk_gemm_K, block_64, - block_64, // kStrideN, // block_64, + block_64, block_64, u8_dt, u8_dt, B_blocked_xform_u8, - key_reorder_ptr + n * qk_gemm_K + - b * qk_gemm_K); - } else { - // (*brgemm_k_xform_tail_tail).execute( - // // k_data + i * kStrideB + j * kStrideH + n * kStrideN + b * kStrideN, - // B_blocked_xform_u8, - // key_reorder_ptr + n * qk_gemm_K + - // b * qk_gemm_K - // ); - at::native::cpublas::pack( + key_reorder_ptr + i * num_head * qk_gemm_K * rndkvSize + + j * qk_gemm_K * rndkvSize + n * qk_gemm_K + + b * qk_gemm_K); + } else { + at::native::cpublas::pack( qk_gemm_K, kv_tail_tail_block_size, - kv_tail_tail_block_size, // kStrideN, // kv_tail_tail_block_size, + kv_tail_tail_block_size, kv_tail_tail_block_size, u8_dt, u8_dt, B_blocked_xform_u8, - key_reorder_ptr + n * qk_gemm_K + - b * qk_gemm_K); - } - b += block_size; - block_size = (kvTail - b) >= block_64 ? block_64 : kv_tail_tail_block_size; - } - // split headSize to block_64, block_64, block_64 ... - // [kvTail, headSize] -> [av_gemm_K_tail, block_64 ...] - for (int64_t b = 0; b < headSize; b += block_64) { - // (*brgemm_v_xform).execute( - // v_data + i * vStrideB + j * vStrideH + n * vStrideN + b, - // // B_blocked_xform_u8, - // value_reorder_ptr + n * rndHeadSize + - // av_gemm_K * b); - at::native::cpublas::pack( + key_reorder_ptr + i * num_head * qk_gemm_K * rndkvSize + + j * qk_gemm_K * rndkvSize + n * qk_gemm_K + + b * qk_gemm_K); + } + b += block_size; + block_size = (kvTail - b) >= block_64 ? block_64 : kv_tail_tail_block_size; + } + // split headSize to block_64, block_64, block_64 ... + // [kvTail, headSize] -> [av_gemm_K_tail, block_64 ...] + for (int64_t b = 0; b < headSize; b += block_64) { + at::native::cpublas::pack( av_gemm_K, block_64, vStrideN, // block_64, @@ -1607,63 +1300,93 @@ sdpa_int8_kernel_impl( u8_dt, u8_dt, v_data + i * vStrideB + j * vStrideH + n * vStrideN + b, - value_reorder_ptr + n * rndHeadSize + - av_gemm_K * b); - } - } + value_reorder_ptr + + i * num_head * kvSlice * av_gemm_K * rndHeadSize + + j * kvSlice * av_gemm_K * rndHeadSize + n * rndHeadSize + + av_gemm_K * b); } + } + // Move to the next query + at::native::data_index_step(i, batchSize, j, num_head, l, kvSlice); + } + }); + + at::parallel_for( + 0, batchSize * num_head * qSlice, 1, [&](int64_t begin, int64_t end) { + int64_t i = 0, j = 0, k = 0; + at::native::data_index_init( + begin, i, batchSize, j, num_head, k, qSlice); + int ompIdx = at::get_thread_num(); + scalar_t* total_buf_ptr = total_buf_data + ompIdx * total_size_uint8_per_thread; + int32_t offset = 0; + accum_t* qk_data = reinterpret_cast(total_buf_ptr); + offset += kvSlice * qSplitSize * rndkvSplitSize * 4; + accum_t* qk_local_data = reinterpret_cast(total_buf_ptr + offset); + offset += kvSlice * av_gemm_K * 4; + scalar_t* qk_reduced_data = reinterpret_cast(total_buf_ptr + offset); + offset += kvSlice * qSplitSize * av_gemm_K; + int32_t* qk_s32_data = reinterpret_cast(total_buf_ptr + offset); + offset += qSplitSize * rndkvSplitSize * 4; + int32_t* dst_s32_data = reinterpret_cast(total_buf_ptr + offset); + offset += qSplitSize * rndHeadSize * 4; + accum_t* sfm_sum_ptr = reinterpret_cast(total_buf_ptr + offset); + offset += qSplitSize * 4; + int32_t* q_sum_ptr = reinterpret_cast(total_buf_ptr + offset); + offset += qSplitSize * 4; + int32_t* a_sum_ptr = reinterpret_cast(total_buf_ptr + offset); + offset += qSplitSize * 4; + accum_t* sfm_max_ptr = reinterpret_cast(total_buf_ptr + offset); + offset += qSplitSize * 4; + scalar_t* query_t_padding_ptr = reinterpret_cast(total_buf_ptr + offset); + + for (const auto z : c10::irange(begin, end)) { + (void)z; // Suppress unused variable + + int32_t* kv_sum_ptr = kv_sum_buf_data + + i * num_head * kv_sum_size_per_BH + + j * kv_sum_size_per_BH; + int32_t* k_sum_ptr = kv_sum_ptr; + int32_t* v_sum_ptr = kv_sum_ptr + kvSize; // sdpa core - int32_t* k_sum_ptr = k_sum_buf_data + i * num_head * kvSize + j * kvSize; - int32_t* v_sum_ptr = v_sum_buf_data + i * num_head * headSize + j * headSize; - for (int64_t k = 0; k < qSlice; k++) { - int64_t m = k * qSplitSize; - int64_t qBlockSize = std::min(qSplitSize, qSize - m); - // Initialize sum and max - fill_stub( - sfm_sum_ptr, static_cast(0), qSplitSize); - fill_stub( - a_sum_ptr, static_cast(0), qSplitSize); + int64_t m = k * qSplitSize; + int64_t qBlockSize = std::min(qSplitSize, qSize - m); + // Initialize sum and max + fill_stub( + sfm_sum_ptr, static_cast(0), qSplitSize); + fill_stub( + a_sum_ptr, static_cast(0), qSplitSize); + fill_stub( + sfm_max_ptr, static_cast(-std::numeric_limits::infinity()), qSplitSize); + int64_t num_keys = + is_causal ? std::min(m + qBlockSize, kvSize) : kvSize; + copy_value_with_pad( + q_data + i * qStrideB + j * qStrideH + m * qStrideM, + query_t_padding_ptr, + qBlockSize, + headSize, + qBlockSize, + qk_gemm_K, + qStrideM); + + if (k_zp != 0) { + _int_sum_b_contiguous_kernel(q_data + i * qStrideB + j * qStrideH + m * qStrideM, + q_sum_ptr, qBlockSize, headSize, qStrideM, k_zp); + } else { fill_stub( - sfm_max_ptr, static_cast(-std::numeric_limits::infinity()), qSplitSize); - int64_t num_keys = - is_causal ? std::min(m + qBlockSize, kvSize) : kvSize; - copy_value_with_pad( - q_data + i * qStrideB + j * qStrideH + m * qStrideM, - query_t_padding_ptr, - qBlockSize, - headSize, - qBlockSize, - qk_gemm_K, - qStrideM); - - if (k_zp == 0) { - _int_sum_b_contiguous_kernel(q_data + i * qStrideB + j * qStrideH + m * qStrideM, - q_sum_ptr, qBlockSize, headSize, qStrideM, k_zp); - } else { - fill_stub( - q_sum_ptr, static_cast(0), qSplitSize); - } - const int64_t rkvSlice = (num_keys - 1) / kvSplitSize + 1; - for (int64_t l = 0; l < rkvSlice; l++) { - int64_t n = l * kvSplitSize; - int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n); - // Calculate sums for dequant compensation item - if (qBlockSize == qSplitSize) { - // q main - if (n + kvSplitSize < kvSize) { - // k main - for (int64_t b = 0; b < kvSplitSize; b += block_64) { - // dnnl::ukernel::brgemm::release_hw_context(); - // (*qk_gemm).set_hw_context(); - // (*qk_gemm).execute( - // query_t_padding_ptr, - // key_reorder_ptr + n * qk_gemm_K + - // b * qk_gemm_K, - // A_B_offsets, - // qk_s32_data + b, - // scratchpad_gemm); - at::native::cpublas::brgemm( + q_sum_ptr, static_cast(0), qSplitSize); + } + const int64_t rkvSlice = (num_keys - 1) / kvSplitSize + 1; + for (int64_t l = 0; l < rkvSlice; l++) { + int64_t n = l * kvSplitSize; + int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n); + // Calculate sums for dequant compensation item + if (qBlockSize == qSplitSize) { + // q main + if (n + kvSplitSize < kvSize) { + // k main + for (int64_t b = 0; b < kvSplitSize; b += block_64) { + at::native::cpublas::brgemm( qSplitSize, block_64, qk_gemm_K, 1, //batch_size qk_gemm_K, // lda @@ -1671,26 +1394,18 @@ sdpa_int8_kernel_impl( rndkvSplitSize, //ldc, false, query_t_padding_ptr, - key_reorder_ptr + n * qk_gemm_K + - b * qk_gemm_K, + key_reorder_ptr + i * num_head * qk_gemm_K * rndkvSize + + j * qk_gemm_K * rndkvSize + n * qk_gemm_K + + b * qk_gemm_K, qk_s32_data + b, A_B_offsets); - } - } else { - auto block_size = kvTail >= block_64 ? block_64 : kv_tail_tail_block_size; - int64_t b = 0; - while (b < kvTail) { - if (block_size == block_64) { - // dnnl::ukernel::brgemm::release_hw_context(); - // (*qk_gemm_ktail).set_hw_context(); - // (*qk_gemm_ktail).execute( - // query_t_padding_ptr, - // key_reorder_ptr + n * qk_gemm_K + - // b * qk_gemm_K, - // A_B_offsets, - // qk_s32_data + b, - // scratchpad_gemm); - at::native::cpublas::brgemm( + } + } else { + auto block_size = kvTail >= block_64 ? block_64 : kv_tail_tail_block_size; + int64_t b = 0; + while (b < kvTail) { + if (block_size == block_64) { + at::native::cpublas::brgemm( qSplitSize, block_64, qk_gemm_K, 1, //batch_size qk_gemm_K, // lda @@ -1698,21 +1413,13 @@ sdpa_int8_kernel_impl( rndkvTail, //ldc false, query_t_padding_ptr, - key_reorder_ptr + n * qk_gemm_K + - b * qk_gemm_K, + key_reorder_ptr + i * num_head * qk_gemm_K * rndkvSize + + j * qk_gemm_K * rndkvSize + n * qk_gemm_K + + b * qk_gemm_K, qk_s32_data + b, A_B_offsets); - } else { - // dnnl::ukernel::brgemm::release_hw_context(); - // (*qk_gemm_ktail_tail).set_hw_context(); - // (*qk_gemm_ktail_tail).execute( - // query_t_padding_ptr, - // key_reorder_ptr + n * qk_gemm_K + - // b * qk_gemm_K, - // A_B_offsets, - // qk_s32_data + b, - // scratchpad_gemm); - at::native::cpublas::brgemm( + } else { + at::native::cpublas::brgemm( qSplitSize, kv_tail_tail_block_size, qk_gemm_K, 1, //batch_size qk_gemm_K, // lda @@ -1720,56 +1427,40 @@ sdpa_int8_kernel_impl( rndkvTail, //ldc false, query_t_padding_ptr, - key_reorder_ptr + n * qk_gemm_K + - b * qk_gemm_K, + key_reorder_ptr + i * num_head * qk_gemm_K * rndkvSize + + j * qk_gemm_K * rndkvSize + n * qk_gemm_K + + b * qk_gemm_K, qk_s32_data + b, A_B_offsets); - } - b += block_size; - block_size = (kvTail - b) >= block_64 ? block_64 : kv_tail_tail_block_size; } + b += block_size; + block_size = (kvTail - b) >= block_64 ? block_64 : kv_tail_tail_block_size; } - } else { - if (n + kvSplitSize < kvSize) { - for (int64_t b = 0; b < kvSplitSize; b += block_64) { - // dnnl::ukernel::brgemm::release_hw_context(); - // (*qk_gemm_qtail).set_hw_context(); - // (*qk_gemm_qtail).execute( - // query_t_padding_ptr, - // key_reorder_ptr + n * qk_gemm_K + - // b * qk_gemm_K, - // A_B_offsets, - // qk_s32_data + b, - // scratchpad_gemm); - at::native::cpublas::brgemm( + } + } else { + if (n + kvSplitSize < kvSize) { + for (int64_t b = 0; b < kvSplitSize; b += block_64) { + at::native::cpublas::brgemm( qTail, block_64, qk_gemm_K, 1, //batch_size - qk_gemm_K,//headSize_mul4 ? qStrideM : qk_gemm_K, // lda + qk_gemm_K,// lda block_64, //ldb rndkvSplitSize, //ldc false, query_t_padding_ptr, - key_reorder_ptr + n * qk_gemm_K + - b * qk_gemm_K, + key_reorder_ptr + i * num_head * qk_gemm_K * rndkvSize + + j * qk_gemm_K * rndkvSize + n * qk_gemm_K + + b * qk_gemm_K, qk_s32_data + b, A_B_offsets); - } - } else { - // k tail - auto block_size = kvTail >= block_64 ? block_64 : kv_tail_tail_block_size; - int64_t b = 0; - while (b < kvTail) { - if (block_size == block_64) { - // dnnl::ukernel::brgemm::release_hw_context(); - // (*qk_gemm_qktail).set_hw_context(); - // (*qk_gemm_qktail).execute( - // query_t_padding_ptr, - // key_reorder_ptr + n * qk_gemm_K + - // b * qk_gemm_K, - // A_B_offsets, - // qk_s32_data + b, - // scratchpad_gemm); - at::native::cpublas::brgemm( + } + } else { + // k tail + auto block_size = kvTail >= block_64 ? block_64 : kv_tail_tail_block_size; + int64_t b = 0; + while (b < kvTail) { + if (block_size == block_64) { + at::native::cpublas::brgemm( qTail, block_64, qk_gemm_K, 1, //batch_size qk_gemm_K, // lda @@ -1777,21 +1468,13 @@ sdpa_int8_kernel_impl( rndkvTail, //ldc false, query_t_padding_ptr, - key_reorder_ptr + n * qk_gemm_K + - b * qk_gemm_K, + key_reorder_ptr + i * num_head * qk_gemm_K * rndkvSize + + j * qk_gemm_K * rndkvSize + n * qk_gemm_K + + b * qk_gemm_K, qk_s32_data + b, A_B_offsets); - } else { - // dnnl::ukernel::brgemm::release_hw_context(); - // (*qk_gemm_qktail_tail).set_hw_context(); - // (*qk_gemm_qktail_tail).execute( - // query_t_padding_ptr, - // key_reorder_ptr + n * qk_gemm_K + - // b * qk_gemm_K, - // A_B_offsets, - // qk_s32_data + b, - // scratchpad_gemm); - at::native::cpublas::brgemm( + } else { + at::native::cpublas::brgemm( qSplitSize, kv_tail_tail_block_size, qk_gemm_K, 1, //batch_size qk_gemm_K, // lda @@ -1799,108 +1482,99 @@ sdpa_int8_kernel_impl( rndkvTail, //ldc false, query_t_padding_ptr, - key_reorder_ptr + n * qk_gemm_K + - b * qk_gemm_K, + key_reorder_ptr + i * num_head * qk_gemm_K * rndkvSize + + j * qk_gemm_K * rndkvSize + n * qk_gemm_K + + b * qk_gemm_K, qk_s32_data + b, A_B_offsets); - } - b += block_size; - block_size = (kvTail - b) >= block_64 ? block_64 : kv_tail_tail_block_size; } + b += block_size; + block_size = (kvTail - b) >= block_64 ? block_64 : kv_tail_tail_block_size; } } - - // do dequant compensation, add mask, max reduce for softmax, and convert qk from s32 to fp32 - int64_t rndkvBlockSize = kvBlockSize == kvSplitSize ? rndkvSplitSize : rndkvTail; - accum_t* qk_block_data = qk_data + l * qSplitSize * rndkvSplitSize; - if (has_attn_mask) { - mask_t* mask_data_offset = mask_data + i * mStrideB + j * mStrideH + m * mStrideM + (mStrideN == 0 ? 0 : n); - _dequant_mask_max_fusion_kernel( - qk_s32_data, //in - mask_data_offset, //mask_ptr - q_sum_ptr, //sum_a_ptr - k_sum_ptr + n, //sum_b_ptr - qBlockSize, //M - kvBlockSize, //N - rndkvBlockSize, //ldi - mStrideM, //ldm - rndkvSplitSize,//kvBlockSize, //ldo - q_zp * k_zp * headSize, //zp_a*zp_b*k=beta - q_scale * k_scale * scaling_factor, //scale_a*scale_b*scale_sdpa=alpha - qk_block_data, //out - sfm_max_ptr // sfm_max_ptr - ); - } else { - _dequant_max_fusion_kernel( - qk_s32_data, //in - q_sum_ptr, //sum_a_ptr - k_sum_ptr + n, //sum_b_ptr - qBlockSize, //M - kvBlockSize, //N - rndkvBlockSize, //ldi - rndkvSplitSize,//kvBlockSize, //ldo - q_zp * k_zp * headSize, //zp_a*zp_b*k=beta - q_scale * k_scale * scaling_factor, //scale_a*scale_b*scale_sdpa=alpha - qk_block_data, //out - sfm_max_ptr // sfm_max_ptr - ); - } } - // sub max, exp, sum reduce, div sum for softmax - // and quant - // and sum for attention - if (v_zp == 0) { - _sub_exp_sum_div_quant_fusion_kernel( - qk_data, //in + // do dequant compensation, add mask, max reduce for softmax, and convert qk from s32 to fp32 + int64_t rndkvBlockSize = kvBlockSize == kvSplitSize ? rndkvSplitSize : rndkvTail; + accum_t* qk_block_data = qk_data + l * qSplitSize * rndkvSplitSize; + if (has_attn_mask) { + mask_t* mask_data_offset = mask_data + i * mStrideB + j * mStrideH + m * mStrideM + (mStrideN == 0 ? 0 : n); + _dequant_mask_max_fusion_kernel( + qk_s32_data, //in + mask_data_offset, //mask_ptr + q_sum_ptr, //sum_a_ptr + k_sum_ptr + n, //sum_b_ptr qBlockSize, //M - kvSplitSize, //N_step - rkvSlice, //NSlices - qSplitSize * rndkvSplitSize, //ldi - qSplitSize * av_gemm_K, //ldo - kvSize, //kvSize - rndkvSplitSize, //rndkvSplitSize - av_gemm_K, //av_gemm_K - a_zp, // zp_a=beta1 - a_scale, // scale_a=alpha - qk_local_data, //local - qk_reduced_data, //out - sfm_max_ptr, //sfm_max_ptr - sfm_sum_ptr //sfm_sum_ptr + kvBlockSize, //N + rndkvBlockSize, //ldi + mStrideM, //ldm + rndkvSplitSize,//kvBlockSize, //ldo + q_zp * k_zp * headSize, //zp_a*zp_b*k=beta + q_scale * k_scale * scaling_factor, //scale_a*scale_b*scale_sdpa=alpha + qk_block_data, //out + sfm_max_ptr // sfm_max_ptr ); } else { - _sub_exp_sum_div_quant_sum_fusion_kernel( - qk_data, //in + _dequant_max_fusion_kernel( + qk_s32_data, //in + q_sum_ptr, //sum_a_ptr + k_sum_ptr + n, //sum_b_ptr qBlockSize, //M - kvSplitSize, //N_step - rkvSlice, //NSlice - qSplitSize * rndkvSplitSize, //ldi - qSplitSize * av_gemm_K, //ldo - kvSize, //kvSize - rndkvSplitSize, //rndkvSplitSize - av_gemm_K, //av_gemm_K - a_zp, // zp_a=beta1 - v_zp, // zp_b=beta2 - a_scale, // scale_a=alpha - qk_local_data, //local - qk_reduced_data, //out - sfm_max_ptr, //sfm_max_ptr - sfm_sum_ptr, //sfm_sum_ptr - a_sum_ptr //a_sum_ptr - ); + kvBlockSize, //N + rndkvBlockSize, //ldi + rndkvSplitSize,//kvBlockSize, //ldo + q_zp * k_zp * headSize, //zp_a*zp_b*k=beta + q_scale * k_scale * scaling_factor, //scale_a*scale_b*scale_sdpa=alpha + qk_block_data, //out + sfm_max_ptr // sfm_max_ptr + ); } - - // Calculate Softmax(q @ k.T) @ v - for (int64_t b = 0; b < headSize; b += block_64) { - // dnnl::ukernel::brgemm::release_hw_context(); - // (*av_gemm_batch).set_hw_context(); - // (*av_gemm_batch).execute( - // qk_reduced_data, - // value_reorder_ptr + b * av_gemm_K, - // A_B_offsets_batch, - // dst_s32_data + b, - // scratchpad_gemm); - at::native::cpublas::brgemm( + } + // sub max, exp, sum reduce, div sum for softmax + // and quant + // and sum for attention + if (v_zp == 0) { + _sub_exp_sum_div_quant_fusion_kernel( + qk_data, //in + qBlockSize, //M + kvSplitSize, //N_step + rkvSlice, //NSlices + qSplitSize * rndkvSplitSize, //ldi + qSplitSize * av_gemm_K, //ldo + kvSize, //kvSize + rndkvSplitSize, //rndkvSplitSize + av_gemm_K, //av_gemm_K + a_zp, // zp_a=beta1 + a_scale, // scale_a=alpha + qk_local_data, //local + qk_reduced_data, //out + sfm_max_ptr, //sfm_max_ptr + sfm_sum_ptr //sfm_sum_ptr + ); + } else { + _sub_exp_sum_div_quant_sum_fusion_kernel( + qk_data, //in + qBlockSize, //M + kvSplitSize, //N_step + rkvSlice, //NSlice + qSplitSize * rndkvSplitSize, //ldi + qSplitSize * av_gemm_K, //ldo + kvSize, //kvSize + rndkvSplitSize, //rndkvSplitSize + av_gemm_K, //av_gemm_K + a_zp, // zp_a=beta1 + v_zp, // zp_b=beta2 + a_scale, // scale_a=alpha + qk_local_data, //local + qk_reduced_data, //out + sfm_max_ptr, //sfm_max_ptr + sfm_sum_ptr, //sfm_sum_ptr + a_sum_ptr //a_sum_ptr + ); + } + // Calculate Softmax(q @ k.T) @ v + for (int64_t b = 0; b < headSize; b += block_64) { + at::native::cpublas::brgemm( qSplitSize, block_64, av_gemm_K, kvSlice, //batch_size av_gemm_K, // lda @@ -1908,33 +1582,33 @@ sdpa_int8_kernel_impl( rndHeadSize, //ldc false, qk_reduced_data, - value_reorder_ptr + b * av_gemm_K, + value_reorder_ptr + + i * num_head * kvSlice * av_gemm_K * rndHeadSize + + j * kvSlice * av_gemm_K * rndHeadSize + b * av_gemm_K, dst_s32_data + b, A_B_offsets_batch); - } - - // After the last gemm, - // do dequant compensation, quant and convert from s32 to int8 - _dequant_quant_fusion_kernel( - dst_s32_data, //in - a_sum_ptr, //sum_a_ptr - v_sum_ptr, //sum_b_ptr - qBlockSize, //M - headSize, //N - rndHeadSize, //ldi - oStrideM, //ldo - a_zp * v_zp * kvSize, //zp_a*zp_b*k=beta1 - o_zp, //zp_c=beta2 - a_scale * v_scale / o_scale, //scale_a*scale_b/scale_c=alpha - out_data + i * oStrideB + j * oStrideH + m * oStrideM //out - ); } + + // After the last gemm, + // do dequant compensation, quant and convert from s32 to int8 + _dequant_quant_fusion_kernel( + dst_s32_data, //in + a_sum_ptr, //sum_a_ptr + v_sum_ptr, //sum_b_ptr + qBlockSize, //M + headSize, //N + rndHeadSize, //ldi + oStrideM, //ldo + a_zp * v_zp * kvSize, //zp_a*zp_b*k=beta1 + o_zp, //zp_c=beta2 + a_scale * v_scale / o_scale, //scale_a*scale_b/scale_c=alpha + out_data + i * oStrideB + j * oStrideH + m * oStrideM //out + ); // Move to the next query - at::native::data_index_step(i, batchSize, j, num_head); + at::native::data_index_step(i, batchSize, j, num_head, k, qSlice); } }); // Once all computations are done, need to release HW context. - // brgemm::release_hw_context(); at::native::cpublas::brgemm_release(); } @@ -2040,60 +1714,6 @@ void sdpa_int8_kernel( } } -// at::Tensor sdpa_int8_math_impl( -// const at::Tensor& query_, -// const at::Tensor& key, -// const at::Tensor& value, -// double dropout_p, -// bool is_causal, -// at::Tensor& attn_mask_, -// double scale, -// int32_t q_zp, -// float q_scale, -// int32_t k_zp, -// float k_scale, -// int32_t v_zp, -// float v_scale, -// int32_t a_zp, -// float a_scale, -// int32_t o_zp, -// float o_scale) { -// // dequant q/k/v -// auto q = (query_.to(at::kFloat) - q_zp) * q_scale; -// auto k = (key.to(at::kFloat) - k_zp) * k_scale; -// auto v = (value.to(at::kFloat) - v_zp) * v_scale; -// auto attn_mask = attn_mask_; -// if (attn_mask.defined()) { -// *attn_mask = (*attn_mask).to(at::kFloat); -// } -// // Scale q, k before matmul for stability see https://tinyurl.com/sudb9s96 for math -// bool is_negative_scaling = scale.defined() && scale < 0.0; -// const auto scaling_factor = sdp::calculate_scale(q, is_negative_scaling ? std::abs(scale) : scale).sqrt(); -// q = q * (is_negative_scaling ? c10::SymFloat(0.0) - scaling_factor: scaling_factor); -// auto attn = at::matmul(q, k.transpose(-2, -1) * scaling_factor); -// if (attn_mask.defined()) { -// if (at::areAnyTensorSubclassLike({attn, *attn_mask})) { -// attn = attn.add(*attn_mask); -// } else { -// attn.add_(*attn_mask); -// } -// } -// attn = at::softmax(attn, -1); -// // quant attn -// attn = at::clamp_max( -// at::clamp_min(at::round(attn / a_scale) + a_zp, 0), 255 -// ); -// // dequant attn -// attn = (attn - a_zp) * a_scale; -// auto output = at::matmul(attn, v); -// // quant output -// output = at::clamp_max( -// at::clamp_min(at::round(output / o_scale) + o_zp, 0), 255 -// ).to(at::kByte); -// return output; -// } - - at::Tensor _scaled_dot_product_int8_cpu( const at::Tensor& query, const at::Tensor& key, @@ -2135,16 +1755,6 @@ at::Tensor _scaled_dot_product_int8_cpu( (attn_mask.dim() == 2 || attn_mask.dim() == 4), "_scaled_dot_product_int8_cpu: Attention mask dim in {2, 4}"); - // fallback math path - // at::Tensor output = sdpa_int8_math_impl(query, key, value, - // dropout_p, is_causal, attn_mask, scale, - // q_zp, q_scale, - // k_zp, k_scale, - // v_zp, v_scale, - // a_zp, a_scale, - // o_zp, o_scale); - - // fused sdpa int8 impl at::Tensor output = at::empty_like(query, query.options()).transpose(1, 2); sdpa_int8_kernel(output, query, key, value, dropout_p, is_causal, attn_mask, scale, diff --git a/torchao/ops.py b/torchao/ops.py index 9dcbdd7886..732e0083a5 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -168,7 +168,7 @@ def scaled_dot_product_int8( attn_mask: Tensor = None, dropout_p: float = 0.0, is_causal: bool = False, - scale: float = 1.0, + scale: float = 0.0, q_zp: int = 0, q_scale: float = 1.0, k_zp: int = 0, @@ -197,7 +197,7 @@ def _( attn_mask: Tensor = None, dropout_p: float = 0.0, is_causal: bool = False, - scale: float = 1.0, + scale: float = 0.0, q_zp: int = 0, q_scale: float = 1.0, k_zp: int = 0, From 52ddb9bcbc45521d47404e8862f10331796a8ef8 Mon Sep 17 00:00:00 2001 From: Valentine233 Date: Tue, 7 Jan 2025 01:08:46 -0500 Subject: [PATCH 06/36] update int8 sdpa cpu --- setup.py | 9 + test/quantization/test_sfdp_int8_fx_pass.py | 2 +- torchao/csrc/cpu/sdpa.cpp | 838 +++++++++++++++++--- torchao/quantization/sfdp_int8_fx_pass.py | 42 +- 4 files changed, 763 insertions(+), 128 deletions(-) diff --git a/setup.py b/setup.py index cdbe34ee84..396f869e93 100644 --- a/setup.py +++ b/setup.py @@ -46,6 +46,7 @@ def read_version(file_path="version.txt"): version_suffix = f"+git{get_git_commit_id()}" use_cpp = os.getenv("USE_CPP") +use_cpp_avx512 = os.getenv('USE_AVX512', 1) import platform @@ -291,6 +292,14 @@ def get_extensions(): ["-O3" if not debug_mode else "-O0", "-fdiagnostics-color=always"] ) + if use_cpp_avx512: + extra_compile_args["cxx"].extend([ + "-DCPU_CAPABILITY_AVX512", + "-march=native", + "-mfma", + "-fopenmp", + ]) + if debug_mode: extra_compile_args["cxx"].append("-g") if "nvcc" in extra_compile_args: diff --git a/test/quantization/test_sfdp_int8_fx_pass.py b/test/quantization/test_sfdp_int8_fx_pass.py index 3e17d9ce81..9c5f362ff8 100644 --- a/test/quantization/test_sfdp_int8_fx_pass.py +++ b/test/quantization/test_sfdp_int8_fx_pass.py @@ -160,7 +160,7 @@ def _check_common( def _test_sdpa_rewriter_int8_1_to_4(self): # pattern is different for bs=1 for dtype, has_mask, bs in itertools.product( - [torch.float32], [True, False], [56, 1] + [torch.float32, torch.bfloat16], [True, False], [56, 1] ): mod = SelfAttnLikeModule( input_dim=64 * 16, diff --git a/torchao/csrc/cpu/sdpa.cpp b/torchao/csrc/cpu/sdpa.cpp index 3357608db5..7ef2e3e471 100644 --- a/torchao/csrc/cpu/sdpa.cpp +++ b/torchao/csrc/cpu/sdpa.cpp @@ -963,7 +963,7 @@ inline void copy_value_with_pad( // UINT8 - u8u8s32 template inline typename std::enable_if_t, void> -sdpa_int8_kernel_impl( +sdpa_int8_kernel_large_impl( const at::Tensor& output, const at::Tensor& q, const at::Tensor& k, @@ -1102,18 +1102,601 @@ sdpa_int8_kernel_impl( int qk_gemm_K_padding = headSize_mul4 ? 0 : 4 - headSize % 4; int qk_gemm_K = headSize + qk_gemm_K_padding; - std::vector> A_B_offsets(1); - A_B_offsets[0] = std::make_pair(0, 0); + int64_t qk_reduce_strideL = qSplitSize * av_gemm_K; + int64_t v_reorder_strideL = av_gemm_K * rndHeadSize; - std::vector> A_B_offsets_batch(kvSlice); - for (auto s=0; s(); + + at::parallel_for( + 0, batchSize * num_head, 1, [&](int64_t begin, int64_t end) { + int64_t i = 0, j = 0; + at::native::data_index_init( + begin, i, batchSize, j, num_head); + int ompIdx = at::get_thread_num(); + scalar_t* total_buf_ptr = total_buf_data + ompIdx * total_size_uint8_per_thread; + int32_t offset = 0; + accum_t* qk_data = reinterpret_cast(total_buf_ptr); + offset += kvSlice * qSplitSize * rndkvSplitSize * 4; + accum_t* qk_local_data = reinterpret_cast(total_buf_ptr + offset); + offset += kvSlice * av_gemm_K * 4; + scalar_t* qk_reduced_data = reinterpret_cast(total_buf_ptr + offset); + offset += kvSlice * qk_reduce_strideL; + int32_t* qk_s32_data = reinterpret_cast(total_buf_ptr + offset); + offset += qSplitSize * rndkvSplitSize * 4; + int32_t* dst_s32_data = reinterpret_cast(total_buf_ptr + offset); + offset += qSplitSize * rndHeadSize * 4; + accum_t* sfm_sum_ptr = reinterpret_cast(total_buf_ptr + offset); + offset += qSplitSize * 4; + int32_t* q_sum_ptr = reinterpret_cast(total_buf_ptr + offset); + offset += qSplitSize * 4; + int32_t* a_sum_ptr = reinterpret_cast(total_buf_ptr + offset); + offset += qSplitSize * 4; + accum_t* sfm_max_ptr = reinterpret_cast(total_buf_ptr + offset); + offset += qSplitSize * 4; + scalar_t* query_t_padding_ptr = reinterpret_cast(total_buf_ptr + offset); + offset += qSplitSize * qk_gemm_K; + + int32_t* k_sum_ptr = reinterpret_cast(total_buf_ptr + offset); + offset += kvSize * 4; + int32_t* v_sum_ptr = reinterpret_cast(total_buf_ptr + offset); + offset += headSize * 4; + scalar_t* key_reorder_ptr = reinterpret_cast(total_buf_ptr + offset); + offset += qk_gemm_K * rndkvSize; + scalar_t* value_reorder_ptr = reinterpret_cast(total_buf_ptr + offset); + + uint8_t* B_blocked_xform_u8 = new uint8_t[std::max(qk_gemm_K, av_gemm_K) * block_64]; + + for (const auto z : c10::irange(begin, end)) { + (void)z; // Suppress unused variable + + // sum k and v + if (q_zp == 0) { + fill_stub(k_sum_ptr, static_cast(0), kvSize); + } else { + _int_sum_b_contiguous_kernel(k_data + i * kStrideB + j * kStrideH, + k_sum_ptr, + kvSize, headSize, kStrideN, q_zp); + } + if (a_zp == 0) { + fill_stub(v_sum_ptr, static_cast(0), headSize); + } else { + _int_sum_a_contiguous_kernel(v_data + i * vStrideB + j * vStrideH, + v_sum_ptr, + headSize, kvSize, vStrideN, a_zp); + } + + // pack + for (int64_t n = 0; n < kvSize; n += kvSplitSize) { + if (n + kvSplitSize < kvSize) { + for (int64_t b = 0; b < rndkvSplitSize; b += block_64) { + bool tail = kvSplitSize - b < block_64; + do_transpose( + // do_copy( + k_data + i * kStrideB + j * kStrideH + n * kStrideN + b * kStrideN, + B_blocked_xform_u8, + tail ? kvSplitSize - b : block_64, + headSize, + kStrideN, + block_64); + if (!headSize_mul4) { + pad_remain_row_col( + B_blocked_xform_u8, + headSize, + block_64, + qk_gemm_K, + block_64, + block_64 + ); + } + // Pack + at::native::cpublas::pack( + qk_gemm_K, // K + block_64, // N + block_64, // ld_in + block_64, // ld_out + u8_dt, // dt_in + u8_dt, // dt_out + B_blocked_xform_u8, + key_reorder_ptr + n * qk_gemm_K + + b * qk_gemm_K); + } + // split headSize to block_64, block_64, block_64 ... + // [kvSplitSize, headSize] -> [av_gemm_K, block_64 ...] + for (int64_t b = 0; b < rndHeadSize; b += block_64) { + at::native::cpublas::pack( + av_gemm_K, + block_64, + vStrideN, // block_64, + block_64, + u8_dt, + u8_dt, + v_data + i * vStrideB + j * vStrideH + n * vStrideN + b, + value_reorder_ptr + n * rndHeadSize + + av_gemm_K * b); + } + } else { + // tail + auto block_size = kvTail >= block_64 ? block_64 : kv_tail_tail_block_size; + int64_t b = 0; + while (b < rndkvTail) { + bool tail = kvTail - b < block_size; + do_transpose( + k_data + i * kStrideB + j * kStrideH + n * kStrideN + b * kStrideN, + B_blocked_xform_u8, + tail ? kvTail - b : block_size, + headSize, + kStrideN, + block_size); + if (!headSize_mul4) { + pad_remain_row_col( + B_blocked_xform_u8, + headSize, + block_size, + qk_gemm_K, + block_size, + block_size + ); + } + // Pack + if (block_size == block_64) { + at::native::cpublas::pack( + qk_gemm_K, + block_64, + block_64, + block_64, + u8_dt, + u8_dt, + B_blocked_xform_u8, + key_reorder_ptr + n * qk_gemm_K + + b * qk_gemm_K); + } else { + at::native::cpublas::pack( + qk_gemm_K, + kv_tail_tail_block_size, + kv_tail_tail_block_size, + kv_tail_tail_block_size, + u8_dt, + u8_dt, + B_blocked_xform_u8, + key_reorder_ptr + n * qk_gemm_K + + b * qk_gemm_K); + } + b += block_size; + block_size = (kvTail - b) >= block_64 ? block_64 : kv_tail_tail_block_size; + } + // split headSize to block_64, block_64, block_64 ... + // [kvTail, headSize] -> [av_gemm_K_tail, block_64 ...] + for (int64_t b = 0; b < headSize; b += block_64) { + at::native::cpublas::pack( + av_gemm_K, + block_64, + vStrideN, // block_64, + block_64, + u8_dt, + u8_dt, + v_data + i * vStrideB + j * vStrideH + n * vStrideN + b, + value_reorder_ptr + n * rndHeadSize + + av_gemm_K * b); + } + } + } + + // sdpa core + for (int64_t k = 0; k < qSlice; k++) { + int64_t m = k * qSplitSize; + int64_t qBlockSize = std::min(qSplitSize, qSize - m); + // Initialize sum and max + fill_stub( + sfm_sum_ptr, static_cast(0), qSplitSize); + fill_stub( + a_sum_ptr, static_cast(0), qSplitSize); + fill_stub( + sfm_max_ptr, static_cast(-std::numeric_limits::infinity()), qSplitSize); + int64_t num_keys = + is_causal ? std::min(m + qBlockSize, kvSize) : kvSize; + copy_value_with_pad( + q_data + i * qStrideB + j * qStrideH + m * qStrideM, + query_t_padding_ptr, + qBlockSize, + headSize, + qBlockSize, + qk_gemm_K, + qStrideM); + + if (k_zp != 0) { + _int_sum_b_contiguous_kernel(q_data + i * qStrideB + j * qStrideH + m * qStrideM, + q_sum_ptr, qBlockSize, headSize, qStrideM, k_zp); + } else { + fill_stub( + q_sum_ptr, static_cast(0), qSplitSize); + } + const int64_t rkvSlice = (num_keys - 1) / kvSplitSize + 1; + for (int64_t l = 0; l < rkvSlice; l++) { + int64_t n = l * kvSplitSize; + int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n); + // Calculate sums for dequant compensation item + if (qBlockSize == qSplitSize) { + // q main + if (n + kvSplitSize < kvSize) { + // k main + for (int64_t b = 0; b < kvSplitSize; b += block_64) { + at::native::cpublas::brgemm( + qSplitSize, block_64, qk_gemm_K, + qk_gemm_K, // lda + block_64, //ldb + rndkvSplitSize, //ldc, + false, + query_t_padding_ptr, + key_reorder_ptr + n * qk_gemm_K + + b * qk_gemm_K, + qk_s32_data + b); + } + } else { + auto block_size = kvTail >= block_64 ? block_64 : kv_tail_tail_block_size; + int64_t b = 0; + while (b < kvTail) { + if (block_size == block_64) { + at::native::cpublas::brgemm( + qSplitSize, block_64, qk_gemm_K, + qk_gemm_K, // lda + block_64, //ldb + rndkvTail, //ldc + false, + query_t_padding_ptr, + key_reorder_ptr + n * qk_gemm_K + + b * qk_gemm_K, + qk_s32_data + b); + } else { + at::native::cpublas::brgemm( + qSplitSize, kv_tail_tail_block_size, qk_gemm_K, + qk_gemm_K, // lda + kv_tail_tail_block_size, //ldb + rndkvTail, //ldc + false, + query_t_padding_ptr, + key_reorder_ptr + n * qk_gemm_K + + b * qk_gemm_K, + qk_s32_data + b); + } + b += block_size; + block_size = (kvTail - b) >= block_64 ? block_64 : kv_tail_tail_block_size; + } + } + } else { + if (n + kvSplitSize < kvSize) { + for (int64_t b = 0; b < kvSplitSize; b += block_64) { + at::native::cpublas::brgemm( + qTail, block_64, qk_gemm_K, + qk_gemm_K,// lda + block_64, //ldb + rndkvSplitSize, //ldc + false, + query_t_padding_ptr, + key_reorder_ptr + n * qk_gemm_K + + b * qk_gemm_K, + qk_s32_data + b); + } + } else { + // k tail + auto block_size = kvTail >= block_64 ? block_64 : kv_tail_tail_block_size; + int64_t b = 0; + while (b < kvTail) { + if (block_size == block_64) { + at::native::cpublas::brgemm( + qTail, block_64, qk_gemm_K, + qk_gemm_K, // lda + block_64, //ldb + rndkvTail, //ldc + false, + query_t_padding_ptr, + key_reorder_ptr + n * qk_gemm_K + + b * qk_gemm_K, + qk_s32_data + b); + } else { + at::native::cpublas::brgemm( + qSplitSize, kv_tail_tail_block_size, qk_gemm_K, + qk_gemm_K, // lda + kv_tail_tail_block_size, //ldb + rndkvTail, //ldc + false, + query_t_padding_ptr, + key_reorder_ptr + n * qk_gemm_K + + b * qk_gemm_K, + qk_s32_data + b); + } + b += block_size; + block_size = (kvTail - b) >= block_64 ? block_64 : kv_tail_tail_block_size; + } + } + } + + // do dequant compensation, add mask, max reduce for softmax, and convert qk from s32 to fp32 + int64_t rndkvBlockSize = kvBlockSize == kvSplitSize ? rndkvSplitSize : rndkvTail; + accum_t* qk_block_data = qk_data + l * qSplitSize * rndkvSplitSize; + if (has_attn_mask) { + mask_t* mask_data_offset = mask_data + i * mStrideB + j * mStrideH + m * mStrideM + (mStrideN == 0 ? 0 : n); + _dequant_mask_max_fusion_kernel( + qk_s32_data, //in + mask_data_offset, //mask_ptr + q_sum_ptr, //sum_a_ptr + k_sum_ptr + n, //sum_b_ptr + qBlockSize, //M + kvBlockSize, //N + rndkvBlockSize, //ldi + mStrideM, //ldm + rndkvSplitSize,//kvBlockSize, //ldo + q_zp * k_zp * headSize, //zp_a*zp_b*k=beta + q_scale * k_scale * scaling_factor, //scale_a*scale_b*scale_sdpa=alpha + qk_block_data, //out + sfm_max_ptr // sfm_max_ptr + ); + } else { + _dequant_max_fusion_kernel( + qk_s32_data, //in + q_sum_ptr, //sum_a_ptr + k_sum_ptr + n, //sum_b_ptr + qBlockSize, //M + kvBlockSize, //N + rndkvBlockSize, //ldi + rndkvSplitSize,//kvBlockSize, //ldo + q_zp * k_zp * headSize, //zp_a*zp_b*k=beta + q_scale * k_scale * scaling_factor, //scale_a*scale_b*scale_sdpa=alpha + qk_block_data, //out + sfm_max_ptr // sfm_max_ptr + ); + } + } + // sub max, exp, sum reduce, div sum for softmax + // and quant + // and sum for attention + if (v_zp == 0) { + _sub_exp_sum_div_quant_fusion_kernel( + qk_data, //in + qBlockSize, //M + kvSplitSize, //N_step + rkvSlice, //NSlices + qSplitSize * rndkvSplitSize, //ldi + qk_reduce_strideL, //ldo + kvSize, //kvSize + rndkvSplitSize, //rndkvSplitSize + av_gemm_K, //av_gemm_K + a_zp, // zp_a=beta1 + a_scale, // scale_a=alpha + qk_local_data, //local + qk_reduced_data, //out + sfm_max_ptr, //sfm_max_ptr + sfm_sum_ptr //sfm_sum_ptr + ); + } else { + _sub_exp_sum_div_quant_sum_fusion_kernel( + qk_data, //in + qBlockSize, //M + kvSplitSize, //N_step + rkvSlice, //NSlice + qSplitSize * rndkvSplitSize, //ldi + qk_reduce_strideL, //ldo + kvSize, //kvSize + rndkvSplitSize, //rndkvSplitSize + av_gemm_K, //av_gemm_K + a_zp, // zp_a=beta1 + v_zp, // zp_b=beta2 + a_scale, // scale_a=alpha + qk_local_data, //local + qk_reduced_data, //out + sfm_max_ptr, //sfm_max_ptr + sfm_sum_ptr, //sfm_sum_ptr + a_sum_ptr //a_sum_ptr + ); + } + // Calculate Softmax(q @ k.T) @ v + for (int64_t b = 0; b < headSize; b += block_64) { + auto value_reorder_b = value_reorder_ptr + b * av_gemm_K; + auto dst_s32_b = dst_s32_data + b; + for (int64_t s = 0; s < kvSlice; s++) { + at::native::cpublas::brgemm( + qSplitSize, block_64, av_gemm_K, + av_gemm_K, // lda + rndHeadSize, //block_64, //ldb + rndHeadSize, //ldc + s != 0, + qk_reduced_data + s * qk_reduce_strideL, + value_reorder_b + s * v_reorder_strideL, + dst_s32_b); + } + } + + // After the last gemm, + // do dequant compensation, quant and convert from s32 to int8 + _dequant_quant_fusion_kernel( + dst_s32_data, //in + a_sum_ptr, //sum_a_ptr + v_sum_ptr, //sum_b_ptr + qBlockSize, //M + headSize, //N + rndHeadSize, //ldi + oStrideM, //ldo + a_zp * v_zp * kvSize, //zp_a*zp_b*k=beta1 + o_zp, //zp_c=beta2 + a_scale * v_scale / o_scale, //scale_a*scale_b/scale_c=alpha + out_data + i * oStrideB + j * oStrideH + m * oStrideM //out + ); + } + // Move to the next query + at::native::data_index_step(i, batchSize, j, num_head); + } + }); + // Once all computations are done, need to release HW context. + at::native::cpublas::brgemm_release(); +} + +// UINT8 - u8u8s32 +template +inline typename std::enable_if_t, void> +sdpa_int8_kernel_small_impl( + const at::Tensor& output, + const at::Tensor& q, + const at::Tensor& k, + const at::Tensor& v, + double dropout_p, + bool is_causal, + at::Tensor& attention_mask, + double scale, + int32_t q_zp, + float q_scale, + int32_t k_zp, + float k_scale, + int32_t v_zp, + float v_scale, + int32_t a_zp, + float a_scale, + int32_t o_zp, + float o_scale) { + // Query (Batch x Num_heads x Q_seq_len x Dim_per_head) + // -> (Batch x Q_seq_len x Num_heads x Dim_per_head) + // Key (Batch x Num_heads x KV_seq_len x Dim_per_head) + // -> (Batch x KV_seq_len x Num_heads x Dim_per_head) + // Value (Batch x Num_heads x KV_seq_len x Dim_per_head) + // -> (Batch x KV_seq_len x Num_heads x Dim_per_head) + at::Tensor query = q.transpose(1, 2); + at::Tensor key = k.transpose(1, 2); + at::Tensor value = v.transpose(1, 2); + + const auto accumulate_dtype = at::kFloat; + + using accum_t = float; + using Vec = at::vec::Vectorized; + accum_t scaling_factor = calculate_scale(query, scale); + int block_64 = 64; + // Sizes + TORCH_CHECK( + (query.size(3) == value.size(3)) && (key.size(3) == value.size(3)), + "scaled_dot_product_attention_sdpa: Q/K/V should have the same head size"); + TORCH_CHECK( + kv_split_size % block_64 == 0, "kv_split_size is not divisble by ", block_64); + + int64_t batchSize = query.size(0); + int64_t qSize = query.size(1); + int64_t kvSize = value.size(1); + int64_t num_head = query.size(2); + int64_t headSize = query.size(3); + + + bool has_attn_mask = attention_mask.defined() && attention_mask.numel(); + if (has_attn_mask) { + reshape_attn_mask_to_4d(attention_mask, batchSize, num_head, qSize, kvSize); } + // Strides + int64_t qStrideB = query.stride(0); + int64_t qStrideM = query.stride(1); + int64_t qStrideH = query.stride(2); + int64_t kStrideB = key.stride(0); + int64_t kStrideN = key.stride(1); + int64_t kStrideH = key.stride(2); + int64_t vStrideB = value.stride(0); + int64_t vStrideN = value.stride(1); + int64_t vStrideH = value.stride(2); + int64_t oStrideB = output.stride(0); + int64_t oStrideM = output.stride(1); + int64_t oStrideH = output.stride(2); + int64_t mStrideB = + (attention_mask.defined() && attention_mask.size(0) > 1) + ? attention_mask.stride(0) + : 0; + int64_t mStrideH = + (attention_mask.defined() && attention_mask.size(1) > 1) + ? attention_mask.stride(1) + : 0; + int64_t mStrideM = + (attention_mask.defined() && attention_mask.size(2) > 1) + ? attention_mask.stride(2) + : 0; + int64_t mStrideN = + (attention_mask.defined() && attention_mask.size(3) > 1) + ? attention_mask.stride(3) + : 0; + + int64_t qSplitSize = q_split_size > qSize ? qSize : q_split_size; + int64_t kvSplitSize = kv_split_size > kvSize ? kvSize : kv_split_size; + int64_t qSlice = (qSize - 1) / qSplitSize + 1; + int64_t qTail = (qSize - 1) % qSplitSize + 1; + int64_t kvSlice = (kvSize - 1) / kvSplitSize + 1; + int64_t kvTail = (kvSize - 1) % kvSplitSize + 1; + int64_t num_thread = at::get_num_threads(); + + int64_t rndHeadSize = (headSize + block_64 - 1L) / block_64 * block_64; + // one of 16, 32, 48, 64 + auto select_tail_tail_block_size = [](int64_t size) -> int64_t { + if (size == 0) { + return 0; + } else if (size <= 16) { + return 16; + } else if (size <= 32) { + return 32; + } else if (size <= 48) { + return 48; + } else { + return 64; + } + }; + int64_t kv_tail_tail_block_size = select_tail_tail_block_size(kvTail % block_64); + int64_t rndkvSplitSize = (kvSplitSize + block_64 - 1L) / block_64 * block_64; + int64_t rndkvTail = (kvTail + block_64 - 1L) / block_64 * block_64; + int64_t rndkvSize = kv_split_size > kvSize ? rndkvTail : rndkvSplitSize * kvSlice + rndkvTail; + + bool av_gemm_K_mul4 = kvSplitSize % 4 == 0; + int av_gemm_K_padding = av_gemm_K_mul4 ? 0 : 4 - kvSplitSize % 4; + // // If K of Gemm is not even, use mkl gemm instead of tpp for BF16 + int av_gemm_K = kvSplitSize + av_gemm_K_padding; + bool av_gemm_K_tail_mul4 = kvTail % 4 == 0; + int av_gemm_K_tail_padding = av_gemm_K_tail_mul4 ? 0 : 4 - kvTail % 4; + int av_gemm_K_tail = kvTail + av_gemm_K_tail_padding; + + auto u8_dt = at::ScalarType::Byte; + auto s8_dt = at::ScalarType::Int; + auto f32_dt = at::ScalarType::Float; + + // Data ptrs + scalar_t* q_data = query.data_ptr(); + scalar_t* k_data = key.data_ptr(); + scalar_t* v_data = value.data_ptr(); + mask_t* mask_data = attention_mask.defined() + ? attention_mask.data_ptr() + : nullptr; + scalar_t* out_data = output.data_ptr(); + + // Create tpp kernels for Query @ Key + bool headSize_mul4 = headSize % 4 == 0; + // If K of Gemm is not even, use mkl gemm instead of tpp for BF16 + int qk_gemm_K_padding = headSize_mul4 ? 0 : 4 - headSize % 4; + int qk_gemm_K = headSize + qk_gemm_K_padding; + + int64_t qk_reduce_strideL = qSplitSize * av_gemm_K; + int64_t v_reorder_strideL = av_gemm_K * rndHeadSize; + int64_t total_size_uint8_per_thread = /* qk */ kvSlice * qSplitSize * rndkvSplitSize * 4 + /* qk_local */ kvSlice * av_gemm_K * 4 + - /* qk_reduce */ kvSlice * qSplitSize * av_gemm_K + + /* qk_reduce */ kvSlice * qk_reduce_strideL + /* qk_s32 */ qSplitSize * rndkvSplitSize * 4 + /* dst_s32 */ qSplitSize * rndHeadSize * 4 + /* softmax_sum */ qSplitSize * 4 + @@ -1138,7 +1721,7 @@ sdpa_int8_kernel_impl( int64_t kv_reorder_size_per_BH = /* key_t_reorder */ qk_gemm_K * rndkvSize + - /* value_t_reorder */ kvSlice * av_gemm_K * rndHeadSize; + /* value_t_reorder */ kvSlice * v_reorder_strideL; at::Tensor kv_reorder_buf = at::empty( {batchSize, num_head, kv_reorder_size_per_BH}, @@ -1189,6 +1772,11 @@ sdpa_int8_kernel_impl( for (const auto z : c10::irange(begin, end)) { (void)z; // Suppress unused variable n = l * kvSplitSize; + auto k_reorder = key_reorder_ptr + i * num_head * qk_gemm_K * rndkvSize + + j * qk_gemm_K * rndkvSize + n * qk_gemm_K; + auto v_reorder = value_reorder_ptr + + i * num_head * kvSlice * v_reorder_strideL + + j * kvSlice * v_reorder_strideL + n * rndHeadSize; if (n + kvSplitSize < kvSize) { for (int64_t b = 0; b < rndkvSplitSize; b += block_64) { bool tail = kvSplitSize - b < block_64; @@ -1217,9 +1805,7 @@ sdpa_int8_kernel_impl( u8_dt, // dt_in u8_dt, // dt_out B_blocked_xform_u8, - key_reorder_ptr + i * num_head * qk_gemm_K * rndkvSize + - j * qk_gemm_K * rndkvSize + n * qk_gemm_K + - b * qk_gemm_K); + k_reorder + b * qk_gemm_K); } // split headSize to block_64, block_64, block_64 ... // [kvSplitSize, headSize] -> [av_gemm_K, block_64 ...] @@ -1232,10 +1818,7 @@ sdpa_int8_kernel_impl( u8_dt, u8_dt, v_data + i * vStrideB + j * vStrideH + n * vStrideN + b, - value_reorder_ptr + - i * num_head * kvSlice * av_gemm_K * rndHeadSize + - j * kvSlice * av_gemm_K * rndHeadSize + n * rndHeadSize + - av_gemm_K * b); + v_reorder + av_gemm_K * b); } } else { // tail @@ -1270,9 +1853,7 @@ sdpa_int8_kernel_impl( u8_dt, u8_dt, B_blocked_xform_u8, - key_reorder_ptr + i * num_head * qk_gemm_K * rndkvSize + - j * qk_gemm_K * rndkvSize + n * qk_gemm_K + - b * qk_gemm_K); + k_reorder + b * qk_gemm_K); } else { at::native::cpublas::pack( qk_gemm_K, @@ -1282,9 +1863,7 @@ sdpa_int8_kernel_impl( u8_dt, u8_dt, B_blocked_xform_u8, - key_reorder_ptr + i * num_head * qk_gemm_K * rndkvSize + - j * qk_gemm_K * rndkvSize + n * qk_gemm_K + - b * qk_gemm_K); + k_reorder + b * qk_gemm_K); } b += block_size; block_size = (kvTail - b) >= block_64 ? block_64 : kv_tail_tail_block_size; @@ -1300,10 +1879,7 @@ sdpa_int8_kernel_impl( u8_dt, u8_dt, v_data + i * vStrideB + j * vStrideH + n * vStrideN + b, - value_reorder_ptr + - i * num_head * kvSlice * av_gemm_K * rndHeadSize + - j * kvSlice * av_gemm_K * rndHeadSize + n * rndHeadSize + - av_gemm_K * b); + v_reorder + av_gemm_K * b); } } // Move to the next query @@ -1324,7 +1900,7 @@ sdpa_int8_kernel_impl( accum_t* qk_local_data = reinterpret_cast(total_buf_ptr + offset); offset += kvSlice * av_gemm_K * 4; scalar_t* qk_reduced_data = reinterpret_cast(total_buf_ptr + offset); - offset += kvSlice * qSplitSize * av_gemm_K; + offset += kvSlice * qk_reduce_strideL; int32_t* qk_s32_data = reinterpret_cast(total_buf_ptr + offset); offset += qSplitSize * rndkvSplitSize * 4; int32_t* dst_s32_data = reinterpret_cast(total_buf_ptr + offset); @@ -1380,6 +1956,8 @@ sdpa_int8_kernel_impl( for (int64_t l = 0; l < rkvSlice; l++) { int64_t n = l * kvSplitSize; int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n); + auto k_reorder = key_reorder_ptr + i * num_head * qk_gemm_K * rndkvSize + + j * qk_gemm_K * rndkvSize + n * qk_gemm_K; // Calculate sums for dequant compensation item if (qBlockSize == qSplitSize) { // q main @@ -1388,17 +1966,13 @@ sdpa_int8_kernel_impl( for (int64_t b = 0; b < kvSplitSize; b += block_64) { at::native::cpublas::brgemm( qSplitSize, block_64, qk_gemm_K, - 1, //batch_size qk_gemm_K, // lda block_64, //ldb rndkvSplitSize, //ldc, false, query_t_padding_ptr, - key_reorder_ptr + i * num_head * qk_gemm_K * rndkvSize + - j * qk_gemm_K * rndkvSize + n * qk_gemm_K + - b * qk_gemm_K, - qk_s32_data + b, - A_B_offsets); + k_reorder + b * qk_gemm_K, + qk_s32_data + b); } } else { auto block_size = kvTail >= block_64 ? block_64 : kv_tail_tail_block_size; @@ -1407,31 +1981,23 @@ sdpa_int8_kernel_impl( if (block_size == block_64) { at::native::cpublas::brgemm( qSplitSize, block_64, qk_gemm_K, - 1, //batch_size qk_gemm_K, // lda block_64, //ldb rndkvTail, //ldc false, query_t_padding_ptr, - key_reorder_ptr + i * num_head * qk_gemm_K * rndkvSize + - j * qk_gemm_K * rndkvSize + n * qk_gemm_K + - b * qk_gemm_K, - qk_s32_data + b, - A_B_offsets); + k_reorder + b * qk_gemm_K, + qk_s32_data + b); } else { at::native::cpublas::brgemm( qSplitSize, kv_tail_tail_block_size, qk_gemm_K, - 1, //batch_size qk_gemm_K, // lda kv_tail_tail_block_size, //ldb rndkvTail, //ldc false, query_t_padding_ptr, - key_reorder_ptr + i * num_head * qk_gemm_K * rndkvSize + - j * qk_gemm_K * rndkvSize + n * qk_gemm_K + - b * qk_gemm_K, - qk_s32_data + b, - A_B_offsets); + k_reorder + b * qk_gemm_K, + qk_s32_data + b); } b += block_size; block_size = (kvTail - b) >= block_64 ? block_64 : kv_tail_tail_block_size; @@ -1442,17 +2008,13 @@ sdpa_int8_kernel_impl( for (int64_t b = 0; b < kvSplitSize; b += block_64) { at::native::cpublas::brgemm( qTail, block_64, qk_gemm_K, - 1, //batch_size qk_gemm_K,// lda block_64, //ldb rndkvSplitSize, //ldc false, query_t_padding_ptr, - key_reorder_ptr + i * num_head * qk_gemm_K * rndkvSize + - j * qk_gemm_K * rndkvSize + n * qk_gemm_K + - b * qk_gemm_K, - qk_s32_data + b, - A_B_offsets); + k_reorder + b * qk_gemm_K, + qk_s32_data + b); } } else { // k tail @@ -1462,34 +2024,26 @@ sdpa_int8_kernel_impl( if (block_size == block_64) { at::native::cpublas::brgemm( qTail, block_64, qk_gemm_K, - 1, //batch_size qk_gemm_K, // lda block_64, //ldb rndkvTail, //ldc false, query_t_padding_ptr, - key_reorder_ptr + i * num_head * qk_gemm_K * rndkvSize + - j * qk_gemm_K * rndkvSize + n * qk_gemm_K + - b * qk_gemm_K, - qk_s32_data + b, - A_B_offsets); + k_reorder + b * qk_gemm_K, + qk_s32_data + b); } else { at::native::cpublas::brgemm( qSplitSize, kv_tail_tail_block_size, qk_gemm_K, - 1, //batch_size qk_gemm_K, // lda kv_tail_tail_block_size, //ldb rndkvTail, //ldc false, query_t_padding_ptr, - key_reorder_ptr + i * num_head * qk_gemm_K * rndkvSize + - j * qk_gemm_K * rndkvSize + n * qk_gemm_K + - b * qk_gemm_K, - qk_s32_data + b, - A_B_offsets); + k_reorder + b * qk_gemm_K, + qk_s32_data + b); } - b += block_size; - block_size = (kvTail - b) >= block_64 ? block_64 : kv_tail_tail_block_size; + b += block_size; + block_size = (kvTail - b) >= block_64 ? block_64 : kv_tail_tail_block_size; } } } @@ -1527,7 +2081,7 @@ sdpa_int8_kernel_impl( q_scale * k_scale * scaling_factor, //scale_a*scale_b*scale_sdpa=alpha qk_block_data, //out sfm_max_ptr // sfm_max_ptr - ); + ); } } // sub max, exp, sum reduce, div sum for softmax @@ -1540,7 +2094,7 @@ sdpa_int8_kernel_impl( kvSplitSize, //N_step rkvSlice, //NSlices qSplitSize * rndkvSplitSize, //ldi - qSplitSize * av_gemm_K, //ldo + qk_reduce_strideL, //ldo kvSize, //kvSize rndkvSplitSize, //rndkvSplitSize av_gemm_K, //av_gemm_K @@ -1558,7 +2112,7 @@ sdpa_int8_kernel_impl( kvSplitSize, //N_step rkvSlice, //NSlice qSplitSize * rndkvSplitSize, //ldi - qSplitSize * av_gemm_K, //ldo + qk_reduce_strideL, //ldo kvSize, //kvSize rndkvSplitSize, //rndkvSplitSize av_gemm_K, //av_gemm_K @@ -1573,20 +2127,23 @@ sdpa_int8_kernel_impl( ); } // Calculate Softmax(q @ k.T) @ v + auto v_reorder = value_reorder_ptr + + i * num_head * kvSlice * v_reorder_strideL + + j * kvSlice * v_reorder_strideL; for (int64_t b = 0; b < headSize; b += block_64) { - at::native::cpublas::brgemm( + auto value_reorder_b = v_reorder + b * av_gemm_K; + auto dst_s32_b = dst_s32_data + b; + for (int64_t s = 0; s < kvSlice; s++) { + at::native::cpublas::brgemm( qSplitSize, block_64, av_gemm_K, - kvSlice, //batch_size av_gemm_K, // lda rndHeadSize, //block_64, //ldb rndHeadSize, //ldc - false, - qk_reduced_data, - value_reorder_ptr + - i * num_head * kvSlice * av_gemm_K * rndHeadSize + - j * kvSlice * av_gemm_K * rndHeadSize + b * av_gemm_K, - dst_s32_data + b, - A_B_offsets_batch); + s != 0, + qk_reduced_data + s * qk_reduce_strideL, + value_reorder_b + s * v_reorder_strideL, + dst_s32_b); + } } // After the last gemm, @@ -1651,39 +2208,73 @@ void sdpa_int8_kernel( int64_t q_seq_len = query.size(2); TORCH_CHECK(query.scalar_type() == c10::kByte); - if (!attn_mask.defined()) { - if (q_seq_len >= 768) { - sdpa_int8_kernel_impl( - output, query, key, value, - dropout_p, is_causal, attn_mask, scale, - q_zp, q_scale, - k_zp, k_scale, - v_zp, v_scale, - a_zp, a_scale, - o_zp, o_scale); - } else if (q_seq_len >= 192) { - sdpa_int8_kernel_impl( - output, query, key, value, - dropout_p, is_causal, attn_mask, scale, - q_zp, q_scale, - k_zp, k_scale, - v_zp, v_scale, - a_zp, a_scale, - o_zp, o_scale); + bool use_one_parallel_loop = batchSize * num_head % 40 == 0; + if (use_one_parallel_loop) { + if (!attn_mask.defined()) { + if (q_seq_len >= 768) { + sdpa_int8_kernel_large_impl( + output, query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_zp, q_scale, + k_zp, k_scale, + v_zp, v_scale, + a_zp, a_scale, + o_zp, o_scale); + } else if (q_seq_len >= 192) { + sdpa_int8_kernel_large_impl( + output, query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_zp, q_scale, + k_zp, k_scale, + v_zp, v_scale, + a_zp, a_scale, + o_zp, o_scale); + } else { + sdpa_int8_kernel_large_impl( + output, query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_zp, q_scale, + k_zp, k_scale, + v_zp, v_scale, + a_zp, a_scale, + o_zp, o_scale); + } } else { - sdpa_int8_kernel_impl( - output, query, key, value, - dropout_p, is_causal, attn_mask, scale, - q_zp, q_scale, - k_zp, k_scale, - v_zp, v_scale, - a_zp, a_scale, - o_zp, o_scale); + AT_DISPATCH_MASK_TYPES(attn_mask.scalar_type(), "sdpa_mask", [&]() { + if (q_seq_len >= 768) { + sdpa_int8_kernel_large_impl( + output, query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_zp, q_scale, + k_zp, k_scale, + v_zp, v_scale, + a_zp, a_scale, + o_zp, o_scale); + } else if (q_seq_len >= 192) { + sdpa_int8_kernel_large_impl( + output, query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_zp, q_scale, + k_zp, k_scale, + v_zp, v_scale, + a_zp, a_scale, + o_zp, o_scale); + } else { + sdpa_int8_kernel_large_impl( + output, query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_zp, q_scale, + k_zp, k_scale, + v_zp, v_scale, + a_zp, a_scale, + o_zp, o_scale); + } + }); } } else { - AT_DISPATCH_MASK_TYPES(attn_mask.scalar_type(), "sdpa_mask", [&]() { + if (!attn_mask.defined()) { if (q_seq_len >= 768) { - sdpa_int8_kernel_impl( + sdpa_int8_kernel_small_impl( output, query, key, value, dropout_p, is_causal, attn_mask, scale, q_zp, q_scale, @@ -1692,7 +2283,7 @@ void sdpa_int8_kernel( a_zp, a_scale, o_zp, o_scale); } else if (q_seq_len >= 192) { - sdpa_int8_kernel_impl( + sdpa_int8_kernel_small_impl( output, query, key, value, dropout_p, is_causal, attn_mask, scale, q_zp, q_scale, @@ -1701,7 +2292,7 @@ void sdpa_int8_kernel( a_zp, a_scale, o_zp, o_scale); } else { - sdpa_int8_kernel_impl( + sdpa_int8_kernel_small_impl( output, query, key, value, dropout_p, is_causal, attn_mask, scale, q_zp, q_scale, @@ -1710,7 +2301,38 @@ void sdpa_int8_kernel( a_zp, a_scale, o_zp, o_scale); } - }); + } else { + AT_DISPATCH_MASK_TYPES(attn_mask.scalar_type(), "sdpa_mask", [&]() { + if (q_seq_len >= 768) { + sdpa_int8_kernel_small_impl( + output, query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_zp, q_scale, + k_zp, k_scale, + v_zp, v_scale, + a_zp, a_scale, + o_zp, o_scale); + } else if (q_seq_len >= 192) { + sdpa_int8_kernel_small_impl( + output, query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_zp, q_scale, + k_zp, k_scale, + v_zp, v_scale, + a_zp, a_scale, + o_zp, o_scale); + } else { + sdpa_int8_kernel_small_impl( + output, query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_zp, q_scale, + k_zp, k_scale, + v_zp, v_scale, + a_zp, a_scale, + o_zp, o_scale); + } + }); + } } } diff --git a/torchao/quantization/sfdp_int8_fx_pass.py b/torchao/quantization/sfdp_int8_fx_pass.py index 1dedd60ee5..9ff7dc7e8c 100644 --- a/torchao/quantization/sfdp_int8_fx_pass.py +++ b/torchao/quantization/sfdp_int8_fx_pass.py @@ -134,16 +134,19 @@ def _sfdp_pattern_int8_2( # int8-mix-reduce QUANTIZED SDPA with mask q = query.permute([0, 2, 1, 3]) q = torch.ops.quantized_decomposed.dequantize_per_tensor.default( - q, float(q_scale), int(q_zp), 0, 255, torch.uint8 - ).to(torch.float16) + q, float(q_scale), int(q_zp), 0, 255, + # torch.uint8).to(torch.bfloat16) + torch.uint8, out_dtype=torch.bfloat16).to(torch.bfloat16) k = key.permute([0, 2, 1, 3]).transpose(-2, -1) k = torch.ops.quantized_decomposed.dequantize_per_tensor.default( - k, float(k_scale), int(k_zp), 0, 255, torch.uint8 - ).to(torch.float16) + k, float(k_scale), int(k_zp), 0, 255, + # torch.uint8).to(torch.bfloat16) + torch.uint8, out_dtype=torch.bfloat16).to(torch.bfloat16) v = value.permute([0, 2, 1, 3]) v = torch.ops.quantized_decomposed.dequantize_per_tensor.default( - v, float(v_scale), int(v_zp), 0, 255, torch.uint8 - ).to(torch.float16) + v, float(v_scale), int(v_zp), 0, 255, + # torch.uint8).to(torch.bfloat16) + torch.uint8, out_dtype=torch.bfloat16).to(torch.bfloat16) a = torch.nn.functional.dropout( (torch.matmul(q, k).div(inv_scale) + attn_mask).softmax(dim=-1), dropout, @@ -152,8 +155,9 @@ def _sfdp_pattern_int8_2( a, float(a_scale), int(a_zp), 0, 255, torch.uint8 ) a = torch.ops.quantized_decomposed.dequantize_per_tensor.default( - qa, float(a_scale), int(a_zp), 0, 255, torch.uint8 - ).to(torch.float16) + qa, float(a_scale), int(a_zp), 0, 255, + torch.uint8).to(torch.bfloat16) + # torch.uint8, out_dtype=torch.bfloat16) o = a.matmul(v) o = o.permute(0, 2, 1, 3).contiguous() return torch.ops.quantized_decomposed.quantize_per_tensor.default( @@ -310,16 +314,16 @@ def _sfdp_pattern_int8_4( # int8-mix-reduce QUANTIZED SDPA without mask q = query.permute([0, 2, 1, 3]) q = torch.ops.quantized_decomposed.dequantize_per_tensor.default( - q, float(q_scale), int(q_zp), 0, 255, torch.uint8 - ).to(torch.float16) + q, float(q_scale), int(q_zp), 0, 255, + torch.uint8, out_dtype=torch.bfloat16) k = key.permute([0, 2, 1, 3]).transpose(-2, -1) k = torch.ops.quantized_decomposed.dequantize_per_tensor.default( - k, float(k_scale), int(k_zp), 0, 255, torch.uint8 - ).to(torch.float16) + k, float(k_scale), int(k_zp), 0, 255, + torch.uint8, out_dtype=torch.bfloat16) v = value.permute([0, 2, 1, 3]) v = torch.ops.quantized_decomposed.dequantize_per_tensor.default( - v, float(v_scale), int(v_zp), 0, 255, torch.uint8 - ).to(torch.float16) + v, float(v_scale), int(v_zp), 0, 255, + torch.uint8, out_dtype=torch.bfloat16) a = torch.nn.functional.dropout( torch.matmul(q, k).div(inv_scale).softmax(dim=-1), dropout, @@ -328,8 +332,8 @@ def _sfdp_pattern_int8_4( a, float(a_scale), int(a_zp), 0, 255, torch.uint8 ) a = torch.ops.quantized_decomposed.dequantize_per_tensor.default( - qa, float(a_scale), int(a_zp), 0, 255, torch.uint8 - ).to(torch.float16) + qa, float(a_scale), int(a_zp), 0, 255, + torch.uint8, out_dtype=torch.bfloat16) o = a.matmul(v) o = o.permute(0, 2, 1, 3).contiguous() return torch.ops.quantized_decomposed.quantize_per_tensor.default( @@ -450,11 +454,11 @@ def _gen_sfdp_patterns_int8(): torch.empty, (1, 4, 8, 16), device=device, requires_grad=True ) m_bs1_inp = functools.partial(torch.empty, (1, 1, 1, 4), device=device) - for dtype in [torch.float, torch.half]: + for dtype in [torch.float, torch.bfloat16]: g_u8 = functools.partial(g_inp, dtype=torch.uint8, requires_grad=False) g_u8_bs1 = functools.partial(g_bs1_inp, dtype=torch.uint8, requires_grad=False) - m = functools.partial(m_inp, dtype=dtype) - m_bs1 = functools.partial(m_bs1_inp, dtype=dtype) + m = functools.partial(m_inp, dtype=torch.float) + m_bs1 = functools.partial(m_bs1_inp, dtype=torch.float) c = functools.partial(c_inp, dtype=dtype) zp = functools.partial(zp_inp, dtype=torch.int) scale = functools.partial(scale_inp, dtype=torch.float) From 65ae166cc9393ad40306f6310480d649efa79243 Mon Sep 17 00:00:00 2001 From: Valentine233 Date: Wed, 8 Jan 2025 01:16:38 -0500 Subject: [PATCH 07/36] update int8 sdpa cpu --- setup.py | 4 +- torchao/csrc/cpu/sdpa.cpp | 84 ++++++++++++++++++++++++++++++++++----- 2 files changed, 76 insertions(+), 12 deletions(-) diff --git a/setup.py b/setup.py index 396f869e93..0b758b5e81 100644 --- a/setup.py +++ b/setup.py @@ -45,8 +45,8 @@ def read_version(file_path="version.txt"): if version_suffix is None: version_suffix = f"+git{get_git_commit_id()}" -use_cpp = os.getenv("USE_CPP") -use_cpp_avx512 = os.getenv('USE_AVX512', 1) +use_cpp = os.getenv('USE_CPP') +use_cpp_avx512 = os.getenv('USE_AVX512', '1') == '1' import platform diff --git a/torchao/csrc/cpu/sdpa.cpp b/torchao/csrc/cpu/sdpa.cpp index 7ef2e3e471..c2db4fa389 100644 --- a/torchao/csrc/cpu/sdpa.cpp +++ b/torchao/csrc/cpu/sdpa.cpp @@ -43,6 +43,7 @@ inline double calculate_scale( return scale == 0.0 ? 1.0 / std::sqrt(query.size(-1)) : scale; } +#ifdef CPU_CAPABILITY_AVX512 // out = val * a + b // is_b_stride_zero: If the stride of b is 0 (mask broadcasting case), // take b as a scalar pointer. @@ -2184,7 +2185,7 @@ sdpa_int8_kernel_small_impl( AT_PRIVATE_CASE_TYPE_USING_HINT( \ at::ScalarType::Half, mask_t, __VA_ARGS__)) -void sdpa_int8_kernel( +void sdpa_int8_fused_kernel( const at::Tensor& output, const at::Tensor& query, const at::Tensor& key, @@ -2335,6 +2336,50 @@ void sdpa_int8_kernel( } } } +#endif // CPU_CAPABILITY_AVX512 + +at::Tensor sdpa_int8_math_kernel( + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + double dropout_p, + bool is_causal, + at::Tensor& attn_mask, + double scale, + int32_t q_zp, + float q_scale, + int32_t k_zp, + float k_scale, + int32_t v_zp, + float v_scale, + int32_t a_zp, + float a_scale, + int32_t o_zp, + float o_scale) { + // dequant q/k/v + auto q = (query.to(at::kFloat) - q_zp) * q_scale; + auto k = (key.to(at::kFloat) - k_zp) * k_scale; + auto v = (value.to(at::kFloat) - v_zp) * v_scale; + const auto scaling_factor = calculate_scale(q, scale); + auto attn = at::matmul(q, k.transpose(-2, -1)) * scaling_factor; + if (attn_mask.defined() && attn_mask.numel()) { + attn = attn.add(attn_mask.to(at::kFloat)); + } + attn = at::softmax(attn, -1); + // quant attn + attn = at::clamp_max( + at::clamp_min(at::round(attn / a_scale) + a_zp, 0), 255 + ); + // dequant attn + attn = (attn - a_zp) * a_scale; + auto output = at::matmul(attn, v); + // quant output + output = at::clamp_max( + at::clamp_min(at::round(output / o_scale) + o_zp, 0), 255 + ).to(at::kByte); + return output; +} + at::Tensor _scaled_dot_product_int8_cpu( const at::Tensor& query, @@ -2377,16 +2422,35 @@ at::Tensor _scaled_dot_product_int8_cpu( (attn_mask.dim() == 2 || attn_mask.dim() == 4), "_scaled_dot_product_int8_cpu: Attention mask dim in {2, 4}"); - at::Tensor output = at::empty_like(query, query.options()).transpose(1, 2); - sdpa_int8_kernel(output, query, key, value, - dropout_p, is_causal, attn_mask, scale, - q_zp, q_scale, - k_zp, k_scale, - v_zp, v_scale, - a_zp, a_scale, - o_zp, o_scale); + if (!at::native::cpublas::could_pack(dtype)) { + return sdpa_int8_math_kernel(query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_zp, q_scale, + k_zp, k_scale, + v_zp, v_scale, + a_zp, a_scale, + o_zp, o_scale); + } - return output.transpose(1, 2); + #ifdef CPU_CAPABILITY_AVX512 + at::Tensor output = at::empty_like(query, query.options()).transpose(1, 2); + sdpa_int8_fused_kernel(output, query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_zp, q_scale, + k_zp, k_scale, + v_zp, v_scale, + a_zp, a_scale, + o_zp, o_scale); + return output.transpose(1, 2); + #else + return sdpa_int8_math_kernel(query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_zp, q_scale, + k_zp, k_scale, + v_zp, v_scale, + a_zp, a_scale, + o_zp, o_scale); + #endif // CPU_CAPABILITY_AVX512 } From e7c5a2247f51418041daf240e56c1e5978e789c1 Mon Sep 17 00:00:00 2001 From: Valentine233 Date: Thu, 13 Feb 2025 02:54:05 -0500 Subject: [PATCH 08/36] update int8 sdpa cpu --- test/test_ops.py | 44 +++++++++++++++++++++++++++++++++++++++ torchao/csrc/cpu/sdpa.cpp | 34 ++++++++++++++---------------- torchao/ops.py | 27 +++++++++++++++++++++++- 3 files changed, 86 insertions(+), 19 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index d7094f29f4..c7222c662c 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -154,6 +154,21 @@ def _scaled_dot_product_int8_op_ref( SDPA_INT8_HEAD_DIM = [32, 64] SDPA_INT8_MASK_DTYPE = [None, torch.float32, torch.bfloat16] + # VIT + # SDPA_INT8_BATCH_SIZE = [224] + # SDPA_INT8_NUM_HEADS = [12] + # SDPA_INT8_Q_SEQ_LEN = [197] + # SDPA_INT8_KV_SEQ_LEN = [197] + # SDPA_INT8_HEAD_DIM = [64] + # SDPA_INT8_MASK_DTYPE = [torch.bfloat16] + # BERTLARGE + # SDPA_INT8_BATCH_SIZE = [120] + # SDPA_INT8_NUM_HEADS = [16] + # SDPA_INT8_Q_SEQ_LEN = [384] + # SDPA_INT8_KV_SEQ_LEN = [384] + # SDPA_INT8_HEAD_DIM = [64] + # SDPA_INT8_MASK_DTYPE = [torch.bfloat16] + @parametrize("batch_size", SDPA_INT8_BATCH_SIZE) @parametrize("n_head", SDPA_INT8_NUM_HEADS) @parametrize("q_seq_len", SDPA_INT8_Q_SEQ_LEN) @@ -176,6 +191,7 @@ def test_scaled_dot_product_int8_op(self, batch_size, n_head, q_seq_len, kv_seq_ q_shape = [batch_size, q_seq_len, n_head, head_dim] kv_shape = [batch_size, kv_seq_len, n_head, head_dim] mask_shape = [batch_size, 1, 1, kv_seq_len] + print(f"q_shape: {q_shape}") q = torch.randn(q_shape, dtype=torch.float, device=device).transpose(1, 2) * 100 k = torch.randn(kv_shape, dtype=torch.float, device=device).transpose(1, 2) * 100 v = torch.randn(kv_shape, dtype=torch.float, device=device).transpose(1, 2) * 100 @@ -223,6 +239,34 @@ def test_scaled_dot_product_int8_op(self, batch_size, n_head, q_seq_len, kv_seq_ ) self.assertEqual(actual, math_ref, atol=1.0, rtol=5e-6) + + iter_n = 20 + with torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CPU], + schedule=torch.profiler.schedule(wait=2, warmup=iter_n, active=20), + ) as prof: + for _ in range(iter_n + 22): + r = torch.ops.torchao.scaled_dot_product_int8( + q, + k, + v, + attn_mask=attn_mask_2, + dropout_p=0.0, + is_causal=False, + q_zp=q_zp, + q_scale=q_scale, + k_zp=k_zp, + k_scale=k_scale, + v_zp=v_zp, + v_scale=v_scale, + a_zp=a_zp, + a_scale=a_scale, + o_zp=o_zp, + o_scale=o_scale + ) + prof.step() + print(prof.key_averages().table(sort_by="self_cpu_time_total")) + instantiate_parametrized_tests(TestOps) diff --git a/torchao/csrc/cpu/sdpa.cpp b/torchao/csrc/cpu/sdpa.cpp index c2db4fa389..20938b64ec 100644 --- a/torchao/csrc/cpu/sdpa.cpp +++ b/torchao/csrc/cpu/sdpa.cpp @@ -961,10 +961,10 @@ inline void copy_value_with_pad( } -// UINT8 - u8u8s32 +// UINT8 - one parallel loop with u8u8s32 GEMM template inline typename std::enable_if_t, void> -sdpa_int8_kernel_large_impl( +sdpa_int8_kernel_one_loop_impl( const at::Tensor& output, const at::Tensor& q, const at::Tensor& k, @@ -1549,10 +1549,10 @@ sdpa_int8_kernel_large_impl( at::native::cpublas::brgemm_release(); } -// UINT8 - u8u8s32 +// UINT8 - several parallel loops with u8u8s32 GEMM template inline typename std::enable_if_t, void> -sdpa_int8_kernel_small_impl( +sdpa_int8_kernel_several_loops_impl( const at::Tensor& output, const at::Tensor& q, const at::Tensor& k, @@ -2213,7 +2213,7 @@ void sdpa_int8_fused_kernel( if (use_one_parallel_loop) { if (!attn_mask.defined()) { if (q_seq_len >= 768) { - sdpa_int8_kernel_large_impl( + sdpa_int8_kernel_one_loop_impl( output, query, key, value, dropout_p, is_causal, attn_mask, scale, q_zp, q_scale, @@ -2222,7 +2222,7 @@ void sdpa_int8_fused_kernel( a_zp, a_scale, o_zp, o_scale); } else if (q_seq_len >= 192) { - sdpa_int8_kernel_large_impl( + sdpa_int8_kernel_one_loop_impl( output, query, key, value, dropout_p, is_causal, attn_mask, scale, q_zp, q_scale, @@ -2231,7 +2231,7 @@ void sdpa_int8_fused_kernel( a_zp, a_scale, o_zp, o_scale); } else { - sdpa_int8_kernel_large_impl( + sdpa_int8_kernel_one_loop_impl( output, query, key, value, dropout_p, is_causal, attn_mask, scale, q_zp, q_scale, @@ -2243,7 +2243,7 @@ void sdpa_int8_fused_kernel( } else { AT_DISPATCH_MASK_TYPES(attn_mask.scalar_type(), "sdpa_mask", [&]() { if (q_seq_len >= 768) { - sdpa_int8_kernel_large_impl( + sdpa_int8_kernel_one_loop_impl( output, query, key, value, dropout_p, is_causal, attn_mask, scale, q_zp, q_scale, @@ -2252,7 +2252,7 @@ void sdpa_int8_fused_kernel( a_zp, a_scale, o_zp, o_scale); } else if (q_seq_len >= 192) { - sdpa_int8_kernel_large_impl( + sdpa_int8_kernel_one_loop_impl( output, query, key, value, dropout_p, is_causal, attn_mask, scale, q_zp, q_scale, @@ -2261,7 +2261,7 @@ void sdpa_int8_fused_kernel( a_zp, a_scale, o_zp, o_scale); } else { - sdpa_int8_kernel_large_impl( + sdpa_int8_kernel_one_loop_impl( output, query, key, value, dropout_p, is_causal, attn_mask, scale, q_zp, q_scale, @@ -2275,7 +2275,7 @@ void sdpa_int8_fused_kernel( } else { if (!attn_mask.defined()) { if (q_seq_len >= 768) { - sdpa_int8_kernel_small_impl( + sdpa_int8_kernel_several_loops_impl( output, query, key, value, dropout_p, is_causal, attn_mask, scale, q_zp, q_scale, @@ -2284,7 +2284,7 @@ void sdpa_int8_fused_kernel( a_zp, a_scale, o_zp, o_scale); } else if (q_seq_len >= 192) { - sdpa_int8_kernel_small_impl( + sdpa_int8_kernel_several_loops_impl( output, query, key, value, dropout_p, is_causal, attn_mask, scale, q_zp, q_scale, @@ -2293,7 +2293,7 @@ void sdpa_int8_fused_kernel( a_zp, a_scale, o_zp, o_scale); } else { - sdpa_int8_kernel_small_impl( + sdpa_int8_kernel_several_loops_impl( output, query, key, value, dropout_p, is_causal, attn_mask, scale, q_zp, q_scale, @@ -2305,7 +2305,7 @@ void sdpa_int8_fused_kernel( } else { AT_DISPATCH_MASK_TYPES(attn_mask.scalar_type(), "sdpa_mask", [&]() { if (q_seq_len >= 768) { - sdpa_int8_kernel_small_impl( + sdpa_int8_kernel_several_loops_impl( output, query, key, value, dropout_p, is_causal, attn_mask, scale, q_zp, q_scale, @@ -2314,7 +2314,7 @@ void sdpa_int8_fused_kernel( a_zp, a_scale, o_zp, o_scale); } else if (q_seq_len >= 192) { - sdpa_int8_kernel_small_impl( + sdpa_int8_kernel_several_loops_impl( output, query, key, value, dropout_p, is_causal, attn_mask, scale, q_zp, q_scale, @@ -2323,7 +2323,7 @@ void sdpa_int8_fused_kernel( a_zp, a_scale, o_zp, o_scale); } else { - sdpa_int8_kernel_small_impl( + sdpa_int8_kernel_several_loops_impl( output, query, key, value, dropout_p, is_causal, attn_mask, scale, q_zp, q_scale, @@ -2386,11 +2386,9 @@ at::Tensor _scaled_dot_product_int8_cpu( const at::Tensor& key, const at::Tensor& value, at::Tensor& attn_mask, - // const std::optional& attn_mask, double dropout_p, bool is_causal, double scale, - // std::optional scale, int64_t q_zp, double q_scale, int64_t k_zp, diff --git a/torchao/ops.py b/torchao/ops.py index 732e0083a5..d1849b3034 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -56,7 +56,7 @@ tags=[torch._C.Tag.needs_fixed_stride_order], ) lib.define( - "scaled_dot_product_int8(Tensor query, Tensor key, Tensor value, Tensor attn_mask=None, float dropout_p=0.0, bool is_causal=False, float scale=1.0, int q_zp=0, float q_scale=1.0, int k_zp=0, float k_scale=1.0, int v_zp=0, float v_scale=1.0, int a_zp=0, float a_scale=1.0, int o_zp=0, float o_scale=1.0) -> Tensor" + "scaled_dot_product_int8(Tensor query, Tensor key, Tensor value, Tensor attn_mask=None, float dropout_p=0.0, bool is_causal=False, float scale=0.0, int q_zp=0, float q_scale=1.0, int k_zp=0, float k_scale=1.0, int v_zp=0, float v_scale=1.0, int a_zp=0, float a_scale=1.0, int o_zp=0, float o_scale=1.0) -> Tensor" ) def register_custom_op(name): @@ -180,6 +180,31 @@ def scaled_dot_product_int8( o_zp: int = 0, o_scale: float = 1.0, ) -> Tensor: + """ + Quantized SDPA with uint8 inputs and outputs. + + Arguments + query: input query tensor, + key: input key tensor, + value: input value tensor, + attn_mask: attention mask tensor, + dropout_p: dropout probability, + is_causal: causal flag, + scale: scaling factor applied prior to softmax, + q_zp: zero point for query from linear quantization, + q_scale: scale for query of linear quantization, + k_zp: zero point of key of linear quantization, + k_scale: scale for key of linear quantization, + v_zp: zero point of value from linear quantization, + v_scale: zero point for value from linear quantization, + a_zp: zero point for attention from softmax quantization, + a_scale: scale for attention from softmax quantization, + o_zp: zero point for output from linear quantization, + o_scale: scale for output from linear quantization, + + Returns + output of quantized SDPA + """ return torch.ops.torchao.scaled_dot_product_int8.default(query, key, value, attn_mask, dropout_p, is_causal, scale, q_zp, q_scale, From 18b3ae9360d1628718c797f6c555cd4e560292b0 Mon Sep 17 00:00:00 2001 From: Valentine233 Date: Fri, 14 Feb 2025 03:13:04 -0500 Subject: [PATCH 09/36] add heuristic strategy selection --- test/quantization/test_sfdp_int8_fx_pass.py | 43 ++++++++++++++++----- torchao/csrc/cpu/sdpa.cpp | 34 ++++++++++------ torchao/ops.py | 6 +-- 3 files changed, 60 insertions(+), 23 deletions(-) diff --git a/test/quantization/test_sfdp_int8_fx_pass.py b/test/quantization/test_sfdp_int8_fx_pass.py index 9c5f362ff8..fec93f01fd 100644 --- a/test/quantization/test_sfdp_int8_fx_pass.py +++ b/test/quantization/test_sfdp_int8_fx_pass.py @@ -126,8 +126,9 @@ def _check_common( counters.clear() torch.manual_seed(1234) + compiled_model = torch.compile(dot_prod_attention, fullgraph=True) result2, source_code = run_and_get_code( - torch.compile(dot_prod_attention, fullgraph=True), + compiled_model, *(args2 + dropout_arg), ) source_code = "\n".join(source_code) @@ -155,27 +156,51 @@ def _check_common( ): self.assertEqual(arg1.grad, arg2.grad, atol=atol, rtol=rtol) + iter_n = 20 + with torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CPU], + schedule=torch.profiler.schedule(wait=2, warmup=iter_n, active=20), + ) as prof: + for _ in range(iter_n + 22): + r = compiled_model(*(args2 + dropout_arg)) + prof.step() + print(prof.key_averages().table(sort_by="self_cpu_time_total")) + @skipIfRocm @config.patch({"freezing": True}) def _test_sdpa_rewriter_int8_1_to_4(self): # pattern is different for bs=1 - for dtype, has_mask, bs in itertools.product( - [torch.float32, torch.bfloat16], [True, False], [56, 1] - ): + # for dtype, has_mask, bs in itertools.product( + # [torch.float32, torch.bfloat16], [True, False], [56, 1] + # ): + dtype = torch.bfloat16 + has_mask = True + is_bs_1 = 1 + if is_bs_1: + candidates = [[1, 384, 16, 64], [1, 197, 12, 64]] + else: + candidates = [[120, 384, 16, 64], [224, 197, 12, 64]] + for bs, seqlen, numhead, headsize in candidates: mod = SelfAttnLikeModule( - input_dim=64 * 16, + input_dim=headsize * numhead, has_mask=has_mask, - num_attention_heads=16, - attention_head_size=64, + num_attention_heads=numhead, + attention_head_size=headsize, ).eval() maybe_autocast = ( torch.cpu.amp.autocast() if dtype == torch.bfloat16 else contextlib.nullcontext() ) + print("\nTEST shape", bs, numhead, seqlen, headsize) inputs = ( - torch.randn((bs, 384, 64 * 16), device=self.device, dtype=dtype), - torch.randn((bs, 1, 1, 384), device=self.device) if has_mask else None, + torch.randn( + (bs, seqlen, headsize * numhead), device=self.device, dtype=dtype + ) + * 10, + torch.randn((bs, 1, 1, seqlen), device=self.device) * 10 + if has_mask + else None, ) with torch.no_grad(), maybe_autocast: _sfdp_init_int8() diff --git a/torchao/csrc/cpu/sdpa.cpp b/torchao/csrc/cpu/sdpa.cpp index 20938b64ec..45b646a94b 100644 --- a/torchao/csrc/cpu/sdpa.cpp +++ b/torchao/csrc/cpu/sdpa.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include @@ -2204,15 +2205,26 @@ void sdpa_int8_fused_kernel( double a_scale, long o_zp, double o_scale) { + TORCH_CHECK(query.scalar_type() == c10::kByte); int64_t batchSize = query.size(0); int64_t num_head = query.size(1); int64_t q_seq_len = query.size(2); - - TORCH_CHECK(query.scalar_type() == c10::kByte); - bool use_one_parallel_loop = batchSize * num_head % 40 == 0; + int64_t kv_seq_len = key.size(2); + int64_t q_split_size = 32; + if (q_seq_len >= 768) { + q_split_size = 256; + } else if (q_seq_len >= 192) { + q_split_size = 64; + } + // Heuristic to decide whether to use one parallel loop or not + uint32_t l2_cache_size = at::cpu::L2_cache_size(); + int64_t num_thread = at::get_num_threads(); + int64_t attn_size = q_split_size * kv_seq_len * sizeof(int32_t) * num_thread; + bool use_one_parallel_loop = (batchSize * num_head > num_thread) && + (attn_size > l2_cache_size); if (use_one_parallel_loop) { if (!attn_mask.defined()) { - if (q_seq_len >= 768) { + if (q_split_size == 256) { sdpa_int8_kernel_one_loop_impl( output, query, key, value, dropout_p, is_causal, attn_mask, scale, @@ -2221,7 +2233,7 @@ void sdpa_int8_fused_kernel( v_zp, v_scale, a_zp, a_scale, o_zp, o_scale); - } else if (q_seq_len >= 192) { + } else if (q_split_size == 64) { sdpa_int8_kernel_one_loop_impl( output, query, key, value, dropout_p, is_causal, attn_mask, scale, @@ -2242,7 +2254,7 @@ void sdpa_int8_fused_kernel( } } else { AT_DISPATCH_MASK_TYPES(attn_mask.scalar_type(), "sdpa_mask", [&]() { - if (q_seq_len >= 768) { + if (q_split_size == 256) { sdpa_int8_kernel_one_loop_impl( output, query, key, value, dropout_p, is_causal, attn_mask, scale, @@ -2251,7 +2263,7 @@ void sdpa_int8_fused_kernel( v_zp, v_scale, a_zp, a_scale, o_zp, o_scale); - } else if (q_seq_len >= 192) { + } else if (q_split_size == 64) { sdpa_int8_kernel_one_loop_impl( output, query, key, value, dropout_p, is_causal, attn_mask, scale, @@ -2274,7 +2286,7 @@ void sdpa_int8_fused_kernel( } } else { if (!attn_mask.defined()) { - if (q_seq_len >= 768) { + if (q_split_size == 256) { sdpa_int8_kernel_several_loops_impl( output, query, key, value, dropout_p, is_causal, attn_mask, scale, @@ -2283,7 +2295,7 @@ void sdpa_int8_fused_kernel( v_zp, v_scale, a_zp, a_scale, o_zp, o_scale); - } else if (q_seq_len >= 192) { + } else if (q_split_size == 64) { sdpa_int8_kernel_several_loops_impl( output, query, key, value, dropout_p, is_causal, attn_mask, scale, @@ -2304,7 +2316,7 @@ void sdpa_int8_fused_kernel( } } else { AT_DISPATCH_MASK_TYPES(attn_mask.scalar_type(), "sdpa_mask", [&]() { - if (q_seq_len >= 768) { + if (q_split_size == 256) { sdpa_int8_kernel_several_loops_impl( output, query, key, value, dropout_p, is_causal, attn_mask, scale, @@ -2313,7 +2325,7 @@ void sdpa_int8_fused_kernel( v_zp, v_scale, a_zp, a_scale, o_zp, o_scale); - } else if (q_seq_len >= 192) { + } else if (q_split_size == 64) { sdpa_int8_kernel_several_loops_impl( output, query, key, value, dropout_p, is_causal, attn_mask, scale, diff --git a/torchao/ops.py b/torchao/ops.py index d1849b3034..40bbc39ba3 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -192,9 +192,9 @@ def scaled_dot_product_int8( is_causal: causal flag, scale: scaling factor applied prior to softmax, q_zp: zero point for query from linear quantization, - q_scale: scale for query of linear quantization, - k_zp: zero point of key of linear quantization, - k_scale: scale for key of linear quantization, + q_scale: scale for query from linear quantization, + k_zp: zero point of key from linear quantization, + k_scale: scale for key from linear quantization, v_zp: zero point of value from linear quantization, v_scale: zero point for value from linear quantization, a_zp: zero point for attention from softmax quantization, From dd7179860d87673ad2acf21c79636e9d3b270b00 Mon Sep 17 00:00:00 2001 From: Valentine233 Date: Tue, 25 Feb 2025 20:21:24 -0500 Subject: [PATCH 10/36] update pattern match --- test/quantization/test_sfdp_int8_fx_pass.py | 50 +- torchao/csrc/cpu/sdpa.cpp | 242 ++-- torchao/quantization/__init__.py | 2 +- torchao/quantization/sfdp_int8_fx_pass.py | 1163 ++++++++----------- 4 files changed, 608 insertions(+), 849 deletions(-) diff --git a/test/quantization/test_sfdp_int8_fx_pass.py b/test/quantization/test_sfdp_int8_fx_pass.py index fec93f01fd..45556bc9a1 100644 --- a/test/quantization/test_sfdp_int8_fx_pass.py +++ b/test/quantization/test_sfdp_int8_fx_pass.py @@ -21,7 +21,7 @@ from torch.ao.quantization.quantizer.x86_inductor_quantizer import ( X86InductorQuantizer, ) -from torchao.quantization.sfdp_int8_fx_pass import _sfdp_init_int8 +from torchao.quantization.sfdp_int8_fx_pass import _sfdp_int8_init class SelfAttnLikeModule(torch.nn.Module): def __init__( @@ -62,7 +62,7 @@ def forward(self, x, mask): k = self.transpose_for_scores(k) v = self.transpose_for_scores(v) scores = torch.matmul(q, k.transpose(-1, -2)) / (self.input_dim**0.5) - if self.has_mask: + if self.has_mask and mask.dtype != scores.dtype: scores = scores + mask attention = self.softmax(scores) attention = self.dropout(attention) @@ -73,15 +73,6 @@ def forward(self, x, mask): ) return self.dense(context_layer) -def _generate_qdq_quantized_model(mod, inputs, quantizer): - with torch.no_grad(): - export_model = export_for_training(mod, inputs).module() - prepare_model = prepare_pt2e(export_model, quantizer) - prepare_model(*inputs) - convert_model = convert_pt2e(prepare_model) - torch.ao.quantization.move_exported_model_to_eval(convert_model) - return convert_model - class TestSDPAPatternRewriterTemplate(TestCase): def _clone_inputs(self, inputs): def clone(x): @@ -133,7 +124,7 @@ def _check_common( ) source_code = "\n".join(source_code) if has_fuse_pattern: - self.assertGreaterEqual(counters["inductor"]["fuse_attention_int8"], 1) + self.assertGreaterEqual(counters["inductor"]["int8_fuse_attention"], 1) if contains: # many of the patterns get re-expanded in dispatcher self.assertIn( @@ -170,17 +161,19 @@ def _check_common( @config.patch({"freezing": True}) def _test_sdpa_rewriter_int8_1_to_4(self): # pattern is different for bs=1 - # for dtype, has_mask, bs in itertools.product( - # [torch.float32, torch.bfloat16], [True, False], [56, 1] - # ): - dtype = torch.bfloat16 - has_mask = True - is_bs_1 = 1 - if is_bs_1: - candidates = [[1, 384, 16, 64], [1, 197, 12, 64]] - else: - candidates = [[120, 384, 16, 64], [224, 197, 12, 64]] - for bs, seqlen, numhead, headsize in candidates: + for dtype, has_mask, bs in itertools.product( + [torch.float32, torch.bfloat16], [True, False], [56, 1] + ): + seqlen, numhead, headsize = 197, 16, 64 + # dtype = torch.bfloat16 + # has_mask = True + # is_bs_1 = 0 + # if is_bs_1: + # candidates = [[1, 384, 16, 64], [1, 197, 12, 64]] + # else: + # candidates = [[120, 384, 16, 64], [224, 197, 12, 64]] + # candidates = [[120, 384, 16, 64]] + # for bs, seqlen, numhead, headsize in candidates: mod = SelfAttnLikeModule( input_dim=headsize * numhead, has_mask=has_mask, @@ -203,13 +196,20 @@ def _test_sdpa_rewriter_int8_1_to_4(self): else None, ) with torch.no_grad(), maybe_autocast: - _sfdp_init_int8() + _sfdp_int8_init() quantizer = X86InductorQuantizer() quantizer.set_global(xiq.get_default_x86_inductor_quantization_config()) quantizer.set_function_type_qconfig( torch.matmul, quantizer.get_global_quantization_config() ) - convert_model = _generate_qdq_quantized_model(mod, inputs, quantizer) + export_model = export_for_training( + mod, + inputs, + ).module() + prepare_model = prepare_pt2e(export_model, quantizer) + prepare_model(*inputs) + convert_model = convert_pt2e(prepare_model) + torch.ao.quantization.move_exported_model_to_eval(convert_model) self._check_common( convert_model, args1=inputs, check_train=False, atol=1.0 ) diff --git a/torchao/csrc/cpu/sdpa.cpp b/torchao/csrc/cpu/sdpa.cpp index 45b646a94b..4f5ca3fcaf 100644 --- a/torchao/csrc/cpu/sdpa.cpp +++ b/torchao/csrc/cpu/sdpa.cpp @@ -967,9 +967,9 @@ template , void> sdpa_int8_kernel_one_loop_impl( const at::Tensor& output, - const at::Tensor& q, - const at::Tensor& k, - const at::Tensor& v, + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, double dropout_p, bool is_causal, at::Tensor& attention_mask, @@ -984,15 +984,9 @@ sdpa_int8_kernel_one_loop_impl( float a_scale, int32_t o_zp, float o_scale) { - // Query (Batch x Num_heads x Q_seq_len x Dim_per_head) - // -> (Batch x Q_seq_len x Num_heads x Dim_per_head) - // Key (Batch x Num_heads x KV_seq_len x Dim_per_head) - // -> (Batch x KV_seq_len x Num_heads x Dim_per_head) - // Value (Batch x Num_heads x KV_seq_len x Dim_per_head) - // -> (Batch x KV_seq_len x Num_heads x Dim_per_head) - at::Tensor query = q.transpose(1, 2); - at::Tensor key = k.transpose(1, 2); - at::Tensor value = v.transpose(1, 2); + // Query (Batch x Q_seq_len x Num_heads x Dim_per_head) + // Key (Batch x KV_seq_len x Num_heads x Dim_per_head) + // Value (Batch x KV_seq_len x Num_heads x Dim_per_head) const auto accumulate_dtype = at::kFloat; @@ -1190,12 +1184,10 @@ sdpa_int8_kernel_one_loop_impl( for (int64_t n = 0; n < kvSize; n += kvSplitSize) { if (n + kvSplitSize < kvSize) { for (int64_t b = 0; b < rndkvSplitSize; b += block_64) { - bool tail = kvSplitSize - b < block_64; do_transpose( - // do_copy( k_data + i * kStrideB + j * kStrideH + n * kStrideN + b * kStrideN, B_blocked_xform_u8, - tail ? kvSplitSize - b : block_64, + std::min(int(kvSplitSize - b), block_64), headSize, kStrideN, block_64); @@ -1240,11 +1232,10 @@ sdpa_int8_kernel_one_loop_impl( auto block_size = kvTail >= block_64 ? block_64 : kv_tail_tail_block_size; int64_t b = 0; while (b < rndkvTail) { - bool tail = kvTail - b < block_size; do_transpose( k_data + i * kStrideB + j * kStrideH + n * kStrideN + b * kStrideN, B_blocked_xform_u8, - tail ? kvTail - b : block_size, + std::min(kvTail - b, block_size), headSize, kStrideN, block_size); @@ -1259,29 +1250,16 @@ sdpa_int8_kernel_one_loop_impl( ); } // Pack - if (block_size == block_64) { - at::native::cpublas::pack( - qk_gemm_K, - block_64, - block_64, - block_64, - u8_dt, - u8_dt, - B_blocked_xform_u8, - key_reorder_ptr + n * qk_gemm_K + - b * qk_gemm_K); - } else { - at::native::cpublas::pack( - qk_gemm_K, - kv_tail_tail_block_size, - kv_tail_tail_block_size, - kv_tail_tail_block_size, - u8_dt, - u8_dt, - B_blocked_xform_u8, - key_reorder_ptr + n * qk_gemm_K + - b * qk_gemm_K); - } + at::native::cpublas::pack( + qk_gemm_K, + block_size, + block_size, + block_size, + u8_dt, + u8_dt, + B_blocked_xform_u8, + key_reorder_ptr + n * qk_gemm_K + + b * qk_gemm_K); b += block_size; block_size = (kvTail - b) >= block_64 ? block_64 : kv_tail_tail_block_size; } @@ -1356,29 +1334,16 @@ sdpa_int8_kernel_one_loop_impl( auto block_size = kvTail >= block_64 ? block_64 : kv_tail_tail_block_size; int64_t b = 0; while (b < kvTail) { - if (block_size == block_64) { - at::native::cpublas::brgemm( - qSplitSize, block_64, qk_gemm_K, - qk_gemm_K, // lda - block_64, //ldb - rndkvTail, //ldc - false, - query_t_padding_ptr, - key_reorder_ptr + n * qk_gemm_K + - b * qk_gemm_K, - qk_s32_data + b); - } else { - at::native::cpublas::brgemm( - qSplitSize, kv_tail_tail_block_size, qk_gemm_K, - qk_gemm_K, // lda - kv_tail_tail_block_size, //ldb - rndkvTail, //ldc - false, - query_t_padding_ptr, - key_reorder_ptr + n * qk_gemm_K + - b * qk_gemm_K, - qk_s32_data + b); - } + at::native::cpublas::brgemm( + qSplitSize, block_size, qk_gemm_K, + qk_gemm_K, // lda + block_size, //ldb + rndkvTail, //ldc + false, + query_t_padding_ptr, + key_reorder_ptr + n * qk_gemm_K + + b * qk_gemm_K, + qk_s32_data + b); b += block_size; block_size = (kvTail - b) >= block_64 ? block_64 : kv_tail_tail_block_size; } @@ -1402,29 +1367,16 @@ sdpa_int8_kernel_one_loop_impl( auto block_size = kvTail >= block_64 ? block_64 : kv_tail_tail_block_size; int64_t b = 0; while (b < kvTail) { - if (block_size == block_64) { - at::native::cpublas::brgemm( - qTail, block_64, qk_gemm_K, - qk_gemm_K, // lda - block_64, //ldb - rndkvTail, //ldc - false, - query_t_padding_ptr, - key_reorder_ptr + n * qk_gemm_K + - b * qk_gemm_K, - qk_s32_data + b); - } else { - at::native::cpublas::brgemm( - qSplitSize, kv_tail_tail_block_size, qk_gemm_K, - qk_gemm_K, // lda - kv_tail_tail_block_size, //ldb - rndkvTail, //ldc - false, - query_t_padding_ptr, - key_reorder_ptr + n * qk_gemm_K + - b * qk_gemm_K, - qk_s32_data + b); - } + at::native::cpublas::brgemm( + qTail, block_size, qk_gemm_K, + qk_gemm_K, // lda + block_size, //ldb + rndkvTail, //ldc + false, + query_t_padding_ptr, + key_reorder_ptr + n * qk_gemm_K + + b * qk_gemm_K, + qk_s32_data + b); b += block_size; block_size = (kvTail - b) >= block_64 ? block_64 : kv_tail_tail_block_size; } @@ -1555,9 +1507,9 @@ template , void> sdpa_int8_kernel_several_loops_impl( const at::Tensor& output, - const at::Tensor& q, - const at::Tensor& k, - const at::Tensor& v, + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, double dropout_p, bool is_causal, at::Tensor& attention_mask, @@ -1572,15 +1524,9 @@ sdpa_int8_kernel_several_loops_impl( float a_scale, int32_t o_zp, float o_scale) { - // Query (Batch x Num_heads x Q_seq_len x Dim_per_head) - // -> (Batch x Q_seq_len x Num_heads x Dim_per_head) - // Key (Batch x Num_heads x KV_seq_len x Dim_per_head) - // -> (Batch x KV_seq_len x Num_heads x Dim_per_head) - // Value (Batch x Num_heads x KV_seq_len x Dim_per_head) - // -> (Batch x KV_seq_len x Num_heads x Dim_per_head) - at::Tensor query = q.transpose(1, 2); - at::Tensor key = k.transpose(1, 2); - at::Tensor value = v.transpose(1, 2); + // Query (Batch x Q_seq_len x Num_heads x Dim_per_head) + // Key (Batch x KV_seq_len x Num_heads x Dim_per_head) + // Value (Batch x KV_seq_len x Num_heads x Dim_per_head) const auto accumulate_dtype = at::kFloat; @@ -1781,11 +1727,10 @@ sdpa_int8_kernel_several_loops_impl( j * kvSlice * v_reorder_strideL + n * rndHeadSize; if (n + kvSplitSize < kvSize) { for (int64_t b = 0; b < rndkvSplitSize; b += block_64) { - bool tail = kvSplitSize - b < block_64; do_transpose( k_data + i * kStrideB + j * kStrideH + n * kStrideN + b * kStrideN, B_blocked_xform_u8, - tail ? kvSplitSize - b : block_64, + std::min(int(kvSplitSize - b), block_64), headSize, kStrideN, block_64); @@ -1827,11 +1772,10 @@ sdpa_int8_kernel_several_loops_impl( auto block_size = kvTail >= block_64 ? block_64 : kv_tail_tail_block_size; int64_t b = 0; while (b < rndkvTail) { - bool tail = kvTail - b < block_size; do_transpose( k_data + i * kStrideB + j * kStrideH + n * kStrideN + b * kStrideN, B_blocked_xform_u8, - tail ? kvTail - b : block_size, + std::min(kvTail - b, block_size), headSize, kStrideN, block_size); @@ -1846,27 +1790,15 @@ sdpa_int8_kernel_several_loops_impl( ); } // Pack - if (block_size == block_64) { - at::native::cpublas::pack( - qk_gemm_K, - block_64, - block_64, - block_64, - u8_dt, - u8_dt, - B_blocked_xform_u8, - k_reorder + b * qk_gemm_K); - } else { - at::native::cpublas::pack( - qk_gemm_K, - kv_tail_tail_block_size, - kv_tail_tail_block_size, - kv_tail_tail_block_size, - u8_dt, - u8_dt, - B_blocked_xform_u8, - k_reorder + b * qk_gemm_K); - } + at::native::cpublas::pack( + qk_gemm_K, + block_size, + block_size, + block_size, + u8_dt, + u8_dt, + B_blocked_xform_u8, + k_reorder + b * qk_gemm_K); b += block_size; block_size = (kvTail - b) >= block_64 ? block_64 : kv_tail_tail_block_size; } @@ -1980,27 +1912,15 @@ sdpa_int8_kernel_several_loops_impl( auto block_size = kvTail >= block_64 ? block_64 : kv_tail_tail_block_size; int64_t b = 0; while (b < kvTail) { - if (block_size == block_64) { - at::native::cpublas::brgemm( - qSplitSize, block_64, qk_gemm_K, - qk_gemm_K, // lda - block_64, //ldb - rndkvTail, //ldc - false, - query_t_padding_ptr, - k_reorder + b * qk_gemm_K, - qk_s32_data + b); - } else { - at::native::cpublas::brgemm( - qSplitSize, kv_tail_tail_block_size, qk_gemm_K, - qk_gemm_K, // lda - kv_tail_tail_block_size, //ldb - rndkvTail, //ldc - false, - query_t_padding_ptr, - k_reorder + b * qk_gemm_K, - qk_s32_data + b); - } + at::native::cpublas::brgemm( + qSplitSize, block_size, qk_gemm_K, + qk_gemm_K, // lda + block_size, //ldb + rndkvTail, //ldc + false, + query_t_padding_ptr, + k_reorder + b * qk_gemm_K, + qk_s32_data + b); b += block_size; block_size = (kvTail - b) >= block_64 ? block_64 : kv_tail_tail_block_size; } @@ -2023,27 +1943,15 @@ sdpa_int8_kernel_several_loops_impl( auto block_size = kvTail >= block_64 ? block_64 : kv_tail_tail_block_size; int64_t b = 0; while (b < kvTail) { - if (block_size == block_64) { - at::native::cpublas::brgemm( - qTail, block_64, qk_gemm_K, - qk_gemm_K, // lda - block_64, //ldb - rndkvTail, //ldc - false, - query_t_padding_ptr, - k_reorder + b * qk_gemm_K, - qk_s32_data + b); - } else { - at::native::cpublas::brgemm( - qSplitSize, kv_tail_tail_block_size, qk_gemm_K, - qk_gemm_K, // lda - kv_tail_tail_block_size, //ldb - rndkvTail, //ldc - false, - query_t_padding_ptr, - k_reorder + b * qk_gemm_K, - qk_s32_data + b); - } + at::native::cpublas::brgemm( + qTail, block_size, qk_gemm_K, + qk_gemm_K, // lda + block_size, //ldb + rndkvTail, //ldc + false, + query_t_padding_ptr, + k_reorder + b * qk_gemm_K, + qk_s32_data + b); b += block_size; block_size = (kvTail - b) >= block_64 ? block_64 : kv_tail_tail_block_size; } @@ -2443,7 +2351,7 @@ at::Tensor _scaled_dot_product_int8_cpu( } #ifdef CPU_CAPABILITY_AVX512 - at::Tensor output = at::empty_like(query, query.options()).transpose(1, 2); + at::Tensor output = at::empty_like(query, query.options()); sdpa_int8_fused_kernel(output, query, key, value, dropout_p, is_causal, attn_mask, scale, q_zp, q_scale, @@ -2451,7 +2359,7 @@ at::Tensor _scaled_dot_product_int8_cpu( v_zp, v_scale, a_zp, a_scale, o_zp, o_scale); - return output.transpose(1, 2); + return output; #else return sdpa_int8_math_kernel(query, key, value, dropout_p, is_causal, attn_mask, scale, diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index a53a87a919..9ef71a44fc 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -95,7 +95,7 @@ swap_linear_with_smooth_fq_linear, ) from .sfdp_int8_fx_pass import ( - _sfdp_init_int8, + _sfdp_int8_init, ) from .subclass import * # noqa: F403 from .transform_module import register_quantize_module_handler diff --git a/torchao/quantization/sfdp_int8_fx_pass.py b/torchao/quantization/sfdp_int8_fx_pass.py index 9ff7dc7e8c..9359ae57b5 100644 --- a/torchao/quantization/sfdp_int8_fx_pass.py +++ b/torchao/quantization/sfdp_int8_fx_pass.py @@ -4,687 +4,538 @@ import torch from torch._inductor import config from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, filter_nodes, - fwd_only, - register_replacement, - gen_register_replacement, + KeywordArg, + ListOf, + Match, PatternMatcherPass, ) +from torch._inductor.fx_passes.post_grad import register_lowering_pattern +from torch._inductor.lowering import lowerings as L, make_fallback from torch._dynamo.utils import counters -from torch._inductor.fx_passes.fuse_attention import ( - partialize_and_update_signature -) -from torchao.ops import scaled_dot_product_int8 __all__ = [ - "_sfdp_init_int8", + "_sfdp_int8_post_grad_init", ] +make_fallback(torch.ops.torchao.scaled_dot_product_int8.default) + aten = torch.ops.aten patterns = PatternMatcherPass() -def _sfdp_pattern_int8_1( - query, - key, - value, - attn_mask, - inv_scale, - q_zp, - q_scale, - k_zp, - k_scale, - v_zp, - v_scale, - a_zp, - a_scale, - o_zp, - o_scale, - dropout, -): - # int8-mix-fp32 QUANTIZED SDPA with mask - q = query.permute([0, 2, 1, 3]) - q = torch.ops.quantized_decomposed.dequantize_per_tensor.default( - q, float(q_scale), int(q_zp), 0, 255, torch.uint8 - ) - k = key.permute([0, 2, 1, 3]).transpose(-2, -1) - k = torch.ops.quantized_decomposed.dequantize_per_tensor.default( - k, float(k_scale), int(k_zp), 0, 255, torch.uint8 - ) - v = value.permute([0, 2, 1, 3]) - v = torch.ops.quantized_decomposed.dequantize_per_tensor.default( - v, float(v_scale), int(v_zp), 0, 255, torch.uint8 - ) - a = torch.nn.functional.dropout( - (torch.matmul(q, k).div(inv_scale) + attn_mask).softmax(dim=-1), - dropout, - ) - qa = torch.ops.quantized_decomposed.quantize_per_tensor.default( - a, float(a_scale), int(a_zp), 0, 255, torch.uint8 - ) - a = torch.ops.quantized_decomposed.dequantize_per_tensor.default( - qa, float(a_scale), int(a_zp), 0, 255, torch.uint8 - ) - o = a.matmul(v) - o = o.permute(0, 2, 1, 3).contiguous() - return torch.ops.quantized_decomposed.quantize_per_tensor.default( - o, float(o_scale), int(o_zp), 0, 255, torch.uint8 - ) - - -def _sfdp_replacement_int8_1( - query, - key, - value, - attn_mask, - inv_scale, - q_zp, - q_scale, - k_zp, - k_scale, - v_zp, - v_scale, - a_zp, - a_scale, - o_zp, - o_scale, - dropout, -): - print("hit _sfdp_replacement_int8_1") - counters["inductor"]["fuse_attention_int8"] += 1 - res = scaled_dot_product_int8( - query.transpose(1, 2), - key.transpose(1, 2), - value.transpose(1, 2), - attn_mask=attn_mask, - dropout_p=dropout, - is_causal=False, - scale=1.0 / inv_scale, - q_zp=q_zp, - q_scale=q_scale, - k_zp=k_zp, - k_scale=k_scale, - v_zp=v_zp, - v_scale=v_scale, - a_zp=a_zp, - a_scale=a_scale, - o_zp=o_zp, - o_scale=o_scale, - ) - return res.permute(0, 2, 1, 3).contiguous() - - -def _sfdp_pattern_int8_2( - query, - key, - value, - attn_mask, - inv_scale, - q_zp, - q_scale, - k_zp, - k_scale, - v_zp, - v_scale, - a_zp, - a_scale, - o_zp, - o_scale, - dropout, -): - # int8-mix-reduce QUANTIZED SDPA with mask - q = query.permute([0, 2, 1, 3]) - q = torch.ops.quantized_decomposed.dequantize_per_tensor.default( - q, float(q_scale), int(q_zp), 0, 255, - # torch.uint8).to(torch.bfloat16) - torch.uint8, out_dtype=torch.bfloat16).to(torch.bfloat16) - k = key.permute([0, 2, 1, 3]).transpose(-2, -1) - k = torch.ops.quantized_decomposed.dequantize_per_tensor.default( - k, float(k_scale), int(k_zp), 0, 255, - # torch.uint8).to(torch.bfloat16) - torch.uint8, out_dtype=torch.bfloat16).to(torch.bfloat16) - v = value.permute([0, 2, 1, 3]) - v = torch.ops.quantized_decomposed.dequantize_per_tensor.default( - v, float(v_scale), int(v_zp), 0, 255, - # torch.uint8).to(torch.bfloat16) - torch.uint8, out_dtype=torch.bfloat16).to(torch.bfloat16) - a = torch.nn.functional.dropout( - (torch.matmul(q, k).div(inv_scale) + attn_mask).softmax(dim=-1), - dropout, - ) - qa = torch.ops.quantized_decomposed.quantize_per_tensor.default( - a, float(a_scale), int(a_zp), 0, 255, torch.uint8 - ) - a = torch.ops.quantized_decomposed.dequantize_per_tensor.default( - qa, float(a_scale), int(a_zp), 0, 255, - torch.uint8).to(torch.bfloat16) - # torch.uint8, out_dtype=torch.bfloat16) - o = a.matmul(v) - o = o.permute(0, 2, 1, 3).contiguous() - return torch.ops.quantized_decomposed.quantize_per_tensor.default( - o, float(o_scale), int(o_zp), 0, 255, torch.uint8 - ) - - -def _sfdp_replacement_int8_2( - query, - key, - value, - attn_mask, - inv_scale, - q_zp, - q_scale, - k_zp, - k_scale, - v_zp, - v_scale, - a_zp, - a_scale, - o_zp, - o_scale, - dropout, -): - print("hit _sfdp_replacement_int8_2") - counters["inductor"]["fuse_attention_int8"] += 1 - res = scaled_dot_product_int8( - query.transpose(1, 2), - key.transpose(1, 2), - value.transpose(1, 2), - attn_mask=attn_mask, - dropout_p=dropout, - is_causal=False, - scale=1.0 / inv_scale, - q_zp=q_zp, - q_scale=q_scale, - k_zp=k_zp, - k_scale=k_scale, - v_zp=v_zp, - v_scale=v_scale, - a_zp=a_zp, - a_scale=a_scale, - o_zp=o_zp, - o_scale=o_scale, - ) - return res.permute(0, 2, 1, 3).contiguous() - - -def _sfdp_pattern_int8_3( - query, - key, - value, - inv_scale, - q_zp, - q_scale, - k_zp, - k_scale, - v_zp, - v_scale, - a_zp, - a_scale, - o_zp, - o_scale, - dropout, -): - # int8-mix-fp32 QUANTIZED SDPA without mask - q = query.permute([0, 2, 1, 3]) - q = torch.ops.quantized_decomposed.dequantize_per_tensor.default( - q, float(q_scale), int(q_zp), 0, 255, torch.uint8 - ) - k = key.permute([0, 2, 1, 3]).transpose(-2, -1) - k = torch.ops.quantized_decomposed.dequantize_per_tensor.default( - k, float(k_scale), int(k_zp), 0, 255, torch.uint8 - ) - v = value.permute([0, 2, 1, 3]) - v = torch.ops.quantized_decomposed.dequantize_per_tensor.default( - v, float(v_scale), int(v_zp), 0, 255, torch.uint8 - ) - a = torch.nn.functional.dropout( - torch.matmul(q, k).div(inv_scale).softmax(dim=-1), - dropout, - ) - qa = torch.ops.quantized_decomposed.quantize_per_tensor.default( - a, float(a_scale), int(a_zp), 0, 255, torch.uint8 - ) - a = torch.ops.quantized_decomposed.dequantize_per_tensor.default( - qa, float(a_scale), int(a_zp), 0, 255, torch.uint8 - ) - o = a.matmul(v) - o = o.permute(0, 2, 1, 3).contiguous() - return torch.ops.quantized_decomposed.quantize_per_tensor.default( - o, float(o_scale), int(o_zp), 0, 255, torch.uint8 - ) - - -def _sfdp_replacement_int8_3( - query, - key, - value, - inv_scale, - q_zp, - q_scale, - k_zp, - k_scale, - v_zp, - v_scale, - a_zp, - a_scale, - o_zp, - o_scale, - dropout, -): - print("hit _sfdp_replacement_int8_3") - counters["inductor"]["fuse_attention_int8"] += 1 - res = scaled_dot_product_int8( - query.transpose(1, 2), - key.transpose(1, 2), - value.transpose(1, 2), - dropout_p=dropout, - is_causal=False, - scale=1.0 / inv_scale, - q_zp=q_zp, - q_scale=q_scale, - k_zp=k_zp, - k_scale=k_scale, - v_zp=v_zp, - v_scale=v_scale, - a_zp=a_zp, - a_scale=a_scale, - o_zp=o_zp, - o_scale=o_scale, - ) - return res.permute(0, 2, 1, 3).contiguous() - - -def _sfdp_pattern_int8_4( - query, - key, - value, - inv_scale, - q_zp, - q_scale, - k_zp, - k_scale, - v_zp, - v_scale, - a_zp, - a_scale, - o_zp, - o_scale, - dropout, -): - # int8-mix-reduce QUANTIZED SDPA without mask - q = query.permute([0, 2, 1, 3]) - q = torch.ops.quantized_decomposed.dequantize_per_tensor.default( - q, float(q_scale), int(q_zp), 0, 255, - torch.uint8, out_dtype=torch.bfloat16) - k = key.permute([0, 2, 1, 3]).transpose(-2, -1) - k = torch.ops.quantized_decomposed.dequantize_per_tensor.default( - k, float(k_scale), int(k_zp), 0, 255, - torch.uint8, out_dtype=torch.bfloat16) - v = value.permute([0, 2, 1, 3]) - v = torch.ops.quantized_decomposed.dequantize_per_tensor.default( - v, float(v_scale), int(v_zp), 0, 255, - torch.uint8, out_dtype=torch.bfloat16) - a = torch.nn.functional.dropout( - torch.matmul(q, k).div(inv_scale).softmax(dim=-1), - dropout, - ) - qa = torch.ops.quantized_decomposed.quantize_per_tensor.default( - a, float(a_scale), int(a_zp), 0, 255, torch.uint8 - ) - a = torch.ops.quantized_decomposed.dequantize_per_tensor.default( - qa, float(a_scale), int(a_zp), 0, 255, - torch.uint8, out_dtype=torch.bfloat16) - o = a.matmul(v) - o = o.permute(0, 2, 1, 3).contiguous() - return torch.ops.quantized_decomposed.quantize_per_tensor.default( - o, float(o_scale), int(o_zp), 0, 255, torch.uint8 - ) - - -def _sfdp_replacement_int8_4( - query, - key, - value, - inv_scale, - q_zp, - q_scale, - k_zp, - k_scale, - v_zp, - v_scale, - a_zp, - a_scale, - o_zp, - o_scale, - dropout, -): - print("hit _sfdp_replacement_int8_4") - counters["inductor"]["fuse_attention_int8"] += 1 - res = scaled_dot_product_int8( - query.transpose(1, 2), - key.transpose(1, 2), - value.transpose(1, 2), - dropout_p=dropout, - is_causal=False, - scale=1.0 / inv_scale, - q_zp=q_zp, - q_scale=q_scale, - k_zp=k_zp, - k_scale=k_scale, - v_zp=v_zp, - v_scale=v_scale, - a_zp=a_zp, - a_scale=a_scale, - o_zp=o_zp, - o_scale=o_scale, - ) - return res.permute(0, 2, 1, 3).contiguous() - - -def _sfdp_params_check_int8(match): - assert all(k in match.kwargs for k in ("query", "key", "value")) - query = match.kwargs["query"].meta["val"] - key = match.kwargs["key"].meta["val"] - value = match.kwargs["value"].meta["val"] - if not (query.dtype == key.dtype == value.dtype) or not ( - query.device == key.device == value.device - ): - return False - add_nodes = filter_nodes(match.nodes, aten.add.Tensor) - # Has attn_mask add. - add_mask_node = [n for n in add_nodes if n.prev.target == torch.ops.aten.div.Tensor] - if len(add_mask_node) > 0: - attn_mask_node = add_mask_node[0].args[1] - # attn_mask_node may be a float/int number. - if not hasattr(attn_mask_node, "meta"): - return False - attn_mask = attn_mask_node.meta["val"] # type: ignore[union-attr] - # Make sure attn_mask.dtype == query.dtype or attn_mask.dtype == torch.bool - # attn_mask.dtype == torch.float for models like albert. - if ( - not isinstance(attn_mask, torch.Tensor) - or not ( - attn_mask.dtype == query.dtype - or attn_mask.dtype == torch.bool - or attn_mask.dtype == torch.float - ) - or query.device != attn_mask.device - ): - return False - return True - - -def _sfdp_extra_check_int8(scale_factor_op=None, disable_cuda=False): +def _is_valid_int8_sdpa_pattern(): def fn(match): - if ( - disable_cuda - and "query" in match.kwargs - and "cuda" in str(match.kwargs["query"].meta["val"].device) - ): - return False - if scale_factor_op is not None: - scale_factor_node = filter_nodes(match.nodes, scale_factor_op)[0] - # Note: args[1] of the scale_factor_node is always the scale_factor for the current patterns. - scale_factor = scale_factor_node.args[1] - # make sure the scale_factor a float/int. SymInt? - if not isinstance(scale_factor, (float, int)): - return False - return _sfdp_params_check_int8(match) + assert all(k in match.kwargs for k in ("query", "key", "value")) + query = match.kwargs["query"].meta["val"] + key = match.kwargs["key"].meta["val"] + value = match.kwargs["value"].meta["val"] + return ( + query.dtype == torch.uint8 + and key.dtype == torch.uint8 + and value.dtype == torch.uint8 + and query.device.type == "cpu" + and key.device == query.device + and value.device == query.device + ) return fn -def _gen_sfdp_patterns_int8(): - if torch.cuda.is_available(): - device = "cuda" +def _register_int8_sdpa_pattern(pattern): + @register_lowering_pattern( + pattern, + extra_check=_is_valid_int8_sdpa_pattern(), + ) + def int8_sdpa(match: Match, *args, **kwargs): + print("\n***hit int8_sdpa_pattern***\n") + query = kwargs["query"] + key = kwargs["key"] + value = kwargs["value"] + inv_scale = kwargs["inv_scale"] + attn_mask = kwargs["attn_mask"] if "attn_mask" in kwargs else None + q_zp = kwargs["q_zp"] + q_scale = kwargs["q_scale"] + k_zp = kwargs["k_zp"] + k_scale = kwargs["k_scale"] + v_zp = kwargs["v_zp"] + v_scale = kwargs["v_scale"] + a_zp = kwargs["a_zp"] + a_scale = kwargs["a_scale"] + o_zp = kwargs["o_zp"] + o_scale = kwargs["o_scale"] + counters["inductor"]["int8_fuse_attention"] += 1 + counters["inductor"]["int8_sdpa_nodes"] += len(match.nodes) + + return L[torch.ops.torchao.scaled_dot_product_int8.default]( + query, + key, + value, + attn_mask, + 0.0, #dropout + False, #is_causal + 1.0 / inv_scale, #scale + q_zp, + q_scale, + k_zp, + k_scale, + v_zp, + v_scale, + a_zp, + a_scale, + o_zp, + o_scale, + ) + + return int8_sdpa + + +def _get_int8_sdpa_q_pattern(is_batch_size_1: bool, has_convert: bool): + int8_sdpa_q_basic_pattern = CallFunction( + torch.ops.quantized_decomposed.dequantize_per_tensor.default, # dequantize_per_tensor_3 + CallFunction( + aten.permute.default, # permute_3 + KeywordArg("query"), + Arg(), + ), + KeywordArg("q_scale"), + KeywordArg("q_zp"), + Arg(), + Arg(), + Arg(), + ) + if has_convert: + int8_sdpa_q_basic_pattern = CallFunction( + torch.ops.prims.convert_element_type.default, + int8_sdpa_q_basic_pattern, + Arg(), + ) + int8_sdpa_q_basic_pattern = CallFunction( + aten.expand.default, + int8_sdpa_q_basic_pattern, + Arg(), + ) + if is_batch_size_1: + # pattern is different for bs=1 + return CallFunction( + aten.reshape.default, # view_9 + int8_sdpa_q_basic_pattern, + Arg(), + ) else: - device = "cpu" - g_inp = functools.partial( - torch.empty, (2, 4, 8, 16), device=device, requires_grad=True - ) - m_inp = functools.partial(torch.empty, (2, 1, 1, 4), device=device) # attn_mask - c_inp = functools.partial(torch.tensor, 2.0, device=device) # inv_scale - zp_inp = functools.partial(torch.tensor, 127, device=device) # quant_zero_point - scale_inp = functools.partial(torch.tensor, 0.018, device=device) # quant_scale - - # reshape in matmul decomposition generates a clone when batch_size>1 due to the memory layout change. - # however when batch_size=1, reshape does not change the memory layout, so clone would not be generated. - # here we need to trace with input of batch_size=1 to generate a pattern graph without clone. - g_bs1_inp = functools.partial( - torch.empty, (1, 4, 8, 16), device=device, requires_grad=True - ) - m_bs1_inp = functools.partial(torch.empty, (1, 1, 1, 4), device=device) - for dtype in [torch.float, torch.bfloat16]: - g_u8 = functools.partial(g_inp, dtype=torch.uint8, requires_grad=False) - g_u8_bs1 = functools.partial(g_bs1_inp, dtype=torch.uint8, requires_grad=False) - m = functools.partial(m_inp, dtype=torch.float) - m_bs1 = functools.partial(m_bs1_inp, dtype=torch.float) - c = functools.partial(c_inp, dtype=dtype) - zp = functools.partial(zp_inp, dtype=torch.int) - scale = functools.partial(scale_inp, dtype=torch.float) - d_u8 = { - "dropout": 0.113377, - "q_zp": 23, - "q_scale": 0.0111541, - "k_zp": 14, - "k_scale": 0.0256212, - "v_zp": 28, - "v_scale": 0.0164518, - "a_zp": 12, - "a_scale": 0.0572114, - "o_zp": 36, - "o_scale": 0.0235489, - } - int8_candidates = [ - ( - _sfdp_pattern_int8_1, - _sfdp_replacement_int8_1, - [ - g_u8(), - g_u8(), - g_u8(), - m(), - c(), - zp(), - scale(), - zp(), - scale(), - zp(), - scale(), - zp(), - scale(), - zp(), - scale(), - ], - d_u8, - _sfdp_extra_check_int8(aten.div.Tensor), - ), - ( - _sfdp_pattern_int8_1, - _sfdp_replacement_int8_1, - [ - g_u8_bs1(), - g_u8_bs1(), - g_u8_bs1(), - m_bs1(), - c(), - zp(), - scale(), - zp(), - scale(), - zp(), - scale(), - zp(), - scale(), - zp(), - scale(), - ], - d_u8, - _sfdp_extra_check_int8(aten.div.Tensor), - ), - ( - _sfdp_pattern_int8_2, - _sfdp_replacement_int8_2, - [ - g_u8(), - g_u8(), - g_u8(), - m(), - c(), - zp(), - scale(), - zp(), - scale(), - zp(), - scale(), - zp(), - scale(), - zp(), - scale(), - ], - d_u8, - _sfdp_extra_check_int8(aten.div.Tensor), + return CallFunction( + aten.reshape.default, # view_9 + CallFunction( + aten.clone.default, # clone + int8_sdpa_q_basic_pattern, + memory_format=Arg(), ), - ( - _sfdp_pattern_int8_2, - _sfdp_replacement_int8_2, - [ - g_u8_bs1(), - g_u8_bs1(), - g_u8_bs1(), - m_bs1(), - c(), - zp(), - scale(), - zp(), - scale(), - zp(), - scale(), - zp(), - scale(), - zp(), - scale(), - ], - d_u8, - _sfdp_extra_check_int8(aten.div.Tensor), + Arg(), + ) + + +def _get_int8_sdpa_k_pattern(is_batch_size_1: bool, has_convert: bool): + int8_sdpa_k_basic_pattern = CallFunction( + torch.ops.quantized_decomposed.dequantize_per_tensor.default, # dequantize_per_tensor_5 + CallFunction( + aten.permute.default, # permute_6 + CallFunction( + aten.permute.default, # permute_4 + KeywordArg("key"), + Arg(), ), - ( - _sfdp_pattern_int8_3, - _sfdp_replacement_int8_3, - [ - g_u8(), - g_u8(), - g_u8(), - c(), - zp(), - scale(), - zp(), - scale(), - zp(), - scale(), - zp(), - scale(), - zp(), - scale(), - ], - d_u8, - _sfdp_extra_check_int8(aten.div.Tensor), + Arg(), + ), + KeywordArg("k_scale"), + KeywordArg("k_zp"), + Arg(), + Arg(), + Arg(), + ) + if has_convert: + int8_sdpa_k_basic_pattern = CallFunction( + torch.ops.prims.convert_element_type.default, + int8_sdpa_k_basic_pattern, + Arg(), + ) + int8_sdpa_k_basic_pattern = CallFunction( + aten.expand.default, + int8_sdpa_k_basic_pattern, + Arg(), + ) + if is_batch_size_1: + # pattern is different for bs=1 + return CallFunction( + aten.reshape.default, # view_10 + int8_sdpa_k_basic_pattern, + Arg(), + ) + else: + return CallFunction( + aten.reshape.default, # view_10 + CallFunction( + aten.clone.default, # clone_1 + int8_sdpa_k_basic_pattern, + memory_format=Arg(), ), - ( - _sfdp_pattern_int8_3, - _sfdp_replacement_int8_3, - [ - g_u8_bs1(), - g_u8_bs1(), - g_u8_bs1(), - c(), - zp(), - scale(), - zp(), - scale(), - zp(), - scale(), - zp(), - scale(), - zp(), - scale(), - ], - d_u8, - _sfdp_extra_check_int8(aten.div.Tensor), + Arg(), + ) + + +def _get_int8_sdpa_v_pattern(is_batch_size_1: bool, has_convert: bool): + int8_sdpa_v_basic_pattern = CallFunction( + torch.ops.quantized_decomposed.dequantize_per_tensor.default, # dequantize_per_tensor_4 + CallFunction( + aten.permute.default, # permute_5 + KeywordArg("value"), + Arg(), + ), + KeywordArg("v_scale"), + KeywordArg("v_zp"), + Arg(), + Arg(), + Arg(), + ) + if has_convert: + int8_sdpa_v_basic_pattern = CallFunction( + torch.ops.prims.convert_element_type.default, + int8_sdpa_v_basic_pattern, + Arg(), + ) + int8_sdpa_v_basic_pattern = CallFunction( + aten.expand.default, # expand_3 + int8_sdpa_v_basic_pattern, + Arg(), + ) + if is_batch_size_1: + # pattern is different for bs=1 + return CallFunction( + aten.reshape.default, # view_13 + int8_sdpa_v_basic_pattern, + Arg(), + ) + else: + return CallFunction( + aten.reshape.default, # view_13 + CallFunction( + aten.clone.default, # clone_3 + int8_sdpa_v_basic_pattern, + memory_format=Arg(), ), - ( - _sfdp_pattern_int8_4, - _sfdp_replacement_int8_4, - [ - g_u8(), - g_u8(), - g_u8(), - c(), - zp(), - scale(), - zp(), - scale(), - zp(), - scale(), - zp(), - scale(), - zp(), - scale(), - ], - d_u8, - _sfdp_extra_check_int8(aten.div.Tensor), + Arg(), + ) + + +def _get_int8_sdpa_score_pattern( + has_mask: bool, is_batch_size_1: bool, is_reduced_type: bool, has_convert: bool +): + int8_sdpa_q_pattern = _get_int8_sdpa_q_pattern(is_batch_size_1, has_convert) + int8_sdpa_k_pattern = _get_int8_sdpa_k_pattern(is_batch_size_1, has_convert) + int8_sdpa_score_basic_pattern = CallFunction( + aten.reshape.default, # view_11 + CallFunction( + aten.bmm.default, # bmm + int8_sdpa_q_pattern, # view_9 + int8_sdpa_k_pattern, # view_10 + ), + Arg(), + ) + if is_reduced_type and not has_mask: + int8_sdpa_score_basic_pattern = CallFunction( + torch.ops.prims.convert_element_type.default, + int8_sdpa_score_basic_pattern, + Arg(), + ) + if has_mask: + return CallFunction( + aten.add.Tensor, # add + CallFunction( + aten.div.Tensor, # div + int8_sdpa_score_basic_pattern, + KeywordArg("inv_scale"), ), - ( - _sfdp_pattern_int8_4, - _sfdp_replacement_int8_4, - [ - g_u8_bs1(), - g_u8_bs1(), - g_u8_bs1(), - c(), - zp(), - scale(), - zp(), - scale(), - zp(), - scale(), - zp(), - scale(), - zp(), - scale(), - ], - d_u8, - _sfdp_extra_check_int8(aten.div.Tensor), + KeywordArg("attn_mask"), + _users=2, + ) + else: + return CallFunction( + aten.mul.Tensor, # mul_tensor + int8_sdpa_score_basic_pattern, + Arg(), + _users=2, + ) + + +def _get_int8_sdpa_exp_pattern( + has_mask: bool, is_batch_size_1: bool, is_reduced_type: bool, has_convert: bool +): + int8_sdpa_score_pattern = _get_int8_sdpa_score_pattern( + has_mask, is_batch_size_1, is_reduced_type, has_convert + ) + int8_sdpa_exp_basic_pattern = CallFunction( + aten.sub.Tensor, # sub + int8_sdpa_score_pattern, # add + CallFunction( + aten.amax.default, # amax + int8_sdpa_score_pattern, # add + Arg(), + Arg(), + ), + ) + if has_mask: + return CallFunction( + aten.exp.default, # exp + int8_sdpa_exp_basic_pattern, + _users=2, + ) + else: + return CallFunction( + aten.exp.default, # exp + CallFunction( + aten.div.Tensor, # div_tensor + int8_sdpa_exp_basic_pattern, + KeywordArg("inv_scale"), ), - ] - for pattern, replacement, args, workaround, extra_check in int8_candidates: - # XXX: when adding a new pattern, re-run `gen_attention_patterns` so the pattern - # gets serialized to a python file and does not require tracing at runtime. - assert isinstance(workaround, dict) - name = pattern.__name__ - - if len(workaround) >= 1: - pattern = partialize_and_update_signature(pattern, dropout=0.0) - replacement = partialize_and_update_signature( - replacement, dropout=0.0 + _users=2, + ) + + +def _get_int8_sdpa_attn_pattern( + has_mask: bool, is_batch_size_1: bool, is_reduced_type: bool, has_convert: bool +): + int8_sdpa_exp_pattern = _get_int8_sdpa_exp_pattern( + has_mask, is_batch_size_1, is_reduced_type, has_convert + ) + int8_sdpa_div_pattern = CallFunction( + aten.div.Tensor, # div_1 + int8_sdpa_exp_pattern, # exp + CallFunction( + aten.sum.dim_IntList, # sum_1 + int8_sdpa_exp_pattern, # exp + Arg(), + Arg(), + ), + ) + int8_sdpa_softmax_pattern = CallFunction( + torch.ops.quantized_decomposed.dequantize_per_tensor.default, # dequantize_per_tensor_6 + CallFunction( + torch.ops.quantized_decomposed.quantize_per_tensor.default, # quantize_per_tensor_4 + int8_sdpa_div_pattern, + KeywordArg("a_scale"), + KeywordArg("a_zp"), + Arg(), + Arg(), + Arg(), + ), + KeywordArg("a_scale"), + KeywordArg("a_zp"), + Arg(), + Arg(), + Arg(), + ) + if is_reduced_type: + if has_mask: + int8_sdpa_softmax_pattern = CallFunction( + torch.ops.prims.convert_element_type.default, + int8_sdpa_softmax_pattern, + Arg(), + ) + else: + int8_sdpa_softmax_pattern = CallFunction( + torch.ops.quantized_decomposed.dequantize_per_tensor.default, # dequantize_per_tensor_6 + CallFunction( + torch.ops.quantized_decomposed.quantize_per_tensor.default, # quantize_per_tensor_4 + CallFunction( + torch.ops.prims.convert_element_type.default, + int8_sdpa_div_pattern, + Arg(), + ), + KeywordArg("a_scale"), + KeywordArg("a_zp"), + Arg(), + Arg(), + Arg(), + ), + KeywordArg("a_scale"), + KeywordArg("a_zp"), + Arg(), + Arg(), + Arg(), ) - if "dropout" in workaround: - del workaround["dropout"] + if has_convert: + int8_sdpa_softmax_pattern = CallFunction( + torch.ops.prims.convert_element_type.default, + int8_sdpa_softmax_pattern, + Arg(), + ) + return CallFunction( + aten.reshape.default, # view_12 + CallFunction( + aten.expand.default, # expand_2 + int8_sdpa_softmax_pattern, + Arg(), + ), + Arg(), + ) + + +def _get_int8_sdpa_final_pattern( + has_mask: bool, is_batch_size_1: bool, is_reduced_type: bool, has_convert: bool +): + int8_sdpa_v_pattern = _get_int8_sdpa_v_pattern(is_batch_size_1, has_convert) + int8_sdpa_attn_pattern = _get_int8_sdpa_attn_pattern( + has_mask, is_batch_size_1, is_reduced_type, has_convert + ) + return CallFunction( + torch.ops.quantized_decomposed.quantize_per_tensor.default, # quantize_per_tensor_5 + CallFunction( + aten.clone.default, # clone_4 + CallFunction( + aten.permute.default, # permute_7 + CallFunction( + aten.reshape.default, # view_14 + CallFunction( + aten.bmm.default, # bmm_1 + int8_sdpa_attn_pattern, # view_12 + int8_sdpa_v_pattern, # view_13 + ), + Arg(), + ), + Arg(), + ), + memory_format=Arg(), + ), + KeywordArg("o_scale"), + KeywordArg("o_zp"), + Arg(), + Arg(), + Arg(), + ) + + +def _register_int8_sdpa_fp32_lowering(): + # dtype = float32, without attention mask, batch size > 1 + _register_int8_sdpa_pattern( + _get_int8_sdpa_final_pattern( + has_mask=False, is_batch_size_1=False, is_reduced_type=False, has_convert=True + ) + ) + _register_int8_sdpa_pattern( + _get_int8_sdpa_final_pattern( + has_mask=False, is_batch_size_1=False, is_reduced_type=False, has_convert=False + ) + ) + + +def _register_int8_sdpa_fp32_mask_lowering(): + # dtype = float32, with attention mask, batch size > 1 + _register_int8_sdpa_pattern( + _get_int8_sdpa_final_pattern( + has_mask=True, is_batch_size_1=False, is_reduced_type=False, has_convert=True + ) + ) + _register_int8_sdpa_pattern( + _get_int8_sdpa_final_pattern( + has_mask=True, is_batch_size_1=False, is_reduced_type=False, has_convert=False + ) + ) + - inference_name = name + "_inference" - yield inference_name, { - "search_fn": pattern, - "replace_fn": replacement, - "example_inputs": args, - "trace_fn": fwd_only, - "pass_dicts": patterns, - "extra_check": extra_check, - "scalar_workaround": workaround, - } +def _register_int8_sdpa_fp32_bs1_lowering(): + # dtype = float32, without attention mask, batch size == 1 + _register_int8_sdpa_pattern( + _get_int8_sdpa_final_pattern( + has_mask=False, is_batch_size_1=True, is_reduced_type=False, has_convert=True + ) + ) + _register_int8_sdpa_pattern( + _get_int8_sdpa_final_pattern( + has_mask=False, is_batch_size_1=True, is_reduced_type=False, has_convert=False + ) + ) + + +def _register_int8_sdpa_fp32_mask_bs1_lowering(): + # dtype = float32, with attention mask, batch size == 1 + _register_int8_sdpa_pattern( + _get_int8_sdpa_final_pattern( + has_mask=True, is_batch_size_1=True, is_reduced_type=False, has_convert=True + ) + ) + _register_int8_sdpa_pattern( + _get_int8_sdpa_final_pattern( + has_mask=True, is_batch_size_1=True, is_reduced_type=False, has_convert=False + ) + ) + + +def _register_int8_sdpa_bf16_lowering(): + # dtype = bfloat16, without attention mask, batch size > 1 + _register_int8_sdpa_pattern( + _get_int8_sdpa_final_pattern( + has_mask=False, is_batch_size_1=False, is_reduced_type=True, has_convert=True + ) + ) + _register_int8_sdpa_pattern( + _get_int8_sdpa_final_pattern( + has_mask=False, is_batch_size_1=False, is_reduced_type=True, has_convert=False + ) + ) + + +def _register_int8_sdpa_bf16_mask_lowering(): + # dtype = bfloat16, with attention mask, batch size > 1 + _register_int8_sdpa_pattern( + _get_int8_sdpa_final_pattern( + has_mask=True, is_batch_size_1=False, is_reduced_type=True, has_convert=True + ) + ) + _register_int8_sdpa_pattern( + _get_int8_sdpa_final_pattern( + has_mask=True, is_batch_size_1=False, is_reduced_type=True, has_convert=False + ) + ) + + +def _register_int8_sdpa_bf16_bs1_lowering(): + # dtype = bfloat16, without attention mask, batch size == 1 + _register_int8_sdpa_pattern( + _get_int8_sdpa_final_pattern( + has_mask=False, is_batch_size_1=True, is_reduced_type=True, has_convert=True + ) + ) + _register_int8_sdpa_pattern( + _get_int8_sdpa_final_pattern( + has_mask=False, is_batch_size_1=True, is_reduced_type=True, has_convert=False + ) + ) + + +def _register_int8_sdpa_bf16_mask_bs1_lowering(): + # dtype = bfloat16, with attention mask, batch size == 1 + _register_int8_sdpa_pattern( + _get_int8_sdpa_final_pattern( + has_mask=True, is_batch_size_1=True, is_reduced_type=True, has_convert=True + ) + ) + _register_int8_sdpa_pattern( + _get_int8_sdpa_final_pattern( + has_mask=True, is_batch_size_1=True, is_reduced_type=True, has_convert=False + ) + ) +def _register_quantized_sdpa_lowerings(): + _register_int8_sdpa_fp32_lowering() + _register_int8_sdpa_fp32_mask_lowering() + _register_int8_sdpa_fp32_bs1_lowering() + _register_int8_sdpa_fp32_mask_bs1_lowering() + _register_int8_sdpa_bf16_lowering() + _register_int8_sdpa_bf16_mask_lowering() + _register_int8_sdpa_bf16_bs1_lowering() + _register_int8_sdpa_bf16_mask_bs1_lowering() @functools.lru_cache(None) -def _sfdp_init_int8(): - for key, register_replacement_kwargs in _gen_sfdp_patterns_int8(): - register_replacement(**register_replacement_kwargs) - config.joint_custom_pre_pass = patterns.apply +def _sfdp_int8_post_grad_init(): + _register_quantized_sdpa_lowerings() + config.post_grad_custom_pre_pass = patterns.apply From 16f82cf5902eb1f58f427ca4ff6df0267e81e362 Mon Sep 17 00:00:00 2001 From: Valentine233 Date: Wed, 26 Feb 2025 00:52:55 -0500 Subject: [PATCH 11/36] update --- test/quantization/test_sfdp_int8_fx_pass.py | 224 -- torchao/csrc/cpu/sdpa.cpp | 2382 ------------------- torchao/quantization/__init__.py | 3 - torchao/quantization/sfdp_int8_fx_pass.py | 541 ----- 4 files changed, 3150 deletions(-) delete mode 100644 test/quantization/test_sfdp_int8_fx_pass.py delete mode 100644 torchao/csrc/cpu/sdpa.cpp delete mode 100644 torchao/quantization/sfdp_int8_fx_pass.py diff --git a/test/quantization/test_sfdp_int8_fx_pass.py b/test/quantization/test_sfdp_int8_fx_pass.py deleted file mode 100644 index 45556bc9a1..0000000000 --- a/test/quantization/test_sfdp_int8_fx_pass.py +++ /dev/null @@ -1,224 +0,0 @@ -import torchao - -import contextlib -import functools -import itertools -import math - -import torch -import torch.utils.checkpoint -from torch._dynamo.debug_utils import aot_graph_input_parser -from torch._dynamo.utils import counters -from torch._inductor import config -from torch._inductor.test_case import run_tests, TestCase -from torch._inductor.utils import run_and_get_code -from torch.testing._internal.common_utils import IS_LINUX, skipIfRocm -from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA - -import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq -from torch.export import export_for_training -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e -from torch.ao.quantization.quantizer.x86_inductor_quantizer import ( - X86InductorQuantizer, -) -from torchao.quantization.sfdp_int8_fx_pass import _sfdp_int8_init - -class SelfAttnLikeModule(torch.nn.Module): - def __init__( - self, - input_dim, - has_mask, - num_attention_heads=None, - attention_head_size=None, - ) -> None: - super().__init__() - self.input_dim = input_dim - self.q_proj = torch.nn.Linear(input_dim, input_dim, bias=False) - self.k_proj = torch.nn.Linear(input_dim, input_dim, bias=False) - self.v_proj = torch.nn.Linear(input_dim, input_dim, bias=False) - self.softmax = torch.nn.Softmax(dim=-1) - assert num_attention_heads is not None - assert attention_head_size is not None - self.num_attention_heads = num_attention_heads - self.attention_head_size = attention_head_size - self.all_head_size = self.num_attention_heads * self.attention_head_size - self.dense = torch.nn.Linear(self.all_head_size, self.all_head_size) - self.dropout = torch.nn.Dropout(0) - self.has_mask = has_mask - - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + ( - self.num_attention_heads, - self.attention_head_size, - ) - x = x.view(new_x_shape) - return x.permute([0, 2, 1, 3]) - - def forward(self, x, mask): - q = self.q_proj(x) - k = self.k_proj(x) - v = self.v_proj(x) - q = self.transpose_for_scores(q) - k = self.transpose_for_scores(k) - v = self.transpose_for_scores(v) - scores = torch.matmul(q, k.transpose(-1, -2)) / (self.input_dim**0.5) - if self.has_mask and mask.dtype != scores.dtype: - scores = scores + mask - attention = self.softmax(scores) - attention = self.dropout(attention) - context_layer = torch.matmul(attention, v) - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - context_layer = context_layer.view( - context_layer.size()[:-2] + (self.all_head_size,) - ) - return self.dense(context_layer) - -class TestSDPAPatternRewriterTemplate(TestCase): - def _clone_inputs(self, inputs): - def clone(x): - if not isinstance(x, torch.Tensor): - return x - return x.clone() - - return [clone(x) for x in inputs] - - def _check_common( - self, - dot_prod_attention, - args1=None, - contains=True, - atol=1e-5, - has_fuse_pattern=True, - has_dropout=False, - check_train=True, - override_check_equal=False, - dtype=torch.float, - rtol=1.3e-6, - ): - if args1 is None: - tensor_shape = (4, 2, 16, 32) - args1 = [ - torch.randn(tensor_shape, device=self.device, dtype=dtype), - torch.randn(tensor_shape, device=self.device, dtype=dtype), - torch.randn(tensor_shape, device=self.device, dtype=dtype), - ] - else: - args1 = list(args1) - args2 = self._clone_inputs(args1) - - for training in [False, True] if check_train else [False]: - for x in itertools.chain(args1[:], args2[:]): - if isinstance(x, torch.Tensor) and x.is_floating_point(): - x.requires_grad = training - - dropout_arg = [training] if has_dropout else [] - torch.manual_seed(1234) - result1 = dot_prod_attention(*(args1 + dropout_arg)) - - counters.clear() - torch.manual_seed(1234) - compiled_model = torch.compile(dot_prod_attention, fullgraph=True) - result2, source_code = run_and_get_code( - compiled_model, - *(args2 + dropout_arg), - ) - source_code = "\n".join(source_code) - if has_fuse_pattern: - self.assertGreaterEqual(counters["inductor"]["int8_fuse_attention"], 1) - if contains: - # many of the patterns get re-expanded in dispatcher - self.assertIn( - "torchao.scaled_dot_product_int8", - source_code, - ) - - # some tests configured with very low dropout where we still want to check equality - if not has_dropout or override_check_equal: - self.assertEqual(result1, result2, atol=atol, rtol=1.3e-6) - - if training: - result1.sum().backward() - result2.sum().backward() - for arg1, arg2 in zip(args1, args2): - if ( - isinstance(arg1, torch.Tensor) - and arg1.is_floating_point() - and (not has_dropout or override_check_equal) - ): - self.assertEqual(arg1.grad, arg2.grad, atol=atol, rtol=rtol) - - iter_n = 20 - with torch.profiler.profile( - activities=[torch.profiler.ProfilerActivity.CPU], - schedule=torch.profiler.schedule(wait=2, warmup=iter_n, active=20), - ) as prof: - for _ in range(iter_n + 22): - r = compiled_model(*(args2 + dropout_arg)) - prof.step() - print(prof.key_averages().table(sort_by="self_cpu_time_total")) - - @skipIfRocm - @config.patch({"freezing": True}) - def _test_sdpa_rewriter_int8_1_to_4(self): - # pattern is different for bs=1 - for dtype, has_mask, bs in itertools.product( - [torch.float32, torch.bfloat16], [True, False], [56, 1] - ): - seqlen, numhead, headsize = 197, 16, 64 - # dtype = torch.bfloat16 - # has_mask = True - # is_bs_1 = 0 - # if is_bs_1: - # candidates = [[1, 384, 16, 64], [1, 197, 12, 64]] - # else: - # candidates = [[120, 384, 16, 64], [224, 197, 12, 64]] - # candidates = [[120, 384, 16, 64]] - # for bs, seqlen, numhead, headsize in candidates: - mod = SelfAttnLikeModule( - input_dim=headsize * numhead, - has_mask=has_mask, - num_attention_heads=numhead, - attention_head_size=headsize, - ).eval() - maybe_autocast = ( - torch.cpu.amp.autocast() - if dtype == torch.bfloat16 - else contextlib.nullcontext() - ) - print("\nTEST shape", bs, numhead, seqlen, headsize) - inputs = ( - torch.randn( - (bs, seqlen, headsize * numhead), device=self.device, dtype=dtype - ) - * 10, - torch.randn((bs, 1, 1, seqlen), device=self.device) * 10 - if has_mask - else None, - ) - with torch.no_grad(), maybe_autocast: - _sfdp_int8_init() - quantizer = X86InductorQuantizer() - quantizer.set_global(xiq.get_default_x86_inductor_quantization_config()) - quantizer.set_function_type_qconfig( - torch.matmul, quantizer.get_global_quantization_config() - ) - export_model = export_for_training( - mod, - inputs, - ).module() - prepare_model = prepare_pt2e(export_model, quantizer) - prepare_model(*inputs) - convert_model = convert_pt2e(prepare_model) - torch.ao.quantization.move_exported_model_to_eval(convert_model) - self._check_common( - convert_model, args1=inputs, check_train=False, atol=1.0 - ) - -if HAS_CPU: - class SDPAPatternRewriterCpuTests(TestSDPAPatternRewriterTemplate): - device = "cpu" - test_sdpa_rewriter_int8_1_to_4_cpu = TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_int8_1_to_4 - -if __name__ == "__main__": - if IS_LINUX: - run_tests() diff --git a/torchao/csrc/cpu/sdpa.cpp b/torchao/csrc/cpu/sdpa.cpp deleted file mode 100644 index 4f5ca3fcaf..0000000000 --- a/torchao/csrc/cpu/sdpa.cpp +++ /dev/null @@ -1,2382 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include - -#ifndef AT_PER_OPERATOR_HEADERS -#include -#else -#include -#endif - -#include -#include -#include -#include -#include -#include - -namespace torchao { - -namespace { - -template -struct is_reduced_floating_point: - std::integral_constant || - std::is_same_v> { -}; - -template -constexpr bool is_reduced_floating_point_v = is_reduced_floating_point::value; - -inline double calculate_scale( - const at::Tensor& query, - double scale) { - return scale == 0.0 ? 1.0 / std::sqrt(query.size(-1)) : scale; -} - -#ifdef CPU_CAPABILITY_AVX512 -// out = val * a + b -// is_b_stride_zero: If the stride of b is 0 (mask broadcasting case), -// take b as a scalar pointer. -template -inline void _scale_attn_mask_fusion_kernel( - T1* a, - T2* b, - const int& size, - T1* out, - T1& val) { - const auto vec_size1 = at::vec::Vectorized::size(); - const auto vec_size2 = at::vec::Vectorized::size(); - constexpr int64_t T1_n = - (vec_size2 == vec_size1 * 2 && is_reduced_floating_point_v) ? 2 : 1; - constexpr int64_t T2_n = 1; - auto vec_scale = at::vec::VectorizedN(val); - int64_t i = 0; - for (; i < size - (size % vec_size2); i += vec_size2) { - auto a_n = at::vec::VectorizedN::loadu(a + i); - at::vec::VectorizedN b_n; - if constexpr(is_b_stride_zero) { - b_n = at::vec::VectorizedN((T1)b[0]); - } else { - b_n = at::vec::VectorizedN::loadu(b + i); - } - auto b_n_convert = at::vec::convert(b_n); - auto res = a_n * vec_scale + b_n_convert; - res.store(out + i); - } - for (; i < size; i++) { - auto tmp0 = a[i]; - T1 tmp1; - if constexpr(is_b_stride_zero) { - tmp1 = (T1)b[0]; - } else { - tmp1 = (T1)b[i]; - } - out[i] = tmp0 * val + tmp1; - } -} - -// 1) out = exp(a - val) -// 2) val = sum(out) -template -inline void _exp_reduce_sum_fusion_kernel( - T1* a, - const int& size, - T2* out, - T1& val) { - auto vec_size = at::vec::Vectorized::size(); - auto vec_max = at::vec::Vectorized(val); - T1 tmp_sum = 0; - auto vec_tmp_sum = at::vec::Vectorized(tmp_sum); - for (long i = 0; i < vec_size * (size / vec_size); i += vec_size) { - auto tmp0 = at::vec::Vectorized::loadu(a + i); - auto tmp1 = tmp0 - vec_max; - auto tmp2 = tmp1.exp_u20(); - vec_tmp_sum += tmp2; - _store(out + i, tmp2); - } - tmp_sum = at::vec::vec_reduce_all( - [](at::vec::Vectorized& x, at::vec::Vectorized& y) { - return x + y; - }, - vec_tmp_sum); - for (long i = vec_size * (size / vec_size); i < size; i++) { - auto tmp0 = a[i]; - auto tmp1 = tmp0 - val; - auto tmp2 = exp(tmp1); - tmp_sum += tmp2; - out[i] = tmp2; - } - val = tmp_sum; -} - -// 1) out = a * scale -// 2) max = max(out) -template -inline void _mul_reduce_max_fusion_kernel( - const scalar_t* a, - const scalar_t& scale, - const int& size, - scalar_t* out, - scalar_t& max) { - auto vec_size = at::vec::Vectorized::size(); - auto vec_scale = at::vec::Vectorized(scale); - scalar_t tmp_max = -std::numeric_limits::infinity(); - auto vec_tmp_max = at::vec::Vectorized(tmp_max); - for (long i = 0; i < vec_size * (size / vec_size); i += vec_size) { - auto tmp0 = at::vec::Vectorized::loadu(a + i); - auto tmp1 = tmp0 * vec_scale; - vec_tmp_max = at::vec::maximum(vec_tmp_max, tmp1); - _store(out + i, tmp1); - } - for (long i = vec_size * (size / vec_size); i < size; i++) { - auto tmp0 = a[i]; - auto tmp1 = tmp0 * scale; - tmp_max = std::max(tmp_max, tmp1); - out[i] = tmp1; - } - max = std::max(tmp_max, vec_tmp_max.reduce_max()); -} - -template -static inline scalar_t* conditional_data_ptr(scalar_t* ptr, scalar_t* ptr2) { - TORCH_CHECK(ptr2 == nullptr); - return ptr; -} - -template , int> = 0> -static inline scalar_t* conditional_data_ptr(float* ptr, scalar_t* ptr2) { - return ptr2; -} - -template -inline void fill_stub(scalar_t* data, scalar_t val, int64_t size) { - using Vec = at::vec::Vectorized; - Vec data_vec = Vec(val); - int64_t d = 0; - for (; d < size - (size % Vec::size()); d += Vec::size()) { - data_vec.store(data + d); - } - #if !defined(_MSC_VER) && !defined(COMPILING_FOR_MIN_SIZE) - # pragma unroll - #endif - for (; d < size; d++) { - data[d] = val; - } -} - -void reshape_attn_mask_to_4d( - at::Tensor& attn_mask, - int64_t batchSize, - int64_t num_head, - int64_t qSize, - int64_t kvSize) { - // Support mask shapes: - // 2d: ({Q_seq_len, 1} x {KV_seq_len, 1}) - // 4d: ({Batch, 1} x {Num_heads, 1} x {Q_seq_len, 1} x {KV_seq_len, 1}) - // Guaranteed in check_attn_mask_shape - int64_t attn_mask_size_0 = 1; - int64_t attn_mask_size_1 = 1; - if (attn_mask.dim() == 4) { - if (attn_mask.size(0) == batchSize) { - attn_mask_size_0 = batchSize; - } - if (attn_mask.size(1) == num_head) { - attn_mask_size_1 = num_head; - } - } - attn_mask = attn_mask - .view({attn_mask_size_0, attn_mask_size_1, attn_mask.size(-2), attn_mask.size(-1)}) - .expand({attn_mask_size_0, attn_mask_size_1, qSize, kvSize}); -} - -// TODO: Use at::native::_store instead when it supports Half. -template -inline void _store(scalar_t* dst, at::vec::Vectorized src, int size=at::vec::Vectorized::size()) { - src.store(dst, size); -} - -template -inline typename std::enable_if_t, void> -_store(scalar_t* dst, at::vec::Vectorized src) { - auto res = at::vec::convert_from_float(src, src); - res.store(dst, at::vec::Vectorized::size()); -} - -template -inline typename std::enable_if_t || std::is_same_v, void> -_store(scalar_t* dst, at::vec::Vectorized src) { - auto res = at::vec::convert(src); - res.store(dst, at::vec::Vectorized::size()); -} - -template -inline void pad_row_zero( - scalar_t* value_ptr, - scalar_t* padding_value_ptr, - int rows, - int cols, - int ldi) { - auto vec_size = at::vec::Vectorized::size(); - int i = 0; - for (; i < rows - 1; i++) { - int j = 0; - for (; j < cols - (cols % vec_size); j += vec_size) { - auto vec_v = - at::vec::Vectorized::loadu(value_ptr + i * ldi + j); - vec_v.store(padding_value_ptr + i * cols + j); - } - - if (j < cols) { - auto vec_v = at::vec::Vectorized::loadu( - value_ptr + i * ldi + j, cols - j); - vec_v.store(padding_value_ptr + i * cols + j, cols - j); - } - } - - // zero padding - int j = 0; - for (; j < cols - (cols % vec_size); j += vec_size) { - auto vec_v = at::vec::Vectorized(0); - vec_v.store(padding_value_ptr + i * cols + j); - } - - if (j < cols) { - auto vec_v = at::vec::Vectorized(0); - vec_v.store(padding_value_ptr + i * cols + j, cols - j); - } -} - -template -inline void pad_row_128_padding( - scalar_t* value_ptr, - scalar_t* padding_value_ptr, - int rows, - int cols, - int ldi, - int padding) { - auto vec_size = at::vec::Vectorized::size(); - int i = 0; - for (; i < rows - padding; i++) { - int j = 0; - for (; j < cols - (cols % vec_size); j += vec_size) { - auto vec_v = - at::vec::Vectorized::loadu(value_ptr + i * ldi + j); - vec_v.store(padding_value_ptr + i * cols + j); - } - - if (j < cols) { - auto vec_v = at::vec::Vectorized::loadu( - value_ptr + i * ldi + j, cols - j); - vec_v.store(padding_value_ptr + i * cols + j, cols - j); - } - } - - // 128 padding - for (; i < rows; i++) { - int j = 0; - for (; j < cols - (cols % vec_size); j += vec_size) { - auto vec_v = at::vec::Vectorized(128); - vec_v.store(padding_value_ptr + i * cols + j); - } - - if (j < cols) { - auto vec_v = at::vec::Vectorized(128); - vec_v.store(padding_value_ptr + i * cols + j, cols - j); - } - } -} - -template -inline void pad_col_zero( - scalar_t* value_ptr, - scalar_t* padding_value_ptr, - int rows, - int cols, - int ldi) { - auto vec_size = at::vec::Vectorized::size(); - for (int i = 0; i < rows; i++) { - int j = 0; - for (; j < cols - 1 - ((cols - 1) % vec_size); j += vec_size) { - auto vec_v = - at::vec::Vectorized::loadu(value_ptr + i * ldi + j); - vec_v.store(padding_value_ptr + i * cols + j); - } - if (j < cols - 1) { - auto vec_v = at::vec::Vectorized::loadu( - value_ptr + i * ldi + j, cols - 1 - j); - vec_v.store(padding_value_ptr + i * cols + j, cols - 1 - j); - *(padding_value_ptr + i * cols + cols - 1) = scalar_t(0); - } - } -} - -template -inline void pad_col_zero_padding( - scalar_t* value_ptr, - scalar_t* padding_value_ptr, - int rows, - int cols, - int ldi, - int padding) { - auto vec_size = at::vec::Vectorized::size(); - for (int i = 0; i < rows; i++) { - int j = 0; - for (; j < cols - padding - ((cols - padding) % vec_size); j += vec_size) { - auto vec_v = - at::vec::Vectorized::loadu(value_ptr + i * ldi + j); - vec_v.store(padding_value_ptr + i * cols + j); - } - if (j < cols - padding) { - auto vec_v = at::vec::Vectorized::loadu( - value_ptr + i * ldi + j, cols - padding - j); - vec_v.store(padding_value_ptr + i * cols + j, cols - padding - j); - *(padding_value_ptr + i * cols + cols - padding) = scalar_t(0); - } - } -} - -/* -1. dequant -2. add mask -3. max reduce for softmax -*/ -template -inline void _dequant_mask_max_fusion_kernel( - const int32_t* in, - const mask_t* mask_ptr, - const int32_t* sum_a_ptr, - const int32_t* sum_b_ptr, - const int& M, - const int& N, - const int& ldi, - const int& ldm, // leading dimension mask - const int& ldo, - const int32_t& beta, // zp_a*zp_b*k - const float& alpha, // scale_a*scale_b*scale_sdpa - float* out, - float* sfm_max_ptr) { - const int32_t vec_size = at::vec::Vectorized::size(); - auto vec_beta = at::vec::Vectorized(beta); - auto vec_alpha = at::vec::Vectorized(alpha); - for (long row = 0; row < M; row += 1) { - auto sum_a = sum_a_ptr[row]; - auto vec_sum_a = at::vec::Vectorized(sum_a); - const int32_t* tmp_in = in + row * ldi; - float* tmp_out = out + row * ldo; - const mask_t* mask_data_ptr = mask_ptr + row * ldm; - float tmp_max = -std::numeric_limits::infinity(); - auto vec_tmp_max = at::vec::Vectorized(tmp_max); - for (long col = 0; col < vec_size * (N / vec_size); col += vec_size) { - auto vec_sum_b = at::vec::Vectorized::loadu(sum_b_ptr + col); - auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col); - auto tmp1 = tmp0 - vec_sum_b; - auto tmp2 = tmp1 - vec_sum_a; - auto tmp3 = tmp2 + vec_beta; - auto tmp4 = at::vec::convert(tmp3); - auto tmp5 = tmp4 * vec_alpha; - auto tmp6 = at::vec::Vectorized::loadu(mask_data_ptr + col); - auto tmp7 = at::vec::convert(tmp6); - auto tmp8 = tmp5 + tmp7; - vec_tmp_max = at::vec::maximum(vec_tmp_max, tmp8); - _store(tmp_out + col, tmp8); - } - tmp_max = std::max(tmp_max, vec_tmp_max.reduce_max()); - for (long col = vec_size * (N / vec_size); col < N; col++) { - auto sum_b = sum_b_ptr[col]; - auto tmp0 = tmp_in[col]; - auto tmp1 = tmp0 - sum_b; - auto tmp2 = tmp1 - sum_a; - auto tmp3 = tmp2 + beta; - auto tmp4 = (float) tmp3; - auto tmp5 = tmp4 * alpha; - auto tmp6 = mask_data_ptr[col]; - auto tmp7 = (float) tmp6; - auto tmp8 = tmp5 + tmp7; - tmp_max = std::max(tmp_max, tmp8); - tmp_out[col] = tmp8; - } - sfm_max_ptr[row] = std::max(sfm_max_ptr[row], tmp_max); - } -} - -/* -1. dequant -2. max reduce for softmax -*/ -inline void _dequant_max_fusion_kernel( - const int32_t* in, - const int32_t* sum_a_ptr, - const int32_t* sum_b_ptr, - const int& M, - const int& N, - const int& ldi, - const int& ldo, - const int32_t& beta, // zp_a*zp_b*k - const float& alpha, // scale_a*scale_b*scale_sdpa - float* out, - float* sfm_max_ptr) { - const int32_t vec_size = at::vec::Vectorized::size(); - auto vec_beta = at::vec::Vectorized(beta); - auto vec_alpha = at::vec::Vectorized(alpha); - for (long row = 0; row < M; row += 1) { - auto sum_a = sum_a_ptr[row]; - auto vec_sum_a = at::vec::Vectorized(sum_a); - const int32_t* tmp_in = in + row * ldi; - float* tmp_out = out + row * ldo; - float tmp_max = -std::numeric_limits::infinity(); - auto vec_tmp_max = at::vec::Vectorized(tmp_max); - for (long col = 0; col < vec_size * (N / vec_size); col += vec_size) { - auto vec_sum_b = at::vec::Vectorized::loadu(sum_b_ptr + col); - auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col); - auto tmp1 = tmp0 - vec_sum_b; - auto tmp2 = tmp1 - vec_sum_a; - auto tmp3 = tmp2 + vec_beta; - auto tmp4 = at::vec::convert(tmp3); - auto tmp5 = tmp4 * vec_alpha; - vec_tmp_max = at::vec::maximum(vec_tmp_max, tmp5); - _store(tmp_out + col, tmp5); - } - tmp_max = std::max(tmp_max, vec_tmp_max.reduce_max()); - for (long col = vec_size * (N / vec_size); col < N; col++) { - auto sum_b = sum_b_ptr[col]; - auto tmp0 = tmp_in[col]; - auto tmp1 = tmp0 - sum_b; - auto tmp2 = tmp1 - sum_a; - auto tmp3 = tmp2 + beta; - auto tmp4 = (float) tmp3; - auto tmp5 = tmp4 * alpha; - tmp_max = std::max(tmp_max, tmp5); - tmp_out[col] = tmp5; - } - sfm_max_ptr[row] = std::max(sfm_max_ptr[row], tmp_max); - } -} - -/* -1. Softmax: sub max, exp, sum reduce, div sum -2. quant -3. sum for attention -*/ -template -inline void _sub_exp_sum_div_quant_sum_fusion_kernel( - const float* in, - const int64_t& M, - const int64_t& N_step, - const int64_t& NSlice, - const int& ldi, - const int& ldo, - const int& kvSize, - const int& rndkvSplitSize, - const int& av_gemm_K, - const int32_t& beta1, // zp_a - const int32_t& beta2, // zp_b - const float& alpha, // scale_a - float* local, - scalar_t* out, - float* sfm_max_ptr, - float* sfm_sum_ptr, - int32_t* sum_a_ptr) { - const int32_t vec_size = at::vec::Vectorized::size(); - float min_val = 0; - float max_val = 255; - auto vec_min_val = at::vec::Vectorized(min_val); - auto vec_max_val = at::vec::Vectorized(max_val); - scalar_t zero = 0; - auto vec_zero = at::vec::Vectorized(zero); - float beta1_float = (float) beta1; - auto vec_beta1 = at::vec::Vectorized(beta1_float); - for (int64_t row = 0; row < M; ++row) { - auto sfm_max = sfm_max_ptr[row]; - auto vec_max = at::vec::Vectorized(sfm_max); - // sub max, exp, sum reduce - const float* qk_block_data = in + row * rndkvSplitSize; - for (int64_t l = 0; l < NSlice; l ++) { - int64_t n = l * N_step; - int64_t kvBlockSize = std::min(N_step, kvSize - n); - const float* tmp_in = qk_block_data + l * ldi; - float tmp_sum = 0; - auto vec_tmp_sum = at::vec::Vectorized(tmp_sum); - float* tmp_out = local + n; - for (long col = 0; col < vec_size * (kvBlockSize / vec_size); col += vec_size) { - auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col); - auto tmp1 = tmp0 - vec_max; - auto tmp2 = tmp1.exp_u20(); - vec_tmp_sum += tmp2; - _store(tmp_out + col, tmp2); - } - tmp_sum += vec_tmp_sum.reduce_add(); - for (long col = vec_size * (kvBlockSize / vec_size); col < kvBlockSize; col++) { - auto tmp0 = tmp_in[col]; - auto tmp1 = tmp0 - sfm_max; - auto tmp2 = exp(tmp1); - tmp_sum += tmp2; - tmp_out[col] = tmp2; - } - sfm_sum_ptr[row] += tmp_sum; - } - // div sum, sum for attention - auto sum_scale = 1 / sfm_sum_ptr[row] / alpha; - auto vec_sum_scale = at::vec::Vectorized(sum_scale); - scalar_t* qk_reduced_block_data = out + row * av_gemm_K; - for (int64_t l = 0; l < NSlice; l ++) { - int64_t n = l * N_step; - int64_t kvBlockSize = std::min(N_step, kvSize - n); - int32_t tmp_sum = 0; - auto vec_tmp_sum = at::vec::Vectorized(tmp_sum); - float* tmp_in = local + n; - scalar_t* tmp_out = qk_reduced_block_data + l * ldo; - for (long col = 0; col < vec_size * (kvBlockSize / vec_size); col += vec_size) { - auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col); - auto tmp1 = tmp0 * vec_sum_scale; - auto tmp2 = tmp1.round(); - auto tmp3 = tmp2 + vec_beta1; - auto tmp4 = at::vec::maximum(tmp3, vec_min_val); - auto tmp5 = at::vec::minimum(tmp4, vec_max_val); - _store(tmp_out + col, tmp5); - auto tmp6 = at::vec::convert(tmp5); - vec_tmp_sum += tmp6; - } - tmp_sum += vec_tmp_sum.reduce_add(); - for (long col = vec_size * (kvBlockSize / vec_size); col < kvBlockSize; col++) { - auto tmp0 = tmp_in[col]; - auto tmp1 = tmp0 * sum_scale; - auto tmp2 = std::nearbyint(tmp1); - auto tmp3 = tmp2 + beta1_float; - auto tmp4 = std::max(tmp3, min_val); - auto tmp5 = std::min(tmp4, max_val); - tmp_out[col] = tmp5; - auto tmp6 = (int32_t) tmp5; - tmp_sum += tmp6; - } - sum_a_ptr[row] += tmp_sum * beta2; - // set zero - for (long col = kvBlockSize; col < vec_size * (av_gemm_K / vec_size); col += vec_size) { - _store(tmp_out + col, vec_zero); - } - for (long col = vec_size * (av_gemm_K / vec_size); col < av_gemm_K; col++) { - tmp_out[col] = zero; - } - } - } -} - -template -inline void _sub_exp_sum_div_quant_fusion_kernel( - const float* in, - const int64_t& M, - const int64_t& N_step, - const int64_t& NSlice, - const int& ldi, - const int& ldo, - const int& kvSize, - const int& rndkvSplitSize, - const int& av_gemm_K, - const int32_t& beta1, // zp_a - const float& alpha, // scale_a - float* local, - scalar_t* out, - float* sfm_max_ptr, - float* sfm_sum_ptr) { - const int32_t vec_size = at::vec::Vectorized::size(); - float min_val = 0; - float max_val = 255; - auto vec_min_val = at::vec::Vectorized(min_val); - auto vec_max_val = at::vec::Vectorized(max_val); - scalar_t zero = 0; - auto vec_zero = at::vec::Vectorized(zero); - float beta1_float = (float) beta1; - auto vec_beta1 = at::vec::Vectorized(beta1_float); - for (int64_t row = 0; row < M; ++row) { - auto sfm_max = sfm_max_ptr[row]; - auto vec_max = at::vec::Vectorized(sfm_max); - // sub max, exp, sum reduce - const float* qk_block_data = in + row * rndkvSplitSize; - for (int64_t l = 0; l < NSlice; l ++) { - int64_t n = l * N_step; - int64_t kvBlockSize = std::min(N_step, kvSize - n); - const float* tmp_in = qk_block_data + l * ldi; - float tmp_sum = 0; - auto vec_tmp_sum = at::vec::Vectorized(tmp_sum); - float* tmp_out = local + n; - for (long col = 0; col < vec_size * (kvBlockSize / vec_size); col += vec_size) { - auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col); - auto tmp1 = tmp0 - vec_max; - auto tmp2 = tmp1.exp_u20(); - vec_tmp_sum += tmp2; - _store(tmp_out + col, tmp2); - } - tmp_sum += vec_tmp_sum.reduce_add(); - for (long col = vec_size * (kvBlockSize / vec_size); col < kvBlockSize; col++) { - auto tmp0 = tmp_in[col]; - auto tmp1 = tmp0 - sfm_max; - auto tmp2 = exp(tmp1); - tmp_sum += tmp2; - tmp_out[col] = tmp2; - } - sfm_sum_ptr[row] += tmp_sum; - } - // div sum, sum for attention - auto sum_scale = 1 / sfm_sum_ptr[row] / alpha; - auto vec_sum_scale = at::vec::Vectorized(sum_scale); - scalar_t* qk_reduced_block_data = out + row * av_gemm_K; - for (int64_t l = 0; l < NSlice; l ++) { - int64_t n = l * N_step; - int64_t kvBlockSize = std::min(N_step, kvSize - n); - float* tmp_in = local + n; - scalar_t* tmp_out = qk_reduced_block_data + l * ldo; - for (long col = 0; col < vec_size * (kvBlockSize / vec_size); col += vec_size) { - auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col); - auto tmp1 = tmp0 * vec_sum_scale; - auto tmp2 = tmp1.round(); - auto tmp3 = tmp2 + vec_beta1; - auto tmp4 = at::vec::maximum(tmp3, vec_min_val); - auto tmp5 = at::vec::minimum(tmp4, vec_max_val); - _store(tmp_out + col, tmp5); - } - for (long col = vec_size * (kvBlockSize / vec_size); col < kvBlockSize; col++) { - auto tmp0 = tmp_in[col]; - auto tmp1 = tmp0 * sum_scale; - auto tmp2 = std::nearbyint(tmp1); - auto tmp3 = tmp2 + beta1_float; - auto tmp4 = std::max(tmp3, min_val); - auto tmp5 = std::min(tmp4, max_val); - tmp_out[col] = tmp5; - } - // set zero - for (long col = kvBlockSize; col < vec_size * (av_gemm_K / vec_size); col += vec_size) { - _store(tmp_out + col, vec_zero); - } - for (long col = vec_size * (av_gemm_K / vec_size); col < av_gemm_K; col++) { - tmp_out[col] = zero; - } - } - } -} - -/* -1. dequant -2. quant -*/ -template -inline void _dequant_quant_fusion_kernel( - const int32_t* in, - const int32_t* sum_a_ptr, - const int32_t* sum_b_ptr, - const int& M, - const int& N, - const int& ldi, - const int& ldo, - const int32_t& beta1, // zp_a*zp_b*k - const int32_t& beta2, // zp_c - const float& alpha, // scale_a*scale_b/scale_c - scalar_t* out) { - const int32_t vec_size = at::vec::Vectorized::size(); - float min_val = 0; - float max_val = 255; - auto vec_min_val = at::vec::Vectorized(min_val); - auto vec_max_val = at::vec::Vectorized(max_val); - auto vec_beta1 = at::vec::Vectorized(beta1); - auto vec_alpha = at::vec::Vectorized(alpha); - float beta2_float = (float) beta2; - auto vec_beta2 = at::vec::Vectorized(beta2_float); - for (long row = 0; row < M; row += 1) { - auto sum_a = sum_a_ptr[row]; - auto vec_sum_a = at::vec::Vectorized(sum_a); - const int32_t* tmp_in = in + row * ldi; - scalar_t* tmp_out = out + row * ldo; - for (long col = 0; col < vec_size * (N / vec_size); col += vec_size) { - auto vec_sum_b = at::vec::Vectorized::loadu(sum_b_ptr + col); - auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col); - auto tmp1 = tmp0 - vec_sum_b; - auto tmp2 = tmp1 - vec_sum_a; - auto tmp3 = tmp2 + vec_beta1; - auto tmp4 = at::vec::convert(tmp3); - auto tmp5 = tmp4 * vec_alpha; - auto tmp6 = tmp5.round(); - auto tmp7 = tmp6 + vec_beta2; - auto tmp8 = at::vec::maximum(tmp7, vec_min_val); - auto tmp9 = at::vec::minimum(tmp8, vec_max_val); - _store(tmp_out + col, tmp9); - } - for (long col = vec_size * (N / vec_size); col < N; col++) { - auto sum_b = sum_b_ptr[col]; - auto tmp0 = tmp_in[col]; - auto tmp1 = tmp0 - sum_b; - auto tmp2 = tmp1 - sum_a; - auto tmp3 = tmp2 + beta1; - auto tmp4 = (float) tmp3; - auto tmp5 = tmp4 * alpha; - auto tmp6 = std::nearbyint(tmp5); - auto tmp7 = tmp6 + beta2_float; - auto tmp8 = std::max(tmp7, min_val); - auto tmp9 = std::min(tmp8, max_val); - tmp_out[col] = tmp9; - } - } -} - -template -inline void _int_sum_b_contiguous_kernel_helper( - const scalar_t* in, - int32_t* out, - const int& N, - const int32_t& scale) { - const int32_t vec_size = at::vec::Vectorized::size(); - int32_t tmp_sum = 0; - auto vec_tmp_sum = at::vec::Vectorized(tmp_sum); - for (long i = 0; i < vec_size * (N / vec_size); i += vec_size) { - auto tmp0 = at::vec::Vectorized::loadu(in + i); - auto tmp1 = at::vec::convert(tmp0); - vec_tmp_sum = vec_tmp_sum + tmp1; - } - tmp_sum += vec_tmp_sum.reduce_add(); - for (long i = vec_size * (N / vec_size); i < N; i++) { - tmp_sum += static_cast(in[i]); - } - out[0] = tmp_sum * scale; -} - -template -inline void _int_sum_b_contiguous_kernel( - const scalar_t* in, - int32_t* out, - const int& M, - const int& N, - const int& ld, - const int32_t& scale) { - for (long r = 0; r < M; r += 1) { - _int_sum_b_contiguous_kernel_helper(in + r * ld, out + r, N, scale); - } -} - -template -inline void _int_sum_a_contiguous_kernel( - const scalar_t* in, - int32_t* out, - const int& M, - const int& N, - const int& ld, - const int32_t& scale) { - const int32_t vec_size = at::vec::Vectorized::size(); - auto vec_scale = at::vec::Vectorized(scale); - // initialization with 0 - int32_t zero = 0; - auto vec_zero = at::vec::Vectorized(zero); - for (long i = 0; i < vec_size * (M / vec_size); i += vec_size) { - _store(out + i, vec_zero); - } - for (long i = vec_size * (M / vec_size); i < M; i++) { - out[i] = zero; - } - // sum - for (long j = 0; j < N; j++) { - const scalar_t* tmp_in = in + j * ld; - for (long i = 0; i < vec_size * (M / vec_size); i += vec_size) { - auto tmp0 = at::vec::Vectorized::loadu(tmp_in + i); - auto tmp1 = at::vec::Vectorized::loadu(out + i); - auto tmp2 = at::vec::convert(tmp0); - auto tmp3 = tmp1 + tmp2; - _store(out + i, tmp3); - } - for (long i = vec_size * (M / vec_size); i < M; i++) { - auto tmp0 = tmp_in[i]; - auto tmp1 = out[i]; - auto tmp2 = static_cast(tmp0); - auto tmp3 = tmp1 + tmp2; - out[i] = tmp3; - } - } - // scale - for (long i = 0; i < vec_size * (M / vec_size); i += vec_size) { - auto tmp0 = at::vec::Vectorized::loadu(out + i); - auto tmp1 = tmp0 * vec_scale; - _store(out + i, tmp1); - } - for (long i = vec_size * (M / vec_size); i < M; i++) { - auto tmp0 = out[i]; - auto tmp1 = tmp0 * scale; - out[i] = tmp1; - } -} - -inline void do_convert_u8_s8( - unsigned char* src, - signed char* dst, - int64_t in_rows, - int64_t in_cols, - int64_t ldi, - int64_t ldo) { - auto vec_size = at::vec::Vectorized::size(); - auto vec_128 = at::vec::Vectorized(128); - for (int64_t r = 0; r < in_rows; r++) { - const unsigned char* tmp_src = src + r * ldi; - signed char* tmp_dst = dst + r * ldo; - for (int64_t c = 0; c < vec_size * (in_cols / vec_size); c += vec_size) { - auto tmp0 = at::vec::Vectorized::loadu(tmp_src + c, vec_size); - auto tmp1 = at::vec::convert(tmp0); - auto tmp2 = tmp1 - vec_128; - auto tmp3 = at::vec::convert(tmp2); - _store(tmp_dst + c, tmp3, vec_size); - } - for (int64_t c = vec_size * (in_cols / vec_size); c < in_cols; c++) { - auto tmp0 = tmp_src[c]; - auto tmp1 = (int16_t) tmp0; - auto tmp2 = tmp1 - 128; - auto tmp3 = (signed char) tmp2; - tmp_dst[c] = tmp3; - } - } -} - -template -inline void do_transpose( - scalar_t* src, - scalar_t* dst, - int64_t in_rows, - int64_t in_cols, - int64_t ldi, - int64_t ldo) { - for (int64_t r=0; r -inline void do_copy( - scalar_t* src, - scalar_t* dst, - int64_t in_rows, - int64_t in_cols, - int64_t ldi, - int64_t ldo) { - for (int64_t r=0; r -inline void pad_remain_row_col( - scalar_t* value_ptr, - int rows, - int cols, - int prows, - int pcols, - int ldi, - scalar_t pad_val=0) { - auto psize = pcols - cols; - if (psize == 0 && prows == rows) { - return; - } - const int32_t vec_size = at::vec::Vectorized::size(); - auto pad = at::vec::Vectorized(pad_val); - if (psize > 0) { - for (int i = 0; i < rows; i++) { - int j = 0; - for (; j < psize - (psize % vec_size); j += vec_size) { - pad.store(value_ptr + i * ldi + cols + j); - } - if (j < psize) { - pad.store(value_ptr + i * ldi + cols + j, psize - j); - } - } - } - - for (int i = rows; i < prows; i++) { - int j = 0; - for (; j < pcols - (pcols % vec_size); j += vec_size) { - pad.store(value_ptr + i * ldi + j); - } - if (j < pcols) { - pad.store(value_ptr + i * ldi + j, pcols - j); - } - } -} - -template -inline void copy_value_with_pad( - scalar_t* value_ptr, - scalar_t* dst_ptr, - int rows, - int cols, - int prows, - int pcols, - int ldi, - scalar_t pad_val=0) { - const int32_t vec_size = at::vec::Vectorized::size(); - auto pad = at::vec::Vectorized(pad_val); - int i = 0; - for (; i < rows; i++) { - int j = 0; - for (; j < cols - (cols % vec_size); j += vec_size) { - auto vec_v = - at::vec::Vectorized::loadu(value_ptr + i * ldi + j); - vec_v.store(dst_ptr + i * pcols + j); - } - - if (j < cols) { - auto vec_v = at::vec::Vectorized::loadu( - value_ptr + i * ldi + j, cols - j); - vec_v.store(dst_ptr + i * pcols + j, cols - j); - } - - // col padding - auto psize = pcols - cols; - if (psize > 0) { - int pj = 0; - for (; pj < psize - (psize % vec_size); pj += vec_size) { - pad.store(dst_ptr + i * pcols + cols + pj); - } - if (pj < psize) { - pad.store(dst_ptr + i * pcols + cols + pj, psize - pj); - } - } - } - - // row padding - for (; i < prows; i++) { - int j = 0; - for (; j < pcols - (pcols % vec_size); j += vec_size) { - pad.store(dst_ptr + i * pcols + j); - } - if (j < pcols) { - pad.store(dst_ptr + i * pcols + j, pcols - j); - } - - } - -} - -// UINT8 - one parallel loop with u8u8s32 GEMM -template -inline typename std::enable_if_t, void> -sdpa_int8_kernel_one_loop_impl( - const at::Tensor& output, - const at::Tensor& query, - const at::Tensor& key, - const at::Tensor& value, - double dropout_p, - bool is_causal, - at::Tensor& attention_mask, - double scale, - int32_t q_zp, - float q_scale, - int32_t k_zp, - float k_scale, - int32_t v_zp, - float v_scale, - int32_t a_zp, - float a_scale, - int32_t o_zp, - float o_scale) { - // Query (Batch x Q_seq_len x Num_heads x Dim_per_head) - // Key (Batch x KV_seq_len x Num_heads x Dim_per_head) - // Value (Batch x KV_seq_len x Num_heads x Dim_per_head) - - const auto accumulate_dtype = at::kFloat; - - using accum_t = float; - using Vec = at::vec::Vectorized; - accum_t scaling_factor = calculate_scale(query, scale); - int block_64 = 64; - // Sizes - TORCH_CHECK( - (query.size(3) == value.size(3)) && (key.size(3) == value.size(3)), - "scaled_dot_product_attention_sdpa: Q/K/V should have the same head size"); - TORCH_CHECK( - kv_split_size % block_64 == 0, "kv_split_size is not divisble by ", block_64); - - int64_t batchSize = query.size(0); - int64_t qSize = query.size(1); - int64_t kvSize = value.size(1); - int64_t num_head = query.size(2); - int64_t headSize = query.size(3); - - - bool has_attn_mask = attention_mask.defined() && attention_mask.numel(); - if (has_attn_mask) { - reshape_attn_mask_to_4d(attention_mask, batchSize, num_head, qSize, kvSize); - } - - // Strides - int64_t qStrideB = query.stride(0); - int64_t qStrideM = query.stride(1); - int64_t qStrideH = query.stride(2); - int64_t kStrideB = key.stride(0); - int64_t kStrideN = key.stride(1); - int64_t kStrideH = key.stride(2); - int64_t vStrideB = value.stride(0); - int64_t vStrideN = value.stride(1); - int64_t vStrideH = value.stride(2); - int64_t oStrideB = output.stride(0); - int64_t oStrideM = output.stride(1); - int64_t oStrideH = output.stride(2); - int64_t mStrideB = - (attention_mask.defined() && attention_mask.size(0) > 1) - ? attention_mask.stride(0) - : 0; - int64_t mStrideH = - (attention_mask.defined() && attention_mask.size(1) > 1) - ? attention_mask.stride(1) - : 0; - int64_t mStrideM = - (attention_mask.defined() && attention_mask.size(2) > 1) - ? attention_mask.stride(2) - : 0; - int64_t mStrideN = - (attention_mask.defined() && attention_mask.size(3) > 1) - ? attention_mask.stride(3) - : 0; - - int64_t qSplitSize = q_split_size > qSize ? qSize : q_split_size; - int64_t kvSplitSize = kv_split_size > kvSize ? kvSize : kv_split_size; - int64_t qSlice = (qSize - 1) / qSplitSize + 1; - int64_t qTail = (qSize - 1) % qSplitSize + 1; - int64_t kvSlice = (kvSize - 1) / kvSplitSize + 1; - int64_t kvTail = (kvSize - 1) % kvSplitSize + 1; - int64_t num_thread = at::get_num_threads(); - - int64_t rndHeadSize = (headSize + block_64 - 1L) / block_64 * block_64; - // one of 16, 32, 48, 64 - auto select_tail_tail_block_size = [](int64_t size) -> int64_t { - if (size == 0) { - return 0; - } else if (size <= 16) { - return 16; - } else if (size <= 32) { - return 32; - } else if (size <= 48) { - return 48; - } else { - return 64; - } - }; - int64_t kv_tail_tail_block_size = select_tail_tail_block_size(kvTail % block_64); - int64_t rndkvSplitSize = (kvSplitSize + block_64 - 1L) / block_64 * block_64; - int64_t rndkvTail = (kvTail + block_64 - 1L) / block_64 * block_64; - int64_t rndkvSize = kv_split_size > kvSize ? rndkvTail : rndkvSplitSize * kvSlice + rndkvTail; - - bool av_gemm_K_mul4 = kvSplitSize % 4 == 0; - int av_gemm_K_padding = av_gemm_K_mul4 ? 0 : 4 - kvSplitSize % 4; - // // If K of Gemm is not even, use mkl gemm instead of tpp for BF16 - int av_gemm_K = kvSplitSize + av_gemm_K_padding; - bool av_gemm_K_tail_mul4 = kvTail % 4 == 0; - int av_gemm_K_tail_padding = av_gemm_K_tail_mul4 ? 0 : 4 - kvTail % 4; - int av_gemm_K_tail = kvTail + av_gemm_K_tail_padding; - - auto u8_dt = at::ScalarType::Byte; - auto s8_dt = at::ScalarType::Int; - auto f32_dt = at::ScalarType::Float; - - // Data ptrs - scalar_t* q_data = query.data_ptr(); - scalar_t* k_data = key.data_ptr(); - scalar_t* v_data = value.data_ptr(); - mask_t* mask_data = attention_mask.defined() - ? attention_mask.data_ptr() - : nullptr; - scalar_t* out_data = output.data_ptr(); - - // Create tpp kernels for Query @ Key - bool headSize_mul4 = headSize % 4 == 0; - // If K of Gemm is not even, use mkl gemm instead of tpp for BF16 - int qk_gemm_K_padding = headSize_mul4 ? 0 : 4 - headSize % 4; - int qk_gemm_K = headSize + qk_gemm_K_padding; - - int64_t qk_reduce_strideL = qSplitSize * av_gemm_K; - int64_t v_reorder_strideL = av_gemm_K * rndHeadSize; - - int64_t total_size_uint8_per_thread = - /* qk */ kvSlice * qSplitSize * rndkvSplitSize * 4 + - /* qk_local */ kvSlice * av_gemm_K * 4 + - /* qk_reduce */ kvSlice * qk_reduce_strideL + - /* qk_s32 */ qSplitSize * rndkvSplitSize * 4 + - /* dst_s32 */ qSplitSize * rndHeadSize * 4 + - /* softmax_sum */ qSplitSize * 4 + - /* query_sum */ qSplitSize * 4 + - /* attention_sum */ qSplitSize * 4 + - /* softmax max */ qSplitSize * 4 + - /* query_padding_data */ qSplitSize * qk_gemm_K + - /* key_sum */ kvSize * 4 + - /* value_sum */ headSize * 4 + - /* key_t_reorder */ qk_gemm_K * rndkvSize + - /* value_t_reorder */ kvSlice * v_reorder_strideL; - - at::Tensor total_buf = at::empty( - {num_thread, total_size_uint8_per_thread}, - query.options()).zero_(); - scalar_t* total_buf_data = total_buf.data_ptr(); - - at::parallel_for( - 0, batchSize * num_head, 1, [&](int64_t begin, int64_t end) { - int64_t i = 0, j = 0; - at::native::data_index_init( - begin, i, batchSize, j, num_head); - int ompIdx = at::get_thread_num(); - scalar_t* total_buf_ptr = total_buf_data + ompIdx * total_size_uint8_per_thread; - int32_t offset = 0; - accum_t* qk_data = reinterpret_cast(total_buf_ptr); - offset += kvSlice * qSplitSize * rndkvSplitSize * 4; - accum_t* qk_local_data = reinterpret_cast(total_buf_ptr + offset); - offset += kvSlice * av_gemm_K * 4; - scalar_t* qk_reduced_data = reinterpret_cast(total_buf_ptr + offset); - offset += kvSlice * qk_reduce_strideL; - int32_t* qk_s32_data = reinterpret_cast(total_buf_ptr + offset); - offset += qSplitSize * rndkvSplitSize * 4; - int32_t* dst_s32_data = reinterpret_cast(total_buf_ptr + offset); - offset += qSplitSize * rndHeadSize * 4; - accum_t* sfm_sum_ptr = reinterpret_cast(total_buf_ptr + offset); - offset += qSplitSize * 4; - int32_t* q_sum_ptr = reinterpret_cast(total_buf_ptr + offset); - offset += qSplitSize * 4; - int32_t* a_sum_ptr = reinterpret_cast(total_buf_ptr + offset); - offset += qSplitSize * 4; - accum_t* sfm_max_ptr = reinterpret_cast(total_buf_ptr + offset); - offset += qSplitSize * 4; - scalar_t* query_t_padding_ptr = reinterpret_cast(total_buf_ptr + offset); - offset += qSplitSize * qk_gemm_K; - - int32_t* k_sum_ptr = reinterpret_cast(total_buf_ptr + offset); - offset += kvSize * 4; - int32_t* v_sum_ptr = reinterpret_cast(total_buf_ptr + offset); - offset += headSize * 4; - scalar_t* key_reorder_ptr = reinterpret_cast(total_buf_ptr + offset); - offset += qk_gemm_K * rndkvSize; - scalar_t* value_reorder_ptr = reinterpret_cast(total_buf_ptr + offset); - - uint8_t* B_blocked_xform_u8 = new uint8_t[std::max(qk_gemm_K, av_gemm_K) * block_64]; - - for (const auto z : c10::irange(begin, end)) { - (void)z; // Suppress unused variable - - // sum k and v - if (q_zp == 0) { - fill_stub(k_sum_ptr, static_cast(0), kvSize); - } else { - _int_sum_b_contiguous_kernel(k_data + i * kStrideB + j * kStrideH, - k_sum_ptr, - kvSize, headSize, kStrideN, q_zp); - } - if (a_zp == 0) { - fill_stub(v_sum_ptr, static_cast(0), headSize); - } else { - _int_sum_a_contiguous_kernel(v_data + i * vStrideB + j * vStrideH, - v_sum_ptr, - headSize, kvSize, vStrideN, a_zp); - } - - // pack - for (int64_t n = 0; n < kvSize; n += kvSplitSize) { - if (n + kvSplitSize < kvSize) { - for (int64_t b = 0; b < rndkvSplitSize; b += block_64) { - do_transpose( - k_data + i * kStrideB + j * kStrideH + n * kStrideN + b * kStrideN, - B_blocked_xform_u8, - std::min(int(kvSplitSize - b), block_64), - headSize, - kStrideN, - block_64); - if (!headSize_mul4) { - pad_remain_row_col( - B_blocked_xform_u8, - headSize, - block_64, - qk_gemm_K, - block_64, - block_64 - ); - } - // Pack - at::native::cpublas::pack( - qk_gemm_K, // K - block_64, // N - block_64, // ld_in - block_64, // ld_out - u8_dt, // dt_in - u8_dt, // dt_out - B_blocked_xform_u8, - key_reorder_ptr + n * qk_gemm_K + - b * qk_gemm_K); - } - // split headSize to block_64, block_64, block_64 ... - // [kvSplitSize, headSize] -> [av_gemm_K, block_64 ...] - for (int64_t b = 0; b < rndHeadSize; b += block_64) { - at::native::cpublas::pack( - av_gemm_K, - block_64, - vStrideN, // block_64, - block_64, - u8_dt, - u8_dt, - v_data + i * vStrideB + j * vStrideH + n * vStrideN + b, - value_reorder_ptr + n * rndHeadSize + - av_gemm_K * b); - } - } else { - // tail - auto block_size = kvTail >= block_64 ? block_64 : kv_tail_tail_block_size; - int64_t b = 0; - while (b < rndkvTail) { - do_transpose( - k_data + i * kStrideB + j * kStrideH + n * kStrideN + b * kStrideN, - B_blocked_xform_u8, - std::min(kvTail - b, block_size), - headSize, - kStrideN, - block_size); - if (!headSize_mul4) { - pad_remain_row_col( - B_blocked_xform_u8, - headSize, - block_size, - qk_gemm_K, - block_size, - block_size - ); - } - // Pack - at::native::cpublas::pack( - qk_gemm_K, - block_size, - block_size, - block_size, - u8_dt, - u8_dt, - B_blocked_xform_u8, - key_reorder_ptr + n * qk_gemm_K + - b * qk_gemm_K); - b += block_size; - block_size = (kvTail - b) >= block_64 ? block_64 : kv_tail_tail_block_size; - } - // split headSize to block_64, block_64, block_64 ... - // [kvTail, headSize] -> [av_gemm_K_tail, block_64 ...] - for (int64_t b = 0; b < headSize; b += block_64) { - at::native::cpublas::pack( - av_gemm_K, - block_64, - vStrideN, // block_64, - block_64, - u8_dt, - u8_dt, - v_data + i * vStrideB + j * vStrideH + n * vStrideN + b, - value_reorder_ptr + n * rndHeadSize + - av_gemm_K * b); - } - } - } - - // sdpa core - for (int64_t k = 0; k < qSlice; k++) { - int64_t m = k * qSplitSize; - int64_t qBlockSize = std::min(qSplitSize, qSize - m); - // Initialize sum and max - fill_stub( - sfm_sum_ptr, static_cast(0), qSplitSize); - fill_stub( - a_sum_ptr, static_cast(0), qSplitSize); - fill_stub( - sfm_max_ptr, static_cast(-std::numeric_limits::infinity()), qSplitSize); - int64_t num_keys = - is_causal ? std::min(m + qBlockSize, kvSize) : kvSize; - copy_value_with_pad( - q_data + i * qStrideB + j * qStrideH + m * qStrideM, - query_t_padding_ptr, - qBlockSize, - headSize, - qBlockSize, - qk_gemm_K, - qStrideM); - - if (k_zp != 0) { - _int_sum_b_contiguous_kernel(q_data + i * qStrideB + j * qStrideH + m * qStrideM, - q_sum_ptr, qBlockSize, headSize, qStrideM, k_zp); - } else { - fill_stub( - q_sum_ptr, static_cast(0), qSplitSize); - } - const int64_t rkvSlice = (num_keys - 1) / kvSplitSize + 1; - for (int64_t l = 0; l < rkvSlice; l++) { - int64_t n = l * kvSplitSize; - int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n); - // Calculate sums for dequant compensation item - if (qBlockSize == qSplitSize) { - // q main - if (n + kvSplitSize < kvSize) { - // k main - for (int64_t b = 0; b < kvSplitSize; b += block_64) { - at::native::cpublas::brgemm( - qSplitSize, block_64, qk_gemm_K, - qk_gemm_K, // lda - block_64, //ldb - rndkvSplitSize, //ldc, - false, - query_t_padding_ptr, - key_reorder_ptr + n * qk_gemm_K + - b * qk_gemm_K, - qk_s32_data + b); - } - } else { - auto block_size = kvTail >= block_64 ? block_64 : kv_tail_tail_block_size; - int64_t b = 0; - while (b < kvTail) { - at::native::cpublas::brgemm( - qSplitSize, block_size, qk_gemm_K, - qk_gemm_K, // lda - block_size, //ldb - rndkvTail, //ldc - false, - query_t_padding_ptr, - key_reorder_ptr + n * qk_gemm_K + - b * qk_gemm_K, - qk_s32_data + b); - b += block_size; - block_size = (kvTail - b) >= block_64 ? block_64 : kv_tail_tail_block_size; - } - } - } else { - if (n + kvSplitSize < kvSize) { - for (int64_t b = 0; b < kvSplitSize; b += block_64) { - at::native::cpublas::brgemm( - qTail, block_64, qk_gemm_K, - qk_gemm_K,// lda - block_64, //ldb - rndkvSplitSize, //ldc - false, - query_t_padding_ptr, - key_reorder_ptr + n * qk_gemm_K + - b * qk_gemm_K, - qk_s32_data + b); - } - } else { - // k tail - auto block_size = kvTail >= block_64 ? block_64 : kv_tail_tail_block_size; - int64_t b = 0; - while (b < kvTail) { - at::native::cpublas::brgemm( - qTail, block_size, qk_gemm_K, - qk_gemm_K, // lda - block_size, //ldb - rndkvTail, //ldc - false, - query_t_padding_ptr, - key_reorder_ptr + n * qk_gemm_K + - b * qk_gemm_K, - qk_s32_data + b); - b += block_size; - block_size = (kvTail - b) >= block_64 ? block_64 : kv_tail_tail_block_size; - } - } - } - - // do dequant compensation, add mask, max reduce for softmax, and convert qk from s32 to fp32 - int64_t rndkvBlockSize = kvBlockSize == kvSplitSize ? rndkvSplitSize : rndkvTail; - accum_t* qk_block_data = qk_data + l * qSplitSize * rndkvSplitSize; - if (has_attn_mask) { - mask_t* mask_data_offset = mask_data + i * mStrideB + j * mStrideH + m * mStrideM + (mStrideN == 0 ? 0 : n); - _dequant_mask_max_fusion_kernel( - qk_s32_data, //in - mask_data_offset, //mask_ptr - q_sum_ptr, //sum_a_ptr - k_sum_ptr + n, //sum_b_ptr - qBlockSize, //M - kvBlockSize, //N - rndkvBlockSize, //ldi - mStrideM, //ldm - rndkvSplitSize,//kvBlockSize, //ldo - q_zp * k_zp * headSize, //zp_a*zp_b*k=beta - q_scale * k_scale * scaling_factor, //scale_a*scale_b*scale_sdpa=alpha - qk_block_data, //out - sfm_max_ptr // sfm_max_ptr - ); - } else { - _dequant_max_fusion_kernel( - qk_s32_data, //in - q_sum_ptr, //sum_a_ptr - k_sum_ptr + n, //sum_b_ptr - qBlockSize, //M - kvBlockSize, //N - rndkvBlockSize, //ldi - rndkvSplitSize,//kvBlockSize, //ldo - q_zp * k_zp * headSize, //zp_a*zp_b*k=beta - q_scale * k_scale * scaling_factor, //scale_a*scale_b*scale_sdpa=alpha - qk_block_data, //out - sfm_max_ptr // sfm_max_ptr - ); - } - } - // sub max, exp, sum reduce, div sum for softmax - // and quant - // and sum for attention - if (v_zp == 0) { - _sub_exp_sum_div_quant_fusion_kernel( - qk_data, //in - qBlockSize, //M - kvSplitSize, //N_step - rkvSlice, //NSlices - qSplitSize * rndkvSplitSize, //ldi - qk_reduce_strideL, //ldo - kvSize, //kvSize - rndkvSplitSize, //rndkvSplitSize - av_gemm_K, //av_gemm_K - a_zp, // zp_a=beta1 - a_scale, // scale_a=alpha - qk_local_data, //local - qk_reduced_data, //out - sfm_max_ptr, //sfm_max_ptr - sfm_sum_ptr //sfm_sum_ptr - ); - } else { - _sub_exp_sum_div_quant_sum_fusion_kernel( - qk_data, //in - qBlockSize, //M - kvSplitSize, //N_step - rkvSlice, //NSlice - qSplitSize * rndkvSplitSize, //ldi - qk_reduce_strideL, //ldo - kvSize, //kvSize - rndkvSplitSize, //rndkvSplitSize - av_gemm_K, //av_gemm_K - a_zp, // zp_a=beta1 - v_zp, // zp_b=beta2 - a_scale, // scale_a=alpha - qk_local_data, //local - qk_reduced_data, //out - sfm_max_ptr, //sfm_max_ptr - sfm_sum_ptr, //sfm_sum_ptr - a_sum_ptr //a_sum_ptr - ); - } - // Calculate Softmax(q @ k.T) @ v - for (int64_t b = 0; b < headSize; b += block_64) { - auto value_reorder_b = value_reorder_ptr + b * av_gemm_K; - auto dst_s32_b = dst_s32_data + b; - for (int64_t s = 0; s < kvSlice; s++) { - at::native::cpublas::brgemm( - qSplitSize, block_64, av_gemm_K, - av_gemm_K, // lda - rndHeadSize, //block_64, //ldb - rndHeadSize, //ldc - s != 0, - qk_reduced_data + s * qk_reduce_strideL, - value_reorder_b + s * v_reorder_strideL, - dst_s32_b); - } - } - - // After the last gemm, - // do dequant compensation, quant and convert from s32 to int8 - _dequant_quant_fusion_kernel( - dst_s32_data, //in - a_sum_ptr, //sum_a_ptr - v_sum_ptr, //sum_b_ptr - qBlockSize, //M - headSize, //N - rndHeadSize, //ldi - oStrideM, //ldo - a_zp * v_zp * kvSize, //zp_a*zp_b*k=beta1 - o_zp, //zp_c=beta2 - a_scale * v_scale / o_scale, //scale_a*scale_b/scale_c=alpha - out_data + i * oStrideB + j * oStrideH + m * oStrideM //out - ); - } - // Move to the next query - at::native::data_index_step(i, batchSize, j, num_head); - } - }); - // Once all computations are done, need to release HW context. - at::native::cpublas::brgemm_release(); -} - -// UINT8 - several parallel loops with u8u8s32 GEMM -template -inline typename std::enable_if_t, void> -sdpa_int8_kernel_several_loops_impl( - const at::Tensor& output, - const at::Tensor& query, - const at::Tensor& key, - const at::Tensor& value, - double dropout_p, - bool is_causal, - at::Tensor& attention_mask, - double scale, - int32_t q_zp, - float q_scale, - int32_t k_zp, - float k_scale, - int32_t v_zp, - float v_scale, - int32_t a_zp, - float a_scale, - int32_t o_zp, - float o_scale) { - // Query (Batch x Q_seq_len x Num_heads x Dim_per_head) - // Key (Batch x KV_seq_len x Num_heads x Dim_per_head) - // Value (Batch x KV_seq_len x Num_heads x Dim_per_head) - - const auto accumulate_dtype = at::kFloat; - - using accum_t = float; - using Vec = at::vec::Vectorized; - accum_t scaling_factor = calculate_scale(query, scale); - int block_64 = 64; - // Sizes - TORCH_CHECK( - (query.size(3) == value.size(3)) && (key.size(3) == value.size(3)), - "scaled_dot_product_attention_sdpa: Q/K/V should have the same head size"); - TORCH_CHECK( - kv_split_size % block_64 == 0, "kv_split_size is not divisble by ", block_64); - - int64_t batchSize = query.size(0); - int64_t qSize = query.size(1); - int64_t kvSize = value.size(1); - int64_t num_head = query.size(2); - int64_t headSize = query.size(3); - - - bool has_attn_mask = attention_mask.defined() && attention_mask.numel(); - if (has_attn_mask) { - reshape_attn_mask_to_4d(attention_mask, batchSize, num_head, qSize, kvSize); - } - - // Strides - int64_t qStrideB = query.stride(0); - int64_t qStrideM = query.stride(1); - int64_t qStrideH = query.stride(2); - int64_t kStrideB = key.stride(0); - int64_t kStrideN = key.stride(1); - int64_t kStrideH = key.stride(2); - int64_t vStrideB = value.stride(0); - int64_t vStrideN = value.stride(1); - int64_t vStrideH = value.stride(2); - int64_t oStrideB = output.stride(0); - int64_t oStrideM = output.stride(1); - int64_t oStrideH = output.stride(2); - int64_t mStrideB = - (attention_mask.defined() && attention_mask.size(0) > 1) - ? attention_mask.stride(0) - : 0; - int64_t mStrideH = - (attention_mask.defined() && attention_mask.size(1) > 1) - ? attention_mask.stride(1) - : 0; - int64_t mStrideM = - (attention_mask.defined() && attention_mask.size(2) > 1) - ? attention_mask.stride(2) - : 0; - int64_t mStrideN = - (attention_mask.defined() && attention_mask.size(3) > 1) - ? attention_mask.stride(3) - : 0; - - int64_t qSplitSize = q_split_size > qSize ? qSize : q_split_size; - int64_t kvSplitSize = kv_split_size > kvSize ? kvSize : kv_split_size; - int64_t qSlice = (qSize - 1) / qSplitSize + 1; - int64_t qTail = (qSize - 1) % qSplitSize + 1; - int64_t kvSlice = (kvSize - 1) / kvSplitSize + 1; - int64_t kvTail = (kvSize - 1) % kvSplitSize + 1; - int64_t num_thread = at::get_num_threads(); - - int64_t rndHeadSize = (headSize + block_64 - 1L) / block_64 * block_64; - // one of 16, 32, 48, 64 - auto select_tail_tail_block_size = [](int64_t size) -> int64_t { - if (size == 0) { - return 0; - } else if (size <= 16) { - return 16; - } else if (size <= 32) { - return 32; - } else if (size <= 48) { - return 48; - } else { - return 64; - } - }; - int64_t kv_tail_tail_block_size = select_tail_tail_block_size(kvTail % block_64); - int64_t rndkvSplitSize = (kvSplitSize + block_64 - 1L) / block_64 * block_64; - int64_t rndkvTail = (kvTail + block_64 - 1L) / block_64 * block_64; - int64_t rndkvSize = kv_split_size > kvSize ? rndkvTail : rndkvSplitSize * kvSlice + rndkvTail; - - bool av_gemm_K_mul4 = kvSplitSize % 4 == 0; - int av_gemm_K_padding = av_gemm_K_mul4 ? 0 : 4 - kvSplitSize % 4; - // // If K of Gemm is not even, use mkl gemm instead of tpp for BF16 - int av_gemm_K = kvSplitSize + av_gemm_K_padding; - bool av_gemm_K_tail_mul4 = kvTail % 4 == 0; - int av_gemm_K_tail_padding = av_gemm_K_tail_mul4 ? 0 : 4 - kvTail % 4; - int av_gemm_K_tail = kvTail + av_gemm_K_tail_padding; - - auto u8_dt = at::ScalarType::Byte; - auto s8_dt = at::ScalarType::Int; - auto f32_dt = at::ScalarType::Float; - - // Data ptrs - scalar_t* q_data = query.data_ptr(); - scalar_t* k_data = key.data_ptr(); - scalar_t* v_data = value.data_ptr(); - mask_t* mask_data = attention_mask.defined() - ? attention_mask.data_ptr() - : nullptr; - scalar_t* out_data = output.data_ptr(); - - // Create tpp kernels for Query @ Key - bool headSize_mul4 = headSize % 4 == 0; - // If K of Gemm is not even, use mkl gemm instead of tpp for BF16 - int qk_gemm_K_padding = headSize_mul4 ? 0 : 4 - headSize % 4; - int qk_gemm_K = headSize + qk_gemm_K_padding; - - int64_t qk_reduce_strideL = qSplitSize * av_gemm_K; - int64_t v_reorder_strideL = av_gemm_K * rndHeadSize; - - int64_t total_size_uint8_per_thread = - /* qk */ kvSlice * qSplitSize * rndkvSplitSize * 4 + - /* qk_local */ kvSlice * av_gemm_K * 4 + - /* qk_reduce */ kvSlice * qk_reduce_strideL + - /* qk_s32 */ qSplitSize * rndkvSplitSize * 4 + - /* dst_s32 */ qSplitSize * rndHeadSize * 4 + - /* softmax_sum */ qSplitSize * 4 + - /* query_sum */ qSplitSize * 4 + - /* attention_sum */ qSplitSize * 4 + - /* softmax max */ qSplitSize * 4 + - /* query_padding_data */ qSplitSize * qk_gemm_K; - - at::Tensor total_buf = at::empty( - {num_thread, total_size_uint8_per_thread}, - query.options()).zero_(); - scalar_t* total_buf_data = total_buf.data_ptr(); - - int64_t kv_sum_size_per_BH = - /* key_sum */ kvSize + - /* value_sum */ headSize; - - at::Tensor kv_sum_buf = at::empty( - {batchSize, num_head, kv_sum_size_per_BH}, - query.options().dtype(at::kInt)).zero_(); - int32_t* kv_sum_buf_data = kv_sum_buf.data_ptr(); - - int64_t kv_reorder_size_per_BH = - /* key_t_reorder */ qk_gemm_K * rndkvSize + - /* value_t_reorder */ kvSlice * v_reorder_strideL; - - at::Tensor kv_reorder_buf = at::empty( - {batchSize, num_head, kv_reorder_size_per_BH}, - query.options()).zero_(); - scalar_t* kv_reorder_buf_data = kv_reorder_buf.data_ptr(); - scalar_t* key_reorder_ptr = kv_reorder_buf_data; - scalar_t* value_reorder_ptr = kv_reorder_buf_data + batchSize * num_head * qk_gemm_K * rndkvSize; - - // sum k and v - at::parallel_for( - 0, batchSize * num_head, 1, [&](int64_t begin, int64_t end) { - int64_t i = 0, j = 0; - at::native::data_index_init( - begin, i, batchSize, j, num_head); - for (const auto z : c10::irange(begin, end)) { - (void)z; // Suppress unused variable - int32_t* kv_sum_ptr = kv_sum_buf_data - + i * num_head * kv_sum_size_per_BH - + j * kv_sum_size_per_BH; - int32_t* k_sum_ptr = kv_sum_ptr; - int32_t* v_sum_ptr = kv_sum_ptr + kvSize; - if (q_zp == 0) { - fill_stub(k_sum_ptr, static_cast(0), kvSize); - } else { - _int_sum_b_contiguous_kernel(k_data + i * kStrideB + j * kStrideH, - k_sum_ptr, - kvSize, headSize, kStrideN, q_zp); - } - if (a_zp == 0) { - fill_stub(v_sum_ptr, static_cast(0), headSize); - } else { - _int_sum_a_contiguous_kernel(v_data + i * vStrideB + j * vStrideH, - v_sum_ptr, - headSize, kvSize, vStrideN, a_zp); - } - // Move to the next query - at::native::data_index_step(i, batchSize, j, num_head); - } - }); - - // packing - at::parallel_for( - 0, batchSize * num_head * kvSlice, 1, [&](int64_t begin, int64_t end) { - int64_t i = 0, j = 0, l = 0, n = 0; - at::native::data_index_init( - begin, i, batchSize, j, num_head, l, kvSlice); - uint8_t* B_blocked_xform_u8 = new uint8_t[std::max(qk_gemm_K, av_gemm_K) * block_64]; - for (const auto z : c10::irange(begin, end)) { - (void)z; // Suppress unused variable - n = l * kvSplitSize; - auto k_reorder = key_reorder_ptr + i * num_head * qk_gemm_K * rndkvSize + - j * qk_gemm_K * rndkvSize + n * qk_gemm_K; - auto v_reorder = value_reorder_ptr + - i * num_head * kvSlice * v_reorder_strideL + - j * kvSlice * v_reorder_strideL + n * rndHeadSize; - if (n + kvSplitSize < kvSize) { - for (int64_t b = 0; b < rndkvSplitSize; b += block_64) { - do_transpose( - k_data + i * kStrideB + j * kStrideH + n * kStrideN + b * kStrideN, - B_blocked_xform_u8, - std::min(int(kvSplitSize - b), block_64), - headSize, - kStrideN, - block_64); - if (!headSize_mul4) { - pad_remain_row_col( - B_blocked_xform_u8, - headSize, - block_64, - qk_gemm_K, - block_64, - block_64 - ); - } - at::native::cpublas::pack( - qk_gemm_K, // K - block_64, // N - block_64, // ld_in - block_64, // ld_out - u8_dt, // dt_in - u8_dt, // dt_out - B_blocked_xform_u8, - k_reorder + b * qk_gemm_K); - } - // split headSize to block_64, block_64, block_64 ... - // [kvSplitSize, headSize] -> [av_gemm_K, block_64 ...] - for (int64_t b = 0; b < rndHeadSize; b += block_64) { - at::native::cpublas::pack( - av_gemm_K, - block_64, - vStrideN, // block_64, - block_64, - u8_dt, - u8_dt, - v_data + i * vStrideB + j * vStrideH + n * vStrideN + b, - v_reorder + av_gemm_K * b); - } - } else { - // tail - auto block_size = kvTail >= block_64 ? block_64 : kv_tail_tail_block_size; - int64_t b = 0; - while (b < rndkvTail) { - do_transpose( - k_data + i * kStrideB + j * kStrideH + n * kStrideN + b * kStrideN, - B_blocked_xform_u8, - std::min(kvTail - b, block_size), - headSize, - kStrideN, - block_size); - if (!headSize_mul4) { - pad_remain_row_col( - B_blocked_xform_u8, - headSize, - block_size, - qk_gemm_K, - block_size, - block_size - ); - } - // Pack - at::native::cpublas::pack( - qk_gemm_K, - block_size, - block_size, - block_size, - u8_dt, - u8_dt, - B_blocked_xform_u8, - k_reorder + b * qk_gemm_K); - b += block_size; - block_size = (kvTail - b) >= block_64 ? block_64 : kv_tail_tail_block_size; - } - // split headSize to block_64, block_64, block_64 ... - // [kvTail, headSize] -> [av_gemm_K_tail, block_64 ...] - for (int64_t b = 0; b < headSize; b += block_64) { - at::native::cpublas::pack( - av_gemm_K, - block_64, - vStrideN, // block_64, - block_64, - u8_dt, - u8_dt, - v_data + i * vStrideB + j * vStrideH + n * vStrideN + b, - v_reorder + av_gemm_K * b); - } - } - // Move to the next query - at::native::data_index_step(i, batchSize, j, num_head, l, kvSlice); - } - }); - - at::parallel_for( - 0, batchSize * num_head * qSlice, 1, [&](int64_t begin, int64_t end) { - int64_t i = 0, j = 0, k = 0; - at::native::data_index_init( - begin, i, batchSize, j, num_head, k, qSlice); - int ompIdx = at::get_thread_num(); - scalar_t* total_buf_ptr = total_buf_data + ompIdx * total_size_uint8_per_thread; - int32_t offset = 0; - accum_t* qk_data = reinterpret_cast(total_buf_ptr); - offset += kvSlice * qSplitSize * rndkvSplitSize * 4; - accum_t* qk_local_data = reinterpret_cast(total_buf_ptr + offset); - offset += kvSlice * av_gemm_K * 4; - scalar_t* qk_reduced_data = reinterpret_cast(total_buf_ptr + offset); - offset += kvSlice * qk_reduce_strideL; - int32_t* qk_s32_data = reinterpret_cast(total_buf_ptr + offset); - offset += qSplitSize * rndkvSplitSize * 4; - int32_t* dst_s32_data = reinterpret_cast(total_buf_ptr + offset); - offset += qSplitSize * rndHeadSize * 4; - accum_t* sfm_sum_ptr = reinterpret_cast(total_buf_ptr + offset); - offset += qSplitSize * 4; - int32_t* q_sum_ptr = reinterpret_cast(total_buf_ptr + offset); - offset += qSplitSize * 4; - int32_t* a_sum_ptr = reinterpret_cast(total_buf_ptr + offset); - offset += qSplitSize * 4; - accum_t* sfm_max_ptr = reinterpret_cast(total_buf_ptr + offset); - offset += qSplitSize * 4; - scalar_t* query_t_padding_ptr = reinterpret_cast(total_buf_ptr + offset); - - for (const auto z : c10::irange(begin, end)) { - (void)z; // Suppress unused variable - - int32_t* kv_sum_ptr = kv_sum_buf_data - + i * num_head * kv_sum_size_per_BH - + j * kv_sum_size_per_BH; - int32_t* k_sum_ptr = kv_sum_ptr; - int32_t* v_sum_ptr = kv_sum_ptr + kvSize; - - // sdpa core - int64_t m = k * qSplitSize; - int64_t qBlockSize = std::min(qSplitSize, qSize - m); - // Initialize sum and max - fill_stub( - sfm_sum_ptr, static_cast(0), qSplitSize); - fill_stub( - a_sum_ptr, static_cast(0), qSplitSize); - fill_stub( - sfm_max_ptr, static_cast(-std::numeric_limits::infinity()), qSplitSize); - int64_t num_keys = - is_causal ? std::min(m + qBlockSize, kvSize) : kvSize; - copy_value_with_pad( - q_data + i * qStrideB + j * qStrideH + m * qStrideM, - query_t_padding_ptr, - qBlockSize, - headSize, - qBlockSize, - qk_gemm_K, - qStrideM); - - if (k_zp != 0) { - _int_sum_b_contiguous_kernel(q_data + i * qStrideB + j * qStrideH + m * qStrideM, - q_sum_ptr, qBlockSize, headSize, qStrideM, k_zp); - } else { - fill_stub( - q_sum_ptr, static_cast(0), qSplitSize); - } - const int64_t rkvSlice = (num_keys - 1) / kvSplitSize + 1; - for (int64_t l = 0; l < rkvSlice; l++) { - int64_t n = l * kvSplitSize; - int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n); - auto k_reorder = key_reorder_ptr + i * num_head * qk_gemm_K * rndkvSize + - j * qk_gemm_K * rndkvSize + n * qk_gemm_K; - // Calculate sums for dequant compensation item - if (qBlockSize == qSplitSize) { - // q main - if (n + kvSplitSize < kvSize) { - // k main - for (int64_t b = 0; b < kvSplitSize; b += block_64) { - at::native::cpublas::brgemm( - qSplitSize, block_64, qk_gemm_K, - qk_gemm_K, // lda - block_64, //ldb - rndkvSplitSize, //ldc, - false, - query_t_padding_ptr, - k_reorder + b * qk_gemm_K, - qk_s32_data + b); - } - } else { - auto block_size = kvTail >= block_64 ? block_64 : kv_tail_tail_block_size; - int64_t b = 0; - while (b < kvTail) { - at::native::cpublas::brgemm( - qSplitSize, block_size, qk_gemm_K, - qk_gemm_K, // lda - block_size, //ldb - rndkvTail, //ldc - false, - query_t_padding_ptr, - k_reorder + b * qk_gemm_K, - qk_s32_data + b); - b += block_size; - block_size = (kvTail - b) >= block_64 ? block_64 : kv_tail_tail_block_size; - } - } - } else { - if (n + kvSplitSize < kvSize) { - for (int64_t b = 0; b < kvSplitSize; b += block_64) { - at::native::cpublas::brgemm( - qTail, block_64, qk_gemm_K, - qk_gemm_K,// lda - block_64, //ldb - rndkvSplitSize, //ldc - false, - query_t_padding_ptr, - k_reorder + b * qk_gemm_K, - qk_s32_data + b); - } - } else { - // k tail - auto block_size = kvTail >= block_64 ? block_64 : kv_tail_tail_block_size; - int64_t b = 0; - while (b < kvTail) { - at::native::cpublas::brgemm( - qTail, block_size, qk_gemm_K, - qk_gemm_K, // lda - block_size, //ldb - rndkvTail, //ldc - false, - query_t_padding_ptr, - k_reorder + b * qk_gemm_K, - qk_s32_data + b); - b += block_size; - block_size = (kvTail - b) >= block_64 ? block_64 : kv_tail_tail_block_size; - } - } - } - - // do dequant compensation, add mask, max reduce for softmax, and convert qk from s32 to fp32 - int64_t rndkvBlockSize = kvBlockSize == kvSplitSize ? rndkvSplitSize : rndkvTail; - accum_t* qk_block_data = qk_data + l * qSplitSize * rndkvSplitSize; - if (has_attn_mask) { - mask_t* mask_data_offset = mask_data + i * mStrideB + j * mStrideH + m * mStrideM + (mStrideN == 0 ? 0 : n); - _dequant_mask_max_fusion_kernel( - qk_s32_data, //in - mask_data_offset, //mask_ptr - q_sum_ptr, //sum_a_ptr - k_sum_ptr + n, //sum_b_ptr - qBlockSize, //M - kvBlockSize, //N - rndkvBlockSize, //ldi - mStrideM, //ldm - rndkvSplitSize,//kvBlockSize, //ldo - q_zp * k_zp * headSize, //zp_a*zp_b*k=beta - q_scale * k_scale * scaling_factor, //scale_a*scale_b*scale_sdpa=alpha - qk_block_data, //out - sfm_max_ptr // sfm_max_ptr - ); - } else { - _dequant_max_fusion_kernel( - qk_s32_data, //in - q_sum_ptr, //sum_a_ptr - k_sum_ptr + n, //sum_b_ptr - qBlockSize, //M - kvBlockSize, //N - rndkvBlockSize, //ldi - rndkvSplitSize,//kvBlockSize, //ldo - q_zp * k_zp * headSize, //zp_a*zp_b*k=beta - q_scale * k_scale * scaling_factor, //scale_a*scale_b*scale_sdpa=alpha - qk_block_data, //out - sfm_max_ptr // sfm_max_ptr - ); - } - } - // sub max, exp, sum reduce, div sum for softmax - // and quant - // and sum for attention - if (v_zp == 0) { - _sub_exp_sum_div_quant_fusion_kernel( - qk_data, //in - qBlockSize, //M - kvSplitSize, //N_step - rkvSlice, //NSlices - qSplitSize * rndkvSplitSize, //ldi - qk_reduce_strideL, //ldo - kvSize, //kvSize - rndkvSplitSize, //rndkvSplitSize - av_gemm_K, //av_gemm_K - a_zp, // zp_a=beta1 - a_scale, // scale_a=alpha - qk_local_data, //local - qk_reduced_data, //out - sfm_max_ptr, //sfm_max_ptr - sfm_sum_ptr //sfm_sum_ptr - ); - } else { - _sub_exp_sum_div_quant_sum_fusion_kernel( - qk_data, //in - qBlockSize, //M - kvSplitSize, //N_step - rkvSlice, //NSlice - qSplitSize * rndkvSplitSize, //ldi - qk_reduce_strideL, //ldo - kvSize, //kvSize - rndkvSplitSize, //rndkvSplitSize - av_gemm_K, //av_gemm_K - a_zp, // zp_a=beta1 - v_zp, // zp_b=beta2 - a_scale, // scale_a=alpha - qk_local_data, //local - qk_reduced_data, //out - sfm_max_ptr, //sfm_max_ptr - sfm_sum_ptr, //sfm_sum_ptr - a_sum_ptr //a_sum_ptr - ); - } - // Calculate Softmax(q @ k.T) @ v - auto v_reorder = value_reorder_ptr + - i * num_head * kvSlice * v_reorder_strideL + - j * kvSlice * v_reorder_strideL; - for (int64_t b = 0; b < headSize; b += block_64) { - auto value_reorder_b = v_reorder + b * av_gemm_K; - auto dst_s32_b = dst_s32_data + b; - for (int64_t s = 0; s < kvSlice; s++) { - at::native::cpublas::brgemm( - qSplitSize, block_64, av_gemm_K, - av_gemm_K, // lda - rndHeadSize, //block_64, //ldb - rndHeadSize, //ldc - s != 0, - qk_reduced_data + s * qk_reduce_strideL, - value_reorder_b + s * v_reorder_strideL, - dst_s32_b); - } - } - - // After the last gemm, - // do dequant compensation, quant and convert from s32 to int8 - _dequant_quant_fusion_kernel( - dst_s32_data, //in - a_sum_ptr, //sum_a_ptr - v_sum_ptr, //sum_b_ptr - qBlockSize, //M - headSize, //N - rndHeadSize, //ldi - oStrideM, //ldo - a_zp * v_zp * kvSize, //zp_a*zp_b*k=beta1 - o_zp, //zp_c=beta2 - a_scale * v_scale / o_scale, //scale_a*scale_b/scale_c=alpha - out_data + i * oStrideB + j * oStrideH + m * oStrideM //out - ); - // Move to the next query - at::native::data_index_step(i, batchSize, j, num_head, k, qSlice); - } - }); - // Once all computations are done, need to release HW context. - at::native::cpublas::brgemm_release(); -} - -#define AT_DISPATCH_MASK_TYPES(TYPE, NAME, ...) \ - AT_DISPATCH_SWITCH( \ - TYPE, \ - NAME, \ - AT_PRIVATE_CASE_TYPE_USING_HINT( \ - at::ScalarType::Bool, mask_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE_USING_HINT( \ - at::ScalarType::Float, mask_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE_USING_HINT( \ - at::ScalarType::Double, mask_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE_USING_HINT( \ - at::ScalarType::BFloat16, mask_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE_USING_HINT( \ - at::ScalarType::Half, mask_t, __VA_ARGS__)) - -void sdpa_int8_fused_kernel( - const at::Tensor& output, - const at::Tensor& query, - const at::Tensor& key, - const at::Tensor& value, - double dropout_p, - bool is_causal, - at::Tensor& attn_mask, - double scale, - long q_zp, - double q_scale, - long k_zp, - double k_scale, - long v_zp, - double v_scale, - long a_zp, - double a_scale, - long o_zp, - double o_scale) { - TORCH_CHECK(query.scalar_type() == c10::kByte); - int64_t batchSize = query.size(0); - int64_t num_head = query.size(1); - int64_t q_seq_len = query.size(2); - int64_t kv_seq_len = key.size(2); - int64_t q_split_size = 32; - if (q_seq_len >= 768) { - q_split_size = 256; - } else if (q_seq_len >= 192) { - q_split_size = 64; - } - // Heuristic to decide whether to use one parallel loop or not - uint32_t l2_cache_size = at::cpu::L2_cache_size(); - int64_t num_thread = at::get_num_threads(); - int64_t attn_size = q_split_size * kv_seq_len * sizeof(int32_t) * num_thread; - bool use_one_parallel_loop = (batchSize * num_head > num_thread) && - (attn_size > l2_cache_size); - if (use_one_parallel_loop) { - if (!attn_mask.defined()) { - if (q_split_size == 256) { - sdpa_int8_kernel_one_loop_impl( - output, query, key, value, - dropout_p, is_causal, attn_mask, scale, - q_zp, q_scale, - k_zp, k_scale, - v_zp, v_scale, - a_zp, a_scale, - o_zp, o_scale); - } else if (q_split_size == 64) { - sdpa_int8_kernel_one_loop_impl( - output, query, key, value, - dropout_p, is_causal, attn_mask, scale, - q_zp, q_scale, - k_zp, k_scale, - v_zp, v_scale, - a_zp, a_scale, - o_zp, o_scale); - } else { - sdpa_int8_kernel_one_loop_impl( - output, query, key, value, - dropout_p, is_causal, attn_mask, scale, - q_zp, q_scale, - k_zp, k_scale, - v_zp, v_scale, - a_zp, a_scale, - o_zp, o_scale); - } - } else { - AT_DISPATCH_MASK_TYPES(attn_mask.scalar_type(), "sdpa_mask", [&]() { - if (q_split_size == 256) { - sdpa_int8_kernel_one_loop_impl( - output, query, key, value, - dropout_p, is_causal, attn_mask, scale, - q_zp, q_scale, - k_zp, k_scale, - v_zp, v_scale, - a_zp, a_scale, - o_zp, o_scale); - } else if (q_split_size == 64) { - sdpa_int8_kernel_one_loop_impl( - output, query, key, value, - dropout_p, is_causal, attn_mask, scale, - q_zp, q_scale, - k_zp, k_scale, - v_zp, v_scale, - a_zp, a_scale, - o_zp, o_scale); - } else { - sdpa_int8_kernel_one_loop_impl( - output, query, key, value, - dropout_p, is_causal, attn_mask, scale, - q_zp, q_scale, - k_zp, k_scale, - v_zp, v_scale, - a_zp, a_scale, - o_zp, o_scale); - } - }); - } - } else { - if (!attn_mask.defined()) { - if (q_split_size == 256) { - sdpa_int8_kernel_several_loops_impl( - output, query, key, value, - dropout_p, is_causal, attn_mask, scale, - q_zp, q_scale, - k_zp, k_scale, - v_zp, v_scale, - a_zp, a_scale, - o_zp, o_scale); - } else if (q_split_size == 64) { - sdpa_int8_kernel_several_loops_impl( - output, query, key, value, - dropout_p, is_causal, attn_mask, scale, - q_zp, q_scale, - k_zp, k_scale, - v_zp, v_scale, - a_zp, a_scale, - o_zp, o_scale); - } else { - sdpa_int8_kernel_several_loops_impl( - output, query, key, value, - dropout_p, is_causal, attn_mask, scale, - q_zp, q_scale, - k_zp, k_scale, - v_zp, v_scale, - a_zp, a_scale, - o_zp, o_scale); - } - } else { - AT_DISPATCH_MASK_TYPES(attn_mask.scalar_type(), "sdpa_mask", [&]() { - if (q_split_size == 256) { - sdpa_int8_kernel_several_loops_impl( - output, query, key, value, - dropout_p, is_causal, attn_mask, scale, - q_zp, q_scale, - k_zp, k_scale, - v_zp, v_scale, - a_zp, a_scale, - o_zp, o_scale); - } else if (q_split_size == 64) { - sdpa_int8_kernel_several_loops_impl( - output, query, key, value, - dropout_p, is_causal, attn_mask, scale, - q_zp, q_scale, - k_zp, k_scale, - v_zp, v_scale, - a_zp, a_scale, - o_zp, o_scale); - } else { - sdpa_int8_kernel_several_loops_impl( - output, query, key, value, - dropout_p, is_causal, attn_mask, scale, - q_zp, q_scale, - k_zp, k_scale, - v_zp, v_scale, - a_zp, a_scale, - o_zp, o_scale); - } - }); - } - } -} -#endif // CPU_CAPABILITY_AVX512 - -at::Tensor sdpa_int8_math_kernel( - const at::Tensor& query, - const at::Tensor& key, - const at::Tensor& value, - double dropout_p, - bool is_causal, - at::Tensor& attn_mask, - double scale, - int32_t q_zp, - float q_scale, - int32_t k_zp, - float k_scale, - int32_t v_zp, - float v_scale, - int32_t a_zp, - float a_scale, - int32_t o_zp, - float o_scale) { - // dequant q/k/v - auto q = (query.to(at::kFloat) - q_zp) * q_scale; - auto k = (key.to(at::kFloat) - k_zp) * k_scale; - auto v = (value.to(at::kFloat) - v_zp) * v_scale; - const auto scaling_factor = calculate_scale(q, scale); - auto attn = at::matmul(q, k.transpose(-2, -1)) * scaling_factor; - if (attn_mask.defined() && attn_mask.numel()) { - attn = attn.add(attn_mask.to(at::kFloat)); - } - attn = at::softmax(attn, -1); - // quant attn - attn = at::clamp_max( - at::clamp_min(at::round(attn / a_scale) + a_zp, 0), 255 - ); - // dequant attn - attn = (attn - a_zp) * a_scale; - auto output = at::matmul(attn, v); - // quant output - output = at::clamp_max( - at::clamp_min(at::round(output / o_scale) + o_zp, 0), 255 - ).to(at::kByte); - return output; -} - - -at::Tensor _scaled_dot_product_int8_cpu( - const at::Tensor& query, - const at::Tensor& key, - const at::Tensor& value, - at::Tensor& attn_mask, - double dropout_p, - bool is_causal, - double scale, - int64_t q_zp, - double q_scale, - int64_t k_zp, - double k_scale, - int64_t v_zp, - double v_scale, - int64_t a_zp, - double a_scale, - int64_t o_zp, - double o_scale) { - const auto dtype = query.scalar_type(); - TORCH_CHECK(!query.is_nested() && !key.is_nested() && !value.is_nested(), - "_scaled_dot_product_int8_cpu: Only accept plain inputs"); - TORCH_CHECK(!is_causal, - "_scaled_dot_product_int8_cpu: is_causal not supported."); - TORCH_CHECK(dtype == at::ScalarType::Byte, - "_scaled_dot_product_int8_cpu: Expected data type be U8, but got ", dtype, " instead."); - TORCH_CHECK(query.dim() == 4 && key.dim() == 4 && value.dim() == 4, - "_scaled_dot_product_int8_cpu: Accept only 4 dims inputs shape of {B, H, T, K}"); - TORCH_CHECK(dropout_p == 0.0, - "_scaled_dot_product_int8_cpu: Currently do not support dropout > 0"); - TORCH_CHECK((query.size(3) == value.size(3)) && (key.size(3) == value.size(3)), - "_scaled_dot_product_int8_cpu: Q/K/V should have the same head size"); - TORCH_CHECK(!attn_mask.defined() || - attn_mask.scalar_type() == at::kFloat || - attn_mask.scalar_type() == at::kBFloat16, - "_scaled_dot_product_int8_cpu: Expected attention mask be float or bf16"); - TORCH_CHECK(!attn_mask.defined() || - (attn_mask.dim() == 2 || attn_mask.dim() == 4), - "_scaled_dot_product_int8_cpu: Attention mask dim in {2, 4}"); - - if (!at::native::cpublas::could_pack(dtype)) { - return sdpa_int8_math_kernel(query, key, value, - dropout_p, is_causal, attn_mask, scale, - q_zp, q_scale, - k_zp, k_scale, - v_zp, v_scale, - a_zp, a_scale, - o_zp, o_scale); - } - - #ifdef CPU_CAPABILITY_AVX512 - at::Tensor output = at::empty_like(query, query.options()); - sdpa_int8_fused_kernel(output, query, key, value, - dropout_p, is_causal, attn_mask, scale, - q_zp, q_scale, - k_zp, k_scale, - v_zp, v_scale, - a_zp, a_scale, - o_zp, o_scale); - return output; - #else - return sdpa_int8_math_kernel(query, key, value, - dropout_p, is_causal, attn_mask, scale, - q_zp, q_scale, - k_zp, k_scale, - v_zp, v_scale, - a_zp, a_scale, - o_zp, o_scale); - #endif // CPU_CAPABILITY_AVX512 -} - - -} // anonymous namespace - -TORCH_LIBRARY_IMPL(torchao, CPU, m) { - m.impl("torchao::scaled_dot_product_int8", &_scaled_dot_product_int8_cpu); -} - -// } // at::native -} // namespace torchao diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 9ef71a44fc..46740a81fd 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -94,9 +94,6 @@ smooth_fq_linear_to_inference, swap_linear_with_smooth_fq_linear, ) -from .sfdp_int8_fx_pass import ( - _sfdp_int8_init, -) from .subclass import * # noqa: F403 from .transform_module import register_quantize_module_handler from .unified import Quantizer, TwoStepQuantizer diff --git a/torchao/quantization/sfdp_int8_fx_pass.py b/torchao/quantization/sfdp_int8_fx_pass.py deleted file mode 100644 index 9359ae57b5..0000000000 --- a/torchao/quantization/sfdp_int8_fx_pass.py +++ /dev/null @@ -1,541 +0,0 @@ -import functools -from typing import Callable - -import torch -from torch._inductor import config -from torch._inductor.pattern_matcher import ( - Arg, - CallFunction, - filter_nodes, - KeywordArg, - ListOf, - Match, - PatternMatcherPass, -) -from torch._inductor.fx_passes.post_grad import register_lowering_pattern -from torch._inductor.lowering import lowerings as L, make_fallback -from torch._dynamo.utils import counters - -__all__ = [ - "_sfdp_int8_post_grad_init", -] - -make_fallback(torch.ops.torchao.scaled_dot_product_int8.default) - -aten = torch.ops.aten -patterns = PatternMatcherPass() - -def _is_valid_int8_sdpa_pattern(): - def fn(match): - assert all(k in match.kwargs for k in ("query", "key", "value")) - query = match.kwargs["query"].meta["val"] - key = match.kwargs["key"].meta["val"] - value = match.kwargs["value"].meta["val"] - return ( - query.dtype == torch.uint8 - and key.dtype == torch.uint8 - and value.dtype == torch.uint8 - and query.device.type == "cpu" - and key.device == query.device - and value.device == query.device - ) - - return fn - - -def _register_int8_sdpa_pattern(pattern): - @register_lowering_pattern( - pattern, - extra_check=_is_valid_int8_sdpa_pattern(), - ) - def int8_sdpa(match: Match, *args, **kwargs): - print("\n***hit int8_sdpa_pattern***\n") - query = kwargs["query"] - key = kwargs["key"] - value = kwargs["value"] - inv_scale = kwargs["inv_scale"] - attn_mask = kwargs["attn_mask"] if "attn_mask" in kwargs else None - q_zp = kwargs["q_zp"] - q_scale = kwargs["q_scale"] - k_zp = kwargs["k_zp"] - k_scale = kwargs["k_scale"] - v_zp = kwargs["v_zp"] - v_scale = kwargs["v_scale"] - a_zp = kwargs["a_zp"] - a_scale = kwargs["a_scale"] - o_zp = kwargs["o_zp"] - o_scale = kwargs["o_scale"] - counters["inductor"]["int8_fuse_attention"] += 1 - counters["inductor"]["int8_sdpa_nodes"] += len(match.nodes) - - return L[torch.ops.torchao.scaled_dot_product_int8.default]( - query, - key, - value, - attn_mask, - 0.0, #dropout - False, #is_causal - 1.0 / inv_scale, #scale - q_zp, - q_scale, - k_zp, - k_scale, - v_zp, - v_scale, - a_zp, - a_scale, - o_zp, - o_scale, - ) - - return int8_sdpa - - -def _get_int8_sdpa_q_pattern(is_batch_size_1: bool, has_convert: bool): - int8_sdpa_q_basic_pattern = CallFunction( - torch.ops.quantized_decomposed.dequantize_per_tensor.default, # dequantize_per_tensor_3 - CallFunction( - aten.permute.default, # permute_3 - KeywordArg("query"), - Arg(), - ), - KeywordArg("q_scale"), - KeywordArg("q_zp"), - Arg(), - Arg(), - Arg(), - ) - if has_convert: - int8_sdpa_q_basic_pattern = CallFunction( - torch.ops.prims.convert_element_type.default, - int8_sdpa_q_basic_pattern, - Arg(), - ) - int8_sdpa_q_basic_pattern = CallFunction( - aten.expand.default, - int8_sdpa_q_basic_pattern, - Arg(), - ) - if is_batch_size_1: - # pattern is different for bs=1 - return CallFunction( - aten.reshape.default, # view_9 - int8_sdpa_q_basic_pattern, - Arg(), - ) - else: - return CallFunction( - aten.reshape.default, # view_9 - CallFunction( - aten.clone.default, # clone - int8_sdpa_q_basic_pattern, - memory_format=Arg(), - ), - Arg(), - ) - - -def _get_int8_sdpa_k_pattern(is_batch_size_1: bool, has_convert: bool): - int8_sdpa_k_basic_pattern = CallFunction( - torch.ops.quantized_decomposed.dequantize_per_tensor.default, # dequantize_per_tensor_5 - CallFunction( - aten.permute.default, # permute_6 - CallFunction( - aten.permute.default, # permute_4 - KeywordArg("key"), - Arg(), - ), - Arg(), - ), - KeywordArg("k_scale"), - KeywordArg("k_zp"), - Arg(), - Arg(), - Arg(), - ) - if has_convert: - int8_sdpa_k_basic_pattern = CallFunction( - torch.ops.prims.convert_element_type.default, - int8_sdpa_k_basic_pattern, - Arg(), - ) - int8_sdpa_k_basic_pattern = CallFunction( - aten.expand.default, - int8_sdpa_k_basic_pattern, - Arg(), - ) - if is_batch_size_1: - # pattern is different for bs=1 - return CallFunction( - aten.reshape.default, # view_10 - int8_sdpa_k_basic_pattern, - Arg(), - ) - else: - return CallFunction( - aten.reshape.default, # view_10 - CallFunction( - aten.clone.default, # clone_1 - int8_sdpa_k_basic_pattern, - memory_format=Arg(), - ), - Arg(), - ) - - -def _get_int8_sdpa_v_pattern(is_batch_size_1: bool, has_convert: bool): - int8_sdpa_v_basic_pattern = CallFunction( - torch.ops.quantized_decomposed.dequantize_per_tensor.default, # dequantize_per_tensor_4 - CallFunction( - aten.permute.default, # permute_5 - KeywordArg("value"), - Arg(), - ), - KeywordArg("v_scale"), - KeywordArg("v_zp"), - Arg(), - Arg(), - Arg(), - ) - if has_convert: - int8_sdpa_v_basic_pattern = CallFunction( - torch.ops.prims.convert_element_type.default, - int8_sdpa_v_basic_pattern, - Arg(), - ) - int8_sdpa_v_basic_pattern = CallFunction( - aten.expand.default, # expand_3 - int8_sdpa_v_basic_pattern, - Arg(), - ) - if is_batch_size_1: - # pattern is different for bs=1 - return CallFunction( - aten.reshape.default, # view_13 - int8_sdpa_v_basic_pattern, - Arg(), - ) - else: - return CallFunction( - aten.reshape.default, # view_13 - CallFunction( - aten.clone.default, # clone_3 - int8_sdpa_v_basic_pattern, - memory_format=Arg(), - ), - Arg(), - ) - - -def _get_int8_sdpa_score_pattern( - has_mask: bool, is_batch_size_1: bool, is_reduced_type: bool, has_convert: bool -): - int8_sdpa_q_pattern = _get_int8_sdpa_q_pattern(is_batch_size_1, has_convert) - int8_sdpa_k_pattern = _get_int8_sdpa_k_pattern(is_batch_size_1, has_convert) - int8_sdpa_score_basic_pattern = CallFunction( - aten.reshape.default, # view_11 - CallFunction( - aten.bmm.default, # bmm - int8_sdpa_q_pattern, # view_9 - int8_sdpa_k_pattern, # view_10 - ), - Arg(), - ) - if is_reduced_type and not has_mask: - int8_sdpa_score_basic_pattern = CallFunction( - torch.ops.prims.convert_element_type.default, - int8_sdpa_score_basic_pattern, - Arg(), - ) - if has_mask: - return CallFunction( - aten.add.Tensor, # add - CallFunction( - aten.div.Tensor, # div - int8_sdpa_score_basic_pattern, - KeywordArg("inv_scale"), - ), - KeywordArg("attn_mask"), - _users=2, - ) - else: - return CallFunction( - aten.mul.Tensor, # mul_tensor - int8_sdpa_score_basic_pattern, - Arg(), - _users=2, - ) - - -def _get_int8_sdpa_exp_pattern( - has_mask: bool, is_batch_size_1: bool, is_reduced_type: bool, has_convert: bool -): - int8_sdpa_score_pattern = _get_int8_sdpa_score_pattern( - has_mask, is_batch_size_1, is_reduced_type, has_convert - ) - int8_sdpa_exp_basic_pattern = CallFunction( - aten.sub.Tensor, # sub - int8_sdpa_score_pattern, # add - CallFunction( - aten.amax.default, # amax - int8_sdpa_score_pattern, # add - Arg(), - Arg(), - ), - ) - if has_mask: - return CallFunction( - aten.exp.default, # exp - int8_sdpa_exp_basic_pattern, - _users=2, - ) - else: - return CallFunction( - aten.exp.default, # exp - CallFunction( - aten.div.Tensor, # div_tensor - int8_sdpa_exp_basic_pattern, - KeywordArg("inv_scale"), - ), - _users=2, - ) - - -def _get_int8_sdpa_attn_pattern( - has_mask: bool, is_batch_size_1: bool, is_reduced_type: bool, has_convert: bool -): - int8_sdpa_exp_pattern = _get_int8_sdpa_exp_pattern( - has_mask, is_batch_size_1, is_reduced_type, has_convert - ) - int8_sdpa_div_pattern = CallFunction( - aten.div.Tensor, # div_1 - int8_sdpa_exp_pattern, # exp - CallFunction( - aten.sum.dim_IntList, # sum_1 - int8_sdpa_exp_pattern, # exp - Arg(), - Arg(), - ), - ) - int8_sdpa_softmax_pattern = CallFunction( - torch.ops.quantized_decomposed.dequantize_per_tensor.default, # dequantize_per_tensor_6 - CallFunction( - torch.ops.quantized_decomposed.quantize_per_tensor.default, # quantize_per_tensor_4 - int8_sdpa_div_pattern, - KeywordArg("a_scale"), - KeywordArg("a_zp"), - Arg(), - Arg(), - Arg(), - ), - KeywordArg("a_scale"), - KeywordArg("a_zp"), - Arg(), - Arg(), - Arg(), - ) - if is_reduced_type: - if has_mask: - int8_sdpa_softmax_pattern = CallFunction( - torch.ops.prims.convert_element_type.default, - int8_sdpa_softmax_pattern, - Arg(), - ) - else: - int8_sdpa_softmax_pattern = CallFunction( - torch.ops.quantized_decomposed.dequantize_per_tensor.default, # dequantize_per_tensor_6 - CallFunction( - torch.ops.quantized_decomposed.quantize_per_tensor.default, # quantize_per_tensor_4 - CallFunction( - torch.ops.prims.convert_element_type.default, - int8_sdpa_div_pattern, - Arg(), - ), - KeywordArg("a_scale"), - KeywordArg("a_zp"), - Arg(), - Arg(), - Arg(), - ), - KeywordArg("a_scale"), - KeywordArg("a_zp"), - Arg(), - Arg(), - Arg(), - ) - if has_convert: - int8_sdpa_softmax_pattern = CallFunction( - torch.ops.prims.convert_element_type.default, - int8_sdpa_softmax_pattern, - Arg(), - ) - return CallFunction( - aten.reshape.default, # view_12 - CallFunction( - aten.expand.default, # expand_2 - int8_sdpa_softmax_pattern, - Arg(), - ), - Arg(), - ) - - -def _get_int8_sdpa_final_pattern( - has_mask: bool, is_batch_size_1: bool, is_reduced_type: bool, has_convert: bool -): - int8_sdpa_v_pattern = _get_int8_sdpa_v_pattern(is_batch_size_1, has_convert) - int8_sdpa_attn_pattern = _get_int8_sdpa_attn_pattern( - has_mask, is_batch_size_1, is_reduced_type, has_convert - ) - return CallFunction( - torch.ops.quantized_decomposed.quantize_per_tensor.default, # quantize_per_tensor_5 - CallFunction( - aten.clone.default, # clone_4 - CallFunction( - aten.permute.default, # permute_7 - CallFunction( - aten.reshape.default, # view_14 - CallFunction( - aten.bmm.default, # bmm_1 - int8_sdpa_attn_pattern, # view_12 - int8_sdpa_v_pattern, # view_13 - ), - Arg(), - ), - Arg(), - ), - memory_format=Arg(), - ), - KeywordArg("o_scale"), - KeywordArg("o_zp"), - Arg(), - Arg(), - Arg(), - ) - - -def _register_int8_sdpa_fp32_lowering(): - # dtype = float32, without attention mask, batch size > 1 - _register_int8_sdpa_pattern( - _get_int8_sdpa_final_pattern( - has_mask=False, is_batch_size_1=False, is_reduced_type=False, has_convert=True - ) - ) - _register_int8_sdpa_pattern( - _get_int8_sdpa_final_pattern( - has_mask=False, is_batch_size_1=False, is_reduced_type=False, has_convert=False - ) - ) - - -def _register_int8_sdpa_fp32_mask_lowering(): - # dtype = float32, with attention mask, batch size > 1 - _register_int8_sdpa_pattern( - _get_int8_sdpa_final_pattern( - has_mask=True, is_batch_size_1=False, is_reduced_type=False, has_convert=True - ) - ) - _register_int8_sdpa_pattern( - _get_int8_sdpa_final_pattern( - has_mask=True, is_batch_size_1=False, is_reduced_type=False, has_convert=False - ) - ) - - -def _register_int8_sdpa_fp32_bs1_lowering(): - # dtype = float32, without attention mask, batch size == 1 - _register_int8_sdpa_pattern( - _get_int8_sdpa_final_pattern( - has_mask=False, is_batch_size_1=True, is_reduced_type=False, has_convert=True - ) - ) - _register_int8_sdpa_pattern( - _get_int8_sdpa_final_pattern( - has_mask=False, is_batch_size_1=True, is_reduced_type=False, has_convert=False - ) - ) - - -def _register_int8_sdpa_fp32_mask_bs1_lowering(): - # dtype = float32, with attention mask, batch size == 1 - _register_int8_sdpa_pattern( - _get_int8_sdpa_final_pattern( - has_mask=True, is_batch_size_1=True, is_reduced_type=False, has_convert=True - ) - ) - _register_int8_sdpa_pattern( - _get_int8_sdpa_final_pattern( - has_mask=True, is_batch_size_1=True, is_reduced_type=False, has_convert=False - ) - ) - - -def _register_int8_sdpa_bf16_lowering(): - # dtype = bfloat16, without attention mask, batch size > 1 - _register_int8_sdpa_pattern( - _get_int8_sdpa_final_pattern( - has_mask=False, is_batch_size_1=False, is_reduced_type=True, has_convert=True - ) - ) - _register_int8_sdpa_pattern( - _get_int8_sdpa_final_pattern( - has_mask=False, is_batch_size_1=False, is_reduced_type=True, has_convert=False - ) - ) - - -def _register_int8_sdpa_bf16_mask_lowering(): - # dtype = bfloat16, with attention mask, batch size > 1 - _register_int8_sdpa_pattern( - _get_int8_sdpa_final_pattern( - has_mask=True, is_batch_size_1=False, is_reduced_type=True, has_convert=True - ) - ) - _register_int8_sdpa_pattern( - _get_int8_sdpa_final_pattern( - has_mask=True, is_batch_size_1=False, is_reduced_type=True, has_convert=False - ) - ) - - -def _register_int8_sdpa_bf16_bs1_lowering(): - # dtype = bfloat16, without attention mask, batch size == 1 - _register_int8_sdpa_pattern( - _get_int8_sdpa_final_pattern( - has_mask=False, is_batch_size_1=True, is_reduced_type=True, has_convert=True - ) - ) - _register_int8_sdpa_pattern( - _get_int8_sdpa_final_pattern( - has_mask=False, is_batch_size_1=True, is_reduced_type=True, has_convert=False - ) - ) - - -def _register_int8_sdpa_bf16_mask_bs1_lowering(): - # dtype = bfloat16, with attention mask, batch size == 1 - _register_int8_sdpa_pattern( - _get_int8_sdpa_final_pattern( - has_mask=True, is_batch_size_1=True, is_reduced_type=True, has_convert=True - ) - ) - _register_int8_sdpa_pattern( - _get_int8_sdpa_final_pattern( - has_mask=True, is_batch_size_1=True, is_reduced_type=True, has_convert=False - ) - ) - -def _register_quantized_sdpa_lowerings(): - _register_int8_sdpa_fp32_lowering() - _register_int8_sdpa_fp32_mask_lowering() - _register_int8_sdpa_fp32_bs1_lowering() - _register_int8_sdpa_fp32_mask_bs1_lowering() - _register_int8_sdpa_bf16_lowering() - _register_int8_sdpa_bf16_mask_lowering() - _register_int8_sdpa_bf16_bs1_lowering() - _register_int8_sdpa_bf16_mask_bs1_lowering() - -@functools.lru_cache(None) -def _sfdp_int8_post_grad_init(): - _register_quantized_sdpa_lowerings() - config.post_grad_custom_pre_pass = patterns.apply From bd9ae063a9220eae98c7014147f8381a02e60f03 Mon Sep 17 00:00:00 2001 From: Valentine233 Date: Wed, 26 Feb 2025 03:30:49 -0500 Subject: [PATCH 12/36] update --- .../inductor/test_int8_sdpa_fusion.py | 224 ++ torchao/csrc/cpu/int8_sdpa.cpp | 2382 +++++++++++++++++ .../prototype/inductor/fx_passes/README.md | 34 + .../inductor/fx_passes/int8_sdpa_fusion.py | 539 ++++ 4 files changed, 3179 insertions(+) create mode 100644 test/prototype/inductor/test_int8_sdpa_fusion.py create mode 100644 torchao/csrc/cpu/int8_sdpa.cpp create mode 100644 torchao/prototype/inductor/fx_passes/README.md create mode 100644 torchao/prototype/inductor/fx_passes/int8_sdpa_fusion.py diff --git a/test/prototype/inductor/test_int8_sdpa_fusion.py b/test/prototype/inductor/test_int8_sdpa_fusion.py new file mode 100644 index 0000000000..315e5b90f9 --- /dev/null +++ b/test/prototype/inductor/test_int8_sdpa_fusion.py @@ -0,0 +1,224 @@ +import torchao + +import contextlib +import functools +import itertools +import math + +import torch +import torch.utils.checkpoint +from torch._dynamo.debug_utils import aot_graph_input_parser +from torch._dynamo.utils import counters +from torch._inductor import config +from torch._inductor.test_case import run_tests, TestCase +from torch._inductor.utils import run_and_get_code +from torch.testing._internal.common_utils import IS_LINUX, skipIfRocm +from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA + +import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq +from torch.export import export_for_training +from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e +from torch.ao.quantization.quantizer.x86_inductor_quantizer import ( + X86InductorQuantizer, +) +from torchao.prototype.inductor.fx_passes.int8_sdpa_fusion import _int8_sdpa_init + +class SelfAttnLikeModule(torch.nn.Module): + def __init__( + self, + input_dim, + has_mask, + num_attention_heads=None, + attention_head_size=None, + ) -> None: + super().__init__() + self.input_dim = input_dim + self.q_proj = torch.nn.Linear(input_dim, input_dim, bias=False) + self.k_proj = torch.nn.Linear(input_dim, input_dim, bias=False) + self.v_proj = torch.nn.Linear(input_dim, input_dim, bias=False) + self.softmax = torch.nn.Softmax(dim=-1) + assert num_attention_heads is not None + assert attention_head_size is not None + self.num_attention_heads = num_attention_heads + self.attention_head_size = attention_head_size + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.dense = torch.nn.Linear(self.all_head_size, self.all_head_size) + self.dropout = torch.nn.Dropout(0) + self.has_mask = has_mask + + def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: + new_x_shape = x.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) + x = x.view(new_x_shape) + return x.permute([0, 2, 1, 3]) + + def forward(self, x, mask): + q = self.q_proj(x) + k = self.k_proj(x) + v = self.v_proj(x) + q = self.transpose_for_scores(q) + k = self.transpose_for_scores(k) + v = self.transpose_for_scores(v) + scores = torch.matmul(q, k.transpose(-1, -2)) / (self.input_dim**0.5) + if self.has_mask and mask.dtype != scores.dtype: + scores = scores + mask + attention = self.softmax(scores) + attention = self.dropout(attention) + context_layer = torch.matmul(attention, v) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + context_layer = context_layer.view( + context_layer.size()[:-2] + (self.all_head_size,) + ) + return self.dense(context_layer) + +class TestSDPAPatternRewriterTemplate(TestCase): + def _clone_inputs(self, inputs): + def clone(x): + if not isinstance(x, torch.Tensor): + return x + return x.clone() + + return [clone(x) for x in inputs] + + def _check_common( + self, + dot_prod_attention, + args1=None, + contains=True, + atol=1e-5, + has_fuse_pattern=True, + has_dropout=False, + check_train=True, + override_check_equal=False, + dtype=torch.float, + rtol=1.3e-6, + ): + if args1 is None: + tensor_shape = (4, 2, 16, 32) + args1 = [ + torch.randn(tensor_shape, device=self.device, dtype=dtype), + torch.randn(tensor_shape, device=self.device, dtype=dtype), + torch.randn(tensor_shape, device=self.device, dtype=dtype), + ] + else: + args1 = list(args1) + args2 = self._clone_inputs(args1) + + for training in [False, True] if check_train else [False]: + for x in itertools.chain(args1[:], args2[:]): + if isinstance(x, torch.Tensor) and x.is_floating_point(): + x.requires_grad = training + + dropout_arg = [training] if has_dropout else [] + torch.manual_seed(1234) + result1 = dot_prod_attention(*(args1 + dropout_arg)) + + counters.clear() + torch.manual_seed(1234) + compiled_model = torch.compile(dot_prod_attention, fullgraph=True) + result2, source_code = run_and_get_code( + compiled_model, + *(args2 + dropout_arg), + ) + source_code = "\n".join(source_code) + if has_fuse_pattern: + self.assertGreaterEqual(counters["inductor"]["int8_fuse_attention"], 1) + if contains: + # many of the patterns get re-expanded in dispatcher + self.assertIn( + "torchao.scaled_dot_product_int8", + source_code, + ) + + # some tests configured with very low dropout where we still want to check equality + if not has_dropout or override_check_equal: + self.assertEqual(result1, result2, atol=atol, rtol=1.3e-6) + + if training: + result1.sum().backward() + result2.sum().backward() + for arg1, arg2 in zip(args1, args2): + if ( + isinstance(arg1, torch.Tensor) + and arg1.is_floating_point() + and (not has_dropout or override_check_equal) + ): + self.assertEqual(arg1.grad, arg2.grad, atol=atol, rtol=rtol) + + iter_n = 20 + with torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CPU], + schedule=torch.profiler.schedule(wait=2, warmup=iter_n, active=20), + ) as prof: + for _ in range(iter_n + 22): + r = compiled_model(*(args2 + dropout_arg)) + prof.step() + print(prof.key_averages().table(sort_by="self_cpu_time_total")) + + @skipIfRocm + @config.patch({"freezing": True}) + def _test_sdpa_rewriter_int8_1_to_4(self): + # pattern is different for bs=1 + for dtype, has_mask, bs in itertools.product( + [torch.float32, torch.bfloat16], [True, False], [56, 1] + ): + seqlen, numhead, headsize = 197, 16, 64 + # dtype = torch.bfloat16 + # has_mask = True + # is_bs_1 = 0 + # if is_bs_1: + # candidates = [[1, 384, 16, 64], [1, 197, 12, 64]] + # else: + # candidates = [[120, 384, 16, 64], [224, 197, 12, 64]] + # candidates = [[120, 384, 16, 64]] + # for bs, seqlen, numhead, headsize in candidates: + mod = SelfAttnLikeModule( + input_dim=headsize * numhead, + has_mask=has_mask, + num_attention_heads=numhead, + attention_head_size=headsize, + ).eval() + maybe_autocast = ( + torch.cpu.amp.autocast() + if dtype == torch.bfloat16 + else contextlib.nullcontext() + ) + print("\nTEST shape", bs, numhead, seqlen, headsize) + inputs = ( + torch.randn( + (bs, seqlen, headsize * numhead), device=self.device, dtype=dtype + ) + * 10, + torch.randn((bs, 1, 1, seqlen), device=self.device) * 10 + if has_mask + else None, + ) + with torch.no_grad(), maybe_autocast: + _int8_sdpa_init() + quantizer = X86InductorQuantizer() + quantizer.set_global(xiq.get_default_x86_inductor_quantization_config()) + quantizer.set_function_type_qconfig( + torch.matmul, quantizer.get_global_quantization_config() + ) + export_model = export_for_training( + mod, + inputs, + ).module() + prepare_model = prepare_pt2e(export_model, quantizer) + prepare_model(*inputs) + convert_model = convert_pt2e(prepare_model) + torch.ao.quantization.move_exported_model_to_eval(convert_model) + self._check_common( + convert_model, args1=inputs, check_train=False, atol=1.0 + ) + +if HAS_CPU: + class SDPAPatternRewriterCpuTests(TestSDPAPatternRewriterTemplate): + device = "cpu" + test_sdpa_rewriter_int8_1_to_4_cpu = TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_int8_1_to_4 + +if __name__ == "__main__": + if IS_LINUX: + run_tests() diff --git a/torchao/csrc/cpu/int8_sdpa.cpp b/torchao/csrc/cpu/int8_sdpa.cpp new file mode 100644 index 0000000000..4f5ca3fcaf --- /dev/null +++ b/torchao/csrc/cpu/int8_sdpa.cpp @@ -0,0 +1,2382 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#endif + +#include +#include +#include +#include +#include +#include + +namespace torchao { + +namespace { + +template +struct is_reduced_floating_point: + std::integral_constant || + std::is_same_v> { +}; + +template +constexpr bool is_reduced_floating_point_v = is_reduced_floating_point::value; + +inline double calculate_scale( + const at::Tensor& query, + double scale) { + return scale == 0.0 ? 1.0 / std::sqrt(query.size(-1)) : scale; +} + +#ifdef CPU_CAPABILITY_AVX512 +// out = val * a + b +// is_b_stride_zero: If the stride of b is 0 (mask broadcasting case), +// take b as a scalar pointer. +template +inline void _scale_attn_mask_fusion_kernel( + T1* a, + T2* b, + const int& size, + T1* out, + T1& val) { + const auto vec_size1 = at::vec::Vectorized::size(); + const auto vec_size2 = at::vec::Vectorized::size(); + constexpr int64_t T1_n = + (vec_size2 == vec_size1 * 2 && is_reduced_floating_point_v) ? 2 : 1; + constexpr int64_t T2_n = 1; + auto vec_scale = at::vec::VectorizedN(val); + int64_t i = 0; + for (; i < size - (size % vec_size2); i += vec_size2) { + auto a_n = at::vec::VectorizedN::loadu(a + i); + at::vec::VectorizedN b_n; + if constexpr(is_b_stride_zero) { + b_n = at::vec::VectorizedN((T1)b[0]); + } else { + b_n = at::vec::VectorizedN::loadu(b + i); + } + auto b_n_convert = at::vec::convert(b_n); + auto res = a_n * vec_scale + b_n_convert; + res.store(out + i); + } + for (; i < size; i++) { + auto tmp0 = a[i]; + T1 tmp1; + if constexpr(is_b_stride_zero) { + tmp1 = (T1)b[0]; + } else { + tmp1 = (T1)b[i]; + } + out[i] = tmp0 * val + tmp1; + } +} + +// 1) out = exp(a - val) +// 2) val = sum(out) +template +inline void _exp_reduce_sum_fusion_kernel( + T1* a, + const int& size, + T2* out, + T1& val) { + auto vec_size = at::vec::Vectorized::size(); + auto vec_max = at::vec::Vectorized(val); + T1 tmp_sum = 0; + auto vec_tmp_sum = at::vec::Vectorized(tmp_sum); + for (long i = 0; i < vec_size * (size / vec_size); i += vec_size) { + auto tmp0 = at::vec::Vectorized::loadu(a + i); + auto tmp1 = tmp0 - vec_max; + auto tmp2 = tmp1.exp_u20(); + vec_tmp_sum += tmp2; + _store(out + i, tmp2); + } + tmp_sum = at::vec::vec_reduce_all( + [](at::vec::Vectorized& x, at::vec::Vectorized& y) { + return x + y; + }, + vec_tmp_sum); + for (long i = vec_size * (size / vec_size); i < size; i++) { + auto tmp0 = a[i]; + auto tmp1 = tmp0 - val; + auto tmp2 = exp(tmp1); + tmp_sum += tmp2; + out[i] = tmp2; + } + val = tmp_sum; +} + +// 1) out = a * scale +// 2) max = max(out) +template +inline void _mul_reduce_max_fusion_kernel( + const scalar_t* a, + const scalar_t& scale, + const int& size, + scalar_t* out, + scalar_t& max) { + auto vec_size = at::vec::Vectorized::size(); + auto vec_scale = at::vec::Vectorized(scale); + scalar_t tmp_max = -std::numeric_limits::infinity(); + auto vec_tmp_max = at::vec::Vectorized(tmp_max); + for (long i = 0; i < vec_size * (size / vec_size); i += vec_size) { + auto tmp0 = at::vec::Vectorized::loadu(a + i); + auto tmp1 = tmp0 * vec_scale; + vec_tmp_max = at::vec::maximum(vec_tmp_max, tmp1); + _store(out + i, tmp1); + } + for (long i = vec_size * (size / vec_size); i < size; i++) { + auto tmp0 = a[i]; + auto tmp1 = tmp0 * scale; + tmp_max = std::max(tmp_max, tmp1); + out[i] = tmp1; + } + max = std::max(tmp_max, vec_tmp_max.reduce_max()); +} + +template +static inline scalar_t* conditional_data_ptr(scalar_t* ptr, scalar_t* ptr2) { + TORCH_CHECK(ptr2 == nullptr); + return ptr; +} + +template , int> = 0> +static inline scalar_t* conditional_data_ptr(float* ptr, scalar_t* ptr2) { + return ptr2; +} + +template +inline void fill_stub(scalar_t* data, scalar_t val, int64_t size) { + using Vec = at::vec::Vectorized; + Vec data_vec = Vec(val); + int64_t d = 0; + for (; d < size - (size % Vec::size()); d += Vec::size()) { + data_vec.store(data + d); + } + #if !defined(_MSC_VER) && !defined(COMPILING_FOR_MIN_SIZE) + # pragma unroll + #endif + for (; d < size; d++) { + data[d] = val; + } +} + +void reshape_attn_mask_to_4d( + at::Tensor& attn_mask, + int64_t batchSize, + int64_t num_head, + int64_t qSize, + int64_t kvSize) { + // Support mask shapes: + // 2d: ({Q_seq_len, 1} x {KV_seq_len, 1}) + // 4d: ({Batch, 1} x {Num_heads, 1} x {Q_seq_len, 1} x {KV_seq_len, 1}) + // Guaranteed in check_attn_mask_shape + int64_t attn_mask_size_0 = 1; + int64_t attn_mask_size_1 = 1; + if (attn_mask.dim() == 4) { + if (attn_mask.size(0) == batchSize) { + attn_mask_size_0 = batchSize; + } + if (attn_mask.size(1) == num_head) { + attn_mask_size_1 = num_head; + } + } + attn_mask = attn_mask + .view({attn_mask_size_0, attn_mask_size_1, attn_mask.size(-2), attn_mask.size(-1)}) + .expand({attn_mask_size_0, attn_mask_size_1, qSize, kvSize}); +} + +// TODO: Use at::native::_store instead when it supports Half. +template +inline void _store(scalar_t* dst, at::vec::Vectorized src, int size=at::vec::Vectorized::size()) { + src.store(dst, size); +} + +template +inline typename std::enable_if_t, void> +_store(scalar_t* dst, at::vec::Vectorized src) { + auto res = at::vec::convert_from_float(src, src); + res.store(dst, at::vec::Vectorized::size()); +} + +template +inline typename std::enable_if_t || std::is_same_v, void> +_store(scalar_t* dst, at::vec::Vectorized src) { + auto res = at::vec::convert(src); + res.store(dst, at::vec::Vectorized::size()); +} + +template +inline void pad_row_zero( + scalar_t* value_ptr, + scalar_t* padding_value_ptr, + int rows, + int cols, + int ldi) { + auto vec_size = at::vec::Vectorized::size(); + int i = 0; + for (; i < rows - 1; i++) { + int j = 0; + for (; j < cols - (cols % vec_size); j += vec_size) { + auto vec_v = + at::vec::Vectorized::loadu(value_ptr + i * ldi + j); + vec_v.store(padding_value_ptr + i * cols + j); + } + + if (j < cols) { + auto vec_v = at::vec::Vectorized::loadu( + value_ptr + i * ldi + j, cols - j); + vec_v.store(padding_value_ptr + i * cols + j, cols - j); + } + } + + // zero padding + int j = 0; + for (; j < cols - (cols % vec_size); j += vec_size) { + auto vec_v = at::vec::Vectorized(0); + vec_v.store(padding_value_ptr + i * cols + j); + } + + if (j < cols) { + auto vec_v = at::vec::Vectorized(0); + vec_v.store(padding_value_ptr + i * cols + j, cols - j); + } +} + +template +inline void pad_row_128_padding( + scalar_t* value_ptr, + scalar_t* padding_value_ptr, + int rows, + int cols, + int ldi, + int padding) { + auto vec_size = at::vec::Vectorized::size(); + int i = 0; + for (; i < rows - padding; i++) { + int j = 0; + for (; j < cols - (cols % vec_size); j += vec_size) { + auto vec_v = + at::vec::Vectorized::loadu(value_ptr + i * ldi + j); + vec_v.store(padding_value_ptr + i * cols + j); + } + + if (j < cols) { + auto vec_v = at::vec::Vectorized::loadu( + value_ptr + i * ldi + j, cols - j); + vec_v.store(padding_value_ptr + i * cols + j, cols - j); + } + } + + // 128 padding + for (; i < rows; i++) { + int j = 0; + for (; j < cols - (cols % vec_size); j += vec_size) { + auto vec_v = at::vec::Vectorized(128); + vec_v.store(padding_value_ptr + i * cols + j); + } + + if (j < cols) { + auto vec_v = at::vec::Vectorized(128); + vec_v.store(padding_value_ptr + i * cols + j, cols - j); + } + } +} + +template +inline void pad_col_zero( + scalar_t* value_ptr, + scalar_t* padding_value_ptr, + int rows, + int cols, + int ldi) { + auto vec_size = at::vec::Vectorized::size(); + for (int i = 0; i < rows; i++) { + int j = 0; + for (; j < cols - 1 - ((cols - 1) % vec_size); j += vec_size) { + auto vec_v = + at::vec::Vectorized::loadu(value_ptr + i * ldi + j); + vec_v.store(padding_value_ptr + i * cols + j); + } + if (j < cols - 1) { + auto vec_v = at::vec::Vectorized::loadu( + value_ptr + i * ldi + j, cols - 1 - j); + vec_v.store(padding_value_ptr + i * cols + j, cols - 1 - j); + *(padding_value_ptr + i * cols + cols - 1) = scalar_t(0); + } + } +} + +template +inline void pad_col_zero_padding( + scalar_t* value_ptr, + scalar_t* padding_value_ptr, + int rows, + int cols, + int ldi, + int padding) { + auto vec_size = at::vec::Vectorized::size(); + for (int i = 0; i < rows; i++) { + int j = 0; + for (; j < cols - padding - ((cols - padding) % vec_size); j += vec_size) { + auto vec_v = + at::vec::Vectorized::loadu(value_ptr + i * ldi + j); + vec_v.store(padding_value_ptr + i * cols + j); + } + if (j < cols - padding) { + auto vec_v = at::vec::Vectorized::loadu( + value_ptr + i * ldi + j, cols - padding - j); + vec_v.store(padding_value_ptr + i * cols + j, cols - padding - j); + *(padding_value_ptr + i * cols + cols - padding) = scalar_t(0); + } + } +} + +/* +1. dequant +2. add mask +3. max reduce for softmax +*/ +template +inline void _dequant_mask_max_fusion_kernel( + const int32_t* in, + const mask_t* mask_ptr, + const int32_t* sum_a_ptr, + const int32_t* sum_b_ptr, + const int& M, + const int& N, + const int& ldi, + const int& ldm, // leading dimension mask + const int& ldo, + const int32_t& beta, // zp_a*zp_b*k + const float& alpha, // scale_a*scale_b*scale_sdpa + float* out, + float* sfm_max_ptr) { + const int32_t vec_size = at::vec::Vectorized::size(); + auto vec_beta = at::vec::Vectorized(beta); + auto vec_alpha = at::vec::Vectorized(alpha); + for (long row = 0; row < M; row += 1) { + auto sum_a = sum_a_ptr[row]; + auto vec_sum_a = at::vec::Vectorized(sum_a); + const int32_t* tmp_in = in + row * ldi; + float* tmp_out = out + row * ldo; + const mask_t* mask_data_ptr = mask_ptr + row * ldm; + float tmp_max = -std::numeric_limits::infinity(); + auto vec_tmp_max = at::vec::Vectorized(tmp_max); + for (long col = 0; col < vec_size * (N / vec_size); col += vec_size) { + auto vec_sum_b = at::vec::Vectorized::loadu(sum_b_ptr + col); + auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col); + auto tmp1 = tmp0 - vec_sum_b; + auto tmp2 = tmp1 - vec_sum_a; + auto tmp3 = tmp2 + vec_beta; + auto tmp4 = at::vec::convert(tmp3); + auto tmp5 = tmp4 * vec_alpha; + auto tmp6 = at::vec::Vectorized::loadu(mask_data_ptr + col); + auto tmp7 = at::vec::convert(tmp6); + auto tmp8 = tmp5 + tmp7; + vec_tmp_max = at::vec::maximum(vec_tmp_max, tmp8); + _store(tmp_out + col, tmp8); + } + tmp_max = std::max(tmp_max, vec_tmp_max.reduce_max()); + for (long col = vec_size * (N / vec_size); col < N; col++) { + auto sum_b = sum_b_ptr[col]; + auto tmp0 = tmp_in[col]; + auto tmp1 = tmp0 - sum_b; + auto tmp2 = tmp1 - sum_a; + auto tmp3 = tmp2 + beta; + auto tmp4 = (float) tmp3; + auto tmp5 = tmp4 * alpha; + auto tmp6 = mask_data_ptr[col]; + auto tmp7 = (float) tmp6; + auto tmp8 = tmp5 + tmp7; + tmp_max = std::max(tmp_max, tmp8); + tmp_out[col] = tmp8; + } + sfm_max_ptr[row] = std::max(sfm_max_ptr[row], tmp_max); + } +} + +/* +1. dequant +2. max reduce for softmax +*/ +inline void _dequant_max_fusion_kernel( + const int32_t* in, + const int32_t* sum_a_ptr, + const int32_t* sum_b_ptr, + const int& M, + const int& N, + const int& ldi, + const int& ldo, + const int32_t& beta, // zp_a*zp_b*k + const float& alpha, // scale_a*scale_b*scale_sdpa + float* out, + float* sfm_max_ptr) { + const int32_t vec_size = at::vec::Vectorized::size(); + auto vec_beta = at::vec::Vectorized(beta); + auto vec_alpha = at::vec::Vectorized(alpha); + for (long row = 0; row < M; row += 1) { + auto sum_a = sum_a_ptr[row]; + auto vec_sum_a = at::vec::Vectorized(sum_a); + const int32_t* tmp_in = in + row * ldi; + float* tmp_out = out + row * ldo; + float tmp_max = -std::numeric_limits::infinity(); + auto vec_tmp_max = at::vec::Vectorized(tmp_max); + for (long col = 0; col < vec_size * (N / vec_size); col += vec_size) { + auto vec_sum_b = at::vec::Vectorized::loadu(sum_b_ptr + col); + auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col); + auto tmp1 = tmp0 - vec_sum_b; + auto tmp2 = tmp1 - vec_sum_a; + auto tmp3 = tmp2 + vec_beta; + auto tmp4 = at::vec::convert(tmp3); + auto tmp5 = tmp4 * vec_alpha; + vec_tmp_max = at::vec::maximum(vec_tmp_max, tmp5); + _store(tmp_out + col, tmp5); + } + tmp_max = std::max(tmp_max, vec_tmp_max.reduce_max()); + for (long col = vec_size * (N / vec_size); col < N; col++) { + auto sum_b = sum_b_ptr[col]; + auto tmp0 = tmp_in[col]; + auto tmp1 = tmp0 - sum_b; + auto tmp2 = tmp1 - sum_a; + auto tmp3 = tmp2 + beta; + auto tmp4 = (float) tmp3; + auto tmp5 = tmp4 * alpha; + tmp_max = std::max(tmp_max, tmp5); + tmp_out[col] = tmp5; + } + sfm_max_ptr[row] = std::max(sfm_max_ptr[row], tmp_max); + } +} + +/* +1. Softmax: sub max, exp, sum reduce, div sum +2. quant +3. sum for attention +*/ +template +inline void _sub_exp_sum_div_quant_sum_fusion_kernel( + const float* in, + const int64_t& M, + const int64_t& N_step, + const int64_t& NSlice, + const int& ldi, + const int& ldo, + const int& kvSize, + const int& rndkvSplitSize, + const int& av_gemm_K, + const int32_t& beta1, // zp_a + const int32_t& beta2, // zp_b + const float& alpha, // scale_a + float* local, + scalar_t* out, + float* sfm_max_ptr, + float* sfm_sum_ptr, + int32_t* sum_a_ptr) { + const int32_t vec_size = at::vec::Vectorized::size(); + float min_val = 0; + float max_val = 255; + auto vec_min_val = at::vec::Vectorized(min_val); + auto vec_max_val = at::vec::Vectorized(max_val); + scalar_t zero = 0; + auto vec_zero = at::vec::Vectorized(zero); + float beta1_float = (float) beta1; + auto vec_beta1 = at::vec::Vectorized(beta1_float); + for (int64_t row = 0; row < M; ++row) { + auto sfm_max = sfm_max_ptr[row]; + auto vec_max = at::vec::Vectorized(sfm_max); + // sub max, exp, sum reduce + const float* qk_block_data = in + row * rndkvSplitSize; + for (int64_t l = 0; l < NSlice; l ++) { + int64_t n = l * N_step; + int64_t kvBlockSize = std::min(N_step, kvSize - n); + const float* tmp_in = qk_block_data + l * ldi; + float tmp_sum = 0; + auto vec_tmp_sum = at::vec::Vectorized(tmp_sum); + float* tmp_out = local + n; + for (long col = 0; col < vec_size * (kvBlockSize / vec_size); col += vec_size) { + auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col); + auto tmp1 = tmp0 - vec_max; + auto tmp2 = tmp1.exp_u20(); + vec_tmp_sum += tmp2; + _store(tmp_out + col, tmp2); + } + tmp_sum += vec_tmp_sum.reduce_add(); + for (long col = vec_size * (kvBlockSize / vec_size); col < kvBlockSize; col++) { + auto tmp0 = tmp_in[col]; + auto tmp1 = tmp0 - sfm_max; + auto tmp2 = exp(tmp1); + tmp_sum += tmp2; + tmp_out[col] = tmp2; + } + sfm_sum_ptr[row] += tmp_sum; + } + // div sum, sum for attention + auto sum_scale = 1 / sfm_sum_ptr[row] / alpha; + auto vec_sum_scale = at::vec::Vectorized(sum_scale); + scalar_t* qk_reduced_block_data = out + row * av_gemm_K; + for (int64_t l = 0; l < NSlice; l ++) { + int64_t n = l * N_step; + int64_t kvBlockSize = std::min(N_step, kvSize - n); + int32_t tmp_sum = 0; + auto vec_tmp_sum = at::vec::Vectorized(tmp_sum); + float* tmp_in = local + n; + scalar_t* tmp_out = qk_reduced_block_data + l * ldo; + for (long col = 0; col < vec_size * (kvBlockSize / vec_size); col += vec_size) { + auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col); + auto tmp1 = tmp0 * vec_sum_scale; + auto tmp2 = tmp1.round(); + auto tmp3 = tmp2 + vec_beta1; + auto tmp4 = at::vec::maximum(tmp3, vec_min_val); + auto tmp5 = at::vec::minimum(tmp4, vec_max_val); + _store(tmp_out + col, tmp5); + auto tmp6 = at::vec::convert(tmp5); + vec_tmp_sum += tmp6; + } + tmp_sum += vec_tmp_sum.reduce_add(); + for (long col = vec_size * (kvBlockSize / vec_size); col < kvBlockSize; col++) { + auto tmp0 = tmp_in[col]; + auto tmp1 = tmp0 * sum_scale; + auto tmp2 = std::nearbyint(tmp1); + auto tmp3 = tmp2 + beta1_float; + auto tmp4 = std::max(tmp3, min_val); + auto tmp5 = std::min(tmp4, max_val); + tmp_out[col] = tmp5; + auto tmp6 = (int32_t) tmp5; + tmp_sum += tmp6; + } + sum_a_ptr[row] += tmp_sum * beta2; + // set zero + for (long col = kvBlockSize; col < vec_size * (av_gemm_K / vec_size); col += vec_size) { + _store(tmp_out + col, vec_zero); + } + for (long col = vec_size * (av_gemm_K / vec_size); col < av_gemm_K; col++) { + tmp_out[col] = zero; + } + } + } +} + +template +inline void _sub_exp_sum_div_quant_fusion_kernel( + const float* in, + const int64_t& M, + const int64_t& N_step, + const int64_t& NSlice, + const int& ldi, + const int& ldo, + const int& kvSize, + const int& rndkvSplitSize, + const int& av_gemm_K, + const int32_t& beta1, // zp_a + const float& alpha, // scale_a + float* local, + scalar_t* out, + float* sfm_max_ptr, + float* sfm_sum_ptr) { + const int32_t vec_size = at::vec::Vectorized::size(); + float min_val = 0; + float max_val = 255; + auto vec_min_val = at::vec::Vectorized(min_val); + auto vec_max_val = at::vec::Vectorized(max_val); + scalar_t zero = 0; + auto vec_zero = at::vec::Vectorized(zero); + float beta1_float = (float) beta1; + auto vec_beta1 = at::vec::Vectorized(beta1_float); + for (int64_t row = 0; row < M; ++row) { + auto sfm_max = sfm_max_ptr[row]; + auto vec_max = at::vec::Vectorized(sfm_max); + // sub max, exp, sum reduce + const float* qk_block_data = in + row * rndkvSplitSize; + for (int64_t l = 0; l < NSlice; l ++) { + int64_t n = l * N_step; + int64_t kvBlockSize = std::min(N_step, kvSize - n); + const float* tmp_in = qk_block_data + l * ldi; + float tmp_sum = 0; + auto vec_tmp_sum = at::vec::Vectorized(tmp_sum); + float* tmp_out = local + n; + for (long col = 0; col < vec_size * (kvBlockSize / vec_size); col += vec_size) { + auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col); + auto tmp1 = tmp0 - vec_max; + auto tmp2 = tmp1.exp_u20(); + vec_tmp_sum += tmp2; + _store(tmp_out + col, tmp2); + } + tmp_sum += vec_tmp_sum.reduce_add(); + for (long col = vec_size * (kvBlockSize / vec_size); col < kvBlockSize; col++) { + auto tmp0 = tmp_in[col]; + auto tmp1 = tmp0 - sfm_max; + auto tmp2 = exp(tmp1); + tmp_sum += tmp2; + tmp_out[col] = tmp2; + } + sfm_sum_ptr[row] += tmp_sum; + } + // div sum, sum for attention + auto sum_scale = 1 / sfm_sum_ptr[row] / alpha; + auto vec_sum_scale = at::vec::Vectorized(sum_scale); + scalar_t* qk_reduced_block_data = out + row * av_gemm_K; + for (int64_t l = 0; l < NSlice; l ++) { + int64_t n = l * N_step; + int64_t kvBlockSize = std::min(N_step, kvSize - n); + float* tmp_in = local + n; + scalar_t* tmp_out = qk_reduced_block_data + l * ldo; + for (long col = 0; col < vec_size * (kvBlockSize / vec_size); col += vec_size) { + auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col); + auto tmp1 = tmp0 * vec_sum_scale; + auto tmp2 = tmp1.round(); + auto tmp3 = tmp2 + vec_beta1; + auto tmp4 = at::vec::maximum(tmp3, vec_min_val); + auto tmp5 = at::vec::minimum(tmp4, vec_max_val); + _store(tmp_out + col, tmp5); + } + for (long col = vec_size * (kvBlockSize / vec_size); col < kvBlockSize; col++) { + auto tmp0 = tmp_in[col]; + auto tmp1 = tmp0 * sum_scale; + auto tmp2 = std::nearbyint(tmp1); + auto tmp3 = tmp2 + beta1_float; + auto tmp4 = std::max(tmp3, min_val); + auto tmp5 = std::min(tmp4, max_val); + tmp_out[col] = tmp5; + } + // set zero + for (long col = kvBlockSize; col < vec_size * (av_gemm_K / vec_size); col += vec_size) { + _store(tmp_out + col, vec_zero); + } + for (long col = vec_size * (av_gemm_K / vec_size); col < av_gemm_K; col++) { + tmp_out[col] = zero; + } + } + } +} + +/* +1. dequant +2. quant +*/ +template +inline void _dequant_quant_fusion_kernel( + const int32_t* in, + const int32_t* sum_a_ptr, + const int32_t* sum_b_ptr, + const int& M, + const int& N, + const int& ldi, + const int& ldo, + const int32_t& beta1, // zp_a*zp_b*k + const int32_t& beta2, // zp_c + const float& alpha, // scale_a*scale_b/scale_c + scalar_t* out) { + const int32_t vec_size = at::vec::Vectorized::size(); + float min_val = 0; + float max_val = 255; + auto vec_min_val = at::vec::Vectorized(min_val); + auto vec_max_val = at::vec::Vectorized(max_val); + auto vec_beta1 = at::vec::Vectorized(beta1); + auto vec_alpha = at::vec::Vectorized(alpha); + float beta2_float = (float) beta2; + auto vec_beta2 = at::vec::Vectorized(beta2_float); + for (long row = 0; row < M; row += 1) { + auto sum_a = sum_a_ptr[row]; + auto vec_sum_a = at::vec::Vectorized(sum_a); + const int32_t* tmp_in = in + row * ldi; + scalar_t* tmp_out = out + row * ldo; + for (long col = 0; col < vec_size * (N / vec_size); col += vec_size) { + auto vec_sum_b = at::vec::Vectorized::loadu(sum_b_ptr + col); + auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col); + auto tmp1 = tmp0 - vec_sum_b; + auto tmp2 = tmp1 - vec_sum_a; + auto tmp3 = tmp2 + vec_beta1; + auto tmp4 = at::vec::convert(tmp3); + auto tmp5 = tmp4 * vec_alpha; + auto tmp6 = tmp5.round(); + auto tmp7 = tmp6 + vec_beta2; + auto tmp8 = at::vec::maximum(tmp7, vec_min_val); + auto tmp9 = at::vec::minimum(tmp8, vec_max_val); + _store(tmp_out + col, tmp9); + } + for (long col = vec_size * (N / vec_size); col < N; col++) { + auto sum_b = sum_b_ptr[col]; + auto tmp0 = tmp_in[col]; + auto tmp1 = tmp0 - sum_b; + auto tmp2 = tmp1 - sum_a; + auto tmp3 = tmp2 + beta1; + auto tmp4 = (float) tmp3; + auto tmp5 = tmp4 * alpha; + auto tmp6 = std::nearbyint(tmp5); + auto tmp7 = tmp6 + beta2_float; + auto tmp8 = std::max(tmp7, min_val); + auto tmp9 = std::min(tmp8, max_val); + tmp_out[col] = tmp9; + } + } +} + +template +inline void _int_sum_b_contiguous_kernel_helper( + const scalar_t* in, + int32_t* out, + const int& N, + const int32_t& scale) { + const int32_t vec_size = at::vec::Vectorized::size(); + int32_t tmp_sum = 0; + auto vec_tmp_sum = at::vec::Vectorized(tmp_sum); + for (long i = 0; i < vec_size * (N / vec_size); i += vec_size) { + auto tmp0 = at::vec::Vectorized::loadu(in + i); + auto tmp1 = at::vec::convert(tmp0); + vec_tmp_sum = vec_tmp_sum + tmp1; + } + tmp_sum += vec_tmp_sum.reduce_add(); + for (long i = vec_size * (N / vec_size); i < N; i++) { + tmp_sum += static_cast(in[i]); + } + out[0] = tmp_sum * scale; +} + +template +inline void _int_sum_b_contiguous_kernel( + const scalar_t* in, + int32_t* out, + const int& M, + const int& N, + const int& ld, + const int32_t& scale) { + for (long r = 0; r < M; r += 1) { + _int_sum_b_contiguous_kernel_helper(in + r * ld, out + r, N, scale); + } +} + +template +inline void _int_sum_a_contiguous_kernel( + const scalar_t* in, + int32_t* out, + const int& M, + const int& N, + const int& ld, + const int32_t& scale) { + const int32_t vec_size = at::vec::Vectorized::size(); + auto vec_scale = at::vec::Vectorized(scale); + // initialization with 0 + int32_t zero = 0; + auto vec_zero = at::vec::Vectorized(zero); + for (long i = 0; i < vec_size * (M / vec_size); i += vec_size) { + _store(out + i, vec_zero); + } + for (long i = vec_size * (M / vec_size); i < M; i++) { + out[i] = zero; + } + // sum + for (long j = 0; j < N; j++) { + const scalar_t* tmp_in = in + j * ld; + for (long i = 0; i < vec_size * (M / vec_size); i += vec_size) { + auto tmp0 = at::vec::Vectorized::loadu(tmp_in + i); + auto tmp1 = at::vec::Vectorized::loadu(out + i); + auto tmp2 = at::vec::convert(tmp0); + auto tmp3 = tmp1 + tmp2; + _store(out + i, tmp3); + } + for (long i = vec_size * (M / vec_size); i < M; i++) { + auto tmp0 = tmp_in[i]; + auto tmp1 = out[i]; + auto tmp2 = static_cast(tmp0); + auto tmp3 = tmp1 + tmp2; + out[i] = tmp3; + } + } + // scale + for (long i = 0; i < vec_size * (M / vec_size); i += vec_size) { + auto tmp0 = at::vec::Vectorized::loadu(out + i); + auto tmp1 = tmp0 * vec_scale; + _store(out + i, tmp1); + } + for (long i = vec_size * (M / vec_size); i < M; i++) { + auto tmp0 = out[i]; + auto tmp1 = tmp0 * scale; + out[i] = tmp1; + } +} + +inline void do_convert_u8_s8( + unsigned char* src, + signed char* dst, + int64_t in_rows, + int64_t in_cols, + int64_t ldi, + int64_t ldo) { + auto vec_size = at::vec::Vectorized::size(); + auto vec_128 = at::vec::Vectorized(128); + for (int64_t r = 0; r < in_rows; r++) { + const unsigned char* tmp_src = src + r * ldi; + signed char* tmp_dst = dst + r * ldo; + for (int64_t c = 0; c < vec_size * (in_cols / vec_size); c += vec_size) { + auto tmp0 = at::vec::Vectorized::loadu(tmp_src + c, vec_size); + auto tmp1 = at::vec::convert(tmp0); + auto tmp2 = tmp1 - vec_128; + auto tmp3 = at::vec::convert(tmp2); + _store(tmp_dst + c, tmp3, vec_size); + } + for (int64_t c = vec_size * (in_cols / vec_size); c < in_cols; c++) { + auto tmp0 = tmp_src[c]; + auto tmp1 = (int16_t) tmp0; + auto tmp2 = tmp1 - 128; + auto tmp3 = (signed char) tmp2; + tmp_dst[c] = tmp3; + } + } +} + +template +inline void do_transpose( + scalar_t* src, + scalar_t* dst, + int64_t in_rows, + int64_t in_cols, + int64_t ldi, + int64_t ldo) { + for (int64_t r=0; r +inline void do_copy( + scalar_t* src, + scalar_t* dst, + int64_t in_rows, + int64_t in_cols, + int64_t ldi, + int64_t ldo) { + for (int64_t r=0; r +inline void pad_remain_row_col( + scalar_t* value_ptr, + int rows, + int cols, + int prows, + int pcols, + int ldi, + scalar_t pad_val=0) { + auto psize = pcols - cols; + if (psize == 0 && prows == rows) { + return; + } + const int32_t vec_size = at::vec::Vectorized::size(); + auto pad = at::vec::Vectorized(pad_val); + if (psize > 0) { + for (int i = 0; i < rows; i++) { + int j = 0; + for (; j < psize - (psize % vec_size); j += vec_size) { + pad.store(value_ptr + i * ldi + cols + j); + } + if (j < psize) { + pad.store(value_ptr + i * ldi + cols + j, psize - j); + } + } + } + + for (int i = rows; i < prows; i++) { + int j = 0; + for (; j < pcols - (pcols % vec_size); j += vec_size) { + pad.store(value_ptr + i * ldi + j); + } + if (j < pcols) { + pad.store(value_ptr + i * ldi + j, pcols - j); + } + } +} + +template +inline void copy_value_with_pad( + scalar_t* value_ptr, + scalar_t* dst_ptr, + int rows, + int cols, + int prows, + int pcols, + int ldi, + scalar_t pad_val=0) { + const int32_t vec_size = at::vec::Vectorized::size(); + auto pad = at::vec::Vectorized(pad_val); + int i = 0; + for (; i < rows; i++) { + int j = 0; + for (; j < cols - (cols % vec_size); j += vec_size) { + auto vec_v = + at::vec::Vectorized::loadu(value_ptr + i * ldi + j); + vec_v.store(dst_ptr + i * pcols + j); + } + + if (j < cols) { + auto vec_v = at::vec::Vectorized::loadu( + value_ptr + i * ldi + j, cols - j); + vec_v.store(dst_ptr + i * pcols + j, cols - j); + } + + // col padding + auto psize = pcols - cols; + if (psize > 0) { + int pj = 0; + for (; pj < psize - (psize % vec_size); pj += vec_size) { + pad.store(dst_ptr + i * pcols + cols + pj); + } + if (pj < psize) { + pad.store(dst_ptr + i * pcols + cols + pj, psize - pj); + } + } + } + + // row padding + for (; i < prows; i++) { + int j = 0; + for (; j < pcols - (pcols % vec_size); j += vec_size) { + pad.store(dst_ptr + i * pcols + j); + } + if (j < pcols) { + pad.store(dst_ptr + i * pcols + j, pcols - j); + } + + } + +} + +// UINT8 - one parallel loop with u8u8s32 GEMM +template +inline typename std::enable_if_t, void> +sdpa_int8_kernel_one_loop_impl( + const at::Tensor& output, + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + double dropout_p, + bool is_causal, + at::Tensor& attention_mask, + double scale, + int32_t q_zp, + float q_scale, + int32_t k_zp, + float k_scale, + int32_t v_zp, + float v_scale, + int32_t a_zp, + float a_scale, + int32_t o_zp, + float o_scale) { + // Query (Batch x Q_seq_len x Num_heads x Dim_per_head) + // Key (Batch x KV_seq_len x Num_heads x Dim_per_head) + // Value (Batch x KV_seq_len x Num_heads x Dim_per_head) + + const auto accumulate_dtype = at::kFloat; + + using accum_t = float; + using Vec = at::vec::Vectorized; + accum_t scaling_factor = calculate_scale(query, scale); + int block_64 = 64; + // Sizes + TORCH_CHECK( + (query.size(3) == value.size(3)) && (key.size(3) == value.size(3)), + "scaled_dot_product_attention_sdpa: Q/K/V should have the same head size"); + TORCH_CHECK( + kv_split_size % block_64 == 0, "kv_split_size is not divisble by ", block_64); + + int64_t batchSize = query.size(0); + int64_t qSize = query.size(1); + int64_t kvSize = value.size(1); + int64_t num_head = query.size(2); + int64_t headSize = query.size(3); + + + bool has_attn_mask = attention_mask.defined() && attention_mask.numel(); + if (has_attn_mask) { + reshape_attn_mask_to_4d(attention_mask, batchSize, num_head, qSize, kvSize); + } + + // Strides + int64_t qStrideB = query.stride(0); + int64_t qStrideM = query.stride(1); + int64_t qStrideH = query.stride(2); + int64_t kStrideB = key.stride(0); + int64_t kStrideN = key.stride(1); + int64_t kStrideH = key.stride(2); + int64_t vStrideB = value.stride(0); + int64_t vStrideN = value.stride(1); + int64_t vStrideH = value.stride(2); + int64_t oStrideB = output.stride(0); + int64_t oStrideM = output.stride(1); + int64_t oStrideH = output.stride(2); + int64_t mStrideB = + (attention_mask.defined() && attention_mask.size(0) > 1) + ? attention_mask.stride(0) + : 0; + int64_t mStrideH = + (attention_mask.defined() && attention_mask.size(1) > 1) + ? attention_mask.stride(1) + : 0; + int64_t mStrideM = + (attention_mask.defined() && attention_mask.size(2) > 1) + ? attention_mask.stride(2) + : 0; + int64_t mStrideN = + (attention_mask.defined() && attention_mask.size(3) > 1) + ? attention_mask.stride(3) + : 0; + + int64_t qSplitSize = q_split_size > qSize ? qSize : q_split_size; + int64_t kvSplitSize = kv_split_size > kvSize ? kvSize : kv_split_size; + int64_t qSlice = (qSize - 1) / qSplitSize + 1; + int64_t qTail = (qSize - 1) % qSplitSize + 1; + int64_t kvSlice = (kvSize - 1) / kvSplitSize + 1; + int64_t kvTail = (kvSize - 1) % kvSplitSize + 1; + int64_t num_thread = at::get_num_threads(); + + int64_t rndHeadSize = (headSize + block_64 - 1L) / block_64 * block_64; + // one of 16, 32, 48, 64 + auto select_tail_tail_block_size = [](int64_t size) -> int64_t { + if (size == 0) { + return 0; + } else if (size <= 16) { + return 16; + } else if (size <= 32) { + return 32; + } else if (size <= 48) { + return 48; + } else { + return 64; + } + }; + int64_t kv_tail_tail_block_size = select_tail_tail_block_size(kvTail % block_64); + int64_t rndkvSplitSize = (kvSplitSize + block_64 - 1L) / block_64 * block_64; + int64_t rndkvTail = (kvTail + block_64 - 1L) / block_64 * block_64; + int64_t rndkvSize = kv_split_size > kvSize ? rndkvTail : rndkvSplitSize * kvSlice + rndkvTail; + + bool av_gemm_K_mul4 = kvSplitSize % 4 == 0; + int av_gemm_K_padding = av_gemm_K_mul4 ? 0 : 4 - kvSplitSize % 4; + // // If K of Gemm is not even, use mkl gemm instead of tpp for BF16 + int av_gemm_K = kvSplitSize + av_gemm_K_padding; + bool av_gemm_K_tail_mul4 = kvTail % 4 == 0; + int av_gemm_K_tail_padding = av_gemm_K_tail_mul4 ? 0 : 4 - kvTail % 4; + int av_gemm_K_tail = kvTail + av_gemm_K_tail_padding; + + auto u8_dt = at::ScalarType::Byte; + auto s8_dt = at::ScalarType::Int; + auto f32_dt = at::ScalarType::Float; + + // Data ptrs + scalar_t* q_data = query.data_ptr(); + scalar_t* k_data = key.data_ptr(); + scalar_t* v_data = value.data_ptr(); + mask_t* mask_data = attention_mask.defined() + ? attention_mask.data_ptr() + : nullptr; + scalar_t* out_data = output.data_ptr(); + + // Create tpp kernels for Query @ Key + bool headSize_mul4 = headSize % 4 == 0; + // If K of Gemm is not even, use mkl gemm instead of tpp for BF16 + int qk_gemm_K_padding = headSize_mul4 ? 0 : 4 - headSize % 4; + int qk_gemm_K = headSize + qk_gemm_K_padding; + + int64_t qk_reduce_strideL = qSplitSize * av_gemm_K; + int64_t v_reorder_strideL = av_gemm_K * rndHeadSize; + + int64_t total_size_uint8_per_thread = + /* qk */ kvSlice * qSplitSize * rndkvSplitSize * 4 + + /* qk_local */ kvSlice * av_gemm_K * 4 + + /* qk_reduce */ kvSlice * qk_reduce_strideL + + /* qk_s32 */ qSplitSize * rndkvSplitSize * 4 + + /* dst_s32 */ qSplitSize * rndHeadSize * 4 + + /* softmax_sum */ qSplitSize * 4 + + /* query_sum */ qSplitSize * 4 + + /* attention_sum */ qSplitSize * 4 + + /* softmax max */ qSplitSize * 4 + + /* query_padding_data */ qSplitSize * qk_gemm_K + + /* key_sum */ kvSize * 4 + + /* value_sum */ headSize * 4 + + /* key_t_reorder */ qk_gemm_K * rndkvSize + + /* value_t_reorder */ kvSlice * v_reorder_strideL; + + at::Tensor total_buf = at::empty( + {num_thread, total_size_uint8_per_thread}, + query.options()).zero_(); + scalar_t* total_buf_data = total_buf.data_ptr(); + + at::parallel_for( + 0, batchSize * num_head, 1, [&](int64_t begin, int64_t end) { + int64_t i = 0, j = 0; + at::native::data_index_init( + begin, i, batchSize, j, num_head); + int ompIdx = at::get_thread_num(); + scalar_t* total_buf_ptr = total_buf_data + ompIdx * total_size_uint8_per_thread; + int32_t offset = 0; + accum_t* qk_data = reinterpret_cast(total_buf_ptr); + offset += kvSlice * qSplitSize * rndkvSplitSize * 4; + accum_t* qk_local_data = reinterpret_cast(total_buf_ptr + offset); + offset += kvSlice * av_gemm_K * 4; + scalar_t* qk_reduced_data = reinterpret_cast(total_buf_ptr + offset); + offset += kvSlice * qk_reduce_strideL; + int32_t* qk_s32_data = reinterpret_cast(total_buf_ptr + offset); + offset += qSplitSize * rndkvSplitSize * 4; + int32_t* dst_s32_data = reinterpret_cast(total_buf_ptr + offset); + offset += qSplitSize * rndHeadSize * 4; + accum_t* sfm_sum_ptr = reinterpret_cast(total_buf_ptr + offset); + offset += qSplitSize * 4; + int32_t* q_sum_ptr = reinterpret_cast(total_buf_ptr + offset); + offset += qSplitSize * 4; + int32_t* a_sum_ptr = reinterpret_cast(total_buf_ptr + offset); + offset += qSplitSize * 4; + accum_t* sfm_max_ptr = reinterpret_cast(total_buf_ptr + offset); + offset += qSplitSize * 4; + scalar_t* query_t_padding_ptr = reinterpret_cast(total_buf_ptr + offset); + offset += qSplitSize * qk_gemm_K; + + int32_t* k_sum_ptr = reinterpret_cast(total_buf_ptr + offset); + offset += kvSize * 4; + int32_t* v_sum_ptr = reinterpret_cast(total_buf_ptr + offset); + offset += headSize * 4; + scalar_t* key_reorder_ptr = reinterpret_cast(total_buf_ptr + offset); + offset += qk_gemm_K * rndkvSize; + scalar_t* value_reorder_ptr = reinterpret_cast(total_buf_ptr + offset); + + uint8_t* B_blocked_xform_u8 = new uint8_t[std::max(qk_gemm_K, av_gemm_K) * block_64]; + + for (const auto z : c10::irange(begin, end)) { + (void)z; // Suppress unused variable + + // sum k and v + if (q_zp == 0) { + fill_stub(k_sum_ptr, static_cast(0), kvSize); + } else { + _int_sum_b_contiguous_kernel(k_data + i * kStrideB + j * kStrideH, + k_sum_ptr, + kvSize, headSize, kStrideN, q_zp); + } + if (a_zp == 0) { + fill_stub(v_sum_ptr, static_cast(0), headSize); + } else { + _int_sum_a_contiguous_kernel(v_data + i * vStrideB + j * vStrideH, + v_sum_ptr, + headSize, kvSize, vStrideN, a_zp); + } + + // pack + for (int64_t n = 0; n < kvSize; n += kvSplitSize) { + if (n + kvSplitSize < kvSize) { + for (int64_t b = 0; b < rndkvSplitSize; b += block_64) { + do_transpose( + k_data + i * kStrideB + j * kStrideH + n * kStrideN + b * kStrideN, + B_blocked_xform_u8, + std::min(int(kvSplitSize - b), block_64), + headSize, + kStrideN, + block_64); + if (!headSize_mul4) { + pad_remain_row_col( + B_blocked_xform_u8, + headSize, + block_64, + qk_gemm_K, + block_64, + block_64 + ); + } + // Pack + at::native::cpublas::pack( + qk_gemm_K, // K + block_64, // N + block_64, // ld_in + block_64, // ld_out + u8_dt, // dt_in + u8_dt, // dt_out + B_blocked_xform_u8, + key_reorder_ptr + n * qk_gemm_K + + b * qk_gemm_K); + } + // split headSize to block_64, block_64, block_64 ... + // [kvSplitSize, headSize] -> [av_gemm_K, block_64 ...] + for (int64_t b = 0; b < rndHeadSize; b += block_64) { + at::native::cpublas::pack( + av_gemm_K, + block_64, + vStrideN, // block_64, + block_64, + u8_dt, + u8_dt, + v_data + i * vStrideB + j * vStrideH + n * vStrideN + b, + value_reorder_ptr + n * rndHeadSize + + av_gemm_K * b); + } + } else { + // tail + auto block_size = kvTail >= block_64 ? block_64 : kv_tail_tail_block_size; + int64_t b = 0; + while (b < rndkvTail) { + do_transpose( + k_data + i * kStrideB + j * kStrideH + n * kStrideN + b * kStrideN, + B_blocked_xform_u8, + std::min(kvTail - b, block_size), + headSize, + kStrideN, + block_size); + if (!headSize_mul4) { + pad_remain_row_col( + B_blocked_xform_u8, + headSize, + block_size, + qk_gemm_K, + block_size, + block_size + ); + } + // Pack + at::native::cpublas::pack( + qk_gemm_K, + block_size, + block_size, + block_size, + u8_dt, + u8_dt, + B_blocked_xform_u8, + key_reorder_ptr + n * qk_gemm_K + + b * qk_gemm_K); + b += block_size; + block_size = (kvTail - b) >= block_64 ? block_64 : kv_tail_tail_block_size; + } + // split headSize to block_64, block_64, block_64 ... + // [kvTail, headSize] -> [av_gemm_K_tail, block_64 ...] + for (int64_t b = 0; b < headSize; b += block_64) { + at::native::cpublas::pack( + av_gemm_K, + block_64, + vStrideN, // block_64, + block_64, + u8_dt, + u8_dt, + v_data + i * vStrideB + j * vStrideH + n * vStrideN + b, + value_reorder_ptr + n * rndHeadSize + + av_gemm_K * b); + } + } + } + + // sdpa core + for (int64_t k = 0; k < qSlice; k++) { + int64_t m = k * qSplitSize; + int64_t qBlockSize = std::min(qSplitSize, qSize - m); + // Initialize sum and max + fill_stub( + sfm_sum_ptr, static_cast(0), qSplitSize); + fill_stub( + a_sum_ptr, static_cast(0), qSplitSize); + fill_stub( + sfm_max_ptr, static_cast(-std::numeric_limits::infinity()), qSplitSize); + int64_t num_keys = + is_causal ? std::min(m + qBlockSize, kvSize) : kvSize; + copy_value_with_pad( + q_data + i * qStrideB + j * qStrideH + m * qStrideM, + query_t_padding_ptr, + qBlockSize, + headSize, + qBlockSize, + qk_gemm_K, + qStrideM); + + if (k_zp != 0) { + _int_sum_b_contiguous_kernel(q_data + i * qStrideB + j * qStrideH + m * qStrideM, + q_sum_ptr, qBlockSize, headSize, qStrideM, k_zp); + } else { + fill_stub( + q_sum_ptr, static_cast(0), qSplitSize); + } + const int64_t rkvSlice = (num_keys - 1) / kvSplitSize + 1; + for (int64_t l = 0; l < rkvSlice; l++) { + int64_t n = l * kvSplitSize; + int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n); + // Calculate sums for dequant compensation item + if (qBlockSize == qSplitSize) { + // q main + if (n + kvSplitSize < kvSize) { + // k main + for (int64_t b = 0; b < kvSplitSize; b += block_64) { + at::native::cpublas::brgemm( + qSplitSize, block_64, qk_gemm_K, + qk_gemm_K, // lda + block_64, //ldb + rndkvSplitSize, //ldc, + false, + query_t_padding_ptr, + key_reorder_ptr + n * qk_gemm_K + + b * qk_gemm_K, + qk_s32_data + b); + } + } else { + auto block_size = kvTail >= block_64 ? block_64 : kv_tail_tail_block_size; + int64_t b = 0; + while (b < kvTail) { + at::native::cpublas::brgemm( + qSplitSize, block_size, qk_gemm_K, + qk_gemm_K, // lda + block_size, //ldb + rndkvTail, //ldc + false, + query_t_padding_ptr, + key_reorder_ptr + n * qk_gemm_K + + b * qk_gemm_K, + qk_s32_data + b); + b += block_size; + block_size = (kvTail - b) >= block_64 ? block_64 : kv_tail_tail_block_size; + } + } + } else { + if (n + kvSplitSize < kvSize) { + for (int64_t b = 0; b < kvSplitSize; b += block_64) { + at::native::cpublas::brgemm( + qTail, block_64, qk_gemm_K, + qk_gemm_K,// lda + block_64, //ldb + rndkvSplitSize, //ldc + false, + query_t_padding_ptr, + key_reorder_ptr + n * qk_gemm_K + + b * qk_gemm_K, + qk_s32_data + b); + } + } else { + // k tail + auto block_size = kvTail >= block_64 ? block_64 : kv_tail_tail_block_size; + int64_t b = 0; + while (b < kvTail) { + at::native::cpublas::brgemm( + qTail, block_size, qk_gemm_K, + qk_gemm_K, // lda + block_size, //ldb + rndkvTail, //ldc + false, + query_t_padding_ptr, + key_reorder_ptr + n * qk_gemm_K + + b * qk_gemm_K, + qk_s32_data + b); + b += block_size; + block_size = (kvTail - b) >= block_64 ? block_64 : kv_tail_tail_block_size; + } + } + } + + // do dequant compensation, add mask, max reduce for softmax, and convert qk from s32 to fp32 + int64_t rndkvBlockSize = kvBlockSize == kvSplitSize ? rndkvSplitSize : rndkvTail; + accum_t* qk_block_data = qk_data + l * qSplitSize * rndkvSplitSize; + if (has_attn_mask) { + mask_t* mask_data_offset = mask_data + i * mStrideB + j * mStrideH + m * mStrideM + (mStrideN == 0 ? 0 : n); + _dequant_mask_max_fusion_kernel( + qk_s32_data, //in + mask_data_offset, //mask_ptr + q_sum_ptr, //sum_a_ptr + k_sum_ptr + n, //sum_b_ptr + qBlockSize, //M + kvBlockSize, //N + rndkvBlockSize, //ldi + mStrideM, //ldm + rndkvSplitSize,//kvBlockSize, //ldo + q_zp * k_zp * headSize, //zp_a*zp_b*k=beta + q_scale * k_scale * scaling_factor, //scale_a*scale_b*scale_sdpa=alpha + qk_block_data, //out + sfm_max_ptr // sfm_max_ptr + ); + } else { + _dequant_max_fusion_kernel( + qk_s32_data, //in + q_sum_ptr, //sum_a_ptr + k_sum_ptr + n, //sum_b_ptr + qBlockSize, //M + kvBlockSize, //N + rndkvBlockSize, //ldi + rndkvSplitSize,//kvBlockSize, //ldo + q_zp * k_zp * headSize, //zp_a*zp_b*k=beta + q_scale * k_scale * scaling_factor, //scale_a*scale_b*scale_sdpa=alpha + qk_block_data, //out + sfm_max_ptr // sfm_max_ptr + ); + } + } + // sub max, exp, sum reduce, div sum for softmax + // and quant + // and sum for attention + if (v_zp == 0) { + _sub_exp_sum_div_quant_fusion_kernel( + qk_data, //in + qBlockSize, //M + kvSplitSize, //N_step + rkvSlice, //NSlices + qSplitSize * rndkvSplitSize, //ldi + qk_reduce_strideL, //ldo + kvSize, //kvSize + rndkvSplitSize, //rndkvSplitSize + av_gemm_K, //av_gemm_K + a_zp, // zp_a=beta1 + a_scale, // scale_a=alpha + qk_local_data, //local + qk_reduced_data, //out + sfm_max_ptr, //sfm_max_ptr + sfm_sum_ptr //sfm_sum_ptr + ); + } else { + _sub_exp_sum_div_quant_sum_fusion_kernel( + qk_data, //in + qBlockSize, //M + kvSplitSize, //N_step + rkvSlice, //NSlice + qSplitSize * rndkvSplitSize, //ldi + qk_reduce_strideL, //ldo + kvSize, //kvSize + rndkvSplitSize, //rndkvSplitSize + av_gemm_K, //av_gemm_K + a_zp, // zp_a=beta1 + v_zp, // zp_b=beta2 + a_scale, // scale_a=alpha + qk_local_data, //local + qk_reduced_data, //out + sfm_max_ptr, //sfm_max_ptr + sfm_sum_ptr, //sfm_sum_ptr + a_sum_ptr //a_sum_ptr + ); + } + // Calculate Softmax(q @ k.T) @ v + for (int64_t b = 0; b < headSize; b += block_64) { + auto value_reorder_b = value_reorder_ptr + b * av_gemm_K; + auto dst_s32_b = dst_s32_data + b; + for (int64_t s = 0; s < kvSlice; s++) { + at::native::cpublas::brgemm( + qSplitSize, block_64, av_gemm_K, + av_gemm_K, // lda + rndHeadSize, //block_64, //ldb + rndHeadSize, //ldc + s != 0, + qk_reduced_data + s * qk_reduce_strideL, + value_reorder_b + s * v_reorder_strideL, + dst_s32_b); + } + } + + // After the last gemm, + // do dequant compensation, quant and convert from s32 to int8 + _dequant_quant_fusion_kernel( + dst_s32_data, //in + a_sum_ptr, //sum_a_ptr + v_sum_ptr, //sum_b_ptr + qBlockSize, //M + headSize, //N + rndHeadSize, //ldi + oStrideM, //ldo + a_zp * v_zp * kvSize, //zp_a*zp_b*k=beta1 + o_zp, //zp_c=beta2 + a_scale * v_scale / o_scale, //scale_a*scale_b/scale_c=alpha + out_data + i * oStrideB + j * oStrideH + m * oStrideM //out + ); + } + // Move to the next query + at::native::data_index_step(i, batchSize, j, num_head); + } + }); + // Once all computations are done, need to release HW context. + at::native::cpublas::brgemm_release(); +} + +// UINT8 - several parallel loops with u8u8s32 GEMM +template +inline typename std::enable_if_t, void> +sdpa_int8_kernel_several_loops_impl( + const at::Tensor& output, + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + double dropout_p, + bool is_causal, + at::Tensor& attention_mask, + double scale, + int32_t q_zp, + float q_scale, + int32_t k_zp, + float k_scale, + int32_t v_zp, + float v_scale, + int32_t a_zp, + float a_scale, + int32_t o_zp, + float o_scale) { + // Query (Batch x Q_seq_len x Num_heads x Dim_per_head) + // Key (Batch x KV_seq_len x Num_heads x Dim_per_head) + // Value (Batch x KV_seq_len x Num_heads x Dim_per_head) + + const auto accumulate_dtype = at::kFloat; + + using accum_t = float; + using Vec = at::vec::Vectorized; + accum_t scaling_factor = calculate_scale(query, scale); + int block_64 = 64; + // Sizes + TORCH_CHECK( + (query.size(3) == value.size(3)) && (key.size(3) == value.size(3)), + "scaled_dot_product_attention_sdpa: Q/K/V should have the same head size"); + TORCH_CHECK( + kv_split_size % block_64 == 0, "kv_split_size is not divisble by ", block_64); + + int64_t batchSize = query.size(0); + int64_t qSize = query.size(1); + int64_t kvSize = value.size(1); + int64_t num_head = query.size(2); + int64_t headSize = query.size(3); + + + bool has_attn_mask = attention_mask.defined() && attention_mask.numel(); + if (has_attn_mask) { + reshape_attn_mask_to_4d(attention_mask, batchSize, num_head, qSize, kvSize); + } + + // Strides + int64_t qStrideB = query.stride(0); + int64_t qStrideM = query.stride(1); + int64_t qStrideH = query.stride(2); + int64_t kStrideB = key.stride(0); + int64_t kStrideN = key.stride(1); + int64_t kStrideH = key.stride(2); + int64_t vStrideB = value.stride(0); + int64_t vStrideN = value.stride(1); + int64_t vStrideH = value.stride(2); + int64_t oStrideB = output.stride(0); + int64_t oStrideM = output.stride(1); + int64_t oStrideH = output.stride(2); + int64_t mStrideB = + (attention_mask.defined() && attention_mask.size(0) > 1) + ? attention_mask.stride(0) + : 0; + int64_t mStrideH = + (attention_mask.defined() && attention_mask.size(1) > 1) + ? attention_mask.stride(1) + : 0; + int64_t mStrideM = + (attention_mask.defined() && attention_mask.size(2) > 1) + ? attention_mask.stride(2) + : 0; + int64_t mStrideN = + (attention_mask.defined() && attention_mask.size(3) > 1) + ? attention_mask.stride(3) + : 0; + + int64_t qSplitSize = q_split_size > qSize ? qSize : q_split_size; + int64_t kvSplitSize = kv_split_size > kvSize ? kvSize : kv_split_size; + int64_t qSlice = (qSize - 1) / qSplitSize + 1; + int64_t qTail = (qSize - 1) % qSplitSize + 1; + int64_t kvSlice = (kvSize - 1) / kvSplitSize + 1; + int64_t kvTail = (kvSize - 1) % kvSplitSize + 1; + int64_t num_thread = at::get_num_threads(); + + int64_t rndHeadSize = (headSize + block_64 - 1L) / block_64 * block_64; + // one of 16, 32, 48, 64 + auto select_tail_tail_block_size = [](int64_t size) -> int64_t { + if (size == 0) { + return 0; + } else if (size <= 16) { + return 16; + } else if (size <= 32) { + return 32; + } else if (size <= 48) { + return 48; + } else { + return 64; + } + }; + int64_t kv_tail_tail_block_size = select_tail_tail_block_size(kvTail % block_64); + int64_t rndkvSplitSize = (kvSplitSize + block_64 - 1L) / block_64 * block_64; + int64_t rndkvTail = (kvTail + block_64 - 1L) / block_64 * block_64; + int64_t rndkvSize = kv_split_size > kvSize ? rndkvTail : rndkvSplitSize * kvSlice + rndkvTail; + + bool av_gemm_K_mul4 = kvSplitSize % 4 == 0; + int av_gemm_K_padding = av_gemm_K_mul4 ? 0 : 4 - kvSplitSize % 4; + // // If K of Gemm is not even, use mkl gemm instead of tpp for BF16 + int av_gemm_K = kvSplitSize + av_gemm_K_padding; + bool av_gemm_K_tail_mul4 = kvTail % 4 == 0; + int av_gemm_K_tail_padding = av_gemm_K_tail_mul4 ? 0 : 4 - kvTail % 4; + int av_gemm_K_tail = kvTail + av_gemm_K_tail_padding; + + auto u8_dt = at::ScalarType::Byte; + auto s8_dt = at::ScalarType::Int; + auto f32_dt = at::ScalarType::Float; + + // Data ptrs + scalar_t* q_data = query.data_ptr(); + scalar_t* k_data = key.data_ptr(); + scalar_t* v_data = value.data_ptr(); + mask_t* mask_data = attention_mask.defined() + ? attention_mask.data_ptr() + : nullptr; + scalar_t* out_data = output.data_ptr(); + + // Create tpp kernels for Query @ Key + bool headSize_mul4 = headSize % 4 == 0; + // If K of Gemm is not even, use mkl gemm instead of tpp for BF16 + int qk_gemm_K_padding = headSize_mul4 ? 0 : 4 - headSize % 4; + int qk_gemm_K = headSize + qk_gemm_K_padding; + + int64_t qk_reduce_strideL = qSplitSize * av_gemm_K; + int64_t v_reorder_strideL = av_gemm_K * rndHeadSize; + + int64_t total_size_uint8_per_thread = + /* qk */ kvSlice * qSplitSize * rndkvSplitSize * 4 + + /* qk_local */ kvSlice * av_gemm_K * 4 + + /* qk_reduce */ kvSlice * qk_reduce_strideL + + /* qk_s32 */ qSplitSize * rndkvSplitSize * 4 + + /* dst_s32 */ qSplitSize * rndHeadSize * 4 + + /* softmax_sum */ qSplitSize * 4 + + /* query_sum */ qSplitSize * 4 + + /* attention_sum */ qSplitSize * 4 + + /* softmax max */ qSplitSize * 4 + + /* query_padding_data */ qSplitSize * qk_gemm_K; + + at::Tensor total_buf = at::empty( + {num_thread, total_size_uint8_per_thread}, + query.options()).zero_(); + scalar_t* total_buf_data = total_buf.data_ptr(); + + int64_t kv_sum_size_per_BH = + /* key_sum */ kvSize + + /* value_sum */ headSize; + + at::Tensor kv_sum_buf = at::empty( + {batchSize, num_head, kv_sum_size_per_BH}, + query.options().dtype(at::kInt)).zero_(); + int32_t* kv_sum_buf_data = kv_sum_buf.data_ptr(); + + int64_t kv_reorder_size_per_BH = + /* key_t_reorder */ qk_gemm_K * rndkvSize + + /* value_t_reorder */ kvSlice * v_reorder_strideL; + + at::Tensor kv_reorder_buf = at::empty( + {batchSize, num_head, kv_reorder_size_per_BH}, + query.options()).zero_(); + scalar_t* kv_reorder_buf_data = kv_reorder_buf.data_ptr(); + scalar_t* key_reorder_ptr = kv_reorder_buf_data; + scalar_t* value_reorder_ptr = kv_reorder_buf_data + batchSize * num_head * qk_gemm_K * rndkvSize; + + // sum k and v + at::parallel_for( + 0, batchSize * num_head, 1, [&](int64_t begin, int64_t end) { + int64_t i = 0, j = 0; + at::native::data_index_init( + begin, i, batchSize, j, num_head); + for (const auto z : c10::irange(begin, end)) { + (void)z; // Suppress unused variable + int32_t* kv_sum_ptr = kv_sum_buf_data + + i * num_head * kv_sum_size_per_BH + + j * kv_sum_size_per_BH; + int32_t* k_sum_ptr = kv_sum_ptr; + int32_t* v_sum_ptr = kv_sum_ptr + kvSize; + if (q_zp == 0) { + fill_stub(k_sum_ptr, static_cast(0), kvSize); + } else { + _int_sum_b_contiguous_kernel(k_data + i * kStrideB + j * kStrideH, + k_sum_ptr, + kvSize, headSize, kStrideN, q_zp); + } + if (a_zp == 0) { + fill_stub(v_sum_ptr, static_cast(0), headSize); + } else { + _int_sum_a_contiguous_kernel(v_data + i * vStrideB + j * vStrideH, + v_sum_ptr, + headSize, kvSize, vStrideN, a_zp); + } + // Move to the next query + at::native::data_index_step(i, batchSize, j, num_head); + } + }); + + // packing + at::parallel_for( + 0, batchSize * num_head * kvSlice, 1, [&](int64_t begin, int64_t end) { + int64_t i = 0, j = 0, l = 0, n = 0; + at::native::data_index_init( + begin, i, batchSize, j, num_head, l, kvSlice); + uint8_t* B_blocked_xform_u8 = new uint8_t[std::max(qk_gemm_K, av_gemm_K) * block_64]; + for (const auto z : c10::irange(begin, end)) { + (void)z; // Suppress unused variable + n = l * kvSplitSize; + auto k_reorder = key_reorder_ptr + i * num_head * qk_gemm_K * rndkvSize + + j * qk_gemm_K * rndkvSize + n * qk_gemm_K; + auto v_reorder = value_reorder_ptr + + i * num_head * kvSlice * v_reorder_strideL + + j * kvSlice * v_reorder_strideL + n * rndHeadSize; + if (n + kvSplitSize < kvSize) { + for (int64_t b = 0; b < rndkvSplitSize; b += block_64) { + do_transpose( + k_data + i * kStrideB + j * kStrideH + n * kStrideN + b * kStrideN, + B_blocked_xform_u8, + std::min(int(kvSplitSize - b), block_64), + headSize, + kStrideN, + block_64); + if (!headSize_mul4) { + pad_remain_row_col( + B_blocked_xform_u8, + headSize, + block_64, + qk_gemm_K, + block_64, + block_64 + ); + } + at::native::cpublas::pack( + qk_gemm_K, // K + block_64, // N + block_64, // ld_in + block_64, // ld_out + u8_dt, // dt_in + u8_dt, // dt_out + B_blocked_xform_u8, + k_reorder + b * qk_gemm_K); + } + // split headSize to block_64, block_64, block_64 ... + // [kvSplitSize, headSize] -> [av_gemm_K, block_64 ...] + for (int64_t b = 0; b < rndHeadSize; b += block_64) { + at::native::cpublas::pack( + av_gemm_K, + block_64, + vStrideN, // block_64, + block_64, + u8_dt, + u8_dt, + v_data + i * vStrideB + j * vStrideH + n * vStrideN + b, + v_reorder + av_gemm_K * b); + } + } else { + // tail + auto block_size = kvTail >= block_64 ? block_64 : kv_tail_tail_block_size; + int64_t b = 0; + while (b < rndkvTail) { + do_transpose( + k_data + i * kStrideB + j * kStrideH + n * kStrideN + b * kStrideN, + B_blocked_xform_u8, + std::min(kvTail - b, block_size), + headSize, + kStrideN, + block_size); + if (!headSize_mul4) { + pad_remain_row_col( + B_blocked_xform_u8, + headSize, + block_size, + qk_gemm_K, + block_size, + block_size + ); + } + // Pack + at::native::cpublas::pack( + qk_gemm_K, + block_size, + block_size, + block_size, + u8_dt, + u8_dt, + B_blocked_xform_u8, + k_reorder + b * qk_gemm_K); + b += block_size; + block_size = (kvTail - b) >= block_64 ? block_64 : kv_tail_tail_block_size; + } + // split headSize to block_64, block_64, block_64 ... + // [kvTail, headSize] -> [av_gemm_K_tail, block_64 ...] + for (int64_t b = 0; b < headSize; b += block_64) { + at::native::cpublas::pack( + av_gemm_K, + block_64, + vStrideN, // block_64, + block_64, + u8_dt, + u8_dt, + v_data + i * vStrideB + j * vStrideH + n * vStrideN + b, + v_reorder + av_gemm_K * b); + } + } + // Move to the next query + at::native::data_index_step(i, batchSize, j, num_head, l, kvSlice); + } + }); + + at::parallel_for( + 0, batchSize * num_head * qSlice, 1, [&](int64_t begin, int64_t end) { + int64_t i = 0, j = 0, k = 0; + at::native::data_index_init( + begin, i, batchSize, j, num_head, k, qSlice); + int ompIdx = at::get_thread_num(); + scalar_t* total_buf_ptr = total_buf_data + ompIdx * total_size_uint8_per_thread; + int32_t offset = 0; + accum_t* qk_data = reinterpret_cast(total_buf_ptr); + offset += kvSlice * qSplitSize * rndkvSplitSize * 4; + accum_t* qk_local_data = reinterpret_cast(total_buf_ptr + offset); + offset += kvSlice * av_gemm_K * 4; + scalar_t* qk_reduced_data = reinterpret_cast(total_buf_ptr + offset); + offset += kvSlice * qk_reduce_strideL; + int32_t* qk_s32_data = reinterpret_cast(total_buf_ptr + offset); + offset += qSplitSize * rndkvSplitSize * 4; + int32_t* dst_s32_data = reinterpret_cast(total_buf_ptr + offset); + offset += qSplitSize * rndHeadSize * 4; + accum_t* sfm_sum_ptr = reinterpret_cast(total_buf_ptr + offset); + offset += qSplitSize * 4; + int32_t* q_sum_ptr = reinterpret_cast(total_buf_ptr + offset); + offset += qSplitSize * 4; + int32_t* a_sum_ptr = reinterpret_cast(total_buf_ptr + offset); + offset += qSplitSize * 4; + accum_t* sfm_max_ptr = reinterpret_cast(total_buf_ptr + offset); + offset += qSplitSize * 4; + scalar_t* query_t_padding_ptr = reinterpret_cast(total_buf_ptr + offset); + + for (const auto z : c10::irange(begin, end)) { + (void)z; // Suppress unused variable + + int32_t* kv_sum_ptr = kv_sum_buf_data + + i * num_head * kv_sum_size_per_BH + + j * kv_sum_size_per_BH; + int32_t* k_sum_ptr = kv_sum_ptr; + int32_t* v_sum_ptr = kv_sum_ptr + kvSize; + + // sdpa core + int64_t m = k * qSplitSize; + int64_t qBlockSize = std::min(qSplitSize, qSize - m); + // Initialize sum and max + fill_stub( + sfm_sum_ptr, static_cast(0), qSplitSize); + fill_stub( + a_sum_ptr, static_cast(0), qSplitSize); + fill_stub( + sfm_max_ptr, static_cast(-std::numeric_limits::infinity()), qSplitSize); + int64_t num_keys = + is_causal ? std::min(m + qBlockSize, kvSize) : kvSize; + copy_value_with_pad( + q_data + i * qStrideB + j * qStrideH + m * qStrideM, + query_t_padding_ptr, + qBlockSize, + headSize, + qBlockSize, + qk_gemm_K, + qStrideM); + + if (k_zp != 0) { + _int_sum_b_contiguous_kernel(q_data + i * qStrideB + j * qStrideH + m * qStrideM, + q_sum_ptr, qBlockSize, headSize, qStrideM, k_zp); + } else { + fill_stub( + q_sum_ptr, static_cast(0), qSplitSize); + } + const int64_t rkvSlice = (num_keys - 1) / kvSplitSize + 1; + for (int64_t l = 0; l < rkvSlice; l++) { + int64_t n = l * kvSplitSize; + int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n); + auto k_reorder = key_reorder_ptr + i * num_head * qk_gemm_K * rndkvSize + + j * qk_gemm_K * rndkvSize + n * qk_gemm_K; + // Calculate sums for dequant compensation item + if (qBlockSize == qSplitSize) { + // q main + if (n + kvSplitSize < kvSize) { + // k main + for (int64_t b = 0; b < kvSplitSize; b += block_64) { + at::native::cpublas::brgemm( + qSplitSize, block_64, qk_gemm_K, + qk_gemm_K, // lda + block_64, //ldb + rndkvSplitSize, //ldc, + false, + query_t_padding_ptr, + k_reorder + b * qk_gemm_K, + qk_s32_data + b); + } + } else { + auto block_size = kvTail >= block_64 ? block_64 : kv_tail_tail_block_size; + int64_t b = 0; + while (b < kvTail) { + at::native::cpublas::brgemm( + qSplitSize, block_size, qk_gemm_K, + qk_gemm_K, // lda + block_size, //ldb + rndkvTail, //ldc + false, + query_t_padding_ptr, + k_reorder + b * qk_gemm_K, + qk_s32_data + b); + b += block_size; + block_size = (kvTail - b) >= block_64 ? block_64 : kv_tail_tail_block_size; + } + } + } else { + if (n + kvSplitSize < kvSize) { + for (int64_t b = 0; b < kvSplitSize; b += block_64) { + at::native::cpublas::brgemm( + qTail, block_64, qk_gemm_K, + qk_gemm_K,// lda + block_64, //ldb + rndkvSplitSize, //ldc + false, + query_t_padding_ptr, + k_reorder + b * qk_gemm_K, + qk_s32_data + b); + } + } else { + // k tail + auto block_size = kvTail >= block_64 ? block_64 : kv_tail_tail_block_size; + int64_t b = 0; + while (b < kvTail) { + at::native::cpublas::brgemm( + qTail, block_size, qk_gemm_K, + qk_gemm_K, // lda + block_size, //ldb + rndkvTail, //ldc + false, + query_t_padding_ptr, + k_reorder + b * qk_gemm_K, + qk_s32_data + b); + b += block_size; + block_size = (kvTail - b) >= block_64 ? block_64 : kv_tail_tail_block_size; + } + } + } + + // do dequant compensation, add mask, max reduce for softmax, and convert qk from s32 to fp32 + int64_t rndkvBlockSize = kvBlockSize == kvSplitSize ? rndkvSplitSize : rndkvTail; + accum_t* qk_block_data = qk_data + l * qSplitSize * rndkvSplitSize; + if (has_attn_mask) { + mask_t* mask_data_offset = mask_data + i * mStrideB + j * mStrideH + m * mStrideM + (mStrideN == 0 ? 0 : n); + _dequant_mask_max_fusion_kernel( + qk_s32_data, //in + mask_data_offset, //mask_ptr + q_sum_ptr, //sum_a_ptr + k_sum_ptr + n, //sum_b_ptr + qBlockSize, //M + kvBlockSize, //N + rndkvBlockSize, //ldi + mStrideM, //ldm + rndkvSplitSize,//kvBlockSize, //ldo + q_zp * k_zp * headSize, //zp_a*zp_b*k=beta + q_scale * k_scale * scaling_factor, //scale_a*scale_b*scale_sdpa=alpha + qk_block_data, //out + sfm_max_ptr // sfm_max_ptr + ); + } else { + _dequant_max_fusion_kernel( + qk_s32_data, //in + q_sum_ptr, //sum_a_ptr + k_sum_ptr + n, //sum_b_ptr + qBlockSize, //M + kvBlockSize, //N + rndkvBlockSize, //ldi + rndkvSplitSize,//kvBlockSize, //ldo + q_zp * k_zp * headSize, //zp_a*zp_b*k=beta + q_scale * k_scale * scaling_factor, //scale_a*scale_b*scale_sdpa=alpha + qk_block_data, //out + sfm_max_ptr // sfm_max_ptr + ); + } + } + // sub max, exp, sum reduce, div sum for softmax + // and quant + // and sum for attention + if (v_zp == 0) { + _sub_exp_sum_div_quant_fusion_kernel( + qk_data, //in + qBlockSize, //M + kvSplitSize, //N_step + rkvSlice, //NSlices + qSplitSize * rndkvSplitSize, //ldi + qk_reduce_strideL, //ldo + kvSize, //kvSize + rndkvSplitSize, //rndkvSplitSize + av_gemm_K, //av_gemm_K + a_zp, // zp_a=beta1 + a_scale, // scale_a=alpha + qk_local_data, //local + qk_reduced_data, //out + sfm_max_ptr, //sfm_max_ptr + sfm_sum_ptr //sfm_sum_ptr + ); + } else { + _sub_exp_sum_div_quant_sum_fusion_kernel( + qk_data, //in + qBlockSize, //M + kvSplitSize, //N_step + rkvSlice, //NSlice + qSplitSize * rndkvSplitSize, //ldi + qk_reduce_strideL, //ldo + kvSize, //kvSize + rndkvSplitSize, //rndkvSplitSize + av_gemm_K, //av_gemm_K + a_zp, // zp_a=beta1 + v_zp, // zp_b=beta2 + a_scale, // scale_a=alpha + qk_local_data, //local + qk_reduced_data, //out + sfm_max_ptr, //sfm_max_ptr + sfm_sum_ptr, //sfm_sum_ptr + a_sum_ptr //a_sum_ptr + ); + } + // Calculate Softmax(q @ k.T) @ v + auto v_reorder = value_reorder_ptr + + i * num_head * kvSlice * v_reorder_strideL + + j * kvSlice * v_reorder_strideL; + for (int64_t b = 0; b < headSize; b += block_64) { + auto value_reorder_b = v_reorder + b * av_gemm_K; + auto dst_s32_b = dst_s32_data + b; + for (int64_t s = 0; s < kvSlice; s++) { + at::native::cpublas::brgemm( + qSplitSize, block_64, av_gemm_K, + av_gemm_K, // lda + rndHeadSize, //block_64, //ldb + rndHeadSize, //ldc + s != 0, + qk_reduced_data + s * qk_reduce_strideL, + value_reorder_b + s * v_reorder_strideL, + dst_s32_b); + } + } + + // After the last gemm, + // do dequant compensation, quant and convert from s32 to int8 + _dequant_quant_fusion_kernel( + dst_s32_data, //in + a_sum_ptr, //sum_a_ptr + v_sum_ptr, //sum_b_ptr + qBlockSize, //M + headSize, //N + rndHeadSize, //ldi + oStrideM, //ldo + a_zp * v_zp * kvSize, //zp_a*zp_b*k=beta1 + o_zp, //zp_c=beta2 + a_scale * v_scale / o_scale, //scale_a*scale_b/scale_c=alpha + out_data + i * oStrideB + j * oStrideH + m * oStrideM //out + ); + // Move to the next query + at::native::data_index_step(i, batchSize, j, num_head, k, qSlice); + } + }); + // Once all computations are done, need to release HW context. + at::native::cpublas::brgemm_release(); +} + +#define AT_DISPATCH_MASK_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_PRIVATE_CASE_TYPE_USING_HINT( \ + at::ScalarType::Bool, mask_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE_USING_HINT( \ + at::ScalarType::Float, mask_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE_USING_HINT( \ + at::ScalarType::Double, mask_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE_USING_HINT( \ + at::ScalarType::BFloat16, mask_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE_USING_HINT( \ + at::ScalarType::Half, mask_t, __VA_ARGS__)) + +void sdpa_int8_fused_kernel( + const at::Tensor& output, + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + double dropout_p, + bool is_causal, + at::Tensor& attn_mask, + double scale, + long q_zp, + double q_scale, + long k_zp, + double k_scale, + long v_zp, + double v_scale, + long a_zp, + double a_scale, + long o_zp, + double o_scale) { + TORCH_CHECK(query.scalar_type() == c10::kByte); + int64_t batchSize = query.size(0); + int64_t num_head = query.size(1); + int64_t q_seq_len = query.size(2); + int64_t kv_seq_len = key.size(2); + int64_t q_split_size = 32; + if (q_seq_len >= 768) { + q_split_size = 256; + } else if (q_seq_len >= 192) { + q_split_size = 64; + } + // Heuristic to decide whether to use one parallel loop or not + uint32_t l2_cache_size = at::cpu::L2_cache_size(); + int64_t num_thread = at::get_num_threads(); + int64_t attn_size = q_split_size * kv_seq_len * sizeof(int32_t) * num_thread; + bool use_one_parallel_loop = (batchSize * num_head > num_thread) && + (attn_size > l2_cache_size); + if (use_one_parallel_loop) { + if (!attn_mask.defined()) { + if (q_split_size == 256) { + sdpa_int8_kernel_one_loop_impl( + output, query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_zp, q_scale, + k_zp, k_scale, + v_zp, v_scale, + a_zp, a_scale, + o_zp, o_scale); + } else if (q_split_size == 64) { + sdpa_int8_kernel_one_loop_impl( + output, query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_zp, q_scale, + k_zp, k_scale, + v_zp, v_scale, + a_zp, a_scale, + o_zp, o_scale); + } else { + sdpa_int8_kernel_one_loop_impl( + output, query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_zp, q_scale, + k_zp, k_scale, + v_zp, v_scale, + a_zp, a_scale, + o_zp, o_scale); + } + } else { + AT_DISPATCH_MASK_TYPES(attn_mask.scalar_type(), "sdpa_mask", [&]() { + if (q_split_size == 256) { + sdpa_int8_kernel_one_loop_impl( + output, query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_zp, q_scale, + k_zp, k_scale, + v_zp, v_scale, + a_zp, a_scale, + o_zp, o_scale); + } else if (q_split_size == 64) { + sdpa_int8_kernel_one_loop_impl( + output, query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_zp, q_scale, + k_zp, k_scale, + v_zp, v_scale, + a_zp, a_scale, + o_zp, o_scale); + } else { + sdpa_int8_kernel_one_loop_impl( + output, query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_zp, q_scale, + k_zp, k_scale, + v_zp, v_scale, + a_zp, a_scale, + o_zp, o_scale); + } + }); + } + } else { + if (!attn_mask.defined()) { + if (q_split_size == 256) { + sdpa_int8_kernel_several_loops_impl( + output, query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_zp, q_scale, + k_zp, k_scale, + v_zp, v_scale, + a_zp, a_scale, + o_zp, o_scale); + } else if (q_split_size == 64) { + sdpa_int8_kernel_several_loops_impl( + output, query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_zp, q_scale, + k_zp, k_scale, + v_zp, v_scale, + a_zp, a_scale, + o_zp, o_scale); + } else { + sdpa_int8_kernel_several_loops_impl( + output, query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_zp, q_scale, + k_zp, k_scale, + v_zp, v_scale, + a_zp, a_scale, + o_zp, o_scale); + } + } else { + AT_DISPATCH_MASK_TYPES(attn_mask.scalar_type(), "sdpa_mask", [&]() { + if (q_split_size == 256) { + sdpa_int8_kernel_several_loops_impl( + output, query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_zp, q_scale, + k_zp, k_scale, + v_zp, v_scale, + a_zp, a_scale, + o_zp, o_scale); + } else if (q_split_size == 64) { + sdpa_int8_kernel_several_loops_impl( + output, query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_zp, q_scale, + k_zp, k_scale, + v_zp, v_scale, + a_zp, a_scale, + o_zp, o_scale); + } else { + sdpa_int8_kernel_several_loops_impl( + output, query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_zp, q_scale, + k_zp, k_scale, + v_zp, v_scale, + a_zp, a_scale, + o_zp, o_scale); + } + }); + } + } +} +#endif // CPU_CAPABILITY_AVX512 + +at::Tensor sdpa_int8_math_kernel( + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + double dropout_p, + bool is_causal, + at::Tensor& attn_mask, + double scale, + int32_t q_zp, + float q_scale, + int32_t k_zp, + float k_scale, + int32_t v_zp, + float v_scale, + int32_t a_zp, + float a_scale, + int32_t o_zp, + float o_scale) { + // dequant q/k/v + auto q = (query.to(at::kFloat) - q_zp) * q_scale; + auto k = (key.to(at::kFloat) - k_zp) * k_scale; + auto v = (value.to(at::kFloat) - v_zp) * v_scale; + const auto scaling_factor = calculate_scale(q, scale); + auto attn = at::matmul(q, k.transpose(-2, -1)) * scaling_factor; + if (attn_mask.defined() && attn_mask.numel()) { + attn = attn.add(attn_mask.to(at::kFloat)); + } + attn = at::softmax(attn, -1); + // quant attn + attn = at::clamp_max( + at::clamp_min(at::round(attn / a_scale) + a_zp, 0), 255 + ); + // dequant attn + attn = (attn - a_zp) * a_scale; + auto output = at::matmul(attn, v); + // quant output + output = at::clamp_max( + at::clamp_min(at::round(output / o_scale) + o_zp, 0), 255 + ).to(at::kByte); + return output; +} + + +at::Tensor _scaled_dot_product_int8_cpu( + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + at::Tensor& attn_mask, + double dropout_p, + bool is_causal, + double scale, + int64_t q_zp, + double q_scale, + int64_t k_zp, + double k_scale, + int64_t v_zp, + double v_scale, + int64_t a_zp, + double a_scale, + int64_t o_zp, + double o_scale) { + const auto dtype = query.scalar_type(); + TORCH_CHECK(!query.is_nested() && !key.is_nested() && !value.is_nested(), + "_scaled_dot_product_int8_cpu: Only accept plain inputs"); + TORCH_CHECK(!is_causal, + "_scaled_dot_product_int8_cpu: is_causal not supported."); + TORCH_CHECK(dtype == at::ScalarType::Byte, + "_scaled_dot_product_int8_cpu: Expected data type be U8, but got ", dtype, " instead."); + TORCH_CHECK(query.dim() == 4 && key.dim() == 4 && value.dim() == 4, + "_scaled_dot_product_int8_cpu: Accept only 4 dims inputs shape of {B, H, T, K}"); + TORCH_CHECK(dropout_p == 0.0, + "_scaled_dot_product_int8_cpu: Currently do not support dropout > 0"); + TORCH_CHECK((query.size(3) == value.size(3)) && (key.size(3) == value.size(3)), + "_scaled_dot_product_int8_cpu: Q/K/V should have the same head size"); + TORCH_CHECK(!attn_mask.defined() || + attn_mask.scalar_type() == at::kFloat || + attn_mask.scalar_type() == at::kBFloat16, + "_scaled_dot_product_int8_cpu: Expected attention mask be float or bf16"); + TORCH_CHECK(!attn_mask.defined() || + (attn_mask.dim() == 2 || attn_mask.dim() == 4), + "_scaled_dot_product_int8_cpu: Attention mask dim in {2, 4}"); + + if (!at::native::cpublas::could_pack(dtype)) { + return sdpa_int8_math_kernel(query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_zp, q_scale, + k_zp, k_scale, + v_zp, v_scale, + a_zp, a_scale, + o_zp, o_scale); + } + + #ifdef CPU_CAPABILITY_AVX512 + at::Tensor output = at::empty_like(query, query.options()); + sdpa_int8_fused_kernel(output, query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_zp, q_scale, + k_zp, k_scale, + v_zp, v_scale, + a_zp, a_scale, + o_zp, o_scale); + return output; + #else + return sdpa_int8_math_kernel(query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_zp, q_scale, + k_zp, k_scale, + v_zp, v_scale, + a_zp, a_scale, + o_zp, o_scale); + #endif // CPU_CAPABILITY_AVX512 +} + + +} // anonymous namespace + +TORCH_LIBRARY_IMPL(torchao, CPU, m) { + m.impl("torchao::scaled_dot_product_int8", &_scaled_dot_product_int8_cpu); +} + +// } // at::native +} // namespace torchao diff --git a/torchao/prototype/inductor/fx_passes/README.md b/torchao/prototype/inductor/fx_passes/README.md new file mode 100644 index 0000000000..71535f87fb --- /dev/null +++ b/torchao/prototype/inductor/fx_passes/README.md @@ -0,0 +1,34 @@ +# Inductor FX Passes + +This directory contains the FX passes of Inductor. FX passes are transformations applied to the computational graph to optimize and modify it for better performance and functionality. + +You can replace the following customized graph passes of Inductor in TorchAO: +- `pre_grad_custom_pass` +- `joint_custom_pre_pass` +- `joint_custom_post_pass` +- `post_grad_custom_post_pass` +- `post_grad_custom_pre_pass` + +## Directory Structure + +- `int8_sdpa_fusion`: Pattern match for int8 sdpa fusion. + +## Getting Started + +To get started with using the FX passes in TorchAO, you can register and apply them to your computational graph as follows: + +```python +from torch._inductor import config +from torch._inductor.pattern_matcher import PatternMatcherPass + +# Example usage +patterns = PatternMatcherPass() # create a pattern matcher pass +_register_patterns() # register your own patterns +config.custom_pass = patterns.apply # define the custom pass with the patterns + +``` + +## Limitations + +For now, we can only register one pass as the custom pass. +In the future, it is better to extend it to a list. diff --git a/torchao/prototype/inductor/fx_passes/int8_sdpa_fusion.py b/torchao/prototype/inductor/fx_passes/int8_sdpa_fusion.py new file mode 100644 index 0000000000..e463256057 --- /dev/null +++ b/torchao/prototype/inductor/fx_passes/int8_sdpa_fusion.py @@ -0,0 +1,539 @@ +import functools +from typing import Callable + +import torch +from torch._inductor import config +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + filter_nodes, + KeywordArg, + ListOf, + Match, + PatternMatcherPass, +) +from torch._inductor.fx_passes.post_grad import register_lowering_pattern +from torch._inductor.lowering import lowerings as L, make_fallback +from torch._dynamo.utils import counters + +make_fallback(torch.ops.torchao.scaled_dot_product_int8.default) + +aten = torch.ops.aten +patterns = PatternMatcherPass() + +def _is_valid_int8_sdpa_pattern(): + def fn(match): + assert all(k in match.kwargs for k in ("query", "key", "value")) + query = match.kwargs["query"].meta["val"] + key = match.kwargs["key"].meta["val"] + value = match.kwargs["value"].meta["val"] + return ( + query.dtype == torch.uint8 + and key.dtype == torch.uint8 + and value.dtype == torch.uint8 + and query.device.type == "cpu" + and key.device == query.device + and value.device == query.device + ) + + return fn + + +def _register_int8_sdpa_pattern(pattern): + @register_lowering_pattern( + pattern, + extra_check=_is_valid_int8_sdpa_pattern(), + ) + def int8_sdpa(match: Match, *args, **kwargs): + print("\n***hit int8_sdpa_pattern***\n") + query = kwargs["query"] + key = kwargs["key"] + value = kwargs["value"] + inv_scale = kwargs["inv_scale"] + attn_mask = kwargs["attn_mask"] if "attn_mask" in kwargs else None + q_zp = kwargs["q_zp"] + q_scale = kwargs["q_scale"] + k_zp = kwargs["k_zp"] + k_scale = kwargs["k_scale"] + v_zp = kwargs["v_zp"] + v_scale = kwargs["v_scale"] + a_zp = kwargs["a_zp"] + a_scale = kwargs["a_scale"] + o_zp = kwargs["o_zp"] + o_scale = kwargs["o_scale"] + counters["inductor"]["int8_fuse_attention"] += 1 + counters["inductor"]["int8_sdpa_nodes"] += len(match.nodes) + + return L[torch.ops.torchao.scaled_dot_product_int8.default]( + query, + key, + value, + attn_mask, + 0.0, #dropout + False, #is_causal + 1.0 / inv_scale, #scale + q_zp, + q_scale, + k_zp, + k_scale, + v_zp, + v_scale, + a_zp, + a_scale, + o_zp, + o_scale, + ) + + return int8_sdpa + + +def _get_int8_sdpa_q_pattern(is_batch_size_1: bool, has_convert: bool): + int8_sdpa_q_basic_pattern = CallFunction( + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + CallFunction( + aten.permute.default, + KeywordArg("query"), + Arg(), + ), + KeywordArg("q_scale"), + KeywordArg("q_zp"), + Arg(), + Arg(), + Arg(), + ) + if has_convert: + int8_sdpa_q_basic_pattern = CallFunction( + torch.ops.prims.convert_element_type.default, + int8_sdpa_q_basic_pattern, + Arg(), + ) + int8_sdpa_q_basic_pattern = CallFunction( + aten.expand.default, + int8_sdpa_q_basic_pattern, + Arg(), + ) + if is_batch_size_1: + # pattern is different for bs=1 + return CallFunction( + aten.reshape.default, + int8_sdpa_q_basic_pattern, + Arg(), + ) + else: + return CallFunction( + aten.reshape.default, + CallFunction( + aten.clone.default, + int8_sdpa_q_basic_pattern, + memory_format=Arg(), + ), + Arg(), + ) + + +def _get_int8_sdpa_k_pattern(is_batch_size_1: bool, has_convert: bool): + int8_sdpa_k_basic_pattern = CallFunction( + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + CallFunction( + aten.permute.default, + CallFunction( + aten.permute.default, + KeywordArg("key"), + Arg(), + ), + Arg(), + ), + KeywordArg("k_scale"), + KeywordArg("k_zp"), + Arg(), + Arg(), + Arg(), + ) + if has_convert: + int8_sdpa_k_basic_pattern = CallFunction( + torch.ops.prims.convert_element_type.default, + int8_sdpa_k_basic_pattern, + Arg(), + ) + int8_sdpa_k_basic_pattern = CallFunction( + aten.expand.default, + int8_sdpa_k_basic_pattern, + Arg(), + ) + if is_batch_size_1: + # pattern is different for bs=1 + return CallFunction( + aten.reshape.default, + int8_sdpa_k_basic_pattern, + Arg(), + ) + else: + return CallFunction( + aten.reshape.default, + CallFunction( + aten.clone.default, + int8_sdpa_k_basic_pattern, + memory_format=Arg(), + ), + Arg(), + ) + + +def _get_int8_sdpa_v_pattern(is_batch_size_1: bool, has_convert: bool): + int8_sdpa_v_basic_pattern = CallFunction( + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + CallFunction( + aten.permute.default, + KeywordArg("value"), + Arg(), + ), + KeywordArg("v_scale"), + KeywordArg("v_zp"), + Arg(), + Arg(), + Arg(), + ) + if has_convert: + int8_sdpa_v_basic_pattern = CallFunction( + torch.ops.prims.convert_element_type.default, + int8_sdpa_v_basic_pattern, + Arg(), + ) + int8_sdpa_v_basic_pattern = CallFunction( + aten.expand.default, + int8_sdpa_v_basic_pattern, + Arg(), + ) + if is_batch_size_1: + # pattern is different for bs=1 + return CallFunction( + aten.reshape.default, + int8_sdpa_v_basic_pattern, + Arg(), + ) + else: + return CallFunction( + aten.reshape.default, + CallFunction( + aten.clone.default, + int8_sdpa_v_basic_pattern, + memory_format=Arg(), + ), + Arg(), + ) + + +def _get_int8_sdpa_score_pattern( + has_mask: bool, is_batch_size_1: bool, is_reduced_type: bool, has_convert: bool +): + int8_sdpa_q_pattern = _get_int8_sdpa_q_pattern(is_batch_size_1, has_convert) + int8_sdpa_k_pattern = _get_int8_sdpa_k_pattern(is_batch_size_1, has_convert) + int8_sdpa_score_basic_pattern = CallFunction( + aten.reshape.default, + CallFunction( + aten.bmm.default, + int8_sdpa_q_pattern, + int8_sdpa_k_pattern, + ), + Arg(), + ) + if is_reduced_type and not has_mask: + int8_sdpa_score_basic_pattern = CallFunction( + torch.ops.prims.convert_element_type.default, + int8_sdpa_score_basic_pattern, + Arg(), + ) + if has_mask: + return CallFunction( + aten.add.Tensor, + CallFunction( + aten.div.Tensor, + int8_sdpa_score_basic_pattern, + KeywordArg("inv_scale"), + ), + KeywordArg("attn_mask"), + _users=2, + ) + else: + return CallFunction( + aten.mul.Tensor, + int8_sdpa_score_basic_pattern, + Arg(), + _users=2, + ) + + +def _get_int8_sdpa_exp_pattern( + has_mask: bool, is_batch_size_1: bool, is_reduced_type: bool, has_convert: bool +): + int8_sdpa_score_pattern = _get_int8_sdpa_score_pattern( + has_mask, is_batch_size_1, is_reduced_type, has_convert + ) + int8_sdpa_exp_basic_pattern = CallFunction( + aten.sub.Tensor, + int8_sdpa_score_pattern, + CallFunction( + aten.amax.default, + int8_sdpa_score_pattern, + Arg(), + Arg(), + ), + ) + if has_mask: + return CallFunction( + aten.exp.default, + int8_sdpa_exp_basic_pattern, + _users=2, + ) + else: + return CallFunction( + aten.exp.default, + CallFunction( + aten.div.Tensor, + int8_sdpa_exp_basic_pattern, + KeywordArg("inv_scale"), + ), + _users=2, + ) + + +def _get_int8_sdpa_attn_pattern( + has_mask: bool, is_batch_size_1: bool, is_reduced_type: bool, has_convert: bool +): + int8_sdpa_exp_pattern = _get_int8_sdpa_exp_pattern( + has_mask, is_batch_size_1, is_reduced_type, has_convert + ) + int8_sdpa_div_pattern = CallFunction( + aten.div.Tensor, + int8_sdpa_exp_pattern, + CallFunction( + aten.sum.dim_IntList, + int8_sdpa_exp_pattern, + Arg(), + Arg(), + ), + ) + int8_sdpa_softmax_pattern = CallFunction( + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + CallFunction( + torch.ops.quantized_decomposed.quantize_per_tensor.default, + int8_sdpa_div_pattern, + KeywordArg("a_scale"), + KeywordArg("a_zp"), + Arg(), + Arg(), + Arg(), + ), + KeywordArg("a_scale"), + KeywordArg("a_zp"), + Arg(), + Arg(), + Arg(), + ) + if is_reduced_type: + if has_mask: + int8_sdpa_softmax_pattern = CallFunction( + torch.ops.prims.convert_element_type.default, + int8_sdpa_softmax_pattern, + Arg(), + ) + else: + int8_sdpa_softmax_pattern = CallFunction( + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + CallFunction( + torch.ops.quantized_decomposed.quantize_per_tensor.default, + CallFunction( + torch.ops.prims.convert_element_type.default, + int8_sdpa_div_pattern, + Arg(), + ), + KeywordArg("a_scale"), + KeywordArg("a_zp"), + Arg(), + Arg(), + Arg(), + ), + KeywordArg("a_scale"), + KeywordArg("a_zp"), + Arg(), + Arg(), + Arg(), + ) + if has_convert: + int8_sdpa_softmax_pattern = CallFunction( + torch.ops.prims.convert_element_type.default, + int8_sdpa_softmax_pattern, + Arg(), + ) + return CallFunction( + aten.reshape.default, + CallFunction( + aten.expand.default, + int8_sdpa_softmax_pattern, + Arg(), + ), + Arg(), + ) + + +def _get_int8_sdpa_final_pattern( + has_mask: bool, is_batch_size_1: bool, is_reduced_type: bool, has_convert: bool +): + int8_sdpa_v_pattern = _get_int8_sdpa_v_pattern(is_batch_size_1, has_convert) + int8_sdpa_attn_pattern = _get_int8_sdpa_attn_pattern( + has_mask, is_batch_size_1, is_reduced_type, has_convert + ) + return CallFunction( + torch.ops.quantized_decomposed.quantize_per_tensor.default, + CallFunction( + aten.clone.default, + CallFunction( + aten.permute.default, + CallFunction( + aten.reshape.default, + CallFunction( + aten.bmm.default, + int8_sdpa_attn_pattern, + int8_sdpa_v_pattern, + ), + Arg(), + ), + Arg(), + ), + memory_format=Arg(), + ), + KeywordArg("o_scale"), + KeywordArg("o_zp"), + Arg(), + Arg(), + Arg(), + ) + + +def _register_int8_sdpa_fp32_lowering(): + # dtype = float32, without attention mask, batch size > 1 + _register_int8_sdpa_pattern( + _get_int8_sdpa_final_pattern( + has_mask=False, is_batch_size_1=False, is_reduced_type=False, has_convert=True + ) + ) + _register_int8_sdpa_pattern( + _get_int8_sdpa_final_pattern( + has_mask=False, is_batch_size_1=False, is_reduced_type=False, has_convert=False + ) + ) + + +def _register_int8_sdpa_fp32_mask_lowering(): + # dtype = float32, with attention mask, batch size > 1 + _register_int8_sdpa_pattern( + _get_int8_sdpa_final_pattern( + has_mask=True, is_batch_size_1=False, is_reduced_type=False, has_convert=True + ) + ) + _register_int8_sdpa_pattern( + _get_int8_sdpa_final_pattern( + has_mask=True, is_batch_size_1=False, is_reduced_type=False, has_convert=False + ) + ) + + +def _register_int8_sdpa_fp32_bs1_lowering(): + # dtype = float32, without attention mask, batch size == 1 + _register_int8_sdpa_pattern( + _get_int8_sdpa_final_pattern( + has_mask=False, is_batch_size_1=True, is_reduced_type=False, has_convert=True + ) + ) + _register_int8_sdpa_pattern( + _get_int8_sdpa_final_pattern( + has_mask=False, is_batch_size_1=True, is_reduced_type=False, has_convert=False + ) + ) + + +def _register_int8_sdpa_fp32_mask_bs1_lowering(): + # dtype = float32, with attention mask, batch size == 1 + _register_int8_sdpa_pattern( + _get_int8_sdpa_final_pattern( + has_mask=True, is_batch_size_1=True, is_reduced_type=False, has_convert=True + ) + ) + _register_int8_sdpa_pattern( + _get_int8_sdpa_final_pattern( + has_mask=True, is_batch_size_1=True, is_reduced_type=False, has_convert=False + ) + ) + + +def _register_int8_sdpa_bf16_lowering(): + # dtype = bfloat16, without attention mask, batch size > 1 + _register_int8_sdpa_pattern( + _get_int8_sdpa_final_pattern( + has_mask=False, is_batch_size_1=False, is_reduced_type=True, has_convert=True + ) + ) + _register_int8_sdpa_pattern( + _get_int8_sdpa_final_pattern( + has_mask=False, is_batch_size_1=False, is_reduced_type=True, has_convert=False + ) + ) + + +def _register_int8_sdpa_bf16_mask_lowering(): + # dtype = bfloat16, with attention mask, batch size > 1 + _register_int8_sdpa_pattern( + _get_int8_sdpa_final_pattern( + has_mask=True, is_batch_size_1=False, is_reduced_type=True, has_convert=True + ) + ) + _register_int8_sdpa_pattern( + _get_int8_sdpa_final_pattern( + has_mask=True, is_batch_size_1=False, is_reduced_type=True, has_convert=False + ) + ) + + +def _register_int8_sdpa_bf16_bs1_lowering(): + # dtype = bfloat16, without attention mask, batch size == 1 + _register_int8_sdpa_pattern( + _get_int8_sdpa_final_pattern( + has_mask=False, is_batch_size_1=True, is_reduced_type=True, has_convert=True + ) + ) + _register_int8_sdpa_pattern( + _get_int8_sdpa_final_pattern( + has_mask=False, is_batch_size_1=True, is_reduced_type=True, has_convert=False + ) + ) + + +def _register_int8_sdpa_bf16_mask_bs1_lowering(): + # dtype = bfloat16, with attention mask, batch size == 1 + _register_int8_sdpa_pattern( + _get_int8_sdpa_final_pattern( + has_mask=True, is_batch_size_1=True, is_reduced_type=True, has_convert=True + ) + ) + _register_int8_sdpa_pattern( + _get_int8_sdpa_final_pattern( + has_mask=True, is_batch_size_1=True, is_reduced_type=True, has_convert=False + ) + ) + + +def _register_int8_sdpa_lowerings(): + _register_int8_sdpa_fp32_lowering() + _register_int8_sdpa_fp32_mask_lowering() + _register_int8_sdpa_fp32_bs1_lowering() + _register_int8_sdpa_fp32_mask_bs1_lowering() + _register_int8_sdpa_bf16_lowering() + _register_int8_sdpa_bf16_mask_lowering() + _register_int8_sdpa_bf16_bs1_lowering() + _register_int8_sdpa_bf16_mask_bs1_lowering() + + +@functools.lru_cache(None) +def _int8_sdpa_init(): + _register_int8_sdpa_lowerings() + config.post_grad_custom_pre_pass = patterns.apply From c6478880d11a54891798b8c89ba13d1942ebee9b Mon Sep 17 00:00:00 2001 From: Valentine233 Date: Sun, 2 Mar 2025 22:21:15 -0500 Subject: [PATCH 13/36] update --- torchao/csrc/cpu/int8_sdpa.cpp | 44 ++++++++++++------- .../inductor/fx_passes/int8_sdpa_fusion.py | 13 ++++-- 2 files changed, 37 insertions(+), 20 deletions(-) diff --git a/torchao/csrc/cpu/int8_sdpa.cpp b/torchao/csrc/cpu/int8_sdpa.cpp index 4f5ca3fcaf..0aea924968 100644 --- a/torchao/csrc/cpu/int8_sdpa.cpp +++ b/torchao/csrc/cpu/int8_sdpa.cpp @@ -967,9 +967,9 @@ template , void> sdpa_int8_kernel_one_loop_impl( const at::Tensor& output, - const at::Tensor& query, - const at::Tensor& key, - const at::Tensor& value, + const at::Tensor& q, + const at::Tensor& k, + const at::Tensor& v, double dropout_p, bool is_causal, at::Tensor& attention_mask, @@ -984,9 +984,15 @@ sdpa_int8_kernel_one_loop_impl( float a_scale, int32_t o_zp, float o_scale) { - // Query (Batch x Q_seq_len x Num_heads x Dim_per_head) - // Key (Batch x KV_seq_len x Num_heads x Dim_per_head) - // Value (Batch x KV_seq_len x Num_heads x Dim_per_head) + // Query (Batch x Num_heads x Q_seq_len x Dim_per_head) + // -> (Batch x Q_seq_len x Num_heads x Dim_per_head) + // Key (Batch x Num_heads x KV_seq_len x Dim_per_head) + // -> (Batch x KV_seq_len x Num_heads x Dim_per_head) + // Value (Batch x Num_heads x KV_seq_len x Dim_per_head) + // -> (Batch x KV_seq_len x Num_heads x Dim_per_head) + at::Tensor query = q.transpose(1, 2); + at::Tensor key = k.transpose(1, 2); + at::Tensor value = v.transpose(1, 2); const auto accumulate_dtype = at::kFloat; @@ -1507,9 +1513,9 @@ template , void> sdpa_int8_kernel_several_loops_impl( const at::Tensor& output, - const at::Tensor& query, - const at::Tensor& key, - const at::Tensor& value, + const at::Tensor& q, + const at::Tensor& k, + const at::Tensor& v, double dropout_p, bool is_causal, at::Tensor& attention_mask, @@ -1524,9 +1530,15 @@ sdpa_int8_kernel_several_loops_impl( float a_scale, int32_t o_zp, float o_scale) { - // Query (Batch x Q_seq_len x Num_heads x Dim_per_head) - // Key (Batch x KV_seq_len x Num_heads x Dim_per_head) - // Value (Batch x KV_seq_len x Num_heads x Dim_per_head) + // Query (Batch x Num_heads x Q_seq_len x Dim_per_head) + // -> (Batch x Q_seq_len x Num_heads x Dim_per_head) + // Key (Batch x Num_heads x KV_seq_len x Dim_per_head) + // -> (Batch x KV_seq_len x Num_heads x Dim_per_head) + // Value (Batch x Num_heads x KV_seq_len x Dim_per_head) + // -> (Batch x KV_seq_len x Num_heads x Dim_per_head) + at::Tensor query = q.transpose(1, 2); + at::Tensor key = k.transpose(1, 2); + at::Tensor value = v.transpose(1, 2); const auto accumulate_dtype = at::kFloat; @@ -2347,11 +2359,11 @@ at::Tensor _scaled_dot_product_int8_cpu( k_zp, k_scale, v_zp, v_scale, a_zp, a_scale, - o_zp, o_scale); + o_zp, o_scale).transpose(1, 2).contiguous().transpose(1, 2); } #ifdef CPU_CAPABILITY_AVX512 - at::Tensor output = at::empty_like(query, query.options()); + at::Tensor output = at::empty_like(query, query.options()).transpose(1, 2); sdpa_int8_fused_kernel(output, query, key, value, dropout_p, is_causal, attn_mask, scale, q_zp, q_scale, @@ -2359,7 +2371,7 @@ at::Tensor _scaled_dot_product_int8_cpu( v_zp, v_scale, a_zp, a_scale, o_zp, o_scale); - return output; + return output.transpose(1, 2); #else return sdpa_int8_math_kernel(query, key, value, dropout_p, is_causal, attn_mask, scale, @@ -2367,7 +2379,7 @@ at::Tensor _scaled_dot_product_int8_cpu( k_zp, k_scale, v_zp, v_scale, a_zp, a_scale, - o_zp, o_scale); + o_zp, o_scale).transpose(1, 2).contiguous().transpose(1, 2); #endif // CPU_CAPABILITY_AVX512 } diff --git a/torchao/prototype/inductor/fx_passes/int8_sdpa_fusion.py b/torchao/prototype/inductor/fx_passes/int8_sdpa_fusion.py index e463256057..128ae60766 100644 --- a/torchao/prototype/inductor/fx_passes/int8_sdpa_fusion.py +++ b/torchao/prototype/inductor/fx_passes/int8_sdpa_fusion.py @@ -64,10 +64,13 @@ def int8_sdpa(match: Match, *args, **kwargs): counters["inductor"]["int8_fuse_attention"] += 1 counters["inductor"]["int8_sdpa_nodes"] += len(match.nodes) - return L[torch.ops.torchao.scaled_dot_product_int8.default]( - query, - key, - value, + trans_query = L[aten.permute.default](query, [0, 2, 1, 3]) + trans_key = L[aten.permute.default](key, [0, 2, 1, 3]) + trans_value = L[aten.permute.default](value, [0, 2, 1, 3]) + output = L[torch.ops.torchao.scaled_dot_product_int8.default]( + trans_query, + trans_key, + trans_value, attn_mask, 0.0, #dropout False, #is_causal @@ -83,6 +86,8 @@ def int8_sdpa(match: Match, *args, **kwargs): o_zp, o_scale, ) + trans_output = L[aten.permute.default](output, [0, 2, 1, 3]) + return L[aten.clone.default](trans_output, memory_format=torch.contiguous_format) return int8_sdpa From 6a119f6e5c40634401af2e8ee07b171c0ebbb884 Mon Sep 17 00:00:00 2001 From: Valentine233 Date: Mon, 3 Mar 2025 21:35:21 -0500 Subject: [PATCH 14/36] update --- test/test_ops.py | 43 --- torchao/csrc/cpu/int8_sdpa.cpp | 357 ++---------------- .../prototype/inductor/fx_passes/README.md | 8 +- 3 files changed, 28 insertions(+), 380 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index c7222c662c..62f85b822e 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -154,21 +154,6 @@ def _scaled_dot_product_int8_op_ref( SDPA_INT8_HEAD_DIM = [32, 64] SDPA_INT8_MASK_DTYPE = [None, torch.float32, torch.bfloat16] - # VIT - # SDPA_INT8_BATCH_SIZE = [224] - # SDPA_INT8_NUM_HEADS = [12] - # SDPA_INT8_Q_SEQ_LEN = [197] - # SDPA_INT8_KV_SEQ_LEN = [197] - # SDPA_INT8_HEAD_DIM = [64] - # SDPA_INT8_MASK_DTYPE = [torch.bfloat16] - # BERTLARGE - # SDPA_INT8_BATCH_SIZE = [120] - # SDPA_INT8_NUM_HEADS = [16] - # SDPA_INT8_Q_SEQ_LEN = [384] - # SDPA_INT8_KV_SEQ_LEN = [384] - # SDPA_INT8_HEAD_DIM = [64] - # SDPA_INT8_MASK_DTYPE = [torch.bfloat16] - @parametrize("batch_size", SDPA_INT8_BATCH_SIZE) @parametrize("n_head", SDPA_INT8_NUM_HEADS) @parametrize("q_seq_len", SDPA_INT8_Q_SEQ_LEN) @@ -191,7 +176,6 @@ def test_scaled_dot_product_int8_op(self, batch_size, n_head, q_seq_len, kv_seq_ q_shape = [batch_size, q_seq_len, n_head, head_dim] kv_shape = [batch_size, kv_seq_len, n_head, head_dim] mask_shape = [batch_size, 1, 1, kv_seq_len] - print(f"q_shape: {q_shape}") q = torch.randn(q_shape, dtype=torch.float, device=device).transpose(1, 2) * 100 k = torch.randn(kv_shape, dtype=torch.float, device=device).transpose(1, 2) * 100 v = torch.randn(kv_shape, dtype=torch.float, device=device).transpose(1, 2) * 100 @@ -239,33 +223,6 @@ def test_scaled_dot_product_int8_op(self, batch_size, n_head, q_seq_len, kv_seq_ ) self.assertEqual(actual, math_ref, atol=1.0, rtol=5e-6) - - iter_n = 20 - with torch.profiler.profile( - activities=[torch.profiler.ProfilerActivity.CPU], - schedule=torch.profiler.schedule(wait=2, warmup=iter_n, active=20), - ) as prof: - for _ in range(iter_n + 22): - r = torch.ops.torchao.scaled_dot_product_int8( - q, - k, - v, - attn_mask=attn_mask_2, - dropout_p=0.0, - is_causal=False, - q_zp=q_zp, - q_scale=q_scale, - k_zp=k_zp, - k_scale=k_scale, - v_zp=v_zp, - v_scale=v_scale, - a_zp=a_zp, - a_scale=a_scale, - o_zp=o_zp, - o_scale=o_scale - ) - prof.step() - print(prof.key_averages().table(sort_by="self_cpu_time_total")) diff --git a/torchao/csrc/cpu/int8_sdpa.cpp b/torchao/csrc/cpu/int8_sdpa.cpp index 0aea924968..3872107e7d 100644 --- a/torchao/csrc/cpu/int8_sdpa.cpp +++ b/torchao/csrc/cpu/int8_sdpa.cpp @@ -45,121 +45,6 @@ inline double calculate_scale( } #ifdef CPU_CAPABILITY_AVX512 -// out = val * a + b -// is_b_stride_zero: If the stride of b is 0 (mask broadcasting case), -// take b as a scalar pointer. -template -inline void _scale_attn_mask_fusion_kernel( - T1* a, - T2* b, - const int& size, - T1* out, - T1& val) { - const auto vec_size1 = at::vec::Vectorized::size(); - const auto vec_size2 = at::vec::Vectorized::size(); - constexpr int64_t T1_n = - (vec_size2 == vec_size1 * 2 && is_reduced_floating_point_v) ? 2 : 1; - constexpr int64_t T2_n = 1; - auto vec_scale = at::vec::VectorizedN(val); - int64_t i = 0; - for (; i < size - (size % vec_size2); i += vec_size2) { - auto a_n = at::vec::VectorizedN::loadu(a + i); - at::vec::VectorizedN b_n; - if constexpr(is_b_stride_zero) { - b_n = at::vec::VectorizedN((T1)b[0]); - } else { - b_n = at::vec::VectorizedN::loadu(b + i); - } - auto b_n_convert = at::vec::convert(b_n); - auto res = a_n * vec_scale + b_n_convert; - res.store(out + i); - } - for (; i < size; i++) { - auto tmp0 = a[i]; - T1 tmp1; - if constexpr(is_b_stride_zero) { - tmp1 = (T1)b[0]; - } else { - tmp1 = (T1)b[i]; - } - out[i] = tmp0 * val + tmp1; - } -} - -// 1) out = exp(a - val) -// 2) val = sum(out) -template -inline void _exp_reduce_sum_fusion_kernel( - T1* a, - const int& size, - T2* out, - T1& val) { - auto vec_size = at::vec::Vectorized::size(); - auto vec_max = at::vec::Vectorized(val); - T1 tmp_sum = 0; - auto vec_tmp_sum = at::vec::Vectorized(tmp_sum); - for (long i = 0; i < vec_size * (size / vec_size); i += vec_size) { - auto tmp0 = at::vec::Vectorized::loadu(a + i); - auto tmp1 = tmp0 - vec_max; - auto tmp2 = tmp1.exp_u20(); - vec_tmp_sum += tmp2; - _store(out + i, tmp2); - } - tmp_sum = at::vec::vec_reduce_all( - [](at::vec::Vectorized& x, at::vec::Vectorized& y) { - return x + y; - }, - vec_tmp_sum); - for (long i = vec_size * (size / vec_size); i < size; i++) { - auto tmp0 = a[i]; - auto tmp1 = tmp0 - val; - auto tmp2 = exp(tmp1); - tmp_sum += tmp2; - out[i] = tmp2; - } - val = tmp_sum; -} - -// 1) out = a * scale -// 2) max = max(out) -template -inline void _mul_reduce_max_fusion_kernel( - const scalar_t* a, - const scalar_t& scale, - const int& size, - scalar_t* out, - scalar_t& max) { - auto vec_size = at::vec::Vectorized::size(); - auto vec_scale = at::vec::Vectorized(scale); - scalar_t tmp_max = -std::numeric_limits::infinity(); - auto vec_tmp_max = at::vec::Vectorized(tmp_max); - for (long i = 0; i < vec_size * (size / vec_size); i += vec_size) { - auto tmp0 = at::vec::Vectorized::loadu(a + i); - auto tmp1 = tmp0 * vec_scale; - vec_tmp_max = at::vec::maximum(vec_tmp_max, tmp1); - _store(out + i, tmp1); - } - for (long i = vec_size * (size / vec_size); i < size; i++) { - auto tmp0 = a[i]; - auto tmp1 = tmp0 * scale; - tmp_max = std::max(tmp_max, tmp1); - out[i] = tmp1; - } - max = std::max(tmp_max, vec_tmp_max.reduce_max()); -} - -template -static inline scalar_t* conditional_data_ptr(scalar_t* ptr, scalar_t* ptr2) { - TORCH_CHECK(ptr2 == nullptr); - return ptr; -} - -template , int> = 0> -static inline scalar_t* conditional_data_ptr(float* ptr, scalar_t* ptr2) { - return ptr2; -} - template inline void fill_stub(scalar_t* data, scalar_t val, int64_t size) { using Vec = at::vec::Vectorized; @@ -221,132 +106,6 @@ _store(scalar_t* dst, at::vec::Vectorized src) { res.store(dst, at::vec::Vectorized::size()); } -template -inline void pad_row_zero( - scalar_t* value_ptr, - scalar_t* padding_value_ptr, - int rows, - int cols, - int ldi) { - auto vec_size = at::vec::Vectorized::size(); - int i = 0; - for (; i < rows - 1; i++) { - int j = 0; - for (; j < cols - (cols % vec_size); j += vec_size) { - auto vec_v = - at::vec::Vectorized::loadu(value_ptr + i * ldi + j); - vec_v.store(padding_value_ptr + i * cols + j); - } - - if (j < cols) { - auto vec_v = at::vec::Vectorized::loadu( - value_ptr + i * ldi + j, cols - j); - vec_v.store(padding_value_ptr + i * cols + j, cols - j); - } - } - - // zero padding - int j = 0; - for (; j < cols - (cols % vec_size); j += vec_size) { - auto vec_v = at::vec::Vectorized(0); - vec_v.store(padding_value_ptr + i * cols + j); - } - - if (j < cols) { - auto vec_v = at::vec::Vectorized(0); - vec_v.store(padding_value_ptr + i * cols + j, cols - j); - } -} - -template -inline void pad_row_128_padding( - scalar_t* value_ptr, - scalar_t* padding_value_ptr, - int rows, - int cols, - int ldi, - int padding) { - auto vec_size = at::vec::Vectorized::size(); - int i = 0; - for (; i < rows - padding; i++) { - int j = 0; - for (; j < cols - (cols % vec_size); j += vec_size) { - auto vec_v = - at::vec::Vectorized::loadu(value_ptr + i * ldi + j); - vec_v.store(padding_value_ptr + i * cols + j); - } - - if (j < cols) { - auto vec_v = at::vec::Vectorized::loadu( - value_ptr + i * ldi + j, cols - j); - vec_v.store(padding_value_ptr + i * cols + j, cols - j); - } - } - - // 128 padding - for (; i < rows; i++) { - int j = 0; - for (; j < cols - (cols % vec_size); j += vec_size) { - auto vec_v = at::vec::Vectorized(128); - vec_v.store(padding_value_ptr + i * cols + j); - } - - if (j < cols) { - auto vec_v = at::vec::Vectorized(128); - vec_v.store(padding_value_ptr + i * cols + j, cols - j); - } - } -} - -template -inline void pad_col_zero( - scalar_t* value_ptr, - scalar_t* padding_value_ptr, - int rows, - int cols, - int ldi) { - auto vec_size = at::vec::Vectorized::size(); - for (int i = 0; i < rows; i++) { - int j = 0; - for (; j < cols - 1 - ((cols - 1) % vec_size); j += vec_size) { - auto vec_v = - at::vec::Vectorized::loadu(value_ptr + i * ldi + j); - vec_v.store(padding_value_ptr + i * cols + j); - } - if (j < cols - 1) { - auto vec_v = at::vec::Vectorized::loadu( - value_ptr + i * ldi + j, cols - 1 - j); - vec_v.store(padding_value_ptr + i * cols + j, cols - 1 - j); - *(padding_value_ptr + i * cols + cols - 1) = scalar_t(0); - } - } -} - -template -inline void pad_col_zero_padding( - scalar_t* value_ptr, - scalar_t* padding_value_ptr, - int rows, - int cols, - int ldi, - int padding) { - auto vec_size = at::vec::Vectorized::size(); - for (int i = 0; i < rows; i++) { - int j = 0; - for (; j < cols - padding - ((cols - padding) % vec_size); j += vec_size) { - auto vec_v = - at::vec::Vectorized::loadu(value_ptr + i * ldi + j); - vec_v.store(padding_value_ptr + i * cols + j); - } - if (j < cols - padding) { - auto vec_v = at::vec::Vectorized::loadu( - value_ptr + i * ldi + j, cols - padding - j); - vec_v.store(padding_value_ptr + i * cols + j, cols - padding - j); - *(padding_value_ptr + i * cols + cols - padding) = scalar_t(0); - } - } -} - /* 1. dequant 2. add mask @@ -389,7 +148,7 @@ inline void _dequant_mask_max_fusion_kernel( auto tmp6 = at::vec::Vectorized::loadu(mask_data_ptr + col); auto tmp7 = at::vec::convert(tmp6); auto tmp8 = tmp5 + tmp7; - vec_tmp_max = at::vec::maximum(vec_tmp_max, tmp8); + vec_tmp_max = at::vec::clamp_min(vec_tmp_max, tmp8); _store(tmp_out + col, tmp8); } tmp_max = std::max(tmp_max, vec_tmp_max.reduce_max()); @@ -445,7 +204,7 @@ inline void _dequant_max_fusion_kernel( auto tmp3 = tmp2 + vec_beta; auto tmp4 = at::vec::convert(tmp3); auto tmp5 = tmp4 * vec_alpha; - vec_tmp_max = at::vec::maximum(vec_tmp_max, tmp5); + vec_tmp_max = at::vec::clamp_min(vec_tmp_max, tmp5); _store(tmp_out + col, tmp5); } tmp_max = std::max(tmp_max, vec_tmp_max.reduce_max()); @@ -542,10 +301,9 @@ inline void _sub_exp_sum_div_quant_sum_fusion_kernel( auto tmp1 = tmp0 * vec_sum_scale; auto tmp2 = tmp1.round(); auto tmp3 = tmp2 + vec_beta1; - auto tmp4 = at::vec::maximum(tmp3, vec_min_val); - auto tmp5 = at::vec::minimum(tmp4, vec_max_val); - _store(tmp_out + col, tmp5); - auto tmp6 = at::vec::convert(tmp5); + auto tmp4 = at::vec::clamp(tmp3, vec_min_val, vec_max_val); + _store(tmp_out + col, tmp4); + auto tmp6 = at::vec::convert(tmp4); vec_tmp_sum += tmp6; } tmp_sum += vec_tmp_sum.reduce_add(); @@ -554,10 +312,9 @@ inline void _sub_exp_sum_div_quant_sum_fusion_kernel( auto tmp1 = tmp0 * sum_scale; auto tmp2 = std::nearbyint(tmp1); auto tmp3 = tmp2 + beta1_float; - auto tmp4 = std::max(tmp3, min_val); - auto tmp5 = std::min(tmp4, max_val); - tmp_out[col] = tmp5; - auto tmp6 = (int32_t) tmp5; + auto tmp4 = std::clamp(tmp3, min_val, max_val); + tmp_out[col] = tmp4; + auto tmp6 = (int32_t) tmp4; tmp_sum += tmp6; } sum_a_ptr[row] += tmp_sum * beta2; @@ -641,18 +398,16 @@ inline void _sub_exp_sum_div_quant_fusion_kernel( auto tmp1 = tmp0 * vec_sum_scale; auto tmp2 = tmp1.round(); auto tmp3 = tmp2 + vec_beta1; - auto tmp4 = at::vec::maximum(tmp3, vec_min_val); - auto tmp5 = at::vec::minimum(tmp4, vec_max_val); - _store(tmp_out + col, tmp5); + auto tmp4 = at::vec::clamp(tmp3, vec_min_val, vec_max_val); + _store(tmp_out + col, tmp4); } for (long col = vec_size * (kvBlockSize / vec_size); col < kvBlockSize; col++) { auto tmp0 = tmp_in[col]; auto tmp1 = tmp0 * sum_scale; auto tmp2 = std::nearbyint(tmp1); auto tmp3 = tmp2 + beta1_float; - auto tmp4 = std::max(tmp3, min_val); - auto tmp5 = std::min(tmp4, max_val); - tmp_out[col] = tmp5; + auto tmp4 = std::clamp(tmp3, min_val, max_val); + tmp_out[col] = tmp4; } // set zero for (long col = kvBlockSize; col < vec_size * (av_gemm_K / vec_size); col += vec_size) { @@ -706,9 +461,8 @@ inline void _dequant_quant_fusion_kernel( auto tmp5 = tmp4 * vec_alpha; auto tmp6 = tmp5.round(); auto tmp7 = tmp6 + vec_beta2; - auto tmp8 = at::vec::maximum(tmp7, vec_min_val); - auto tmp9 = at::vec::minimum(tmp8, vec_max_val); - _store(tmp_out + col, tmp9); + auto tmp8 = at::vec::clamp(tmp7, vec_min_val, vec_max_val); + _store(tmp_out + col, tmp8); } for (long col = vec_size * (N / vec_size); col < N; col++) { auto sum_b = sum_b_ptr[col]; @@ -720,9 +474,8 @@ inline void _dequant_quant_fusion_kernel( auto tmp5 = tmp4 * alpha; auto tmp6 = std::nearbyint(tmp5); auto tmp7 = tmp6 + beta2_float; - auto tmp8 = std::max(tmp7, min_val); - auto tmp9 = std::min(tmp8, max_val); - tmp_out[col] = tmp9; + auto tmp8 = std::clamp(tmp7, min_val, max_val); + tmp_out[col] = tmp8; } } } @@ -811,35 +564,6 @@ inline void _int_sum_a_contiguous_kernel( } } -inline void do_convert_u8_s8( - unsigned char* src, - signed char* dst, - int64_t in_rows, - int64_t in_cols, - int64_t ldi, - int64_t ldo) { - auto vec_size = at::vec::Vectorized::size(); - auto vec_128 = at::vec::Vectorized(128); - for (int64_t r = 0; r < in_rows; r++) { - const unsigned char* tmp_src = src + r * ldi; - signed char* tmp_dst = dst + r * ldo; - for (int64_t c = 0; c < vec_size * (in_cols / vec_size); c += vec_size) { - auto tmp0 = at::vec::Vectorized::loadu(tmp_src + c, vec_size); - auto tmp1 = at::vec::convert(tmp0); - auto tmp2 = tmp1 - vec_128; - auto tmp3 = at::vec::convert(tmp2); - _store(tmp_dst + c, tmp3, vec_size); - } - for (int64_t c = vec_size * (in_cols / vec_size); c < in_cols; c++) { - auto tmp0 = tmp_src[c]; - auto tmp1 = (int16_t) tmp0; - auto tmp2 = tmp1 - 128; - auto tmp3 = (signed char) tmp2; - tmp_dst[c] = tmp3; - } - } -} - template inline void do_transpose( scalar_t* src, @@ -855,21 +579,6 @@ inline void do_transpose( } } -template -inline void do_copy( - scalar_t* src, - scalar_t* dst, - int64_t in_rows, - int64_t in_cols, - int64_t ldi, - int64_t ldo) { - for (int64_t r=0; r inline void pad_remain_row_col( scalar_t* value_ptr, @@ -994,12 +703,11 @@ sdpa_int8_kernel_one_loop_impl( at::Tensor key = k.transpose(1, 2); at::Tensor value = v.transpose(1, 2); - const auto accumulate_dtype = at::kFloat; - using accum_t = float; - using Vec = at::vec::Vectorized; accum_t scaling_factor = calculate_scale(query, scale); int block_64 = 64; + auto u8_dt = at::ScalarType::Byte; + // Sizes TORCH_CHECK( (query.size(3) == value.size(3)) && (key.size(3) == value.size(3)), @@ -1013,7 +721,6 @@ sdpa_int8_kernel_one_loop_impl( int64_t num_head = query.size(2); int64_t headSize = query.size(3); - bool has_attn_mask = attention_mask.defined() && attention_mask.numel(); if (has_attn_mask) { reshape_attn_mask_to_4d(attention_mask, batchSize, num_head, qSize, kvSize); @@ -1081,13 +788,6 @@ sdpa_int8_kernel_one_loop_impl( int av_gemm_K_padding = av_gemm_K_mul4 ? 0 : 4 - kvSplitSize % 4; // // If K of Gemm is not even, use mkl gemm instead of tpp for BF16 int av_gemm_K = kvSplitSize + av_gemm_K_padding; - bool av_gemm_K_tail_mul4 = kvTail % 4 == 0; - int av_gemm_K_tail_padding = av_gemm_K_tail_mul4 ? 0 : 4 - kvTail % 4; - int av_gemm_K_tail = kvTail + av_gemm_K_tail_padding; - - auto u8_dt = at::ScalarType::Byte; - auto s8_dt = at::ScalarType::Int; - auto f32_dt = at::ScalarType::Float; // Data ptrs scalar_t* q_data = query.data_ptr(); @@ -1125,7 +825,7 @@ sdpa_int8_kernel_one_loop_impl( at::Tensor total_buf = at::empty( {num_thread, total_size_uint8_per_thread}, - query.options()).zero_(); + query.options()); scalar_t* total_buf_data = total_buf.data_ptr(); at::parallel_for( @@ -1540,12 +1240,11 @@ sdpa_int8_kernel_several_loops_impl( at::Tensor key = k.transpose(1, 2); at::Tensor value = v.transpose(1, 2); - const auto accumulate_dtype = at::kFloat; - using accum_t = float; - using Vec = at::vec::Vectorized; accum_t scaling_factor = calculate_scale(query, scale); int block_64 = 64; + auto u8_dt = at::ScalarType::Byte; + // Sizes TORCH_CHECK( (query.size(3) == value.size(3)) && (key.size(3) == value.size(3)), @@ -1559,7 +1258,6 @@ sdpa_int8_kernel_several_loops_impl( int64_t num_head = query.size(2); int64_t headSize = query.size(3); - bool has_attn_mask = attention_mask.defined() && attention_mask.numel(); if (has_attn_mask) { reshape_attn_mask_to_4d(attention_mask, batchSize, num_head, qSize, kvSize); @@ -1627,13 +1325,6 @@ sdpa_int8_kernel_several_loops_impl( int av_gemm_K_padding = av_gemm_K_mul4 ? 0 : 4 - kvSplitSize % 4; // // If K of Gemm is not even, use mkl gemm instead of tpp for BF16 int av_gemm_K = kvSplitSize + av_gemm_K_padding; - bool av_gemm_K_tail_mul4 = kvTail % 4 == 0; - int av_gemm_K_tail_padding = av_gemm_K_tail_mul4 ? 0 : 4 - kvTail % 4; - int av_gemm_K_tail = kvTail + av_gemm_K_tail_padding; - - auto u8_dt = at::ScalarType::Byte; - auto s8_dt = at::ScalarType::Int; - auto f32_dt = at::ScalarType::Float; // Data ptrs scalar_t* q_data = query.data_ptr(); @@ -1667,7 +1358,7 @@ sdpa_int8_kernel_several_loops_impl( at::Tensor total_buf = at::empty( {num_thread, total_size_uint8_per_thread}, - query.options()).zero_(); + query.options()); scalar_t* total_buf_data = total_buf.data_ptr(); int64_t kv_sum_size_per_BH = @@ -1676,7 +1367,7 @@ sdpa_int8_kernel_several_loops_impl( at::Tensor kv_sum_buf = at::empty( {batchSize, num_head, kv_sum_size_per_BH}, - query.options().dtype(at::kInt)).zero_(); + query.options().dtype(at::kInt)); int32_t* kv_sum_buf_data = kv_sum_buf.data_ptr(); int64_t kv_reorder_size_per_BH = @@ -1685,7 +1376,7 @@ sdpa_int8_kernel_several_loops_impl( at::Tensor kv_reorder_buf = at::empty( {batchSize, num_head, kv_reorder_size_per_BH}, - query.options()).zero_(); + query.options()); scalar_t* kv_reorder_buf_data = kv_reorder_buf.data_ptr(); scalar_t* key_reorder_ptr = kv_reorder_buf_data; scalar_t* value_reorder_ptr = kv_reorder_buf_data + batchSize * num_head * qk_gemm_K * rndkvSize; diff --git a/torchao/prototype/inductor/fx_passes/README.md b/torchao/prototype/inductor/fx_passes/README.md index 71535f87fb..9171f508a8 100644 --- a/torchao/prototype/inductor/fx_passes/README.md +++ b/torchao/prototype/inductor/fx_passes/README.md @@ -1,8 +1,8 @@ # Inductor FX Passes -This directory contains the FX passes of Inductor. FX passes are transformations applied to the computational graph to optimize and modify it for better performance and functionality. +This directory contains the FX passes of Inductor. FX passes are transformations applied to the FX graph to optimize and modify it for better performance and functionality. -You can replace the following customized graph passes of Inductor in TorchAO: +In TorchAO, you can replace the following customized graph passes of Inductor: - `pre_grad_custom_pass` - `joint_custom_pre_pass` - `joint_custom_post_pass` @@ -15,7 +15,7 @@ You can replace the following customized graph passes of Inductor in TorchAO: ## Getting Started -To get started with using the FX passes in TorchAO, you can register and apply them to your computational graph as follows: +To get started with using the FX passes in TorchAO, you can register and apply them to your FX graph as follows: ```python from torch._inductor import config @@ -23,7 +23,7 @@ from torch._inductor.pattern_matcher import PatternMatcherPass # Example usage patterns = PatternMatcherPass() # create a pattern matcher pass -_register_patterns() # register your own patterns +_register_patterns(...) # register your own patterns config.custom_pass = patterns.apply # define the custom pass with the patterns ``` From ecb0516d6afa44329b8a065239103d2d1f1259d6 Mon Sep 17 00:00:00 2001 From: Valentine233 Date: Mon, 3 Mar 2025 22:08:38 -0500 Subject: [PATCH 15/36] update --- .../inductor/test_int8_sdpa_fusion.py | 35 +++---------------- .../inductor/fx_passes/int8_sdpa_fusion.py | 1 - 2 files changed, 5 insertions(+), 31 deletions(-) diff --git a/test/prototype/inductor/test_int8_sdpa_fusion.py b/test/prototype/inductor/test_int8_sdpa_fusion.py index 315e5b90f9..3a8c0b7e2b 100644 --- a/test/prototype/inductor/test_int8_sdpa_fusion.py +++ b/test/prototype/inductor/test_int8_sdpa_fusion.py @@ -147,55 +147,30 @@ def _check_common( ): self.assertEqual(arg1.grad, arg2.grad, atol=atol, rtol=rtol) - iter_n = 20 - with torch.profiler.profile( - activities=[torch.profiler.ProfilerActivity.CPU], - schedule=torch.profiler.schedule(wait=2, warmup=iter_n, active=20), - ) as prof: - for _ in range(iter_n + 22): - r = compiled_model(*(args2 + dropout_arg)) - prof.step() - print(prof.key_averages().table(sort_by="self_cpu_time_total")) - @skipIfRocm @config.patch({"freezing": True}) - def _test_sdpa_rewriter_int8_1_to_4(self): + def _test_sdpa_int8_rewriter(self): # pattern is different for bs=1 for dtype, has_mask, bs in itertools.product( [torch.float32, torch.bfloat16], [True, False], [56, 1] ): seqlen, numhead, headsize = 197, 16, 64 - # dtype = torch.bfloat16 - # has_mask = True - # is_bs_1 = 0 - # if is_bs_1: - # candidates = [[1, 384, 16, 64], [1, 197, 12, 64]] - # else: - # candidates = [[120, 384, 16, 64], [224, 197, 12, 64]] - # candidates = [[120, 384, 16, 64]] - # for bs, seqlen, numhead, headsize in candidates: mod = SelfAttnLikeModule( input_dim=headsize * numhead, has_mask=has_mask, num_attention_heads=numhead, attention_head_size=headsize, ).eval() - maybe_autocast = ( - torch.cpu.amp.autocast() - if dtype == torch.bfloat16 - else contextlib.nullcontext() - ) - print("\nTEST shape", bs, numhead, seqlen, headsize) inputs = ( torch.randn( (bs, seqlen, headsize * numhead), device=self.device, dtype=dtype - ) - * 10, + ) * 10, torch.randn((bs, 1, 1, seqlen), device=self.device) * 10 if has_mask else None, ) - with torch.no_grad(), maybe_autocast: + enable_autocast = (dtype == torch.bfloat16) + with torch.no_grad(), torch.amp.autocast("cpu", enabled=enable_autocast, dtype=torch.bfloat16): _int8_sdpa_init() quantizer = X86InductorQuantizer() quantizer.set_global(xiq.get_default_x86_inductor_quantization_config()) @@ -217,7 +192,7 @@ def _test_sdpa_rewriter_int8_1_to_4(self): if HAS_CPU: class SDPAPatternRewriterCpuTests(TestSDPAPatternRewriterTemplate): device = "cpu" - test_sdpa_rewriter_int8_1_to_4_cpu = TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_int8_1_to_4 + test_sdpa_int8_rewriter_cpu = TestSDPAPatternRewriterTemplate._test_sdpa_int8_rewriter if __name__ == "__main__": if IS_LINUX: diff --git a/torchao/prototype/inductor/fx_passes/int8_sdpa_fusion.py b/torchao/prototype/inductor/fx_passes/int8_sdpa_fusion.py index 128ae60766..c860522794 100644 --- a/torchao/prototype/inductor/fx_passes/int8_sdpa_fusion.py +++ b/torchao/prototype/inductor/fx_passes/int8_sdpa_fusion.py @@ -45,7 +45,6 @@ def _register_int8_sdpa_pattern(pattern): extra_check=_is_valid_int8_sdpa_pattern(), ) def int8_sdpa(match: Match, *args, **kwargs): - print("\n***hit int8_sdpa_pattern***\n") query = kwargs["query"] key = kwargs["key"] value = kwargs["value"] From 170499eea78e2fc0e0adcb30a294133d4a7d89cd Mon Sep 17 00:00:00 2001 From: Valentine233 Date: Tue, 4 Mar 2025 01:14:31 -0500 Subject: [PATCH 16/36] fix issue --- setup.py | 4 +++- .../inductor/test_int8_sdpa_fusion.py | 19 +++++++------------ test/test_ops.py | 2 +- .../prototype/inductor/fx_passes/__init__.py | 5 +++++ .../inductor/fx_passes/int8_sdpa_fusion.py | 10 ++++------ 5 files changed, 20 insertions(+), 20 deletions(-) create mode 100644 torchao/prototype/inductor/fx_passes/__init__.py diff --git a/setup.py b/setup.py index 0b758b5e81..d031229268 100644 --- a/setup.py +++ b/setup.py @@ -72,6 +72,7 @@ def use_debug_mode(): import torch from torch.utils.cpp_extension import ( CUDA_HOME, + IS_MACOS, IS_WINDOWS, ROCM_HOME, BuildExtension, @@ -297,8 +298,9 @@ def get_extensions(): "-DCPU_CAPABILITY_AVX512", "-march=native", "-mfma", - "-fopenmp", ]) + if not IS_MACOS: + extra_compile_args["cxx"].append("-fopenmp") if debug_mode: extra_compile_args["cxx"].append("-g") diff --git a/test/prototype/inductor/test_int8_sdpa_fusion.py b/test/prototype/inductor/test_int8_sdpa_fusion.py index 3a8c0b7e2b..f4ed7a19cd 100644 --- a/test/prototype/inductor/test_int8_sdpa_fusion.py +++ b/test/prototype/inductor/test_int8_sdpa_fusion.py @@ -1,28 +1,23 @@ -import torchao - -import contextlib -import functools import itertools -import math import torch +import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq import torch.utils.checkpoint -from torch._dynamo.debug_utils import aot_graph_input_parser from torch._dynamo.utils import counters from torch._inductor import config -from torch._inductor.test_case import run_tests, TestCase +from torch._inductor.test_case import TestCase, run_tests from torch._inductor.utils import run_and_get_code -from torch.testing._internal.common_utils import IS_LINUX, skipIfRocm -from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA - -import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq -from torch.export import export_for_training from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e from torch.ao.quantization.quantizer.x86_inductor_quantizer import ( X86InductorQuantizer, ) +from torch.export import export_for_training +from torch.testing._internal.common_utils import IS_LINUX, skipIfRocm +from torch.testing._internal.inductor_utils import HAS_CPU + from torchao.prototype.inductor.fx_passes.int8_sdpa_fusion import _int8_sdpa_init + class SelfAttnLikeModule(torch.nn.Module): def __init__( self, diff --git a/test/test_ops.py b/test/test_ops.py index 62f85b822e..be3c1ce373 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -4,10 +4,10 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. import itertools +import math import sys import pytest -import math import torch from torch.testing._internal.common_utils import ( TestCase, diff --git a/torchao/prototype/inductor/fx_passes/__init__.py b/torchao/prototype/inductor/fx_passes/__init__.py new file mode 100644 index 0000000000..aae6d5348a --- /dev/null +++ b/torchao/prototype/inductor/fx_passes/__init__.py @@ -0,0 +1,5 @@ +from .int8_sdpa_fusion import _int8_sdpa_init + +__all__ = [ + "_int8_sdpa_init", +] diff --git a/torchao/prototype/inductor/fx_passes/int8_sdpa_fusion.py b/torchao/prototype/inductor/fx_passes/int8_sdpa_fusion.py index c860522794..2d5772772b 100644 --- a/torchao/prototype/inductor/fx_passes/int8_sdpa_fusion.py +++ b/torchao/prototype/inductor/fx_passes/int8_sdpa_fusion.py @@ -1,20 +1,18 @@ import functools -from typing import Callable import torch +from torch._dynamo.utils import counters from torch._inductor import config +from torch._inductor.fx_passes.post_grad import register_lowering_pattern +from torch._inductor.lowering import lowerings as L +from torch._inductor.lowering import make_fallback from torch._inductor.pattern_matcher import ( Arg, CallFunction, - filter_nodes, KeywordArg, - ListOf, Match, PatternMatcherPass, ) -from torch._inductor.fx_passes.post_grad import register_lowering_pattern -from torch._inductor.lowering import lowerings as L, make_fallback -from torch._dynamo.utils import counters make_fallback(torch.ops.torchao.scaled_dot_product_int8.default) From b9804fa9c1ea0e73294f5c7b5e5902abdafa04b9 Mon Sep 17 00:00:00 2001 From: Valentine233 Date: Tue, 4 Mar 2025 02:46:16 -0500 Subject: [PATCH 17/36] fix issue --- setup.py | 25 +++--- .../inductor/test_int8_sdpa_fusion.py | 25 +++++- test/test_ops.py | 82 ++++++++++++------- torchao/csrc/cpu/int8_sdpa.cpp | 41 +++++----- torchao/ops.py | 26 ++++-- .../inductor/fx_passes/int8_sdpa_fusion.py | 70 ++++++++++++---- 6 files changed, 183 insertions(+), 86 deletions(-) diff --git a/setup.py b/setup.py index d031229268..2fe4c7e912 100644 --- a/setup.py +++ b/setup.py @@ -45,8 +45,7 @@ def read_version(file_path="version.txt"): if version_suffix is None: version_suffix = f"+git{get_git_commit_id()}" -use_cpp = os.getenv('USE_CPP') -use_cpp_avx512 = os.getenv('USE_AVX512', '1') == '1' +use_cpp = os.getenv("USE_CPP") import platform @@ -56,6 +55,10 @@ def read_version(file_path="version.txt"): and platform.system() == "Darwin" ) +use_cpp_avx512 = os.getenv("USE_AVX512", "1") == "1" and build_torchao_experimental + +from torchao.utils import TORCH_VERSION_AT_LEAST_2_7 + version_prefix = read_version() # Version is version.dev year month date if using nightlies and version if not version = ( @@ -72,7 +75,6 @@ def use_debug_mode(): import torch from torch.utils.cpp_extension import ( CUDA_HOME, - IS_MACOS, IS_WINDOWS, ROCM_HOME, BuildExtension, @@ -293,14 +295,15 @@ def get_extensions(): ["-O3" if not debug_mode else "-O0", "-fdiagnostics-color=always"] ) - if use_cpp_avx512: - extra_compile_args["cxx"].extend([ - "-DCPU_CAPABILITY_AVX512", - "-march=native", - "-mfma", - ]) - if not IS_MACOS: - extra_compile_args["cxx"].append("-fopenmp") + if use_cpp_avx512 and TORCH_VERSION_AT_LEAST_2_7: + extra_compile_args["cxx"].extend( + [ + "-DCPU_CAPABILITY_AVX512", + "-march=native", + "-mfma", + "-fopenmp", + ] + ) if debug_mode: extra_compile_args["cxx"].append("-g") diff --git a/test/prototype/inductor/test_int8_sdpa_fusion.py b/test/prototype/inductor/test_int8_sdpa_fusion.py index f4ed7a19cd..0dad80cfa4 100644 --- a/test/prototype/inductor/test_int8_sdpa_fusion.py +++ b/test/prototype/inductor/test_int8_sdpa_fusion.py @@ -1,5 +1,6 @@ import itertools +import pytest import torch import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq import torch.utils.checkpoint @@ -16,6 +17,7 @@ from torch.testing._internal.inductor_utils import HAS_CPU from torchao.prototype.inductor.fx_passes.int8_sdpa_fusion import _int8_sdpa_init +from torchao.utils import TORCH_VERSION_AT_LEAST_2_7 class SelfAttnLikeModule(torch.nn.Module): @@ -68,6 +70,7 @@ def forward(self, x, mask): ) return self.dense(context_layer) + class TestSDPAPatternRewriterTemplate(TestCase): def _clone_inputs(self, inputs): def clone(x): @@ -143,6 +146,9 @@ def _check_common( self.assertEqual(arg1.grad, arg2.grad, atol=atol, rtol=rtol) @skipIfRocm + @pytest.mark.skipif( + not TORCH_VERSION_AT_LEAST_2_7, reason="int8 sdpa requires torch 2.7 or later" + ) @config.patch({"freezing": True}) def _test_sdpa_int8_rewriter(self): # pattern is different for bs=1 @@ -159,13 +165,19 @@ def _test_sdpa_int8_rewriter(self): inputs = ( torch.randn( (bs, seqlen, headsize * numhead), device=self.device, dtype=dtype - ) * 10, + ) + * 10, torch.randn((bs, 1, 1, seqlen), device=self.device) * 10 if has_mask else None, ) - enable_autocast = (dtype == torch.bfloat16) - with torch.no_grad(), torch.amp.autocast("cpu", enabled=enable_autocast, dtype=torch.bfloat16): + enable_autocast = dtype == torch.bfloat16 + with ( + torch.no_grad(), + torch.amp.autocast( + "cpu", enabled=enable_autocast, dtype=torch.bfloat16 + ), + ): _int8_sdpa_init() quantizer = X86InductorQuantizer() quantizer.set_global(xiq.get_default_x86_inductor_quantization_config()) @@ -184,10 +196,15 @@ def _test_sdpa_int8_rewriter(self): convert_model, args1=inputs, check_train=False, atol=1.0 ) + if HAS_CPU: + class SDPAPatternRewriterCpuTests(TestSDPAPatternRewriterTemplate): device = "cpu" - test_sdpa_int8_rewriter_cpu = TestSDPAPatternRewriterTemplate._test_sdpa_int8_rewriter + test_sdpa_int8_rewriter_cpu = ( + TestSDPAPatternRewriterTemplate._test_sdpa_int8_rewriter + ) + if __name__ == "__main__": if IS_LINUX: diff --git a/test/test_ops.py b/test/test_ops.py index be3c1ce373..3d165514a4 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -24,7 +24,11 @@ ) from torchao.quantization.quant_primitives import choose_qparams_and_quantize_affine_qqq from torchao.sparsity.marlin import inject_24, marlin_24_workspace, pack_to_marlin_24 -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, compute_max_diff +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_5, + TORCH_VERSION_AT_LEAST_2_7, + compute_max_diff, +) IS_CUDA = torch.cuda.is_available() and torch.version.cuda IS_ROCM = torch.cuda.is_available() and torch.version.hip @@ -111,23 +115,24 @@ def test_quant_llm_linear_correctness( assert relative_error < rtol def _scaled_dot_product_int8_op_ref( - self, - q, - k, - v, - attn_mask=None, - dropout_p=0, - is_causal=False, - q_zp=0, - q_scale=1.0, - k_zp=0, - k_scale=1.0, - v_zp=0, - v_scale=1.0, - a_zp=0, - a_scale=1.0, - o_zp=0, - o_scale=1.0): + self, + q, + k, + v, + attn_mask=None, + dropout_p=0, + is_causal=False, + q_zp=0, + q_scale=1.0, + k_zp=0, + k_scale=1.0, + v_zp=0, + v_scale=1.0, + a_zp=0, + a_scale=1.0, + o_zp=0, + o_scale=1.0, + ): q = (q.to(torch.float) - q_zp) * q_scale k = (k.to(torch.float) - k_zp) * k_scale v = (v.to(torch.float) - v_zp) * v_scale @@ -140,7 +145,7 @@ def _scaled_dot_product_int8_op_ref( attn = attn - attn_max attn = torch.exp(attn) attn_sum = torch.sum(attn, dim=-1, keepdim=True) - attn = attn / attn_sum + attn = attn / attn_sum attn = torch.clamp(torch.round(attn / a_scale) + a_zp, min=0, max=255) attn = (attn - a_zp) * a_scale out = attn @ v @@ -154,13 +159,18 @@ def _scaled_dot_product_int8_op_ref( SDPA_INT8_HEAD_DIM = [32, 64] SDPA_INT8_MASK_DTYPE = [None, torch.float32, torch.bfloat16] + @pytest.mark.skipif( + not TORCH_VERSION_AT_LEAST_2_7, reason="int8 sdpa requires torch 2.7 or later" + ) @parametrize("batch_size", SDPA_INT8_BATCH_SIZE) @parametrize("n_head", SDPA_INT8_NUM_HEADS) @parametrize("q_seq_len", SDPA_INT8_Q_SEQ_LEN) @parametrize("kv_seq_len", SDPA_INT8_KV_SEQ_LEN) @parametrize("head_dim", SDPA_INT8_HEAD_DIM) @parametrize("mask_dtype", SDPA_INT8_MASK_DTYPE) - def test_scaled_dot_product_int8_op(self, batch_size, n_head, q_seq_len, kv_seq_len, head_dim, mask_dtype): + def test_scaled_dot_product_int8_op( + self, batch_size, n_head, q_seq_len, kv_seq_len, head_dim, mask_dtype + ): torch.manual_seed(1234) device = "cpu" q_zp = int(127) @@ -177,14 +187,29 @@ def test_scaled_dot_product_int8_op(self, batch_size, n_head, q_seq_len, kv_seq_ kv_shape = [batch_size, kv_seq_len, n_head, head_dim] mask_shape = [batch_size, 1, 1, kv_seq_len] q = torch.randn(q_shape, dtype=torch.float, device=device).transpose(1, 2) * 100 - k = torch.randn(kv_shape, dtype=torch.float, device=device).transpose(1, 2) * 100 - v = torch.randn(kv_shape, dtype=torch.float, device=device).transpose(1, 2) * 100 + k = ( + torch.randn(kv_shape, dtype=torch.float, device=device).transpose(1, 2) + * 100 + ) + v = ( + torch.randn(kv_shape, dtype=torch.float, device=device).transpose(1, 2) + * 100 + ) q = q.to(torch.uint8) k = k.to(torch.uint8) v = v.to(torch.uint8) - attn_mask = torch.randn(mask_shape, dtype=mask_dtype, device=device) if mask_dtype is not None else None - q2, k2, v2, attn_mask_2 = q.clone(), k.clone(), v.clone(), attn_mask.clone() if mask_dtype is not None else None - + attn_mask = ( + torch.randn(mask_shape, dtype=mask_dtype, device=device) + if mask_dtype is not None + else None + ) + q2, k2, v2, attn_mask_2 = ( + q.clone(), + k.clone(), + v.clone(), + attn_mask.clone() if mask_dtype is not None else None, + ) + math_ref = self._scaled_dot_product_int8_op_ref( q2, k2, @@ -201,8 +226,8 @@ def test_scaled_dot_product_int8_op(self, batch_size, n_head, q_seq_len, kv_seq_ a_zp=a_zp, a_scale=a_scale, o_zp=o_zp, - o_scale=o_scale - ) + o_scale=o_scale, + ) actual = torch.ops.torchao.scaled_dot_product_int8( q, k, @@ -219,13 +244,12 @@ def test_scaled_dot_product_int8_op(self, batch_size, n_head, q_seq_len, kv_seq_ a_zp=a_zp, a_scale=a_scale, o_zp=o_zp, - o_scale=o_scale + o_scale=o_scale, ) self.assertEqual(actual, math_ref, atol=1.0, rtol=5e-6) - instantiate_parametrized_tests(TestOps) diff --git a/torchao/csrc/cpu/int8_sdpa.cpp b/torchao/csrc/cpu/int8_sdpa.cpp index 3872107e7d..591d1796df 100644 --- a/torchao/csrc/cpu/int8_sdpa.cpp +++ b/torchao/csrc/cpu/int8_sdpa.cpp @@ -7,7 +7,6 @@ #include #include -#include #include #include @@ -45,6 +44,8 @@ inline double calculate_scale( } #ifdef CPU_CAPABILITY_AVX512 +#include + template inline void fill_stub(scalar_t* data, scalar_t val, int64_t size) { using Vec = at::vec::Vectorized; @@ -2043,26 +2044,26 @@ at::Tensor _scaled_dot_product_int8_cpu( (attn_mask.dim() == 2 || attn_mask.dim() == 4), "_scaled_dot_product_int8_cpu: Attention mask dim in {2, 4}"); - if (!at::native::cpublas::could_pack(dtype)) { - return sdpa_int8_math_kernel(query, key, value, - dropout_p, is_causal, attn_mask, scale, - q_zp, q_scale, - k_zp, k_scale, - v_zp, v_scale, - a_zp, a_scale, - o_zp, o_scale).transpose(1, 2).contiguous().transpose(1, 2); - } - #ifdef CPU_CAPABILITY_AVX512 - at::Tensor output = at::empty_like(query, query.options()).transpose(1, 2); - sdpa_int8_fused_kernel(output, query, key, value, - dropout_p, is_causal, attn_mask, scale, - q_zp, q_scale, - k_zp, k_scale, - v_zp, v_scale, - a_zp, a_scale, - o_zp, o_scale); - return output.transpose(1, 2); + if (at::native::cpublas::could_pack(dtype)) { + at::Tensor output = at::empty_like(query, query.options()).transpose(1, 2); + sdpa_int8_fused_kernel(output, query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_zp, q_scale, + k_zp, k_scale, + v_zp, v_scale, + a_zp, a_scale, + o_zp, o_scale); + return output.transpose(1, 2); + } else { + return sdpa_int8_math_kernel(query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_zp, q_scale, + k_zp, k_scale, + v_zp, v_scale, + a_zp, a_scale, + o_zp, o_scale).transpose(1, 2).contiguous().transpose(1, 2); + } #else return sdpa_int8_math_kernel(query, key, value, dropout_p, is_causal, attn_mask, scale, diff --git a/torchao/ops.py b/torchao/ops.py index 40bbc39ba3..38a341f435 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -205,13 +205,25 @@ def scaled_dot_product_int8( Returns output of quantized SDPA """ - return torch.ops.torchao.scaled_dot_product_int8.default(query, key, value, - attn_mask, dropout_p, is_causal, scale, - q_zp, q_scale, - k_zp, k_scale, - v_zp, v_scale, - a_zp, a_scale, - o_zp, o_scale) + return torch.ops.torchao.scaled_dot_product_int8.default( + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + scale, + q_zp, + q_scale, + k_zp, + k_scale, + v_zp, + v_scale, + a_zp, + a_scale, + o_zp, + o_scale, + ) @register_custom_op("torchao::scaled_dot_product_int8") diff --git a/torchao/prototype/inductor/fx_passes/int8_sdpa_fusion.py b/torchao/prototype/inductor/fx_passes/int8_sdpa_fusion.py index 2d5772772b..bcf59ad35b 100644 --- a/torchao/prototype/inductor/fx_passes/int8_sdpa_fusion.py +++ b/torchao/prototype/inductor/fx_passes/int8_sdpa_fusion.py @@ -14,11 +14,16 @@ PatternMatcherPass, ) +__all__ = [ + "_int8_sdpa_init", +] + make_fallback(torch.ops.torchao.scaled_dot_product_int8.default) aten = torch.ops.aten patterns = PatternMatcherPass() + def _is_valid_int8_sdpa_pattern(): def fn(match): assert all(k in match.kwargs for k in ("query", "key", "value")) @@ -69,9 +74,9 @@ def int8_sdpa(match: Match, *args, **kwargs): trans_key, trans_value, attn_mask, - 0.0, #dropout - False, #is_causal - 1.0 / inv_scale, #scale + 0.0, # dropout + False, # is_causal + 1.0 / inv_scale, # scale q_zp, q_scale, k_zp, @@ -84,7 +89,9 @@ def int8_sdpa(match: Match, *args, **kwargs): o_scale, ) trans_output = L[aten.permute.default](output, [0, 2, 1, 3]) - return L[aten.clone.default](trans_output, memory_format=torch.contiguous_format) + return L[aten.clone.default]( + trans_output, memory_format=torch.contiguous_format + ) return int8_sdpa @@ -416,12 +423,18 @@ def _register_int8_sdpa_fp32_lowering(): # dtype = float32, without attention mask, batch size > 1 _register_int8_sdpa_pattern( _get_int8_sdpa_final_pattern( - has_mask=False, is_batch_size_1=False, is_reduced_type=False, has_convert=True + has_mask=False, + is_batch_size_1=False, + is_reduced_type=False, + has_convert=True, ) ) _register_int8_sdpa_pattern( _get_int8_sdpa_final_pattern( - has_mask=False, is_batch_size_1=False, is_reduced_type=False, has_convert=False + has_mask=False, + is_batch_size_1=False, + is_reduced_type=False, + has_convert=False, ) ) @@ -430,12 +443,18 @@ def _register_int8_sdpa_fp32_mask_lowering(): # dtype = float32, with attention mask, batch size > 1 _register_int8_sdpa_pattern( _get_int8_sdpa_final_pattern( - has_mask=True, is_batch_size_1=False, is_reduced_type=False, has_convert=True + has_mask=True, + is_batch_size_1=False, + is_reduced_type=False, + has_convert=True, ) ) _register_int8_sdpa_pattern( _get_int8_sdpa_final_pattern( - has_mask=True, is_batch_size_1=False, is_reduced_type=False, has_convert=False + has_mask=True, + is_batch_size_1=False, + is_reduced_type=False, + has_convert=False, ) ) @@ -444,12 +463,18 @@ def _register_int8_sdpa_fp32_bs1_lowering(): # dtype = float32, without attention mask, batch size == 1 _register_int8_sdpa_pattern( _get_int8_sdpa_final_pattern( - has_mask=False, is_batch_size_1=True, is_reduced_type=False, has_convert=True + has_mask=False, + is_batch_size_1=True, + is_reduced_type=False, + has_convert=True, ) ) _register_int8_sdpa_pattern( _get_int8_sdpa_final_pattern( - has_mask=False, is_batch_size_1=True, is_reduced_type=False, has_convert=False + has_mask=False, + is_batch_size_1=True, + is_reduced_type=False, + has_convert=False, ) ) @@ -463,7 +488,10 @@ def _register_int8_sdpa_fp32_mask_bs1_lowering(): ) _register_int8_sdpa_pattern( _get_int8_sdpa_final_pattern( - has_mask=True, is_batch_size_1=True, is_reduced_type=False, has_convert=False + has_mask=True, + is_batch_size_1=True, + is_reduced_type=False, + has_convert=False, ) ) @@ -472,12 +500,18 @@ def _register_int8_sdpa_bf16_lowering(): # dtype = bfloat16, without attention mask, batch size > 1 _register_int8_sdpa_pattern( _get_int8_sdpa_final_pattern( - has_mask=False, is_batch_size_1=False, is_reduced_type=True, has_convert=True + has_mask=False, + is_batch_size_1=False, + is_reduced_type=True, + has_convert=True, ) ) _register_int8_sdpa_pattern( _get_int8_sdpa_final_pattern( - has_mask=False, is_batch_size_1=False, is_reduced_type=True, has_convert=False + has_mask=False, + is_batch_size_1=False, + is_reduced_type=True, + has_convert=False, ) ) @@ -491,7 +525,10 @@ def _register_int8_sdpa_bf16_mask_lowering(): ) _register_int8_sdpa_pattern( _get_int8_sdpa_final_pattern( - has_mask=True, is_batch_size_1=False, is_reduced_type=True, has_convert=False + has_mask=True, + is_batch_size_1=False, + is_reduced_type=True, + has_convert=False, ) ) @@ -505,7 +542,10 @@ def _register_int8_sdpa_bf16_bs1_lowering(): ) _register_int8_sdpa_pattern( _get_int8_sdpa_final_pattern( - has_mask=False, is_batch_size_1=True, is_reduced_type=True, has_convert=False + has_mask=False, + is_batch_size_1=True, + is_reduced_type=True, + has_convert=False, ) ) From 45ed3ccccc97ae61af4195249b9117dd3f76b755 Mon Sep 17 00:00:00 2001 From: Valentine233 Date: Tue, 4 Mar 2025 03:35:45 -0500 Subject: [PATCH 18/36] fix issue --- test/prototype/inductor/test_int8_sdpa_fusion.py | 13 +++++++------ torchao/prototype/inductor/__init__.py | 0 2 files changed, 7 insertions(+), 6 deletions(-) create mode 100644 torchao/prototype/inductor/__init__.py diff --git a/test/prototype/inductor/test_int8_sdpa_fusion.py b/test/prototype/inductor/test_int8_sdpa_fusion.py index 0dad80cfa4..985a3af544 100644 --- a/test/prototype/inductor/test_int8_sdpa_fusion.py +++ b/test/prototype/inductor/test_int8_sdpa_fusion.py @@ -2,17 +2,11 @@ import pytest import torch -import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq import torch.utils.checkpoint from torch._dynamo.utils import counters from torch._inductor import config from torch._inductor.test_case import TestCase, run_tests from torch._inductor.utils import run_and_get_code -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e -from torch.ao.quantization.quantizer.x86_inductor_quantizer import ( - X86InductorQuantizer, -) -from torch.export import export_for_training from torch.testing._internal.common_utils import IS_LINUX, skipIfRocm from torch.testing._internal.inductor_utils import HAS_CPU @@ -151,6 +145,13 @@ def _check_common( ) @config.patch({"freezing": True}) def _test_sdpa_int8_rewriter(self): + import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq + from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e + from torch.ao.quantization.quantizer.x86_inductor_quantizer import ( + X86InductorQuantizer, + ) + from torch.export import export_for_training + # pattern is different for bs=1 for dtype, has_mask, bs in itertools.product( [torch.float32, torch.bfloat16], [True, False], [56, 1] diff --git a/torchao/prototype/inductor/__init__.py b/torchao/prototype/inductor/__init__.py new file mode 100644 index 0000000000..e69de29bb2 From b3b9b3980a1e6a47095cb656df1faaa6be5858dd Mon Sep 17 00:00:00 2001 From: Valentine233 Date: Tue, 4 Mar 2025 04:12:56 -0500 Subject: [PATCH 19/36] fix issue --- test/prototype/inductor/test_int8_sdpa_fusion.py | 1 + torchao/csrc/cpu/int8_sdpa.cpp | 1 + 2 files changed, 2 insertions(+) diff --git a/test/prototype/inductor/test_int8_sdpa_fusion.py b/test/prototype/inductor/test_int8_sdpa_fusion.py index 985a3af544..b2506e3d5f 100644 --- a/test/prototype/inductor/test_int8_sdpa_fusion.py +++ b/test/prototype/inductor/test_int8_sdpa_fusion.py @@ -153,6 +153,7 @@ def _test_sdpa_int8_rewriter(self): from torch.export import export_for_training # pattern is different for bs=1 + torch.manual_seed(1234) for dtype, has_mask, bs in itertools.product( [torch.float32, torch.bfloat16], [True, False], [56, 1] ): diff --git a/torchao/csrc/cpu/int8_sdpa.cpp b/torchao/csrc/cpu/int8_sdpa.cpp index 591d1796df..fbffdc871a 100644 --- a/torchao/csrc/cpu/int8_sdpa.cpp +++ b/torchao/csrc/cpu/int8_sdpa.cpp @@ -1,3 +1,4 @@ +#pragma once #include #include #include From 1054e888656a92e1cdaf3b74317e739ece2cc14b Mon Sep 17 00:00:00 2001 From: Valentine233 Date: Tue, 4 Mar 2025 21:32:49 -0500 Subject: [PATCH 20/36] fix issue --- setup.py | 6 ++++++ .../inductor/test_int8_sdpa_fusion.py | 9 ++++---- test/test_ops.py | 21 +++++++------------ 3 files changed, 19 insertions(+), 17 deletions(-) diff --git a/setup.py b/setup.py index 2fe4c7e912..6b01308194 100644 --- a/setup.py +++ b/setup.py @@ -358,6 +358,12 @@ def get_extensions(): # Collect C++ source files sources = list(glob.glob(os.path.join(extensions_dir, "**/*.cpp"), recursive=True)) + if IS_WINDOWS: + # Remove csrc/cpu/*.cpp on Windows due to the link issue: unresolved external symbol PyInit__C + cpp_sources = list( + glob.glob(os.path.join(extensions_dir, "cpu/*.cpp"), recursive=True) + ) + sources = [s for s in sources if s not in cpp_sources] # Collect CUDA source files extensions_cuda_dir = os.path.join(extensions_dir, "cuda") diff --git a/test/prototype/inductor/test_int8_sdpa_fusion.py b/test/prototype/inductor/test_int8_sdpa_fusion.py index b2506e3d5f..a701dde1a9 100644 --- a/test/prototype/inductor/test_int8_sdpa_fusion.py +++ b/test/prototype/inductor/test_int8_sdpa_fusion.py @@ -12,6 +12,7 @@ from torchao.prototype.inductor.fx_passes.int8_sdpa_fusion import _int8_sdpa_init from torchao.utils import TORCH_VERSION_AT_LEAST_2_7 +from torch.utils.cpp_extension import IS_WINDOWS class SelfAttnLikeModule(torch.nn.Module): @@ -143,6 +144,7 @@ def _check_common( @pytest.mark.skipif( not TORCH_VERSION_AT_LEAST_2_7, reason="int8 sdpa requires torch 2.7 or later" ) + @pytest.mark.skipif(IS_WINDOWS, reason="int8 sdpa does not support windows yet") @config.patch({"freezing": True}) def _test_sdpa_int8_rewriter(self): import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq @@ -167,9 +169,8 @@ def _test_sdpa_int8_rewriter(self): inputs = ( torch.randn( (bs, seqlen, headsize * numhead), device=self.device, dtype=dtype - ) - * 10, - torch.randn((bs, 1, 1, seqlen), device=self.device) * 10 + ), + torch.randn((bs, 1, 1, seqlen), device=self.device) if has_mask else None, ) @@ -177,7 +178,7 @@ def _test_sdpa_int8_rewriter(self): with ( torch.no_grad(), torch.amp.autocast( - "cpu", enabled=enable_autocast, dtype=torch.bfloat16 + self.device, enabled=enable_autocast, dtype=torch.bfloat16 ), ): _int8_sdpa_init() diff --git a/test/test_ops.py b/test/test_ops.py index 3d165514a4..e0e2e7106b 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -29,6 +29,7 @@ TORCH_VERSION_AT_LEAST_2_7, compute_max_diff, ) +from torch.utils.cpp_extension import IS_WINDOWS IS_CUDA = torch.cuda.is_available() and torch.version.cuda IS_ROCM = torch.cuda.is_available() and torch.version.hip @@ -152,22 +153,16 @@ def _scaled_dot_product_int8_op_ref( out = torch.clamp(torch.round(out / o_scale) + o_zp, min=0, max=255) return out.to(torch.uint8) - SDPA_INT8_BATCH_SIZE = [56, 120] - SDPA_INT8_NUM_HEADS = [2, 16] - SDPA_INT8_Q_SEQ_LEN = [18, 89] - SDPA_INT8_KV_SEQ_LEN = [100, 253] - SDPA_INT8_HEAD_DIM = [32, 64] - SDPA_INT8_MASK_DTYPE = [None, torch.float32, torch.bfloat16] - @pytest.mark.skipif( not TORCH_VERSION_AT_LEAST_2_7, reason="int8 sdpa requires torch 2.7 or later" ) - @parametrize("batch_size", SDPA_INT8_BATCH_SIZE) - @parametrize("n_head", SDPA_INT8_NUM_HEADS) - @parametrize("q_seq_len", SDPA_INT8_Q_SEQ_LEN) - @parametrize("kv_seq_len", SDPA_INT8_KV_SEQ_LEN) - @parametrize("head_dim", SDPA_INT8_HEAD_DIM) - @parametrize("mask_dtype", SDPA_INT8_MASK_DTYPE) + @pytest.mark.skipif(IS_WINDOWS, reason="int8 sdpa does not support windows yet") + @parametrize("batch_size", [56, 120]) + @parametrize("n_head", [2, 16]) + @parametrize("q_seq_len", [18, 89]) + @parametrize("kv_seq_len", [100, 253]) + @parametrize("head_dim", [32, 64]) + @parametrize("mask_dtype", [None, torch.float32, torch.bfloat16]) def test_scaled_dot_product_int8_op( self, batch_size, n_head, q_seq_len, kv_seq_len, head_dim, mask_dtype ): From 3d192a7f47c1e9621d3abd32ce9c4e538f0bc919 Mon Sep 17 00:00:00 2001 From: Valentine233 Date: Tue, 4 Mar 2025 21:53:44 -0500 Subject: [PATCH 21/36] update --- setup.py | 6 +++++- test/prototype/inductor/test_int8_sdpa_fusion.py | 2 +- test/test_ops.py | 2 +- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 6b01308194..b726bf681f 100644 --- a/setup.py +++ b/setup.py @@ -55,7 +55,11 @@ def read_version(file_path="version.txt"): and platform.system() == "Darwin" ) -use_cpp_avx512 = os.getenv("USE_AVX512", "1") == "1" and build_torchao_experimental +use_cpp_avx512 = ( + os.getenv("USE_AVX512", "1") == "1" + and use_cpp == "1" + and platform.system() == "Linux" +) from torchao.utils import TORCH_VERSION_AT_LEAST_2_7 diff --git a/test/prototype/inductor/test_int8_sdpa_fusion.py b/test/prototype/inductor/test_int8_sdpa_fusion.py index a701dde1a9..5c005b1c79 100644 --- a/test/prototype/inductor/test_int8_sdpa_fusion.py +++ b/test/prototype/inductor/test_int8_sdpa_fusion.py @@ -9,10 +9,10 @@ from torch._inductor.utils import run_and_get_code from torch.testing._internal.common_utils import IS_LINUX, skipIfRocm from torch.testing._internal.inductor_utils import HAS_CPU +from torch.utils.cpp_extension import IS_WINDOWS from torchao.prototype.inductor.fx_passes.int8_sdpa_fusion import _int8_sdpa_init from torchao.utils import TORCH_VERSION_AT_LEAST_2_7 -from torch.utils.cpp_extension import IS_WINDOWS class SelfAttnLikeModule(torch.nn.Module): diff --git a/test/test_ops.py b/test/test_ops.py index e0e2e7106b..7cd5c2d00a 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -15,6 +15,7 @@ parametrize, ) from torch.testing._internal.optests import opcheck +from torch.utils.cpp_extension import IS_WINDOWS import torchao from torchao.dtypes.floatx import from_scaled_tc_floatx @@ -29,7 +30,6 @@ TORCH_VERSION_AT_LEAST_2_7, compute_max_diff, ) -from torch.utils.cpp_extension import IS_WINDOWS IS_CUDA = torch.cuda.is_available() and torch.version.cuda IS_ROCM = torch.cuda.is_available() and torch.version.hip From 9acf6f0e9dfa50e27514092da9cbac0582ea6097 Mon Sep 17 00:00:00 2001 From: Valentine233 Date: Tue, 11 Mar 2025 02:44:04 -0400 Subject: [PATCH 22/36] fix issue --- torchao/csrc/cpu/int8_sdpa.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/csrc/cpu/int8_sdpa.cpp b/torchao/csrc/cpu/int8_sdpa.cpp index fbffdc871a..16cda3db7d 100644 --- a/torchao/csrc/cpu/int8_sdpa.cpp +++ b/torchao/csrc/cpu/int8_sdpa.cpp @@ -23,6 +23,7 @@ #include #include #include +#include namespace torchao { @@ -45,7 +46,6 @@ inline double calculate_scale( } #ifdef CPU_CAPABILITY_AVX512 -#include template inline void fill_stub(scalar_t* data, scalar_t val, int64_t size) { From dc78455742eb7320b257950e991879b5643cd6c3 Mon Sep 17 00:00:00 2001 From: Valentine233 Date: Tue, 11 Mar 2025 02:45:00 -0400 Subject: [PATCH 23/36] fix issue --- setup.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/setup.py b/setup.py index b726bf681f..d9cd10dc52 100644 --- a/setup.py +++ b/setup.py @@ -55,11 +55,7 @@ def read_version(file_path="version.txt"): and platform.system() == "Darwin" ) -use_cpp_avx512 = ( - os.getenv("USE_AVX512", "1") == "1" - and use_cpp == "1" - and platform.system() == "Linux" -) +use_cpp_avx512 = os.getenv("USE_AVX512", "1") == "1" and platform.system() == "Linux" from torchao.utils import TORCH_VERSION_AT_LEAST_2_7 From 97485c5a701bb3f16b5b33a421498392c7145696 Mon Sep 17 00:00:00 2001 From: Valentine233 Date: Wed, 26 Mar 2025 07:56:34 -0400 Subject: [PATCH 24/36] update --- setup.py | 8 +- torchao/csrc/cpu/int8_sdpa.cpp | 231 +++++++++++++++++---------------- 2 files changed, 126 insertions(+), 113 deletions(-) diff --git a/setup.py b/setup.py index d9cd10dc52..b6a2d77f19 100644 --- a/setup.py +++ b/setup.py @@ -55,7 +55,12 @@ def read_version(file_path="version.txt"): and platform.system() == "Darwin" ) -use_cpp_avx512 = os.getenv("USE_AVX512", "1") == "1" and platform.system() == "Linux" +import torch +use_cpp_avx512 = ( + os.getenv("USE_AVX512", "1") == "1" + and torch._C._cpu._is_avx512_supported() + and platform.system() == "Linux" +) from torchao.utils import TORCH_VERSION_AT_LEAST_2_7 @@ -72,7 +77,6 @@ def use_debug_mode(): return os.getenv("DEBUG", "0") == "1" -import torch from torch.utils.cpp_extension import ( CUDA_HOME, IS_WINDOWS, diff --git a/torchao/csrc/cpu/int8_sdpa.cpp b/torchao/csrc/cpu/int8_sdpa.cpp index 16cda3db7d..5fdfe56f4a 100644 --- a/torchao/csrc/cpu/int8_sdpa.cpp +++ b/torchao/csrc/cpu/int8_sdpa.cpp @@ -96,16 +96,16 @@ inline void _store(scalar_t* dst, at::vec::Vectorized src, int size=at template inline typename std::enable_if_t, void> -_store(scalar_t* dst, at::vec::Vectorized src) { +_store(scalar_t* dst, at::vec::Vectorized src, int size=at::vec::Vectorized::size()) { auto res = at::vec::convert_from_float(src, src); - res.store(dst, at::vec::Vectorized::size()); + res.store(dst, size); } template inline typename std::enable_if_t || std::is_same_v, void> -_store(scalar_t* dst, at::vec::Vectorized src) { +_store(scalar_t* dst, at::vec::Vectorized src, int size=at::vec::Vectorized::size()) { auto res = at::vec::convert(src); - res.store(dst, at::vec::Vectorized::size()); + res.store(dst, size); } /* @@ -139,7 +139,8 @@ inline void _dequant_mask_max_fusion_kernel( const mask_t* mask_data_ptr = mask_ptr + row * ldm; float tmp_max = -std::numeric_limits::infinity(); auto vec_tmp_max = at::vec::Vectorized(tmp_max); - for (long col = 0; col < vec_size * (N / vec_size); col += vec_size) { + long col = 0; + for (; col < vec_size * (N / vec_size); col += vec_size) { auto vec_sum_b = at::vec::Vectorized::loadu(sum_b_ptr + col); auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col); auto tmp1 = tmp0 - vec_sum_b; @@ -153,22 +154,21 @@ inline void _dequant_mask_max_fusion_kernel( vec_tmp_max = at::vec::clamp_min(vec_tmp_max, tmp8); _store(tmp_out + col, tmp8); } - tmp_max = std::max(tmp_max, vec_tmp_max.reduce_max()); - for (long col = vec_size * (N / vec_size); col < N; col++) { - auto sum_b = sum_b_ptr[col]; - auto tmp0 = tmp_in[col]; - auto tmp1 = tmp0 - sum_b; - auto tmp2 = tmp1 - sum_a; - auto tmp3 = tmp2 + beta; - auto tmp4 = (float) tmp3; - auto tmp5 = tmp4 * alpha; - auto tmp6 = mask_data_ptr[col]; - auto tmp7 = (float) tmp6; + if (col < N) { + auto vec_sum_b = at::vec::Vectorized::loadu(sum_b_ptr + col, N - col); + auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col, N - col); + auto tmp1 = tmp0 - vec_sum_b; + auto tmp2 = tmp1 - vec_sum_a; + auto tmp3 = tmp2 + vec_beta; + auto tmp4 = at::vec::convert(tmp3); + auto tmp5 = tmp4 * vec_alpha; + auto tmp6 = at::vec::Vectorized::loadu(mask_data_ptr + col, N - col); + auto tmp7 = at::vec::convert(tmp6); auto tmp8 = tmp5 + tmp7; - tmp_max = std::max(tmp_max, tmp8); - tmp_out[col] = tmp8; + _store(tmp_out + col, tmp8, N - col); + vec_tmp_max = at::vec::Vectorized::set(vec_tmp_max, at::vec::clamp_min(vec_tmp_max, tmp8), N - col); } - sfm_max_ptr[row] = std::max(sfm_max_ptr[row], tmp_max); + sfm_max_ptr[row] = std::max(sfm_max_ptr[row], vec_tmp_max.reduce_max()); } } @@ -198,7 +198,8 @@ inline void _dequant_max_fusion_kernel( float* tmp_out = out + row * ldo; float tmp_max = -std::numeric_limits::infinity(); auto vec_tmp_max = at::vec::Vectorized(tmp_max); - for (long col = 0; col < vec_size * (N / vec_size); col += vec_size) { + long col = 0; + for (; col < vec_size * (N / vec_size); col += vec_size) { auto vec_sum_b = at::vec::Vectorized::loadu(sum_b_ptr + col); auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col); auto tmp1 = tmp0 - vec_sum_b; @@ -209,19 +210,18 @@ inline void _dequant_max_fusion_kernel( vec_tmp_max = at::vec::clamp_min(vec_tmp_max, tmp5); _store(tmp_out + col, tmp5); } - tmp_max = std::max(tmp_max, vec_tmp_max.reduce_max()); - for (long col = vec_size * (N / vec_size); col < N; col++) { - auto sum_b = sum_b_ptr[col]; - auto tmp0 = tmp_in[col]; - auto tmp1 = tmp0 - sum_b; - auto tmp2 = tmp1 - sum_a; - auto tmp3 = tmp2 + beta; - auto tmp4 = (float) tmp3; - auto tmp5 = tmp4 * alpha; - tmp_max = std::max(tmp_max, tmp5); - tmp_out[col] = tmp5; + if (col < N) { + auto vec_sum_b = at::vec::Vectorized::loadu(sum_b_ptr + col, N - col); + auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col, N - col); + auto tmp1 = tmp0 - vec_sum_b; + auto tmp2 = tmp1 - vec_sum_a; + auto tmp3 = tmp2 + vec_beta; + auto tmp4 = at::vec::convert(tmp3); + auto tmp5 = tmp4 * vec_alpha; + _store(tmp_out + col, tmp5, N - col); + vec_tmp_max = at::vec::Vectorized::set(vec_tmp_max, at::vec::clamp_min(vec_tmp_max, tmp5), N - col); } - sfm_max_ptr[row] = std::max(sfm_max_ptr[row], tmp_max); + sfm_max_ptr[row] = std::max(sfm_max_ptr[row], vec_tmp_max.reduce_max()); } } @@ -270,22 +270,22 @@ inline void _sub_exp_sum_div_quant_sum_fusion_kernel( float tmp_sum = 0; auto vec_tmp_sum = at::vec::Vectorized(tmp_sum); float* tmp_out = local + n; - for (long col = 0; col < vec_size * (kvBlockSize / vec_size); col += vec_size) { + long col = 0; + for (; col < vec_size * (kvBlockSize / vec_size); col += vec_size) { auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col); auto tmp1 = tmp0 - vec_max; auto tmp2 = tmp1.exp_u20(); vec_tmp_sum += tmp2; _store(tmp_out + col, tmp2); } - tmp_sum += vec_tmp_sum.reduce_add(); - for (long col = vec_size * (kvBlockSize / vec_size); col < kvBlockSize; col++) { - auto tmp0 = tmp_in[col]; - auto tmp1 = tmp0 - sfm_max; - auto tmp2 = exp(tmp1); - tmp_sum += tmp2; - tmp_out[col] = tmp2; + if (col < kvBlockSize) { + auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col, kvBlockSize - col); + auto tmp1 = tmp0 - vec_max; + auto tmp2 = tmp1.exp_u20(); + _store(tmp_out + col, tmp2, kvBlockSize - col); + vec_tmp_sum = at::vec::Vectorized::set(vec_tmp_sum, vec_tmp_sum + tmp2, kvBlockSize - col); } - sfm_sum_ptr[row] += tmp_sum; + sfm_sum_ptr[row] += vec_tmp_sum.reduce_add(); } // div sum, sum for attention auto sum_scale = 1 / sfm_sum_ptr[row] / alpha; @@ -298,7 +298,8 @@ inline void _sub_exp_sum_div_quant_sum_fusion_kernel( auto vec_tmp_sum = at::vec::Vectorized(tmp_sum); float* tmp_in = local + n; scalar_t* tmp_out = qk_reduced_block_data + l * ldo; - for (long col = 0; col < vec_size * (kvBlockSize / vec_size); col += vec_size) { + long col = 0; + for (; col < vec_size * (kvBlockSize / vec_size); col += vec_size) { auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col); auto tmp1 = tmp0 * vec_sum_scale; auto tmp2 = tmp1.round(); @@ -308,24 +309,24 @@ inline void _sub_exp_sum_div_quant_sum_fusion_kernel( auto tmp6 = at::vec::convert(tmp4); vec_tmp_sum += tmp6; } - tmp_sum += vec_tmp_sum.reduce_add(); - for (long col = vec_size * (kvBlockSize / vec_size); col < kvBlockSize; col++) { - auto tmp0 = tmp_in[col]; - auto tmp1 = tmp0 * sum_scale; - auto tmp2 = std::nearbyint(tmp1); - auto tmp3 = tmp2 + beta1_float; - auto tmp4 = std::clamp(tmp3, min_val, max_val); - tmp_out[col] = tmp4; - auto tmp6 = (int32_t) tmp4; - tmp_sum += tmp6; + if (col < kvBlockSize) { + auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col, kvBlockSize - col); + auto tmp1 = tmp0 * vec_sum_scale; + auto tmp2 = tmp1.round(); + auto tmp3 = tmp2 + vec_beta1; + auto tmp4 = at::vec::clamp(tmp3, vec_min_val, vec_max_val); + _store(tmp_out + col, tmp4, kvBlockSize - col); + auto tmp6 = at::vec::convert(tmp4); + vec_tmp_sum = at::vec::Vectorized::set(vec_tmp_sum, vec_tmp_sum + tmp6, kvBlockSize - col); } - sum_a_ptr[row] += tmp_sum * beta2; + sum_a_ptr[row] += vec_tmp_sum.reduce_add() * beta2; // set zero - for (long col = kvBlockSize; col < vec_size * (av_gemm_K / vec_size); col += vec_size) { + col = kvBlockSize; + for (; col < vec_size * (av_gemm_K / vec_size); col += vec_size) { _store(tmp_out + col, vec_zero); } - for (long col = vec_size * (av_gemm_K / vec_size); col < av_gemm_K; col++) { - tmp_out[col] = zero; + if (col < av_gemm_K) { + _store(tmp_out + col, vec_zero, av_gemm_K - col); } } } @@ -369,22 +370,22 @@ inline void _sub_exp_sum_div_quant_fusion_kernel( float tmp_sum = 0; auto vec_tmp_sum = at::vec::Vectorized(tmp_sum); float* tmp_out = local + n; - for (long col = 0; col < vec_size * (kvBlockSize / vec_size); col += vec_size) { + long col = 0; + for (; col < vec_size * (kvBlockSize / vec_size); col += vec_size) { auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col); auto tmp1 = tmp0 - vec_max; auto tmp2 = tmp1.exp_u20(); vec_tmp_sum += tmp2; _store(tmp_out + col, tmp2); } - tmp_sum += vec_tmp_sum.reduce_add(); - for (long col = vec_size * (kvBlockSize / vec_size); col < kvBlockSize; col++) { - auto tmp0 = tmp_in[col]; - auto tmp1 = tmp0 - sfm_max; - auto tmp2 = exp(tmp1); - tmp_sum += tmp2; - tmp_out[col] = tmp2; + if (col < kvBlockSize) { + auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col, kvBlockSize - col); + auto tmp1 = tmp0 - vec_max; + auto tmp2 = tmp1.exp_u20(); + vec_tmp_sum = at::vec::Vectorized::set(vec_tmp_sum, vec_tmp_sum + tmp2, kvBlockSize - col); + _store(tmp_out + col, tmp2, kvBlockSize - col); } - sfm_sum_ptr[row] += tmp_sum; + sfm_sum_ptr[row] += vec_tmp_sum.reduce_add(); } // div sum, sum for attention auto sum_scale = 1 / sfm_sum_ptr[row] / alpha; @@ -395,7 +396,8 @@ inline void _sub_exp_sum_div_quant_fusion_kernel( int64_t kvBlockSize = std::min(N_step, kvSize - n); float* tmp_in = local + n; scalar_t* tmp_out = qk_reduced_block_data + l * ldo; - for (long col = 0; col < vec_size * (kvBlockSize / vec_size); col += vec_size) { + long col = 0; + for (; col < vec_size * (kvBlockSize / vec_size); col += vec_size) { auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col); auto tmp1 = tmp0 * vec_sum_scale; auto tmp2 = tmp1.round(); @@ -403,20 +405,21 @@ inline void _sub_exp_sum_div_quant_fusion_kernel( auto tmp4 = at::vec::clamp(tmp3, vec_min_val, vec_max_val); _store(tmp_out + col, tmp4); } - for (long col = vec_size * (kvBlockSize / vec_size); col < kvBlockSize; col++) { - auto tmp0 = tmp_in[col]; - auto tmp1 = tmp0 * sum_scale; - auto tmp2 = std::nearbyint(tmp1); - auto tmp3 = tmp2 + beta1_float; - auto tmp4 = std::clamp(tmp3, min_val, max_val); - tmp_out[col] = tmp4; + if (col < kvBlockSize) { + auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col, kvBlockSize - col); + auto tmp1 = tmp0 * vec_sum_scale; + auto tmp2 = tmp1.round(); + auto tmp3 = tmp2 + vec_beta1; + auto tmp4 = at::vec::clamp(tmp3, vec_min_val, vec_max_val); + _store(tmp_out + col, tmp4, kvBlockSize - col); } // set zero - for (long col = kvBlockSize; col < vec_size * (av_gemm_K / vec_size); col += vec_size) { + col = kvBlockSize; + for (; col < vec_size * (av_gemm_K / vec_size); col += vec_size) { _store(tmp_out + col, vec_zero); } - for (long col = vec_size * (av_gemm_K / vec_size); col < av_gemm_K; col++) { - tmp_out[col] = zero; + if (col < av_gemm_K) { + _store(tmp_out + col, vec_zero, av_gemm_K - col); } } } @@ -453,7 +456,8 @@ inline void _dequant_quant_fusion_kernel( auto vec_sum_a = at::vec::Vectorized(sum_a); const int32_t* tmp_in = in + row * ldi; scalar_t* tmp_out = out + row * ldo; - for (long col = 0; col < vec_size * (N / vec_size); col += vec_size) { + long col = 0; + for (; col < vec_size * (N / vec_size); col += vec_size) { auto vec_sum_b = at::vec::Vectorized::loadu(sum_b_ptr + col); auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col); auto tmp1 = tmp0 - vec_sum_b; @@ -466,18 +470,18 @@ inline void _dequant_quant_fusion_kernel( auto tmp8 = at::vec::clamp(tmp7, vec_min_val, vec_max_val); _store(tmp_out + col, tmp8); } - for (long col = vec_size * (N / vec_size); col < N; col++) { - auto sum_b = sum_b_ptr[col]; - auto tmp0 = tmp_in[col]; - auto tmp1 = tmp0 - sum_b; - auto tmp2 = tmp1 - sum_a; - auto tmp3 = tmp2 + beta1; - auto tmp4 = (float) tmp3; - auto tmp5 = tmp4 * alpha; - auto tmp6 = std::nearbyint(tmp5); - auto tmp7 = tmp6 + beta2_float; - auto tmp8 = std::clamp(tmp7, min_val, max_val); - tmp_out[col] = tmp8; + if (col < N) { + auto vec_sum_b = at::vec::Vectorized::loadu(sum_b_ptr + col, N - col); + auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col, N - col); + auto tmp1 = tmp0 - vec_sum_b; + auto tmp2 = tmp1 - vec_sum_a; + auto tmp3 = tmp2 + vec_beta1; + auto tmp4 = at::vec::convert(tmp3); + auto tmp5 = tmp4 * vec_alpha; + auto tmp6 = tmp5.round(); + auto tmp7 = tmp6 + vec_beta2; + auto tmp8 = at::vec::clamp(tmp7, vec_min_val, vec_max_val); + _store(tmp_out + col, tmp8, N - col); } } } @@ -491,16 +495,18 @@ inline void _int_sum_b_contiguous_kernel_helper( const int32_t vec_size = at::vec::Vectorized::size(); int32_t tmp_sum = 0; auto vec_tmp_sum = at::vec::Vectorized(tmp_sum); - for (long i = 0; i < vec_size * (N / vec_size); i += vec_size) { + long i = 0; + for (; i < vec_size * (N / vec_size); i += vec_size) { auto tmp0 = at::vec::Vectorized::loadu(in + i); auto tmp1 = at::vec::convert(tmp0); vec_tmp_sum = vec_tmp_sum + tmp1; } - tmp_sum += vec_tmp_sum.reduce_add(); - for (long i = vec_size * (N / vec_size); i < N; i++) { - tmp_sum += static_cast(in[i]); + if (i < N) { + auto tmp0 = at::vec::Vectorized::loadu(in + i, N - i); + auto tmp1 = at::vec::convert(tmp0); + vec_tmp_sum = at::vec::Vectorized::set(vec_tmp_sum, vec_tmp_sum + tmp1, N - i); } - out[0] = tmp_sum * scale; + out[0] = vec_tmp_sum.reduce_add() * scale; } template @@ -529,40 +535,43 @@ inline void _int_sum_a_contiguous_kernel( // initialization with 0 int32_t zero = 0; auto vec_zero = at::vec::Vectorized(zero); - for (long i = 0; i < vec_size * (M / vec_size); i += vec_size) { + long i = 0; + for (; i < vec_size * (M / vec_size); i += vec_size) { _store(out + i, vec_zero); } - for (long i = vec_size * (M / vec_size); i < M; i++) { - out[i] = zero; + if (i < M) { + _store(out + i, vec_zero, M - i); } // sum for (long j = 0; j < N; j++) { const scalar_t* tmp_in = in + j * ld; - for (long i = 0; i < vec_size * (M / vec_size); i += vec_size) { - auto tmp0 = at::vec::Vectorized::loadu(tmp_in + i); - auto tmp1 = at::vec::Vectorized::loadu(out + i); + long k = 0; + for (; k < vec_size * (M / vec_size); k += vec_size) { + auto tmp0 = at::vec::Vectorized::loadu(tmp_in + k); + auto tmp1 = at::vec::Vectorized::loadu(out + k); auto tmp2 = at::vec::convert(tmp0); auto tmp3 = tmp1 + tmp2; - _store(out + i, tmp3); + _store(out + k, tmp3); } - for (long i = vec_size * (M / vec_size); i < M; i++) { - auto tmp0 = tmp_in[i]; - auto tmp1 = out[i]; - auto tmp2 = static_cast(tmp0); + if (k < M) { + auto tmp0 = at::vec::Vectorized::loadu(tmp_in + k, M - k); + auto tmp1 = at::vec::Vectorized::loadu(out + k, M - k); + auto tmp2 = at::vec::convert(tmp0); auto tmp3 = tmp1 + tmp2; - out[i] = tmp3; + _store(out + k, tmp3, M - k); } } // scale - for (long i = 0; i < vec_size * (M / vec_size); i += vec_size) { + i = 0; + for (; i < vec_size * (M / vec_size); i += vec_size) { auto tmp0 = at::vec::Vectorized::loadu(out + i); auto tmp1 = tmp0 * vec_scale; _store(out + i, tmp1); } - for (long i = vec_size * (M / vec_size); i < M; i++) { - auto tmp0 = out[i]; - auto tmp1 = tmp0 * scale; - out[i] = tmp1; + if (i < M) { + auto tmp0 = at::vec::Vectorized::loadu(out + i, M - i); + auto tmp1 = tmp0 * vec_scale; + _store(out + i, tmp1, M - i); } } From 0fb02f7bca159f6dd927f40a6328a2714d2159ba Mon Sep 17 00:00:00 2001 From: Valentine233 Date: Wed, 9 Apr 2025 22:42:59 -0400 Subject: [PATCH 25/36] optm kernel --- torchao/csrc/cpu/int8_sdpa.cpp | 573 +++++++++++---------------------- 1 file changed, 194 insertions(+), 379 deletions(-) diff --git a/torchao/csrc/cpu/int8_sdpa.cpp b/torchao/csrc/cpu/int8_sdpa.cpp index 5fdfe56f4a..c88e0ae8e8 100644 --- a/torchao/csrc/cpu/int8_sdpa.cpp +++ b/torchao/csrc/cpu/int8_sdpa.cpp @@ -49,17 +49,14 @@ inline double calculate_scale( template inline void fill_stub(scalar_t* data, scalar_t val, int64_t size) { - using Vec = at::vec::Vectorized; - Vec data_vec = Vec(val); + const int32_t vec_size = at::vec::Vectorized::size(); + auto data_vec = at::vec::Vectorized(val); int64_t d = 0; - for (; d < size - (size % Vec::size()); d += Vec::size()) { + for (; d < size - (size % vec_size); d += vec_size) { data_vec.store(data + d); } - #if !defined(_MSC_VER) && !defined(COMPILING_FOR_MIN_SIZE) - # pragma unroll - #endif - for (; d < size; d++) { - data[d] = val; + if (d < size) { + data_vec.store(data + d, size - d); } } @@ -486,6 +483,56 @@ inline void _dequant_quant_fusion_kernel( } } +template +inline void _dequant_quant_fusion_kernel( + const int32_t* in, + const int32_t* sum_a_ptr, + const int& M, + const int& N, + const int& ldi, + const int& ldo, + const int32_t& beta2, // zp_c + const float& alpha, // scale_a*scale_b/scale_c + scalar_t* out) { + const int32_t vec_size = at::vec::Vectorized::size(); + float min_val = 0; + float max_val = 255; + auto vec_min_val = at::vec::Vectorized(min_val); + auto vec_max_val = at::vec::Vectorized(max_val); + // auto vec_beta1 = at::vec::Vectorized(beta1); + auto vec_alpha = at::vec::Vectorized(alpha); + float beta2_float = (float) beta2; + auto vec_beta2 = at::vec::Vectorized(beta2_float); + for (long row = 0; row < M; row += 1) { + auto sum_a = sum_a_ptr[row]; + auto vec_sum_a = at::vec::Vectorized(sum_a); + const int32_t* tmp_in = in + row * ldi; + scalar_t* tmp_out = out + row * ldo; + long col = 0; + for (; col < vec_size * (N / vec_size); col += vec_size) { + auto tmp1 = at::vec::Vectorized::loadu(tmp_in + col); + auto tmp3 = tmp1 - vec_sum_a; + // auto tmp3 = tmp2 + vec_beta1; + auto tmp4 = at::vec::convert(tmp3); + auto tmp5 = tmp4 * vec_alpha; + auto tmp6 = tmp5.round(); + auto tmp7 = tmp6 + vec_beta2; + auto tmp8 = at::vec::clamp(tmp7, vec_min_val, vec_max_val); + _store(tmp_out + col, tmp8); + } + if (col < N) { + auto tmp1 = at::vec::Vectorized::loadu(tmp_in + col, N - col); + auto tmp3 = tmp1 - vec_sum_a; + auto tmp4 = at::vec::convert(tmp3); + auto tmp5 = tmp4 * vec_alpha; + auto tmp6 = tmp5.round(); + auto tmp7 = tmp6 + vec_beta2; + auto tmp8 = at::vec::clamp(tmp7, vec_min_val, vec_max_val); + _store(tmp_out + col, tmp8, N - col); + } + } +} + template inline void _int_sum_b_contiguous_kernel_helper( const scalar_t* in, @@ -770,27 +817,11 @@ sdpa_int8_kernel_one_loop_impl( int64_t qSplitSize = q_split_size > qSize ? qSize : q_split_size; int64_t kvSplitSize = kv_split_size > kvSize ? kvSize : kv_split_size; int64_t qSlice = (qSize - 1) / qSplitSize + 1; - int64_t qTail = (qSize - 1) % qSplitSize + 1; int64_t kvSlice = (kvSize - 1) / kvSplitSize + 1; int64_t kvTail = (kvSize - 1) % kvSplitSize + 1; int64_t num_thread = at::get_num_threads(); int64_t rndHeadSize = (headSize + block_64 - 1L) / block_64 * block_64; - // one of 16, 32, 48, 64 - auto select_tail_tail_block_size = [](int64_t size) -> int64_t { - if (size == 0) { - return 0; - } else if (size <= 16) { - return 16; - } else if (size <= 32) { - return 32; - } else if (size <= 48) { - return 48; - } else { - return 64; - } - }; - int64_t kv_tail_tail_block_size = select_tail_tail_block_size(kvTail % block_64); int64_t rndkvSplitSize = (kvSplitSize + block_64 - 1L) / block_64 * block_64; int64_t rndkvTail = (kvTail + block_64 - 1L) / block_64 * block_64; int64_t rndkvSize = kv_split_size > kvSize ? rndkvTail : rndkvSplitSize * kvSlice + rndkvTail; @@ -809,10 +840,8 @@ sdpa_int8_kernel_one_loop_impl( : nullptr; scalar_t* out_data = output.data_ptr(); - // Create tpp kernels for Query @ Key - bool headSize_mul4 = headSize % 4 == 0; - // If K of Gemm is not even, use mkl gemm instead of tpp for BF16 - int qk_gemm_K_padding = headSize_mul4 ? 0 : 4 - headSize % 4; + bool headSize_mul64 = headSize % 64 == 0; + int qk_gemm_K_padding = headSize_mul64 ? 0 : 64 - headSize % 64; int qk_gemm_K = headSize + qk_gemm_K_padding; int64_t qk_reduce_strideL = qSplitSize * av_gemm_K; @@ -876,7 +905,7 @@ sdpa_int8_kernel_one_loop_impl( offset += qk_gemm_K * rndkvSize; scalar_t* value_reorder_ptr = reinterpret_cast(total_buf_ptr + offset); - uint8_t* B_blocked_xform_u8 = new uint8_t[std::max(qk_gemm_K, av_gemm_K) * block_64]; + uint8_t* B_blocked_xform_u8 = new uint8_t[qk_gemm_K * block_64]; for (const auto z : c10::irange(begin, end)) { (void)z; // Suppress unused variable @@ -899,101 +928,51 @@ sdpa_int8_kernel_one_loop_impl( // pack for (int64_t n = 0; n < kvSize; n += kvSplitSize) { - if (n + kvSplitSize < kvSize) { - for (int64_t b = 0; b < rndkvSplitSize; b += block_64) { - do_transpose( - k_data + i * kStrideB + j * kStrideH + n * kStrideN + b * kStrideN, - B_blocked_xform_u8, - std::min(int(kvSplitSize - b), block_64), - headSize, - kStrideN, - block_64); - if (!headSize_mul4) { - pad_remain_row_col( - B_blocked_xform_u8, - headSize, - block_64, - qk_gemm_K, - block_64, - block_64 - ); - } - // Pack - at::native::cpublas::pack( - qk_gemm_K, // K - block_64, // N - block_64, // ld_in - block_64, // ld_out - u8_dt, // dt_in - u8_dt, // dt_out - B_blocked_xform_u8, - key_reorder_ptr + n * qk_gemm_K + - b * qk_gemm_K); - } - // split headSize to block_64, block_64, block_64 ... - // [kvSplitSize, headSize] -> [av_gemm_K, block_64 ...] - for (int64_t b = 0; b < rndHeadSize; b += block_64) { - at::native::cpublas::pack( - av_gemm_K, - block_64, - vStrideN, // block_64, - block_64, - u8_dt, - u8_dt, - v_data + i * vStrideB + j * vStrideH + n * vStrideN + b, - value_reorder_ptr + n * rndHeadSize + - av_gemm_K * b); - } - } else { - // tail - auto block_size = kvTail >= block_64 ? block_64 : kv_tail_tail_block_size; - int64_t b = 0; - while (b < rndkvTail) { - do_transpose( - k_data + i * kStrideB + j * kStrideH + n * kStrideN + b * kStrideN, + int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n); + for (int64_t b = 0; b < kvBlockSize; b += block_64) { + bool istail = kvBlockSize - b < block_64; + int64_t trans_rows = istail ? kvBlockSize - b : block_64; + do_transpose( + k_data + i * kStrideB + j * kStrideH + n * kStrideN + b * kStrideN, + B_blocked_xform_u8, + trans_rows, + headSize, + kStrideN, + block_64); + if (!headSize_mul64 || istail) { + pad_remain_row_col( B_blocked_xform_u8, - std::min(kvTail - b, block_size), headSize, - kStrideN, - block_size); - if (!headSize_mul4) { - pad_remain_row_col( - B_blocked_xform_u8, - headSize, - block_size, - qk_gemm_K, - block_size, - block_size - ); - } - // Pack - at::native::cpublas::pack( + trans_rows, qk_gemm_K, - block_size, - block_size, - block_size, - u8_dt, - u8_dt, - B_blocked_xform_u8, - key_reorder_ptr + n * qk_gemm_K + - b * qk_gemm_K); - b += block_size; - block_size = (kvTail - b) >= block_64 ? block_64 : kv_tail_tail_block_size; - } - // split headSize to block_64, block_64, block_64 ... - // [kvTail, headSize] -> [av_gemm_K_tail, block_64 ...] - for (int64_t b = 0; b < headSize; b += block_64) { - at::native::cpublas::pack( - av_gemm_K, - block_64, - vStrideN, // block_64, block_64, - u8_dt, - u8_dt, - v_data + i * vStrideB + j * vStrideH + n * vStrideN + b, - value_reorder_ptr + n * rndHeadSize + - av_gemm_K * b); + block_64 + ); } + at::native::cpublas::pack( + qk_gemm_K, // K + block_64, // N + block_64, // ld_in + block_64, // ld_out + u8_dt, // dt_in + u8_dt, // dt_out + B_blocked_xform_u8, + key_reorder_ptr + n * qk_gemm_K + + b * qk_gemm_K); + } + // split headSize to block_64, block_64, block_64 ... + // [av_gemm_K, headSize] -> [av_gemm_K, block_64 ...] + for (int64_t b = 0; b < rndHeadSize; b += block_64) { + at::native::cpublas::pack( + av_gemm_K, + block_64, + vStrideN, // block_64, + block_64, + u8_dt, + u8_dt, + v_data + i * vStrideB + j * vStrideH + n * vStrideN + b, + value_reorder_ptr + n * rndHeadSize + + av_gemm_K * b); } } @@ -1030,74 +1009,17 @@ sdpa_int8_kernel_one_loop_impl( for (int64_t l = 0; l < rkvSlice; l++) { int64_t n = l * kvSplitSize; int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n); - // Calculate sums for dequant compensation item - if (qBlockSize == qSplitSize) { - // q main - if (n + kvSplitSize < kvSize) { - // k main - for (int64_t b = 0; b < kvSplitSize; b += block_64) { - at::native::cpublas::brgemm( - qSplitSize, block_64, qk_gemm_K, - qk_gemm_K, // lda - block_64, //ldb - rndkvSplitSize, //ldc, - false, - query_t_padding_ptr, - key_reorder_ptr + n * qk_gemm_K + - b * qk_gemm_K, - qk_s32_data + b); - } - } else { - auto block_size = kvTail >= block_64 ? block_64 : kv_tail_tail_block_size; - int64_t b = 0; - while (b < kvTail) { - at::native::cpublas::brgemm( - qSplitSize, block_size, qk_gemm_K, - qk_gemm_K, // lda - block_size, //ldb - rndkvTail, //ldc - false, - query_t_padding_ptr, - key_reorder_ptr + n * qk_gemm_K + - b * qk_gemm_K, - qk_s32_data + b); - b += block_size; - block_size = (kvTail - b) >= block_64 ? block_64 : kv_tail_tail_block_size; - } - } - } else { - if (n + kvSplitSize < kvSize) { - for (int64_t b = 0; b < kvSplitSize; b += block_64) { - at::native::cpublas::brgemm( - qTail, block_64, qk_gemm_K, - qk_gemm_K,// lda - block_64, //ldb - rndkvSplitSize, //ldc - false, - query_t_padding_ptr, - key_reorder_ptr + n * qk_gemm_K + - b * qk_gemm_K, - qk_s32_data + b); - } - } else { - // k tail - auto block_size = kvTail >= block_64 ? block_64 : kv_tail_tail_block_size; - int64_t b = 0; - while (b < kvTail) { - at::native::cpublas::brgemm( - qTail, block_size, qk_gemm_K, - qk_gemm_K, // lda - block_size, //ldb - rndkvTail, //ldc - false, - query_t_padding_ptr, - key_reorder_ptr + n * qk_gemm_K + - b * qk_gemm_K, - qk_s32_data + b); - b += block_size; - block_size = (kvTail - b) >= block_64 ? block_64 : kv_tail_tail_block_size; - } - } + for (int64_t b = 0; b < kvBlockSize; b += block_64) { + at::native::cpublas::brgemm( + qSplitSize, block_64, qk_gemm_K, + qk_gemm_K, // lda + block_64, //ldb + rndkvSplitSize, //ldc, + false, + query_t_padding_ptr, + key_reorder_ptr + n * qk_gemm_K + + b * qk_gemm_K, + qk_s32_data + b); } // do dequant compensation, add mask, max reduce for softmax, and convert qk from s32 to fp32 @@ -1307,34 +1229,17 @@ sdpa_int8_kernel_several_loops_impl( int64_t qSplitSize = q_split_size > qSize ? qSize : q_split_size; int64_t kvSplitSize = kv_split_size > kvSize ? kvSize : kv_split_size; int64_t qSlice = (qSize - 1) / qSplitSize + 1; - int64_t qTail = (qSize - 1) % qSplitSize + 1; int64_t kvSlice = (kvSize - 1) / kvSplitSize + 1; int64_t kvTail = (kvSize - 1) % kvSplitSize + 1; int64_t num_thread = at::get_num_threads(); int64_t rndHeadSize = (headSize + block_64 - 1L) / block_64 * block_64; - // one of 16, 32, 48, 64 - auto select_tail_tail_block_size = [](int64_t size) -> int64_t { - if (size == 0) { - return 0; - } else if (size <= 16) { - return 16; - } else if (size <= 32) { - return 32; - } else if (size <= 48) { - return 48; - } else { - return 64; - } - }; - int64_t kv_tail_tail_block_size = select_tail_tail_block_size(kvTail % block_64); int64_t rndkvSplitSize = (kvSplitSize + block_64 - 1L) / block_64 * block_64; int64_t rndkvTail = (kvTail + block_64 - 1L) / block_64 * block_64; int64_t rndkvSize = kv_split_size > kvSize ? rndkvTail : rndkvSplitSize * kvSlice + rndkvTail; bool av_gemm_K_mul4 = kvSplitSize % 4 == 0; int av_gemm_K_padding = av_gemm_K_mul4 ? 0 : 4 - kvSplitSize % 4; - // // If K of Gemm is not even, use mkl gemm instead of tpp for BF16 int av_gemm_K = kvSplitSize + av_gemm_K_padding; // Data ptrs @@ -1347,9 +1252,8 @@ sdpa_int8_kernel_several_loops_impl( scalar_t* out_data = output.data_ptr(); // Create tpp kernels for Query @ Key - bool headSize_mul4 = headSize % 4 == 0; - // If K of Gemm is not even, use mkl gemm instead of tpp for BF16 - int qk_gemm_K_padding = headSize_mul4 ? 0 : 4 - headSize % 4; + bool headSize_mul64 = headSize % 64 == 0; + int qk_gemm_K_padding = headSize_mul64 ? 0 : 64 - headSize % 64; int qk_gemm_K = headSize + qk_gemm_K_padding; int64_t qk_reduce_strideL = qSplitSize * av_gemm_K; @@ -1430,7 +1334,7 @@ sdpa_int8_kernel_several_loops_impl( int64_t i = 0, j = 0, l = 0, n = 0; at::native::data_index_init( begin, i, batchSize, j, num_head, l, kvSlice); - uint8_t* B_blocked_xform_u8 = new uint8_t[std::max(qk_gemm_K, av_gemm_K) * block_64]; + uint8_t* B_blocked_xform_u8 = new uint8_t[qk_gemm_K * block_64]; for (const auto z : c10::irange(begin, end)) { (void)z; // Suppress unused variable n = l * kvSplitSize; @@ -1439,96 +1343,49 @@ sdpa_int8_kernel_several_loops_impl( auto v_reorder = value_reorder_ptr + i * num_head * kvSlice * v_reorder_strideL + j * kvSlice * v_reorder_strideL + n * rndHeadSize; - if (n + kvSplitSize < kvSize) { - for (int64_t b = 0; b < rndkvSplitSize; b += block_64) { - do_transpose( - k_data + i * kStrideB + j * kStrideH + n * kStrideN + b * kStrideN, + int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n); + for (int64_t b = 0; b < kvBlockSize; b += block_64) { + bool istail = kvBlockSize - b < block_64; + int64_t trans_rows = istail ? kvBlockSize - b : block_64; + do_transpose( + k_data + i * kStrideB + j * kStrideH + n * kStrideN + b * kStrideN, + B_blocked_xform_u8, + trans_rows, + headSize, + kStrideN, + block_64); + if (!headSize_mul64 || istail) { + pad_remain_row_col( B_blocked_xform_u8, - std::min(int(kvSplitSize - b), block_64), headSize, - kStrideN, - block_64); - if (!headSize_mul4) { - pad_remain_row_col( + trans_rows, + qk_gemm_K, + block_64, + block_64 + ); + } + at::native::cpublas::pack( + qk_gemm_K, // K + block_64, // N + block_64, // ld_in + block_64, // ld_out + u8_dt, // dt_in + u8_dt, // dt_out B_blocked_xform_u8, - headSize, + k_reorder + b * qk_gemm_K); + } + // split headSize to block_64, block_64, block_64 ... + // [av_gemm_K, headSize] -> [av_gemm_K, block_64 ...] + for (int64_t b = 0; b < rndHeadSize; b += block_64) { + at::native::cpublas::pack( + av_gemm_K, block_64, - qk_gemm_K, + vStrideN, // block_64, block_64, - block_64 - ); - } - at::native::cpublas::pack( - qk_gemm_K, // K - block_64, // N - block_64, // ld_in - block_64, // ld_out - u8_dt, // dt_in - u8_dt, // dt_out - B_blocked_xform_u8, - k_reorder + b * qk_gemm_K); - } - // split headSize to block_64, block_64, block_64 ... - // [kvSplitSize, headSize] -> [av_gemm_K, block_64 ...] - for (int64_t b = 0; b < rndHeadSize; b += block_64) { - at::native::cpublas::pack( - av_gemm_K, - block_64, - vStrideN, // block_64, - block_64, - u8_dt, - u8_dt, - v_data + i * vStrideB + j * vStrideH + n * vStrideN + b, - v_reorder + av_gemm_K * b); - } - } else { - // tail - auto block_size = kvTail >= block_64 ? block_64 : kv_tail_tail_block_size; - int64_t b = 0; - while (b < rndkvTail) { - do_transpose( - k_data + i * kStrideB + j * kStrideH + n * kStrideN + b * kStrideN, - B_blocked_xform_u8, - std::min(kvTail - b, block_size), - headSize, - kStrideN, - block_size); - if (!headSize_mul4) { - pad_remain_row_col( - B_blocked_xform_u8, - headSize, - block_size, - qk_gemm_K, - block_size, - block_size - ); - } - // Pack - at::native::cpublas::pack( - qk_gemm_K, - block_size, - block_size, - block_size, - u8_dt, - u8_dt, - B_blocked_xform_u8, - k_reorder + b * qk_gemm_K); - b += block_size; - block_size = (kvTail - b) >= block_64 ? block_64 : kv_tail_tail_block_size; - } - // split headSize to block_64, block_64, block_64 ... - // [kvTail, headSize] -> [av_gemm_K_tail, block_64 ...] - for (int64_t b = 0; b < headSize; b += block_64) { - at::native::cpublas::pack( - av_gemm_K, - block_64, - vStrideN, // block_64, - block_64, - u8_dt, - u8_dt, - v_data + i * vStrideB + j * vStrideH + n * vStrideN + b, - v_reorder + av_gemm_K * b); - } + u8_dt, + u8_dt, + v_data + i * vStrideB + j * vStrideH + n * vStrideN + b, + v_reorder + av_gemm_K * b); } // Move to the next query at::native::data_index_step(i, batchSize, j, num_head, l, kvSlice); @@ -1582,98 +1439,42 @@ sdpa_int8_kernel_several_loops_impl( a_sum_ptr, static_cast(0), qSplitSize); fill_stub( sfm_max_ptr, static_cast(-std::numeric_limits::infinity()), qSplitSize); - int64_t num_keys = - is_causal ? std::min(m + qBlockSize, kvSize) : kvSize; copy_value_with_pad( q_data + i * qStrideB + j * qStrideH + m * qStrideM, query_t_padding_ptr, qBlockSize, headSize, - qBlockSize, + qSplitSize, //qSplitSize, qk_gemm_K, qStrideM); if (k_zp != 0) { - _int_sum_b_contiguous_kernel(q_data + i * qStrideB + j * qStrideH + m * qStrideM, - q_sum_ptr, qBlockSize, headSize, qStrideM, k_zp); + _int_sum_b_contiguous_kernel(query_t_padding_ptr, + q_sum_ptr, qBlockSize, headSize, qk_gemm_K, k_zp); } else { fill_stub( q_sum_ptr, static_cast(0), qSplitSize); } - const int64_t rkvSlice = (num_keys - 1) / kvSplitSize + 1; + const int64_t rkvSlice = (kvSize - 1) / kvSplitSize + 1; for (int64_t l = 0; l < rkvSlice; l++) { int64_t n = l * kvSplitSize; int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n); auto k_reorder = key_reorder_ptr + i * num_head * qk_gemm_K * rndkvSize + j * qk_gemm_K * rndkvSize + n * qk_gemm_K; // Calculate sums for dequant compensation item - if (qBlockSize == qSplitSize) { - // q main - if (n + kvSplitSize < kvSize) { - // k main - for (int64_t b = 0; b < kvSplitSize; b += block_64) { - at::native::cpublas::brgemm( - qSplitSize, block_64, qk_gemm_K, - qk_gemm_K, // lda - block_64, //ldb - rndkvSplitSize, //ldc, - false, - query_t_padding_ptr, - k_reorder + b * qk_gemm_K, - qk_s32_data + b); - } - } else { - auto block_size = kvTail >= block_64 ? block_64 : kv_tail_tail_block_size; - int64_t b = 0; - while (b < kvTail) { - at::native::cpublas::brgemm( - qSplitSize, block_size, qk_gemm_K, - qk_gemm_K, // lda - block_size, //ldb - rndkvTail, //ldc - false, - query_t_padding_ptr, - k_reorder + b * qk_gemm_K, - qk_s32_data + b); - b += block_size; - block_size = (kvTail - b) >= block_64 ? block_64 : kv_tail_tail_block_size; - } - } - } else { - if (n + kvSplitSize < kvSize) { - for (int64_t b = 0; b < kvSplitSize; b += block_64) { - at::native::cpublas::brgemm( - qTail, block_64, qk_gemm_K, - qk_gemm_K,// lda - block_64, //ldb - rndkvSplitSize, //ldc - false, - query_t_padding_ptr, - k_reorder + b * qk_gemm_K, - qk_s32_data + b); - } - } else { - // k tail - auto block_size = kvTail >= block_64 ? block_64 : kv_tail_tail_block_size; - int64_t b = 0; - while (b < kvTail) { - at::native::cpublas::brgemm( - qTail, block_size, qk_gemm_K, - qk_gemm_K, // lda - block_size, //ldb - rndkvTail, //ldc - false, - query_t_padding_ptr, - k_reorder + b * qk_gemm_K, - qk_s32_data + b); - b += block_size; - block_size = (kvTail - b) >= block_64 ? block_64 : kv_tail_tail_block_size; - } - } + for (int64_t b = 0; b < kvBlockSize; b += block_64) { + at::native::cpublas::brgemm( + qSplitSize, block_64, qk_gemm_K, + qk_gemm_K, // lda + block_64, //ldb + rndkvSplitSize, //ldc, + false, + query_t_padding_ptr, + k_reorder + b * qk_gemm_K, + qk_s32_data + b); } // do dequant compensation, add mask, max reduce for softmax, and convert qk from s32 to fp32 - int64_t rndkvBlockSize = kvBlockSize == kvSplitSize ? rndkvSplitSize : rndkvTail; accum_t* qk_block_data = qk_data + l * qSplitSize * rndkvSplitSize; if (has_attn_mask) { mask_t* mask_data_offset = mask_data + i * mStrideB + j * mStrideH + m * mStrideM + (mStrideN == 0 ? 0 : n); @@ -1684,7 +1485,7 @@ sdpa_int8_kernel_several_loops_impl( k_sum_ptr + n, //sum_b_ptr qBlockSize, //M kvBlockSize, //N - rndkvBlockSize, //ldi + rndkvSplitSize, //ldi mStrideM, //ldm rndkvSplitSize,//kvBlockSize, //ldo q_zp * k_zp * headSize, //zp_a*zp_b*k=beta @@ -1699,7 +1500,7 @@ sdpa_int8_kernel_several_loops_impl( k_sum_ptr + n, //sum_b_ptr qBlockSize, //M kvBlockSize, //N - rndkvBlockSize, //ldi + rndkvSplitSize, //ldi rndkvSplitSize,//kvBlockSize, //ldo q_zp * k_zp * headSize, //zp_a*zp_b*k=beta q_scale * k_scale * scaling_factor, //scale_a*scale_b*scale_sdpa=alpha @@ -1772,19 +1573,33 @@ sdpa_int8_kernel_several_loops_impl( // After the last gemm, // do dequant compensation, quant and convert from s32 to int8 - _dequant_quant_fusion_kernel( - dst_s32_data, //in - a_sum_ptr, //sum_a_ptr - v_sum_ptr, //sum_b_ptr - qBlockSize, //M - headSize, //N - rndHeadSize, //ldi - oStrideM, //ldo - a_zp * v_zp * kvSize, //zp_a*zp_b*k=beta1 - o_zp, //zp_c=beta2 - a_scale * v_scale / o_scale, //scale_a*scale_b/scale_c=alpha - out_data + i * oStrideB + j * oStrideH + m * oStrideM //out - ); + if (a_zp == 0) { + _dequant_quant_fusion_kernel( + dst_s32_data, //in + a_sum_ptr, //sum_a_ptr + qBlockSize, //M + headSize, //N + rndHeadSize, //ldi + oStrideM, //ldo + o_zp, //zp_c=beta2 + a_scale * v_scale / o_scale, //scale_a*scale_b/scale_c=alpha + out_data + i * oStrideB + j * oStrideH + m * oStrideM //out + ); + } else { + _dequant_quant_fusion_kernel( + dst_s32_data, //in + a_sum_ptr, //sum_a_ptr + v_sum_ptr, //sum_b_ptr + qBlockSize, //M + headSize, //N + rndHeadSize, //ldi + oStrideM, //ldo + a_zp * v_zp * kvSize, //zp_a*zp_b*k=beta1 + o_zp, //zp_c=beta2 + a_scale * v_scale / o_scale, //scale_a*scale_b/scale_c=alpha + out_data + i * oStrideB + j * oStrideH + m * oStrideM //out + ); + } // Move to the next query at::native::data_index_step(i, batchSize, j, num_head, k, qSlice); } @@ -1843,7 +1658,7 @@ void sdpa_int8_fused_kernel( int64_t num_thread = at::get_num_threads(); int64_t attn_size = q_split_size * kv_seq_len * sizeof(int32_t) * num_thread; bool use_one_parallel_loop = (batchSize * num_head > num_thread) && - (attn_size > l2_cache_size); + (attn_size > 1.5 * l2_cache_size); if (use_one_parallel_loop) { if (!attn_mask.defined()) { if (q_split_size == 256) { From 3856f4957023b42c32d98b815b2b27fc72ce9739 Mon Sep 17 00:00:00 2001 From: Valentine233 Date: Tue, 15 Apr 2025 02:01:02 +0000 Subject: [PATCH 26/36] rebase and update --- .../inductor/test_int8_sdpa_fusion.py | 5 +- torchao/csrc/cpu/int8_sdpa.cpp | 184 +++++++++--------- torchao/ops.py | 7 +- torchao/quantization/__init__.py | 1 - 4 files changed, 103 insertions(+), 94 deletions(-) diff --git a/test/prototype/inductor/test_int8_sdpa_fusion.py b/test/prototype/inductor/test_int8_sdpa_fusion.py index 5c005b1c79..a299222324 100644 --- a/test/prototype/inductor/test_int8_sdpa_fusion.py +++ b/test/prototype/inductor/test_int8_sdpa_fusion.py @@ -169,8 +169,9 @@ def _test_sdpa_int8_rewriter(self): inputs = ( torch.randn( (bs, seqlen, headsize * numhead), device=self.device, dtype=dtype - ), - torch.randn((bs, 1, 1, seqlen), device=self.device) + ) + * 10, + torch.randn((bs, 1, 1, seqlen), device=self.device) * 10 if has_mask else None, ) diff --git a/torchao/csrc/cpu/int8_sdpa.cpp b/torchao/csrc/cpu/int8_sdpa.cpp index c88e0ae8e8..0b4117c77d 100644 --- a/torchao/csrc/cpu/int8_sdpa.cpp +++ b/torchao/csrc/cpu/int8_sdpa.cpp @@ -29,16 +29,6 @@ namespace torchao { namespace { -template -struct is_reduced_floating_point: - std::integral_constant || - std::is_same_v> { -}; - -template -constexpr bool is_reduced_floating_point_v = is_reduced_floating_point::value; - inline double calculate_scale( const at::Tensor& query, double scale) { @@ -91,13 +81,6 @@ inline void _store(scalar_t* dst, at::vec::Vectorized src, int size=at src.store(dst, size); } -template -inline typename std::enable_if_t, void> -_store(scalar_t* dst, at::vec::Vectorized src, int size=at::vec::Vectorized::size()) { - auto res = at::vec::convert_from_float(src, src); - res.store(dst, size); -} - template inline typename std::enable_if_t || std::is_same_v, void> _store(scalar_t* dst, at::vec::Vectorized src, int size=at::vec::Vectorized::size()) { @@ -329,6 +312,10 @@ inline void _sub_exp_sum_div_quant_sum_fusion_kernel( } } +/* +1. Softmax: sub max, exp, sum reduce, div sum +2. quant +*/ template inline void _sub_exp_sum_div_quant_fusion_kernel( const float* in, @@ -483,6 +470,10 @@ inline void _dequant_quant_fusion_kernel( } } +/* +1. dequant +2. quant +*/ template inline void _dequant_quant_fusion_kernel( const int32_t* in, @@ -556,6 +547,7 @@ inline void _int_sum_b_contiguous_kernel_helper( out[0] = vec_tmp_sum.reduce_add() * scale; } +// reduce along dim b for shape [a, b], with sum shape [a] template inline void _int_sum_b_contiguous_kernel( const scalar_t* in, @@ -569,6 +561,7 @@ inline void _int_sum_b_contiguous_kernel( } } +// reduce along dim a for shape [a, b], with sum shape [b] template inline void _int_sum_a_contiguous_kernel( const scalar_t* in, @@ -622,6 +615,7 @@ inline void _int_sum_a_contiguous_kernel( } } +// do the transpose: [in_rows, in_cols] -> [in_cols, in_rows] template inline void do_transpose( scalar_t* src, @@ -637,6 +631,7 @@ inline void do_transpose( } } +// padding with pad_val: [rows, cols] -> [prows, pcols] template inline void pad_remain_row_col( scalar_t* value_ptr, @@ -675,6 +670,7 @@ inline void pad_remain_row_col( } } +// copy value_ptr to dst_ptr with padding: [rows, cols] -> [prows, pcols] template inline void copy_value_with_pad( scalar_t* value_ptr, @@ -739,7 +735,7 @@ sdpa_int8_kernel_one_loop_impl( const at::Tensor& v, double dropout_p, bool is_causal, - at::Tensor& attention_mask, + std::optional attention_mask, double scale, int32_t q_zp, float q_scale, @@ -779,9 +775,9 @@ sdpa_int8_kernel_one_loop_impl( int64_t num_head = query.size(2); int64_t headSize = query.size(3); - bool has_attn_mask = attention_mask.defined() && attention_mask.numel(); + bool has_attn_mask = attention_mask.has_value() && attention_mask.value().numel(); if (has_attn_mask) { - reshape_attn_mask_to_4d(attention_mask, batchSize, num_head, qSize, kvSize); + reshape_attn_mask_to_4d(attention_mask.value(), batchSize, num_head, qSize, kvSize); } // Strides @@ -798,20 +794,20 @@ sdpa_int8_kernel_one_loop_impl( int64_t oStrideM = output.stride(1); int64_t oStrideH = output.stride(2); int64_t mStrideB = - (attention_mask.defined() && attention_mask.size(0) > 1) - ? attention_mask.stride(0) + (has_attn_mask && attention_mask.value().size(0) > 1) + ? attention_mask.value().stride(0) : 0; int64_t mStrideH = - (attention_mask.defined() && attention_mask.size(1) > 1) - ? attention_mask.stride(1) + (has_attn_mask && attention_mask.value().size(1) > 1) + ? attention_mask.value().stride(1) : 0; int64_t mStrideM = - (attention_mask.defined() && attention_mask.size(2) > 1) - ? attention_mask.stride(2) + (has_attn_mask && attention_mask.value().size(2) > 1) + ? attention_mask.value().stride(2) : 0; int64_t mStrideN = - (attention_mask.defined() && attention_mask.size(3) > 1) - ? attention_mask.stride(3) + (has_attn_mask && attention_mask.value().size(3) > 1) + ? attention_mask.value().stride(3) : 0; int64_t qSplitSize = q_split_size > qSize ? qSize : q_split_size; @@ -828,15 +824,14 @@ sdpa_int8_kernel_one_loop_impl( bool av_gemm_K_mul4 = kvSplitSize % 4 == 0; int av_gemm_K_padding = av_gemm_K_mul4 ? 0 : 4 - kvSplitSize % 4; - // // If K of Gemm is not even, use mkl gemm instead of tpp for BF16 int av_gemm_K = kvSplitSize + av_gemm_K_padding; // Data ptrs scalar_t* q_data = query.data_ptr(); scalar_t* k_data = key.data_ptr(); scalar_t* v_data = value.data_ptr(); - mask_t* mask_data = attention_mask.defined() - ? attention_mask.data_ptr() + mask_t* mask_data = attention_mask.has_value() + ? attention_mask.value().data_ptr() : nullptr; scalar_t* out_data = output.data_ptr(); @@ -926,7 +921,7 @@ sdpa_int8_kernel_one_loop_impl( headSize, kvSize, vStrideN, a_zp); } - // pack + // transpose and packing for (int64_t n = 0; n < kvSize; n += kvSplitSize) { int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n); for (int64_t b = 0; b < kvBlockSize; b += block_64) { @@ -966,7 +961,7 @@ sdpa_int8_kernel_one_loop_impl( at::native::cpublas::pack( av_gemm_K, block_64, - vStrideN, // block_64, + vStrideN, block_64, u8_dt, u8_dt, @@ -997,7 +992,7 @@ sdpa_int8_kernel_one_loop_impl( qBlockSize, qk_gemm_K, qStrideM); - + // sum q if (k_zp != 0) { _int_sum_b_contiguous_kernel(q_data + i * qStrideB + j * qStrideH + m * qStrideM, q_sum_ptr, qBlockSize, headSize, qStrideM, k_zp); @@ -1009,6 +1004,7 @@ sdpa_int8_kernel_one_loop_impl( for (int64_t l = 0; l < rkvSlice; l++) { int64_t n = l * kvSplitSize; int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n); + // Calculate q @ k.T for (int64_t b = 0; b < kvBlockSize; b += block_64) { at::native::cpublas::brgemm( qSplitSize, block_64, qk_gemm_K, @@ -1023,7 +1019,6 @@ sdpa_int8_kernel_one_loop_impl( } // do dequant compensation, add mask, max reduce for softmax, and convert qk from s32 to fp32 - int64_t rndkvBlockSize = kvBlockSize == kvSplitSize ? rndkvSplitSize : rndkvTail; accum_t* qk_block_data = qk_data + l * qSplitSize * rndkvSplitSize; if (has_attn_mask) { mask_t* mask_data_offset = mask_data + i * mStrideB + j * mStrideH + m * mStrideM + (mStrideN == 0 ? 0 : n); @@ -1034,9 +1029,9 @@ sdpa_int8_kernel_one_loop_impl( k_sum_ptr + n, //sum_b_ptr qBlockSize, //M kvBlockSize, //N - rndkvBlockSize, //ldi + rndkvSplitSize, //ldi mStrideM, //ldm - rndkvSplitSize,//kvBlockSize, //ldo + rndkvSplitSize, //ldo q_zp * k_zp * headSize, //zp_a*zp_b*k=beta q_scale * k_scale * scaling_factor, //scale_a*scale_b*scale_sdpa=alpha qk_block_data, //out @@ -1049,8 +1044,8 @@ sdpa_int8_kernel_one_loop_impl( k_sum_ptr + n, //sum_b_ptr qBlockSize, //M kvBlockSize, //N - rndkvBlockSize, //ldi - rndkvSplitSize,//kvBlockSize, //ldo + rndkvSplitSize, //ldi + rndkvSplitSize, //ldo q_zp * k_zp * headSize, //zp_a*zp_b*k=beta q_scale * k_scale * scaling_factor, //scale_a*scale_b*scale_sdpa=alpha qk_block_data, //out @@ -1119,19 +1114,33 @@ sdpa_int8_kernel_one_loop_impl( // After the last gemm, // do dequant compensation, quant and convert from s32 to int8 - _dequant_quant_fusion_kernel( - dst_s32_data, //in - a_sum_ptr, //sum_a_ptr - v_sum_ptr, //sum_b_ptr - qBlockSize, //M - headSize, //N - rndHeadSize, //ldi - oStrideM, //ldo - a_zp * v_zp * kvSize, //zp_a*zp_b*k=beta1 - o_zp, //zp_c=beta2 - a_scale * v_scale / o_scale, //scale_a*scale_b/scale_c=alpha - out_data + i * oStrideB + j * oStrideH + m * oStrideM //out - ); + if (a_zp == 0) { + _dequant_quant_fusion_kernel( + dst_s32_data, //in + a_sum_ptr, //sum_a_ptr + qBlockSize, //M + headSize, //N + rndHeadSize, //ldi + oStrideM, //ldo + o_zp, //zp_c=beta2 + a_scale * v_scale / o_scale, //scale_a*scale_b/scale_c=alpha + out_data + i * oStrideB + j * oStrideH + m * oStrideM //out + ); + } else { + _dequant_quant_fusion_kernel( + dst_s32_data, //in + a_sum_ptr, //sum_a_ptr + v_sum_ptr, //sum_b_ptr + qBlockSize, //M + headSize, //N + rndHeadSize, //ldi + oStrideM, //ldo + a_zp * v_zp * kvSize, //zp_a*zp_b*k=beta1 + o_zp, //zp_c=beta2 + a_scale * v_scale / o_scale, //scale_a*scale_b/scale_c=alpha + out_data + i * oStrideB + j * oStrideH + m * oStrideM //out + ); + } } // Move to the next query at::native::data_index_step(i, batchSize, j, num_head); @@ -1151,7 +1160,7 @@ sdpa_int8_kernel_several_loops_impl( const at::Tensor& v, double dropout_p, bool is_causal, - at::Tensor& attention_mask, + std::optional attention_mask, double scale, int32_t q_zp, float q_scale, @@ -1191,9 +1200,9 @@ sdpa_int8_kernel_several_loops_impl( int64_t num_head = query.size(2); int64_t headSize = query.size(3); - bool has_attn_mask = attention_mask.defined() && attention_mask.numel(); + bool has_attn_mask = attention_mask.has_value() && attention_mask.value().numel(); if (has_attn_mask) { - reshape_attn_mask_to_4d(attention_mask, batchSize, num_head, qSize, kvSize); + reshape_attn_mask_to_4d(attention_mask.value(), batchSize, num_head, qSize, kvSize); } // Strides @@ -1210,20 +1219,20 @@ sdpa_int8_kernel_several_loops_impl( int64_t oStrideM = output.stride(1); int64_t oStrideH = output.stride(2); int64_t mStrideB = - (attention_mask.defined() && attention_mask.size(0) > 1) - ? attention_mask.stride(0) + (has_attn_mask && attention_mask.value().size(0) > 1) + ? attention_mask.value().stride(0) : 0; int64_t mStrideH = - (attention_mask.defined() && attention_mask.size(1) > 1) - ? attention_mask.stride(1) + (has_attn_mask && attention_mask.value().size(1) > 1) + ? attention_mask.value().stride(1) : 0; int64_t mStrideM = - (attention_mask.defined() && attention_mask.size(2) > 1) - ? attention_mask.stride(2) + (has_attn_mask && attention_mask.value().size(2) > 1) + ? attention_mask.value().stride(2) : 0; int64_t mStrideN = - (attention_mask.defined() && attention_mask.size(3) > 1) - ? attention_mask.stride(3) + (has_attn_mask && attention_mask.value().size(3) > 1) + ? attention_mask.value().stride(3) : 0; int64_t qSplitSize = q_split_size > qSize ? qSize : q_split_size; @@ -1246,12 +1255,11 @@ sdpa_int8_kernel_several_loops_impl( scalar_t* q_data = query.data_ptr(); scalar_t* k_data = key.data_ptr(); scalar_t* v_data = value.data_ptr(); - mask_t* mask_data = attention_mask.defined() - ? attention_mask.data_ptr() + mask_t* mask_data = attention_mask.has_value() + ? attention_mask.value().data_ptr() : nullptr; scalar_t* out_data = output.data_ptr(); - // Create tpp kernels for Query @ Key bool headSize_mul64 = headSize % 64 == 0; int qk_gemm_K_padding = headSize_mul64 ? 0 : 64 - headSize % 64; int qk_gemm_K = headSize + qk_gemm_K_padding; @@ -1328,7 +1336,7 @@ sdpa_int8_kernel_several_loops_impl( } }); - // packing + // transpose and packing at::parallel_for( 0, batchSize * num_head * kvSlice, 1, [&](int64_t begin, int64_t end) { int64_t i = 0, j = 0, l = 0, n = 0; @@ -1380,7 +1388,7 @@ sdpa_int8_kernel_several_loops_impl( at::native::cpublas::pack( av_gemm_K, block_64, - vStrideN, // block_64, + vStrideN, block_64, u8_dt, u8_dt, @@ -1444,10 +1452,10 @@ sdpa_int8_kernel_several_loops_impl( query_t_padding_ptr, qBlockSize, headSize, - qSplitSize, //qSplitSize, + qSplitSize, qk_gemm_K, qStrideM); - + // sum q if (k_zp != 0) { _int_sum_b_contiguous_kernel(query_t_padding_ptr, q_sum_ptr, qBlockSize, headSize, qk_gemm_K, k_zp); @@ -1461,7 +1469,7 @@ sdpa_int8_kernel_several_loops_impl( int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n); auto k_reorder = key_reorder_ptr + i * num_head * qk_gemm_K * rndkvSize + j * qk_gemm_K * rndkvSize + n * qk_gemm_K; - // Calculate sums for dequant compensation item + // Calculate q @ k.T for (int64_t b = 0; b < kvBlockSize; b += block_64) { at::native::cpublas::brgemm( qSplitSize, block_64, qk_gemm_K, @@ -1487,7 +1495,7 @@ sdpa_int8_kernel_several_loops_impl( kvBlockSize, //N rndkvSplitSize, //ldi mStrideM, //ldm - rndkvSplitSize,//kvBlockSize, //ldo + rndkvSplitSize, //ldo q_zp * k_zp * headSize, //zp_a*zp_b*k=beta q_scale * k_scale * scaling_factor, //scale_a*scale_b*scale_sdpa=alpha qk_block_data, //out @@ -1501,7 +1509,7 @@ sdpa_int8_kernel_several_loops_impl( qBlockSize, //M kvBlockSize, //N rndkvSplitSize, //ldi - rndkvSplitSize,//kvBlockSize, //ldo + rndkvSplitSize, //ldo q_zp * k_zp * headSize, //zp_a*zp_b*k=beta q_scale * k_scale * scaling_factor, //scale_a*scale_b*scale_sdpa=alpha qk_block_data, //out @@ -1562,7 +1570,7 @@ sdpa_int8_kernel_several_loops_impl( at::native::cpublas::brgemm( qSplitSize, block_64, av_gemm_K, av_gemm_K, // lda - rndHeadSize, //block_64, //ldb + rndHeadSize, //ldb rndHeadSize, //ldc s != 0, qk_reduced_data + s * qk_reduce_strideL, @@ -1630,7 +1638,7 @@ void sdpa_int8_fused_kernel( const at::Tensor& value, double dropout_p, bool is_causal, - at::Tensor& attn_mask, + std::optional attn_mask, double scale, long q_zp, double q_scale, @@ -1660,7 +1668,7 @@ void sdpa_int8_fused_kernel( bool use_one_parallel_loop = (batchSize * num_head > num_thread) && (attn_size > 1.5 * l2_cache_size); if (use_one_parallel_loop) { - if (!attn_mask.defined()) { + if (!attn_mask.has_value()) { if (q_split_size == 256) { sdpa_int8_kernel_one_loop_impl( output, query, key, value, @@ -1690,7 +1698,7 @@ void sdpa_int8_fused_kernel( o_zp, o_scale); } } else { - AT_DISPATCH_MASK_TYPES(attn_mask.scalar_type(), "sdpa_mask", [&]() { + AT_DISPATCH_MASK_TYPES(attn_mask.value().scalar_type(), "sdpa_mask", [&]() { if (q_split_size == 256) { sdpa_int8_kernel_one_loop_impl( output, query, key, value, @@ -1722,7 +1730,7 @@ void sdpa_int8_fused_kernel( }); } } else { - if (!attn_mask.defined()) { + if (!attn_mask.has_value()) { if (q_split_size == 256) { sdpa_int8_kernel_several_loops_impl( output, query, key, value, @@ -1752,7 +1760,7 @@ void sdpa_int8_fused_kernel( o_zp, o_scale); } } else { - AT_DISPATCH_MASK_TYPES(attn_mask.scalar_type(), "sdpa_mask", [&]() { + AT_DISPATCH_MASK_TYPES(attn_mask.value().scalar_type(), "sdpa_mask", [&]() { if (q_split_size == 256) { sdpa_int8_kernel_several_loops_impl( output, query, key, value, @@ -1793,7 +1801,7 @@ at::Tensor sdpa_int8_math_kernel( const at::Tensor& value, double dropout_p, bool is_causal, - at::Tensor& attn_mask, + std::optional attn_mask, double scale, int32_t q_zp, float q_scale, @@ -1811,8 +1819,8 @@ at::Tensor sdpa_int8_math_kernel( auto v = (value.to(at::kFloat) - v_zp) * v_scale; const auto scaling_factor = calculate_scale(q, scale); auto attn = at::matmul(q, k.transpose(-2, -1)) * scaling_factor; - if (attn_mask.defined() && attn_mask.numel()) { - attn = attn.add(attn_mask.to(at::kFloat)); + if (attn_mask.has_value() && attn_mask.value().numel()) { + attn = attn.add(attn_mask.value().to(at::kFloat)); } attn = at::softmax(attn, -1); // quant attn @@ -1834,7 +1842,7 @@ at::Tensor _scaled_dot_product_int8_cpu( const at::Tensor& query, const at::Tensor& key, const at::Tensor& value, - at::Tensor& attn_mask, + std::optional attn_mask, double dropout_p, bool is_causal, double scale, @@ -1861,12 +1869,12 @@ at::Tensor _scaled_dot_product_int8_cpu( "_scaled_dot_product_int8_cpu: Currently do not support dropout > 0"); TORCH_CHECK((query.size(3) == value.size(3)) && (key.size(3) == value.size(3)), "_scaled_dot_product_int8_cpu: Q/K/V should have the same head size"); - TORCH_CHECK(!attn_mask.defined() || - attn_mask.scalar_type() == at::kFloat || - attn_mask.scalar_type() == at::kBFloat16, + TORCH_CHECK(!attn_mask.has_value() || + attn_mask.value().scalar_type() == at::kFloat || + attn_mask.value().scalar_type() == at::kBFloat16, "_scaled_dot_product_int8_cpu: Expected attention mask be float or bf16"); - TORCH_CHECK(!attn_mask.defined() || - (attn_mask.dim() == 2 || attn_mask.dim() == 4), + TORCH_CHECK(!attn_mask.has_value() || + (attn_mask.value().dim() == 2 || attn_mask.value().dim() == 4), "_scaled_dot_product_int8_cpu: Attention mask dim in {2, 4}"); #ifdef CPU_CAPABILITY_AVX512 diff --git a/torchao/ops.py b/torchao/ops.py index 38a341f435..2c1b2c5368 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -56,9 +56,10 @@ tags=[torch._C.Tag.needs_fixed_stride_order], ) lib.define( - "scaled_dot_product_int8(Tensor query, Tensor key, Tensor value, Tensor attn_mask=None, float dropout_p=0.0, bool is_causal=False, float scale=0.0, int q_zp=0, float q_scale=1.0, int k_zp=0, float k_scale=1.0, int v_zp=0, float v_scale=1.0, int a_zp=0, float a_scale=1.0, int o_zp=0, float o_scale=1.0) -> Tensor" + "scaled_dot_product_int8(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, float scale=0.0, int q_zp=0, float q_scale=1.0, int k_zp=0, float k_scale=1.0, int v_zp=0, float v_scale=1.0, int a_zp=0, float a_scale=1.0, int o_zp=0, float o_scale=1.0) -> Tensor" ) + def register_custom_op(name): def decorator(func): if TORCH_VERSION_AT_LEAST_2_4: @@ -165,7 +166,7 @@ def scaled_dot_product_int8( query: Tensor, key: Tensor, value: Tensor, - attn_mask: Tensor = None, + attn_mask: Optional[Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: float = 0.0, @@ -231,7 +232,7 @@ def _( query: Tensor, key: Tensor, value: Tensor, - attn_mask: Tensor = None, + attn_mask: Optional[Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: float = 0.0, diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 46740a81fd..67a70c5a35 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -192,5 +192,4 @@ "TensorCoreTiledLayout", "CutlassInt4PackedLayout", "Float8MMConfig", - "_sfdp_init_int8", ] From 246d545a4ec870a18a578df1caaf15db62759ac2 Mon Sep 17 00:00:00 2001 From: Valentine233 Date: Tue, 15 Apr 2025 02:16:59 +0000 Subject: [PATCH 27/36] fix issue --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index b6a2d77f19..07a9754a4d 100644 --- a/setup.py +++ b/setup.py @@ -56,9 +56,9 @@ def read_version(file_path="version.txt"): ) import torch + use_cpp_avx512 = ( os.getenv("USE_AVX512", "1") == "1" - and torch._C._cpu._is_avx512_supported() and platform.system() == "Linux" ) From cc3c474219ea90e543d1d654551a94c0eedb2841 Mon Sep 17 00:00:00 2001 From: Valentine233 Date: Tue, 15 Apr 2025 02:22:16 +0000 Subject: [PATCH 28/36] fix issue --- setup.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/setup.py b/setup.py index 07a9754a4d..d9cd10dc52 100644 --- a/setup.py +++ b/setup.py @@ -55,12 +55,7 @@ def read_version(file_path="version.txt"): and platform.system() == "Darwin" ) -import torch - -use_cpp_avx512 = ( - os.getenv("USE_AVX512", "1") == "1" - and platform.system() == "Linux" -) +use_cpp_avx512 = os.getenv("USE_AVX512", "1") == "1" and platform.system() == "Linux" from torchao.utils import TORCH_VERSION_AT_LEAST_2_7 @@ -77,6 +72,7 @@ def use_debug_mode(): return os.getenv("DEBUG", "0") == "1" +import torch from torch.utils.cpp_extension import ( CUDA_HOME, IS_WINDOWS, From 4da9b6e6dd322416c38b39bacb5e5f9d6d71896a Mon Sep 17 00:00:00 2001 From: Valentine233 Date: Tue, 15 Apr 2025 02:35:26 +0000 Subject: [PATCH 29/36] fix issue --- setup.py | 18 +++++++++--------- torchao/csrc/cpu/int8_sdpa.cpp | 1 - 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/setup.py b/setup.py index d9cd10dc52..27865deab3 100644 --- a/setup.py +++ b/setup.py @@ -95,9 +95,9 @@ def __init__(self): default=(self._is_arm64() and self._is_macos()), ) if self.build_cpu_aarch64: - assert self._is_arm64(), ( - "TORCHAO_BUILD_CPU_AARCH64 requires an arm64 machine" - ) + assert ( + self._is_arm64() + ), "TORCHAO_BUILD_CPU_AARCH64 requires an arm64 machine" # TORCHAO_BUILD_KLEIDIAI is disabled by default for now because # 1) It increases the build time @@ -106,9 +106,9 @@ def __init__(self): "TORCHAO_BUILD_KLEIDIAI", default=False ) if self.build_kleidi_ai: - assert self.build_cpu_aarch64, ( - "TORCHAO_BUILD_KLEIDIAI requires TORCHAO_BUILD_CPU_AARCH64 be set" - ) + assert ( + self.build_cpu_aarch64 + ), "TORCHAO_BUILD_KLEIDIAI requires TORCHAO_BUILD_CPU_AARCH64 be set" # TORCHAO_BUILD_EXPERIMENTAL_MPS is disabled by default. self.build_experimental_mps = self._os_bool_var( @@ -117,9 +117,9 @@ def __init__(self): if self.build_experimental_mps: assert self._is_macos(), "TORCHAO_BUILD_EXPERIMENTAL_MPS requires MacOS" assert self._is_arm64(), "TORCHAO_BUILD_EXPERIMENTAL_MPS requires arm64" - assert torch.mps.is_available(), ( - "TORCHAO_BUILD_EXPERIMENTAL_MPS requires MPS be available" - ) + assert ( + torch.mps.is_available() + ), "TORCHAO_BUILD_EXPERIMENTAL_MPS requires MPS be available" def _is_arm64(self) -> bool: return platform.machine().startswith("arm64") diff --git a/torchao/csrc/cpu/int8_sdpa.cpp b/torchao/csrc/cpu/int8_sdpa.cpp index 0b4117c77d..f60b3b19b1 100644 --- a/torchao/csrc/cpu/int8_sdpa.cpp +++ b/torchao/csrc/cpu/int8_sdpa.cpp @@ -23,7 +23,6 @@ #include #include #include -#include namespace torchao { From ea1c75edffc7cebba7a434a5ee5d2f411cbf2d4b Mon Sep 17 00:00:00 2001 From: Valentine233 Date: Tue, 15 Apr 2025 03:26:27 +0000 Subject: [PATCH 30/36] fix issue --- setup.py | 17 +++++++++-------- .../prototype/inductor/test_int8_sdpa_fusion.py | 5 ++--- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/setup.py b/setup.py index 27865deab3..ceeab42661 100644 --- a/setup.py +++ b/setup.py @@ -296,14 +296,15 @@ def get_extensions(): ) if use_cpp_avx512 and TORCH_VERSION_AT_LEAST_2_7: - extra_compile_args["cxx"].extend( - [ - "-DCPU_CAPABILITY_AVX512", - "-march=native", - "-mfma", - "-fopenmp", - ] - ) + if torch._C._cpu._is_avx512_supported(): + extra_compile_args["cxx"].extend( + [ + "-DCPU_CAPABILITY_AVX512", + "-march=native", + "-mfma", + "-fopenmp", + ] + ) if debug_mode: extra_compile_args["cxx"].append("-g") diff --git a/test/prototype/inductor/test_int8_sdpa_fusion.py b/test/prototype/inductor/test_int8_sdpa_fusion.py index a299222324..5c005b1c79 100644 --- a/test/prototype/inductor/test_int8_sdpa_fusion.py +++ b/test/prototype/inductor/test_int8_sdpa_fusion.py @@ -169,9 +169,8 @@ def _test_sdpa_int8_rewriter(self): inputs = ( torch.randn( (bs, seqlen, headsize * numhead), device=self.device, dtype=dtype - ) - * 10, - torch.randn((bs, 1, 1, seqlen), device=self.device) * 10 + ), + torch.randn((bs, 1, 1, seqlen), device=self.device) if has_mask else None, ) From e586b3340cf59460b111519ad61e2ae5c3a6a097 Mon Sep 17 00:00:00 2001 From: Valentine233 Date: Wed, 16 Apr 2025 02:13:27 +0000 Subject: [PATCH 31/36] set strict value for export_for_training --- test/prototype/inductor/test_int8_sdpa_fusion.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/prototype/inductor/test_int8_sdpa_fusion.py b/test/prototype/inductor/test_int8_sdpa_fusion.py index 5c005b1c79..a7a4c00048 100644 --- a/test/prototype/inductor/test_int8_sdpa_fusion.py +++ b/test/prototype/inductor/test_int8_sdpa_fusion.py @@ -190,6 +190,7 @@ def _test_sdpa_int8_rewriter(self): export_model = export_for_training( mod, inputs, + strict=True, ).module() prepare_model = prepare_pt2e(export_model, quantizer) prepare_model(*inputs) From b8857ffc025ab65f88e873437593da1b530f5688 Mon Sep 17 00:00:00 2001 From: Valentine233 Date: Wed, 16 Apr 2025 02:18:57 +0000 Subject: [PATCH 32/36] modify name in setup --- setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index ceeab42661..55ec3644d2 100644 --- a/setup.py +++ b/setup.py @@ -361,10 +361,10 @@ def get_extensions(): sources = list(glob.glob(os.path.join(extensions_dir, "**/*.cpp"), recursive=True)) if IS_WINDOWS: # Remove csrc/cpu/*.cpp on Windows due to the link issue: unresolved external symbol PyInit__C - cpp_sources = list( + excluded_sources = list( glob.glob(os.path.join(extensions_dir, "cpu/*.cpp"), recursive=True) ) - sources = [s for s in sources if s not in cpp_sources] + sources = [s for s in sources if s not in excluded_sources] # Collect CUDA source files extensions_cuda_dir = os.path.join(extensions_dir, "cuda") From 53a6cb6235a99bad05c3d1677605e1448ea86659 Mon Sep 17 00:00:00 2001 From: Valentine233 Date: Thu, 17 Apr 2025 03:12:21 +0000 Subject: [PATCH 33/36] refactor code according to comments --- torchao/csrc/cpu/int8_sdpa.cpp | 207 ++++++------ .../inductor/fx_passes/int8_sdpa_fusion.py | 305 +++--------------- 2 files changed, 146 insertions(+), 366 deletions(-) diff --git a/torchao/csrc/cpu/int8_sdpa.cpp b/torchao/csrc/cpu/int8_sdpa.cpp index f60b3b19b1..1f0ea758fe 100644 --- a/torchao/csrc/cpu/int8_sdpa.cpp +++ b/torchao/csrc/cpu/int8_sdpa.cpp @@ -725,9 +725,12 @@ inline void copy_value_with_pad( } // UINT8 - one parallel loop with u8u8s32 GEMM -template +template = 0> inline typename std::enable_if_t, void> -sdpa_int8_kernel_one_loop_impl( +sdpa_int8_fused_kernel_impl( const at::Tensor& output, const at::Tensor& q, const at::Tensor& k, @@ -1150,9 +1153,12 @@ sdpa_int8_kernel_one_loop_impl( } // UINT8 - several parallel loops with u8u8s32 GEMM -template +template = 0> inline typename std::enable_if_t, void> -sdpa_int8_kernel_several_loops_impl( +sdpa_int8_fused_kernel_impl( const at::Tensor& output, const at::Tensor& q, const at::Tensor& k, @@ -1615,6 +1621,53 @@ sdpa_int8_kernel_several_loops_impl( at::native::cpublas::brgemm_release(); } + +template +inline typename std::enable_if_t, void> +sdpa_int8_fused_kernel_impl( + bool use_one_parallel_loop, + const at::Tensor& output, + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + double dropout_p, + bool is_causal, + std::optional attn_mask, + double scale, + int32_t q_zp, + float q_scale, + int32_t k_zp, + float k_scale, + int32_t v_zp, + float v_scale, + int32_t a_zp, + float a_scale, + int32_t o_zp, + float o_scale) { + if (use_one_parallel_loop) { + sdpa_int8_fused_kernel_impl( + output, query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_zp, q_scale, + k_zp, k_scale, + v_zp, v_scale, + a_zp, a_scale, + o_zp, o_scale); + } else { + sdpa_int8_fused_kernel_impl( + output, query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_zp, q_scale, + k_zp, k_scale, + v_zp, v_scale, + a_zp, a_scale, + o_zp, o_scale); + } +} + + #define AT_DISPATCH_MASK_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH( \ TYPE, \ @@ -1661,77 +1714,50 @@ void sdpa_int8_fused_kernel( q_split_size = 64; } // Heuristic to decide whether to use one parallel loop or not + // true: one parallel loop for sum+packing+core + // false: three parallel loops for sum, packing, core uint32_t l2_cache_size = at::cpu::L2_cache_size(); int64_t num_thread = at::get_num_threads(); int64_t attn_size = q_split_size * kv_seq_len * sizeof(int32_t) * num_thread; bool use_one_parallel_loop = (batchSize * num_head > num_thread) && (attn_size > 1.5 * l2_cache_size); - if (use_one_parallel_loop) { - if (!attn_mask.has_value()) { - if (q_split_size == 256) { - sdpa_int8_kernel_one_loop_impl( - output, query, key, value, - dropout_p, is_causal, attn_mask, scale, - q_zp, q_scale, - k_zp, k_scale, - v_zp, v_scale, - a_zp, a_scale, - o_zp, o_scale); - } else if (q_split_size == 64) { - sdpa_int8_kernel_one_loop_impl( - output, query, key, value, - dropout_p, is_causal, attn_mask, scale, - q_zp, q_scale, - k_zp, k_scale, - v_zp, v_scale, - a_zp, a_scale, - o_zp, o_scale); - } else { - sdpa_int8_kernel_one_loop_impl( - output, query, key, value, - dropout_p, is_causal, attn_mask, scale, - q_zp, q_scale, - k_zp, k_scale, - v_zp, v_scale, - a_zp, a_scale, - o_zp, o_scale); - } + if (!attn_mask.has_value()) { + if (q_split_size == 256) { + sdpa_int8_fused_kernel_impl( + use_one_parallel_loop, + output, query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_zp, q_scale, + k_zp, k_scale, + v_zp, v_scale, + a_zp, a_scale, + o_zp, o_scale); + } else if (q_split_size == 64) { + sdpa_int8_fused_kernel_impl( + use_one_parallel_loop, + output, query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_zp, q_scale, + k_zp, k_scale, + v_zp, v_scale, + a_zp, a_scale, + o_zp, o_scale); } else { - AT_DISPATCH_MASK_TYPES(attn_mask.value().scalar_type(), "sdpa_mask", [&]() { - if (q_split_size == 256) { - sdpa_int8_kernel_one_loop_impl( - output, query, key, value, - dropout_p, is_causal, attn_mask, scale, - q_zp, q_scale, - k_zp, k_scale, - v_zp, v_scale, - a_zp, a_scale, - o_zp, o_scale); - } else if (q_split_size == 64) { - sdpa_int8_kernel_one_loop_impl( - output, query, key, value, - dropout_p, is_causal, attn_mask, scale, - q_zp, q_scale, - k_zp, k_scale, - v_zp, v_scale, - a_zp, a_scale, - o_zp, o_scale); - } else { - sdpa_int8_kernel_one_loop_impl( - output, query, key, value, - dropout_p, is_causal, attn_mask, scale, - q_zp, q_scale, - k_zp, k_scale, - v_zp, v_scale, - a_zp, a_scale, - o_zp, o_scale); - } - }); + sdpa_int8_fused_kernel_impl( + use_one_parallel_loop, + output, query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_zp, q_scale, + k_zp, k_scale, + v_zp, v_scale, + a_zp, a_scale, + o_zp, o_scale); } } else { - if (!attn_mask.has_value()) { + AT_DISPATCH_MASK_TYPES(attn_mask.value().scalar_type(), "sdpa_mask", [&]() { if (q_split_size == 256) { - sdpa_int8_kernel_several_loops_impl( + sdpa_int8_fused_kernel_impl( + use_one_parallel_loop, output, query, key, value, dropout_p, is_causal, attn_mask, scale, q_zp, q_scale, @@ -1740,7 +1766,8 @@ void sdpa_int8_fused_kernel( a_zp, a_scale, o_zp, o_scale); } else if (q_split_size == 64) { - sdpa_int8_kernel_several_loops_impl( + sdpa_int8_fused_kernel_impl( + use_one_parallel_loop, output, query, key, value, dropout_p, is_causal, attn_mask, scale, q_zp, q_scale, @@ -1749,7 +1776,8 @@ void sdpa_int8_fused_kernel( a_zp, a_scale, o_zp, o_scale); } else { - sdpa_int8_kernel_several_loops_impl( + sdpa_int8_fused_kernel_impl( + use_one_parallel_loop, output, query, key, value, dropout_p, is_causal, attn_mask, scale, q_zp, q_scale, @@ -1758,38 +1786,7 @@ void sdpa_int8_fused_kernel( a_zp, a_scale, o_zp, o_scale); } - } else { - AT_DISPATCH_MASK_TYPES(attn_mask.value().scalar_type(), "sdpa_mask", [&]() { - if (q_split_size == 256) { - sdpa_int8_kernel_several_loops_impl( - output, query, key, value, - dropout_p, is_causal, attn_mask, scale, - q_zp, q_scale, - k_zp, k_scale, - v_zp, v_scale, - a_zp, a_scale, - o_zp, o_scale); - } else if (q_split_size == 64) { - sdpa_int8_kernel_several_loops_impl( - output, query, key, value, - dropout_p, is_causal, attn_mask, scale, - q_zp, q_scale, - k_zp, k_scale, - v_zp, v_scale, - a_zp, a_scale, - o_zp, o_scale); - } else { - sdpa_int8_kernel_several_loops_impl( - output, query, key, value, - dropout_p, is_causal, attn_mask, scale, - q_zp, q_scale, - k_zp, k_scale, - v_zp, v_scale, - a_zp, a_scale, - o_zp, o_scale); - } - }); - } + }); } } #endif // CPU_CAPABILITY_AVX512 @@ -1888,6 +1885,7 @@ at::Tensor _scaled_dot_product_int8_cpu( o_zp, o_scale); return output.transpose(1, 2); } else { + #endif // CPU_CAPABILITY_AVX512 return sdpa_int8_math_kernel(query, key, value, dropout_p, is_causal, attn_mask, scale, q_zp, q_scale, @@ -1895,15 +1893,8 @@ at::Tensor _scaled_dot_product_int8_cpu( v_zp, v_scale, a_zp, a_scale, o_zp, o_scale).transpose(1, 2).contiguous().transpose(1, 2); + #ifdef CPU_CAPABILITY_AVX512 } - #else - return sdpa_int8_math_kernel(query, key, value, - dropout_p, is_causal, attn_mask, scale, - q_zp, q_scale, - k_zp, k_scale, - v_zp, v_scale, - a_zp, a_scale, - o_zp, o_scale).transpose(1, 2).contiguous().transpose(1, 2); #endif // CPU_CAPABILITY_AVX512 } diff --git a/torchao/prototype/inductor/fx_passes/int8_sdpa_fusion.py b/torchao/prototype/inductor/fx_passes/int8_sdpa_fusion.py index bcf59ad35b..e805da1327 100644 --- a/torchao/prototype/inductor/fx_passes/int8_sdpa_fusion.py +++ b/torchao/prototype/inductor/fx_passes/int8_sdpa_fusion.py @@ -1,4 +1,5 @@ import functools +import itertools import torch from torch._dynamo.utils import counters @@ -96,128 +97,47 @@ def int8_sdpa(match: Match, *args, **kwargs): return int8_sdpa -def _get_int8_sdpa_q_pattern(is_batch_size_1: bool, has_convert: bool): - int8_sdpa_q_basic_pattern = CallFunction( - torch.ops.quantized_decomposed.dequantize_per_tensor.default, - CallFunction( - aten.permute.default, - KeywordArg("query"), - Arg(), - ), - KeywordArg("q_scale"), - KeywordArg("q_zp"), - Arg(), - Arg(), - Arg(), - ) - if has_convert: - int8_sdpa_q_basic_pattern = CallFunction( - torch.ops.prims.convert_element_type.default, - int8_sdpa_q_basic_pattern, - Arg(), - ) - int8_sdpa_q_basic_pattern = CallFunction( - aten.expand.default, - int8_sdpa_q_basic_pattern, +def _get_int8_sdpa_qkv_pattern( + is_batch_size_1: bool, has_convert: bool, input_name: str +): + assert input_name in ["query", "key", "value"] + int8_sdpa_qkv_pattern_before_dequant = CallFunction( + aten.permute.default, + KeywordArg(input_name), Arg(), ) - if is_batch_size_1: - # pattern is different for bs=1 - return CallFunction( - aten.reshape.default, - int8_sdpa_q_basic_pattern, - Arg(), - ) - else: - return CallFunction( - aten.reshape.default, - CallFunction( - aten.clone.default, - int8_sdpa_q_basic_pattern, - memory_format=Arg(), - ), - Arg(), - ) - - -def _get_int8_sdpa_k_pattern(is_batch_size_1: bool, has_convert: bool): - int8_sdpa_k_basic_pattern = CallFunction( - torch.ops.quantized_decomposed.dequantize_per_tensor.default, - CallFunction( + if input_name == "key": + # do transpose + int8_sdpa_qkv_pattern_before_dequant = CallFunction( aten.permute.default, - CallFunction( - aten.permute.default, - KeywordArg("key"), - Arg(), - ), - Arg(), - ), - KeywordArg("k_scale"), - KeywordArg("k_zp"), - Arg(), - Arg(), - Arg(), - ) - if has_convert: - int8_sdpa_k_basic_pattern = CallFunction( - torch.ops.prims.convert_element_type.default, - int8_sdpa_k_basic_pattern, + int8_sdpa_qkv_pattern_before_dequant, Arg(), ) - int8_sdpa_k_basic_pattern = CallFunction( - aten.expand.default, - int8_sdpa_k_basic_pattern, - Arg(), - ) - if is_batch_size_1: - # pattern is different for bs=1 - return CallFunction( - aten.reshape.default, - int8_sdpa_k_basic_pattern, - Arg(), - ) - else: - return CallFunction( - aten.reshape.default, - CallFunction( - aten.clone.default, - int8_sdpa_k_basic_pattern, - memory_format=Arg(), - ), - Arg(), - ) - - -def _get_int8_sdpa_v_pattern(is_batch_size_1: bool, has_convert: bool): - int8_sdpa_v_basic_pattern = CallFunction( + int8_sdpa_qkv_basic_pattern = CallFunction( torch.ops.quantized_decomposed.dequantize_per_tensor.default, - CallFunction( - aten.permute.default, - KeywordArg("value"), - Arg(), - ), - KeywordArg("v_scale"), - KeywordArg("v_zp"), + int8_sdpa_qkv_pattern_before_dequant, + KeywordArg(input_name[0] + "_scale"), + KeywordArg(input_name[0] + "_zp"), Arg(), Arg(), Arg(), ) if has_convert: - int8_sdpa_v_basic_pattern = CallFunction( + int8_sdpa_qkv_basic_pattern = CallFunction( torch.ops.prims.convert_element_type.default, - int8_sdpa_v_basic_pattern, + int8_sdpa_qkv_basic_pattern, Arg(), ) - int8_sdpa_v_basic_pattern = CallFunction( + int8_sdpa_qkv_basic_pattern = CallFunction( aten.expand.default, - int8_sdpa_v_basic_pattern, + int8_sdpa_qkv_basic_pattern, Arg(), ) if is_batch_size_1: # pattern is different for bs=1 return CallFunction( aten.reshape.default, - int8_sdpa_v_basic_pattern, + int8_sdpa_qkv_basic_pattern, Arg(), ) else: @@ -225,7 +145,7 @@ def _get_int8_sdpa_v_pattern(is_batch_size_1: bool, has_convert: bool): aten.reshape.default, CallFunction( aten.clone.default, - int8_sdpa_v_basic_pattern, + int8_sdpa_qkv_basic_pattern, memory_format=Arg(), ), Arg(), @@ -235,8 +155,12 @@ def _get_int8_sdpa_v_pattern(is_batch_size_1: bool, has_convert: bool): def _get_int8_sdpa_score_pattern( has_mask: bool, is_batch_size_1: bool, is_reduced_type: bool, has_convert: bool ): - int8_sdpa_q_pattern = _get_int8_sdpa_q_pattern(is_batch_size_1, has_convert) - int8_sdpa_k_pattern = _get_int8_sdpa_k_pattern(is_batch_size_1, has_convert) + int8_sdpa_q_pattern = _get_int8_sdpa_qkv_pattern( + is_batch_size_1, has_convert, "query" + ) + int8_sdpa_k_pattern = _get_int8_sdpa_qkv_pattern( + is_batch_size_1, has_convert, "key" + ) int8_sdpa_score_basic_pattern = CallFunction( aten.reshape.default, CallFunction( @@ -385,10 +309,17 @@ def _get_int8_sdpa_attn_pattern( ) +# Parameters to generate various patterns: +# has_mask: if SDPA has attention mask +# is_batch_size_1: if the batch size is 1 +# is_reduced_type: if autocast is enabled +# has_convert: convert type if dequant out dtype is assigned def _get_int8_sdpa_final_pattern( has_mask: bool, is_batch_size_1: bool, is_reduced_type: bool, has_convert: bool ): - int8_sdpa_v_pattern = _get_int8_sdpa_v_pattern(is_batch_size_1, has_convert) + int8_sdpa_v_pattern = _get_int8_sdpa_qkv_pattern( + is_batch_size_1, has_convert, "value" + ) int8_sdpa_attn_pattern = _get_int8_sdpa_attn_pattern( has_mask, is_batch_size_1, is_reduced_type, has_convert ) @@ -419,160 +350,18 @@ def _get_int8_sdpa_final_pattern( ) -def _register_int8_sdpa_fp32_lowering(): - # dtype = float32, without attention mask, batch size > 1 - _register_int8_sdpa_pattern( - _get_int8_sdpa_final_pattern( - has_mask=False, - is_batch_size_1=False, - is_reduced_type=False, - has_convert=True, - ) - ) - _register_int8_sdpa_pattern( - _get_int8_sdpa_final_pattern( - has_mask=False, - is_batch_size_1=False, - is_reduced_type=False, - has_convert=False, - ) - ) - - -def _register_int8_sdpa_fp32_mask_lowering(): - # dtype = float32, with attention mask, batch size > 1 - _register_int8_sdpa_pattern( - _get_int8_sdpa_final_pattern( - has_mask=True, - is_batch_size_1=False, - is_reduced_type=False, - has_convert=True, - ) - ) - _register_int8_sdpa_pattern( - _get_int8_sdpa_final_pattern( - has_mask=True, - is_batch_size_1=False, - is_reduced_type=False, - has_convert=False, - ) - ) - - -def _register_int8_sdpa_fp32_bs1_lowering(): - # dtype = float32, without attention mask, batch size == 1 - _register_int8_sdpa_pattern( - _get_int8_sdpa_final_pattern( - has_mask=False, - is_batch_size_1=True, - is_reduced_type=False, - has_convert=True, - ) - ) - _register_int8_sdpa_pattern( - _get_int8_sdpa_final_pattern( - has_mask=False, - is_batch_size_1=True, - is_reduced_type=False, - has_convert=False, - ) - ) - - -def _register_int8_sdpa_fp32_mask_bs1_lowering(): - # dtype = float32, with attention mask, batch size == 1 - _register_int8_sdpa_pattern( - _get_int8_sdpa_final_pattern( - has_mask=True, is_batch_size_1=True, is_reduced_type=False, has_convert=True - ) - ) - _register_int8_sdpa_pattern( - _get_int8_sdpa_final_pattern( - has_mask=True, - is_batch_size_1=True, - is_reduced_type=False, - has_convert=False, - ) - ) - - -def _register_int8_sdpa_bf16_lowering(): - # dtype = bfloat16, without attention mask, batch size > 1 - _register_int8_sdpa_pattern( - _get_int8_sdpa_final_pattern( - has_mask=False, - is_batch_size_1=False, - is_reduced_type=True, - has_convert=True, - ) - ) - _register_int8_sdpa_pattern( - _get_int8_sdpa_final_pattern( - has_mask=False, - is_batch_size_1=False, - is_reduced_type=True, - has_convert=False, - ) - ) - - -def _register_int8_sdpa_bf16_mask_lowering(): - # dtype = bfloat16, with attention mask, batch size > 1 - _register_int8_sdpa_pattern( - _get_int8_sdpa_final_pattern( - has_mask=True, is_batch_size_1=False, is_reduced_type=True, has_convert=True - ) - ) - _register_int8_sdpa_pattern( - _get_int8_sdpa_final_pattern( - has_mask=True, - is_batch_size_1=False, - is_reduced_type=True, - has_convert=False, - ) - ) - - -def _register_int8_sdpa_bf16_bs1_lowering(): - # dtype = bfloat16, without attention mask, batch size == 1 - _register_int8_sdpa_pattern( - _get_int8_sdpa_final_pattern( - has_mask=False, is_batch_size_1=True, is_reduced_type=True, has_convert=True - ) - ) - _register_int8_sdpa_pattern( - _get_int8_sdpa_final_pattern( - has_mask=False, - is_batch_size_1=True, - is_reduced_type=True, - has_convert=False, - ) - ) - - -def _register_int8_sdpa_bf16_mask_bs1_lowering(): - # dtype = bfloat16, with attention mask, batch size == 1 - _register_int8_sdpa_pattern( - _get_int8_sdpa_final_pattern( - has_mask=True, is_batch_size_1=True, is_reduced_type=True, has_convert=True - ) - ) - _register_int8_sdpa_pattern( - _get_int8_sdpa_final_pattern( - has_mask=True, is_batch_size_1=True, is_reduced_type=True, has_convert=False - ) - ) - - def _register_int8_sdpa_lowerings(): - _register_int8_sdpa_fp32_lowering() - _register_int8_sdpa_fp32_mask_lowering() - _register_int8_sdpa_fp32_bs1_lowering() - _register_int8_sdpa_fp32_mask_bs1_lowering() - _register_int8_sdpa_bf16_lowering() - _register_int8_sdpa_bf16_mask_lowering() - _register_int8_sdpa_bf16_bs1_lowering() - _register_int8_sdpa_bf16_mask_bs1_lowering() + for has_mask, is_batch_size_1, is_reduced_type, has_convert in itertools.product( + [True, False], [True, False], [True, False], [True, False] + ): + _register_int8_sdpa_pattern( + _get_int8_sdpa_final_pattern( + has_mask=has_mask, + is_batch_size_1=is_batch_size_1, + is_reduced_type=is_reduced_type, + has_convert=has_convert, + ) + ) @functools.lru_cache(None) From dcf2a5525cb5613ce46a25e368aad58661bf439a Mon Sep 17 00:00:00 2001 From: Valentine233 Date: Fri, 18 Apr 2025 05:24:26 +0000 Subject: [PATCH 34/36] change param orders --- .../inductor/test_int8_sdpa_fusion.py | 12 +- test/test_ops.py | 40 ++-- torchao/csrc/cpu/int8_sdpa.cpp | 180 +++++++++--------- torchao/ops.py | 42 ++-- .../inductor/fx_passes/int8_sdpa_fusion.py | 20 +- 5 files changed, 148 insertions(+), 146 deletions(-) diff --git a/test/prototype/inductor/test_int8_sdpa_fusion.py b/test/prototype/inductor/test_int8_sdpa_fusion.py index a7a4c00048..9596e71a7a 100644 --- a/test/prototype/inductor/test_int8_sdpa_fusion.py +++ b/test/prototype/inductor/test_int8_sdpa_fusion.py @@ -11,6 +11,7 @@ from torch.testing._internal.inductor_utils import HAS_CPU from torch.utils.cpp_extension import IS_WINDOWS +import torchao from torchao.prototype.inductor.fx_passes.int8_sdpa_fusion import _int8_sdpa_init from torchao.utils import TORCH_VERSION_AT_LEAST_2_7 @@ -147,12 +148,13 @@ def _check_common( @pytest.mark.skipif(IS_WINDOWS, reason="int8 sdpa does not support windows yet") @config.patch({"freezing": True}) def _test_sdpa_int8_rewriter(self): - import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq - from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e - from torch.ao.quantization.quantizer.x86_inductor_quantizer import ( + from torch.export import export_for_training + + import torchao.quantization.pt2e.quantizer.x86_inductor_quantizer as xiq + from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e + from torchao.quantization.pt2e.quantizer.x86_inductor_quantizer import ( X86InductorQuantizer, ) - from torch.export import export_for_training # pattern is different for bs=1 torch.manual_seed(1234) @@ -195,7 +197,7 @@ def _test_sdpa_int8_rewriter(self): prepare_model = prepare_pt2e(export_model, quantizer) prepare_model(*inputs) convert_model = convert_pt2e(prepare_model) - torch.ao.quantization.move_exported_model_to_eval(convert_model) + torchao.quantization.pt2e.move_exported_model_to_eval(convert_model) self._check_common( convert_model, args1=inputs, check_train=False, atol=1.0 ) diff --git a/test/test_ops.py b/test/test_ops.py index 7cd5c2d00a..5025b8a19b 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -123,16 +123,16 @@ def _scaled_dot_product_int8_op_ref( attn_mask=None, dropout_p=0, is_causal=False, - q_zp=0, q_scale=1.0, - k_zp=0, + q_zp=0, k_scale=1.0, - v_zp=0, + k_zp=0, v_scale=1.0, - a_zp=0, + v_zp=0, a_scale=1.0, - o_zp=0, + a_zp=0, o_scale=1.0, + o_zp=0, ): q = (q.to(torch.float) - q_zp) * q_scale k = (k.to(torch.float) - k_zp) * k_scale @@ -168,16 +168,16 @@ def test_scaled_dot_product_int8_op( ): torch.manual_seed(1234) device = "cpu" - q_zp = int(127) q_scale = float(1.7907238006591797) - k_zp = int(125) + q_zp = int(127) k_scale = float(1.8039721250534058) - v_zp = int(127) + k_zp = int(125) v_scale = float(1.839004635810852) - a_zp = int(120) + v_zp = int(127) a_scale = float(0.003919653594493866) - o_zp = int(128) + a_zp = int(120) o_scale = float(1.8191684484481812) + o_zp = int(128) q_shape = [batch_size, q_seq_len, n_head, head_dim] kv_shape = [batch_size, kv_seq_len, n_head, head_dim] mask_shape = [batch_size, 1, 1, kv_seq_len] @@ -212,16 +212,16 @@ def test_scaled_dot_product_int8_op( attn_mask=attn_mask, dropout_p=0.0, is_causal=False, - q_zp=q_zp, q_scale=q_scale, - k_zp=k_zp, + q_zp=q_zp, k_scale=k_scale, - v_zp=v_zp, + k_zp=k_zp, v_scale=v_scale, - a_zp=a_zp, + v_zp=v_zp, a_scale=a_scale, - o_zp=o_zp, + a_zp=a_zp, o_scale=o_scale, + o_zp=o_zp, ) actual = torch.ops.torchao.scaled_dot_product_int8( q, @@ -230,16 +230,16 @@ def test_scaled_dot_product_int8_op( attn_mask=attn_mask_2, dropout_p=0.0, is_causal=False, - q_zp=q_zp, q_scale=q_scale, - k_zp=k_zp, + q_zp=q_zp, k_scale=k_scale, - v_zp=v_zp, + k_zp=k_zp, v_scale=v_scale, - a_zp=a_zp, + v_zp=v_zp, a_scale=a_scale, - o_zp=o_zp, + a_zp=a_zp, o_scale=o_scale, + o_zp=o_zp, ) self.assertEqual(actual, math_ref, atol=1.0, rtol=5e-6) diff --git a/torchao/csrc/cpu/int8_sdpa.cpp b/torchao/csrc/cpu/int8_sdpa.cpp index 1f0ea758fe..9a41a28d9c 100644 --- a/torchao/csrc/cpu/int8_sdpa.cpp +++ b/torchao/csrc/cpu/int8_sdpa.cpp @@ -739,16 +739,16 @@ sdpa_int8_fused_kernel_impl( bool is_causal, std::optional attention_mask, double scale, - int32_t q_zp, float q_scale, - int32_t k_zp, + int32_t q_zp, float k_scale, - int32_t v_zp, + int32_t k_zp, float v_scale, - int32_t a_zp, + int32_t v_zp, float a_scale, - int32_t o_zp, - float o_scale) { + int32_t a_zp, + float o_scale, + int32_t o_zp) { // Query (Batch x Num_heads x Q_seq_len x Dim_per_head) // -> (Batch x Q_seq_len x Num_heads x Dim_per_head) // Key (Batch x Num_heads x KV_seq_len x Dim_per_head) @@ -1167,16 +1167,16 @@ sdpa_int8_fused_kernel_impl( bool is_causal, std::optional attention_mask, double scale, - int32_t q_zp, float q_scale, - int32_t k_zp, + int32_t q_zp, float k_scale, - int32_t v_zp, + int32_t k_zp, float v_scale, - int32_t a_zp, + int32_t v_zp, float a_scale, - int32_t o_zp, - float o_scale) { + int32_t a_zp, + float o_scale, + int32_t o_zp) { // Query (Batch x Num_heads x Q_seq_len x Dim_per_head) // -> (Batch x Q_seq_len x Num_heads x Dim_per_head) // Key (Batch x Num_heads x KV_seq_len x Dim_per_head) @@ -1634,36 +1634,36 @@ sdpa_int8_fused_kernel_impl( bool is_causal, std::optional attn_mask, double scale, - int32_t q_zp, float q_scale, - int32_t k_zp, + int32_t q_zp, float k_scale, - int32_t v_zp, + int32_t k_zp, float v_scale, - int32_t a_zp, + int32_t v_zp, float a_scale, - int32_t o_zp, - float o_scale) { + int32_t a_zp, + float o_scale, + int32_t o_zp) { if (use_one_parallel_loop) { sdpa_int8_fused_kernel_impl( output, query, key, value, dropout_p, is_causal, attn_mask, scale, - q_zp, q_scale, - k_zp, k_scale, - v_zp, v_scale, - a_zp, a_scale, - o_zp, o_scale); + q_scale, q_zp, + k_scale, k_zp, + v_scale, v_zp, + a_scale, a_zp, + o_scale, o_zp); } else { sdpa_int8_fused_kernel_impl( output, query, key, value, dropout_p, is_causal, attn_mask, scale, - q_zp, q_scale, - k_zp, k_scale, - v_zp, v_scale, - a_zp, a_scale, - o_zp, o_scale); + q_scale, q_zp, + k_scale, k_zp, + v_scale, v_zp, + a_scale, a_zp, + o_scale, o_zp); } } @@ -1692,16 +1692,16 @@ void sdpa_int8_fused_kernel( bool is_causal, std::optional attn_mask, double scale, - long q_zp, - double q_scale, - long k_zp, - double k_scale, - long v_zp, - double v_scale, - long a_zp, - double a_scale, - long o_zp, - double o_scale) { + float q_scale, + int32_t q_zp, + float k_scale, + int32_t k_zp, + float v_scale, + int32_t v_zp, + float a_scale, + int32_t a_zp, + float o_scale, + int32_t o_zp) { TORCH_CHECK(query.scalar_type() == c10::kByte); int64_t batchSize = query.size(0); int64_t num_head = query.size(1); @@ -1727,31 +1727,31 @@ void sdpa_int8_fused_kernel( use_one_parallel_loop, output, query, key, value, dropout_p, is_causal, attn_mask, scale, - q_zp, q_scale, - k_zp, k_scale, - v_zp, v_scale, - a_zp, a_scale, - o_zp, o_scale); + q_scale, q_zp, + k_scale, k_zp, + v_scale, v_zp, + a_scale, a_zp, + o_scale, o_zp); } else if (q_split_size == 64) { sdpa_int8_fused_kernel_impl( use_one_parallel_loop, output, query, key, value, dropout_p, is_causal, attn_mask, scale, - q_zp, q_scale, - k_zp, k_scale, - v_zp, v_scale, - a_zp, a_scale, - o_zp, o_scale); + q_scale, q_zp, + k_scale, k_zp, + v_scale, v_zp, + a_scale, a_zp, + o_scale, o_zp); } else { sdpa_int8_fused_kernel_impl( use_one_parallel_loop, output, query, key, value, dropout_p, is_causal, attn_mask, scale, - q_zp, q_scale, - k_zp, k_scale, - v_zp, v_scale, - a_zp, a_scale, - o_zp, o_scale); + q_scale, q_zp, + k_scale, k_zp, + v_scale, v_zp, + a_scale, a_zp, + o_scale, o_zp); } } else { AT_DISPATCH_MASK_TYPES(attn_mask.value().scalar_type(), "sdpa_mask", [&]() { @@ -1760,31 +1760,31 @@ void sdpa_int8_fused_kernel( use_one_parallel_loop, output, query, key, value, dropout_p, is_causal, attn_mask, scale, - q_zp, q_scale, - k_zp, k_scale, - v_zp, v_scale, - a_zp, a_scale, - o_zp, o_scale); + q_scale, q_zp, + k_scale, k_zp, + v_scale, v_zp, + a_scale, a_zp, + o_scale, o_zp); } else if (q_split_size == 64) { sdpa_int8_fused_kernel_impl( use_one_parallel_loop, output, query, key, value, dropout_p, is_causal, attn_mask, scale, - q_zp, q_scale, - k_zp, k_scale, - v_zp, v_scale, - a_zp, a_scale, - o_zp, o_scale); + q_scale, q_zp, + k_scale, k_zp, + v_scale, v_zp, + a_scale, a_zp, + o_scale, o_zp); } else { sdpa_int8_fused_kernel_impl( use_one_parallel_loop, output, query, key, value, dropout_p, is_causal, attn_mask, scale, - q_zp, q_scale, - k_zp, k_scale, - v_zp, v_scale, - a_zp, a_scale, - o_zp, o_scale); + q_scale, q_zp, + k_scale, k_zp, + v_scale, v_zp, + a_scale, a_zp, + o_scale, o_zp); } }); } @@ -1799,16 +1799,16 @@ at::Tensor sdpa_int8_math_kernel( bool is_causal, std::optional attn_mask, double scale, - int32_t q_zp, float q_scale, - int32_t k_zp, + int32_t q_zp, float k_scale, - int32_t v_zp, + int32_t k_zp, float v_scale, - int32_t a_zp, + int32_t v_zp, float a_scale, - int32_t o_zp, - float o_scale) { + int32_t a_zp, + float o_scale, + int32_t o_zp) { // dequant q/k/v auto q = (query.to(at::kFloat) - q_zp) * q_scale; auto k = (key.to(at::kFloat) - k_zp) * k_scale; @@ -1842,16 +1842,16 @@ at::Tensor _scaled_dot_product_int8_cpu( double dropout_p, bool is_causal, double scale, - int64_t q_zp, double q_scale, - int64_t k_zp, + int64_t q_zp, double k_scale, - int64_t v_zp, + int64_t k_zp, double v_scale, - int64_t a_zp, + int64_t v_zp, double a_scale, - int64_t o_zp, - double o_scale) { + int64_t a_zp, + double o_scale, + int64_t o_zp) { const auto dtype = query.scalar_type(); TORCH_CHECK(!query.is_nested() && !key.is_nested() && !value.is_nested(), "_scaled_dot_product_int8_cpu: Only accept plain inputs"); @@ -1878,21 +1878,21 @@ at::Tensor _scaled_dot_product_int8_cpu( at::Tensor output = at::empty_like(query, query.options()).transpose(1, 2); sdpa_int8_fused_kernel(output, query, key, value, dropout_p, is_causal, attn_mask, scale, - q_zp, q_scale, - k_zp, k_scale, - v_zp, v_scale, - a_zp, a_scale, - o_zp, o_scale); + q_scale, q_zp, + k_scale, k_zp, + v_scale, v_zp, + a_scale, a_zp, + o_scale, o_zp); return output.transpose(1, 2); } else { #endif // CPU_CAPABILITY_AVX512 return sdpa_int8_math_kernel(query, key, value, dropout_p, is_causal, attn_mask, scale, - q_zp, q_scale, - k_zp, k_scale, - v_zp, v_scale, - a_zp, a_scale, - o_zp, o_scale).transpose(1, 2).contiguous().transpose(1, 2); + q_scale, q_zp, + k_scale, k_zp, + v_scale, v_zp, + a_scale, a_zp, + o_scale, o_zp).transpose(1, 2).contiguous().transpose(1, 2); #ifdef CPU_CAPABILITY_AVX512 } #endif // CPU_CAPABILITY_AVX512 diff --git a/torchao/ops.py b/torchao/ops.py index 2c1b2c5368..2f8b4ae645 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -56,7 +56,7 @@ tags=[torch._C.Tag.needs_fixed_stride_order], ) lib.define( - "scaled_dot_product_int8(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, float scale=0.0, int q_zp=0, float q_scale=1.0, int k_zp=0, float k_scale=1.0, int v_zp=0, float v_scale=1.0, int a_zp=0, float a_scale=1.0, int o_zp=0, float o_scale=1.0) -> Tensor" + "scaled_dot_product_int8(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, float scale=0.0, float q_scale=1.0, int q_zp=0, float k_scale=1.0, int k_zp=0, float v_scale=1.0, int v_zp=0, float a_scale=1.0, int a_zp=0, float o_scale=1.0, int o_zp=0) -> Tensor" ) @@ -170,16 +170,16 @@ def scaled_dot_product_int8( dropout_p: float = 0.0, is_causal: bool = False, scale: float = 0.0, - q_zp: int = 0, q_scale: float = 1.0, - k_zp: int = 0, + q_zp: int = 0, k_scale: float = 1.0, - v_zp: int = 0, + k_zp: int = 0, v_scale: float = 1.0, - a_zp: int = 0, + v_zp: int = 0, a_scale: float = 1.0, - o_zp: int = 0, + a_zp: int = 0, o_scale: float = 1.0, + o_zp: int = 0, ) -> Tensor: """ Quantized SDPA with uint8 inputs and outputs. @@ -192,16 +192,16 @@ def scaled_dot_product_int8( dropout_p: dropout probability, is_causal: causal flag, scale: scaling factor applied prior to softmax, - q_zp: zero point for query from linear quantization, q_scale: scale for query from linear quantization, - k_zp: zero point of key from linear quantization, + q_zp: zero point for query from linear quantization, k_scale: scale for key from linear quantization, - v_zp: zero point of value from linear quantization, + k_zp: zero point of key from linear quantization, v_scale: zero point for value from linear quantization, - a_zp: zero point for attention from softmax quantization, + v_zp: zero point of value from linear quantization, a_scale: scale for attention from softmax quantization, - o_zp: zero point for output from linear quantization, + a_zp: zero point for attention from softmax quantization, o_scale: scale for output from linear quantization, + o_zp: zero point for output from linear quantization, Returns output of quantized SDPA @@ -214,16 +214,16 @@ def scaled_dot_product_int8( dropout_p, is_causal, scale, - q_zp, q_scale, - k_zp, + q_zp, k_scale, - v_zp, + k_zp, v_scale, - a_zp, + v_zp, a_scale, - o_zp, + a_zp, o_scale, + o_zp, ) @@ -236,16 +236,16 @@ def _( dropout_p: float = 0.0, is_causal: bool = False, scale: float = 0.0, - q_zp: int = 0, q_scale: float = 1.0, - k_zp: int = 0, + q_zp: int = 0, k_scale: float = 1.0, - v_zp: int = 0, + k_zp: int = 0, v_scale: float = 1.0, - a_zp: int = 0, + v_zp: int = 0, a_scale: float = 1.0, - o_zp: int = 0, + a_zp: int = 0, o_scale: float = 1.0, + o_zp: int = 0, ) -> Tensor: return query diff --git a/torchao/prototype/inductor/fx_passes/int8_sdpa_fusion.py b/torchao/prototype/inductor/fx_passes/int8_sdpa_fusion.py index e805da1327..a8f181f2db 100644 --- a/torchao/prototype/inductor/fx_passes/int8_sdpa_fusion.py +++ b/torchao/prototype/inductor/fx_passes/int8_sdpa_fusion.py @@ -54,16 +54,16 @@ def int8_sdpa(match: Match, *args, **kwargs): value = kwargs["value"] inv_scale = kwargs["inv_scale"] attn_mask = kwargs["attn_mask"] if "attn_mask" in kwargs else None - q_zp = kwargs["q_zp"] q_scale = kwargs["q_scale"] - k_zp = kwargs["k_zp"] + q_zp = kwargs["q_zp"] k_scale = kwargs["k_scale"] - v_zp = kwargs["v_zp"] + k_zp = kwargs["k_zp"] v_scale = kwargs["v_scale"] - a_zp = kwargs["a_zp"] + v_zp = kwargs["v_zp"] a_scale = kwargs["a_scale"] - o_zp = kwargs["o_zp"] + a_zp = kwargs["a_zp"] o_scale = kwargs["o_scale"] + o_zp = kwargs["o_zp"] counters["inductor"]["int8_fuse_attention"] += 1 counters["inductor"]["int8_sdpa_nodes"] += len(match.nodes) @@ -78,16 +78,16 @@ def int8_sdpa(match: Match, *args, **kwargs): 0.0, # dropout False, # is_causal 1.0 / inv_scale, # scale - q_zp, q_scale, - k_zp, + q_zp, k_scale, - v_zp, + k_zp, v_scale, - a_zp, + v_zp, a_scale, - o_zp, + a_zp, o_scale, + o_zp, ) trans_output = L[aten.permute.default](output, [0, 2, 1, 3]) return L[aten.clone.default]( From 27d6fc675178618fda6b1eaa8f14288e3490c895 Mon Sep 17 00:00:00 2001 From: Valentine233 Date: Tue, 22 Apr 2025 01:39:27 +0000 Subject: [PATCH 35/36] fix internal build errors --- torchao/csrc/cpu/int8_sdpa.cpp | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/torchao/csrc/cpu/int8_sdpa.cpp b/torchao/csrc/cpu/int8_sdpa.cpp index 9a41a28d9c..36cd24ab5e 100644 --- a/torchao/csrc/cpu/int8_sdpa.cpp +++ b/torchao/csrc/cpu/int8_sdpa.cpp @@ -1,14 +1,14 @@ -#pragma once #include #include #include -#include #include #include #include #include - #include +#include +#include +#include #include #ifndef AT_PER_OPERATOR_HEADERS @@ -17,12 +17,10 @@ #include #endif -#include #include #include #include #include -#include namespace torchao { From 2dc7d78fb315ef88e8c14412d587cb3502d5eee3 Mon Sep 17 00:00:00 2001 From: Valentine233 Date: Fri, 25 Apr 2025 01:47:46 +0000 Subject: [PATCH 36/36] ruff format --- setup.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/setup.py b/setup.py index 55ec3644d2..9482993d43 100644 --- a/setup.py +++ b/setup.py @@ -95,9 +95,9 @@ def __init__(self): default=(self._is_arm64() and self._is_macos()), ) if self.build_cpu_aarch64: - assert ( - self._is_arm64() - ), "TORCHAO_BUILD_CPU_AARCH64 requires an arm64 machine" + assert self._is_arm64(), ( + "TORCHAO_BUILD_CPU_AARCH64 requires an arm64 machine" + ) # TORCHAO_BUILD_KLEIDIAI is disabled by default for now because # 1) It increases the build time @@ -106,9 +106,9 @@ def __init__(self): "TORCHAO_BUILD_KLEIDIAI", default=False ) if self.build_kleidi_ai: - assert ( - self.build_cpu_aarch64 - ), "TORCHAO_BUILD_KLEIDIAI requires TORCHAO_BUILD_CPU_AARCH64 be set" + assert self.build_cpu_aarch64, ( + "TORCHAO_BUILD_KLEIDIAI requires TORCHAO_BUILD_CPU_AARCH64 be set" + ) # TORCHAO_BUILD_EXPERIMENTAL_MPS is disabled by default. self.build_experimental_mps = self._os_bool_var( @@ -117,9 +117,9 @@ def __init__(self): if self.build_experimental_mps: assert self._is_macos(), "TORCHAO_BUILD_EXPERIMENTAL_MPS requires MacOS" assert self._is_arm64(), "TORCHAO_BUILD_EXPERIMENTAL_MPS requires arm64" - assert ( - torch.mps.is_available() - ), "TORCHAO_BUILD_EXPERIMENTAL_MPS requires MPS be available" + assert torch.mps.is_available(), ( + "TORCHAO_BUILD_EXPERIMENTAL_MPS requires MPS be available" + ) def _is_arm64(self) -> bool: return platform.machine().startswith("arm64")