From 4f72c669112970e63af0e67f9518f6646d3ce284 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 30 Oct 2024 19:30:54 -0700 Subject: [PATCH] improvements to scatter / gather (#1541) --- benchmarks/python/scatter_bench.py | 12 +- mlx/backend/metal/CMakeLists.txt | 4 +- mlx/backend/metal/indexing.cpp | 244 +++++++++++++-------------- mlx/backend/metal/jit/indexing.h | 41 ++--- mlx/backend/metal/kernels/gather.h | 12 +- mlx/backend/metal/kernels/indexing.h | 1 + mlx/backend/metal/kernels/scatter.h | 103 +++++------ python/mlx/nn/layers/upsample.py | 14 +- python/tests/test_ops.py | 6 +- 9 files changed, 192 insertions(+), 245 deletions(-) diff --git a/benchmarks/python/scatter_bench.py b/benchmarks/python/scatter_bench.py index d2fd569acf..655cf0b033 100644 --- a/benchmarks/python/scatter_bench.py +++ b/benchmarks/python/scatter_bench.py @@ -9,7 +9,7 @@ def benchmark_scatter_mlx(dst_shape, x_shape, idx_shapes): def scatter(dst, x, idx): - dst[*idx] = x + dst[tuple(idx)] = x mx.eval(dst) idx = [] @@ -23,8 +23,8 @@ def scatter(dst, x, idx): def benchmark_scatter_torch(dst_shape, x_shape, idx_shapes, device): - def gather(dst, x, idx, device): - dst[*idx] = x + def scatter(dst, x, idx, device): + dst[tuple(idx)] = x if device == torch.device("mps"): torch.mps.synchronize() @@ -34,7 +34,7 @@ def gather(dst, x, idx, device): x = torch.randn(x_shape, dtype=torch.float32).to(device) dst = torch.randn(dst_shape, dtype=torch.float32).to(device) - runtime = measure_runtime(gather, dst=dst, x=x, idx=idx, device=device) + runtime = measure_runtime(scatter, dst=dst, x=x, idx=idx, device=device) print(f"PyTorch: {runtime:.3f}ms") @@ -54,7 +54,7 @@ def gather(dst, x, idx, device): (100_000, 64), (1_000_000, 64), (100_000,), - (2_000_00,), + (200_000,), (20_000_000,), (10000, 64), (100, 64), @@ -91,6 +91,6 @@ def gather(dst, x, idx, device): for dst_shape, x_shape, idx_shape in zip(dst_shapes, x_shapes, idx_shapes): print("=" * 20) - print(f"X {x_shape}, Indices {idx_shape}") + print(f"Dst: {dst_shape}, X {x_shape}, Indices {idx_shape}") benchmark_scatter_mlx(dst_shape, x_shape, idx_shape) benchmark_scatter_torch(dst_shape, x_shape, idx_shape, device=device) diff --git a/mlx/backend/metal/CMakeLists.txt b/mlx/backend/metal/CMakeLists.txt index 3e88e18d1d..4b14ebb553 100644 --- a/mlx/backend/metal/CMakeLists.txt +++ b/mlx/backend/metal/CMakeLists.txt @@ -26,8 +26,8 @@ make_jit_source(unary_ops kernels/erf.h kernels/expm1f.h) make_jit_source(binary_ops) make_jit_source(ternary_ops) make_jit_source(reduce_utils kernels/atomic.h kernels/reduction/ops.h) -make_jit_source(scatter) -make_jit_source(gather) +make_jit_source(scatter kernels/indexing.h) +make_jit_source(gather kernels/indexing.h) make_jit_source(hadamard) if(MLX_METAL_JIT) diff --git a/mlx/backend/metal/indexing.cpp b/mlx/backend/metal/indexing.cpp index 7820b0272b..75a5323465 100644 --- a/mlx/backend/metal/indexing.cpp +++ b/mlx/backend/metal/indexing.cpp @@ -113,17 +113,17 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { // Collect all idx shapes and strides into one place std::vector idx_shapes; std::vector idx_strides; - + std::vector idx_contigs; for (int i = 0; i < nidx; ++i) { idx_shapes.insert( idx_shapes.end(), inputs[i + 1].shape().begin(), inputs[i + 1].shape().end()); - idx_strides.insert( idx_strides.end(), inputs[i + 1].strides().begin(), inputs[i + 1].strides().end()); + idx_contigs.push_back(inputs[i + 1].flags().row_contiguous); } // Set all the buffers @@ -131,21 +131,20 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.set_output_array(out, 1); // Set source info - compute_encoder->setBytes(src.shape().data(), ndim * sizeof(int), 2); - compute_encoder->setBytes(src.strides().data(), ndim * sizeof(size_t), 3); + set_vector_bytes(compute_encoder, src.shape(), 2); + set_vector_bytes(compute_encoder, src.strides(), 3); compute_encoder->setBytes(&ndim, sizeof(size_t), 4); - compute_encoder->setBytes(slice_sizes_.data(), ndim * sizeof(int), 5); - compute_encoder->setBytes(axes_.data(), nidx * sizeof(int), 6); + set_vector_bytes(compute_encoder, slice_sizes_, 5); + set_vector_bytes(compute_encoder, axes_, 6); // Set index info // // We don't need to check for empty idx_shapes because gather has a // idx_ndim == 0 specialization - compute_encoder->setBytes( - idx_shapes.data(), idx_shapes.size() * sizeof(int), 7); - compute_encoder->setBytes( - idx_strides.data(), idx_strides.size() * sizeof(size_t), 8); - compute_encoder->setBytes(&idx_ndim, sizeof(int), 9); + set_vector_bytes(compute_encoder, idx_shapes, 7); + set_vector_bytes(compute_encoder, idx_strides, 8); + set_vector_bytes(compute_encoder, idx_contigs, 9); + compute_encoder->setBytes(&idx_ndim, sizeof(int), 10); // Set index buffers for (int i = 0; i < nidx; ++i) { @@ -172,12 +171,20 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { } // Copy src into out - auto copy_type = - inputs[0].data_size() == 1 ? CopyType::Scalar : CopyType::General; + CopyType copy_type; + if (inputs[0].data_size() == 1) { + copy_type = CopyType::Scalar; + } else if (inputs[0].flags().row_contiguous) { + copy_type = CopyType::Vector; + } else { + copy_type = CopyType::General; + } copy_gpu(inputs[0], out, copy_type); + auto& upd = inputs.back(); + // Empty update - if (inputs.back().size() == 0) { + if (upd.size() == 0) { return; } @@ -186,19 +193,20 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { auto& d = metal::device(s.device); int idx_ndim = nidx ? inputs[1].ndim() : 0; - bool index_nd1_specialization = (idx_ndim == 1); - - // Bail from fast path (1d index specialization) if scatter dims aren't - // the outermost dims and contiguous since update access won't be raster - // order. - for (auto i = 0; i < axes_.size() && index_nd1_specialization; i++) { - index_nd1_specialization &= (axes_[i] == i); - } - - // Bail from fast path (1d index specialization) if any of the dims are - // broadcasted, since we can't rely on linear indexing in that case. - for (int i = 1; i < inputs.size() && index_nd1_specialization; i++) { - index_nd1_specialization &= inputs[i].flags().row_contiguous; + size_t idx_size = nidx ? inputs[1].size() : 1; + + auto idx_to_out = idx_size / out.size(); + int nwork; + if (idx_ndim <= 1 || idx_to_out < 1) { + nwork = 1; + } else if (idx_to_out <= 4) { + nwork = 4; + } else if (idx_to_out < 16) { + nwork = 8; + } else if (idx_to_out < 32) { + nwork = 16; + } else { + nwork = 32; } std::string lib_name; @@ -222,19 +230,15 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { op_name = "min"; break; } - + auto upd_contig = upd.flags().row_contiguous; { std::ostringstream kname; - if (index_nd1_specialization) { - kname << "scatter_1d_index" << type_to_name(out) << idx_type_name; - } else { - kname << "scatter" << type_to_name(out) << idx_type_name; - } - kname << "_" << op_name << "_" << nidx; + kname << "scatter" << type_to_name(out) << idx_type_name; + kname << "_" << op_name << "_" << nidx << "_" + << (upd_contig ? "updc_true" : "updc_false") << "_nwork" << nwork; lib_name = kname.str(); kernel_name = kname.str(); } - auto lib = d.get_library(lib_name, [&]() { std::ostringstream kernel_source; kernel_source << metal::utils() << metal::reduce_utils() @@ -274,14 +278,15 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { op_type, nidx, idx_args, - idx_arr); + idx_arr, + upd_contig, + nwork); return kernel_source.str(); }); auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kernel_name, lib); - auto& upd = inputs.back(); size_t nthreads = upd.size(); compute_encoder->setComputePipelineState(kernel); @@ -291,109 +296,86 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.set_output_array(out, 2); // Set update info - uint upd_ndim = upd.ndim(); + size_t upd_ndim = upd.ndim(); size_t upd_size = 1; for (int i = idx_ndim; i < upd.ndim(); ++i) { upd_size *= upd.shape(i); } - if (index_nd1_specialization) { - compute_encoder->setBytes( - out.shape().data(), out.shape().size() * sizeof(int), 3); - compute_encoder->setBytes( - out.strides().data(), out.strides().size() * sizeof(size_t), 4); - - size_t out_ndim = out.ndim(); - compute_encoder->setBytes(&out_ndim, sizeof(out_ndim), 5); - if (upd_ndim <= 1) { - // Placeholder so Metal doesn't compalain - int shape_ = 0; - compute_encoder->setBytes(&shape_, sizeof(int), 6); - } else { - compute_encoder->setBytes(upd.shape().data(), upd_ndim * sizeof(int), 6); - } - compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 7); - compute_encoder->setBytes(&upd_size, sizeof(size_t), 8); - - // Set index buffers - for (int i = 0; i < nidx; ++i) { - compute_encoder.set_input_array(inputs[i + 1], 20 + i); - } - - // Launch grid - MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1); - MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1); - compute_encoder.dispatchThreads(grid_dims, group_dims); + // Collect all idx shapes and strides into one place + std::vector idx_shapes; + std::vector idx_strides; + // To access .data() use char instead of bool + // bool is 1 byte in Metal so this is safe + std::vector idx_contigs; + for (int i = 0; i < nidx; ++i) { + idx_shapes.insert( + idx_shapes.end(), + inputs[i + 1].shape().begin(), + inputs[i + 1].shape().end()); + idx_strides.insert( + idx_strides.end(), + inputs[i + 1].strides().begin(), + inputs[i + 1].strides().end()); + idx_contigs.push_back(inputs[i + 1].flags().row_contiguous); + } + if (upd_ndim == 0) { + // Need placeholders so Metal doesn't compalain + int shape_ = 0; + size_t stride_ = 0; + compute_encoder->setBytes(&shape_, sizeof(int), 3); + compute_encoder->setBytes(&stride_, sizeof(size_t), 4); } else { - // Collect all idx shapes and strides into one place - std::vector idx_shapes; - std::vector idx_strides; - - for (int i = 0; i < nidx; ++i) { - idx_shapes.insert( - idx_shapes.end(), - inputs[i + 1].shape().begin(), - inputs[i + 1].shape().end()); - - idx_strides.insert( - idx_strides.end(), - inputs[i + 1].strides().begin(), - inputs[i + 1].strides().end()); - } + set_vector_bytes(compute_encoder, upd.shape(), 3); + set_vector_bytes(compute_encoder, upd.strides(), 4); + } + compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 5); + compute_encoder->setBytes(&upd_size, sizeof(size_t), 6); + + // Set output info + size_t out_ndim = out.ndim(); + if (out_ndim == 0) { + // Need placeholders so Metal doesn't compalain + int shape_ = 0; + size_t stride_ = 0; + compute_encoder->setBytes(&shape_, sizeof(int), 7); + compute_encoder->setBytes(&stride_, sizeof(size_t), 8); + } else { + set_vector_bytes(compute_encoder, out.shape(), 7); + set_vector_bytes(compute_encoder, out.strides(), 8); + } + compute_encoder->setBytes(&out_ndim, sizeof(size_t), 9); + compute_encoder->setBytes(axes_.data(), axes_.size() * sizeof(int), 10); - if (upd_ndim == 0) { - // Need placeholders so Metal doesn't compalain - int shape_ = 0; - size_t stride_ = 0; - compute_encoder->setBytes(&shape_, sizeof(int), 3); - compute_encoder->setBytes(&stride_, sizeof(size_t), 4); - } else { - compute_encoder->setBytes(upd.shape().data(), upd_ndim * sizeof(int), 3); - compute_encoder->setBytes( - upd.strides().data(), upd_ndim * sizeof(size_t), 4); - } - compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 5); - compute_encoder->setBytes(&upd_size, sizeof(size_t), 6); - - // Set output info - size_t out_ndim = out.ndim(); - if (out_ndim == 0) { - // Need placeholders so Metal doesn't compalain - int shape_ = 0; - size_t stride_ = 0; - compute_encoder->setBytes(&shape_, sizeof(int), 7); - compute_encoder->setBytes(&stride_, sizeof(size_t), 8); - } else { - compute_encoder->setBytes(out.shape().data(), out_ndim * sizeof(int), 7); - compute_encoder->setBytes( - out.strides().data(), out_ndim * sizeof(size_t), 8); - } - compute_encoder->setBytes(&out_ndim, sizeof(size_t), 9); - compute_encoder->setBytes(axes_.data(), axes_.size() * sizeof(int), 10); - - // Set index info - if (idx_ndim == 0) { - // Add a 0 in idx_shapes and strides to avoid the missing buffer binding - // error in the metal API. - idx_shapes.push_back(0); - idx_strides.push_back(0); - } - compute_encoder->setBytes( - idx_shapes.data(), idx_shapes.size() * sizeof(int), 11); - compute_encoder->setBytes( - idx_strides.data(), idx_strides.size() * sizeof(size_t), 12); - compute_encoder->setBytes(&idx_ndim, sizeof(int), 13); - - // Set index buffers - for (int i = 0; i < nidx; ++i) { - compute_encoder.set_input_array(inputs[i + 1], 20 + i); - } + // Set index info + if (idx_ndim == 0) { + // Add a 0 in idx_shapes and strides to avoid the missing buffer binding + // error in the metal API. + idx_shapes.push_back(0); + idx_strides.push_back(0); + idx_contigs.push_back(false); + } + set_vector_bytes(compute_encoder, idx_shapes, 11); + set_vector_bytes(compute_encoder, idx_strides, 12); + set_vector_bytes(compute_encoder, idx_contigs, 13); + compute_encoder->setBytes(&idx_ndim, sizeof(int), 14); + compute_encoder->setBytes(&idx_size, sizeof(size_t), 15); + + // Set index buffers + for (int i = 0; i < nidx; ++i) { + compute_encoder.set_input_array(inputs[i + 1], 20 + i); + } - // Launch grid - MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1); - MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1); - compute_encoder.dispatchThreads(grid_dims, group_dims); + // Launch grid + auto grid_y = (nthreads / upd_size); + grid_y = (grid_y + nwork - 1) / nwork; + MTL::Size grid_dims = MTL::Size(upd_size, grid_y, 1); + auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); + if (thread_group_size != 1024) { + throw std::runtime_error("[Scatter::eval_gpu] Invalid number of threads"); } + MTL::Size group_dims = get_block_dims(upd_size, grid_y, 1); + compute_encoder.dispatchThreads(grid_dims, group_dims); } } // namespace mlx::core diff --git a/mlx/backend/metal/jit/indexing.h b/mlx/backend/metal/jit/indexing.h index 9c5ec62137..77e9541a7b 100644 --- a/mlx/backend/metal/jit/indexing.h +++ b/mlx/backend/metal/jit/indexing.h @@ -11,12 +11,13 @@ constexpr std::string_view gather_kernels = R"( const constant int* axes [[buffer(6)]], const constant int* idx_shapes [[buffer(7)]], const constant size_t* idx_strides [[buffer(8)]], - const constant int& idx_ndim [[buffer(9)]], + const constant bool* idx_contigs [[buffer(9)]], + const constant int& idx_ndim [[buffer(10)]], {4} uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) {{ Indices<{2}, {3}> idxs{{ - {{ {5} }}, idx_shapes, idx_strides, idx_ndim}}; + {{ {5} }}, idx_shapes, idx_strides, idx_contigs, idx_ndim}}; return gather_impl<{1}, {2}, {3}, {6}>( src, @@ -33,32 +34,7 @@ constexpr std::string_view gather_kernels = R"( )"; constexpr std::string_view scatter_kernels = R"( -[[kernel]] void scatter_1d_index{0}_{4}( - const device {1}* updates [[buffer(1)]], - device mlx_atomic<{1}>* out [[buffer(2)]], - const constant int* out_shape [[buffer(3)]], - const constant size_t* out_strides [[buffer(4)]], - const constant size_t& out_ndim [[buffer(5)]], - const constant int* upd_shape [[buffer(6)]], - const constant size_t& upd_ndim [[buffer(7)]], - const constant size_t& upd_size [[buffer(8)]], - {5} - uint2 gid [[thread_position_in_grid]]) {{ - const array idx_buffers = {{ {6} }}; - return scatter_1d_index_impl<{1}, {2}, {3}, {4}>( - updates, - out, - out_shape, - out_strides, - out_ndim, - upd_shape, - upd_ndim, - upd_size, - idx_buffers, - gid); -}} - -[[kernel]] void scatter{0}_{4}( +[[kernel]] void scatter{0}_{4}_updc_{7}_nwork{8}( const device {1}* updates [[buffer(1)]], device mlx_atomic<{1}>* out [[buffer(2)]], const constant int* upd_shape [[buffer(3)]], @@ -71,12 +47,14 @@ constexpr std::string_view scatter_kernels = R"( const constant int* axes [[buffer(10)]], const constant int* idx_shapes [[buffer(11)]], const constant size_t* idx_strides [[buffer(12)]], - const constant int& idx_ndim [[buffer(13)]], + const constant bool* idx_contigs [[buffer(13)]], + const constant int& idx_ndim [[buffer(14)]], + const constant size_t& idx_size [[buffer(15)]], {5} uint2 gid [[thread_position_in_grid]]) {{ - Indices<{2}, {4}> idxs{{ {{ {6} }}, idx_shapes, idx_strides, idx_ndim}}; + Indices<{2}, {4}> idxs{{ {{ {6} }}, idx_shapes, idx_strides, idx_contigs, idx_ndim}}; - return scatter_impl<{1}, {2}, {3}, {4}>( + return scatter_impl<{1}, {2}, {3}, {4}, {7}, {8}>( updates, out, upd_shape, @@ -87,6 +65,7 @@ constexpr std::string_view scatter_kernels = R"( out_strides, out_ndim, axes, + idx_size, idxs, gid); }} diff --git a/mlx/backend/metal/kernels/gather.h b/mlx/backend/metal/kernels/gather.h index 4ee5299747..4d3997ad8c 100644 --- a/mlx/backend/metal/kernels/gather.h +++ b/mlx/backend/metal/kernels/gather.h @@ -25,11 +25,13 @@ METAL_FUNC void gather_impl( idx_loc = index.x * indices.strides[indices.ndim * i]; } else { idx_loc = index.x * indices.strides[indices.ndim * i]; - idx_loc += elem_to_loc( - index.y, - &indices.shapes[indices.ndim * i + 1], - &indices.strides[indices.ndim * i + 1], - indices.ndim - 1); + idx_loc += indices.row_contiguous[i] + ? index.y + : elem_to_loc( + index.y, + &indices.shapes[indices.ndim * i + 1], + &indices.strides[indices.ndim * i + 1], + indices.ndim - 1); } auto ax = axes[i]; auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], src_shape[ax]); diff --git a/mlx/backend/metal/kernels/indexing.h b/mlx/backend/metal/kernels/indexing.h index 9f76e4771c..ca4158df6b 100644 --- a/mlx/backend/metal/kernels/indexing.h +++ b/mlx/backend/metal/kernels/indexing.h @@ -9,6 +9,7 @@ struct Indices { const array buffers; const constant int* shapes; const constant size_t* strides; + const constant bool* row_contiguous; const int ndim; }; diff --git a/mlx/backend/metal/kernels/scatter.h b/mlx/backend/metal/kernels/scatter.h index b4c6f00618..9a38e62b11 100644 --- a/mlx/backend/metal/kernels/scatter.h +++ b/mlx/backend/metal/kernels/scatter.h @@ -4,73 +4,54 @@ #include "mlx/backend/metal/kernels/indexing.h" -template -METAL_FUNC void scatter_1d_index_impl( - const device T* updates [[buffer(1)]], - device mlx_atomic* out [[buffer(2)]], - const constant int* out_shape [[buffer(3)]], - const constant size_t* out_strides [[buffer(4)]], - const constant size_t& out_ndim [[buffer(5)]], - const constant int* upd_shape [[buffer(6)]], - const constant size_t& upd_ndim [[buffer(7)]], - const constant size_t& upd_size [[buffer(8)]], - const thread array& idx_buffers, - uint2 gid [[thread_position_in_grid]]) { - Op op; - - size_t out_idx = 0; - for (int i = 0; i < NIDX; i++) { - auto idx_val = offset_neg_idx(idx_buffers[i][gid.y], out_shape[i]); - out_idx += idx_val * out_strides[i]; - } - - if (upd_ndim > 1) { - auto out_offset = elem_to_loc(gid.x, upd_shape + 1, out_strides, out_ndim); - out_idx += out_offset; - } else { - out_idx += gid.x; - } - - op.atomic_update(out, updates[gid.y * upd_size + gid.x], out_idx); -} - -template +template < + typename T, + typename IdxT, + typename Op, + int NIDX, + bool UPD_ROW_CONTIG, + int NWORK> METAL_FUNC void scatter_impl( - const device T* updates [[buffer(1)]], - device mlx_atomic* out [[buffer(2)]], - const constant int* upd_shape [[buffer(3)]], - const constant size_t* upd_strides [[buffer(4)]], - const constant size_t& upd_ndim [[buffer(5)]], - const constant size_t& upd_size [[buffer(6)]], - const constant int* out_shape [[buffer(7)]], - const constant size_t* out_strides [[buffer(8)]], - const constant size_t& out_ndim [[buffer(9)]], - const constant int* axes [[buffer(10)]], + const device T* updates, + device mlx_atomic* out, + const constant int* upd_shape, + const constant size_t* upd_strides, + const constant size_t& upd_ndim, + const constant size_t& upd_size, + const constant int* out_shape, + const constant size_t* out_strides, + const constant size_t& out_ndim, + const constant int* axes, + const constant size_t& idx_size, const thread Indices& indices, uint2 gid [[thread_position_in_grid]]) { Op op; - auto ind_idx = gid.y; - auto ind_offset = gid.x; - - size_t out_idx = 0; - for (int i = 0; i < NIDX; ++i) { - auto idx_loc = elem_to_loc( - ind_idx, - &indices.shapes[indices.ndim * i], - &indices.strides[indices.ndim * i], - indices.ndim); - auto ax = axes[i]; - auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], out_shape[ax]); - out_idx += idx_val * out_strides[ax]; - } + auto ind_idx = gid.y * NWORK; + size_t out_offset = 0; if (upd_size > 1) { - auto out_offset = elem_to_loc( - ind_offset, upd_shape + indices.ndim, out_strides, out_ndim); - out_idx += out_offset; + out_offset = + elem_to_loc(gid.x, upd_shape + indices.ndim, out_strides, out_ndim); } - auto upd_idx = - elem_to_loc(gid.y * upd_size + gid.x, upd_shape, upd_strides, upd_ndim); - op.atomic_update(out, updates[upd_idx], out_idx); + for (int j = 0; j < NWORK && ind_idx < idx_size; ++j, ind_idx++) { + size_t out_idx = out_offset; + for (int i = 0; i < NIDX; ++i) { + auto idx_loc = indices.row_contiguous[i] + ? ind_idx + : elem_to_loc( + ind_idx, + &indices.shapes[indices.ndim * i], + &indices.strides[indices.ndim * i], + indices.ndim); + auto ax = axes[i]; + auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], out_shape[ax]); + out_idx += idx_val * out_strides[ax]; + } + auto upd_idx = ind_idx * upd_size + gid.x; + if constexpr (!UPD_ROW_CONTIG) { + upd_idx = elem_to_loc(upd_idx, upd_shape, upd_strides, upd_ndim); + } + op.atomic_update(out, updates[upd_idx], out_idx); + } } diff --git a/python/mlx/nn/layers/upsample.py b/python/mlx/nn/layers/upsample.py index 6f813ba3f2..49417a9138 100644 --- a/python/mlx/nn/layers/upsample.py +++ b/python/mlx/nn/layers/upsample.py @@ -25,7 +25,7 @@ def _scaled_indices(N, scale, align_corners, dim, ndims): def _nearest_indices(N, scale, dim, ndims): - return _scaled_indices(N, scale, True, dim, ndims).astype(mx.int32) + return _scaled_indices(N, scale, True, dim, ndims).astype(mx.uint32) def _linear_indices(N, scale, align_corners, dim, ndims): @@ -37,8 +37,8 @@ def _linear_indices(N, scale, align_corners, dim, ndims): weight = mx.expand_dims(weight, -1) return ( - (indices_l.astype(mx.int32), 1 - weight), - (indices_r.astype(mx.int32), weight), + (indices_l.astype(mx.uint32), 1 - weight), + (indices_r.astype(mx.uint32), weight), ) @@ -73,10 +73,10 @@ def _get_weight(ind, grid, dist): indices_r2 = mx.clip(indices_r2, a_min=0, a_max=N - 1) return ( - (indices_l1.astype(mx.int32), weight_l1), - (indices_r1.astype(mx.int32), weight_r1), - (indices_l2.astype(mx.int32), weight_l2), - (indices_r2.astype(mx.int32), weight_r2), + (indices_l1.astype(mx.uint32), weight_l1), + (indices_r1.astype(mx.uint32), weight_r1), + (indices_l2.astype(mx.uint32), weight_l2), + (indices_r2.astype(mx.uint32), weight_r2), ) diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 0bb34cd872..99d253dd89 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1089,12 +1089,14 @@ def test_put_along_axis(self): a_mlx = mx.array(a_np) if ax == None: - idx_np = np.random.randint(low=0, high=a_np.size, size=(16,)) + idx_np = np.random.permutation(a_np.size) values_np = np.random.randint(low=0, high=100, size=(16,)) else: shape = list(a_np.shape) shape[ax] = 2 - idx_np = np.random.randint(low=0, high=a_np.shape[ax], size=shape) + idx_np = np.random.choice(a_np.shape[ax], replace=False, size=(2,)) + idx_np = np.expand_dims(idx_np, list(range(1, 2 - ax + 1))) + idx_np = np.broadcast_to(idx_np, shape) values_np = np.random.randint(low=0, high=100, size=shape) idx_np.astype(np.int32)