From ff86b063d4bde4625b7c623b40e27e58993c5234 Mon Sep 17 00:00:00 2001 From: justinl66 Date: Tue, 21 Apr 2026 21:44:39 -0700 Subject: [PATCH 1/4] Add INT4 and FP16 GPU matmul op --- cactus/CMakeLists.txt | 11 ++- cactus/graph/graph_ops_nn.cpp | 11 +++ cactus/kernel/kernel.h | 14 ++++ cactus/kernel/kernel_matmul.cpp | 4 + cactus/kernel/kernel_mps.mm | 142 ++++++++++++++++++++++++++++++++ tests/CMakeLists.txt | 4 + tests/test_kernel.cpp | 132 +++++++++++++++++++++++++++++ 7 files changed, 316 insertions(+), 2 deletions(-) create mode 100644 cactus/kernel/kernel_mps.mm diff --git a/cactus/CMakeLists.txt b/cactus/CMakeLists.txt index 59784eeb6..8a1d4bce7 100644 --- a/cactus/CMakeLists.txt +++ b/cactus/CMakeLists.txt @@ -53,6 +53,8 @@ if(APPLE) set(CACTUS_CURL_LIBRARY_MACOS "${CACTUS_CURL_ROOT}/macos/libcurl.a") cactus_require_vendored_curl("${CACTUS_CURL_INCLUDE_DIR}" "${CACTUS_CURL_LIBRARY_MACOS}" "macOS") find_library(COREML_FRAMEWORK CoreML REQUIRED) + find_library(METAL_FRAMEWORK Metal REQUIRED) + find_library(MPS_FRAMEWORK MetalPerformanceShaders REQUIRED) find_library(FOUNDATION_FRAMEWORK Foundation REQUIRED) find_library(ACCELERATE_FRAMEWORK Accelerate REQUIRED) find_library(SECURITY_FRAMEWORK Security REQUIRED) @@ -80,10 +82,12 @@ list(REMOVE_ITEM KERNEL_SOURCES "${I8MM_SOURCE}") if(APPLE) set(NPU_SOURCES "npu/npu_ane.mm") - set_source_files_properties(npu/npu_ane.mm PROPERTIES COMPILE_FLAGS "-fobjc-arc") - message(STATUS "Apple platform detected - NPU/ANE acceleration enabled") + set(MPS_SOURCES "kernel/kernel_mps.mm") + set_source_files_properties(npu/npu_ane.mm kernel/kernel_mps.mm PROPERTIES COMPILE_FLAGS "-fobjc-arc") + message(STATUS "Apple platform detected - NPU/ANE and Metal/MPS acceleration enabled") else() set(NPU_SOURCES "npu/npu.cpp") + set(MPS_SOURCES "") endif() set(COMMON_SOURCES @@ -93,6 +97,7 @@ set(COMMON_SOURCES ${FFI_SOURCES} ${MODEL_SOURCES} ${NPU_SOURCES} + ${MPS_SOURCES} ${TELEMETRY_SOURCES} ) @@ -123,6 +128,8 @@ function(configure_cactus_target target_name) if(APPLE) target_link_libraries(${target_name} PUBLIC ${COREML_FRAMEWORK} + ${METAL_FRAMEWORK} + ${MPS_FRAMEWORK} ${FOUNDATION_FRAMEWORK} ${ACCELERATE_FRAMEWORK} ${SECURITY_FRAMEWORK} diff --git a/cactus/graph/graph_ops_nn.cpp b/cactus/graph/graph_ops_nn.cpp index f840c1c14..9a2697ebc 100644 --- a/cactus/graph/graph_ops_nn.cpp +++ b/cactus/graph/graph_ops_nn.cpp @@ -230,6 +230,17 @@ void compute_matmul_node(GraphNode& node, const std::vector= MPS_INT4_M_THRESHOLD && K >= MPS_INT4_K_THRESHOLD && N >= MPS_INT4_N_THRESHOLD && + cactus_mps_available()) { + cactus_matmul_int4_mps(lhs_buffer.data_as<__fp16>(), rhs, rhs_scales, + output, M, K, N, rhs_buffer.group_size); + return; + } +#endif + const int8_t* lhs_int8; const float* lhs_scales; diff --git a/cactus/kernel/kernel.h b/cactus/kernel/kernel.h index 2c0fdfc31..49a91907e 100644 --- a/cactus/kernel/kernel.h +++ b/cactus/kernel/kernel.h @@ -84,6 +84,20 @@ void cactus_matmul_integer(Precision precision, void cactus_matmul_f16(const __fp16* a, const __fp16* b_transposed, __fp16* c, size_t M, size_t K, size_t N); +#ifdef __APPLE__ +constexpr size_t MPS_F16_M_THRESHOLD = 128; +constexpr size_t MPS_F16_K_THRESHOLD = 2048; +constexpr size_t MPS_F16_N_THRESHOLD = 2048; +constexpr size_t MPS_INT4_M_THRESHOLD = 256; +constexpr size_t MPS_INT4_K_THRESHOLD = 1024; +constexpr size_t MPS_INT4_N_THRESHOLD = 1024; +bool cactus_mps_available(); +void cactus_matmul_f16_mps(const __fp16* A, const __fp16* B_T, __fp16* C, + size_t M, size_t K, size_t N); +void cactus_matmul_int4_mps(const __fp16* A, const int8_t* B_packed, const __fp16* B_scales, + __fp16* C, size_t M, size_t K, size_t N, size_t group_size); +#endif + void cactus_transpose_2d_f16(const __fp16* source, __fp16* destination, size_t num_rows, size_t num_cols, size_t start_row, size_t end_row); void cactus_transpose_f16(const __fp16* source, __fp16* destination, const size_t* shape, diff --git a/cactus/kernel/kernel_matmul.cpp b/cactus/kernel/kernel_matmul.cpp index 58b57587e..11d2122da 100644 --- a/cactus/kernel/kernel_matmul.cpp +++ b/cactus/kernel/kernel_matmul.cpp @@ -168,6 +168,10 @@ void cactus_matmul_f16( ) { #ifdef __APPLE__ + if (M >= MPS_F16_M_THRESHOLD && K >= MPS_F16_K_THRESHOLD && N >= MPS_F16_N_THRESHOLD && cactus_mps_available()) { + cactus_matmul_f16_mps(a, b_transposed, c, M, K, N); + return; + } if (K >= ACCELERATE_K_THRESHOLD && M >= ACCELERATE_M_THRESHOLD) { const size_t a_len = M * K; const size_t b_len = N * K; diff --git a/cactus/kernel/kernel_mps.mm b/cactus/kernel/kernel_mps.mm new file mode 100644 index 000000000..76b2ddb91 --- /dev/null +++ b/cactus/kernel/kernel_mps.mm @@ -0,0 +1,142 @@ +#ifdef __APPLE__ + +#import +#import +#include "kernel.h" +#include + +static NSString* const kCactusMSL = @R"( +#include +using namespace metal; + +kernel void cactus_dequant_int4( + device const uchar* B_packed [[buffer(0)]], + device const half* B_scales [[buffer(1)]], + device half* B_out [[buffer(2)]], + constant uint& K [[buffer(3)]], + constant uint& group_size [[buffer(4)]], + uint2 gid [[thread_position_in_grid]]) +{ + uint k = gid.x; + uint n = gid.y; + uint num_groups = K / group_size; + uint n_block = n >> 2; + uint c = n & 3; + uint g = k / group_size; + uint k_local = k - g * group_size; + uint k_super = k_local >> 3; + uint k_in_slab = k_local & 7; + uint byte_in_group = k_super * 16 + c * 4 + (k_in_slab & 3); + uint byte_offset = (n_block * K + g * group_size) * 2 + byte_in_group; + uchar b = B_packed[byte_offset]; + int nibble = (k_in_slab < 4) ? int(b & 0xF) : int(b >> 4); + if (nibble >= 8) nibble -= 16; + half scale = B_scales[(n_block * num_groups + g) * 4 + c]; + B_out[n * K + k] = half(nibble) * scale; +} +)"; + +static id g_device = nil; +static id g_queue = nil; +static id g_dequant_pso = nil; +static dispatch_once_t g_once; + +static void cactus_mps_init() { + dispatch_once(&g_once, ^{ + g_device = MTLCreateSystemDefaultDevice(); + if (!g_device) return; + g_queue = [g_device newCommandQueue]; + NSError* err = nil; + id lib = [g_device newLibraryWithSource:kCactusMSL options:nil error:&err]; + if (!lib) return; + id fn = [lib newFunctionWithName:@"cactus_dequant_int4"]; + if (!fn) return; + g_dequant_pso = [g_device newComputePipelineStateWithFunction:fn error:&err]; + }); +} + +bool cactus_mps_available() { + cactus_mps_init(); + return g_device != nil && g_queue != nil; +} + +void cactus_matmul_f16_mps(const __fp16* A, const __fp16* B_T, __fp16* C, + size_t M, size_t K, size_t N) { + cactus_mps_init(); + if (!g_device || !g_queue) return; + + @autoreleasepool { + const size_t fp16 = sizeof(__fp16); + id bufA = [g_device newBufferWithBytes:A length:M*K*fp16 options:MTLResourceStorageModeShared]; + id bufB = [g_device newBufferWithBytes:B_T length:N*K*fp16 options:MTLResourceStorageModeShared]; + id bufC = [g_device newBufferWithLength:M*N*fp16 options:MTLResourceStorageModeShared]; + + MPSMatrixDescriptor* dA = [MPSMatrixDescriptor matrixDescriptorWithRows:M columns:K rowBytes:K*fp16 dataType:MPSDataTypeFloat16]; + MPSMatrixDescriptor* dB = [MPSMatrixDescriptor matrixDescriptorWithRows:N columns:K rowBytes:K*fp16 dataType:MPSDataTypeFloat16]; + MPSMatrixDescriptor* dC = [MPSMatrixDescriptor matrixDescriptorWithRows:M columns:N rowBytes:N*fp16 dataType:MPSDataTypeFloat16]; + + MPSMatrix* mA = [[MPSMatrix alloc] initWithBuffer:bufA descriptor:dA]; + MPSMatrix* mB = [[MPSMatrix alloc] initWithBuffer:bufB descriptor:dB]; + MPSMatrix* mC = [[MPSMatrix alloc] initWithBuffer:bufC descriptor:dC]; + + MPSMatrixMultiplication* mm = [[MPSMatrixMultiplication alloc] initWithDevice:g_device + transposeLeft:NO transposeRight:YES resultRows:M resultColumns:N interiorColumns:K alpha:1.0 beta:0.0]; + + id cmd = [g_queue commandBuffer]; + [mm encodeToCommandBuffer:cmd leftMatrix:mA rightMatrix:mB resultMatrix:mC]; + [cmd commit]; + [cmd waitUntilCompleted]; + + memcpy(C, [bufC contents], M*N*fp16); + } +} + +void cactus_matmul_int4_mps(const __fp16* A, const int8_t* B_packed, const __fp16* B_scales, + __fp16* C, size_t M, size_t K, size_t N, size_t group_size) { + cactus_mps_init(); + if (!g_device || !g_queue || !g_dequant_pso) return; + if (N % 4 != 0 || K % group_size != 0) return; + + @autoreleasepool { + const size_t fp16 = sizeof(__fp16); + const size_t packed_bytes = (N / 4) * K * 2; + const size_t num_groups = K / group_size; + const size_t scales_bytes = (N / 4) * num_groups * 4 * fp16; + + id bufA = [g_device newBufferWithBytes:A length:M*K*fp16 options:MTLResourceStorageModeShared]; + id bufBp = [g_device newBufferWithBytes:B_packed length:packed_bytes options:MTLResourceStorageModeShared]; + id bufBs = [g_device newBufferWithBytes:B_scales length:scales_bytes options:MTLResourceStorageModeShared]; + id bufBd = [g_device newBufferWithLength:N*K*fp16 options:MTLResourceStorageModeShared]; + id bufC = [g_device newBufferWithLength:M*N*fp16 options:MTLResourceStorageModeShared]; + + id cmd = [g_queue commandBuffer]; + + id enc = [cmd computeCommandEncoder]; + [enc setComputePipelineState:g_dequant_pso]; + [enc setBuffer:bufBp offset:0 atIndex:0]; + [enc setBuffer:bufBs offset:0 atIndex:1]; + [enc setBuffer:bufBd offset:0 atIndex:2]; + uint32_t Ku = (uint32_t)K; + uint32_t Gu = (uint32_t)group_size; + [enc setBytes:&Ku length:sizeof(Ku) atIndex:3]; + [enc setBytes:&Gu length:sizeof(Gu) atIndex:4]; + [enc dispatchThreads:MTLSizeMake(K, N, 1) threadsPerThreadgroup:MTLSizeMake(32, 8, 1)]; + [enc endEncoding]; + + MPSMatrixDescriptor* dA = [MPSMatrixDescriptor matrixDescriptorWithRows:M columns:K rowBytes:K*fp16 dataType:MPSDataTypeFloat16]; + MPSMatrixDescriptor* dB = [MPSMatrixDescriptor matrixDescriptorWithRows:N columns:K rowBytes:K*fp16 dataType:MPSDataTypeFloat16]; + MPSMatrixDescriptor* dC = [MPSMatrixDescriptor matrixDescriptorWithRows:M columns:N rowBytes:N*fp16 dataType:MPSDataTypeFloat16]; + MPSMatrix* mA = [[MPSMatrix alloc] initWithBuffer:bufA descriptor:dA]; + MPSMatrix* mB = [[MPSMatrix alloc] initWithBuffer:bufBd descriptor:dB]; + MPSMatrix* mC = [[MPSMatrix alloc] initWithBuffer:bufC descriptor:dC]; + MPSMatrixMultiplication* mm = [[MPSMatrixMultiplication alloc] initWithDevice:g_device + transposeLeft:NO transposeRight:YES resultRows:M resultColumns:N interiorColumns:K alpha:1.0 beta:0.0]; + [mm encodeToCommandBuffer:cmd leftMatrix:mA rightMatrix:mB resultMatrix:mC]; + + [cmd commit]; + [cmd waitUntilCompleted]; + memcpy(C, [bufC contents], M*N*fp16); + } +} + +#endif diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 23b9ae586..3e37d0334 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -46,6 +46,8 @@ if(NOT APPLE) endif() else() find_library(COREML_FRAMEWORK CoreML REQUIRED) + find_library(METAL_FRAMEWORK Metal REQUIRED) + find_library(MPS_FRAMEWORK MetalPerformanceShaders REQUIRED) find_library(FOUNDATION_FRAMEWORK Foundation REQUIRED) find_library(ACCELERATE_FRAMEWORK Accelerate REQUIRED) find_library(SECURITY_FRAMEWORK Security REQUIRED) @@ -95,6 +97,8 @@ foreach(TEST_FILE ${TEST_SOURCES}) target_link_libraries(${TEST_NAME} PRIVATE ${ACCELERATE_FRAMEWORK} ${COREML_FRAMEWORK} + ${METAL_FRAMEWORK} + ${MPS_FRAMEWORK} ${FOUNDATION_FRAMEWORK} ${SECURITY_FRAMEWORK} ${SYSTEMCONFIGURATION_FRAMEWORK} diff --git a/tests/test_kernel.cpp b/tests/test_kernel.cpp index 8b05b3652..29231b71a 100644 --- a/tests/test_kernel.cpp +++ b/tests/test_kernel.cpp @@ -491,6 +491,134 @@ bool test_fast_tanh_f32x4_correctness() { return true; } +#ifdef __APPLE__ +bool test_mps_matmul_f16_correctness() { + if (!cactus_mps_available()) return true; + const size_t M = 32, K = 512, N = 256; + + std::mt19937 gen(7); + std::uniform_real_distribution dis(-1.0f, 1.0f); + + std::vector<__fp16> A(M * K), BT(N * K), C_mps(M * N), C_ref(M * N); + for (size_t i = 0; i < M * K; ++i) { + A[i] = static_cast<__fp16>(dis(gen)); + } + for (size_t i = 0; i < N * K; ++i) { + BT[i] = static_cast<__fp16>(dis(gen)); + } + + cactus_matmul_f16_mps(A.data(), BT.data(), C_mps.data(), M, K, N); + + for (size_t m = 0; m < M; ++m) { + for (size_t n = 0; n < N; ++n) { + float acc = 0.0f; + for (size_t k = 0; k < K; ++k) { + acc += static_cast(A[m * K + k]) * static_cast(BT[n * K + k]); + } + C_ref[m * N + n] = static_cast<__fp16>(acc); + } + } + + float max_err = 0.0f; + for (size_t i = 0; i < M * N; ++i) { + float err = std::abs(static_cast(C_mps[i]) - static_cast(C_ref[i])); + if (err > max_err) max_err = err; + } + + std::cout << " MPS FP16 matmul max abs error: " << max_err << std::endl; + return max_err < 0.5f; +} + +bool test_mps_matmul_int4_correctness() { + if (!cactus_mps_available()) return true; + const size_t M = 32, K = 512, N = 256, group_size = 32; + const size_t num_groups = K / group_size; + const size_t BS = 4; + + std::mt19937 gen(11); + std::uniform_real_distribution dis(-1.0f, 1.0f); + + std::vector B_fp32(N * K); + for (size_t i = 0; i < N * K; ++i) { + B_fp32[i] = dis(gen); + } + + std::vector B_raw(N * K); + std::vector B_scales(N * num_groups); + for (size_t n = 0; n < N; ++n) { + for (size_t g = 0; g < num_groups; ++g) { + float max_abs = 0.0f; + for (size_t k = 0; k < group_size; ++k) { + float val = std::abs(B_fp32[n * K + g * group_size + k]); + if (val > max_abs) max_abs = val; + } + float scale = std::max(max_abs / 7.0f, 1e-10f); + B_scales[n * num_groups + g] = scale; + for (size_t k = 0; k < group_size; ++k) { + int32_t q = static_cast(std::round(B_fp32[n * K + g * group_size + k] / scale)); + B_raw[n * K + g * group_size + k] = static_cast(std::clamp(q, -8, 7)); + } + } + } + + std::vector B_inter(N * K); + for (size_t nb = 0; nb < N / BS; ++nb) { + for (size_t kb = 0; kb < K / BS; ++kb) { + for (size_t ni = 0; ni < BS; ++ni) { + for (size_t ki = 0; ki < BS; ++ki) { + B_inter[(nb * (K / BS) + kb) * BS * BS + ni * BS + ki] = + B_raw[(nb * BS + ni) * K + kb * BS + ki]; + } + } + } + } + + std::vector B_packed(N * K / 2); + for (size_t i = 0; i < N * K; i += 32) { + for (size_t j = 0; j < 16; ++j) { + uint8_t lo = static_cast(B_inter[i + j] & 0x0F); + uint8_t hi = static_cast((B_inter[i + 16 + j] & 0x0F) << 4); + B_packed[i / 2 + j] = lo | hi; + } + } + + std::vector<__fp16> B_scales_inter(N * num_groups); + for (size_t nb = 0; nb < N / BS; ++nb) { + for (size_t g = 0; g < num_groups; ++g) { + for (size_t ni = 0; ni < BS; ++ni) { + B_scales_inter[(nb * num_groups + g) * BS + ni] = + static_cast<__fp16>(B_scales[(nb * BS + ni) * num_groups + g]); + } + } + } + + std::vector<__fp16> A(M * K); + for (size_t i = 0; i < M * K; ++i) { + A[i] = static_cast<__fp16>(dis(gen)); + } + + std::vector<__fp16> C_mps(M * N); + cactus_matmul_int4_mps(A.data(), reinterpret_cast(B_packed.data()), + B_scales_inter.data(), C_mps.data(), M, K, N, group_size); + + float max_err = 0.0f; + for (size_t m = 0; m < M; ++m) { + for (size_t n = 0; n < N; ++n) { + float acc = 0.0f; + for (size_t k = 0; k < K; ++k) { + acc += static_cast(A[m * K + k]) * static_cast(B_raw[n * K + k]) * + B_scales[n * num_groups + (k / group_size)]; + } + float err = std::abs(static_cast(C_mps[m * N + n]) - acc); + if (err > max_err) max_err = err; + } + } + + std::cout << " MPS INT4 matmul max abs error: " << max_err << std::endl; + return max_err < 0.1f; +} +#endif + int main() { TestUtils::TestRunner runner("Kernel Backend Tests"); @@ -507,6 +635,10 @@ int main() { runner.run_test("Kernel INT4 MatMul Correctness", test_int4_matmul_correctness()); runner.run_test("Kernel STFT Complex Correctness", test_stft_kernel_correctness()); runner.run_test("Kernel Fast Tanh Correctness", test_fast_tanh_f32x4_correctness()); +#ifdef __APPLE__ + runner.run_test("Kernel MPS FP16 MatMul Correctness", test_mps_matmul_f16_correctness()); + runner.run_test("Kernel MPS INT4 MatMul Correctness", test_mps_matmul_int4_correctness()); +#endif runner.print_summary(); return runner.all_passed() ? 0 : 1; From 451671698b00fcfbd7798a782d3145b50554beea Mon Sep 17 00:00:00 2001 From: justinl66 Date: Tue, 21 Apr 2026 22:39:44 -0700 Subject: [PATCH 2/4] change fp16 threshold --- cactus/kernel/kernel.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cactus/kernel/kernel.h b/cactus/kernel/kernel.h index 49a91907e..ecd9f6385 100644 --- a/cactus/kernel/kernel.h +++ b/cactus/kernel/kernel.h @@ -86,8 +86,8 @@ void cactus_matmul_f16(const __fp16* a, const __fp16* b_transposed, __fp16* c, #ifdef __APPLE__ constexpr size_t MPS_F16_M_THRESHOLD = 128; -constexpr size_t MPS_F16_K_THRESHOLD = 2048; -constexpr size_t MPS_F16_N_THRESHOLD = 2048; +constexpr size_t MPS_F16_K_THRESHOLD = 1024; +constexpr size_t MPS_F16_N_THRESHOLD = 1024; constexpr size_t MPS_INT4_M_THRESHOLD = 256; constexpr size_t MPS_INT4_K_THRESHOLD = 1024; constexpr size_t MPS_INT4_N_THRESHOLD = 1024; From ec0f8934cfd39241741ab2fedbe95c554f037551 Mon Sep 17 00:00:00 2001 From: justinl66 Date: Sun, 26 Apr 2026 00:06:07 -0700 Subject: [PATCH 3/4] add gpu gemv and async buffer runtime --- cactus/CMakeLists.txt | 2 + cactus/graph/graph.h | 17 +- cactus/graph/graph_core.cpp | 19 + cactus/graph/graph_execute.cpp | 5 + cactus/graph/graph_ops_nn.cpp | 46 ++- cactus/kernel/kernel.h | 19 +- cactus/kernel/kernel_attention.cpp | 11 + cactus/kernel/kernel_matmul.cpp | 3 +- cactus/kernel/kernel_mps.mm | 643 ++++++++++++++++++++++++++--- tests/test_kernel.cpp | 79 ++++ 10 files changed, 776 insertions(+), 68 deletions(-) diff --git a/cactus/CMakeLists.txt b/cactus/CMakeLists.txt index 8a1d4bce7..a19c2a593 100644 --- a/cactus/CMakeLists.txt +++ b/cactus/CMakeLists.txt @@ -55,6 +55,7 @@ if(APPLE) find_library(COREML_FRAMEWORK CoreML REQUIRED) find_library(METAL_FRAMEWORK Metal REQUIRED) find_library(MPS_FRAMEWORK MetalPerformanceShaders REQUIRED) + find_library(MPSGRAPH_FRAMEWORK MetalPerformanceShadersGraph REQUIRED) find_library(FOUNDATION_FRAMEWORK Foundation REQUIRED) find_library(ACCELERATE_FRAMEWORK Accelerate REQUIRED) find_library(SECURITY_FRAMEWORK Security REQUIRED) @@ -130,6 +131,7 @@ function(configure_cactus_target target_name) ${COREML_FRAMEWORK} ${METAL_FRAMEWORK} ${MPS_FRAMEWORK} + ${MPSGRAPH_FRAMEWORK} ${FOUNDATION_FRAMEWORK} ${ACCELERATE_FRAMEWORK} ${SECURITY_FRAMEWORK} diff --git a/cactus/graph/graph.h b/cactus/graph/graph.h index 9ad6fe389..00da94c15 100644 --- a/cactus/graph/graph.h +++ b/cactus/graph/graph.h @@ -242,7 +242,9 @@ struct BufferDesc { std::unique_ptr owned_scales; bool is_interleaved = false; - size_t original_N = 0; + size_t original_N = 0; + + bool pending_gpu_write = false; void* activation_scales_data = nullptr; std::unique_ptr owned_activation_scales; @@ -261,11 +263,20 @@ struct BufferDesc { void* get_data(); const void* get_data() const; + void flush_if_pending(); + void flush_if_pending() const; + + template + T* data_as() { flush_if_pending(); return static_cast(get_data()); } + + template + const T* data_as() const { flush_if_pending(); return static_cast(get_data()); } + template - T* data_as() { return static_cast(get_data()); } + T* data_ptr_raw() { return static_cast(get_data()); } template - const T* data_as() const { return static_cast(get_data()); } + const T* data_ptr_raw() const { return static_cast(get_data()); } const __fp16* scales_as_fp16() const { return reinterpret_cast(scales_data); diff --git a/cactus/graph/graph_core.cpp b/cactus/graph/graph_core.cpp index f418da72b..7695da0e1 100644 --- a/cactus/graph/graph_core.cpp +++ b/cactus/graph/graph_core.cpp @@ -1,4 +1,5 @@ #include "graph.h" +#include "../kernel/kernel.h" #include #include #include @@ -155,6 +156,24 @@ const void* BufferDesc::get_data() const { return data.get(); } +void BufferDesc::flush_if_pending() { +#ifdef __APPLE__ + if (pending_gpu_write) { + cactus_mps_synchronize(); + pending_gpu_write = false; + } +#endif +} + +void BufferDesc::flush_if_pending() const { +#ifdef __APPLE__ + if (pending_gpu_write) { + cactus_mps_synchronize(); + const_cast(this)->pending_gpu_write = false; + } +#endif +} + void BufferDesc::allocate() { if (!data && !external_data && !pooled_data) { data = std::make_unique(byte_size); diff --git a/cactus/graph/graph_execute.cpp b/cactus/graph/graph_execute.cpp index dd75389fb..827dc2519 100644 --- a/cactus/graph/graph_execute.cpp +++ b/cactus/graph/graph_execute.cpp @@ -1,4 +1,5 @@ #include "graph.h" +#include "../kernel/kernel.h" #include "../kernel/kernel_utils.h" #include #include @@ -700,6 +701,10 @@ void CactusGraph::execute(const std::string& profile_file) { } } +#ifdef __APPLE__ + cactus_mps_flush(); +#endif + if (enable_profiling) { auto total_end = std::chrono::high_resolution_clock::now(); auto total_duration = std::chrono::duration_cast(total_end - total_start); diff --git a/cactus/graph/graph_ops_nn.cpp b/cactus/graph/graph_ops_nn.cpp index 9a2697ebc..f5d23edbd 100644 --- a/cactus/graph/graph_ops_nn.cpp +++ b/cactus/graph/graph_ops_nn.cpp @@ -222,25 +222,40 @@ void compute_matmul_node(GraphNode& node, const std::vector 0) { - const int8_t* rhs = rhs_buffer.data_as(); - const __fp16* rhs_scales = rhs_buffer.scales_as_fp16(); - __fp16* output = node.output_buffer.data_as<__fp16>(); - if (!pretransposed_rhs) { throw std::runtime_error("Group-wise quantized matmul requires pretransposed weights"); } #ifdef __APPLE__ - if (rhs_buffer.precision == Precision::INT4 && + if (cactus_mps_enabled() && + rhs_buffer.precision == Precision::INT4 && lhs_buffer.precision == Precision::FP16 && - M >= MPS_INT4_M_THRESHOLD && K >= MPS_INT4_K_THRESHOLD && N >= MPS_INT4_N_THRESHOLD && cactus_mps_available()) { - cactus_matmul_int4_mps(lhs_buffer.data_as<__fp16>(), rhs, rhs_scales, - output, M, K, N, rhs_buffer.group_size); - return; + if (M == 1 && K >= MPS_GEMV_INT4_K_THRESHOLD && N >= MPS_GEMV_INT4_N_THRESHOLD && N % 4 == 0) { + cactus_gemv_int4_mps(lhs_buffer.data_ptr_raw<__fp16>(), + rhs_buffer.data_ptr_raw(), + rhs_buffer.scales_as_fp16(), + node.output_buffer.data_ptr_raw<__fp16>(), + K, N, rhs_buffer.group_size); + node.output_buffer.pending_gpu_write = true; + return; + } + if (M >= MPS_INT4_M_THRESHOLD && K >= MPS_INT4_K_THRESHOLD && N >= MPS_INT4_N_THRESHOLD) { + cactus_matmul_int4_mps(lhs_buffer.data_ptr_raw<__fp16>(), + rhs_buffer.data_ptr_raw(), + rhs_buffer.scales_as_fp16(), + node.output_buffer.data_ptr_raw<__fp16>(), + M, K, N, rhs_buffer.group_size); + node.output_buffer.pending_gpu_write = true; + return; + } } #endif + const int8_t* rhs = rhs_buffer.data_as(); + const __fp16* rhs_scales = rhs_buffer.scales_as_fp16(); + __fp16* output = node.output_buffer.data_as<__fp16>(); + const int8_t* lhs_int8; const float* lhs_scales; @@ -267,6 +282,19 @@ void compute_matmul_node(GraphNode& node, const std::vector(lhs_buffer.precision)) + ")"); } +#ifdef __APPLE__ + if (cactus_mps_enabled() && pretransposed_rhs && + M >= MPS_F16_M_THRESHOLD && K >= MPS_F16_K_THRESHOLD && N >= MPS_F16_N_THRESHOLD && + cactus_mps_available()) { + cactus_matmul_f16_mps(lhs_buffer.data_ptr_raw<__fp16>(), + rhs_buffer.data_ptr_raw<__fp16>(), + node.output_buffer.data_ptr_raw<__fp16>(), + M, K, N); + node.output_buffer.pending_gpu_write = true; + return; + } +#endif + const __fp16* lhs = lhs_buffer.data_as<__fp16>(); const __fp16* rhs = rhs_buffer.data_as<__fp16>(); __fp16* output = node.output_buffer.data_as<__fp16>(); diff --git a/cactus/kernel/kernel.h b/cactus/kernel/kernel.h index ecd9f6385..217bce175 100644 --- a/cactus/kernel/kernel.h +++ b/cactus/kernel/kernel.h @@ -90,12 +90,29 @@ constexpr size_t MPS_F16_K_THRESHOLD = 1024; constexpr size_t MPS_F16_N_THRESHOLD = 1024; constexpr size_t MPS_INT4_M_THRESHOLD = 256; constexpr size_t MPS_INT4_K_THRESHOLD = 1024; -constexpr size_t MPS_INT4_N_THRESHOLD = 1024; +constexpr size_t MPS_INT4_N_THRESHOLD = 256; +constexpr size_t MPS_GEMV_INT4_K_THRESHOLD = 1024; +constexpr size_t MPS_GEMV_INT4_N_THRESHOLD = 32768; +constexpr size_t MPS_ATTN_SEQ_THRESHOLD = 1024; bool cactus_mps_available(); +void cactus_mps_set_enabled(bool enabled); +bool cactus_mps_enabled(); +void cactus_mps_flush(); +void cactus_mps_synchronize(); void cactus_matmul_f16_mps(const __fp16* A, const __fp16* B_T, __fp16* C, size_t M, size_t K, size_t N); void cactus_matmul_int4_mps(const __fp16* A, const int8_t* B_packed, const __fp16* B_scales, __fp16* C, size_t M, size_t K, size_t N, size_t group_size); +void cactus_gemv_int4_mps(const __fp16* A, const int8_t* B_packed, const __fp16* B_scales, + __fp16* C, size_t K, size_t N, size_t group_size); +void cactus_attention_f16_mps(const __fp16* Q, const __fp16* K, const __fp16* V, __fp16* O, + size_t seq_len, size_t kv_seq_len, + size_t num_q_heads, size_t num_kv_heads, + size_t head_dim, float scale, size_t position_offset); +void cactus_attention_f16_mpsgraph(const __fp16* Q, const __fp16* K, const __fp16* V, __fp16* O, + size_t seq_len, size_t kv_seq_len, + size_t num_q_heads, size_t num_kv_heads, + size_t head_dim, float scale, size_t position_offset); #endif void cactus_transpose_2d_f16(const __fp16* source, __fp16* destination, diff --git a/cactus/kernel/kernel_attention.cpp b/cactus/kernel/kernel_attention.cpp index ad4dd13a4..0f53f10e3 100644 --- a/cactus/kernel/kernel_attention.cpp +++ b/cactus/kernel/kernel_attention.cpp @@ -394,6 +394,17 @@ void cactus_attention_f16( } if (mask == nullptr && head_dim % 8 == 0 && v_head_dim % 8 == 0 && logit_cap == 0.0f) { +#ifdef __APPLE__ + if (cactus_mps_enabled() && cactus_mps_available() && + is_causal && window_size == 0 && batch_size == 1 && + head_dim == v_head_dim && head_dim <= 256 && + seq_len >= MPS_ATTN_SEQ_THRESHOLD) { + cactus_attention_f16_mpsgraph(queries, keys, values, output, + seq_len, kv_seq_len, num_q_heads, num_kv_heads, + head_dim, scale, position_offset); + return; + } +#endif cactus_attention_f16_fast( queries, keys, values, output, batch_size, seq_len, kv_seq_len, diff --git a/cactus/kernel/kernel_matmul.cpp b/cactus/kernel/kernel_matmul.cpp index 11d2122da..ae16e4ff2 100644 --- a/cactus/kernel/kernel_matmul.cpp +++ b/cactus/kernel/kernel_matmul.cpp @@ -168,8 +168,9 @@ void cactus_matmul_f16( ) { #ifdef __APPLE__ - if (M >= MPS_F16_M_THRESHOLD && K >= MPS_F16_K_THRESHOLD && N >= MPS_F16_N_THRESHOLD && cactus_mps_available()) { + if (cactus_mps_enabled() && M >= MPS_F16_M_THRESHOLD && K >= MPS_F16_K_THRESHOLD && N >= MPS_F16_N_THRESHOLD && cactus_mps_available()) { cactus_matmul_f16_mps(a, b_transposed, c, M, K, N); + cactus_mps_synchronize(); return; } if (K >= ACCELERATE_K_THRESHOLD && M >= ACCELERATE_M_THRESHOLD) { diff --git a/cactus/kernel/kernel_mps.mm b/cactus/kernel/kernel_mps.mm index 76b2ddb91..3bb53a91b 100644 --- a/cactus/kernel/kernel_mps.mm +++ b/cactus/kernel/kernel_mps.mm @@ -2,13 +2,60 @@ #import #import +#import #include "kernel.h" #include +#include +#include static NSString* const kCactusMSL = @R"( #include using namespace metal; +kernel void cactus_gemv_int4( + device const half* A [[buffer(0)]], + device const uchar* B_packed [[buffer(1)]], + device const half* B_scales [[buffer(2)]], + device half* C [[buffer(3)]], + constant uint& K [[buffer(4)]], + constant uint& N [[buffer(5)]], + constant uint& group_size [[buffer(6)]], + uint tgpig [[threadgroup_position_in_grid]], + ushort tiisg [[thread_index_in_simdgroup]]) +{ + const uint n_block = tgpig; + const uint n_base = n_block << 2; + if (n_base >= N) return; + const uint num_groups = K / group_size; + + float partials[4] = {0.0f, 0.0f, 0.0f, 0.0f}; + + for (uint k = tiisg; k < K; k += 32) { + float a_val = float(A[k]); + uint g = k / group_size; + uint k_local = k - g * group_size; + uint k_super = k_local >> 3; + uint k_in_slab = k_local & 7; + bool high_nibble = k_in_slab >= 4; + uint byte_offset_base = (n_block * K + g * group_size) * 2 + k_super * 16 + (k_in_slab & 3); + + #pragma unroll + for (uint c = 0; c < 4; ++c) { + uchar b = B_packed[byte_offset_base + c * 4]; + int nibble = high_nibble ? int(b >> 4) : int(b & 0xF); + if (nibble >= 8) nibble -= 16; + float scale = float(B_scales[(n_block * num_groups + g) * 4 + c]); + partials[c] += a_val * float(nibble) * scale; + } + } + + #pragma unroll + for (uint c = 0; c < 4; ++c) { + float total = simd_sum(partials[c]); + if (tiisg == 0 && n_base + c < N) C[n_base + c] = (half)total; + } +} + kernel void cactus_dequant_int4( device const uchar* B_packed [[buffer(0)]], device const half* B_scales [[buffer(1)]], @@ -17,49 +64,333 @@ kernel void cactus_dequant_int4( constant uint& group_size [[buffer(4)]], uint2 gid [[thread_position_in_grid]]) { - uint k = gid.x; - uint n = gid.y; - uint num_groups = K / group_size; - uint n_block = n >> 2; - uint c = n & 3; - uint g = k / group_size; - uint k_local = k - g * group_size; - uint k_super = k_local >> 3; - uint k_in_slab = k_local & 7; - uint byte_in_group = k_super * 16 + c * 4 + (k_in_slab & 3); - uint byte_offset = (n_block * K + g * group_size) * 2 + byte_in_group; - uchar b = B_packed[byte_offset]; - int nibble = (k_in_slab < 4) ? int(b & 0xF) : int(b >> 4); - if (nibble >= 8) nibble -= 16; - half scale = B_scales[(n_block * num_groups + g) * 4 + c]; - B_out[n * K + k] = half(nibble) * scale; + const uint k_slab = gid.x; + const uint n = gid.y; + const uint k_base = k_slab << 3; + + const uint num_groups = K / group_size; + const uint nb = n >> 2; + const uint c = n & 3u; + const uint g = k_base / group_size; + const uint kl = k_base - g * group_size; + const uint slab_in_grp = kl >> 3; + const uint byte_offset = (nb * K + g * group_size) * 2u + slab_in_grp * 16u + c * 4u; + + uchar4 packed = *(device const uchar4*)(B_packed + byte_offset); + half scl = B_scales[(nb * num_groups + g) * 4u + c]; + + half4 lo4, hi4; + #pragma unroll + for (uint i = 0; i < 4; ++i) { + int lo = int(packed[i] & 0xFu); + int hi = int(packed[i] >> 4); + if (lo >= 8) lo -= 16; + if (hi >= 8) hi -= 16; + lo4[i] = half(lo) * scl; + hi4[i] = half(hi) * scl; + } + + device half4* dst = (device half4*)(B_out + n * K + k_base); + dst[0] = lo4; + dst[1] = hi4; +} + + +constant uint A2_Q_BLOCK = 8; +constant uint A2_KV_BLOCK = 16; +constant uint A2_NSG = 2; + +kernel void cactus_flash_attn_f16_v2( + device const half* q_in [[buffer(0)]], + device const half* k_in [[buffer(1)]], + device const half* v_in [[buffer(2)]], + device half* o_out [[buffer(3)]], + constant uint& seq_len [[buffer(4)]], + constant uint& kv_seq_len [[buffer(5)]], + constant uint& nqh [[buffer(6)]], + constant uint& nkh [[buffer(7)]], + constant uint& head_dim [[buffer(8)]], + constant float& scale [[buffer(9)]], + constant uint& position_offset [[buffer(10)]], + threadgroup char* shmem [[threadgroup(0)]], + uint2 tgpig [[threadgroup_position_in_grid]], + ushort sgitg [[simdgroup_index_in_threadgroup]], + ushort tiisg [[thread_index_in_simdgroup]]) +{ + const uint q_block_idx = tgpig.x; + const uint q_head = tgpig.y; + const uint q_start = q_block_idx * A2_Q_BLOCK; + if (q_start >= seq_len) return; + + const uint kv_head = q_head * nkh / nqh; + const uint q_stride = nqh * head_dim; + const uint kv_stride = nkh * head_dim; + + const ushort tid_in_tg = sgitg * 32 + tiisg; + const ushort total_threads = A2_NSG * 32; + + threadgroup half* sQ = (threadgroup half*)shmem; + threadgroup half* sK = sQ + A2_Q_BLOCK * head_dim; + threadgroup half* sV = sK + A2_KV_BLOCK * head_dim; + threadgroup half* sP = sV + A2_KV_BLOCK * head_dim; + threadgroup float* sO = (threadgroup float*)(sP + A2_Q_BLOCK * A2_KV_BLOCK); + threadgroup float* sM = sO + A2_Q_BLOCK * head_dim; + threadgroup float* sL = sM + A2_Q_BLOCK; + threadgroup float* sScratch = sL + A2_Q_BLOCK; + + if (tid_in_tg < A2_Q_BLOCK) { + sM[tid_in_tg] = -INFINITY; + sL[tid_in_tg] = 0.0f; + } + for (uint i = tid_in_tg; i < A2_Q_BLOCK * head_dim; i += total_threads) { + sO[i] = 0.0f; + } + + for (uint j = 0; j < A2_Q_BLOCK; ++j) { + const uint q_pos = q_start + j; + const bool valid = q_pos < seq_len; + for (uint d = tid_in_tg; d < head_dim; d += total_threads) { + sQ[j * head_dim + d] = valid ? q_in[q_pos * q_stride + q_head * head_dim + d] : (half)0; + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + const uint d_blocks = head_dim / 8; + const uint kv_end = min(kv_seq_len, position_offset + q_start + A2_Q_BLOCK); + + for (uint kv0 = 0; kv0 < kv_end; kv0 += A2_KV_BLOCK) { + for (uint i = 0; i < A2_KV_BLOCK; ++i) { + const uint kv_pos = kv0 + i; + const bool valid = kv_pos < kv_seq_len; + for (uint d = tid_in_tg; d < head_dim; d += total_threads) { + sK[i * head_dim + d] = valid ? k_in[kv_pos * kv_stride + kv_head * head_dim + d] : (half)0; + sV[i * head_dim + d] = valid ? v_in[kv_pos * kv_stride + kv_head * head_dim + d] : (half)0; + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + const ushort kv_col_off = sgitg * 8; + simdgroup_float8x8 qk_acc = make_filled_simdgroup_matrix(0.0f); + for (uint b = 0; b < d_blocks; ++b) { + simdgroup_half8x8 q_tile, k_tile; + simdgroup_load(q_tile, sQ + b * 8, head_dim); + simdgroup_load(k_tile, sK + kv_col_off * head_dim + b * 8, head_dim, ulong2(0, 0), true); + simdgroup_multiply_accumulate(qk_acc, q_tile, k_tile, qk_acc); + } + simdgroup_store(qk_acc, sScratch + kv_col_off, A2_KV_BLOCK); + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg < 4) { + const ushort row = sgitg * 4 + tiisg; + const uint q_pos = q_start + row; + float scores[A2_KV_BLOCK]; + #pragma unroll + for (ushort c = 0; c < A2_KV_BLOCK; ++c) { + const uint kv_pos = kv0 + c; + float s = sScratch[row * A2_KV_BLOCK + c] * scale; + if (kv_pos > position_offset + q_pos || q_pos >= seq_len || kv_pos >= kv_seq_len) { + s = -INFINITY; + } + scores[c] = s; + } + float row_max = scores[0]; + for (ushort c = 1; c < A2_KV_BLOCK; ++c) row_max = max(row_max, scores[c]); + + const float old_m = sM[row]; + const float new_m = max(old_m, row_max); + const float alpha = (old_m == -INFINITY) ? 0.0f : exp(old_m - new_m); + + float new_l = alpha * sL[row]; + for (ushort c = 0; c < A2_KV_BLOCK; ++c) { + const float p = (new_m == -INFINITY) ? 0.0f : exp(scores[c] - new_m); + sP[row * A2_KV_BLOCK + c] = (half)p; + new_l += p; + } + sM[row] = new_m; + sL[row] = new_l; + for (uint d = 0; d < head_dim; ++d) { + sO[row * head_dim + d] *= alpha; + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + const uint d_off = sgitg * (head_dim / 2); + const uint sg_d_blocks = head_dim / 16; + for (uint b = 0; b < sg_d_blocks; ++b) { + simdgroup_float8x8 o_tile; + simdgroup_load(o_tile, sO + d_off + b * 8, head_dim); + simdgroup_half8x8 p0, p1, v0, v1; + simdgroup_load(p0, sP, A2_KV_BLOCK); + simdgroup_load(p1, sP + 8, A2_KV_BLOCK); + simdgroup_load(v0, sV + d_off + b * 8, head_dim); + simdgroup_load(v1, sV + 8 * head_dim + d_off + b * 8, head_dim); + simdgroup_multiply_accumulate(o_tile, p0, v0, o_tile); + simdgroup_multiply_accumulate(o_tile, p1, v1, o_tile); + simdgroup_store(o_tile, sO + d_off + b * 8, head_dim); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + for (uint j = 0; j < A2_Q_BLOCK; ++j) { + const uint q_pos = q_start + j; + if (q_pos >= seq_len) continue; + const float l = sL[j]; + const float inv_l = (l > 0.0f) ? (1.0f / l) : 0.0f; + for (uint d = tid_in_tg; d < head_dim; d += total_threads) { + o_out[q_pos * q_stride + q_head * head_dim + d] = (half)(sO[j * head_dim + d] * inv_l); + } + } } + )"; static id g_device = nil; static id g_queue = nil; static id g_dequant_pso = nil; +static id g_gemv_pso = nil; +static id g_attn_v2_pso = nil; +static id g_pending_cmd = nil; +static id g_last_committed = nil; +static NSMutableDictionary>* g_buffer_views = nil; +static id g_dequant_scratch = nil; +static size_t g_dequant_scratch_size = 0; +static NSMutableDictionary* g_mm_cache = nil; +static NSMutableDictionary* g_desc_cache = nil; +static NSMutableDictionary* g_dequant_mat_cache = nil; +static size_t g_page_size = 0; static dispatch_once_t g_once; +static id cactus_get_dequant_scratch(size_t bytes) { + if (g_dequant_scratch && g_dequant_scratch_size >= bytes) return g_dequant_scratch; + size_t want = bytes; + if (g_dequant_scratch_size > 0) { + size_t doubled = g_dequant_scratch_size * 2; + if (doubled > want) want = doubled; + } + g_dequant_scratch = [g_device newBufferWithLength:want options:MTLResourceStorageModeShared]; + g_dequant_scratch_size = want; + [g_dequant_mat_cache removeAllObjects]; + return g_dequant_scratch; +} + +static MPSMatrix* cactus_get_dequant_mat(uint32_t N, uint32_t K) { + uint64_t key = ((uint64_t)N << 32) | (uint64_t)K; + NSNumber* k = @(key); + MPSMatrix* m = g_dequant_mat_cache[k]; + if (!m) { + MPSMatrixDescriptor* d = [MPSMatrixDescriptor matrixDescriptorWithRows:N + columns:K + rowBytes:K * sizeof(__fp16) + dataType:MPSDataTypeFloat16]; + m = [[MPSMatrix alloc] initWithBuffer:g_dequant_scratch offset:0 descriptor:d]; + if (m) g_dequant_mat_cache[k] = m; + } + return m; +} + +static MPSMatrixDescriptor* cactus_get_desc(uint32_t rows, uint32_t cols) { + uint64_t key = ((uint64_t)rows << 32) | (uint64_t)cols; + NSNumber* k = @(key); + MPSMatrixDescriptor* d = g_desc_cache[k]; + if (!d) { + d = [MPSMatrixDescriptor matrixDescriptorWithRows:rows + columns:cols + rowBytes:cols * sizeof(__fp16) + dataType:MPSDataTypeFloat16]; + if (d) g_desc_cache[k] = d; + } + return d; +} + +static MPSMatrixMultiplication* cactus_get_mm(uint32_t M, uint32_t K, uint32_t N) { + uint64_t key = ((uint64_t)M << 42) | ((uint64_t)K << 21) | (uint64_t)N; + NSNumber* k = @(key); + MPSMatrixMultiplication* mm = g_mm_cache[k]; + if (!mm) { + mm = [[MPSMatrixMultiplication alloc] initWithDevice:g_device + transposeLeft:NO transposeRight:YES resultRows:M resultColumns:N interiorColumns:K alpha:1.0 beta:0.0]; + if (mm) g_mm_cache[k] = mm; + } + return mm; +} + static void cactus_mps_init() { dispatch_once(&g_once, ^{ g_device = MTLCreateSystemDefaultDevice(); if (!g_device) return; g_queue = [g_device newCommandQueue]; + g_buffer_views = [NSMutableDictionary new]; + g_mm_cache = [NSMutableDictionary new]; + g_desc_cache = [NSMutableDictionary new]; + g_dequant_mat_cache = [NSMutableDictionary new]; + g_page_size = sysconf(_SC_PAGESIZE); NSError* err = nil; id lib = [g_device newLibraryWithSource:kCactusMSL options:nil error:&err]; if (!lib) return; - id fn = [lib newFunctionWithName:@"cactus_dequant_int4"]; - if (!fn) return; - g_dequant_pso = [g_device newComputePipelineStateWithFunction:fn error:&err]; + id dq = [lib newFunctionWithName:@"cactus_dequant_int4"]; + if (dq) g_dequant_pso = [g_device newComputePipelineStateWithFunction:dq error:&err]; + id gv = [lib newFunctionWithName:@"cactus_gemv_int4"]; + if (gv) g_gemv_pso = [g_device newComputePipelineStateWithFunction:gv error:&err]; + id at2 = [lib newFunctionWithName:@"cactus_flash_attn_f16_v2"]; + if (at2) g_attn_v2_pso = [g_device newComputePipelineStateWithFunction:at2 error:&err]; }); } +static id cactus_buffer_view(const void* ptr, size_t len, size_t* out_offset) { + uintptr_t addr = (uintptr_t)ptr; + uintptr_t aligned_addr = addr & ~(g_page_size - 1); + *out_offset = addr - aligned_addr; + size_t aligned_len = (*out_offset + len + g_page_size - 1) & ~(g_page_size - 1); + NSValue* key = [NSValue valueWithPointer:(void*)aligned_addr]; + @synchronized(g_buffer_views) { + id buf = g_buffer_views[key]; + if (buf && [buf length] >= aligned_len) return buf; + buf = [g_device newBufferWithBytesNoCopy:(void*)aligned_addr + length:aligned_len + options:MTLResourceStorageModeShared + deallocator:nil]; + if (buf) g_buffer_views[key] = buf; + return buf; + } +} + +static bool g_mps_enabled = true; + bool cactus_mps_available() { cactus_mps_init(); return g_device != nil && g_queue != nil; } +void cactus_mps_set_enabled(bool enabled) { g_mps_enabled = enabled; } +bool cactus_mps_enabled() { return g_mps_enabled; } + +static id cactus_mps_active_cmd() { + if (!g_pending_cmd) { + g_pending_cmd = [g_queue commandBuffer]; + } + return g_pending_cmd; +} + +void cactus_mps_flush() { + if (g_pending_cmd) { + [g_pending_cmd commit]; + g_last_committed = g_pending_cmd; + g_pending_cmd = nil; + } +} + +void cactus_mps_synchronize() { + if (g_pending_cmd) { + [g_pending_cmd commit]; + g_last_committed = g_pending_cmd; + g_pending_cmd = nil; + } + if (g_last_committed) { + [g_last_committed waitUntilCompleted]; + g_last_committed = nil; + } +} + void cactus_matmul_f16_mps(const __fp16* A, const __fp16* B_T, __fp16* C, size_t M, size_t K, size_t N) { cactus_mps_init(); @@ -67,27 +398,57 @@ void cactus_matmul_f16_mps(const __fp16* A, const __fp16* B_T, __fp16* C, @autoreleasepool { const size_t fp16 = sizeof(__fp16); - id bufA = [g_device newBufferWithBytes:A length:M*K*fp16 options:MTLResourceStorageModeShared]; - id bufB = [g_device newBufferWithBytes:B_T length:N*K*fp16 options:MTLResourceStorageModeShared]; - id bufC = [g_device newBufferWithLength:M*N*fp16 options:MTLResourceStorageModeShared]; + size_t offA, offB, offC; + id bufA = cactus_buffer_view(A, M*K*fp16, &offA); + id bufB = cactus_buffer_view(B_T, N*K*fp16, &offB); + id bufC = cactus_buffer_view(C, M*N*fp16, &offC); - MPSMatrixDescriptor* dA = [MPSMatrixDescriptor matrixDescriptorWithRows:M columns:K rowBytes:K*fp16 dataType:MPSDataTypeFloat16]; - MPSMatrixDescriptor* dB = [MPSMatrixDescriptor matrixDescriptorWithRows:N columns:K rowBytes:K*fp16 dataType:MPSDataTypeFloat16]; - MPSMatrixDescriptor* dC = [MPSMatrixDescriptor matrixDescriptorWithRows:M columns:N rowBytes:N*fp16 dataType:MPSDataTypeFloat16]; + MPSMatrixDescriptor* dA = cactus_get_desc((uint32_t)M, (uint32_t)K); + MPSMatrixDescriptor* dB = cactus_get_desc((uint32_t)N, (uint32_t)K); + MPSMatrixDescriptor* dC = cactus_get_desc((uint32_t)M, (uint32_t)N); - MPSMatrix* mA = [[MPSMatrix alloc] initWithBuffer:bufA descriptor:dA]; - MPSMatrix* mB = [[MPSMatrix alloc] initWithBuffer:bufB descriptor:dB]; - MPSMatrix* mC = [[MPSMatrix alloc] initWithBuffer:bufC descriptor:dC]; + MPSMatrix* mA = [[MPSMatrix alloc] initWithBuffer:bufA offset:offA descriptor:dA]; + MPSMatrix* mB = [[MPSMatrix alloc] initWithBuffer:bufB offset:offB descriptor:dB]; + MPSMatrix* mC = [[MPSMatrix alloc] initWithBuffer:bufC offset:offC descriptor:dC]; - MPSMatrixMultiplication* mm = [[MPSMatrixMultiplication alloc] initWithDevice:g_device - transposeLeft:NO transposeRight:YES resultRows:M resultColumns:N interiorColumns:K alpha:1.0 beta:0.0]; + MPSMatrixMultiplication* mm = cactus_get_mm((uint32_t)M, (uint32_t)K, (uint32_t)N); - id cmd = [g_queue commandBuffer]; + id cmd = cactus_mps_active_cmd(); [mm encodeToCommandBuffer:cmd leftMatrix:mA rightMatrix:mB resultMatrix:mC]; - [cmd commit]; - [cmd waitUntilCompleted]; + } +} + +void cactus_gemv_int4_mps(const __fp16* A, const int8_t* B_packed, const __fp16* B_scales, + __fp16* C, size_t K, size_t N, size_t group_size) { + cactus_mps_init(); + if (!g_device || !g_queue || !g_gemv_pso) return; + if (N % 4 != 0 || K % group_size != 0) return; + + @autoreleasepool { + const size_t fp16 = sizeof(__fp16); + const size_t packed_bytes = (N / 4) * K * 2; + const size_t num_groups = K / group_size; + const size_t scales_bytes = (N / 4) * num_groups * 4 * fp16; + + size_t offA, offBp, offBs, offC; + id bufA = cactus_buffer_view(A, K*fp16, &offA); + id bufBp = cactus_buffer_view(B_packed, packed_bytes, &offBp); + id bufBs = cactus_buffer_view(B_scales, scales_bytes, &offBs); + id bufC = cactus_buffer_view(C, N*fp16, &offC); - memcpy(C, [bufC contents], M*N*fp16); + id cmd = cactus_mps_active_cmd(); + id enc = [cmd computeCommandEncoder]; + [enc setComputePipelineState:g_gemv_pso]; + [enc setBuffer:bufA offset:offA atIndex:0]; + [enc setBuffer:bufBp offset:offBp atIndex:1]; + [enc setBuffer:bufBs offset:offBs atIndex:2]; + [enc setBuffer:bufC offset:offC atIndex:3]; + uint32_t Ku = (uint32_t)K, Nu = (uint32_t)N, Gu = (uint32_t)group_size; + [enc setBytes:&Ku length:sizeof(Ku) atIndex:4]; + [enc setBytes:&Nu length:sizeof(Nu) atIndex:5]; + [enc setBytes:&Gu length:sizeof(Gu) atIndex:6]; + [enc dispatchThreadgroups:MTLSizeMake(N / 4, 1, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; + [enc endEncoding]; } } @@ -103,39 +464,213 @@ void cactus_matmul_int4_mps(const __fp16* A, const int8_t* B_packed, const __fp1 const size_t num_groups = K / group_size; const size_t scales_bytes = (N / 4) * num_groups * 4 * fp16; - id bufA = [g_device newBufferWithBytes:A length:M*K*fp16 options:MTLResourceStorageModeShared]; - id bufBp = [g_device newBufferWithBytes:B_packed length:packed_bytes options:MTLResourceStorageModeShared]; - id bufBs = [g_device newBufferWithBytes:B_scales length:scales_bytes options:MTLResourceStorageModeShared]; - id bufBd = [g_device newBufferWithLength:N*K*fp16 options:MTLResourceStorageModeShared]; - id bufC = [g_device newBufferWithLength:M*N*fp16 options:MTLResourceStorageModeShared]; + size_t offA, offBp, offBs, offC; + id bufA = cactus_buffer_view(A, M*K*fp16, &offA); + id bufBp = cactus_buffer_view(B_packed, packed_bytes, &offBp); + id bufBs = cactus_buffer_view(B_scales, scales_bytes, &offBs); + id bufBd = cactus_get_dequant_scratch(N*K*fp16); + id bufC = cactus_buffer_view(C, M*N*fp16, &offC); - id cmd = [g_queue commandBuffer]; + id cmd = cactus_mps_active_cmd(); id enc = [cmd computeCommandEncoder]; [enc setComputePipelineState:g_dequant_pso]; - [enc setBuffer:bufBp offset:0 atIndex:0]; - [enc setBuffer:bufBs offset:0 atIndex:1]; + [enc setBuffer:bufBp offset:offBp atIndex:0]; + [enc setBuffer:bufBs offset:offBs atIndex:1]; [enc setBuffer:bufBd offset:0 atIndex:2]; uint32_t Ku = (uint32_t)K; uint32_t Gu = (uint32_t)group_size; [enc setBytes:&Ku length:sizeof(Ku) atIndex:3]; [enc setBytes:&Gu length:sizeof(Gu) atIndex:4]; - [enc dispatchThreads:MTLSizeMake(K, N, 1) threadsPerThreadgroup:MTLSizeMake(32, 8, 1)]; + [enc dispatchThreads:MTLSizeMake(K / 8, N, 1) threadsPerThreadgroup:MTLSizeMake(32, 8, 1)]; [enc endEncoding]; - MPSMatrixDescriptor* dA = [MPSMatrixDescriptor matrixDescriptorWithRows:M columns:K rowBytes:K*fp16 dataType:MPSDataTypeFloat16]; - MPSMatrixDescriptor* dB = [MPSMatrixDescriptor matrixDescriptorWithRows:N columns:K rowBytes:K*fp16 dataType:MPSDataTypeFloat16]; - MPSMatrixDescriptor* dC = [MPSMatrixDescriptor matrixDescriptorWithRows:M columns:N rowBytes:N*fp16 dataType:MPSDataTypeFloat16]; - MPSMatrix* mA = [[MPSMatrix alloc] initWithBuffer:bufA descriptor:dA]; - MPSMatrix* mB = [[MPSMatrix alloc] initWithBuffer:bufBd descriptor:dB]; - MPSMatrix* mC = [[MPSMatrix alloc] initWithBuffer:bufC descriptor:dC]; - MPSMatrixMultiplication* mm = [[MPSMatrixMultiplication alloc] initWithDevice:g_device - transposeLeft:NO transposeRight:YES resultRows:M resultColumns:N interiorColumns:K alpha:1.0 beta:0.0]; + MPSMatrixDescriptor* dA = cactus_get_desc((uint32_t)M, (uint32_t)K); + MPSMatrixDescriptor* dC = cactus_get_desc((uint32_t)M, (uint32_t)N); + MPSMatrix* mA = [[MPSMatrix alloc] initWithBuffer:bufA offset:offA descriptor:dA]; + MPSMatrix* mB = cactus_get_dequant_mat((uint32_t)N, (uint32_t)K); + MPSMatrix* mC = [[MPSMatrix alloc] initWithBuffer:bufC offset:offC descriptor:dC]; + MPSMatrixMultiplication* mm = cactus_get_mm((uint32_t)M, (uint32_t)K, (uint32_t)N); [mm encodeToCommandBuffer:cmd leftMatrix:mA rightMatrix:mB resultMatrix:mC]; + } +} + +void cactus_attention_f16_mps(const __fp16* Q, const __fp16* K, const __fp16* V, __fp16* O, + size_t seq_len, size_t kv_seq_len, + size_t num_q_heads, size_t num_kv_heads, + size_t head_dim, float scale, size_t position_offset) { + cactus_mps_init(); + if (!g_device || !g_queue || !g_attn_v2_pso) return; + if (head_dim > 256 || head_dim % 16 != 0) return; + + @autoreleasepool { + const size_t fp16 = sizeof(__fp16); + const size_t q_bytes = seq_len * num_q_heads * head_dim * fp16; + const size_t kv_bytes = kv_seq_len * num_kv_heads * head_dim * fp16; + const size_t o_bytes = seq_len * num_q_heads * head_dim * fp16; + + size_t offQ, offK, offV, offO; + id bufQ = cactus_buffer_view(Q, q_bytes, &offQ); + id bufK = cactus_buffer_view(K, kv_bytes, &offK); + id bufV = cactus_buffer_view(V, kv_bytes, &offV); + id bufO = cactus_buffer_view(O, o_bytes, &offO); + + id cmd = cactus_mps_active_cmd(); + id enc = [cmd computeCommandEncoder]; + [enc setComputePipelineState:g_attn_v2_pso]; + [enc setBuffer:bufQ offset:offQ atIndex:0]; + [enc setBuffer:bufK offset:offK atIndex:1]; + [enc setBuffer:bufV offset:offV atIndex:2]; + [enc setBuffer:bufO offset:offO atIndex:3]; + uint32_t SL = (uint32_t)seq_len, KL = (uint32_t)kv_seq_len; + uint32_t QH = (uint32_t)num_q_heads, KH = (uint32_t)num_kv_heads; + uint32_t HD = (uint32_t)head_dim; + uint32_t PO = (uint32_t)position_offset; + [enc setBytes:&SL length:sizeof(SL) atIndex:4]; + [enc setBytes:&KL length:sizeof(KL) atIndex:5]; + [enc setBytes:&QH length:sizeof(QH) atIndex:6]; + [enc setBytes:&KH length:sizeof(KH) atIndex:7]; + [enc setBytes:&HD length:sizeof(HD) atIndex:8]; + [enc setBytes:&scale length:sizeof(scale) atIndex:9]; + [enc setBytes:&PO length:sizeof(PO) atIndex:10]; + const size_t Q_BLOCK_V2 = 8, KV_BLOCK_V2 = 16; + const size_t shmem_bytes = + (Q_BLOCK_V2 + KV_BLOCK_V2 + KV_BLOCK_V2) * head_dim * sizeof(__fp16) + + Q_BLOCK_V2 * KV_BLOCK_V2 * sizeof(__fp16) + + Q_BLOCK_V2 * head_dim * sizeof(float) + + (Q_BLOCK_V2 + Q_BLOCK_V2 + Q_BLOCK_V2 * KV_BLOCK_V2) * sizeof(float); + [enc setThreadgroupMemoryLength:shmem_bytes atIndex:0]; + const size_t q_blocks = (seq_len + Q_BLOCK_V2 - 1) / Q_BLOCK_V2; + [enc dispatchThreadgroups:MTLSizeMake(q_blocks, num_q_heads, 1) + threadsPerThreadgroup:MTLSizeMake(64, 1, 1)]; + [enc endEncoding]; + } +} + +API_AVAILABLE(macos(14.0), ios(17.0)) +@interface CactusSDPAGraph : NSObject +@property (nonatomic, strong) MPSGraph* graph; +@property (nonatomic, strong) MPSGraphTensor* qPh; +@property (nonatomic, strong) MPSGraphTensor* kPh; +@property (nonatomic, strong) MPSGraphTensor* vPh; +@property (nonatomic, strong) MPSGraphTensor* outT; +@end +@implementation CactusSDPAGraph +@end + +static NSMutableDictionary* g_sdpa_graph_cache = nil; + +API_AVAILABLE(macos(14.0), ios(17.0)) +static CactusSDPAGraph* cactus_get_sdpa_graph(size_t seq_len, size_t kv_seq_len, + size_t num_q_heads, size_t num_kv_heads, + size_t head_dim, float scale, size_t position_offset) { + if (!g_sdpa_graph_cache) g_sdpa_graph_cache = [NSMutableDictionary new]; + NSString* key = [NSString stringWithFormat:@"%zu_%zu_%zu_%zu_%zu_%zu_%a", + seq_len, kv_seq_len, num_q_heads, num_kv_heads, + head_dim, position_offset, (double)scale]; + CactusSDPAGraph* cached = g_sdpa_graph_cache[key]; + if (cached) return cached; + + MPSGraph* graph = [[MPSGraph alloc] init]; + + MPSGraphTensor* qPh = [graph placeholderWithShape:@[@(seq_len), @(num_q_heads), @(head_dim)] + dataType:MPSDataTypeFloat16 name:nil]; + MPSGraphTensor* kPh = [graph placeholderWithShape:@[@(kv_seq_len), @(num_kv_heads), @(head_dim)] + dataType:MPSDataTypeFloat16 name:nil]; + MPSGraphTensor* vPh = [graph placeholderWithShape:@[@(kv_seq_len), @(num_kv_heads), @(head_dim)] + dataType:MPSDataTypeFloat16 name:nil]; + + MPSGraphTensor* qT = [graph transposeTensor:qPh permutation:@[@1, @0, @2] name:nil]; + MPSGraphTensor* kT = [graph transposeTensor:kPh permutation:@[@1, @0, @2] name:nil]; + MPSGraphTensor* vT = [graph transposeTensor:vPh permutation:@[@1, @0, @2] name:nil]; + + if (num_q_heads != num_kv_heads) { + const NSUInteger groups = num_q_heads / num_kv_heads; + kT = [graph reshapeTensor:kT withShape:@[@(num_kv_heads), @1, @(kv_seq_len), @(head_dim)] name:nil]; + kT = [graph broadcastTensor:kT toShape:@[@(num_kv_heads), @(groups), @(kv_seq_len), @(head_dim)] name:nil]; + kT = [graph reshapeTensor:kT withShape:@[@(num_q_heads), @(kv_seq_len), @(head_dim)] name:nil]; + vT = [graph reshapeTensor:vT withShape:@[@(num_kv_heads), @1, @(kv_seq_len), @(head_dim)] name:nil]; + vT = [graph broadcastTensor:vT toShape:@[@(num_kv_heads), @(groups), @(kv_seq_len), @(head_dim)] name:nil]; + vT = [graph reshapeTensor:vT withShape:@[@(num_q_heads), @(kv_seq_len), @(head_dim)] name:nil]; + } + + std::vector<__fp16> mask_data(seq_len * kv_seq_len); + const __fp16 neg_inf = (__fp16)(-65504.0f); + for (size_t i = 0; i < seq_len; ++i) { + for (size_t j = 0; j < kv_seq_len; ++j) { + bool allowed = (j <= position_offset + i); + mask_data[i * kv_seq_len + j] = allowed ? (__fp16)0.0f : neg_inf; + } + } + NSData* maskNSData = [NSData dataWithBytes:mask_data.data() length:mask_data.size() * sizeof(__fp16)]; + MPSGraphTensor* maskTensor = [graph constantWithData:maskNSData + shape:@[@(seq_len), @(kv_seq_len)] + dataType:MPSDataTypeFloat16]; + + MPSGraphTensor* outT = [graph scaledDotProductAttentionWithQueryTensor:qT + keyTensor:kT + valueTensor:vT + maskTensor:maskTensor + scale:scale + name:nil]; + outT = [graph transposeTensor:outT permutation:@[@1, @0, @2] name:nil]; + + CactusSDPAGraph* entry = [CactusSDPAGraph new]; + entry.graph = graph; + entry.qPh = qPh; + entry.kPh = kPh; + entry.vPh = vPh; + entry.outT = outT; + g_sdpa_graph_cache[key] = entry; + return entry; +} + +void cactus_attention_f16_mpsgraph(const __fp16* Q, const __fp16* K, const __fp16* V, __fp16* O, + size_t seq_len, size_t kv_seq_len, + size_t num_q_heads, size_t num_kv_heads, + size_t head_dim, float scale, size_t position_offset) { + cactus_mps_init(); + if (!g_device || !g_queue) return; + + @autoreleasepool { + const size_t fp16 = sizeof(__fp16); + const size_t q_bytes = seq_len * num_q_heads * head_dim * fp16; + const size_t kv_bytes = kv_seq_len * num_kv_heads * head_dim * fp16; + const size_t o_bytes = seq_len * num_q_heads * head_dim * fp16; + + cactus_mps_synchronize(); + + id bufQ = [g_device newBufferWithBytes:Q length:q_bytes options:MTLResourceStorageModeShared]; + id bufK = [g_device newBufferWithBytes:K length:kv_bytes options:MTLResourceStorageModeShared]; + id bufV = [g_device newBufferWithBytes:V length:kv_bytes options:MTLResourceStorageModeShared]; + id bufO = [g_device newBufferWithLength:o_bytes options:MTLResourceStorageModeShared]; + + CactusSDPAGraph* entry = cactus_get_sdpa_graph(seq_len, kv_seq_len, + num_q_heads, num_kv_heads, + head_dim, scale, position_offset); + + MPSGraphTensorData* qData = [[MPSGraphTensorData alloc] initWithMTLBuffer:bufQ + shape:@[@(seq_len), @(num_q_heads), @(head_dim)] + dataType:MPSDataTypeFloat16]; + MPSGraphTensorData* kData = [[MPSGraphTensorData alloc] initWithMTLBuffer:bufK + shape:@[@(kv_seq_len), @(num_kv_heads), @(head_dim)] + dataType:MPSDataTypeFloat16]; + MPSGraphTensorData* vData = [[MPSGraphTensorData alloc] initWithMTLBuffer:bufV + shape:@[@(kv_seq_len), @(num_kv_heads), @(head_dim)] + dataType:MPSDataTypeFloat16]; + MPSGraphTensorData* oData = [[MPSGraphTensorData alloc] initWithMTLBuffer:bufO + shape:@[@(seq_len), @(num_q_heads), @(head_dim)] + dataType:MPSDataTypeFloat16]; + + NSDictionary* feeds = @{entry.qPh: qData, entry.kPh: kData, entry.vPh: vData}; + NSDictionary* results = @{entry.outT: oData}; + + [entry.graph runWithMTLCommandQueue:g_queue + feeds:feeds + targetOperations:nil + resultsDictionary:results]; - [cmd commit]; - [cmd waitUntilCompleted]; - memcpy(C, [bufC contents], M*N*fp16); + memcpy(O, [bufO contents], o_bytes); } } diff --git a/tests/test_kernel.cpp b/tests/test_kernel.cpp index 29231b71a..6f1d679f9 100644 --- a/tests/test_kernel.cpp +++ b/tests/test_kernel.cpp @@ -508,6 +508,7 @@ bool test_mps_matmul_f16_correctness() { } cactus_matmul_f16_mps(A.data(), BT.data(), C_mps.data(), M, K, N); + cactus_mps_synchronize(); for (size_t m = 0; m < M; ++m) { for (size_t n = 0; n < N; ++n) { @@ -600,6 +601,7 @@ bool test_mps_matmul_int4_correctness() { std::vector<__fp16> C_mps(M * N); cactus_matmul_int4_mps(A.data(), reinterpret_cast(B_packed.data()), B_scales_inter.data(), C_mps.data(), M, K, N, group_size); + cactus_mps_synchronize(); float max_err = 0.0f; for (size_t m = 0; m < M; ++m) { @@ -617,6 +619,81 @@ bool test_mps_matmul_int4_correctness() { std::cout << " MPS INT4 matmul max abs error: " << max_err << std::endl; return max_err < 0.1f; } + +bool test_mps_attention_f16_correctness() { + if (!cactus_mps_available()) return true; + const size_t seq_len = 64, kv_seq_len = 64; + const size_t num_q_heads = 4, num_kv_heads = 2, head_dim = 64; + const float scale = 1.0f / sqrtf(static_cast(head_dim)); + + std::mt19937 gen(13); + std::uniform_real_distribution dis(-0.5f, 0.5f); + + std::vector<__fp16> Q(seq_len * num_q_heads * head_dim); + std::vector<__fp16> K(kv_seq_len * num_kv_heads * head_dim); + std::vector<__fp16> V(kv_seq_len * num_kv_heads * head_dim); + for (size_t i = 0; i < Q.size(); ++i) Q[i] = static_cast<__fp16>(dis(gen)); + for (size_t i = 0; i < K.size(); ++i) K[i] = static_cast<__fp16>(dis(gen)); + for (size_t i = 0; i < V.size(); ++i) V[i] = static_cast<__fp16>(dis(gen)); + + std::vector<__fp16> O_mps(seq_len * num_q_heads * head_dim); + std::vector<__fp16> O_ref(seq_len * num_q_heads * head_dim); + + cactus_attention_f16_mps(Q.data(), K.data(), V.data(), O_mps.data(), + seq_len, kv_seq_len, num_q_heads, num_kv_heads, + head_dim, scale, 0); + cactus_mps_synchronize(); + + cactus_attention_f16(Q.data(), K.data(), V.data(), O_ref.data(), + 1, seq_len, kv_seq_len, num_q_heads, num_kv_heads, + head_dim, scale, nullptr, 0, 0, true, false, false, head_dim, 0.0f); + + float max_err = 0.0f; + for (size_t i = 0; i < O_mps.size(); ++i) { + float err = std::abs(static_cast(O_mps[i]) - static_cast(O_ref[i])); + if (err > max_err) max_err = err; + } + + std::cout << " MPS attention F16 max abs error: " << max_err << std::endl; + return max_err < 0.05f; +} + +bool test_mpsgraph_attention_f16_correctness() { + if (!cactus_mps_available()) return true; + const size_t seq_len = 64, kv_seq_len = 64; + const size_t num_q_heads = 4, num_kv_heads = 2, head_dim = 64; + const float scale = 1.0f / sqrtf(static_cast(head_dim)); + + std::mt19937 gen(17); + std::uniform_real_distribution dis(-0.5f, 0.5f); + + std::vector<__fp16> Q(seq_len * num_q_heads * head_dim); + std::vector<__fp16> K(kv_seq_len * num_kv_heads * head_dim); + std::vector<__fp16> V(kv_seq_len * num_kv_heads * head_dim); + for (size_t i = 0; i < Q.size(); ++i) Q[i] = static_cast<__fp16>(dis(gen)); + for (size_t i = 0; i < K.size(); ++i) K[i] = static_cast<__fp16>(dis(gen)); + for (size_t i = 0; i < V.size(); ++i) V[i] = static_cast<__fp16>(dis(gen)); + + std::vector<__fp16> O_g(seq_len * num_q_heads * head_dim); + std::vector<__fp16> O_ref(seq_len * num_q_heads * head_dim); + + cactus_attention_f16_mpsgraph(Q.data(), K.data(), V.data(), O_g.data(), + seq_len, kv_seq_len, num_q_heads, num_kv_heads, + head_dim, scale, 0); + + cactus_attention_f16(Q.data(), K.data(), V.data(), O_ref.data(), + 1, seq_len, kv_seq_len, num_q_heads, num_kv_heads, + head_dim, scale, nullptr, 0, 0, true, false, false, head_dim, 0.0f); + + float max_err = 0.0f; + for (size_t i = 0; i < O_g.size(); ++i) { + float err = std::abs(static_cast(O_g[i]) - static_cast(O_ref[i])); + if (err > max_err) max_err = err; + } + + std::cout << " MPSGraph SDPA F16 max abs error: " << max_err << std::endl; + return max_err < 0.05f; +} #endif int main() { @@ -638,6 +715,8 @@ int main() { #ifdef __APPLE__ runner.run_test("Kernel MPS FP16 MatMul Correctness", test_mps_matmul_f16_correctness()); runner.run_test("Kernel MPS INT4 MatMul Correctness", test_mps_matmul_int4_correctness()); + runner.run_test("Kernel MPS Attention F16 Correctness", test_mps_attention_f16_correctness()); + runner.run_test("Kernel MPSGraph SDPA F16 Correctness", test_mpsgraph_attention_f16_correctness()); #endif runner.print_summary(); From 4b0979871978c696e42d97c16501bc02e851fc91 Mon Sep 17 00:00:00 2001 From: ParkiratS Date: Tue, 5 May 2026 09:13:26 -0700 Subject: [PATCH 4/4] Benchmark metal + RAM and CPU --- cactus/graph/graph.h | 7 + cactus/graph/graph_io.cpp | 140 +++++++++++++++++-- cactus/kernel/kernel_mps.mm | 161 +++++++++++++++++++++- python/bench_gemma4_compare.py | 237 +++++++++++++++++++++++++++++++++ python/src/cli.py | 6 + tests/CMakeLists.txt | 5 + 6 files changed, 541 insertions(+), 15 deletions(-) create mode 100644 python/bench_gemma4_compare.py diff --git a/cactus/graph/graph.h b/cactus/graph/graph.h index 00da94c15..b90d20a58 100644 --- a/cactus/graph/graph.h +++ b/cactus/graph/graph.h @@ -763,6 +763,11 @@ namespace GraphFile { void prefetch_pages(); private: + enum class StorageMode { + MappedFile, + OwnedRam, + }; + int fd_; void* mapped_data_; size_t file_size_, data_offset_; @@ -777,9 +782,11 @@ namespace GraphFile { bool is_interleaved_ = false; size_t original_N_ = 0; + StorageMode storage_mode_ = StorageMode::MappedFile; void parse_header(); void apply_madvise_hints(); + void load_file_into_ram(); }; } diff --git a/cactus/graph/graph_io.cpp b/cactus/graph/graph_io.cpp index 8eded5e94..eea378ad3 100644 --- a/cactus/graph/graph_io.cpp +++ b/cactus/graph/graph_io.cpp @@ -1,5 +1,10 @@ #include "graph.h" #include "graph_param_io.h" +#include "../kernel/kernel.h" +#include +#include +#include +#include #include #include #include @@ -28,6 +33,31 @@ namespace { return offset + (alignment - remainder); } + std::string normalize_storage_mode(std::string value) { + std::transform(value.begin(), value.end(), value.begin(), + [](unsigned char c) { return static_cast(std::tolower(c)); }); + return value; + } + + bool should_avoid_mmap_for_tensor_files() { + const char* env_mode = std::getenv("CACTUS_WEIGHT_STORAGE"); + if (env_mode && *env_mode) { + const std::string mode = normalize_storage_mode(env_mode); + if (mode == "ram" || mode == "copy" || mode == "heap") { + return true; + } + if (mode == "mmap" || mode == "map" || mode == "mapped") { + return false; + } + } + +#ifdef __APPLE__ + return cactus_mps_enabled() && cactus_mps_available(); +#else + return false; +#endif + } + inline void write_u32(std::ostream& out, uint32_t v) { out.write(reinterpret_cast(&v), sizeof(v)); } @@ -598,24 +628,49 @@ MappedFile::MappedFile(const std::string& filename) } file_size_ = static_cast(st.st_size); - mapped_data_ = mmap(nullptr, file_size_, PROT_READ, MAP_SHARED, fd_, 0); - if (mapped_data_ == MAP_FAILED) { - close(fd_); - throw std::runtime_error("Cannot map file: " + filename); - } + try { + if (should_avoid_mmap_for_tensor_files()) { + storage_mode_ = StorageMode::OwnedRam; + load_file_into_ram(); + } else { + mapped_data_ = mmap(nullptr, file_size_, PROT_READ, MAP_SHARED, fd_, 0); + if (mapped_data_ == MAP_FAILED) { + throw std::runtime_error("Cannot map file: " + filename); + } + storage_mode_ = StorageMode::MappedFile; + } - close(fd_); - fd_ = -1; + close(fd_); + fd_ = -1; - parse_header(); - apply_madvise_hints(); + parse_header(); + apply_madvise_hints(); + } catch (...) { + if (fd_ != -1) { + close(fd_); + fd_ = -1; + } + if (storage_mode_ == StorageMode::MappedFile && + mapped_data_ != nullptr && mapped_data_ != MAP_FAILED) { + munmap(mapped_data_, file_size_); + mapped_data_ = nullptr; + } else if (storage_mode_ == StorageMode::OwnedRam && mapped_data_ != nullptr) { + free(mapped_data_); + mapped_data_ = nullptr; + } + throw; + } } MappedFile::~MappedFile() { - if (mapped_data_ != nullptr && mapped_data_ != MAP_FAILED) { + if (storage_mode_ == StorageMode::MappedFile && + mapped_data_ != nullptr && mapped_data_ != MAP_FAILED) { madvise(mapped_data_, file_size_, MADV_DONTNEED); munmap(mapped_data_, file_size_); mapped_data_ = nullptr; + } else if (storage_mode_ == StorageMode::OwnedRam && mapped_data_ != nullptr) { + free(mapped_data_); + mapped_data_ = nullptr; } if (fd_ != -1) { close(fd_); @@ -631,18 +686,23 @@ MappedFile::MappedFile(MappedFile&& other) noexcept scales_offset_(other.scales_offset_), scales_bytes_(other.scales_bytes_), alignment_(other.alignment_), is_interleaved_(other.is_interleaved_), - original_N_(other.original_N_) { + original_N_(other.original_N_), + storage_mode_(other.storage_mode_) { other.fd_ = -1; other.mapped_data_ = nullptr; other.file_size_ = 0; other.is_interleaved_ = false; other.original_N_ = 0; + other.storage_mode_ = StorageMode::MappedFile; } MappedFile& MappedFile::operator=(MappedFile&& other) noexcept { if (this != &other) { - if (mapped_data_ != nullptr && mapped_data_ != MAP_FAILED) { + if (storage_mode_ == StorageMode::MappedFile && + mapped_data_ != nullptr && mapped_data_ != MAP_FAILED) { munmap(mapped_data_, file_size_); + } else if (storage_mode_ == StorageMode::OwnedRam && mapped_data_ != nullptr) { + free(mapped_data_); } if (fd_ != -1) { close(fd_); @@ -662,11 +722,13 @@ MappedFile& MappedFile::operator=(MappedFile&& other) noexcept { alignment_ = other.alignment_; is_interleaved_ = other.is_interleaved_; original_N_ = other.original_N_; + storage_mode_ = other.storage_mode_; other.fd_ = -1; other.mapped_data_ = nullptr; other.file_size_ = 0; other.is_interleaved_ = false; other.original_N_ = 0; + other.storage_mode_ = StorageMode::MappedFile; } return *this; } @@ -771,6 +833,10 @@ void MappedFile::parse_header() { } void MappedFile::apply_madvise_hints() { + if (storage_mode_ != StorageMode::MappedFile) { + return; + } + if (scales_bytes_ > 0 && scales_offset_ > 0) { madvise(static_cast(mapped_data_) + scales_offset_, scales_bytes_, MADV_WILLNEED); } @@ -783,7 +849,10 @@ void MappedFile::apply_madvise_hints() { } void MappedFile::release_pages() { - if (mapped_data_ == nullptr || mapped_data_ == MAP_FAILED) return; + if (storage_mode_ != StorageMode::MappedFile || + mapped_data_ == nullptr || mapped_data_ == MAP_FAILED) { + return; + } if (scales_bytes_ > 0 && scales_offset_ > 0) { madvise(static_cast(mapped_data_) + scales_offset_, scales_bytes_, MADV_DONTNEED); @@ -792,7 +861,10 @@ void MappedFile::release_pages() { } void MappedFile::prefetch_pages() { - if (mapped_data_ == nullptr || mapped_data_ == MAP_FAILED) return; + if (storage_mode_ != StorageMode::MappedFile || + mapped_data_ == nullptr || mapped_data_ == MAP_FAILED) { + return; + } if (scales_bytes_ > 0 && scales_offset_ > 0) { madvise(static_cast(mapped_data_) + scales_offset_, scales_bytes_, MADV_WILLNEED); @@ -800,6 +872,46 @@ void MappedFile::prefetch_pages() { madvise(static_cast(mapped_data_) + data_offset_, byte_size_, MADV_WILLNEED); } +void MappedFile::load_file_into_ram() { + if (file_size_ == 0) { + throw std::runtime_error("Cannot load empty tensor file into RAM"); + } + + long page_size = sysconf(_SC_PAGESIZE); + if (page_size <= 0) { + page_size = 4096; + } + + void* buffer = nullptr; + const int alloc_rc = posix_memalign(&buffer, static_cast(page_size), file_size_); + if (alloc_rc != 0 || buffer == nullptr) { + throw std::runtime_error("Cannot allocate RAM buffer for tensor file"); + } + + char* dst = static_cast(buffer); + size_t total_read = 0; + constexpr size_t kMaxReadChunk = static_cast(1) << 30; // Keep macOS read() calls well below INT_MAX. + while (total_read < file_size_) { + size_t remaining = file_size_ - total_read; + size_t chunk_size = std::min(remaining, kMaxReadChunk); + ssize_t bytes_read = read(fd_, dst + total_read, chunk_size); + if (bytes_read < 0) { + if (errno == EINTR) { + continue; + } + free(buffer); + throw std::runtime_error("Cannot read tensor file into RAM"); + } + if (bytes_read == 0) { + free(buffer); + throw std::runtime_error("Unexpected EOF while reading tensor file into RAM"); + } + total_read += static_cast(bytes_read); + } + + mapped_data_ = buffer; +} + template const int8_t* MappedFile::typed_data() const; template const float* MappedFile::typed_data() const; template const uint16_t* MappedFile::typed_data() const; diff --git a/cactus/kernel/kernel_mps.mm b/cactus/kernel/kernel_mps.mm index 3bb53a91b..f03edbf85 100644 --- a/cactus/kernel/kernel_mps.mm +++ b/cactus/kernel/kernel_mps.mm @@ -3,8 +3,13 @@ #import #import #import +#include #include "kernel.h" +#include +#include +#include #include +#include #include #include @@ -356,13 +361,130 @@ static void cactus_mps_init() { static bool g_mps_enabled = true; +struct CactusMPSTraceCounters { + std::atomic matmul_f16{0}; + std::atomic gemv_int4{0}; + std::atomic matmul_int4{0}; + std::atomic attention_f16{0}; + std::atomic attention_graph{0}; +}; + +static CactusMPSTraceCounters g_mps_trace_counters; +static std::atomic g_mps_trace_event_index{0}; + +static bool parse_env_bool(const char* name, bool* out_value) { + const char* raw = std::getenv(name); + if (!raw || !*raw || !out_value) { + return false; + } + + std::string value(raw); + for (char& ch : value) { + ch = static_cast(std::tolower(static_cast(ch))); + } + + if (value == "1" || value == "true" || value == "yes" || value == "on") { + *out_value = true; + return true; + } + if (value == "0" || value == "false" || value == "no" || value == "off") { + *out_value = false; + return true; + } + + return false; +} + +static bool cactus_mps_env_enabled() { + bool enabled = true; + if (parse_env_bool("CACTUS_MPS", &enabled)) { + return enabled; + } + + bool disabled = false; + if (parse_env_bool("CACTUS_DISABLE_MPS", &disabled)) { + return !disabled; + } + + return true; +} + +static bool cactus_mps_trace_enabled() { + bool enabled = false; + return parse_env_bool("CACTUS_MPS_TRACE", &enabled) && enabled; +} + +static bool cactus_mps_trace_summary_enabled() { + bool enabled = false; + if (parse_env_bool("CACTUS_MPS_TRACE_SUMMARY", &enabled)) { + return enabled; + } + return cactus_mps_trace_enabled(); +} + +static void cactus_mps_trace_dump_summary() { + if (!cactus_mps_trace_summary_enabled()) { + return; + } + + cactus_mps_init(); + const int enabled = (g_mps_enabled && cactus_mps_env_enabled()) ? 1 : 0; + const int available = (g_device != nil && g_queue != nil) ? 1 : 0; + + const unsigned long long matmul_f16 = + static_cast(g_mps_trace_counters.matmul_f16.load(std::memory_order_relaxed)); + const unsigned long long gemv_int4 = + static_cast(g_mps_trace_counters.gemv_int4.load(std::memory_order_relaxed)); + const unsigned long long matmul_int4 = + static_cast(g_mps_trace_counters.matmul_int4.load(std::memory_order_relaxed)); + const unsigned long long attention_f16 = + static_cast(g_mps_trace_counters.attention_f16.load(std::memory_order_relaxed)); + const unsigned long long attention_graph = + static_cast(g_mps_trace_counters.attention_graph.load(std::memory_order_relaxed)); + const unsigned long long total = + matmul_f16 + gemv_int4 + matmul_int4 + attention_f16 + attention_graph; + + std::fprintf(stderr, + "[MPS_TRACE_SUMMARY] enabled=%d available=%d total=%llu matmul_f16=%llu gemv_int4=%llu " + "matmul_int4=%llu attention_f16=%llu attention_graph=%llu\n", + enabled, available, total, matmul_f16, gemv_int4, matmul_int4, attention_f16, attention_graph); + std::fflush(stderr); +} + +static void cactus_mps_trace_register_summary() { + static dispatch_once_t once; + dispatch_once(&once, ^{ + std::atexit(cactus_mps_trace_dump_summary); + }); +} + +static void cactus_mps_trace_log(const char* kernel_name, const std::string& details) { + if (!cactus_mps_trace_enabled() && !cactus_mps_trace_summary_enabled()) { + return; + } + + cactus_mps_trace_register_summary(); + + if (!cactus_mps_trace_enabled()) { + return; + } + + const unsigned long long event_index = + static_cast(1 + g_mps_trace_event_index.fetch_add(1, std::memory_order_relaxed)); + std::fprintf(stderr, "[MPS_TRACE] #%llu %s %s\n", event_index, kernel_name, details.c_str()); + std::fflush(stderr); +} + bool cactus_mps_available() { + if (cactus_mps_trace_summary_enabled()) { + cactus_mps_trace_register_summary(); + } cactus_mps_init(); return g_device != nil && g_queue != nil; } void cactus_mps_set_enabled(bool enabled) { g_mps_enabled = enabled; } -bool cactus_mps_enabled() { return g_mps_enabled; } +bool cactus_mps_enabled() { return g_mps_enabled && cactus_mps_env_enabled(); } static id cactus_mps_active_cmd() { if (!g_pending_cmd) { @@ -396,6 +518,12 @@ void cactus_matmul_f16_mps(const __fp16* A, const __fp16* B_T, __fp16* C, cactus_mps_init(); if (!g_device || !g_queue) return; + g_mps_trace_counters.matmul_f16.fetch_add(1, std::memory_order_relaxed); + cactus_mps_trace_log("matmul_f16", + "M=" + std::to_string(M) + + " K=" + std::to_string(K) + + " N=" + std::to_string(N)); + @autoreleasepool { const size_t fp16 = sizeof(__fp16); size_t offA, offB, offC; @@ -424,6 +552,12 @@ void cactus_gemv_int4_mps(const __fp16* A, const int8_t* B_packed, const __fp16* if (!g_device || !g_queue || !g_gemv_pso) return; if (N % 4 != 0 || K % group_size != 0) return; + g_mps_trace_counters.gemv_int4.fetch_add(1, std::memory_order_relaxed); + cactus_mps_trace_log("gemv_int4", + "K=" + std::to_string(K) + + " N=" + std::to_string(N) + + " group_size=" + std::to_string(group_size)); + @autoreleasepool { const size_t fp16 = sizeof(__fp16); const size_t packed_bytes = (N / 4) * K * 2; @@ -458,6 +592,13 @@ void cactus_matmul_int4_mps(const __fp16* A, const int8_t* B_packed, const __fp1 if (!g_device || !g_queue || !g_dequant_pso) return; if (N % 4 != 0 || K % group_size != 0) return; + g_mps_trace_counters.matmul_int4.fetch_add(1, std::memory_order_relaxed); + cactus_mps_trace_log("matmul_int4", + "M=" + std::to_string(M) + + " K=" + std::to_string(K) + + " N=" + std::to_string(N) + + " group_size=" + std::to_string(group_size)); + @autoreleasepool { const size_t fp16 = sizeof(__fp16); const size_t packed_bytes = (N / 4) * K * 2; @@ -503,6 +644,15 @@ void cactus_attention_f16_mps(const __fp16* Q, const __fp16* K, const __fp16* V, if (!g_device || !g_queue || !g_attn_v2_pso) return; if (head_dim > 256 || head_dim % 16 != 0) return; + g_mps_trace_counters.attention_f16.fetch_add(1, std::memory_order_relaxed); + cactus_mps_trace_log("attention_f16", + "seq_len=" + std::to_string(seq_len) + + " kv_seq_len=" + std::to_string(kv_seq_len) + + " num_q_heads=" + std::to_string(num_q_heads) + + " num_kv_heads=" + std::to_string(num_kv_heads) + + " head_dim=" + std::to_string(head_dim) + + " position_offset=" + std::to_string(position_offset)); + @autoreleasepool { const size_t fp16 = sizeof(__fp16); const size_t q_bytes = seq_len * num_q_heads * head_dim * fp16; @@ -632,6 +782,15 @@ void cactus_attention_f16_mpsgraph(const __fp16* Q, const __fp16* K, const __fp1 cactus_mps_init(); if (!g_device || !g_queue) return; + g_mps_trace_counters.attention_graph.fetch_add(1, std::memory_order_relaxed); + cactus_mps_trace_log("attention_graph", + "seq_len=" + std::to_string(seq_len) + + " kv_seq_len=" + std::to_string(kv_seq_len) + + " num_q_heads=" + std::to_string(num_q_heads) + + " num_kv_heads=" + std::to_string(num_kv_heads) + + " head_dim=" + std::to_string(head_dim) + + " position_offset=" + std::to_string(position_offset)); + @autoreleasepool { const size_t fp16 = sizeof(__fp16); const size_t q_bytes = seq_len * num_q_heads * head_dim * fp16; diff --git a/python/bench_gemma4_compare.py b/python/bench_gemma4_compare.py new file mode 100644 index 000000000..fb24d245b --- /dev/null +++ b/python/bench_gemma4_compare.py @@ -0,0 +1,237 @@ +#!/usr/bin/env python3 +import argparse +import json +import os +import statistics +import subprocess +import sys +import time +from pathlib import Path + +REPO_ROOT = Path(__file__).resolve().parents[1] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +DEFAULT_MODEL = "weights/gemma-4-e2b-it" +DEFAULT_PROMPT = "Write one friendly sentence about local AI." +RESULT_PREFIX = "RESULT_JSON=" +CASE_ENVS = [ + ("mps_ram", {"CACTUS_MPS": "1", "CACTUS_WEIGHT_STORAGE": "ram"}), + ("mps_mmap", {"CACTUS_MPS": "1", "CACTUS_WEIGHT_STORAGE": "mmap"}), + ("cpu_ram", {"CACTUS_MPS": "0", "CACTUS_WEIGHT_STORAGE": "ram"}), + ("cpu_mmap", {"CACTUS_MPS": "0", "CACTUS_WEIGHT_STORAGE": "mmap"}), +] +LONG_CONTEXT_BLOCK = ( + "Edge inference keeps AI workloads close to the user, which reduces latency, " + "improves privacy, cuts bandwidth use, keeps applications responsive offline, " + "and avoids repeated round trips to remote servers. " + "It also gives product teams more predictable performance, because the model can " + "respond immediately from local state instead of waiting on network conditions. " + "When the workload is interactive, this can make the difference between a tool " + "feeling instant and a tool feeling sluggish." +) + + +def build_prompt(preset): + if preset == "short": + return DEFAULT_PROMPT + if preset == "long_context": + sections = [ + "Summarize the repeated ideas below in one concise paragraph.", + *(f"Passage {i + 1}: {LONG_CONTEXT_BLOCK}" for i in range(16)), + ] + return "\n".join(sections) + raise ValueError(f"Unknown prompt preset: {preset}") + + +def _average_metrics(rows): + keys = [ + "time_to_first_token_ms", + "total_time_ms", + "prefill_tps", + "decode_tps", + "ram_usage_mb", + "prefill_tokens", + "decode_tokens", + "total_tokens", + ] + avg = {} + for key in keys: + values = [float(row.get(key, 0.0)) for row in rows] + avg[key] = statistics.fmean(values) if values else 0.0 + return avg + + +def _child_main(args): + from python.src.cactus import cactus_complete, cactus_destroy, cactus_init, cactus_reset + + prompt = args.prompt if args.prompt is not None else build_prompt(args.prompt_preset) + messages = json.dumps([{"role": "user", "content": prompt}]) + options = json.dumps({ + "temperature": 0.0, + "top_p": 1.0, + "top_k": 1, + "max_tokens": args.max_tokens, + }) + + init_start = time.perf_counter() + model = cactus_init(args.model, None, False) + init_ms = (time.perf_counter() - init_start) * 1000.0 + + def run_once(): + raw = cactus_complete(model, messages, options, None, None) + parsed = json.loads(raw) + if not parsed.get("success", False): + raise RuntimeError(parsed.get("error") or "cactus_complete failed") + return parsed + + for _ in range(args.warmup_runs): + cactus_reset(model) + run_once() + + runs = [] + for _ in range(args.runs): + cactus_reset(model) + runs.append(run_once()) + + cactus_destroy(model) + + payload = { + "label": args.label, + "model": args.model, + "prompt": prompt, + "prompt_preset": args.prompt_preset, + "env": { + "CACTUS_MPS": os.getenv("CACTUS_MPS", ""), + "CACTUS_DISABLE_MPS": os.getenv("CACTUS_DISABLE_MPS", ""), + "CACTUS_WEIGHT_STORAGE": os.getenv("CACTUS_WEIGHT_STORAGE", ""), + "CACTUS_MPS_TRACE": os.getenv("CACTUS_MPS_TRACE", ""), + "CACTUS_MPS_TRACE_SUMMARY": os.getenv("CACTUS_MPS_TRACE_SUMMARY", ""), + }, + "init_ms": init_ms, + "avg": _average_metrics(runs), + "last_response": runs[-1].get("response", "") if runs else "", + "runs": runs, + } + print(RESULT_PREFIX + json.dumps(payload)) + + +def _run_case(script_path, base_args, label, extra_env): + child_args = [ + sys.executable, + str(script_path), + "--child", + "--label", + label, + "--model", + base_args.model, + "--prompt-preset", + base_args.prompt_preset, + "--runs", + str(base_args.runs), + "--warmup-runs", + str(base_args.warmup_runs), + "--max-tokens", + str(base_args.max_tokens), + ] + if base_args.prompt is not None: + child_args.extend(["--prompt", base_args.prompt]) + + env = os.environ.copy() + env.update(extra_env) + if base_args.trace_mps and label.startswith("mps_"): + env["CACTUS_MPS_TRACE"] = "1" + env["CACTUS_MPS_TRACE_SUMMARY"] = "1" + proc = subprocess.run( + child_args, + cwd=Path(__file__).resolve().parents[1], + env=env, + text=True, + capture_output=True, + check=True, + ) + + payload = None + for line in proc.stdout.splitlines(): + if line.startswith(RESULT_PREFIX): + payload = json.loads(line[len(RESULT_PREFIX):]) + if payload is None: + raise RuntimeError(f"No benchmark payload found for {label}.\nSTDOUT:\n{proc.stdout}\nSTDERR:\n{proc.stderr}") + + return { + "payload": payload, + "stdout": proc.stdout, + "stderr": proc.stderr, + } + + +def _format_row(label, payload): + avg = payload["avg"] + return ( + f"{label:12} " + f"prefill_tokens={avg['prefill_tokens']:7.1f} " + f"init={payload['init_ms']:7.1f} ms " + f"ttft={avg['time_to_first_token_ms']:7.1f} ms " + f"total={avg['total_time_ms']:7.1f} ms " + f"prefill={avg['prefill_tps']:7.2f} tok/s " + f"decode={avg['decode_tps']:7.2f} tok/s " + f"ram={avg['ram_usage_mb']:8.1f} MB" + ) + + +def _parent_main(args): + script_path = Path(__file__).resolve() + cases = CASE_ENVS + + if args.case != "all": + cases = [case for case in cases if case[0] == args.case] + + results = [] + for label, env in cases: + results.append((label, _run_case(script_path, args, label, env))) + + prompt_preview = results[0][1]["payload"]["prompt"] if results else (args.prompt or build_prompt(args.prompt_preset)) + print(f"Model: {args.model}") + print(f"Prompt preset: {args.prompt_preset}") + if args.prompt is not None: + print("Prompt source: explicit --prompt") + print(f"Prompt preview: {prompt_preview[:160].replace(chr(10), ' ')}{'...' if len(prompt_preview) > 160 else ''}") + print(f"Runs: {args.runs} measured, {args.warmup_runs} warmup") + print("") + for label, result in results: + print(_format_row(label, result["payload"])) + print("") + for label, result in results: + print(f"[{label}] env={result['payload']['env']}") + response = result["payload"].get("last_response", "").strip() + if response: + print(f"[{label}] response={response}") + stderr = result["stderr"].strip() + if stderr: + print(f"[{label}] stderr:") + print(stderr) + + +def main(): + parser = argparse.ArgumentParser(description="Compare Gemma 4 E2B benchmark settings.") + case_choices = ["all", *(label for label, _ in CASE_ENVS)] + parser.add_argument("--child", action="store_true", help=argparse.SUPPRESS) + parser.add_argument("--label", default="run") + parser.add_argument("--model", default=DEFAULT_MODEL) + parser.add_argument("--prompt", default=None) + parser.add_argument("--prompt-preset", choices=["short", "long_context"], default="short") + parser.add_argument("--runs", type=int, default=3) + parser.add_argument("--warmup-runs", type=int, default=1) + parser.add_argument("--max-tokens", type=int, default=32) + parser.add_argument("--case", choices=case_choices, default="all") + parser.add_argument("--trace-mps", action="store_true") + args = parser.parse_args() + + if args.child: + _child_main(args) + else: + _parent_main(args) + + +if __name__ == "__main__": + main() diff --git a/python/src/cli.py b/python/src/cli.py index be47370c9..36f4bab42 100644 --- a/python/src/cli.py +++ b/python/src/cli.py @@ -819,6 +819,9 @@ def cmd_build(args): str(vendored_curl), "-framework", "Accelerate", "-framework", "CoreML", + "-framework", "Metal", + "-framework", "MetalPerformanceShaders", + "-framework", "MetalPerformanceShadersGraph", "-framework", "Foundation", "-framework", "Security", "-framework", "SystemConfiguration", @@ -866,6 +869,9 @@ def cmd_build(args): str(vendored_curl), "-framework", "Accelerate", "-framework", "CoreML", + "-framework", "Metal", + "-framework", "MetalPerformanceShaders", + "-framework", "MetalPerformanceShadersGraph", "-framework", "Foundation", "-framework", "Security", "-framework", "SystemConfiguration", diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 3e37d0334..b96be9f8f 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -48,6 +48,7 @@ else() find_library(COREML_FRAMEWORK CoreML REQUIRED) find_library(METAL_FRAMEWORK Metal REQUIRED) find_library(MPS_FRAMEWORK MetalPerformanceShaders REQUIRED) + find_library(MPSGRAPH_FRAMEWORK MetalPerformanceShadersGraph REQUIRED) find_library(FOUNDATION_FRAMEWORK Foundation REQUIRED) find_library(ACCELERATE_FRAMEWORK Accelerate REQUIRED) find_library(SECURITY_FRAMEWORK Security REQUIRED) @@ -99,6 +100,7 @@ foreach(TEST_FILE ${TEST_SOURCES}) ${COREML_FRAMEWORK} ${METAL_FRAMEWORK} ${MPS_FRAMEWORK} + ${MPSGRAPH_FRAMEWORK} ${FOUNDATION_FRAMEWORK} ${SECURITY_FRAMEWORK} ${SYSTEMCONFIGURATION_FRAMEWORK} @@ -133,6 +135,9 @@ foreach(APP_NAME chat asr) target_link_libraries(${APP_NAME} PRIVATE ${ACCELERATE_FRAMEWORK} ${COREML_FRAMEWORK} + ${METAL_FRAMEWORK} + ${MPS_FRAMEWORK} + ${MPSGRAPH_FRAMEWORK} ${FOUNDATION_FRAMEWORK} ${SECURITY_FRAMEWORK} ${SYSTEMCONFIGURATION_FRAMEWORK}