diff --git a/torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h b/torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h index f3b5c1be77..7adf48998e 100644 --- a/torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h +++ b/torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h @@ -726,6 +726,48 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_128_lowbit_values_with_lut( unpacked7 = vqtbl1q_s8(lut, idx7); } + +TORCHAO_ALWAYS_INLINE inline void lookup_and_store_16_fp32_values( + float* out, + const uint8x16_t& idx, + const int8x16x4_t& lut) { + +const int8x16_t s_idx = vreinterpretq_s8_u8(idx); +int8x16_t b0 = vqtbl1q_s8(lut.val[0], s_idx); +int8x16_t b1 = vqtbl1q_s8(lut.val[1], s_idx); +int8x16_t b2 = vqtbl1q_s8(lut.val[2], s_idx); +int8x16_t b3 = vqtbl1q_s8(lut.val[3], s_idx); + +int8x16x4_t result_bytes = {b0, b1, b2, b3}; +vst4q_s8(reinterpret_cast(out), result_bytes); +} + +template +TORCHAO_ALWAYS_INLINE inline void unpack_128_lowbit_values_with_fp32_lut( + float* unpacked, + const uint8_t* packed, + const int8x16x4_t& lut +) { + uint8x16_t idx0; + uint8x16_t idx1; + uint8x16_t idx2; + uint8x16_t idx3; + uint8x16_t idx4; + uint8x16_t idx5; + uint8x16_t idx6; + uint8x16_t idx7; + vec_unpack_128_uintx_values( + idx0, idx1, idx2, idx3, idx4, idx5, idx6, idx7, packed); + lookup_and_store_16_fp32_values(unpacked + 0, idx0, lut); + lookup_and_store_16_fp32_values(unpacked + 16, idx1, lut); + lookup_and_store_16_fp32_values(unpacked + 32, idx2, lut); + lookup_and_store_16_fp32_values(unpacked + 48, idx3, lut); + lookup_and_store_16_fp32_values(unpacked + 64, idx4, lut); + lookup_and_store_16_fp32_values(unpacked + 80, idx5, lut); + lookup_and_store_16_fp32_values(unpacked + 96, idx6, lut); + lookup_and_store_16_fp32_values(unpacked + 112, idx7, lut); +} + } // namespace bitpacking } // namespace torchao diff --git a/torchao/experimental/kernels/cpu/aarch64/linear/groupwise_lowbit_weight_with_lut/groupwise_lowbit_weight_with_lut.h b/torchao/experimental/kernels/cpu/aarch64/linear/groupwise_lowbit_weight_with_lut/groupwise_lowbit_weight_with_lut.h new file mode 100644 index 0000000000..e6a907d8e7 --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/linear/groupwise_lowbit_weight_with_lut/groupwise_lowbit_weight_with_lut.h @@ -0,0 +1,254 @@ +#pragma once + +#if defined(__aarch64__) || defined(__ARM_NEON) + +#include +#include +#include +#include // For std::invalid_argument + +#include +#include +#include + +namespace torchao::kernels::cpu::aarch64::linear::groupwise_lowbit_weight_with_lut { + +/** + * @brief Calculates the total memory in bytes required for a packed activation buffer. + * + * This function must be called to determine the correct buffer size to allocate + * before calling `pack_activations`. It accounts for any padding needed to + * make the 'm' dimension a multiple of the kernel's row-tiling factor (MR). + * + * @param m The number of rows in the source activation matrix. + * @param k The number of columns in the source activation matrix. + * @param MR The row-tiling factor of the micro-kernel that will consume this + * packed data (e.g., 4 or 8). + * @return The required size of the buffer in bytes. + */ +inline size_t packed_activations_size(int m, int k, int MR) { + return activation_packing::packed_activations_size(m, k, MR); +} + +/** + * @brief Calculates the number of float elements required for a packed activation buffer. + * + * @param m The number of rows in the source activation matrix. + * @param k The number of columns in the source activation matrix. + * @param MR The row-tiling factor of the micro-kernel that will consume this + * packed data (e.g., 4 or 8). + * @return The number of float elements required for a packed activation buffer. + */ + inline size_t packed_activations_size_float(int m, int k, int MR) { + return activation_packing::packed_activations_size(m, k, MR)/sizeof(float); +} + +/** + * @brief Packs a row-major activation matrix into a kernel-optimized blocked layout. + * + * This function rearranges the source matrix into a (M/MR, K, MR) format, + * which allows the compute kernel to load activation data for MR rows with a + * single vector instruction, dramatically improving performance. + * + * The destination buffer `packed_activations_out` must be pre-allocated by the + * caller with the size returned by `packed_activations_size()`. + * + * @param packed_activations_out Pointer to the destination buffer. + * @param m The number of rows in the source activation matrix. + * @param k The number of columns in the source activation matrix. + * @param activations_in Pointer to the source activation matrix (float32, row-major). + * @param MR The row-tiling factor of the target micro-kernel. This function + * currently supports MR values of 4. + */ +inline void pack_activations( + void* packed_activations_out, + int m, + int k, + const float* activations_in, + int MR) { + + switch (MR) { + case 4: + activation_packing::pack_activations_for_kernel<4>(packed_activations_out, m, k, activations_in); + break; + default: + throw std::invalid_argument("Unsupported MR value for activation packing. Supported values: [4]."); + } +} + +/** + * @brief Calculates the total size in bytes required for the packed weight buffer. + * + * This function must be called to allocate a sufficiently large buffer before + * calling `pack_weights`. + * + * @param weight_nbit The number of bits per weight (e.g., 2, 3, 4). + * @param n The number of output channels (columns of the weight matrix). + * @param k The number of input channels (rows of the weight matrix). + * @param has_bias Whether the packed buffer should include space for a bias vector. + * @param scale_group_size The number of weights that share a single scale factor. + * @param lut_group_size The number of weights that share a single Look-Up Table (LUT). + * @param NR The column-tiling factor of the micro-kernel (e.g., 16 or 8). + * @param promote_to_4bit_layout If true, the packed weights will be promoted to 4-bit layout. + * @return The required size of the buffer in bytes. + */ + inline size_t packed_weights_size( + int weight_nbit, + int n, + int k, + bool has_bias, + int scale_group_size, + int lut_group_size, + int NR, bool promote_to_4bit_layout) { + + if (NR == 16) { + switch (weight_nbit) { + case 1: + return torchao::kernels::cpu::aarch64::linear::groupwise_lowbit_weight_with_lut::weight_packing::packed_weights_size_for_fused_lut_kernel<1, 16>(n, k, has_bias, scale_group_size, lut_group_size, promote_to_4bit_layout); + case 2: + return torchao::kernels::cpu::aarch64::linear::groupwise_lowbit_weight_with_lut::weight_packing::packed_weights_size_for_fused_lut_kernel<2, 16>(n, k, has_bias, scale_group_size, lut_group_size, promote_to_4bit_layout); + case 3: + return torchao::kernels::cpu::aarch64::linear::groupwise_lowbit_weight_with_lut::weight_packing::packed_weights_size_for_fused_lut_kernel<3, 16>(n, k, has_bias, scale_group_size, lut_group_size, promote_to_4bit_layout); + case 4: + return torchao::kernels::cpu::aarch64::linear::groupwise_lowbit_weight_with_lut::weight_packing::packed_weights_size_for_fused_lut_kernel<4, 16>(n, k, has_bias, scale_group_size, lut_group_size, promote_to_4bit_layout); + default: + throw std::invalid_argument("Unsupported weight_nbit. Must be 1, 2, 3, or 4."); + } + } else { + throw std::invalid_argument("Unsupported NR value for weight packing. Supported values: [16]."); + } +} + +/** + * @brief Packs quantized weights, scales, LUTs, and biases into a single + * contiguous buffer optimized for the target micro-kernel. + * + * This function is the primary entry point for preparing weights. It handles + * transposition, bit-packing, metadata layout, and padding. The caller must + * pre-allocate `packed_weights_ptr` with the size returned by `packed_weights_size`. + * + * @param packed_weights_ptr Pointer to the destination buffer. + * @param B_qvals Pointer to the source quantized weights, stored as uint8_t values + * in a (K, N) row-major layout. + * @param weight_scales A vector of all unique scale factors. + * @param weight_luts A vector of all unique Look-Up Tables (LUTs). + * @param weight_nbit The number of bits per weight (e.g., 2, 3, 4). + * @param N The number of output channels (columns of weights). + * @param K The number of input channels (rows of weights). + * @param scale_group_size The grouping factor for scales. + * @param lut_group_size The grouping factor for LUTs. + * @param NR The column-tiling factor for the kernel (e.g., 16). + * @param promote_to_4bit_layout If true, the packed weights will be promoted to 4-bit layout. + */ +inline void pack_weights( + // Output + void* packed_weights_ptr, + // Inputs + const uint8_t* B_qvals, + const std::vector& weight_scales, + const std::vector& weight_luts, + int weight_nbit, + int N, + int K, + int scale_group_size, + int lut_group_size, + int NR, + bool promote_to_4bit_layout) { + + // Dispatcher to call the correct templated implementation. + if (NR == 16) { + switch (weight_nbit) { + case 4: + weight_packing::pack_weights_with_fused_lut<4, 16>( + packed_weights_ptr, B_qvals, weight_scales, weight_luts, + N, K, scale_group_size, lut_group_size, promote_to_4bit_layout); + break; + case 3: + weight_packing::pack_weights_with_fused_lut<3, 16>( + packed_weights_ptr, B_qvals, weight_scales, weight_luts, + N, K, scale_group_size, lut_group_size, promote_to_4bit_layout); + break; + case 2: + weight_packing::pack_weights_with_fused_lut<2, 16>( + packed_weights_ptr, B_qvals, weight_scales, weight_luts, + N, K, scale_group_size, lut_group_size, promote_to_4bit_layout); + break; + case 1: + weight_packing::pack_weights_with_fused_lut<1, 16>( + packed_weights_ptr, B_qvals, weight_scales, weight_luts, + N, K, scale_group_size, lut_group_size, promote_to_4bit_layout); + break; + default: + throw std::invalid_argument("Unsupported weight_nbit for packing. Must be 1, 2, 3, or 4."); + } + } + else { + throw std::invalid_argument("Unsupported NR for weight packing."); + } +} + +/** + * @brief Computes a group-wise low-bit GEMM using an optimized NEON kernel. + * + * This function selects the best available micro-kernel based on the provided + * tile sizes (MR and NR) and dispatches the computation. + * + * @param output Pointer to the output matrix C. + * @param output_m_stride The stride (in elements) between rows of the output matrix. + * @param m Number of rows in A and C. + * @param n Number of columns in B and C. + * @param k Number of columns in A and rows in B. + * @param scale_group_size The grouping factor for scales. + * @param lut_group_size The grouping factor for LUTs. + * @param packed_weights Pointer to the pre-packed weight buffer. + * @param packed_activations Pointer to the pre-packed activation buffer. + * @param biases Pointer to the bias vector. + * @param clamp_min Minimum value for the fused clamp (ReLU) operation. + * @param clamp_max Maximum value for the fused clamp (ReLU6) operation. + * @param has_bias If true, applies the bias. + * @param has_clamp If true, applies the clamping. + * @param weight_nbit The true bit-width of the weights (e.g., 2, 3, 4). + * @param MR The row-tiling factor to use (e.g., 4). Selects the kernel. + * @param NR The column-tiling factor to use (e.g., 16). Selects the kernel. + */ + inline void groupwise_lowbit_lut_kernel( + float* output, + int output_m_stride, + int m, + int n, + int k, + int scale_group_size, + int lut_group_size, + const void* packed_weights, + const void* packed_activations, + const float* biases, + float clamp_min, + float clamp_max, + bool has_bias, + bool has_clamp, + int weight_nbit, + int MR, + int NR) { + +if (MR == 4 && NR == 16) { + kernel::groupwise_lowbit_lut_kernel_4x16( + output, + output_m_stride, + m, n, k, + scale_group_size, + lut_group_size, + packed_weights, + packed_activations, + biases, + clamp_min, clamp_max, + has_bias, has_clamp, weight_nbit); + } + else { + throw std::invalid_argument( + "Unsupported MR/NR combination. Supported values: [MR=4, NR=16]." + ); + } + } +}// namespace torchao::kernels::cpu::aarch64::linear::groupwise_lowbit_weight_with_lut + +#endif // defined(__aarch64__) || defined(__ARM_NEON) diff --git a/torchao/experimental/kernels/cpu/aarch64/linear/groupwise_lowbit_weight_with_lut/kernel_f32.h b/torchao/experimental/kernels/cpu/aarch64/linear/groupwise_lowbit_weight_with_lut/kernel_f32.h new file mode 100644 index 0000000000..0982417698 --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/linear/groupwise_lowbit_weight_with_lut/kernel_f32.h @@ -0,0 +1,273 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#if defined(__aarch64__) || defined(__ARM_NEON) + +#include +#include +#include +#include + +namespace torchao::kernels::cpu::aarch64::linear:: + groupwise_lowbit_weight_with_lut::kernel { + + +namespace internal { + +inline void compute_4x16_4bit_promoted( + float32x4_t accum[4][4], + const float* __restrict__ activation, + const uint8_t* __restrict__ weight_indices, + int K, + const uint8x16x4_t& tbl) +{ + constexpr int MR = 4; + constexpr int NR = 16; + assert(K > 0 && "K must be positive"); + namespace utils = torchao::kernels::cpu::aarch64::linear::groupwise_lowbit_weight_with_lut::utils; + + const uint8x16_t SIXTEEN = vdupq_n_u8(16); + const uint8x16_t THIRTY_TWO = vdupq_n_u8(32); + const uint8x16_t FORTY_EIGHT = vdupq_n_u8(48); + const uint8_t* idx_ptr = weight_indices; + const float* a_ptr = activation; + + for (int k_idx = 0; k_idx < K; ++k_idx) { + + uint8x8_t packed_neon = vld1_u8(idx_ptr); + // Unpack the 8-bit indices into 16 4-bit indices. + uint8x8x2_t interleaved = vzip_u8(vshr_n_u8(packed_neon, 4), vand_u8(packed_neon, vdup_n_u8(0x0F))); + uint8x16_t unpacked_indices_neon = vcombine_u8(interleaved.val[0], interleaved.val[1]); + + uint8x16_t idx_plane0 = unpacked_indices_neon; + uint8x16_t idx_plane1 = vaddq_u8(unpacked_indices_neon, SIXTEEN); + uint8x16_t idx_plane2 = vaddq_u8(unpacked_indices_neon, THIRTY_TWO); + uint8x16_t idx_plane3 = vaddq_u8(unpacked_indices_neon, FORTY_EIGHT); + + uint8x16_t b0 = vqtbl4q_u8(tbl, idx_plane0); + uint8x16_t b1 = vqtbl4q_u8(tbl, idx_plane1); + uint8x16_t b2 = vqtbl4q_u8(tbl, idx_plane2); + uint8x16_t b3 = vqtbl4q_u8(tbl, idx_plane3); + + uint8x16x2_t zip_b01 = vzipq_u8(b0, b1); + uint8x16x2_t zip_b23 = vzipq_u8(b2, b3); + uint16x8x2_t trn_16_0 = vtrnq_u16(vreinterpretq_u16_u8(zip_b01.val[0]), vreinterpretq_u16_u8(zip_b23.val[0])); + uint16x8x2_t trn_16_1 = vtrnq_u16(vreinterpretq_u16_u8(zip_b01.val[1]), vreinterpretq_u16_u8(zip_b23.val[1])); + float32x4x2_t final_zip_0 = vzipq_f32(vreinterpretq_f32_u16(trn_16_0.val[0]), vreinterpretq_f32_u16(trn_16_0.val[1])); + float32x4x2_t final_zip_1 = vzipq_f32(vreinterpretq_f32_u16(trn_16_1.val[0]), vreinterpretq_f32_u16(trn_16_1.val[1])); + + float32x4_t w0_3 = final_zip_0.val[0]; + float32x4_t w4_7 = final_zip_0.val[1]; + float32x4_t w8_11 = final_zip_1.val[0]; + float32x4_t w12_15 = final_zip_1.val[1]; + + float32x4_t a_col = vld1q_f32(a_ptr); + + float32x4_t a0 = vdupq_laneq_f32(a_col, 0); + float32x4_t a1 = vdupq_laneq_f32(a_col, 1); + float32x4_t a2 = vdupq_laneq_f32(a_col, 2); + float32x4_t a3 = vdupq_laneq_f32(a_col, 3); + + accum[0][0] = vfmaq_f32(accum[0][0], w0_3, a0); + accum[0][1] = vfmaq_f32(accum[0][1], w4_7, a0); + accum[0][2] = vfmaq_f32(accum[0][2], w8_11, a0); + accum[0][3] = vfmaq_f32(accum[0][3], w12_15, a0); + + accum[1][0] = vfmaq_f32(accum[1][0], w0_3, a1); + accum[1][1] = vfmaq_f32(accum[1][1], w4_7, a1); + accum[1][2] = vfmaq_f32(accum[1][2], w8_11, a1); + accum[1][3] = vfmaq_f32(accum[1][3], w12_15, a1); + + accum[2][0] = vfmaq_f32(accum[2][0], w0_3, a2); + accum[2][1] = vfmaq_f32(accum[2][1], w4_7, a2); + accum[2][2] = vfmaq_f32(accum[2][2], w8_11, a2); + accum[2][3] = vfmaq_f32(accum[2][3], w12_15, a2); + + accum[3][0] = vfmaq_f32(accum[3][0], w0_3, a3); + accum[3][1] = vfmaq_f32(accum[3][1], w4_7, a3); + accum[3][2] = vfmaq_f32(accum[3][2], w8_11, a3); + accum[3][3] = vfmaq_f32(accum[3][3], w12_15, a3); + + a_ptr += MR; + idx_ptr += (NR / 2); + } +} + +template +inline void micro_kernel_lut_4x16( + float32x4_t accum[4][4], + const float* __restrict__ A, + const uint8_t* __restrict__ W, + int K_group_size) +{ + static_assert(WEIGHT_NBIT >= 1 && WEIGHT_NBIT <= 4, "WEIGHT_NBIT must be 1-4"); + namespace utils = torchao::kernels::cpu::aarch64::linear::groupwise_lowbit_weight_with_lut::utils; + + // 1. Get pointers to the packed data components + const auto* grp = reinterpret_cast*>(W); + const uint8_t* indices_ptr = reinterpret_cast(grp + 1); + + // 2. Perform LUT expansion + uint8x16x4_t tbl; + if constexpr (WEIGHT_NBIT < 4) { + const int src_lut_size = 1 << WEIGHT_NBIT; + alignas(16) uint8_t expanded_lut_soa[64]; + for (int plane = 0; plane < 4; ++plane) { + const uint8_t* src_plane = reinterpret_cast(&grp->lut_soa_planes[plane]); + uint8_t* dst_plane = &expanded_lut_soa[plane * 16]; + for (int i = 0; i < 16; ++i) { + dst_plane[i] = src_plane[i % src_lut_size]; + } + } + memcpy(&tbl, expanded_lut_soa, 64); + } else { // WEIGHT_NBIT == 4 + memcpy(&tbl, grp->lut_soa_planes, 64); + } + + // 3. Call the pure compute kernel with the prepared LUT + compute_4x16_4bit_promoted( + accum, + A, + indices_ptr, + K_group_size, + tbl); +} + +inline void post_process_and_store_4x16( + float* __restrict__ output, + int ldc, + float32x4_t accum[4][4], + const float* __restrict__ bias_ptr, + bool has_clamp, + float32x4_t clamp_min_vec, + float32x4_t clamp_max_vec) +{ + constexpr int MR = 4; + constexpr int NR = 16; + constexpr int NR_VEC = NR / 4; + + for (int m = 0; m < MR; ++m) { + float* out_row = output + m * ldc; + for (int nb = 0; nb < NR_VEC; ++nb) { + float32x4_t res = accum[m][nb]; + + if (bias_ptr != nullptr) { + float32x4_t bias_vec = vld1q_f32(bias_ptr + nb * 4); + res = vaddq_f32(res, bias_vec); + } + + if (has_clamp) { + res = vmaxq_f32(res, clamp_min_vec); + res = vminq_f32(res, clamp_max_vec); + } + vst1q_f32(out_row + nb * 4, res); + } + } +} +} // namespace internal + + +/** + * @brief Computes a group-wise low-bit GEMM using the 4x16 fused LUT kernel. + * + * It assumes activations have been pre-packed by `pack_activations(..., MR=4)` + * and weights by `pack_weights(..., NR=16)`. + */ +void groupwise_lowbit_lut_kernel_4x16( + float* output, + int output_m_stride, + int m, int n, int k, + int scale_group_size, + int lut_group_size, + const void* packed_weights, + const void* packed_activations, const float* biases, + float clamp_min, float clamp_max, + bool has_bias, bool has_clamp, int weight_nbit) { + + // --- 1. Define kernel parameters --- + constexpr int MR = 4; + constexpr int NR = 16; + constexpr bool promote_to_4bit_layout = true; + + const int packing_group_size = std::gcd(scale_group_size, lut_group_size); + + assert(n % NR == 0 && "N must be divisible by tile width NR"); + assert(m % MR == 0 && "M must be divisible by tile height MR"); + assert(k % packing_group_size == 0 && "K must be a multiple of the packing group size"); + + // --- 2. Get the memory layout --- + auto layout = utils::create_fused_lut_layout( + n, k, scale_group_size, lut_group_size, weight_nbit, promote_to_4bit_layout + ); + + // --- 3. Main loop --- + const float32x4_t clamp_min_vec = vdupq_n_f32(clamp_min); + const float32x4_t clamp_max_vec = vdupq_n_f32(clamp_max); + + const int num_groups_per_k_tile = k / packing_group_size; + + for (int m_tile_start = 0; m_tile_start < m; m_tile_start += MR) { + + const size_t activation_tile_size = (size_t)MR * k; + + // Calculate pointer by advancing by the number of tiles to skip. + const auto* current_packed_activations = static_cast(packed_activations) + (m_tile_start / MR) * activation_tile_size; + + for (int n_tile_start = 0; n_tile_start < n; n_tile_start += NR) { + + float32x4_t accumulators[MR][NR / 4] = {{0}}; + + + const auto* weights_for_tile_n = static_cast(packed_weights) + (n_tile_start / NR) * layout.n_tile_stride_bytes; + for (int k_group_idx = 0; k_group_idx < num_groups_per_k_tile; ++k_group_idx) { + // Calculate the starting column index for this group + const int k_group_start = k_group_idx * packing_group_size; + + // A_group_ptr holds the pointer to the relevant slice of the activation data + const float* A_group_ptr = current_packed_activations + k_group_start * MR; + + // W_group_ptr holds the pointer to the relevant packed weight group + const uint8_t* W_group_ptr = weights_for_tile_n + k_group_idx * layout.group_stride_bytes; + + switch (weight_nbit) { + case 4: + internal::micro_kernel_lut_4x16<4>( + accumulators, A_group_ptr, W_group_ptr, packing_group_size); + break; + case 3: + internal::micro_kernel_lut_4x16<3>( + accumulators, A_group_ptr, W_group_ptr, packing_group_size); + break; + case 2: + internal::micro_kernel_lut_4x16<2>( + accumulators, A_group_ptr, W_group_ptr, packing_group_size); + break; + case 1: + internal::micro_kernel_lut_4x16<1>( + accumulators, A_group_ptr, W_group_ptr, packing_group_size); + break; + default: + throw std::invalid_argument("Unsupported weight_nbit in kernel."); + } + } + internal::post_process_and_store_4x16( + output + m_tile_start * output_m_stride + n_tile_start, + output_m_stride, + accumulators, + has_bias ? biases + n_tile_start : nullptr, + has_clamp, + clamp_min_vec, + clamp_max_vec + ); + } + } +} + +} // torchao::kernels::cpu::aarch64::linear::groupwise_lowbit_weight_with_lut::kernel + +#endif // defined(__aarch64__) || defined(__ARM_NEON) diff --git a/torchao/experimental/kernels/cpu/aarch64/linear/groupwise_lowbit_weight_with_lut/pack_activation.h b/torchao/experimental/kernels/cpu/aarch64/linear/groupwise_lowbit_weight_with_lut/pack_activation.h new file mode 100644 index 0000000000..43b3a80d7d --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/linear/groupwise_lowbit_weight_with_lut/pack_activation.h @@ -0,0 +1,53 @@ +#pragma once + +#if defined(__aarch64__) || defined(__ARM_NEON) + +#include +#include +#include + +namespace torchao::kernels::cpu::aarch64::linear::groupwise_lowbit_weight_with_lut::activation_packing { + +inline size_t packed_activations_size(int m, int k, int MR) { + const int m_padded = ((m + MR - 1) / MR) * MR; + return (size_t)m_padded * k * sizeof(float); +} + +template +void pack_activations_for_kernel( + // Output + void* packed_activations, + // Inputs + int m, + int k, + const float* activations) { + + // --- 1. Initialization --- + float* packed_ptr = static_cast(packed_activations); + const int m_padded = ((m + MR - 1) / MR) * MR; + + // --- 2. Main Packing Loops --- + for (int m_start = 0; m_start < m_padded; m_start += MR) { + + for (int k_idx = 0; k_idx < k; ++k_idx) { + + for (int m_offset = 0; m_offset < MR; ++m_offset) { + + // --- 3. Handle Padding and Copy Data --- + const int current_m = m_start + m_offset; + + if (current_m < m) { + *packed_ptr = activations[(size_t)current_m * k + k_idx]; + } else { + *packed_ptr = 0.0f; + } + + packed_ptr++; + } + } + } +} + +} // namespace torchao::kernels::cpu::aarch64::linear::activation_packing + +#endif // defined(__aarch64__) || defined(__ARM_NEON) diff --git a/torchao/experimental/kernels/cpu/aarch64/linear/groupwise_lowbit_weight_with_lut/pack_weights.h b/torchao/experimental/kernels/cpu/aarch64/linear/groupwise_lowbit_weight_with_lut/pack_weights.h new file mode 100644 index 0000000000..eb1abfd655 --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/linear/groupwise_lowbit_weight_with_lut/pack_weights.h @@ -0,0 +1,166 @@ +#pragma once + +#if defined(__aarch64__) || defined(__ARM_NEON) + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace torchao::kernels::cpu::aarch64::linear:: + groupwise_lowbit_weight_with_lut::weight_packing { + +namespace internal { + +/** + * @brief Pack 2 4-bit indices into 1 byte. +*/ +void pack_region(uint8_t* dst, const uint8_t* src, size_t count, int nbit) { + if (nbit != 4) { + return; + } + + assert(count % 16 == 0); + + for (size_t i = 0; i < count; i += 16) { + const uint8_t* current_src = src + i; + uint8_t* current_dst = dst + (i / 2); + + // 1. Load: [i0, i1, i2, i3, ...] + uint8x16_t v_src = vld1q_u8(current_src); + + // 2. De-interleave: + // p.val[0] = [i0, i2, i4, ...] (even indices) + // p.val[1] = [i1, i3, i5, ...] (odd indices) + uint8x16x2_t p = vuzpq_u8(v_src, v_src); + + // 3. vshlq_n_u8(p.val[0], 4) results in: [(i0<<4), (i2<<4), (i4<<4), ...] + uint8x16_t high_nibbles = vshlq_n_u8(p.val[0], 4); + + // 4. result_16_bytes = [(i0<<4)|i1, (i2<<4)|i3, ..., (i14<<4)|i15] + uint8x16_t result_16_bytes = vorrq_u8(p.val[1], high_nibbles); + + // 5. Store the 8 packed bytes. + vst1_u8(current_dst, vget_low_u8(result_16_bytes)); + } +} + +namespace utils = torchao::kernels::cpu::aarch64::linear::groupwise_lowbit_weight_with_lut::utils; + +/** + * @brief A packer-side structure for holding a group's metadata + * before it is transposed into its final, hardware-friendly format. +*/ +template +struct PlainMetadataGroup { + alignas(16 * sizeof(float)) float fused_lut[16]; +}; + +template +inline void transpose_metadata_to_packed_format( + utils::FusedLutPackedWeightGroup* out, + const PlainMetadataGroup* in) +{ + const uint8_t* lut_ptr = reinterpret_cast(in->fused_lut); + uint8x16x4_t soa_lut = vld4q_u8(lut_ptr); + + out->lut_soa_planes[0] = soa_lut.val[0]; + out->lut_soa_planes[1] = soa_lut.val[1]; + out->lut_soa_planes[2] = soa_lut.val[2]; + out->lut_soa_planes[3] = soa_lut.val[3]; +} +} // namespace internal + +/** + * @brief Calculates the total size by delegating to the shared layout factory. + */ + template + size_t packed_weights_size_for_fused_lut_kernel( + int N, int K, bool has_bias, int scale_group_size, int lut_group_size, + bool promote_to_4bit_layout) { + + // The sizer's only job is to create the layout and return the total size. + utils::FusedLutPackedLayout layout = utils::create_fused_lut_layout( + N, K, scale_group_size, lut_group_size, weight_nbit, promote_to_4bit_layout + ); + + return layout.total_buffer_size; +} + +template +void pack_weights_with_fused_lut( + void* packed_weights_ptr, + const uint8_t* B_qvals, // Expected in (K, N) layout + const std::vector& weight_scales, + const std::vector& weight_luts, + int N, int K, int scale_group_size, int lut_group_size, + bool promote_to_4bit_layout) { + + namespace packing_internal = torchao::kernels::cpu::aarch64::linear::groupwise_lowbit_weight_with_lut::weight_packing::internal; + namespace utils = torchao::kernels::cpu::aarch64::linear::groupwise_lowbit_weight_with_lut::utils; + + // --- 1. Get layout and constant --- + utils::FusedLutPackedLayout layout = utils::create_fused_lut_layout( + N, K, scale_group_size, lut_group_size, weight_nbit, promote_to_4bit_layout); + const int packing_group_size = std::gcd(scale_group_size, lut_group_size); + constexpr int src_lut_size_per_entry = (1 << weight_nbit); + + // --- 2. Allocate temporary buffers --- + packing_internal::PlainMetadataGroup temp_group = {}; + std::vector indices_to_pack(packing_group_size * NR); + + // --- 3. Main packing logic --- + char* out_ptr = static_cast(packed_weights_ptr); + const int N_padded = ((N + NR - 1) / NR) * NR; + + for (int n_start = 0; n_start < N_padded; n_start += NR) { + for (int k_group_start = 0; k_group_start < K; k_group_start += packing_group_size) { + + int32_t scale_idx = k_group_start / scale_group_size; + int32_t lut_idx = k_group_start / lut_group_size; + + float scale = weight_scales.empty() ? 1.0f : weight_scales[scale_idx]; + const float* lut_src = weight_luts.data() + lut_idx * src_lut_size_per_entry; + + for (int i = 0; i < src_lut_size_per_entry; ++i) { + temp_group.fused_lut[i] = scale * lut_src[i]; + } + + for (int k_offset = 0; k_offset < packing_group_size; ++k_offset) { + for (int nr_idx = 0; nr_idx < NR; ++nr_idx) { + const int current_k = k_group_start + k_offset; + const int current_n = n_start + nr_idx; + if (current_n < N) { + indices_to_pack[k_offset * NR + nr_idx] = B_qvals[current_k * N + current_n]; + } else { + indices_to_pack[k_offset * NR + nr_idx] = 0; // Padding + } + } + } + + auto* header_dst = reinterpret_cast*>(out_ptr); + packing_internal::transpose_metadata_to_packed_format(header_dst, &temp_group); + + auto* indices_dst = out_ptr + layout.header_bytes_per_group; + int effective_bit_width = promote_to_4bit_layout ? 4 : weight_nbit; + packing_internal::pack_region( + reinterpret_cast(indices_dst), + indices_to_pack.data(), + packing_group_size * NR, + effective_bit_width); + + out_ptr += layout.group_stride_bytes; + } + } +} + +} // torchao::kernels::cpu::aarch64::linear::groupwise_lowbit_weight::weight_packing + +#endif // defined(__aarch64__) || defined(__ARM_NEON) diff --git a/torchao/experimental/kernels/cpu/aarch64/linear/groupwise_lowbit_weight_with_lut/utils.h b/torchao/experimental/kernels/cpu/aarch64/linear/groupwise_lowbit_weight_with_lut/utils.h new file mode 100644 index 0000000000..82c7974206 --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/linear/groupwise_lowbit_weight_with_lut/utils.h @@ -0,0 +1,65 @@ +#pragma once + +#if defined(__aarch64__) || defined(__ARM_NEON) + +#include +#include +#include +#include +#include +#include + +namespace torchao::kernels::cpu::aarch64::linear:: + groupwise_lowbit_weight_with_lut::utils { + + +template +struct FusedLutPackedWeightGroup { + uint8x16_t lut_soa_planes[4]; +}; + +struct FusedLutPackedLayout { + // --- Per-group sizes (Strides within a physical group) --- + size_t header_bytes_per_group; + size_t packed_indices_bytes_per_group; + + // --- Strides between physical groups --- + size_t group_stride_bytes; // Stride between k-groups for the same n-tile + size_t n_tile_stride_bytes; // Stride between n-tiles + + // --- Total Size --- + size_t total_buffer_size; +}; + +template +inline FusedLutPackedLayout create_fused_lut_layout( + int N, int K, int scale_group_size, int lut_group_size, + int weight_nbit, bool promote_to_4bit_layout) { + + FusedLutPackedLayout layout; + + int packing_group_size = std::gcd(scale_group_size, lut_group_size); + + if (promote_to_4bit_layout) { + layout.packed_indices_bytes_per_group = (size_t)packing_group_size * NR / 2; + } else { + layout.packed_indices_bytes_per_group = (size_t)packing_group_size * NR * weight_nbit / 8; + } + layout.header_bytes_per_group = sizeof(FusedLutPackedWeightGroup); + + // --- Calculate Strides --- + layout.group_stride_bytes = layout.header_bytes_per_group + layout.packed_indices_bytes_per_group; + + const int num_groups_per_k_tile = K / packing_group_size; + layout.n_tile_stride_bytes = num_groups_per_k_tile * layout.group_stride_bytes; + + // --- Calculate Total Size --- + const int N_padded = ((N + NR - 1) / NR) * NR; + const int num_n_tiles = N_padded / NR; + layout.total_buffer_size = num_n_tiles * layout.n_tile_stride_bytes; + + return layout; +} + +} +#endif // defined(__aarch64__) || defined(__ARM_NEON) diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt b/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt index 1fd2828fc5..104b63f335 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt +++ b/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt @@ -70,6 +70,14 @@ target_link_libraries( dep ) +add_executable(test_lut test_lut.cpp) +target_link_libraries( + test_lut + PRIVATE + GTest::gtest_main + dep +) + add_executable(test_reduction test_reduction.cpp) target_link_libraries( test_reduction @@ -125,6 +133,7 @@ gtest_discover_tests(test_quantization) gtest_discover_tests(test_reduction) gtest_discover_tests(test_bitpacking) gtest_discover_tests(test_linear) +gtest_discover_tests(test_lut) gtest_discover_tests(test_embedding) gtest_discover_tests(test_weight_packing) gtest_discover_tests(test_qmatmul) diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh b/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh index 5d28ea01cc..ae1b15ec85 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh +++ b/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh @@ -58,6 +58,7 @@ ${CMAKE_OUT}/test_quantization ${CMAKE_OUT}/test_reduction ${CMAKE_OUT}/test_bitpacking ${CMAKE_OUT}/test_linear +${CMAKE_OUT}/test_lut ${CMAKE_OUT}/test_embedding ${CMAKE_OUT}/test_weight_packing ${CMAKE_OUT}/test_qmatmul diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_bitpacking.cpp b/torchao/experimental/kernels/cpu/aarch64/tests/test_bitpacking.cpp index 7e7ccaea26..0fb59fc1cd 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_bitpacking.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_bitpacking.cpp @@ -869,6 +869,56 @@ void test_bitpacking_128_lowbit_values_with_lut() { } } +template +void test_bitpacking_128_lowbit_values_with_fp32_lut() { + + constexpr int num_values = 128; + + const int packed_bytes = (num_values * nbit + 7) / 8; + auto input = torchao::get_random_lowbit_vector(num_values, nbit); + std::vector packed(packed_bytes, 0); + + uint8x16_t idx0; + uint8x16_t idx1; + uint8x16_t idx2; + uint8x16_t idx3; + uint8x16_t idx4; + uint8x16_t idx5; + uint8x16_t idx6; + uint8x16_t idx7; + + torchao::bitpacking::internal::vec_load_64_uint8_values(idx0, idx1, idx2, idx3, input.data()); + torchao::bitpacking::internal::vec_load_64_uint8_values(idx4, idx5, idx6, idx7, input.data() + 64); + + // generate test cases + auto lut = torchao::get_random_vector(16, -1.0, 1.0); + + // prepare LUT + uint8x16x4_t luts_u8 = vld4q_u8(reinterpret_cast(lut.data())); + + // Now, just reinterpret to the signed type your function needs + int8x16x4_t luts = { + vreinterpretq_s8_u8(luts_u8.val[0]), + vreinterpretq_s8_u8(luts_u8.val[1]), + vreinterpretq_s8_u8(luts_u8.val[2]), + vreinterpretq_s8_u8(luts_u8.val[3]) + }; + + + torchao::bitpacking::vec_pack_128_uintx_values(packed.data(), idx0, idx1, idx2, idx3, idx4, idx5, idx6, idx7); + + std::vector unpacked(num_values, 0.0f); + torchao::bitpacking::unpack_128_lowbit_values_with_fp32_lut(unpacked.data(), packed.data(), luts); + + for (int i = 0; i < num_values; ++i) { + uint8_t original_index = input[i]; + float expected_value = lut[original_index]; + float actual_value = unpacked[i]; + EXPECT_FLOAT_EQ(actual_value, expected_value) + << "Mismatch at index " << i << " for nbit=" << nbit; + } +} + #define TEST_BITPACKING_32_LOWBIT_VALUES(nbit) \ TEST(test_bitpacking_32_lowbit_values_##nbit, PackUnpackAreSame) { \ test_bitpacking_32_lowbit_values(); \ @@ -889,6 +939,11 @@ void test_bitpacking_128_lowbit_values_with_lut() { test_bitpacking_128_lowbit_values_with_lut(); \ } +#define TEST_BITPACKING_128_LOWBIT_VALUES_WITH_FP32_LUT(nbit) \ +TEST(test_bitpacking_128_lowbit_values_with_fp32_lut_##nbit, PackUnpackAreSame) { \ + test_bitpacking_128_lowbit_values_with_fp32_lut(); \ +} + TEST_BITPACKING_32_LOWBIT_VALUES(1); TEST_BITPACKING_32_LOWBIT_VALUES(2); TEST_BITPACKING_32_LOWBIT_VALUES(3); @@ -916,6 +971,11 @@ TEST_BITPACKING_128_LOWBIT_VALUES(6); TEST_BITPACKING_128_LOWBIT_VALUES(7); TEST_BITPACKING_128_LOWBIT_VALUES(8); +TEST_BITPACKING_128_LOWBIT_VALUES_WITH_FP32_LUT(1); +TEST_BITPACKING_128_LOWBIT_VALUES_WITH_FP32_LUT(2); +TEST_BITPACKING_128_LOWBIT_VALUES_WITH_FP32_LUT(3); +TEST_BITPACKING_128_LOWBIT_VALUES_WITH_FP32_LUT(4); + TEST_BITPACKING_128_LOWBIT_VALUES_WITH_LUT(1); TEST_BITPACKING_128_LOWBIT_VALUES_WITH_LUT(2); TEST_BITPACKING_128_LOWBIT_VALUES_WITH_LUT(3); diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_lut.cpp b/torchao/experimental/kernels/cpu/aarch64/tests/test_lut.cpp new file mode 100644 index 0000000000..a1bc5b95b4 --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_lut.cpp @@ -0,0 +1,198 @@ +#include +#include + +#include +#include + +// Use the kernel API +using namespace torchao::kernels::cpu::aarch64::linear::groupwise_lowbit_weight_with_lut; + +template +void test_groupwise_lowbit_lut_kernel( + int m, + int k, + int n, + int scale_group_size, + int lut_group_size, + bool has_scales, + bool has_bias, + bool has_clamp, + int weight_nbit, + bool promote_to_4bit_layout) { + using namespace torchao::kernels::cpu::aarch64::linear::groupwise_lowbit_weight_with_lut; + ASSERT_EQ(m % MR, 0) << "M must be a multiple of MR"; + ASSERT_EQ(n % NR, 0) << "N must be a multiple of NR"; + ASSERT_EQ(k % scale_group_size, 0) << "K must be a multiple of scale_group_size"; + ASSERT_EQ(k % lut_group_size, 0) << "K must be a multiple of lut_group_size"; + + // 1. Generate test case + auto test_case = torchao::groupwise_lowbit_weight_lut_test_case::generate_with_decoupled_grouping( + m, k, n, + /*scale_group_size=*/scale_group_size, + /*lut_group_size=*/lut_group_size, + /*weight_nbit=*/weight_nbit, + /*has_scales=*/has_scales, + has_bias, has_clamp); + + // 2. Pack Activations + const auto& source_activations = test_case.activations; + std::vector packed_activations_buffer(packed_activations_size_float(m, k, MR)); + pack_activations(packed_activations_buffer.data(), m, k, source_activations.data(), MR); + + // 3. Pack Weights + std::vector packed_weights(packed_weights_size(4, n, k, has_bias, scale_group_size, lut_group_size, NR, promote_to_4bit_layout)); + + pack_weights( + packed_weights.data(), + test_case.weight_qval_indices.data(), + test_case.weight_scales, + test_case.weight_luts, + test_case.weight_nbit, + n, + k, + test_case.scale_group_size, + test_case.lut_group_size, + NR, promote_to_4bit_layout); + + // 4. Run the kernel + std::vector output(m * n); + groupwise_lowbit_lut_kernel( + output.data(), + n, + m, n, k, + scale_group_size, lut_group_size, + packed_weights.data(), + packed_activations_buffer.data(), + test_case.bias.data(), + test_case.clamp_min, + test_case.clamp_max, + has_bias, + has_clamp, + weight_nbit, + MR, NR); + + // 5. Compare results + constexpr float kTol = 1e-4; + for (int i = 0; i < m * n; i++) { + EXPECT_NEAR(output[i], test_case.expected_output[i], kTol) + << "Mismatch at index " << i; + } +} + +TEST(test_groupwise_lowbit_lut_kernel, tile_4x16_aligned_scale_lut_group_size) { + // MR and NR are fixed for current kernel. + constexpr int MR = 4; // Micro-kernel Row height + constexpr int NR = 16; // Micro-kernel Register width + + // Standard case + test_groupwise_lowbit_lut_kernel( + /*m=*/8, /*k=*/64, /*n=*/32, /*scale_group_size=*/32, + /*lut_group_size=*/32, /*has_scales=*/false, + /*has_bias=*/false, /*has_clamp=*/false, /*weight_nbit=*/4, true); + + // With bias + test_groupwise_lowbit_lut_kernel( + /*m=*/4, /*k=*/32, /*n=*/16,/*scale_group_size=*/32, + /*lut_group_size=*/32, /*has_scales=*/false, + /*has_bias=*/true, /*has_clamp=*/false, /*weight_nbit=*/4, true); + + // With clamp + test_groupwise_lowbit_lut_kernel( + /*m=*/12, /*k=*/64, /*n=*/16, /*scale_group_size=*/32, + /*lut_group_size=*/32, /*has_scales=*/false, + /*has_bias=*/false, /*has_clamp=*/true, /*weight_nbit=*/4, true); + + // With bias and clamp + test_groupwise_lowbit_lut_kernel( + /*m=*/4, /*k=*/128, /*n=*/48, + /*scale_group_size=*/32, + /*lut_group_size=*/32, /*has_scales=*/false, + /*has_bias=*/true, /*has_clamp=*/true, /*weight_nbit=*/4, true); + + // With scales + test_groupwise_lowbit_lut_kernel( + /*m=*/8, /*k=*/64, /*n=*/32, + /*scale_group_size=*/32, + /*lut_group_size=*/32, + /*has_scales=*/true, + /*has_bias=*/false, /*has_clamp=*/false, /*weight_nbit=*/4, true); + + // With scales clamp, and bias + test_groupwise_lowbit_lut_kernel( + /*m=*/8, /*k=*/64, /*n=*/32, + /*scale_group_size=*/32, + /*lut_group_size=*/32, + /*has_scales=*/true, + /*has_bias=*/true, /*has_clamp=*/true, /*weight_nbit=*/4, true); + +} + + +TEST(test_groupwise_lowbit_lut_kernel, tile_4x16_misaligned_scale_lut_group_size) { + constexpr int MR = 4; + constexpr int NR = 16; + + // Standard case + test_groupwise_lowbit_lut_kernel( + /*m=*/8, /*k=*/64, /*n=*/32, /*scale_group_size=*/32, + /*lut_group_size=*/16, /*has_scales=*/false, + /*has_bias=*/false, /*has_clamp=*/false, /*weight_nbit=*/4, true); + + // Standard case + test_groupwise_lowbit_lut_kernel( + /*m=*/8, /*k=*/64, /*n=*/32, /*scale_group_size=*/32, + /*lut_group_size=*/64, /*has_scales=*/false, + /*has_bias=*/false, /*has_clamp=*/false, /*weight_nbit=*/4, true); + + // Standard case + test_groupwise_lowbit_lut_kernel( + /*m=*/8, /*k=*/64, /*n=*/32, /*scale_group_size=*/64, + /*lut_group_size=*/32, /*has_scales=*/false, + /*has_bias=*/false, /*has_clamp=*/false, /*weight_nbit=*/4, true); + + // Standard case + test_groupwise_lowbit_lut_kernel( + /*m=*/8, /*k=*/64, /*n=*/32, /*scale_group_size=*/16, + /*lut_group_size=*/32, /*has_scales=*/false, + /*has_bias=*/false, /*has_clamp=*/false, /*weight_nbit=*/4, true); + +} + + +TEST(test_groupwise_lowbit_lut_kernel, lower_indice_bit) { + // MR and NR are fixed for current kernel. + constexpr int MR = 4; // Micro-kernel Row height + constexpr int NR = 16; // Micro-kernel Register width + + // Standard case + test_groupwise_lowbit_lut_kernel( + /*m=*/8, /*k=*/64, /*n=*/32, /*scale_group_size=*/32, + /*lut_group_size=*/32, /*has_scales=*/false, + /*has_bias=*/false, /*has_clamp=*/false, /*weight_nbit=*/1, true); + + // Standard case + test_groupwise_lowbit_lut_kernel( + /*m=*/8, /*k=*/64, /*n=*/32, /*scale_group_size=*/32, + /*lut_group_size=*/32, /*has_scales=*/false, + /*has_bias=*/false, /*has_clamp=*/false, /*weight_nbit=*/2, true); + + // Standard case + test_groupwise_lowbit_lut_kernel( + /*m=*/8, /*k=*/64, /*n=*/32, /*scale_group_size=*/32, + /*lut_group_size=*/32, /*has_scales=*/false, + /*has_bias=*/false, /*has_clamp=*/false, /*weight_nbit=*/3, true); + + // Standard case + test_groupwise_lowbit_lut_kernel( + /*m=*/8, /*k=*/64, /*n=*/32, /*scale_group_size=*/32, + /*lut_group_size=*/32, /*has_scales=*/false, + /*has_bias=*/false, /*has_clamp=*/false, /*weight_nbit=*/3, true); + + + // Standard case + test_groupwise_lowbit_lut_kernel( + /*m=*/8, /*k=*/64, /*n=*/32, /*scale_group_size=*/32, + /*lut_group_size=*/32, /*has_scales=*/true, + /*has_bias=*/true, /*has_clamp=*/true, /*weight_nbit=*/3, true); + + } diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h b/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h index aeb9042210..1db9283c66 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h @@ -306,6 +306,7 @@ struct channelwise_8bit_activation_groupwise_lowbit_weight_test_case { float res = 0.0; for (int k_idx = 0; k_idx < k; k_idx++) { int activation_idx = m_idx * k + k_idx; + // int weight_idx = n_idx * k + k_idx; int weight_idx = n_idx * k + k_idx; int weight_group_idx = weight_idx / weight_group_size; @@ -617,10 +618,6 @@ struct groupwise_lowbit_weight_lut_test_case { weight_scales(weight_scales_) {} - //-------------------------------------------------------------------------- - // Generator Functions (Factories) - //-------------------------------------------------------------------------- - private: /** * @brief The private "master" generator that provides maximum flexibility. @@ -636,15 +633,12 @@ struct groupwise_lowbit_weight_lut_test_case { int weight_nbit, bool has_scales, bool has_bias, bool has_clamp) { - // --- 0. Validation and Setup --- const int total_weights = n * k; - // Frequencies are controlled by their group sizes. - assert(total_weights % scale_group_size == 0); - assert(total_weights % lut_group_size == 0); + assert(k % scale_group_size == 0); + assert(k % lut_group_size == 0); - // The number of unique scales/LUTs is derived directly from their group size. - const int num_scales = total_weights / scale_group_size; - const int num_luts = total_weights / lut_group_size; + const int num_scales = k / scale_group_size; + const int num_luts = k / lut_group_size; const int lut_size = 1 << weight_nbit; std::mt19937 gen(std::random_device{}()); @@ -683,15 +677,16 @@ struct groupwise_lowbit_weight_lut_test_case { float res = 0.0f; for (int k_idx = 0; k_idx < k; ++k_idx) { float activation_val = activations[m_idx * k + k_idx]; - int weight_idx = n_idx * k + k_idx; - uint8_t qval_idx = weight_qval_indices[weight_idx]; - int32_t scale_idx = weight_idx / scale_group_size; - int32_t lut_idx = weight_idx / lut_group_size; + int weight_qval_1d_idx = k_idx * n + n_idx; + uint8_t qval_idx = weight_qval_indices[weight_qval_1d_idx]; + + int32_t scale_idx = k_idx / scale_group_size; + int32_t lut_idx = k_idx / lut_group_size; - // Dequantize: scale * LUT_value float scale = weight_scales[scale_idx]; float lut_val = weight_luts[lut_idx * lut_size + qval_idx]; + res += activation_val * (scale * lut_val); } res += bias_vec[n_idx];