|
| 1 | +/* Copyright 2025 The JAX Authors. |
| 2 | +
|
| 3 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +you may not use this file except in compliance with the License. |
| 5 | +You may obtain a copy of the License at |
| 6 | +
|
| 7 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +
|
| 9 | +Unless required by applicable law or agreed to in writing, software |
| 10 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +See the License for the specific language governing permissions and |
| 13 | +limitations under the License. |
| 14 | +==============================================================================*/ |
| 15 | + |
| 16 | +#include "jaxlib/cpu/sparse_kernels.h" |
| 17 | + |
| 18 | +#include <algorithm> |
| 19 | +#include <complex> |
| 20 | +#include <cstdint> |
| 21 | +#include <vector> |
| 22 | + |
| 23 | +#include "Eigen/Core" |
| 24 | +#include "Eigen/SparseCore" |
| 25 | +#include "xla/ffi/api/ffi.h" |
| 26 | + |
| 27 | +namespace jax { |
| 28 | + |
| 29 | +template <typename ElementType, typename StorageType> |
| 30 | +using SparseMatrixType = |
| 31 | + Eigen::SparseMatrix<ElementType, Eigen::RowMajor, StorageType>; |
| 32 | +template <typename ElementType> |
| 33 | +using DenseMatrixType = |
| 34 | + Eigen::Matrix<ElementType, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>; |
| 35 | + |
| 36 | +template <typename MatrixT> |
| 37 | +using InputMap = Eigen::Map<const MatrixT, Eigen::Aligned32>; |
| 38 | +template <typename MatrixT> |
| 39 | +using OutputMap = Eigen::Map<MatrixT, Eigen::Aligned32>; |
| 40 | + |
| 41 | +template <typename ElementType, typename StorageType> |
| 42 | +static ::xla::ffi::Future CsrSparseDenseKernelImpl( |
| 43 | + const InputMap<SparseMatrixType<ElementType, StorageType>>& lhs_matrix, |
| 44 | + const InputMap<DenseMatrixType<ElementType>>& rhs_matrix, |
| 45 | + OutputMap<DenseMatrixType<ElementType>>& out_matrix, |
| 46 | + ::xla::ffi::ThreadPool& thread_pool) { |
| 47 | + // Rule of thumb to give each task at least 100k cycles to hide the cost of |
| 48 | + // task scheduling. |
| 49 | + // TODO(willfroom) Do we want to make this configurable? |
| 50 | + constexpr int64_t kTargetCyclesPerTask = 100'000; |
| 51 | + // Based on AVX (CPI 0.5 -> 2 IPC) |
| 52 | + constexpr int64_t kScalarProductsPerCycle = 2 * 32 / sizeof(ElementType); |
| 53 | + constexpr int64_t kTaskSize = kTargetCyclesPerTask * kScalarProductsPerCycle; |
| 54 | + |
| 55 | + if (lhs_matrix.nonZeros() * rhs_matrix.cols() <= kTaskSize || |
| 56 | + thread_pool.num_threads() == 0) { |
| 57 | + out_matrix.noalias() = lhs_matrix * rhs_matrix; |
| 58 | + |
| 59 | + ::xla::ffi::Promise promise; |
| 60 | + promise.SetAvailable(); |
| 61 | + return ::xla::ffi::Future(promise); |
| 62 | + } else { |
| 63 | + std::vector<int64_t> batch_sizes; |
| 64 | + { |
| 65 | + int64_t running_batch_nnz = 0; |
| 66 | + int64_t running_number_rows = 0; |
| 67 | + for (int row = 0; row < lhs_matrix.rows(); ++row) { |
| 68 | + int64_t row_nnz = lhs_matrix.outerIndexPtr()[row + 1] - |
| 69 | + lhs_matrix.outerIndexPtr()[row]; |
| 70 | + // If there is no non-zero elements in a row the task still needs to |
| 71 | + // write out a zero row we give each row a non-zero contribution to |
| 72 | + // avoid the pathological case of a task having to write many rows where |
| 73 | + // there is a large block of zero inputs. |
| 74 | + running_batch_nnz += std::max(row_nnz, static_cast<int64_t>(1)); |
| 75 | + running_number_rows++; |
| 76 | + if (running_batch_nnz * rhs_matrix.cols() > kTaskSize) { |
| 77 | + batch_sizes.push_back(running_number_rows); |
| 78 | + running_batch_nnz = 0; |
| 79 | + running_number_rows = 0; |
| 80 | + } else if (row == lhs_matrix.rows() - 1 && running_number_rows > 0) { |
| 81 | + batch_sizes.push_back(running_number_rows); |
| 82 | + } |
| 83 | + } |
| 84 | + } |
| 85 | + |
| 86 | + ::xla::ffi::CountDownPromise promise(batch_sizes.size()); |
| 87 | + ::xla::ffi::Future future(promise); |
| 88 | + int64_t batch_start = 0; |
| 89 | + for (int64_t size : batch_sizes) { |
| 90 | + thread_pool.Schedule([out_matrix, lhs_matrix, rhs_matrix, batch_start, |
| 91 | + size, promise]() mutable { |
| 92 | + out_matrix.middleRows(batch_start, size).noalias() = |
| 93 | + lhs_matrix.middleRows(batch_start, size) * rhs_matrix; |
| 94 | + promise.CountDown(); |
| 95 | + }); |
| 96 | + batch_start += size; |
| 97 | + } |
| 98 | + return future; |
| 99 | + } |
| 100 | +} |
| 101 | + |
| 102 | +template <typename ElementType, typename StorageType> |
| 103 | +static ::xla::ffi::Future CsrSparseDenseKernelTypedDispatch( |
| 104 | + ::xla::ffi::AnyBuffer lhs_data, ::xla::ffi::AnyBuffer lhs_outer_indicies, |
| 105 | + ::xla::ffi::AnyBuffer lhs_inner_indicies, ::xla::ffi::AnyBuffer rhs, |
| 106 | + ::xla::ffi::Result<::xla::ffi::AnyBuffer> out, |
| 107 | + ::xla::ffi::ThreadPool thread_pool) { |
| 108 | + ::xla::ffi::Span<const int64_t> rhs_shape = rhs.dimensions(); |
| 109 | + ::xla::ffi::Span<const int64_t> out_shape = out->dimensions(); |
| 110 | + |
| 111 | + InputMap<SparseMatrixType<ElementType, StorageType>> lhs_matrix( |
| 112 | + out_shape[0], rhs_shape[0], lhs_data.element_count(), |
| 113 | + lhs_outer_indicies.reinterpret_data<StorageType>(), |
| 114 | + lhs_inner_indicies.reinterpret_data<StorageType>(), |
| 115 | + lhs_data.reinterpret_data<ElementType>()); |
| 116 | + |
| 117 | + InputMap<DenseMatrixType<ElementType>> rhs_matrix( |
| 118 | + rhs.reinterpret_data<ElementType>(), rhs_shape[0], |
| 119 | + rhs_shape.size() > 1 ? rhs_shape[1] : 1); |
| 120 | + OutputMap<DenseMatrixType<ElementType>> out_matrix( |
| 121 | + out->reinterpret_data<ElementType>(), lhs_matrix.rows(), |
| 122 | + rhs_matrix.cols()); |
| 123 | + |
| 124 | + return CsrSparseDenseKernelImpl<ElementType, StorageType>( |
| 125 | + lhs_matrix, rhs_matrix, out_matrix, thread_pool); |
| 126 | +} |
| 127 | + |
| 128 | +template <typename ElementType> |
| 129 | +static ::xla::ffi::Future CsrSparseDenseKernelTypedDispatch( |
| 130 | + ::xla::ffi::AnyBuffer lhs_data, ::xla::ffi::AnyBuffer lhs_outer_indicies, |
| 131 | + ::xla::ffi::AnyBuffer lhs_inner_indicies, ::xla::ffi::AnyBuffer rhs, |
| 132 | + ::xla::ffi::Result<::xla::ffi::AnyBuffer> out, |
| 133 | + ::xla::ffi::ThreadPool thread_pool) { |
| 134 | + if (lhs_outer_indicies.element_type() != lhs_inner_indicies.element_type()) { |
| 135 | + ::xla::ffi::Promise promise; |
| 136 | + promise.SetError(::xla::ffi::Error(::xla::ffi::ErrorCode::kInvalidArgument, |
| 137 | + "Sparse index type mismatch")); |
| 138 | + return ::xla::ffi::Future(promise); |
| 139 | + } |
| 140 | + |
| 141 | + switch (lhs_outer_indicies.element_type()) { |
| 142 | + case ::xla::ffi::DataType::S32: |
| 143 | + return CsrSparseDenseKernelTypedDispatch<ElementType, int32_t>( |
| 144 | + lhs_data, lhs_outer_indicies, lhs_inner_indicies, rhs, out, |
| 145 | + thread_pool); |
| 146 | + case ::xla::ffi::DataType::S64: |
| 147 | + return CsrSparseDenseKernelTypedDispatch<ElementType, int64_t>( |
| 148 | + lhs_data, lhs_outer_indicies, lhs_inner_indicies, rhs, out, |
| 149 | + thread_pool); |
| 150 | + default: |
| 151 | + ::xla::ffi::Promise promise; |
| 152 | + promise.SetError(::xla::ffi::Error( |
| 153 | + ::xla::ffi::ErrorCode::kInvalidArgument, "Invalid index data type")); |
| 154 | + return ::xla::ffi::Future(promise); |
| 155 | + } |
| 156 | +} |
| 157 | + |
| 158 | +static ::xla::ffi::Future CsrSparseDenseKernelDispatch( |
| 159 | + ::xla::ffi::AnyBuffer lhs_data, ::xla::ffi::AnyBuffer lhs_outer_indicies, |
| 160 | + ::xla::ffi::AnyBuffer lhs_inner_indicies, ::xla::ffi::AnyBuffer rhs, |
| 161 | + ::xla::ffi::Result<::xla::ffi::AnyBuffer> out, |
| 162 | + ::xla::ffi::ThreadPool thread_pool) { |
| 163 | + if (lhs_data.element_type() != rhs.element_type() || |
| 164 | + lhs_data.element_type() != out->element_type()) { |
| 165 | + ::xla::ffi::Promise promise; |
| 166 | + promise.SetError(::xla::ffi::Error(::xla::ffi::ErrorCode::kInvalidArgument, |
| 167 | + "Element type mismatch")); |
| 168 | + return ::xla::ffi::Future(promise); |
| 169 | + } |
| 170 | + |
| 171 | + switch (lhs_data.element_type()) { |
| 172 | + case ::xla::ffi::DataType::S32: |
| 173 | + return CsrSparseDenseKernelTypedDispatch<int32_t>( |
| 174 | + lhs_data, lhs_outer_indicies, lhs_inner_indicies, rhs, out, |
| 175 | + thread_pool); |
| 176 | + case ::xla::ffi::DataType::S64: |
| 177 | + return CsrSparseDenseKernelTypedDispatch<int64_t>( |
| 178 | + lhs_data, lhs_outer_indicies, lhs_inner_indicies, rhs, out, |
| 179 | + thread_pool); |
| 180 | + case ::xla::ffi::DataType::F32: |
| 181 | + return CsrSparseDenseKernelTypedDispatch<float>( |
| 182 | + lhs_data, lhs_outer_indicies, lhs_inner_indicies, rhs, out, |
| 183 | + thread_pool); |
| 184 | + case ::xla::ffi::DataType::F64: |
| 185 | + return CsrSparseDenseKernelTypedDispatch<double>( |
| 186 | + lhs_data, lhs_outer_indicies, lhs_inner_indicies, rhs, out, |
| 187 | + thread_pool); |
| 188 | + case ::xla::ffi::DataType::C64: |
| 189 | + return CsrSparseDenseKernelTypedDispatch<std::complex<float>>( |
| 190 | + lhs_data, lhs_outer_indicies, lhs_inner_indicies, rhs, out, |
| 191 | + thread_pool); |
| 192 | + case ::xla::ffi::DataType::C128: |
| 193 | + return CsrSparseDenseKernelTypedDispatch<std::complex<double>>( |
| 194 | + lhs_data, lhs_outer_indicies, lhs_inner_indicies, rhs, out, |
| 195 | + thread_pool); |
| 196 | + default: |
| 197 | + ::xla::ffi::Promise promise; |
| 198 | + promise.SetError(::xla::ffi::Error( |
| 199 | + ::xla::ffi::ErrorCode::kInvalidArgument, "Invalid data type")); |
| 200 | + return ::xla::ffi::Future(promise); |
| 201 | + } |
| 202 | +} |
| 203 | + |
| 204 | +XLA_FFI_DEFINE_HANDLER_SYMBOL( |
| 205 | + cpu_csr_sparse_dense_ffi, CsrSparseDenseKernelDispatch, |
| 206 | + (::xla::ffi::Ffi::Bind() |
| 207 | + .Arg<::xla::ffi::AnyBuffer>(/*lhs_data*/) |
| 208 | + .Arg<::xla::ffi::AnyBuffer>( |
| 209 | + /*lhs_outer_indicies*/) |
| 210 | + .Arg<::xla::ffi::AnyBuffer>( |
| 211 | + /*lhs_inner_indicies*/) |
| 212 | + .Arg<::xla::ffi::AnyBuffer>(/*rhs*/) |
| 213 | + .Ret<::xla::ffi::AnyBuffer>(/*out*/) |
| 214 | + .Ctx<::xla::ffi::ThreadPool>(/*thread_pool*/))); |
| 215 | + |
| 216 | +} // namespace jax |
0 commit comments