Skip to content

Commit 6e87b9e

Browse files
WillFroomGoogle-ML-Automation
authored andcommitted
[JAX:Sparse] Implement CSR sparse kernel
PiperOrigin-RevId: 750576197
1 parent b0251a3 commit 6e87b9e

File tree

10 files changed

+370
-4
lines changed

10 files changed

+370
-4
lines changed

jax/experimental/sparse/bcsr.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def _bcsr_to_bcoo(indices: jax.Array, indptr: jax.Array, *,
145145

146146

147147
def _bcoo_to_bcsr(indices: Array, *, shape: Sequence[int],
148-
index_dtype: DTypeLike = jnp.int32) -> tuple[Array, Array]:
148+
index_dtype: DTypeLike) -> tuple[Array, Array]:
149149
"""Given BCOO (indices), return BCSR (indices, indptr).
150150
151151
Note: this assumes that ``indices`` are lexicographically sorted within each batch.
@@ -238,7 +238,9 @@ def _bcsr_fromdense_impl(mat, *, nse, n_batch, n_dense, index_dtype):
238238
raise ValueError("bcsr_fromdense: must have 2 sparse dimensions.")
239239
bcoo_mat = bcoo.bcoo_fromdense(mat, nse=nse, index_dtype=index_dtype,
240240
n_dense=n_dense, n_batch=n_batch)
241-
indices, indptr = _bcoo_to_bcsr(bcoo_mat.indices, shape=mat.shape)
241+
indices, indptr = _bcoo_to_bcsr(
242+
bcoo_mat.indices, shape=mat.shape, index_dtype=index_dtype
243+
)
242244
return bcoo_mat.data, indices, indptr
243245

244246

@@ -867,7 +869,9 @@ def from_bcoo(cls, arr: bcoo.BCOO) -> BCSR:
867869
raise NotImplementedError(f"BSCR.from_bcoo requires n_sparse=2; got {arr.n_sparse=}")
868870
if not arr.indices_sorted:
869871
arr = arr.sort_indices()
870-
indices, indptr = _bcoo_to_bcsr(arr.indices, shape=arr.shape)
872+
indices, indptr = _bcoo_to_bcsr(
873+
arr.indices, shape=arr.shape, index_dtype=arr.indices.dtype
874+
)
871875
return cls((arr.data, indices, indptr), shape=arr.shape)
872876

873877
@classmethod

jaxlib/BUILD

+2
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ package_group(
5252
py_library_providing_imports_info(
5353
name = "jaxlib",
5454
srcs = [
55+
"cpu_sparse.py",
5556
"gpu_common_utils.py",
5657
"gpu_linalg.py",
5758
"gpu_prng.py",
@@ -76,6 +77,7 @@ py_library_providing_imports_info(
7677
"//jaxlib:_jax",
7778
"//jaxlib:xla_client",
7879
"//jaxlib/cpu:_lapack",
80+
"//jaxlib/cpu:_sparse",
7981
"//jaxlib/mlir",
8082
"//jaxlib/mlir:arithmetic_dialect",
8183
"//jaxlib/mlir:builtin_dialect",

jaxlib/cpu/BUILD

+33
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,42 @@ cc_library(
8585
deps = [
8686
":lapack_kernels",
8787
":lapack_kernels_using_lapack",
88+
":sparse_kernels",
8889
"@xla//xla/ffi/api:c_api",
8990
"@xla//xla/ffi/api:ffi",
9091
"@xla//xla/service:custom_call_target_registry",
9192
],
9293
alwayslink = 1,
9394
)
95+
96+
cc_library(
97+
name = "sparse_kernels",
98+
srcs = ["sparse_kernels.cc"],
99+
hdrs = ["sparse_kernels.h"],
100+
deps = [
101+
"@eigen_archive//:eigen3",
102+
"@xla//xla/ffi/api:ffi",
103+
],
104+
)
105+
106+
nanobind_extension(
107+
name = "_sparse",
108+
srcs = ["sparse.cc"],
109+
copts = [
110+
"-fexceptions",
111+
"-fno-strict-aliasing",
112+
],
113+
enable_stub_generation = False,
114+
features = ["-use_header_modules"],
115+
module_name = "_sparse",
116+
pytype_srcs = [
117+
"_sparse/__init__.pyi",
118+
],
119+
deps = [
120+
":sparse_kernels",
121+
"//jaxlib:kernel_nanobind_helpers",
122+
"@com_google_absl//absl/base",
123+
"@nanobind",
124+
"@xla//xla/ffi/api:ffi",
125+
],
126+
)

jaxlib/cpu/_sparse/__init__.pyi

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright 2025 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
def registrations() -> dict: ...

jaxlib/cpu/cpu_kernels.cc

+3
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ limitations under the License.
1919
#include <complex>
2020

2121
#include "jaxlib/cpu/lapack_kernels.h"
22+
#include "jaxlib/cpu/sparse_kernels.h"
2223
#include "xla/ffi/api/c_api.h"
2324
#include "xla/ffi/api/ffi.h"
2425
#include "xla/service/custom_call_target_registry.h"
@@ -110,6 +111,8 @@ JAX_CPU_REGISTER_HANDLER(lapack_dgtsv_ffi);
110111
JAX_CPU_REGISTER_HANDLER(lapack_cgtsv_ffi);
111112
JAX_CPU_REGISTER_HANDLER(lapack_zgtsv_ffi);
112113

114+
JAX_CPU_REGISTER_HANDLER(cpu_csr_sparse_dense_ffi);
115+
113116
#undef JAX_CPU_REGISTER_HANDLER
114117

115118
} // namespace

jaxlib/cpu/sparse.cc

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
/* Copyright 2021 The JAX Authors.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "nanobind/nanobind.h"
17+
#include "jaxlib/cpu/sparse_kernels.h"
18+
#include "jaxlib/kernel_nanobind_helpers.h"
19+
20+
namespace jax {
21+
namespace {
22+
23+
namespace nb = nanobind;
24+
25+
nb::dict Registrations() {
26+
nb::dict dict;
27+
28+
dict["cpu_csr_sparse_dense_ffi"] =
29+
EncapsulateFunction(cpu_csr_sparse_dense_ffi);
30+
31+
return dict;
32+
}
33+
34+
NB_MODULE(_sparse, m) { m.def("registrations", &Registrations); }
35+
36+
} // namespace
37+
} // namespace jax

jaxlib/cpu/sparse_kernels.cc

+216
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
/* Copyright 2025 The JAX Authors.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "jaxlib/cpu/sparse_kernels.h"
17+
18+
#include <algorithm>
19+
#include <complex>
20+
#include <cstdint>
21+
#include <vector>
22+
23+
#include "Eigen/Core"
24+
#include "Eigen/SparseCore"
25+
#include "xla/ffi/api/ffi.h"
26+
27+
namespace jax {
28+
29+
template <typename ElementType, typename StorageType>
30+
using SparseMatrixType =
31+
Eigen::SparseMatrix<ElementType, Eigen::RowMajor, StorageType>;
32+
template <typename ElementType>
33+
using DenseMatrixType =
34+
Eigen::Matrix<ElementType, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
35+
36+
template <typename MatrixT>
37+
using InputMap = Eigen::Map<const MatrixT, Eigen::Aligned32>;
38+
template <typename MatrixT>
39+
using OutputMap = Eigen::Map<MatrixT, Eigen::Aligned32>;
40+
41+
template <typename ElementType, typename StorageType>
42+
static ::xla::ffi::Future CsrSparseDenseKernelImpl(
43+
const InputMap<SparseMatrixType<ElementType, StorageType>>& lhs_matrix,
44+
const InputMap<DenseMatrixType<ElementType>>& rhs_matrix,
45+
OutputMap<DenseMatrixType<ElementType>>& out_matrix,
46+
::xla::ffi::ThreadPool& thread_pool) {
47+
// Rule of thumb to give each task at least 100k cycles to hide the cost of
48+
// task scheduling.
49+
// TODO(willfroom) Do we want to make this configurable?
50+
constexpr int64_t kTargetCyclesPerTask = 100'000;
51+
// Based on AVX (CPI 0.5 -> 2 IPC)
52+
constexpr int64_t kScalarProductsPerCycle = 2 * 32 / sizeof(ElementType);
53+
constexpr int64_t kTaskSize = kTargetCyclesPerTask * kScalarProductsPerCycle;
54+
55+
if (lhs_matrix.nonZeros() * rhs_matrix.cols() <= kTaskSize ||
56+
thread_pool.num_threads() == 0) {
57+
out_matrix.noalias() = lhs_matrix * rhs_matrix;
58+
59+
::xla::ffi::Promise promise;
60+
promise.SetAvailable();
61+
return ::xla::ffi::Future(promise);
62+
} else {
63+
std::vector<int64_t> batch_sizes;
64+
{
65+
int64_t running_batch_nnz = 0;
66+
int64_t running_number_rows = 0;
67+
for (int row = 0; row < lhs_matrix.rows(); ++row) {
68+
int64_t row_nnz = lhs_matrix.outerIndexPtr()[row + 1] -
69+
lhs_matrix.outerIndexPtr()[row];
70+
// If there is no non-zero elements in a row the task still needs to
71+
// write out a zero row we give each row a non-zero contribution to
72+
// avoid the pathological case of a task having to write many rows where
73+
// there is a large block of zero inputs.
74+
running_batch_nnz += std::max(row_nnz, static_cast<int64_t>(1));
75+
running_number_rows++;
76+
if (running_batch_nnz * rhs_matrix.cols() > kTaskSize) {
77+
batch_sizes.push_back(running_number_rows);
78+
running_batch_nnz = 0;
79+
running_number_rows = 0;
80+
} else if (row == lhs_matrix.rows() - 1 && running_number_rows > 0) {
81+
batch_sizes.push_back(running_number_rows);
82+
}
83+
}
84+
}
85+
86+
::xla::ffi::CountDownPromise promise(batch_sizes.size());
87+
::xla::ffi::Future future(promise);
88+
int64_t batch_start = 0;
89+
for (int64_t size : batch_sizes) {
90+
thread_pool.Schedule([out_matrix, lhs_matrix, rhs_matrix, batch_start,
91+
size, promise]() mutable {
92+
out_matrix.middleRows(batch_start, size).noalias() =
93+
lhs_matrix.middleRows(batch_start, size) * rhs_matrix;
94+
promise.CountDown();
95+
});
96+
batch_start += size;
97+
}
98+
return future;
99+
}
100+
}
101+
102+
template <typename ElementType, typename StorageType>
103+
static ::xla::ffi::Future CsrSparseDenseKernelTypedDispatch(
104+
::xla::ffi::AnyBuffer lhs_data, ::xla::ffi::AnyBuffer lhs_outer_indicies,
105+
::xla::ffi::AnyBuffer lhs_inner_indicies, ::xla::ffi::AnyBuffer rhs,
106+
::xla::ffi::Result<::xla::ffi::AnyBuffer> out,
107+
::xla::ffi::ThreadPool thread_pool) {
108+
::xla::ffi::Span<const int64_t> rhs_shape = rhs.dimensions();
109+
::xla::ffi::Span<const int64_t> out_shape = out->dimensions();
110+
111+
InputMap<SparseMatrixType<ElementType, StorageType>> lhs_matrix(
112+
out_shape[0], rhs_shape[0], lhs_data.element_count(),
113+
lhs_outer_indicies.reinterpret_data<StorageType>(),
114+
lhs_inner_indicies.reinterpret_data<StorageType>(),
115+
lhs_data.reinterpret_data<ElementType>());
116+
117+
InputMap<DenseMatrixType<ElementType>> rhs_matrix(
118+
rhs.reinterpret_data<ElementType>(), rhs_shape[0],
119+
rhs_shape.size() > 1 ? rhs_shape[1] : 1);
120+
OutputMap<DenseMatrixType<ElementType>> out_matrix(
121+
out->reinterpret_data<ElementType>(), lhs_matrix.rows(),
122+
rhs_matrix.cols());
123+
124+
return CsrSparseDenseKernelImpl<ElementType, StorageType>(
125+
lhs_matrix, rhs_matrix, out_matrix, thread_pool);
126+
}
127+
128+
template <typename ElementType>
129+
static ::xla::ffi::Future CsrSparseDenseKernelTypedDispatch(
130+
::xla::ffi::AnyBuffer lhs_data, ::xla::ffi::AnyBuffer lhs_outer_indicies,
131+
::xla::ffi::AnyBuffer lhs_inner_indicies, ::xla::ffi::AnyBuffer rhs,
132+
::xla::ffi::Result<::xla::ffi::AnyBuffer> out,
133+
::xla::ffi::ThreadPool thread_pool) {
134+
if (lhs_outer_indicies.element_type() != lhs_inner_indicies.element_type()) {
135+
::xla::ffi::Promise promise;
136+
promise.SetError(::xla::ffi::Error(::xla::ffi::ErrorCode::kInvalidArgument,
137+
"Sparse index type mismatch"));
138+
return ::xla::ffi::Future(promise);
139+
}
140+
141+
switch (lhs_outer_indicies.element_type()) {
142+
case ::xla::ffi::DataType::S32:
143+
return CsrSparseDenseKernelTypedDispatch<ElementType, int32_t>(
144+
lhs_data, lhs_outer_indicies, lhs_inner_indicies, rhs, out,
145+
thread_pool);
146+
case ::xla::ffi::DataType::S64:
147+
return CsrSparseDenseKernelTypedDispatch<ElementType, int64_t>(
148+
lhs_data, lhs_outer_indicies, lhs_inner_indicies, rhs, out,
149+
thread_pool);
150+
default:
151+
::xla::ffi::Promise promise;
152+
promise.SetError(::xla::ffi::Error(
153+
::xla::ffi::ErrorCode::kInvalidArgument, "Invalid index data type"));
154+
return ::xla::ffi::Future(promise);
155+
}
156+
}
157+
158+
static ::xla::ffi::Future CsrSparseDenseKernelDispatch(
159+
::xla::ffi::AnyBuffer lhs_data, ::xla::ffi::AnyBuffer lhs_outer_indicies,
160+
::xla::ffi::AnyBuffer lhs_inner_indicies, ::xla::ffi::AnyBuffer rhs,
161+
::xla::ffi::Result<::xla::ffi::AnyBuffer> out,
162+
::xla::ffi::ThreadPool thread_pool) {
163+
if (lhs_data.element_type() != rhs.element_type() ||
164+
lhs_data.element_type() != out->element_type()) {
165+
::xla::ffi::Promise promise;
166+
promise.SetError(::xla::ffi::Error(::xla::ffi::ErrorCode::kInvalidArgument,
167+
"Element type mismatch"));
168+
return ::xla::ffi::Future(promise);
169+
}
170+
171+
switch (lhs_data.element_type()) {
172+
case ::xla::ffi::DataType::S32:
173+
return CsrSparseDenseKernelTypedDispatch<int32_t>(
174+
lhs_data, lhs_outer_indicies, lhs_inner_indicies, rhs, out,
175+
thread_pool);
176+
case ::xla::ffi::DataType::S64:
177+
return CsrSparseDenseKernelTypedDispatch<int64_t>(
178+
lhs_data, lhs_outer_indicies, lhs_inner_indicies, rhs, out,
179+
thread_pool);
180+
case ::xla::ffi::DataType::F32:
181+
return CsrSparseDenseKernelTypedDispatch<float>(
182+
lhs_data, lhs_outer_indicies, lhs_inner_indicies, rhs, out,
183+
thread_pool);
184+
case ::xla::ffi::DataType::F64:
185+
return CsrSparseDenseKernelTypedDispatch<double>(
186+
lhs_data, lhs_outer_indicies, lhs_inner_indicies, rhs, out,
187+
thread_pool);
188+
case ::xla::ffi::DataType::C64:
189+
return CsrSparseDenseKernelTypedDispatch<std::complex<float>>(
190+
lhs_data, lhs_outer_indicies, lhs_inner_indicies, rhs, out,
191+
thread_pool);
192+
case ::xla::ffi::DataType::C128:
193+
return CsrSparseDenseKernelTypedDispatch<std::complex<double>>(
194+
lhs_data, lhs_outer_indicies, lhs_inner_indicies, rhs, out,
195+
thread_pool);
196+
default:
197+
::xla::ffi::Promise promise;
198+
promise.SetError(::xla::ffi::Error(
199+
::xla::ffi::ErrorCode::kInvalidArgument, "Invalid data type"));
200+
return ::xla::ffi::Future(promise);
201+
}
202+
}
203+
204+
XLA_FFI_DEFINE_HANDLER_SYMBOL(
205+
cpu_csr_sparse_dense_ffi, CsrSparseDenseKernelDispatch,
206+
(::xla::ffi::Ffi::Bind()
207+
.Arg<::xla::ffi::AnyBuffer>(/*lhs_data*/)
208+
.Arg<::xla::ffi::AnyBuffer>(
209+
/*lhs_outer_indicies*/)
210+
.Arg<::xla::ffi::AnyBuffer>(
211+
/*lhs_inner_indicies*/)
212+
.Arg<::xla::ffi::AnyBuffer>(/*rhs*/)
213+
.Ret<::xla::ffi::AnyBuffer>(/*out*/)
214+
.Ctx<::xla::ffi::ThreadPool>(/*thread_pool*/)));
215+
216+
} // namespace jax

0 commit comments

Comments
 (0)