Skip to content

Commit

Permalink
improvements to scatter / gather (#1541)
Browse files Browse the repository at this point in the history
  • Loading branch information
awni authored Oct 31, 2024
1 parent 960e3f0 commit 4f72c66
Show file tree
Hide file tree
Showing 9 changed files with 192 additions and 245 deletions.
12 changes: 6 additions & 6 deletions benchmarks/python/scatter_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

def benchmark_scatter_mlx(dst_shape, x_shape, idx_shapes):
def scatter(dst, x, idx):
dst[*idx] = x
dst[tuple(idx)] = x
mx.eval(dst)

idx = []
Expand All @@ -23,8 +23,8 @@ def scatter(dst, x, idx):


def benchmark_scatter_torch(dst_shape, x_shape, idx_shapes, device):
def gather(dst, x, idx, device):
dst[*idx] = x
def scatter(dst, x, idx, device):
dst[tuple(idx)] = x
if device == torch.device("mps"):
torch.mps.synchronize()

Expand All @@ -34,7 +34,7 @@ def gather(dst, x, idx, device):
x = torch.randn(x_shape, dtype=torch.float32).to(device)
dst = torch.randn(dst_shape, dtype=torch.float32).to(device)

runtime = measure_runtime(gather, dst=dst, x=x, idx=idx, device=device)
runtime = measure_runtime(scatter, dst=dst, x=x, idx=idx, device=device)
print(f"PyTorch: {runtime:.3f}ms")


Expand All @@ -54,7 +54,7 @@ def gather(dst, x, idx, device):
(100_000, 64),
(1_000_000, 64),
(100_000,),
(2_000_00,),
(200_000,),
(20_000_000,),
(10000, 64),
(100, 64),
Expand Down Expand Up @@ -91,6 +91,6 @@ def gather(dst, x, idx, device):

for dst_shape, x_shape, idx_shape in zip(dst_shapes, x_shapes, idx_shapes):
print("=" * 20)
print(f"X {x_shape}, Indices {idx_shape}")
print(f"Dst: {dst_shape}, X {x_shape}, Indices {idx_shape}")
benchmark_scatter_mlx(dst_shape, x_shape, idx_shape)
benchmark_scatter_torch(dst_shape, x_shape, idx_shape, device=device)
4 changes: 2 additions & 2 deletions mlx/backend/metal/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ make_jit_source(unary_ops kernels/erf.h kernels/expm1f.h)
make_jit_source(binary_ops)
make_jit_source(ternary_ops)
make_jit_source(reduce_utils kernels/atomic.h kernels/reduction/ops.h)
make_jit_source(scatter)
make_jit_source(gather)
make_jit_source(scatter kernels/indexing.h)
make_jit_source(gather kernels/indexing.h)
make_jit_source(hadamard)

if(MLX_METAL_JIT)
Expand Down
244 changes: 113 additions & 131 deletions mlx/backend/metal/indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,39 +113,38 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
// Collect all idx shapes and strides into one place
std::vector<int> idx_shapes;
std::vector<size_t> idx_strides;

std::vector<char> idx_contigs;
for (int i = 0; i < nidx; ++i) {
idx_shapes.insert(
idx_shapes.end(),
inputs[i + 1].shape().begin(),
inputs[i + 1].shape().end());

idx_strides.insert(
idx_strides.end(),
inputs[i + 1].strides().begin(),
inputs[i + 1].strides().end());
idx_contigs.push_back(inputs[i + 1].flags().row_contiguous);
}

// Set all the buffers
compute_encoder.set_input_array(src, 0);
compute_encoder.set_output_array(out, 1);

// Set source info
compute_encoder->setBytes(src.shape().data(), ndim * sizeof(int), 2);
compute_encoder->setBytes(src.strides().data(), ndim * sizeof(size_t), 3);
set_vector_bytes(compute_encoder, src.shape(), 2);
set_vector_bytes(compute_encoder, src.strides(), 3);
compute_encoder->setBytes(&ndim, sizeof(size_t), 4);
compute_encoder->setBytes(slice_sizes_.data(), ndim * sizeof(int), 5);
compute_encoder->setBytes(axes_.data(), nidx * sizeof(int), 6);
set_vector_bytes(compute_encoder, slice_sizes_, 5);
set_vector_bytes(compute_encoder, axes_, 6);

// Set index info
//
// We don't need to check for empty idx_shapes because gather has a
// idx_ndim == 0 specialization
compute_encoder->setBytes(
idx_shapes.data(), idx_shapes.size() * sizeof(int), 7);
compute_encoder->setBytes(
idx_strides.data(), idx_strides.size() * sizeof(size_t), 8);
compute_encoder->setBytes(&idx_ndim, sizeof(int), 9);
set_vector_bytes(compute_encoder, idx_shapes, 7);
set_vector_bytes(compute_encoder, idx_strides, 8);
set_vector_bytes(compute_encoder, idx_contigs, 9);
compute_encoder->setBytes(&idx_ndim, sizeof(int), 10);

// Set index buffers
for (int i = 0; i < nidx; ++i) {
Expand All @@ -172,12 +171,20 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
}

// Copy src into out
auto copy_type =
inputs[0].data_size() == 1 ? CopyType::Scalar : CopyType::General;
CopyType copy_type;
if (inputs[0].data_size() == 1) {
copy_type = CopyType::Scalar;
} else if (inputs[0].flags().row_contiguous) {
copy_type = CopyType::Vector;
} else {
copy_type = CopyType::General;
}
copy_gpu(inputs[0], out, copy_type);

auto& upd = inputs.back();

// Empty update
if (inputs.back().size() == 0) {
if (upd.size() == 0) {
return;
}

Expand All @@ -186,19 +193,20 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& d = metal::device(s.device);

int idx_ndim = nidx ? inputs[1].ndim() : 0;
bool index_nd1_specialization = (idx_ndim == 1);

// Bail from fast path (1d index specialization) if scatter dims aren't
// the outermost dims and contiguous since update access won't be raster
// order.
for (auto i = 0; i < axes_.size() && index_nd1_specialization; i++) {
index_nd1_specialization &= (axes_[i] == i);
}

// Bail from fast path (1d index specialization) if any of the dims are
// broadcasted, since we can't rely on linear indexing in that case.
for (int i = 1; i < inputs.size() && index_nd1_specialization; i++) {
index_nd1_specialization &= inputs[i].flags().row_contiguous;
size_t idx_size = nidx ? inputs[1].size() : 1;

auto idx_to_out = idx_size / out.size();
int nwork;
if (idx_ndim <= 1 || idx_to_out < 1) {
nwork = 1;
} else if (idx_to_out <= 4) {
nwork = 4;
} else if (idx_to_out < 16) {
nwork = 8;
} else if (idx_to_out < 32) {
nwork = 16;
} else {
nwork = 32;
}

std::string lib_name;
Expand All @@ -222,19 +230,15 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
op_name = "min";
break;
}

auto upd_contig = upd.flags().row_contiguous;
{
std::ostringstream kname;
if (index_nd1_specialization) {
kname << "scatter_1d_index" << type_to_name(out) << idx_type_name;
} else {
kname << "scatter" << type_to_name(out) << idx_type_name;
}
kname << "_" << op_name << "_" << nidx;
kname << "scatter" << type_to_name(out) << idx_type_name;
kname << "_" << op_name << "_" << nidx << "_"
<< (upd_contig ? "updc_true" : "updc_false") << "_nwork" << nwork;
lib_name = kname.str();
kernel_name = kname.str();
}

auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::reduce_utils()
Expand Down Expand Up @@ -274,14 +278,15 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
op_type,
nidx,
idx_args,
idx_arr);
idx_arr,
upd_contig,
nwork);
return kernel_source.str();
});

auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kernel_name, lib);

auto& upd = inputs.back();
size_t nthreads = upd.size();

compute_encoder->setComputePipelineState(kernel);
Expand All @@ -291,109 +296,86 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_output_array(out, 2);

// Set update info
uint upd_ndim = upd.ndim();
size_t upd_ndim = upd.ndim();
size_t upd_size = 1;
for (int i = idx_ndim; i < upd.ndim(); ++i) {
upd_size *= upd.shape(i);
}
if (index_nd1_specialization) {
compute_encoder->setBytes(
out.shape().data(), out.shape().size() * sizeof(int), 3);
compute_encoder->setBytes(
out.strides().data(), out.strides().size() * sizeof(size_t), 4);

size_t out_ndim = out.ndim();
compute_encoder->setBytes(&out_ndim, sizeof(out_ndim), 5);
if (upd_ndim <= 1) {
// Placeholder so Metal doesn't compalain
int shape_ = 0;
compute_encoder->setBytes(&shape_, sizeof(int), 6);
} else {
compute_encoder->setBytes(upd.shape().data(), upd_ndim * sizeof(int), 6);
}
compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 7);
compute_encoder->setBytes(&upd_size, sizeof(size_t), 8);

// Set index buffers
for (int i = 0; i < nidx; ++i) {
compute_encoder.set_input_array(inputs[i + 1], 20 + i);
}

// Launch grid
MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1);
MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
// Collect all idx shapes and strides into one place
std::vector<int> idx_shapes;
std::vector<size_t> idx_strides;
// To access .data() use char instead of bool
// bool is 1 byte in Metal so this is safe
std::vector<char> idx_contigs;
for (int i = 0; i < nidx; ++i) {
idx_shapes.insert(
idx_shapes.end(),
inputs[i + 1].shape().begin(),
inputs[i + 1].shape().end());
idx_strides.insert(
idx_strides.end(),
inputs[i + 1].strides().begin(),
inputs[i + 1].strides().end());
idx_contigs.push_back(inputs[i + 1].flags().row_contiguous);
}

