Skip to content

Commit 671d2b5

Browse files
WillFroomGoogle-ML-Automation
authored andcommitted
[JAX:Sparse] Implement CSR sparse kernel
PiperOrigin-RevId: 750576197
1 parent 603f730 commit 671d2b5

File tree

9 files changed

+360
-1
lines changed

9 files changed

+360
-1
lines changed

jaxlib/BUILD

+2
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ package_group(
5252
py_library_providing_imports_info(
5353
name = "jaxlib",
5454
srcs = [
55+
"cpu_sparse.py",
5556
"gpu_common_utils.py",
5657
"gpu_linalg.py",
5758
"gpu_prng.py",
@@ -76,6 +77,7 @@ py_library_providing_imports_info(
7677
"//jaxlib:_jax",
7778
"//jaxlib:xla_client",
7879
"//jaxlib/cpu:_lapack",
80+
"//jaxlib/cpu:_sparse",
7981
"//jaxlib/mlir",
8082
"//jaxlib/mlir:arithmetic_dialect",
8183
"//jaxlib/mlir:builtin_dialect",

jaxlib/cpu/BUILD

+33
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,42 @@ cc_library(
8585
deps = [
8686
":lapack_kernels",
8787
":lapack_kernels_using_lapack",
88+
":sparse_kernels",
8889
"@xla//xla/ffi/api:c_api",
8990
"@xla//xla/ffi/api:ffi",
9091
"@xla//xla/service:custom_call_target_registry",
9192
],
9293
alwayslink = 1,
9394
)
95+
96+
cc_library(
97+
name = "sparse_kernels",
98+
srcs = ["sparse_kernels.cc"],
99+
hdrs = ["sparse_kernels.h"],
100+
deps = [
101+
"@eigen_archive//:eigen3",
102+
"@xla//xla/ffi/api:ffi",
103+
],
104+
)
105+
106+
nanobind_extension(
107+
name = "_sparse",
108+
srcs = ["sparse.cc"],
109+
copts = [
110+
"-fexceptions",
111+
"-fno-strict-aliasing",
112+
],
113+
enable_stub_generation = False,
114+
features = ["-use_header_modules"],
115+
module_name = "_sparse",
116+
pytype_srcs = [
117+
"_sparse/__init__.pyi",
118+
],
119+
deps = [
120+
":sparse_kernels",
121+
"//jaxlib:kernel_nanobind_helpers",
122+
"@com_google_absl//absl/base",
123+
"@nanobind",
124+
"@xla//xla/ffi/api:ffi",
125+
],
126+
)

jaxlib/cpu/_sparse/__init__.pyi

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright 2025 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
def registrations() -> dict: ...

jaxlib/cpu/cpu_kernels.cc

+3
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ limitations under the License.
1919
#include <complex>
2020

2121
#include "jaxlib/cpu/lapack_kernels.h"
22+
#include "jaxlib/cpu/sparse_kernels.h"
2223
#include "xla/ffi/api/c_api.h"
2324
#include "xla/ffi/api/ffi.h"
2425
#include "xla/service/custom_call_target_registry.h"
@@ -110,6 +111,8 @@ JAX_CPU_REGISTER_HANDLER(lapack_dgtsv_ffi);
110111
JAX_CPU_REGISTER_HANDLER(lapack_cgtsv_ffi);
111112
JAX_CPU_REGISTER_HANDLER(lapack_zgtsv_ffi);
112113

114+
JAX_CPU_REGISTER_HANDLER(cpu_csr_sparse_dense_ffi);
115+
113116
#undef JAX_CPU_REGISTER_HANDLER
114117

115118
} // namespace

jaxlib/cpu/sparse.cc

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
/* Copyright 2025 The JAX Authors.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "nanobind/nanobind.h"
17+
#include "jaxlib/cpu/sparse_kernels.h"
18+
#include "jaxlib/kernel_nanobind_helpers.h"
19+
20+
namespace jax {
21+
namespace {
22+
23+
namespace nb = nanobind;
24+
25+
nb::dict Registrations() {
26+
nb::dict dict;
27+
28+
dict["cpu_csr_sparse_dense_ffi"] =
29+
EncapsulateFunction(cpu_csr_sparse_dense_ffi);
30+
31+
return dict;
32+
}
33+
34+
NB_MODULE(_sparse, m) { m.def("registrations", &Registrations); }
35+
36+
} // namespace
37+
} // namespace jax

jaxlib/cpu/sparse_kernels.cc

+215
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
/* Copyright 2025 The JAX Authors.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "jaxlib/cpu/sparse_kernels.h"
17+
18+
#include <algorithm>
19+
#include <complex>
20+
#include <cstdint>
21+
#include <vector>
22+
23+
#include "Eigen/Core"
24+
#include "Eigen/SparseCore"
25+
#include "xla/ffi/api/ffi.h"
26+
27+
namespace ffi = xla::ffi;
28+
29+
namespace jax {
30+
31+
template <typename ElementType, typename StorageType>
32+
using SparseMatrixType =
33+
Eigen::SparseMatrix<ElementType, Eigen::RowMajor, StorageType>;
34+
template <typename ElementType>
35+
using DenseMatrixType =
36+
Eigen::Matrix<ElementType, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
37+
38+
template <typename MatrixT>
39+
using InputMap = Eigen::Map<const MatrixT, Eigen::Aligned32>;
40+
template <typename MatrixT>
41+
using OutputMap = Eigen::Map<MatrixT, Eigen::Aligned32>;
42+
43+
template <typename ElementType, typename StorageType>
44+
static ffi::Future CsrSparseDenseKernelImpl(
45+
const InputMap<SparseMatrixType<ElementType, StorageType>>& lhs_matrix,
46+
const InputMap<DenseMatrixType<ElementType>>& rhs_matrix,
47+
OutputMap<DenseMatrixType<ElementType>>& out_matrix,
48+
ffi::ThreadPool& thread_pool) {
49+
// Rule of thumb to give each task at least 100k cycles to hide the cost of
50+
// task scheduling.
51+
// TODO(willfroom) Do we want to make this configurable?
52+
constexpr int64_t kTargetCyclesPerTask = 100'000;
53+
// Based on AVX (CPI 0.5 -> 2 IPC)
54+
constexpr int64_t kScalarProductsPerCycle = 2 * 32 / sizeof(ElementType);
55+
constexpr int64_t kTaskSize = kTargetCyclesPerTask * kScalarProductsPerCycle;
56+
57+
if (lhs_matrix.nonZeros() * rhs_matrix.cols() <= kTaskSize ||
58+
thread_pool.num_threads() == 0) {
59+
out_matrix.noalias() = lhs_matrix * rhs_matrix;
60+
61+
ffi::Promise promise;
62+
promise.SetAvailable();
63+
return ffi::Future(promise);
64+
} else {
65+
std::vector<int64_t> batch_sizes;
66+
{
67+
int64_t running_batch_nnz = 0;
68+
int64_t running_number_rows = 0;
69+
for (int row = 0; row < lhs_matrix.rows(); ++row) {
70+
int64_t row_nnz = lhs_matrix.outerIndexPtr()[row + 1] -
71+
lhs_matrix.outerIndexPtr()[row];
72+
// If there is no non-zero elements in a row the task still needs to
73+
// write out a zero row we give each row a non-zero contribution to
74+
// avoid the pathological case of a task having to write many rows where
75+
// there is a large block of zero inputs.
76+
running_batch_nnz += std::max(row_nnz, static_cast<int64_t>(1));
77+
running_number_rows++;
78+
if (running_batch_nnz * rhs_matrix.cols() > kTaskSize) {
79+
batch_sizes.push_back(running_number_rows);
80+
running_batch_nnz = 0;
81+
running_number_rows = 0;
82+
} else if (row == lhs_matrix.rows() - 1 && running_number_rows > 0) {
83+
batch_sizes.push_back(running_number_rows);
84+
}
85+
}
86+
}
87+
88+
ffi::CountDownPromise promise(batch_sizes.size());
89+
ffi::Future future(promise);
90+
int64_t batch_start = 0;
91+
for (int64_t size : batch_sizes) {
92+
thread_pool.Schedule([out_matrix, lhs_matrix, rhs_matrix, batch_start,
93+
size, promise]() mutable {
94+
out_matrix.middleRows(batch_start, size).noalias() =
95+
lhs_matrix.middleRows(batch_start, size) * rhs_matrix;
96+
promise.CountDown();
97+
});
98+
batch_start += size;
99+
}
100+
return future;
101+
}
102+
}
103+
104+
template <typename ElementType, typename StorageType>
105+
static ffi::Future CsrSparseDenseKernelTypedDispatch(
106+
ffi::AnyBuffer lhs_data, ffi::AnyBuffer lhs_outer_indicies,
107+
ffi::AnyBuffer lhs_inner_indicies, ffi::AnyBuffer rhs,
108+
ffi::Result<ffi::AnyBuffer> out, ffi::ThreadPool thread_pool) {
109+
ffi::Span<const int64_t> rhs_shape = rhs.dimensions();
110+
ffi::Span<const int64_t> out_shape = out->dimensions();
111+
112+
InputMap<SparseMatrixType<ElementType, StorageType>> lhs_matrix(
113+
out_shape[0], rhs_shape[0], lhs_data.element_count(),
114+
lhs_outer_indicies.reinterpret_data<StorageType>(),
115+
lhs_inner_indicies.reinterpret_data<StorageType>(),
116+
lhs_data.reinterpret_data<ElementType>());
117+
118+
InputMap<DenseMatrixType<ElementType>> rhs_matrix(
119+
rhs.reinterpret_data<ElementType>(), rhs_shape[0],
120+
rhs_shape.size() > 1 ? rhs_shape[1] : 1);
121+
OutputMap<DenseMatrixType<ElementType>> out_matrix(
122+
out->reinterpret_data<ElementType>(), lhs_matrix.rows(),
123+
rhs_matrix.cols());
124+
125+
return CsrSparseDenseKernelImpl<ElementType, StorageType>(
126+
lhs_matrix, rhs_matrix, out_matrix, thread_pool);
127+
}
128+
129+
template <typename ElementType>
130+
static ffi::Future CsrSparseDenseKernelTypedDispatch(
131+
ffi::AnyBuffer lhs_data, ffi::AnyBuffer lhs_outer_indicies,
132+
ffi::AnyBuffer lhs_inner_indicies, ffi::AnyBuffer rhs,
133+
ffi::Result<ffi::AnyBuffer> out, ffi::ThreadPool thread_pool) {
134+
if (lhs_outer_indicies.element_type() != lhs_inner_indicies.element_type()) {
135+
ffi::Promise promise;
136+
promise.SetError(ffi::Error(ffi::ErrorCode::kInvalidArgument,
137+
"Sparse index type mismatch"));
138+
return ffi::Future(promise);
139+
}
140+
141+
switch (lhs_outer_indicies.element_type()) {
142+
case ffi::DataType::S32:
143+
return CsrSparseDenseKernelTypedDispatch<ElementType, int32_t>(
144+
lhs_data, lhs_outer_indicies, lhs_inner_indicies, rhs, out,
145+
thread_pool);
146+
case ffi::DataType::S64:
147+
return CsrSparseDenseKernelTypedDispatch<ElementType, int64_t>(
148+
lhs_data, lhs_outer_indicies, lhs_inner_indicies, rhs, out,
149+
thread_pool);
150+
default:
151+
ffi::Promise promise;
152+
promise.SetError(ffi::Error(ffi::ErrorCode::kInvalidArgument,
153+
"Invalid index data type"));
154+
return ffi::Future(promise);
155+
}
156+
}
157+
158+
static ffi::Future CsrSparseDenseKernelDispatch(
159+
ffi::AnyBuffer lhs_data, ffi::AnyBuffer lhs_outer_indicies,
160+
ffi::AnyBuffer lhs_inner_indicies, ffi::AnyBuffer rhs,
161+
ffi::Result<ffi::AnyBuffer> out, ffi::ThreadPool thread_pool) {
162+
if (lhs_data.element_type() != rhs.element_type() ||
163+
lhs_data.element_type() != out->element_type()) {
164+
ffi::Promise promise;
165+
promise.SetError(
166+
ffi::Error(ffi::ErrorCode::kInvalidArgument, "Element type mismatch"));
167+
return ffi::Future(promise);
168+
}
169+
170+
switch (lhs_data.element_type()) {
171+
case ffi::DataType::S32:
172+
return CsrSparseDenseKernelTypedDispatch<int32_t>(
173+
lhs_data, lhs_outer_indicies, lhs_inner_indicies, rhs, out,
174+
thread_pool);
175+
case ffi::DataType::S64:
176+
return CsrSparseDenseKernelTypedDispatch<int64_t>(
177+
lhs_data, lhs_outer_indicies, lhs_inner_indicies, rhs, out,
178+
thread_pool);
179+
case ffi::DataType::F32:
180+
return CsrSparseDenseKernelTypedDispatch<float>(
181+
lhs_data, lhs_outer_indicies, lhs_inner_indicies, rhs, out,
182+
thread_pool);
183+
case ffi::DataType::F64:
184+
return CsrSparseDenseKernelTypedDispatch<double>(
185+
lhs_data, lhs_outer_indicies, lhs_inner_indicies, rhs, out,
186+
thread_pool);
187+
case ffi::DataType::C64:
188+
return CsrSparseDenseKernelTypedDispatch<std::complex<float>>(
189+
lhs_data, lhs_outer_indicies, lhs_inner_indicies, rhs, out,
190+
thread_pool);
191+
case ffi::DataType::C128:
192+
return CsrSparseDenseKernelTypedDispatch<std::complex<double>>(
193+
lhs_data, lhs_outer_indicies, lhs_inner_indicies, rhs, out,
194+
thread_pool);
195+
default:
196+
ffi::Promise promise;
197+
promise.SetError(
198+
ffi::Error(ffi::ErrorCode::kInvalidArgument, "Invalid data type"));
199+
return ffi::Future(promise);
200+
}
201+
}
202+
203+
XLA_FFI_DEFINE_HANDLER_SYMBOL(cpu_csr_sparse_dense_ffi,
204+
CsrSparseDenseKernelDispatch,
205+
(ffi::Ffi::Bind()
206+
.Arg<ffi::AnyBuffer>(/*lhs_data*/)
207+
.Arg<ffi::AnyBuffer>(
208+
/*lhs_outer_indicies*/)
209+
.Arg<ffi::AnyBuffer>(
210+
/*lhs_inner_indicies*/)
211+
.Arg<ffi::AnyBuffer>(/*rhs*/)
212+
.Ret<ffi::AnyBuffer>(/*out*/)
213+
.Ctx<ffi::ThreadPool>(/*thread_pool*/)));
214+
215+
} // namespace jax

jaxlib/cpu/sparse_kernels.h

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
/* Copyright 2025 The JAX Authors.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#ifndef THIRD_PARTY_PY_JAX_JAXLIB_CPU_SPARSE_KERNELS_H_
17+
#define THIRD_PARTY_PY_JAX_JAXLIB_CPU_SPARSE_KERNELS_H_
18+
19+
#include "xla/ffi/api/ffi.h"
20+
21+
namespace jax {
22+
23+
XLA_FFI_DECLARE_HANDLER_SYMBOL(cpu_csr_sparse_dense_ffi);
24+
25+
} // namespace jax
26+
27+
#endif // THIRD_PARTY_PY_JAX_JAXLIB_CPU_SPARSE_KERNELS_H_

0 commit comments

Comments
 (0)