if (upd_ndim == 0) {
// Need placeholders so Metal doesn't compalain
int shape_ = 0;
size_t stride_ = 0;
compute_encoder->setBytes(&shape_, sizeof(int), 3);
compute_encoder->setBytes(&stride_, sizeof(size_t), 4);
} else {
// Collect all idx shapes and strides into one place
std::vector<int> idx_shapes;
std::vector<size_t> idx_strides;

for (int i = 0; i < nidx; ++i) {
idx_shapes.insert(
idx_shapes.end(),
inputs[i + 1].shape().begin(),
inputs[i + 1].shape().end());

idx_strides.insert(
idx_strides.end(),
inputs[i + 1].strides().begin(),
inputs[i + 1].strides().end());
}
set_vector_bytes(compute_encoder, upd.shape(), 3);
set_vector_bytes(compute_encoder, upd.strides(), 4);
}
compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 5);
compute_encoder->setBytes(&upd_size, sizeof(size_t), 6);

// Set output info
size_t out_ndim = out.ndim();
if (out_ndim == 0) {
// Need placeholders so Metal doesn't compalain
int shape_ = 0;
size_t stride_ = 0;
compute_encoder->setBytes(&shape_, sizeof(int), 7);
compute_encoder->setBytes(&stride_, sizeof(size_t), 8);
} else {
set_vector_bytes(compute_encoder, out.shape(), 7);
set_vector_bytes(compute_encoder, out.strides(), 8);
}
compute_encoder->setBytes(&out_ndim, sizeof(size_t), 9);
compute_encoder->setBytes(axes_.data(), axes_.size() * sizeof(int), 10);

if (upd_ndim == 0) {
// Need placeholders so Metal doesn't compalain
int shape_ = 0;
size_t stride_ = 0;
compute_encoder->setBytes(&shape_, sizeof(int), 3);
compute_encoder->setBytes(&stride_, sizeof(size_t), 4);
} else {
compute_encoder->setBytes(upd.shape().data(), upd_ndim * sizeof(int), 3);
compute_encoder->setBytes(
upd.strides().data(), upd_ndim * sizeof(size_t), 4);
}
compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 5);
compute_encoder->setBytes(&upd_size, sizeof(size_t), 6);

// Set output info
size_t out_ndim = out.ndim();
if (out_ndim == 0) {
// Need placeholders so Metal doesn't compalain
int shape_ = 0;
size_t stride_ = 0;
compute_encoder->setBytes(&shape_, sizeof(int), 7);
compute_encoder->setBytes(&stride_, sizeof(size_t), 8);
} else {
compute_encoder->setBytes(out.shape().data(), out_ndim * sizeof(int), 7);
compute_encoder->setBytes(
out.strides().data(), out_ndim * sizeof(size_t), 8);
}
compute_encoder->setBytes(&out_ndim, sizeof(size_t), 9);
compute_encoder->setBytes(axes_.data(), axes_.size() * sizeof(int), 10);

// Set index info
if (idx_ndim == 0) {
// Add a 0 in idx_shapes and strides to avoid the missing buffer binding
// error in the metal API.
idx_shapes.push_back(0);
idx_strides.push_back(0);
}
compute_encoder->setBytes(
idx_shapes.data(), idx_shapes.size() * sizeof(int), 11);
compute_encoder->setBytes(
idx_strides.data(), idx_strides.size() * sizeof(size_t), 12);
compute_encoder->setBytes(&idx_ndim, sizeof(int), 13);

// Set index buffers
for (int i = 0; i < nidx; ++i) {
compute_encoder.set_input_array(inputs[i + 1], 20 + i);
}
// Set index info
if (idx_ndim == 0) {
// Add a 0 in idx_shapes and strides to avoid the missing buffer binding
// error in the metal API.
idx_shapes.push_back(0);
idx_strides.push_back(0);
idx_contigs.push_back(false);
}
set_vector_bytes(compute_encoder, idx_shapes, 11);
set_vector_bytes(compute_encoder, idx_strides, 12);
set_vector_bytes(compute_encoder, idx_contigs, 13);
compute_encoder->setBytes(&idx_ndim, sizeof(int), 14);
compute_encoder->setBytes(&idx_size, sizeof(size_t), 15);

// Set index buffers
for (int i = 0; i < nidx; ++i) {
compute_encoder.set_input_array(inputs[i + 1], 20 + i);
}

// Launch grid
MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1);
MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
// Launch grid
auto grid_y = (nthreads / upd_size);
grid_y = (grid_y + nwork - 1) / nwork;
MTL::Size grid_dims = MTL::Size(upd_size, grid_y, 1);
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size != 1024) {
throw std::runtime_error("[Scatter::eval_gpu] Invalid number of threads");
}
MTL::Size group_dims = get_block_dims(upd_size, grid_y, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}

} // namespace mlx::core
Loading

0 comments on commit 4f72c66

Please sign in to comment.